Skip to main content

dada_ir_sym/check/env/
combinator.rs

1use std::{panic::Location, pin::pin};
2
3use dada_ir_ast::diagnostic::Errors;
4use futures::{
5    StreamExt,
6    future::{Either, LocalBoxFuture},
7    stream::FuturesUnordered,
8};
9use serde::Serialize;
10
11use crate::{
12    check::{
13        debug::TaskDescription,
14        inference::Direction,
15        red::RedTy,
16        report::{Because, OrElse},
17    },
18    ir::{indices::InferVarIndex, types::SymPerm},
19};
20
21use crate::check::{env::Env, inference::InferenceVarData, report::ArcOrElse};
22
23use super::infer_bounds::{RedPermBoundIterator, RedTyBoundIterator, SymGenericTermBoundIterator};
24
25impl<'db> Env<'db> {
26    pub async fn require(
27        &mut self,
28        a: impl AsyncFnOnce(&mut Env<'db>) -> Errors<bool>,
29        because: impl FnOnce(&mut Env<'db>) -> Because<'db>,
30        or_else: &dyn OrElse<'db>,
31    ) -> Errors<()> {
32        if a(self).await? {
33            Ok(())
34        } else {
35            let because = because(self);
36            Err(or_else.report(self, because))
37        }
38    }
39
40    #[track_caller]
41    pub fn require_for_all<T>(
42        &mut self,
43        items: impl IntoIterator<Item = T>,
44        f: impl AsyncFn(&mut Env<'db>, T) -> Errors<()>,
45    ) -> impl Future<Output = Errors<()>> {
46        let caller = Location::caller();
47        async move {
48            let this = &*self;
49            let f = &f;
50            let _v: Vec<()> =
51                futures::future::try_join_all(items.into_iter().zip(0..).map(|(elem, index)| {
52                    let mut env =
53                        this.fork(|handle| handle.spawn(caller, TaskDescription::Require(index)));
54                    async move { f(&mut env, elem).await }
55                }))
56                .await?;
57            Ok(())
58        }
59    }
60
61    pub fn require_all(&mut self) -> RequireAll<'_, 'db> {
62        RequireAll {
63            env: self,
64            required: vec![],
65        }
66    }
67
68    #[track_caller]
69    pub fn require_both(
70        &mut self,
71        a: impl AsyncFnOnce(&mut Self) -> Errors<()>,
72        b: impl AsyncFnOnce(&mut Self) -> Errors<()>,
73    ) -> impl Future<Output = Errors<()>> {
74        let caller = Location::caller();
75        async move {
76            let ((), ()) = futures::future::try_join(
77                async {
78                    let mut env =
79                        self.fork(|handle| handle.spawn(caller, TaskDescription::Require(0)));
80                    let result = a(&mut env).await;
81                    env.log_result(caller, result)
82                },
83                async {
84                    let mut env =
85                        self.fork(|handle| handle.spawn(caller, TaskDescription::Require(1)));
86                    let result = b(&mut env).await;
87                    env.log_result(caller, result)
88                },
89            )
90            .await?;
91            Ok(())
92        }
93    }
94
95    #[track_caller]
96    pub fn join<A, B>(
97        &mut self,
98        a: impl AsyncFnOnce(&mut Self) -> A,
99        b: impl AsyncFnOnce(&mut Self) -> B,
100    ) -> impl Future<Output = (A, B)>
101    where
102        A: erased_serde::Serialize,
103        B: erased_serde::Serialize,
104    {
105        let caller = Location::caller();
106        futures::future::join(
107            async {
108                let mut env = self.fork(|handle| handle.spawn(caller, TaskDescription::Join(0)));
109                let result = a(&mut env).await;
110                env.log_result(caller, result)
111            },
112            async {
113                let mut env = self.fork(|handle| handle.spawn(caller, TaskDescription::Join(1)));
114                let result = b(&mut env).await;
115                env.log_result(caller, result)
116            },
117        )
118    }
119
120    #[track_caller]
121    pub fn either(
122        &mut self,
123        a: impl AsyncFnOnce(&mut Env<'db>) -> Errors<bool>,
124        b: impl AsyncFnOnce(&mut Env<'db>) -> Errors<bool>,
125    ) -> impl Future<Output = Errors<bool>> {
126        let caller = Location::caller();
127
128        async move {
129            let a = pin!(async {
130                let mut env = self.fork(|handle| handle.spawn(caller, TaskDescription::Any(0)));
131                let result = a(&mut env).await;
132                env.log_result(caller, result)
133            });
134
135            let b = pin!(async {
136                let mut env = self.fork(|handle| handle.spawn(caller, TaskDescription::Any(1)));
137                let result = b(&mut env).await;
138                env.log_result(caller, result)
139            });
140
141            match futures::future::select(a, b).await {
142                Either::Left((Ok(true), _)) | Either::Right((Ok(true), _)) => Ok(true),
143                Either::Left((Err(reported), _)) | Either::Right((Err(reported), _)) => {
144                    Err(reported)
145                }
146                Either::Left((Ok(false), f)) => f.await,
147                Either::Right((Ok(false), f)) => f.await,
148            }
149        }
150    }
151
152    /// Returns true if any of the items satisfies the predicate.
153    /// Returns false if not.
154    /// Stops executing as soon as either an error or a true result is found.
155    #[track_caller]
156    pub fn for_all<T>(
157        &mut self,
158        items: impl IntoIterator<Item = T>,
159        test_fn: impl AsyncFn(&mut Env<'db>, T) -> Errors<bool>,
160    ) -> impl Future<Output = Errors<bool>> {
161        let compiler_location = Location::caller();
162
163        async move {
164            let this = &*self;
165            let test_fn = &test_fn;
166            let unordered = FuturesUnordered::new();
167            for (item, index) in items.into_iter().zip(0..) {
168                unordered.push(async move {
169                    let mut env = this.fork(|handle| {
170                        handle.spawn(compiler_location, TaskDescription::All(index))
171                    });
172                    let result = test_fn(&mut env, item).await;
173                    env.log_result(compiler_location, result)
174                });
175            }
176            let mut unordered = pin!(unordered);
177            while let Some(r) = unordered.next().await {
178                match r {
179                    Ok(true) => {}
180                    Ok(false) => return Ok(false),
181                    Err(reported) => return Err(reported),
182                }
183            }
184            Ok(true)
185        }
186    }
187
188    /// Returns true if any of the items satisfies the predicate.
189    /// Returns false if not.
190    /// Stops executing as soon as either an error or a true result is d.
191    #[track_caller]
192    pub fn exists<T>(
193        &mut self,
194        items: impl IntoIterator<Item = T>,
195        test_fn: impl AsyncFn(&mut Env<'db>, T) -> Errors<bool>,
196    ) -> impl Future<Output = Errors<bool>> {
197        let compiler_location = Location::caller();
198
199        async move {
200            let this = &*self;
201            let test_fn = &test_fn;
202            let unordered = FuturesUnordered::new();
203            for (item, index) in items.into_iter().zip(0..) {
204                unordered.push(async move {
205                    let mut env = this.fork(|handle| {
206                        handle.spawn(compiler_location, TaskDescription::Any(index))
207                    });
208                    let result = test_fn(&mut env, item).await;
209                    env.log_result(compiler_location, result)
210                });
211            }
212            let mut unordered = pin!(unordered);
213            while let Some(r) = unordered.next().await {
214                match r {
215                    Ok(true) => return Ok(true),
216                    Ok(false) => {}
217                    Err(reported) => return Err(reported),
218                }
219            }
220            Ok(false)
221        }
222    }
223
224    /// True if both `a` and `b` are true. Stops as soon as one is found to be false.
225    #[track_caller]
226    pub fn both(
227        &mut self,
228        a: impl AsyncFnOnce(&mut Env<'db>) -> Errors<bool>,
229        b: impl AsyncFnOnce(&mut Env<'db>) -> Errors<bool>,
230    ) -> impl Future<Output = Errors<bool>> {
231        let compiler_location = Location::caller();
232
233        async move {
234            let a = async {
235                let mut env =
236                    self.fork(|handle| handle.spawn(compiler_location, TaskDescription::All(0)));
237                let result = a(&mut env).await;
238                env.log_result(compiler_location, result)
239            };
240
241            let b = async {
242                let mut env =
243                    self.fork(|handle| handle.spawn(compiler_location, TaskDescription::All(1)));
244                let result = b(&mut env).await;
245                env.log_result(compiler_location, result)
246            };
247
248            match futures::future::select(pin!(a), pin!(b)).await {
249                Either::Left((Ok(false), _)) | Either::Right((Ok(false), _)) => Ok(false),
250                Either::Left((Err(reported), _)) | Either::Right((Err(reported), _)) => {
251                    Err(reported)
252                }
253                Either::Left((Ok(true), f)) => f.await,
254                Either::Right((Ok(true), f)) => f.await,
255            }
256        }
257    }
258
259    /// Returns an iterator over the bounds on an inference variable
260    /// that appears under `perm`, yielding terms:
261    ///
262    /// * If this is a permission inference variable, the result are series of permission terms.
263    ///   These are directly converted from the [`RedPerm`](crate::check::red::RedPerm) bounds you get if you call [`Self::red_perm_bounds`].
264    /// * If this is a type inference variable, the result are series of type terms.
265    ///   They do not include the permission inference variable.
266    pub fn term_bounds(
267        &self,
268        perm: SymPerm<'db>,
269        infer: InferVarIndex,
270    ) -> SymGenericTermBoundIterator<'db> {
271        SymGenericTermBoundIterator::new(self, perm, infer)
272    }
273
274    /// Returns an iterator over the red perm bounds on a permission inference variable.
275    pub fn red_perm_bounds(&self, infer: InferVarIndex) -> RedPermBoundIterator<'db> {
276        RedPermBoundIterator::new(self, infer)
277    }
278
279    /// Returns an iterator over the red ty bounds on a type inference variable.
280    /// Note that each type inference variable has an associated permission
281    /// inference variable and that this permission is not reflected in the red ty
282    /// bound. Use [`Self::term_bounds`] to get back the complete inferred type.
283    #[expect(dead_code)]
284    pub fn red_ty_bounds(&self, infer: InferVarIndex) -> RedTyBoundIterator<'db> {
285        RedTyBoundIterator::new(self, infer)
286    }
287
288    #[track_caller]
289    pub fn loop_on_inference_var<T>(
290        &self,
291        infer: InferVarIndex,
292        op: impl FnMut(&InferenceVarData<'db>) -> Option<T>,
293    ) -> impl Future<Output = Option<T>>
294    where
295        T: Serialize,
296    {
297        let compiler_location = Location::caller();
298        self.runtime
299            .loop_on_inference_var(infer, compiler_location, &self.log, op)
300    }
301
302    /// Invoke `op` for each new lower (or upper, depending on direction) bound on `?X`.
303    pub async fn for_each_bound(
304        &mut self,
305        direction: Direction,
306        infer: InferVarIndex,
307        mut op: impl AsyncFnMut(&mut Env<'db>, &RedTy<'db>, ArcOrElse<'db>) -> Errors<()>,
308    ) -> Errors<()> {
309        let mut previous_red_ty = None;
310        loop {
311            let new_pair = self
312                .loop_on_inference_var(infer, |data| {
313                    let (red_ty, or_else) = data.red_ty_bound(direction)?;
314                    if let Some(previous_ty) = &previous_red_ty
315                        && red_ty == *previous_ty
316                    {
317                        return None;
318                    }
319                    Some((red_ty, or_else))
320                })
321                .await;
322
323            match new_pair {
324                None => return Ok(()),
325                Some((lower_red_ty, or_else)) => {
326                    self.log("for_each_bound", &[&infer, &lower_red_ty]);
327                    previous_red_ty = Some(lower_red_ty);
328                    op(self, previous_red_ty.as_ref().unwrap(), or_else).await?;
329                }
330            }
331        }
332    }
333}
334
335pub struct RequireAll<'env, 'db> {
336    env: &'env Env<'db>,
337    required: Vec<LocalBoxFuture<'env, Errors<()>>>,
338}
339
340impl<'env, 'db> RequireAll<'env, 'db> {
341    #[track_caller]
342    pub fn require(mut self, op: impl AsyncFnOnce(&mut Env<'db>) -> Errors<()> + 'env) -> Self {
343        let index = self.required.len();
344        let compiler_location = Location::caller();
345        let mut env = self
346            .env
347            .fork(|log| log.spawn(compiler_location, TaskDescription::All(index)));
348        let future = async move { op(&mut env).await };
349        self.required.push(Box::pin(future));
350        self
351    }
352
353    pub async fn finish(self) -> Errors<()> {
354        futures::future::try_join_all(self.required).await?;
355        Ok(())
356    }
357}