Skip to main content

dada_lang/main_lib/test/
panic_hook.rs

1use annotate_snippets::{Level, Renderer, Snippet};
2use std::panic::PanicHookInfo;
3
4/// The test runner overrides the panic hook temporarily.
5/// When a panic occurs, info about the panic is recorded here.
6#[derive(Debug)]
7pub(super) struct CapturedPanic {
8    pub file: String,
9    pub line: u32,
10    pub column: u32,
11    pub message: String,
12}
13
14thread_local! {
15    static LAST_PANIC: std::cell::Cell<Option<CapturedPanic>> = const { std::cell::Cell::new(None) };
16}
17
18pub(super) fn recording_panics<R>(op: impl FnOnce() -> R) -> R {
19    let _guard = ReplacePanicHook::new();
20    std::panic::set_hook(Box::new(|panic_hook_info| {
21        let mut panic_info = CapturedPanic {
22            file: "(unknown location)".to_string(),
23            line: 0,
24            column: 0,
25            message: "(unknown panic message)".to_string(),
26        };
27
28        if let Some(message) = panic_hook_info.payload_as_str() {
29            panic_info.message = message.to_string();
30        }
31
32        if let Some(location) = panic_hook_info.location() {
33            panic_info.file = location.file().to_string();
34            panic_info.line = location.line();
35            panic_info.column = location.column();
36        }
37
38        LAST_PANIC.with(|cell| {
39            cell.set(Some(panic_info));
40        })
41    }));
42    op()
43}
44
45pub(super) fn captured_panic() -> Option<CapturedPanic> {
46    LAST_PANIC.with(|cell| cell.take())
47}
48
49struct ReplacePanicHook {
50    #[allow(clippy::type_complexity)]
51    old_hook: Option<Box<dyn Fn(&PanicHookInfo<'_>) + Sync + Send + 'static>>,
52}
53
54impl ReplacePanicHook {
55    fn new() -> Self {
56        Self {
57            old_hook: Some(std::panic::take_hook()),
58        }
59    }
60}
61
62impl Drop for ReplacePanicHook {
63    fn drop(&mut self) {
64        std::panic::set_hook(self.old_hook.take().unwrap());
65    }
66}
67
68impl CapturedPanic {
69    pub(super) fn render(&self) -> String {
70        let Ok(source_contents) = std::fs::read_to_string(&self.file) else {
71            return format!(
72                "{}:{}:{}: {}",
73                self.file, self.line, self.column, self.message
74            );
75        };
76
77        let line_starts = std::iter::once(0)
78            .chain(
79                source_contents
80                    .char_indices()
81                    .filter_map(|(i, c)| (c == '\n').then_some(i + 1)),
82            )
83            .chain(std::iter::once(source_contents.len()))
84            .collect::<Vec<_>>();
85
86        let error_offset = line_starts[self.line as usize - 1] + self.column as usize - 1;
87        let error_range = error_offset..error_offset + 1;
88
89        Renderer::plain()
90            .render(
91                Level::Error.title(&self.message).snippet(
92                    Snippet::source(&source_contents)
93                        .line_start(1)
94                        .origin(&self.file)
95                        .fold(true)
96                        .annotation(Level::Error.span(error_range).label(&self.message)),
97                ),
98            )
99            .to_string()
100    }
101}