Skip to main content

dada_ir_sym/ir/
subst.rs

1use std::fmt::Debug;
2
3use dada_ir_ast::{diagnostic::Reported, span::Span};
4use dada_util::{Map, Never};
5
6use crate::{
7    ir::binder::{Binder, BoundTerm, NeverBinder},
8    ir::functions::SymInputOutput,
9    ir::types::{
10        AssertKind, HasKind, SymGenericKind, SymGenericTerm, SymPerm, SymPermKind, SymPlace,
11        SymPlaceKind, SymTy, SymTyKind, SymTyName,
12    },
13    ir::variables::{FromVar, SymVariable},
14};
15
16use super::{
17    classes::SymField,
18    functions::SymFunction,
19    generics::{SymWhereClause, SymWhereClauseKind},
20    indices::InferVarIndex,
21};
22
23pub struct SubstitutionFns<'s, 'db, Term> {
24    /// Invoked for free variables.
25    ///
26    /// If this returns None, no substitution is performed.
27    pub free_var: &'s mut dyn FnMut(SymVariable<'db>) -> Option<Term>,
28
29    /// Invoked for inference variables.
30    ///
31    /// If this returns None, no substitution is performed.
32    pub infer_var: &'s mut dyn FnMut(InferVarIndex) -> Option<Term>,
33}
34
35pub fn default_free_var<Term>(_: SymVariable<'_>) -> Option<Term> {
36    None
37}
38
39/// A type implemented by terms that can be substituted.
40pub trait Subst<'db>: SubstWith<'db, Self::GenericTerm> + Debug {
41    /// The notion of generic term appropriate for this type.
42    /// When we substitute variables, this is the type of value that we replace them with.
43    type GenericTerm: Copy + HasKind<'db> + Debug + FromVar<'db>;
44
45    /// Returns a version of `self` where universal free variables
46    /// have been replaced by the corresponding entry in `terms`.
47    /// If a variable is not present in `terms` it is not substituted.
48    fn subst_vars(
49        &self,
50        db: &'db dyn crate::Db,
51        map: &Map<SymVariable<'db>, Self::GenericTerm>,
52    ) -> Self::Output {
53        debug_assert!(
54            map.iter()
55                .all(|(&var, term)| term.has_kind(db, var.kind(db)))
56        );
57
58        self.subst_with(
59            db,
60            &mut Default::default(),
61            &mut SubstitutionFns {
62                free_var: &mut |var| map.get(&var).copied(),
63                infer_var: &mut |_| None,
64            },
65        )
66    }
67
68    /// Replace the variable `var` with `term`.
69    fn subst_var(
70        &self,
71        db: &'db dyn crate::Db,
72        var: SymVariable<'db>,
73        term: Self::GenericTerm,
74    ) -> Self::Output {
75        debug_assert!(term.has_kind(db, var.kind(db)));
76
77        self.subst_with(
78            db,
79            &mut Default::default(),
80            &mut SubstitutionFns {
81                free_var: &mut |v| if v == var { Some(term) } else { None },
82                infer_var: &mut |_| None,
83            },
84        )
85    }
86
87    /// Replace all inference variables with whatever is returned by `op`;
88    /// if `op` returns None, the inference variable is left unchanged.
89    fn resolve_infer_var(
90        &self,
91        db: &'db dyn crate::Db,
92        bound_vars: &mut Vec<SymVariable<'db>>,
93        mut op: impl FnMut(InferVarIndex) -> Option<Self::GenericTerm>,
94    ) -> Self::Output {
95        self.subst_with(
96            db,
97            bound_vars,
98            &mut SubstitutionFns {
99                free_var: &mut |_| None,
100                infer_var: &mut op,
101            },
102        )
103    }
104}
105
106/// Core substitution operation: produce a version of this type
107/// with variables replaced with instances of `Term`.
108///
109/// Most types implement this for only a single `Term`, but not all
110/// (see the macro [`identity_subst`][]).
111pub trait SubstWith<'db, Term> {
112    /// The type of the resulting term; typically `Self` but not always.
113    type Output;
114
115    /// Reproduce `self` with no edits.
116    fn identity(&self) -> Self::Output;
117
118    /// Replace `self` applying the changes from `subst_fns`.
119    ///
120    /// # Parameters
121    ///
122    /// * `db`, the database
123    /// * `start_binder`, the index of the binder we started from.
124    ///   This always begins as `SymBinderIndex::INNERMOST`
125    ///   but gets incremented as we traverse binders.
126    /// * `subst_fns`, a struct containing callbacks for substitution
127    fn subst_with<'subst>(
128        &'subst self,
129        db: &'db dyn crate::Db,
130        bound_vars: &mut Vec<SymVariable<'db>>,
131        subst_fns: &mut SubstitutionFns<'_, 'db, Term>,
132    ) -> Self::Output;
133}
134
135impl<'db> Subst<'db> for Never {
136    type GenericTerm = SymGenericTerm<'db>;
137}
138
139impl<'db, Term> SubstWith<'db, Term> for Never {
140    type Output = Never;
141
142    fn identity(&self) -> Self::Output {
143        unreachable!()
144    }
145
146    fn subst_with<'subst>(
147        &'subst self,
148        _db: &'db dyn crate::Db,
149        _bound_vars: &mut Vec<SymVariable<'db>>,
150        _subst_fns: &mut SubstitutionFns<'_, 'db, Term>,
151    ) -> Self::Output {
152        unreachable!()
153    }
154}
155
156impl<'db, T> Subst<'db> for &T
157where
158    T: Subst<'db>,
159{
160    type GenericTerm = T::GenericTerm;
161}
162
163impl<'db, T, Term> SubstWith<'db, Term> for &T
164where
165    T: SubstWith<'db, Term>,
166{
167    type Output = T::Output;
168
169    fn subst_with<'subst>(
170        &'subst self,
171        db: &'db dyn crate::Db,
172        bound_vars: &mut Vec<SymVariable<'db>>,
173        subst_fns: &mut SubstitutionFns<'_, 'db, Term>,
174    ) -> Self::Output {
175        T::subst_with(self, db, bound_vars, subst_fns)
176    }
177
178    fn identity(&self) -> Self::Output {
179        T::identity(self)
180    }
181}
182
183impl<'db> Subst<'db> for SymGenericTerm<'db> {
184    type GenericTerm = SymGenericTerm<'db>;
185}
186
187impl<'db> SubstWith<'db, SymGenericTerm<'db>> for SymGenericTerm<'db> {
188    type Output = Self;
189
190    fn subst_with<'subst>(
191        &'subst self,
192        db: &'db dyn crate::Db,
193        bound_vars: &mut Vec<SymVariable<'db>>,
194        subst_fns: &mut SubstitutionFns<'_, 'db, SymGenericTerm<'db>>,
195    ) -> Self::Output {
196        match self {
197            SymGenericTerm::Type(ty) => {
198                SymGenericTerm::Type(ty.subst_with(db, bound_vars, subst_fns))
199            }
200            SymGenericTerm::Perm(perm) => {
201                SymGenericTerm::Perm(perm.subst_with(db, bound_vars, subst_fns))
202            }
203            SymGenericTerm::Place(place) => {
204                SymGenericTerm::Place(place.subst_with(db, bound_vars, subst_fns))
205            }
206            SymGenericTerm::Error(e) => {
207                SymGenericTerm::Error(e.subst_with(db, bound_vars, subst_fns))
208            }
209        }
210    }
211
212    fn identity(&self) -> Self::Output {
213        *self
214    }
215}
216
217impl<'db> Subst<'db> for SymTy<'db> {
218    type GenericTerm = SymGenericTerm<'db>;
219}
220
221impl<'db> SubstWith<'db, SymGenericTerm<'db>> for SymTy<'db> {
222    type Output = Self;
223
224    fn subst_with<'subst>(
225        &'subst self,
226        db: &'db dyn crate::Db,
227        bound_vars: &mut Vec<SymVariable<'db>>,
228        subst_fns: &mut SubstitutionFns<'_, 'db, SymGenericTerm<'db>>,
229    ) -> Self::Output {
230        match self.kind(db) {
231            // Variables
232            SymTyKind::Var(var) => subst_var(db, bound_vars, subst_fns, *var),
233            SymTyKind::Infer(v) => {
234                if let Some(term) = (subst_fns.infer_var)(*v) {
235                    term.assert_type(db)
236                } else {
237                    self.identity()
238                }
239            }
240
241            // Structucal cases
242            SymTyKind::Perm(sym_perm, sym_ty) => SymTy::new(
243                db,
244                SymTyKind::Perm(
245                    sym_perm.subst_with(db, bound_vars, subst_fns),
246                    sym_ty.subst_with(db, bound_vars, subst_fns),
247                ),
248            ),
249            SymTyKind::Named(sym_ty_name, vec) => SymTy::new(
250                db,
251                SymTyKind::Named(
252                    sym_ty_name.subst_with(db, bound_vars, subst_fns),
253                    vec.iter()
254                        .map(|g| g.subst_with(db, bound_vars, subst_fns))
255                        .collect(),
256                ),
257            ),
258            SymTyKind::Never => self.identity(),
259            SymTyKind::Error(_) => self.identity(),
260        }
261    }
262
263    fn identity(&self) -> Self::Output {
264        *self
265    }
266}
267
268impl<'db> Subst<'db> for SymPerm<'db> {
269    type GenericTerm = SymGenericTerm<'db>;
270}
271
272impl<'db> SubstWith<'db, SymGenericTerm<'db>> for SymPerm<'db> {
273    type Output = Self;
274
275    fn identity(&self) -> Self::Output {
276        *self
277    }
278
279    fn subst_with<'subst>(
280        &self,
281        db: &'db dyn crate::Db,
282        bound_vars: &mut Vec<SymVariable<'db>>,
283        subst_fns: &mut SubstitutionFns<'_, 'db, SymGenericTerm<'db>>,
284    ) -> Self::Output {
285        match self.kind(db) {
286            SymPermKind::Var(var) => subst_var(db, bound_vars, subst_fns, *var),
287            SymPermKind::Infer(v) => {
288                if let Some(term) = (subst_fns.infer_var)(*v) {
289                    term.assert_perm(db)
290                } else {
291                    self.identity()
292                }
293            }
294            SymPermKind::Referenced(vec) => SymPerm::new(
295                db,
296                SymPermKind::Referenced(
297                    vec.iter()
298                        .map(|g| g.subst_with(db, bound_vars, subst_fns))
299                        .collect(),
300                ),
301            ),
302            SymPermKind::Mutable(vec) => SymPerm::new(
303                db,
304                SymPermKind::Mutable(
305                    vec.iter()
306                        .map(|g| g.subst_with(db, bound_vars, subst_fns))
307                        .collect(),
308                ),
309            ),
310            SymPermKind::Error(reported) => SymPerm::new(
311                db,
312                SymPermKind::Error(reported.subst_with(db, bound_vars, subst_fns)),
313            ),
314            SymPermKind::My => self.identity(),
315            SymPermKind::Our => self.identity(),
316            SymPermKind::Apply(left, right) => SymPerm::new(
317                db,
318                SymPermKind::Apply(
319                    left.subst_with(db, bound_vars, subst_fns),
320                    right.subst_with(db, bound_vars, subst_fns),
321                ),
322            ),
323            SymPermKind::Or(left, right) => SymPerm::new(
324                db,
325                SymPermKind::Or(
326                    left.subst_with(db, bound_vars, subst_fns),
327                    right.subst_with(db, bound_vars, subst_fns),
328                ),
329            ),
330        }
331    }
332}
333
334impl<'db> Subst<'db> for SymPlace<'db> {
335    type GenericTerm = SymGenericTerm<'db>;
336}
337
338impl<'db> SubstWith<'db, SymGenericTerm<'db>> for SymPlace<'db> {
339    type Output = Self;
340
341    fn identity(&self) -> Self::Output {
342        *self
343    }
344
345    fn subst_with<'subst>(
346        &'subst self,
347        db: &'db dyn crate::Db,
348        bound_vars: &mut Vec<SymVariable<'db>>,
349        subst_fns: &mut SubstitutionFns<'_, 'db, SymGenericTerm<'db>>,
350    ) -> Self::Output {
351        match self.kind(db) {
352            SymPlaceKind::Var(var) => subst_var(db, bound_vars, subst_fns, *var),
353            SymPlaceKind::Field(sym_place, identifier) => SymPlace::new(
354                db,
355                SymPlaceKind::Field(sym_place.subst_with(db, bound_vars, subst_fns), *identifier),
356            ),
357            SymPlaceKind::Index(sym_place) => SymPlace::new(
358                db,
359                SymPlaceKind::Index(sym_place.subst_with(db, bound_vars, subst_fns)),
360            ),
361            SymPlaceKind::Error(reported) => SymPlace::new(
362                db,
363                SymPlaceKind::Error(reported.subst_with(db, bound_vars, subst_fns)),
364            ),
365            SymPlaceKind::Erased => SymPlace::new(db, SymPlaceKind::Erased),
366        }
367    }
368}
369
370impl<'db> Subst<'db> for SymWhereClause<'db> {
371    type GenericTerm = SymGenericTerm<'db>;
372}
373
374impl<'db> SubstWith<'db, SymGenericTerm<'db>> for SymWhereClause<'db> {
375    type Output = SymWhereClause<'db>;
376
377    fn identity(&self) -> Self::Output {
378        *self
379    }
380
381    fn subst_with<'subst>(
382        &'subst self,
383        db: &'db dyn crate::Db,
384        bound_vars: &mut Vec<SymVariable<'db>>,
385        subst_fns: &mut SubstitutionFns<'_, 'db, SymGenericTerm<'db>>,
386    ) -> Self::Output {
387        SymWhereClause::new(
388            db,
389            self.subject(db).subst_with(db, bound_vars, subst_fns),
390            self.kind(db).subst_with(db, bound_vars, subst_fns),
391        )
392    }
393}
394
395impl<'db, T: BoundTerm<'db>> Subst<'db> for Binder<'db, T>
396where
397    T::Output: BoundTerm<'db>,
398{
399    type GenericTerm = T::GenericTerm;
400}
401
402impl<'db, T: BoundTerm<'db>> SubstWith<'db, T::GenericTerm> for Binder<'db, T> {
403    type Output = Binder<'db, T>;
404
405    fn identity(&self) -> Self::Output {
406        Binder {
407            variables: self.variables.clone(),
408            bound_value: T::identity(&self.bound_value),
409        }
410    }
411
412    fn subst_with<'subst>(
413        &'subst self,
414        db: &'db dyn crate::Db,
415        bound_vars: &mut Vec<SymVariable<'db>>,
416        subst_fns: &mut SubstitutionFns<'_, 'db, T::GenericTerm>,
417    ) -> Self::Output {
418        let len = bound_vars.len();
419        bound_vars.extend_from_slice(&self.variables);
420        let bound_value = self.bound_value.subst_with(db, bound_vars, subst_fns);
421        bound_vars.truncate(len);
422
423        Binder {
424            variables: self.variables.clone(),
425            bound_value,
426        }
427    }
428}
429
430impl<'db, T> Subst<'db> for NeverBinder<T>
431where
432    T: Debug,
433{
434    type GenericTerm = SymGenericTerm<'db>;
435}
436
437impl<'db, T, Term> SubstWith<'db, Term> for NeverBinder<T> {
438    type Output = Self;
439
440    fn identity(&self) -> Self::Output {
441        unreachable!()
442    }
443
444    fn subst_with<'subst>(
445        &'subst self,
446        _db: &'db dyn crate::Db,
447        _bound_vars: &mut Vec<SymVariable<'db>>,
448        _subst_fns: &mut SubstitutionFns<'_, 'db, Term>,
449    ) -> Self::Output {
450        unreachable!()
451    }
452}
453
454impl<'db> Subst<'db> for SymInputOutput<'db> {
455    type GenericTerm = SymGenericTerm<'db>;
456}
457
458impl<'db> SubstWith<'db, SymGenericTerm<'db>> for SymInputOutput<'db> {
459    type Output = Self;
460
461    fn identity(&self) -> Self::Output {
462        self.clone()
463    }
464
465    fn subst_with<'subst>(
466        &'subst self,
467        db: &'db dyn crate::Db,
468        bound_vars: &mut Vec<SymVariable<'db>>,
469        subst_fns: &mut SubstitutionFns<'_, 'db, SymGenericTerm<'db>>,
470    ) -> Self::Output {
471        SymInputOutput {
472            input_tys: self.input_tys.subst_with(db, bound_vars, subst_fns),
473            output_ty: self.output_ty.subst_with(db, bound_vars, subst_fns),
474            where_clauses: self.where_clauses.subst_with(db, bound_vars, subst_fns),
475        }
476    }
477}
478
479impl<'db, T> Subst<'db> for Vec<T>
480where
481    T: Subst<'db>,
482{
483    type GenericTerm = T::GenericTerm;
484}
485
486impl<'db, T: Subst<'db>> SubstWith<'db, T::GenericTerm> for Vec<T> {
487    type Output = Vec<T::Output>;
488
489    fn identity(&self) -> Self::Output {
490        self.iter().map(T::identity).collect()
491    }
492
493    fn subst_with<'subst>(
494        &'subst self,
495        db: &'db dyn crate::Db,
496        bound_vars: &mut Vec<SymVariable<'db>>,
497        subst_fns: &mut SubstitutionFns<'_, 'db, T::GenericTerm>,
498    ) -> Self::Output {
499        self.iter()
500            .map(|t| t.subst_with(db, bound_vars, subst_fns))
501            .collect()
502    }
503}
504
505pub fn subst_var<'db, Output, Term>(
506    db: &'db dyn crate::Db,
507    bound_vars: &mut Vec<SymVariable<'db>>,
508    subst_fns: &mut SubstitutionFns<'_, 'db, Term>,
509    var: SymVariable<'db>,
510) -> Output
511where
512    Term: AssertKind<'db, Output>,
513    Output: FromVar<'db>,
514{
515    let var_appears_free = !bound_vars.contains(&var);
516
517    if var_appears_free && let Some(term) = (subst_fns.free_var)(var) {
518        return term.assert_kind(db);
519    }
520
521    Output::var(db, var)
522}
523
524/// For types that do not contain any potentially substitutable
525/// content, we can use a simple impl. Note that while these types
526/// default [`ir::subst::Term`][] type to `SymGenericTerm`,
527/// that is just for convenience -- they implement [`SubstWith`][]
528/// for any type `Term`.
529macro_rules! identity_subst {
530    (for $l:lifetime { $($t:ty,)* }) => {
531        $(
532            impl<$l> Subst<$l> for $t {
533                type GenericTerm = SymGenericTerm<$l>;
534            }
535
536            impl<$l, Term> SubstWith<$l, Term> for $t {
537                type Output = Self;
538
539                fn identity(&self) -> Self::Output {
540                    *self
541                }
542
543                fn subst_with<'subst>(
544                    &self,
545                    _db: &$l dyn crate::Db,
546                    _bound_vars: &mut  Vec<SymVariable<'db>>,
547                    _subst_fns: &mut SubstitutionFns<'_, $l, Term>,
548                ) -> Self::Output {
549                    *self
550                }
551            }
552        )*
553    };
554}
555pub(crate) use identity_subst; // Now classic paths Just Work™
556
557identity_subst! {
558    for 'db {
559        (),
560        Reported,
561        SymGenericKind,
562        SymTyName<'db>,
563        Span<'db>,
564        SymFunction<'db>,
565        SymField<'db>,
566        SymWhereClauseKind,
567    }
568}
569
570impl<'db, Term, T> SubstWith<'db, Term> for Option<T>
571where
572    T: SubstWith<'db, Term>,
573{
574    type Output = Option<T::Output>;
575
576    fn identity(&self) -> Self::Output {
577        self.as_ref().map(|v| v.identity())
578    }
579
580    fn subst_with<'subst>(
581        &'subst self,
582        db: &'db dyn crate::Db,
583        bound_vars: &mut Vec<SymVariable<'db>>,
584        subst_fns: &mut SubstitutionFns<'_, 'db, Term>,
585    ) -> Self::Output {
586        self.as_ref()
587            .map(|v| v.subst_with(db, bound_vars, subst_fns))
588    }
589}
590
591impl<'db, O, E, Term> SubstWith<'db, Term> for Result<O, E>
592where
593    O: SubstWith<'db, Term>,
594    E: SubstWith<'db, Term>,
595{
596    type Output = Result<O::Output, E::Output>;
597
598    fn identity(&self) -> Self::Output {
599        match self {
600            Ok(v) => Ok(v.identity()),
601            Err(e) => Err(e.identity()),
602        }
603    }
604
605    fn subst_with<'subst>(
606        &'subst self,
607        db: &'db dyn crate::Db,
608        bound_vars: &mut Vec<SymVariable<'db>>,
609        subst_fns: &mut SubstitutionFns<'_, 'db, Term>,
610    ) -> Self::Output {
611        match self {
612            Ok(v) => Ok(v.subst_with(db, bound_vars, subst_fns)),
613            Err(e) => Err(e.subst_with(db, bound_vars, subst_fns)),
614        }
615    }
616}