Skip to main content

dada_ir_sym/check/
runtime.rs

1#![allow(clippy::arc_with_non_send_sync)] // FIXME: we may want to do this later?
2
3use std::{
4    fmt::Debug,
5    future::Future,
6    panic::Location,
7    rc::Rc,
8    sync::{
9        Arc, Mutex, RwLock,
10        atomic::{AtomicBool, AtomicU64, Ordering},
11    },
12    task::{Context, Poll, Waker},
13};
14
15use crate::ir::indices::InferVarIndex;
16use check_task::CheckTask;
17use dada_ir_ast::{
18    diagnostic::{Diagnostic, Err, Errors, Level},
19    span::Span,
20};
21use dada_util::{Map, Set, vecext::VecExt};
22use serde::Serialize;
23
24use crate::{check::env::Env, check::inference::InferenceVarData};
25
26use super::{
27    debug::{LogHandle, RootTaskDescription, TaskDescription, event_argument},
28    inference::InferenceVarDataChanged,
29};
30
31#[derive(Clone)]
32pub(crate) struct Runtime<'db> {
33    data: Rc<RuntimeData<'db>>,
34}
35
36pub(crate) struct RuntimeData<'db> {
37    pub db: &'db dyn crate::Db,
38
39    /// Stores the data for each inference variable created thus far.
40    inference_vars: RwLock<Vec<InferenceVarData<'db>>>,
41
42    /// Pairs `(a, b)` of inference variables where `a <: b` is required.
43    /// We insert into this set when we are relating two inference variables.
44    /// If it is a new relation, then we know we must propagate bounds.
45    sub_inference_var_pairs: Mutex<Set<(InferVarIndex, InferVarIndex)>>,
46
47    /// List of tasks that are ready to execute.
48    ready_to_execute: Mutex<Vec<Arc<CheckTask>>>,
49
50    /// List of tasks that are blocked, keyed by the variable they are blocked on.
51    /// When the data for `InferVarIndex` changes, the tasks will be awoken.
52    waiting_on_inference_var: Mutex<Map<InferVarIndex, Vec<EqWaker>>>,
53
54    /// If true, inference state is frozen and will not change further.
55    complete: AtomicBool,
56
57    /// Integer indicating the next task id; each task gets a unique id.
58    next_task_id: AtomicU64,
59
60    /// Root log handle for this check. This handle is not used to record
61    /// events, only to export the overall log. During the check, environments
62    /// carry a log handle that is specific to the current task.
63    /// This way when we log an event it is tied to the task that caused it.
64    root_log: LogHandle<'db>,
65}
66
67/// Wrapper around waker to compare its data/vtable fields by pointer equality.
68/// This suffices to identify the waker for one of our tasks,
69/// as we always use the same data/vtable pointer for a given task.
70struct EqWaker {
71    waker: Waker,
72}
73
74impl EqWaker {
75    fn new(waker: &Waker) -> Self {
76        Self {
77            waker: waker.clone(),
78        }
79    }
80}
81
82impl std::cmp::PartialEq for EqWaker {
83    fn eq(&self, other: &Self) -> bool {
84        std::ptr::addr_eq(self.waker.data(), other.waker.data())
85            && std::ptr::addr_eq(self.waker.vtable(), other.waker.vtable())
86    }
87}
88
89impl std::cmp::Eq for EqWaker {}
90
91impl<'db> std::ops::Deref for Runtime<'db> {
92    type Target = RuntimeData<'db>;
93
94    fn deref(&self) -> &Self::Target {
95        &self.data
96    }
97}
98
99impl<'db> Runtime<'db> {
100    #[track_caller]
101    pub(crate) fn execute<T, R>(
102        db: &'db dyn crate::Db,
103        span: Span<'db>,
104        message: &'static str,
105        values: &[&dyn erased_serde::Serialize],
106        constrain: impl AsyncFnOnce(&Runtime<'db>) -> T + 'db,
107        cleanup: impl FnOnce(T) -> R + 'db,
108    ) -> R
109    where
110        T: 'db,
111        R: 'db + Err<'db> + erased_serde::Serialize + Debug,
112    {
113        let compiler_location = Location::caller();
114        let runtime = Runtime::new(db, compiler_location, span, message, values);
115        let (channel_tx, channel_rx) = std::sync::mpsc::channel();
116        runtime.spawn_future({
117            let runtime = runtime.clone();
118            async move {
119                let result = constrain(&runtime).await;
120                channel_tx.send(result).unwrap();
121            }
122        });
123
124        // Run all spawned tasks until no more progress can be made.
125        runtime.drain();
126
127        // Mark inference as done and drain again. This may generate fresh errors.
128        runtime.mark_complete();
129        runtime.drain();
130
131        let result = match channel_rx.try_recv() {
132            Ok(v) => cleanup(v),
133
134            // FIXME: Obviously we need a better error message than this!
135            Err(_) => R::err(db, runtime.report_type_annotations_needed(span)),
136        };
137
138        runtime
139            .root_log
140            .log(compiler_location, "final result", &[&result]);
141
142        runtime.root_log.dump(span);
143
144        result
145    }
146
147    fn new(
148        db: &'db dyn crate::Db,
149        compiler_location: &'static Location<'static>,
150        span: Span<'db>,
151        message: &'static str,
152        values: &[&dyn erased_serde::Serialize],
153    ) -> Self {
154        Self {
155            data: Rc::new(RuntimeData {
156                db,
157                complete: Default::default(),
158                inference_vars: Default::default(),
159                sub_inference_var_pairs: Default::default(),
160                ready_to_execute: Default::default(),
161                waiting_on_inference_var: Default::default(),
162                next_task_id: Default::default(),
163                root_log: LogHandle::root(
164                    db,
165                    compiler_location,
166                    RootTaskDescription {
167                        span,
168                        message: Some(message),
169                        values: Some(event_argument(values)),
170                    },
171                ),
172            }),
173        }
174    }
175
176    /// Get a duplicate of the root log handle.
177    pub fn root_log(&self) -> LogHandle<'db> {
178        self.root_log.duplicate_root_handle()
179    }
180
181    fn next_task_id(&self) -> u64 {
182        self.data.next_task_id.fetch_add(1, Ordering::Relaxed)
183    }
184
185    /// Spawn a new check-task.
186    #[track_caller]
187    fn spawn_future(&self, future: impl Future<Output = ()> + 'db) {
188        let task = CheckTask::new(Location::caller(), self, future);
189        self.ready_to_execute.lock().unwrap().push(task);
190    }
191
192    /// Pop and return a task that is ready to execute (if any).
193    fn pop_task(&self) -> Option<Arc<CheckTask>> {
194        self.ready_to_execute.lock().unwrap().pop()
195    }
196
197    /// Continues running tasks until no more are left.
198    fn drain(&self) {
199        while let Some(ready) = self.pop_task() {
200            ready.execute(self);
201        }
202    }
203
204    /// Mark the inference process as complete and wake all tasks.
205    fn mark_complete(&self) {
206        self.complete.store(true, Ordering::Relaxed);
207
208        {
209            let mut inference_vars = self.inference_vars.write().unwrap();
210            for data in inference_vars.iter_mut() {
211                data.fallback(self.db);
212            }
213        }
214
215        let map = std::mem::take(&mut *self.waiting_on_inference_var.lock().unwrap());
216        for EqWaker { waker } in map.into_values().flatten() {
217            waker.wake();
218        }
219    }
220
221    /// Returns `true` if we have fully constructed the object IR for a given function.
222    /// Once this returns true, no more bounds will be added to inference variables.
223    pub fn check_complete(&self) -> bool {
224        self.complete.load(Ordering::Relaxed)
225    }
226
227    /// Creates a fresh inference variable of the given kind and universe.
228    ///
229    /// Low-level routine not to be directly invoked.
230    pub fn fresh_inference_var(
231        &self,
232        log: &LogHandle,
233        data: InferenceVarData<'db>,
234    ) -> InferVarIndex {
235        assert!(!self.check_complete());
236        let mut inference_vars = self.inference_vars.write().unwrap();
237        let infer = InferVarIndex::from(inference_vars.len());
238        log.infer(Location::caller(), "fresh_inference_var", infer, &[&data]);
239        inference_vars.push(data);
240        infer
241    }
242
243    /// Returns a future that blocks the current task until `op` returns `Some`.
244    /// `op` will be reinvoked each time the state of the inference variable may have changed.
245    pub fn loop_on_inference_var<T>(
246        &self,
247        infer: InferVarIndex,
248        compiler_location: &'static Location<'static>,
249        log: &LogHandle<'db>,
250        mut op: impl FnMut(&InferenceVarData<'db>) -> Option<T>,
251    ) -> impl Future<Output = Option<T>>
252    where
253        T: Serialize,
254    {
255        std::future::poll_fn(move |cx| {
256            log.infer(compiler_location, "loop_on_inference_var", infer, &[]);
257            let data = self.with_inference_var_data(infer, |data| op(data));
258            match data {
259                Some(v) => {
260                    log.infer(
261                        compiler_location,
262                        "loop_on_inference_var:success",
263                        infer,
264                        &[&v],
265                    );
266                    Poll::Ready(Some(v))
267                }
268                None => {
269                    if self.check_complete() {
270                        log.infer(compiler_location, "loop_on_inference_var:fail", infer, &[]);
271                        Poll::Ready(None)
272                    } else {
273                        log.infer(compiler_location, "loop_on_inference_var:block", infer, &[]);
274                        self.block_on_inference_var(compiler_location, log, infer, cx);
275                        Poll::Pending
276                    }
277                }
278            }
279        })
280    }
281
282    /// If `infer` is a type variable, returns the permission variable associated with `infer`.
283    /// If `infer` is a permission variable, just returns `infer`.
284    pub fn perm_infer(&self, infer: InferVarIndex) -> InferVarIndex {
285        self.with_inference_var_data(infer, |data| data.perm())
286            .unwrap_or(infer)
287    }
288
289    /// Read the current data for the given inference variable.
290    ///
291    /// A lock is held while the read occurs; deadlock will occur if there is an
292    /// attempt to mutate inference var data during the read.
293    pub fn with_inference_var_data<T>(
294        &self,
295        infer: InferVarIndex,
296        op: impl FnOnce(&InferenceVarData<'db>) -> T,
297    ) -> T {
298        let inference_vars = self.inference_vars.read().unwrap();
299        op(&inference_vars[infer.as_usize()])
300    }
301
302    /// Modify the data for the inference variable `infer`.
303    /// If the data actually changes (as indicated by the
304    /// return value of `op` via the [`InferenceVarDataChanged`][]
305    /// trait), then log the result and wake any tasks blocked on
306    /// this inference variable.
307    ///
308    /// `op` should invoke one of the mutation methods on [`InferenceVarData`][]
309    /// and nothing else. A write lock is held during the call so anything
310    /// more complex risks deadlock.
311    #[track_caller]
312    pub fn mutate_inference_var_data<T>(
313        &self,
314        infer: InferVarIndex,
315        log: &LogHandle,
316        op: impl FnOnce(&mut InferenceVarData<'db>) -> T,
317    ) -> T
318    where
319        T: InferenceVarDataChanged,
320    {
321        assert!(!self.check_complete());
322        let mut inference_vars = self.inference_vars.write().unwrap();
323        let inference_var = &mut inference_vars[infer.as_usize()];
324        let result = op(inference_var);
325        if result.did_change() {
326            log.infer(
327                Location::caller(),
328                "mutate_inference_var_data",
329                infer,
330                &[&*inference_var],
331            );
332            self.wake_tasks_monitoring_inference_var(infer);
333        }
334        result
335    }
336
337    /// Record that `lower <: upper` must hold,
338    /// returning true if this is the first time that this has been recorded
339    /// or false if it has been recorded before.
340    #[track_caller]
341    pub fn insert_sub_infer_var_pair(
342        &self,
343        lower: InferVarIndex,
344        upper: InferVarIndex,
345        log: &LogHandle,
346    ) -> bool {
347        log.log(
348            Location::caller(),
349            "insert_sub_infer_var_pair",
350            &[&lower, &upper],
351        );
352        self.sub_inference_var_pairs
353            .lock()
354            .unwrap()
355            .insert((lower, upper))
356    }
357
358    fn wake_tasks_monitoring_inference_var(&self, infer: InferVarIndex) {
359        let mut waiting_on_inference_var = self.waiting_on_inference_var.lock().unwrap();
360        let wakers = waiting_on_inference_var.remove(&infer);
361        for EqWaker { waker } in wakers.into_iter().flatten() {
362            waker.wake();
363        }
364    }
365
366    /// Execute the given future asynchronously from the main execution.
367    /// It must execute to completion eventually or an error will be reported.
368    #[track_caller]
369    pub fn spawn<R>(
370        &self,
371        env: &Env<'db>,
372        task_description: TaskDescription<'db>,
373        check: impl 'db + AsyncFnOnce(&mut Env<'db>) -> R,
374    ) where
375        R: DeferResult,
376    {
377        let compiler_location = Location::caller();
378        let mut env = env.fork(|log| log.spawn(compiler_location, task_description));
379        self.spawn_future(async move { check(&mut env).await.finish() });
380    }
381
382    /// Block the current task on changes to the given inference variable.
383    ///
384    /// # Panics
385    ///
386    /// If called when [`Self::check_complete`][] returns true.
387    fn block_on_inference_var(
388        &self,
389        compiler_location: &'static Location<'static>,
390        log: &LogHandle<'db>,
391        infer: InferVarIndex,
392        cx: &mut Context<'_>,
393    ) {
394        assert!(!self.check_complete());
395        log.infer(
396            compiler_location,
397            "block_on_inference_var",
398            infer,
399            &[&infer],
400        );
401
402        let mut waiting_on_inference_var = self.waiting_on_inference_var.lock().unwrap();
403        waiting_on_inference_var
404            .entry(infer)
405            .or_default()
406            .push_if_not_contained(EqWaker::new(cx.waker()));
407    }
408
409    fn report_type_annotations_needed(&self, span: Span<'db>) -> dada_ir_ast::diagnostic::Reported {
410        let db = self.db;
411        let mut diag = Diagnostic::error(db, span, "type annotations needed").label(
412            db,
413            Level::Error,
414            span,
415            "I need to know some of the types in this function",
416        );
417        let waiting_on_inference_var = self.waiting_on_inference_var.lock().unwrap();
418        let inference_vars = self.inference_vars.read().unwrap();
419        for (var, _) in waiting_on_inference_var.iter() {
420            let var_data = &inference_vars[var.as_usize()];
421            let var_span = var_data.span();
422            diag = diag.label(
423                db,
424                Level::Note,
425                var_span,
426                "need to know the type here".to_string(),
427            );
428        }
429        diag.report(db)
430    }
431}
432
433mod check_task {
434    use dada_util::log::LogState;
435    use futures::{FutureExt, future::LocalBoxFuture};
436    use std::{
437        future::Future,
438        panic::Location,
439        rc::Rc,
440        sync::{Arc, Mutex},
441        task::{Context, Poll, RawWaker, RawWakerVTable, Waker},
442    };
443
444    use super::Runtime;
445
446    /// # Safety notes
447    ///
448    /// This `Check` type is actually valid for some (existential) `'db`.
449    /// We erase this from the type system and simply use `'static` in the field types.
450    ///
451    /// As a result, we cannot safely access `state` unless we can be sure that `'db`
452    /// is still in scope.
453    ///
454    /// To do that, we keep a handle to `check` and then compare using `Arc::ptr_eq` to another `check` instance
455    /// which we have threaded through as an ordinary parameter (whose type must therefore be valid).
456    ///
457    /// If we are able to supply a `check` that has the same underlying `Arc`, and its type is valid,
458    /// then we know that `self.check` has that same type, and that therefore the
459    /// lifetimes in `self.state` are valid.
460    pub(super) struct CheckTask {
461        spawned_at: &'static Location<'static>,
462
463        /// Erased type: `Check<'db>`
464        runtime: Runtime<'static>,
465
466        /// Unique identifier for this task, used for debugging.
467        id: u64,
468
469        /// Erased type: `CheckTaskState<'chk>`
470        state: Mutex<CheckTaskState<'static>>,
471    }
472
473    enum CheckTaskState<'chk> {
474        Executing,
475        Waiting(LocalBoxFuture<'chk, ()>, LogState),
476        Complete,
477    }
478
479    impl CheckTask {
480        pub(super) fn new<'db>(
481            spawned_at: &'static Location<'static>,
482            runtime: &Runtime<'db>,
483            future: impl Future<Output = ()> + 'db,
484        ) -> Arc<Self> {
485            let this = {
486                let my_check = runtime.clone();
487
488                // UNSAFE: Erase lifetimes as described on [`CheckTask`][] above.
489                let my_check =
490                    unsafe { std::mem::transmute::<Runtime<'db>, Runtime<'static>>(my_check) };
491
492                Arc::new(Self {
493                    spawned_at,
494                    runtime: my_check,
495                    id: runtime.next_task_id(),
496                    state: Mutex::new(CheckTaskState::Executing),
497                })
498            };
499
500            this.set_to_wait_state(runtime, future.boxed_local());
501
502            this
503        }
504
505        fn replace_state(&self, new_state: CheckTaskState<'static>) -> CheckTaskState<'static> {
506            std::mem::replace(&mut *self.state.lock().unwrap(), new_state)
507        }
508
509        fn take_state<'db>(&self, from_check: &Runtime<'db>) -> CheckTaskState<'db> {
510            assert!(std::ptr::addr_eq(
511                Rc::as_ptr(&self.runtime.data),
512                Rc::as_ptr(&from_check.data),
513            ));
514
515            let state = self.replace_state(CheckTaskState::Executing);
516
517            // UNSAFE: Hide the lifetimes as described in the safety notes for [`CheckTask`][].
518            unsafe { std::mem::transmute::<CheckTaskState<'static>, CheckTaskState<'db>>(state) }
519        }
520
521        fn set_to_wait_state<'db>(
522            &self,
523            from_check: &Runtime<'db>,
524            future: LocalBoxFuture<'db, ()>,
525        ) {
526            assert!(std::ptr::addr_eq(
527                Rc::as_ptr(&self.runtime.data),
528                Rc::as_ptr(&from_check.data),
529            ));
530
531            // UNSAFE: Hide the lifetimes as described in the safety notes for [`CheckTask`][].
532            let future = unsafe {
533                std::mem::transmute::<LocalBoxFuture<'db, ()>, LocalBoxFuture<'static, ()>>(future)
534            };
535
536            let old_state = self.replace_state(CheckTaskState::Waiting(future, LogState::get()));
537
538            assert!(matches!(old_state, CheckTaskState::Executing));
539        }
540
541        fn waker(self: Arc<Self>) -> Waker {
542            // SAFETY: We uphold the RawWakerVtable contract.
543            // TODO: Document better.
544            unsafe {
545                Waker::from_raw(RawWaker::new(
546                    Arc::into_raw(self) as *const (),
547                    &CHECK_TASK_VTABLE,
548                ))
549            }
550        }
551
552        // Implement of the "Waker::wake" method.
553        // Invoked when an inference variable we were blocked on has changed or something like that.
554        // Adds this task to the list of ready-to-execute tasks.
555        // Note that we *may* already have completed: that's ok, then executing us will be a no-op.
556        fn wake(self: Arc<Self>) {
557            // Subtle: the lifetime annotations on `check` are declared as `'static`
558            // but they should be thought of as existential lifetimes.
559            //
560            // i.e., there is some 'chk and 'db that was associated with check
561            // when this task is created. We don't actually know (locally, anyway)
562            // that they are still valid -- `check` could have leaked via a ref-cycle.
563            //
564            // However, we do know that `check` is still
565            // ALLOCATED, because we hold a strong reference to it.
566            // We can add ourselves into the ready-to-execute list.
567            //
568            // The reader of this list will invoke `execute`, which will verify
569            // that the lifetimes are still valid.
570
571            let check = self.runtime.clone();
572            let mut ready_to_execute = check.ready_to_execute.lock().unwrap();
573            ready_to_execute.push(self);
574        }
575
576        pub(super) fn execute(self: Arc<Self>, from_check: &Runtime<'_>) {
577            let state = self.take_state(from_check);
578            match state {
579                CheckTaskState::Complete => {
580                    *self.state.lock().unwrap() = CheckTaskState::Complete;
581                }
582
583                CheckTaskState::Waiting(mut future, log_state) => {
584                    let _log = dada_util::log::enter_task(self.id, self.spawned_at, log_state);
585                    match Future::poll(
586                        future.as_mut(),
587                        &mut Context::from_waker(&self.clone().waker()),
588                    ) {
589                        Poll::Ready(()) => {
590                            *self.state.lock().unwrap() = CheckTaskState::Complete;
591                        }
592                        Poll::Pending => {
593                            self.set_to_wait_state(from_check, future);
594                        }
595                    }
596                }
597
598                CheckTaskState::Executing => {
599                    // Our execution loop is not re-entrant, so it shouldn't be possible
600                    // to hit the executing state while already executing.
601                    unreachable!();
602                }
603            }
604        }
605    }
606
607    const CHECK_TASK_VTABLE: RawWakerVTable = RawWakerVTable::new(
608        |p| {
609            let p: Arc<CheckTask> = unsafe { Arc::from_raw(p as *const CheckTask) };
610            let q = p.clone();
611            std::mem::forget(p);
612            RawWaker::new(Arc::into_raw(q) as *const (), &CHECK_TASK_VTABLE)
613        },
614        |p| {
615            let p: Arc<CheckTask> = unsafe { Arc::from_raw(p as *const CheckTask) };
616            p.wake();
617        },
618        |p| {
619            let p: Arc<CheckTask> = unsafe { Arc::from_raw(p as *const CheckTask) };
620            p.clone().wake();
621            std::mem::forget(p);
622        },
623        |p| {
624            let p: Arc<CheckTask> = unsafe { Arc::from_raw(p as *const CheckTask) };
625            std::mem::drop(p);
626        },
627    );
628}
629
630/// A trait to process the items that can result from a `Defer`.
631pub(crate) trait DeferResult {
632    fn finish(self);
633}
634
635impl DeferResult for () {
636    fn finish(self) {}
637}
638
639impl<T: DeferResult> DeferResult for Errors<T> {
640    fn finish(self) {
641        match self {
642            Ok(v) => v.finish(),
643            Err(_reported) => (),
644        }
645    }
646}