1#![allow(clippy::arc_with_non_send_sync)] use std::{
4 fmt::Debug,
5 future::Future,
6 panic::Location,
7 rc::Rc,
8 sync::{
9 Arc, Mutex, RwLock,
10 atomic::{AtomicBool, AtomicU64, Ordering},
11 },
12 task::{Context, Poll, Waker},
13};
14
15use crate::ir::indices::InferVarIndex;
16use check_task::CheckTask;
17use dada_ir_ast::{
18 diagnostic::{Diagnostic, Err, Errors, Level},
19 span::Span,
20};
21use dada_util::{Map, Set, vecext::VecExt};
22use serde::Serialize;
23
24use crate::{check::env::Env, check::inference::InferenceVarData};
25
26use super::{
27 debug::{LogHandle, RootTaskDescription, TaskDescription, event_argument},
28 inference::InferenceVarDataChanged,
29};
30
31#[derive(Clone)]
32pub(crate) struct Runtime<'db> {
33 data: Rc<RuntimeData<'db>>,
34}
35
36pub(crate) struct RuntimeData<'db> {
37 pub db: &'db dyn crate::Db,
38
39 inference_vars: RwLock<Vec<InferenceVarData<'db>>>,
41
42 sub_inference_var_pairs: Mutex<Set<(InferVarIndex, InferVarIndex)>>,
46
47 ready_to_execute: Mutex<Vec<Arc<CheckTask>>>,
49
50 waiting_on_inference_var: Mutex<Map<InferVarIndex, Vec<EqWaker>>>,
53
54 complete: AtomicBool,
56
57 next_task_id: AtomicU64,
59
60 root_log: LogHandle<'db>,
65}
66
67struct EqWaker {
71 waker: Waker,
72}
73
74impl EqWaker {
75 fn new(waker: &Waker) -> Self {
76 Self {
77 waker: waker.clone(),
78 }
79 }
80}
81
82impl std::cmp::PartialEq for EqWaker {
83 fn eq(&self, other: &Self) -> bool {
84 std::ptr::addr_eq(self.waker.data(), other.waker.data())
85 && std::ptr::addr_eq(self.waker.vtable(), other.waker.vtable())
86 }
87}
88
89impl std::cmp::Eq for EqWaker {}
90
91impl<'db> std::ops::Deref for Runtime<'db> {
92 type Target = RuntimeData<'db>;
93
94 fn deref(&self) -> &Self::Target {
95 &self.data
96 }
97}
98
99impl<'db> Runtime<'db> {
100 #[track_caller]
101 pub(crate) fn execute<T, R>(
102 db: &'db dyn crate::Db,
103 span: Span<'db>,
104 message: &'static str,
105 values: &[&dyn erased_serde::Serialize],
106 constrain: impl AsyncFnOnce(&Runtime<'db>) -> T + 'db,
107 cleanup: impl FnOnce(T) -> R + 'db,
108 ) -> R
109 where
110 T: 'db,
111 R: 'db + Err<'db> + erased_serde::Serialize + Debug,
112 {
113 let compiler_location = Location::caller();
114 let runtime = Runtime::new(db, compiler_location, span, message, values);
115 let (channel_tx, channel_rx) = std::sync::mpsc::channel();
116 runtime.spawn_future({
117 let runtime = runtime.clone();
118 async move {
119 let result = constrain(&runtime).await;
120 channel_tx.send(result).unwrap();
121 }
122 });
123
124 runtime.drain();
126
127 runtime.mark_complete();
129 runtime.drain();
130
131 let result = match channel_rx.try_recv() {
132 Ok(v) => cleanup(v),
133
134 Err(_) => R::err(db, runtime.report_type_annotations_needed(span)),
136 };
137
138 runtime
139 .root_log
140 .log(compiler_location, "final result", &[&result]);
141
142 runtime.root_log.dump(span);
143
144 result
145 }
146
147 fn new(
148 db: &'db dyn crate::Db,
149 compiler_location: &'static Location<'static>,
150 span: Span<'db>,
151 message: &'static str,
152 values: &[&dyn erased_serde::Serialize],
153 ) -> Self {
154 Self {
155 data: Rc::new(RuntimeData {
156 db,
157 complete: Default::default(),
158 inference_vars: Default::default(),
159 sub_inference_var_pairs: Default::default(),
160 ready_to_execute: Default::default(),
161 waiting_on_inference_var: Default::default(),
162 next_task_id: Default::default(),
163 root_log: LogHandle::root(
164 db,
165 compiler_location,
166 RootTaskDescription {
167 span,
168 message: Some(message),
169 values: Some(event_argument(values)),
170 },
171 ),
172 }),
173 }
174 }
175
176 pub fn root_log(&self) -> LogHandle<'db> {
178 self.root_log.duplicate_root_handle()
179 }
180
181 fn next_task_id(&self) -> u64 {
182 self.data.next_task_id.fetch_add(1, Ordering::Relaxed)
183 }
184
185 #[track_caller]
187 fn spawn_future(&self, future: impl Future<Output = ()> + 'db) {
188 let task = CheckTask::new(Location::caller(), self, future);
189 self.ready_to_execute.lock().unwrap().push(task);
190 }
191
192 fn pop_task(&self) -> Option<Arc<CheckTask>> {
194 self.ready_to_execute.lock().unwrap().pop()
195 }
196
197 fn drain(&self) {
199 while let Some(ready) = self.pop_task() {
200 ready.execute(self);
201 }
202 }
203
204 fn mark_complete(&self) {
206 self.complete.store(true, Ordering::Relaxed);
207
208 {
209 let mut inference_vars = self.inference_vars.write().unwrap();
210 for data in inference_vars.iter_mut() {
211 data.fallback(self.db);
212 }
213 }
214
215 let map = std::mem::take(&mut *self.waiting_on_inference_var.lock().unwrap());
216 for EqWaker { waker } in map.into_values().flatten() {
217 waker.wake();
218 }
219 }
220
221 pub fn check_complete(&self) -> bool {
224 self.complete.load(Ordering::Relaxed)
225 }
226
227 pub fn fresh_inference_var(
231 &self,
232 log: &LogHandle,
233 data: InferenceVarData<'db>,
234 ) -> InferVarIndex {
235 assert!(!self.check_complete());
236 let mut inference_vars = self.inference_vars.write().unwrap();
237 let infer = InferVarIndex::from(inference_vars.len());
238 log.infer(Location::caller(), "fresh_inference_var", infer, &[&data]);
239 inference_vars.push(data);
240 infer
241 }
242
243 pub fn loop_on_inference_var<T>(
246 &self,
247 infer: InferVarIndex,
248 compiler_location: &'static Location<'static>,
249 log: &LogHandle<'db>,
250 mut op: impl FnMut(&InferenceVarData<'db>) -> Option<T>,
251 ) -> impl Future<Output = Option<T>>
252 where
253 T: Serialize,
254 {
255 std::future::poll_fn(move |cx| {
256 log.infer(compiler_location, "loop_on_inference_var", infer, &[]);
257 let data = self.with_inference_var_data(infer, |data| op(data));
258 match data {
259 Some(v) => {
260 log.infer(
261 compiler_location,
262 "loop_on_inference_var:success",
263 infer,
264 &[&v],
265 );
266 Poll::Ready(Some(v))
267 }
268 None => {
269 if self.check_complete() {
270 log.infer(compiler_location, "loop_on_inference_var:fail", infer, &[]);
271 Poll::Ready(None)
272 } else {
273 log.infer(compiler_location, "loop_on_inference_var:block", infer, &[]);
274 self.block_on_inference_var(compiler_location, log, infer, cx);
275 Poll::Pending
276 }
277 }
278 }
279 })
280 }
281
282 pub fn perm_infer(&self, infer: InferVarIndex) -> InferVarIndex {
285 self.with_inference_var_data(infer, |data| data.perm())
286 .unwrap_or(infer)
287 }
288
289 pub fn with_inference_var_data<T>(
294 &self,
295 infer: InferVarIndex,
296 op: impl FnOnce(&InferenceVarData<'db>) -> T,
297 ) -> T {
298 let inference_vars = self.inference_vars.read().unwrap();
299 op(&inference_vars[infer.as_usize()])
300 }
301
302 #[track_caller]
312 pub fn mutate_inference_var_data<T>(
313 &self,
314 infer: InferVarIndex,
315 log: &LogHandle,
316 op: impl FnOnce(&mut InferenceVarData<'db>) -> T,
317 ) -> T
318 where
319 T: InferenceVarDataChanged,
320 {
321 assert!(!self.check_complete());
322 let mut inference_vars = self.inference_vars.write().unwrap();
323 let inference_var = &mut inference_vars[infer.as_usize()];
324 let result = op(inference_var);
325 if result.did_change() {
326 log.infer(
327 Location::caller(),
328 "mutate_inference_var_data",
329 infer,
330 &[&*inference_var],
331 );
332 self.wake_tasks_monitoring_inference_var(infer);
333 }
334 result
335 }
336
337 #[track_caller]
341 pub fn insert_sub_infer_var_pair(
342 &self,
343 lower: InferVarIndex,
344 upper: InferVarIndex,
345 log: &LogHandle,
346 ) -> bool {
347 log.log(
348 Location::caller(),
349 "insert_sub_infer_var_pair",
350 &[&lower, &upper],
351 );
352 self.sub_inference_var_pairs
353 .lock()
354 .unwrap()
355 .insert((lower, upper))
356 }
357
358 fn wake_tasks_monitoring_inference_var(&self, infer: InferVarIndex) {
359 let mut waiting_on_inference_var = self.waiting_on_inference_var.lock().unwrap();
360 let wakers = waiting_on_inference_var.remove(&infer);
361 for EqWaker { waker } in wakers.into_iter().flatten() {
362 waker.wake();
363 }
364 }
365
366 #[track_caller]
369 pub fn spawn<R>(
370 &self,
371 env: &Env<'db>,
372 task_description: TaskDescription<'db>,
373 check: impl 'db + AsyncFnOnce(&mut Env<'db>) -> R,
374 ) where
375 R: DeferResult,
376 {
377 let compiler_location = Location::caller();
378 let mut env = env.fork(|log| log.spawn(compiler_location, task_description));
379 self.spawn_future(async move { check(&mut env).await.finish() });
380 }
381
382 fn block_on_inference_var(
388 &self,
389 compiler_location: &'static Location<'static>,
390 log: &LogHandle<'db>,
391 infer: InferVarIndex,
392 cx: &mut Context<'_>,
393 ) {
394 assert!(!self.check_complete());
395 log.infer(
396 compiler_location,
397 "block_on_inference_var",
398 infer,
399 &[&infer],
400 );
401
402 let mut waiting_on_inference_var = self.waiting_on_inference_var.lock().unwrap();
403 waiting_on_inference_var
404 .entry(infer)
405 .or_default()
406 .push_if_not_contained(EqWaker::new(cx.waker()));
407 }
408
409 fn report_type_annotations_needed(&self, span: Span<'db>) -> dada_ir_ast::diagnostic::Reported {
410 let db = self.db;
411 let mut diag = Diagnostic::error(db, span, "type annotations needed").label(
412 db,
413 Level::Error,
414 span,
415 "I need to know some of the types in this function",
416 );
417 let waiting_on_inference_var = self.waiting_on_inference_var.lock().unwrap();
418 let inference_vars = self.inference_vars.read().unwrap();
419 for (var, _) in waiting_on_inference_var.iter() {
420 let var_data = &inference_vars[var.as_usize()];
421 let var_span = var_data.span();
422 diag = diag.label(
423 db,
424 Level::Note,
425 var_span,
426 "need to know the type here".to_string(),
427 );
428 }
429 diag.report(db)
430 }
431}
432
433mod check_task {
434 use dada_util::log::LogState;
435 use futures::{FutureExt, future::LocalBoxFuture};
436 use std::{
437 future::Future,
438 panic::Location,
439 rc::Rc,
440 sync::{Arc, Mutex},
441 task::{Context, Poll, RawWaker, RawWakerVTable, Waker},
442 };
443
444 use super::Runtime;
445
446 pub(super) struct CheckTask {
461 spawned_at: &'static Location<'static>,
462
463 runtime: Runtime<'static>,
465
466 id: u64,
468
469 state: Mutex<CheckTaskState<'static>>,
471 }
472
473 enum CheckTaskState<'chk> {
474 Executing,
475 Waiting(LocalBoxFuture<'chk, ()>, LogState),
476 Complete,
477 }
478
479 impl CheckTask {
480 pub(super) fn new<'db>(
481 spawned_at: &'static Location<'static>,
482 runtime: &Runtime<'db>,
483 future: impl Future<Output = ()> + 'db,
484 ) -> Arc<Self> {
485 let this = {
486 let my_check = runtime.clone();
487
488 let my_check =
490 unsafe { std::mem::transmute::<Runtime<'db>, Runtime<'static>>(my_check) };
491
492 Arc::new(Self {
493 spawned_at,
494 runtime: my_check,
495 id: runtime.next_task_id(),
496 state: Mutex::new(CheckTaskState::Executing),
497 })
498 };
499
500 this.set_to_wait_state(runtime, future.boxed_local());
501
502 this
503 }
504
505 fn replace_state(&self, new_state: CheckTaskState<'static>) -> CheckTaskState<'static> {
506 std::mem::replace(&mut *self.state.lock().unwrap(), new_state)
507 }
508
509 fn take_state<'db>(&self, from_check: &Runtime<'db>) -> CheckTaskState<'db> {
510 assert!(std::ptr::addr_eq(
511 Rc::as_ptr(&self.runtime.data),
512 Rc::as_ptr(&from_check.data),
513 ));
514
515 let state = self.replace_state(CheckTaskState::Executing);
516
517 unsafe { std::mem::transmute::<CheckTaskState<'static>, CheckTaskState<'db>>(state) }
519 }
520
521 fn set_to_wait_state<'db>(
522 &self,
523 from_check: &Runtime<'db>,
524 future: LocalBoxFuture<'db, ()>,
525 ) {
526 assert!(std::ptr::addr_eq(
527 Rc::as_ptr(&self.runtime.data),
528 Rc::as_ptr(&from_check.data),
529 ));
530
531 let future = unsafe {
533 std::mem::transmute::<LocalBoxFuture<'db, ()>, LocalBoxFuture<'static, ()>>(future)
534 };
535
536 let old_state = self.replace_state(CheckTaskState::Waiting(future, LogState::get()));
537
538 assert!(matches!(old_state, CheckTaskState::Executing));
539 }
540
541 fn waker(self: Arc<Self>) -> Waker {
542 unsafe {
545 Waker::from_raw(RawWaker::new(
546 Arc::into_raw(self) as *const (),
547 &CHECK_TASK_VTABLE,
548 ))
549 }
550 }
551
552 fn wake(self: Arc<Self>) {
557 let check = self.runtime.clone();
572 let mut ready_to_execute = check.ready_to_execute.lock().unwrap();
573 ready_to_execute.push(self);
574 }
575
576 pub(super) fn execute(self: Arc<Self>, from_check: &Runtime<'_>) {
577 let state = self.take_state(from_check);
578 match state {
579 CheckTaskState::Complete => {
580 *self.state.lock().unwrap() = CheckTaskState::Complete;
581 }
582
583 CheckTaskState::Waiting(mut future, log_state) => {
584 let _log = dada_util::log::enter_task(self.id, self.spawned_at, log_state);
585 match Future::poll(
586 future.as_mut(),
587 &mut Context::from_waker(&self.clone().waker()),
588 ) {
589 Poll::Ready(()) => {
590 *self.state.lock().unwrap() = CheckTaskState::Complete;
591 }
592 Poll::Pending => {
593 self.set_to_wait_state(from_check, future);
594 }
595 }
596 }
597
598 CheckTaskState::Executing => {
599 unreachable!();
602 }
603 }
604 }
605 }
606
607 const CHECK_TASK_VTABLE: RawWakerVTable = RawWakerVTable::new(
608 |p| {
609 let p: Arc<CheckTask> = unsafe { Arc::from_raw(p as *const CheckTask) };
610 let q = p.clone();
611 std::mem::forget(p);
612 RawWaker::new(Arc::into_raw(q) as *const (), &CHECK_TASK_VTABLE)
613 },
614 |p| {
615 let p: Arc<CheckTask> = unsafe { Arc::from_raw(p as *const CheckTask) };
616 p.wake();
617 },
618 |p| {
619 let p: Arc<CheckTask> = unsafe { Arc::from_raw(p as *const CheckTask) };
620 p.clone().wake();
621 std::mem::forget(p);
622 },
623 |p| {
624 let p: Arc<CheckTask> = unsafe { Arc::from_raw(p as *const CheckTask) };
625 std::mem::drop(p);
626 },
627 );
628}
629
630pub(crate) trait DeferResult {
632 fn finish(self);
633}
634
635impl DeferResult for () {
636 fn finish(self) {}
637}
638
639impl<T: DeferResult> DeferResult for Errors<T> {
640 fn finish(self) {
641 match self {
642 Ok(v) => v.finish(),
643 Err(_reported) => (),
644 }
645 }
646}