1use std::{panic::Location, pin::pin};
2
3use dada_ir_ast::diagnostic::Errors;
4use futures::{
5 StreamExt,
6 future::{Either, LocalBoxFuture},
7 stream::FuturesUnordered,
8};
9use serde::Serialize;
10
11use crate::{
12 check::{
13 debug::TaskDescription,
14 inference::Direction,
15 red::RedTy,
16 report::{Because, OrElse},
17 },
18 ir::{indices::InferVarIndex, types::SymPerm},
19};
20
21use crate::check::{env::Env, inference::InferenceVarData, report::ArcOrElse};
22
23use super::infer_bounds::{RedPermBoundIterator, RedTyBoundIterator, SymGenericTermBoundIterator};
24
25impl<'db> Env<'db> {
26 pub async fn require(
27 &mut self,
28 a: impl AsyncFnOnce(&mut Env<'db>) -> Errors<bool>,
29 because: impl FnOnce(&mut Env<'db>) -> Because<'db>,
30 or_else: &dyn OrElse<'db>,
31 ) -> Errors<()> {
32 if a(self).await? {
33 Ok(())
34 } else {
35 let because = because(self);
36 Err(or_else.report(self, because))
37 }
38 }
39
40 #[track_caller]
41 pub fn require_for_all<T>(
42 &mut self,
43 items: impl IntoIterator<Item = T>,
44 f: impl AsyncFn(&mut Env<'db>, T) -> Errors<()>,
45 ) -> impl Future<Output = Errors<()>> {
46 let caller = Location::caller();
47 async move {
48 let this = &*self;
49 let f = &f;
50 let _v: Vec<()> =
51 futures::future::try_join_all(items.into_iter().zip(0..).map(|(elem, index)| {
52 let mut env =
53 this.fork(|handle| handle.spawn(caller, TaskDescription::Require(index)));
54 async move { f(&mut env, elem).await }
55 }))
56 .await?;
57 Ok(())
58 }
59 }
60
61 pub fn require_all(&mut self) -> RequireAll<'_, 'db> {
62 RequireAll {
63 env: self,
64 required: vec![],
65 }
66 }
67
68 #[track_caller]
69 pub fn require_both(
70 &mut self,
71 a: impl AsyncFnOnce(&mut Self) -> Errors<()>,
72 b: impl AsyncFnOnce(&mut Self) -> Errors<()>,
73 ) -> impl Future<Output = Errors<()>> {
74 let caller = Location::caller();
75 async move {
76 let ((), ()) = futures::future::try_join(
77 async {
78 let mut env =
79 self.fork(|handle| handle.spawn(caller, TaskDescription::Require(0)));
80 let result = a(&mut env).await;
81 env.log_result(caller, result)
82 },
83 async {
84 let mut env =
85 self.fork(|handle| handle.spawn(caller, TaskDescription::Require(1)));
86 let result = b(&mut env).await;
87 env.log_result(caller, result)
88 },
89 )
90 .await?;
91 Ok(())
92 }
93 }
94
95 #[track_caller]
96 pub fn join<A, B>(
97 &mut self,
98 a: impl AsyncFnOnce(&mut Self) -> A,
99 b: impl AsyncFnOnce(&mut Self) -> B,
100 ) -> impl Future<Output = (A, B)>
101 where
102 A: erased_serde::Serialize,
103 B: erased_serde::Serialize,
104 {
105 let caller = Location::caller();
106 futures::future::join(
107 async {
108 let mut env = self.fork(|handle| handle.spawn(caller, TaskDescription::Join(0)));
109 let result = a(&mut env).await;
110 env.log_result(caller, result)
111 },
112 async {
113 let mut env = self.fork(|handle| handle.spawn(caller, TaskDescription::Join(1)));
114 let result = b(&mut env).await;
115 env.log_result(caller, result)
116 },
117 )
118 }
119
120 #[track_caller]
121 pub fn either(
122 &mut self,
123 a: impl AsyncFnOnce(&mut Env<'db>) -> Errors<bool>,
124 b: impl AsyncFnOnce(&mut Env<'db>) -> Errors<bool>,
125 ) -> impl Future<Output = Errors<bool>> {
126 let caller = Location::caller();
127
128 async move {
129 let a = pin!(async {
130 let mut env = self.fork(|handle| handle.spawn(caller, TaskDescription::Any(0)));
131 let result = a(&mut env).await;
132 env.log_result(caller, result)
133 });
134
135 let b = pin!(async {
136 let mut env = self.fork(|handle| handle.spawn(caller, TaskDescription::Any(1)));
137 let result = b(&mut env).await;
138 env.log_result(caller, result)
139 });
140
141 match futures::future::select(a, b).await {
142 Either::Left((Ok(true), _)) | Either::Right((Ok(true), _)) => Ok(true),
143 Either::Left((Err(reported), _)) | Either::Right((Err(reported), _)) => {
144 Err(reported)
145 }
146 Either::Left((Ok(false), f)) => f.await,
147 Either::Right((Ok(false), f)) => f.await,
148 }
149 }
150 }
151
152 #[track_caller]
156 pub fn for_all<T>(
157 &mut self,
158 items: impl IntoIterator<Item = T>,
159 test_fn: impl AsyncFn(&mut Env<'db>, T) -> Errors<bool>,
160 ) -> impl Future<Output = Errors<bool>> {
161 let compiler_location = Location::caller();
162
163 async move {
164 let this = &*self;
165 let test_fn = &test_fn;
166 let unordered = FuturesUnordered::new();
167 for (item, index) in items.into_iter().zip(0..) {
168 unordered.push(async move {
169 let mut env = this.fork(|handle| {
170 handle.spawn(compiler_location, TaskDescription::All(index))
171 });
172 let result = test_fn(&mut env, item).await;
173 env.log_result(compiler_location, result)
174 });
175 }
176 let mut unordered = pin!(unordered);
177 while let Some(r) = unordered.next().await {
178 match r {
179 Ok(true) => {}
180 Ok(false) => return Ok(false),
181 Err(reported) => return Err(reported),
182 }
183 }
184 Ok(true)
185 }
186 }
187
188 #[track_caller]
192 pub fn exists<T>(
193 &mut self,
194 items: impl IntoIterator<Item = T>,
195 test_fn: impl AsyncFn(&mut Env<'db>, T) -> Errors<bool>,
196 ) -> impl Future<Output = Errors<bool>> {
197 let compiler_location = Location::caller();
198
199 async move {
200 let this = &*self;
201 let test_fn = &test_fn;
202 let unordered = FuturesUnordered::new();
203 for (item, index) in items.into_iter().zip(0..) {
204 unordered.push(async move {
205 let mut env = this.fork(|handle| {
206 handle.spawn(compiler_location, TaskDescription::Any(index))
207 });
208 let result = test_fn(&mut env, item).await;
209 env.log_result(compiler_location, result)
210 });
211 }
212 let mut unordered = pin!(unordered);
213 while let Some(r) = unordered.next().await {
214 match r {
215 Ok(true) => return Ok(true),
216 Ok(false) => {}
217 Err(reported) => return Err(reported),
218 }
219 }
220 Ok(false)
221 }
222 }
223
224 #[track_caller]
226 pub fn both(
227 &mut self,
228 a: impl AsyncFnOnce(&mut Env<'db>) -> Errors<bool>,
229 b: impl AsyncFnOnce(&mut Env<'db>) -> Errors<bool>,
230 ) -> impl Future<Output = Errors<bool>> {
231 let compiler_location = Location::caller();
232
233 async move {
234 let a = async {
235 let mut env =
236 self.fork(|handle| handle.spawn(compiler_location, TaskDescription::All(0)));
237 let result = a(&mut env).await;
238 env.log_result(compiler_location, result)
239 };
240
241 let b = async {
242 let mut env =
243 self.fork(|handle| handle.spawn(compiler_location, TaskDescription::All(1)));
244 let result = b(&mut env).await;
245 env.log_result(compiler_location, result)
246 };
247
248 match futures::future::select(pin!(a), pin!(b)).await {
249 Either::Left((Ok(false), _)) | Either::Right((Ok(false), _)) => Ok(false),
250 Either::Left((Err(reported), _)) | Either::Right((Err(reported), _)) => {
251 Err(reported)
252 }
253 Either::Left((Ok(true), f)) => f.await,
254 Either::Right((Ok(true), f)) => f.await,
255 }
256 }
257 }
258
259 pub fn term_bounds(
267 &self,
268 perm: SymPerm<'db>,
269 infer: InferVarIndex,
270 ) -> SymGenericTermBoundIterator<'db> {
271 SymGenericTermBoundIterator::new(self, perm, infer)
272 }
273
274 pub fn red_perm_bounds(&self, infer: InferVarIndex) -> RedPermBoundIterator<'db> {
276 RedPermBoundIterator::new(self, infer)
277 }
278
279 #[expect(dead_code)]
284 pub fn red_ty_bounds(&self, infer: InferVarIndex) -> RedTyBoundIterator<'db> {
285 RedTyBoundIterator::new(self, infer)
286 }
287
288 #[track_caller]
289 pub fn loop_on_inference_var<T>(
290 &self,
291 infer: InferVarIndex,
292 op: impl FnMut(&InferenceVarData<'db>) -> Option<T>,
293 ) -> impl Future<Output = Option<T>>
294 where
295 T: Serialize,
296 {
297 let compiler_location = Location::caller();
298 self.runtime
299 .loop_on_inference_var(infer, compiler_location, &self.log, op)
300 }
301
302 pub async fn for_each_bound(
304 &mut self,
305 direction: Direction,
306 infer: InferVarIndex,
307 mut op: impl AsyncFnMut(&mut Env<'db>, &RedTy<'db>, ArcOrElse<'db>) -> Errors<()>,
308 ) -> Errors<()> {
309 let mut previous_red_ty = None;
310 loop {
311 let new_pair = self
312 .loop_on_inference_var(infer, |data| {
313 let (red_ty, or_else) = data.red_ty_bound(direction)?;
314 if let Some(previous_ty) = &previous_red_ty
315 && red_ty == *previous_ty
316 {
317 return None;
318 }
319 Some((red_ty, or_else))
320 })
321 .await;
322
323 match new_pair {
324 None => return Ok(()),
325 Some((lower_red_ty, or_else)) => {
326 self.log("for_each_bound", &[&infer, &lower_red_ty]);
327 previous_red_ty = Some(lower_red_ty);
328 op(self, previous_red_ty.as_ref().unwrap(), or_else).await?;
329 }
330 }
331 }
332 }
333}
334
335pub struct RequireAll<'env, 'db> {
336 env: &'env Env<'db>,
337 required: Vec<LocalBoxFuture<'env, Errors<()>>>,
338}
339
340impl<'env, 'db> RequireAll<'env, 'db> {
341 #[track_caller]
342 pub fn require(mut self, op: impl AsyncFnOnce(&mut Env<'db>) -> Errors<()> + 'env) -> Self {
343 let index = self.required.len();
344 let compiler_location = Location::caller();
345 let mut env = self
346 .env
347 .fork(|log| log.spawn(compiler_location, TaskDescription::All(index)));
348 let future = async move { op(&mut env).await };
349 self.required.push(Box::pin(future));
350 self
351 }
352
353 pub async fn finish(self) -> Errors<()> {
354 futures::future::try_join_all(self.required).await?;
355 Ok(())
356 }
357}