Skip to main content

dada_ir_sym/ir/
binder.rs

1use std::fmt::Debug;
2
3use dada_util::Never;
4use salsa::Update;
5use serde::Serialize;
6
7use crate::{
8    ir::subst::{Subst, SubstitutionFns},
9    ir::types::{HasKind, SymGenericKind},
10    ir::variables::SymVariable,
11};
12
13/// Indicates a binder for generic variables
14#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Update, Debug, Serialize)]
15pub struct Binder<'db, T: BoundTerm<'db>> {
16    pub variables: Vec<SymVariable<'db>>,
17    pub bound_value: T,
18}
19
20impl<'db, T: BoundTerm<'db>> Binder<'db, T> {
21    pub fn len(&self) -> usize {
22        self.variables.len()
23    }
24
25    pub fn is_empty(&self) -> bool {
26        self.variables.is_empty()
27    }
28
29    pub fn kind(&self, db: &'db dyn crate::Db, index: usize) -> SymGenericKind {
30        self.variables[index].kind(db)
31    }
32
33    /// Generic way to "open" a binder, giving a function that computes the replacement
34    /// value for each bound variable. You may preference [`Self::substitute`][] for the
35    /// most common cases.
36    ///
37    /// # Parameters
38    ///
39    /// * `db`, the database
40    /// * `func`, compute the replacement for bound variable at the given index
41    pub fn open(
42        &self,
43        db: &'db dyn crate::Db,
44        mut func: impl FnMut(usize) -> T::GenericTerm,
45    ) -> T::Output
46    where
47        T: Subst<'db>,
48    {
49        let mut cache = vec![None; self.len()];
50
51        self.bound_value.subst_with(
52            db,
53            &mut Default::default(),
54            &mut SubstitutionFns {
55                free_var: &mut |var| {
56                    self.variables
57                        .iter()
58                        .position(|v| *v == var)
59                        .map(|index| *cache[index].get_or_insert_with(|| func(index)))
60                },
61                infer_var: &mut |_| None,
62            },
63        )
64    }
65
66    /// Open the binder by replacing each variable with the corresponding term from `substitution`.
67    ///
68    /// # Panics
69    ///
70    /// If `substitution` does not have the correct length or there is a kind-mismatch.
71    pub fn substitute(
72        &self,
73        db: &'db dyn crate::Db,
74        substitution: &[impl Into<T::GenericTerm> + Copy],
75    ) -> T::Output {
76        assert_eq!(self.len(), substitution.len());
77        self.open(db, |index| {
78            let term = substitution[index].into();
79            assert!(term.has_kind(db, self.kind(db, index)));
80            term
81        })
82    }
83
84    /// Maps the bound contents to something else
85    /// using the contents of argument term `arg`.
86    ///
87    /// `arg` will automatically have any bound variables
88    /// shifted by 1 to account for having been inserted
89    /// into a new binder.
90    ///
91    /// If no arg is needed just supply `()`.
92    ///
93    /// NB. The argument is a `fn` to prevent accidentally leaking context.
94    pub fn map<U>(self, _db: &'db dyn crate::Db, op: impl FnOnce(T) -> U) -> Binder<'db, U>
95    where
96        U: BoundTerm<'db>,
97    {
98        Binder {
99            variables: self.variables,
100            bound_value: op(self.bound_value),
101        }
102    }
103
104    /// Maps the bound contents to something else
105    /// using the contents of argument term `arg`.
106    ///
107    /// `arg` will automatically have any bound variables
108    /// shifted by 1 to account for having been inserted
109    /// into a new binder.
110    ///
111    /// If no arg is needed just supply `()`.
112    ///
113    /// NB. The argument is a `fn` to prevent accidentally leaking context.
114    pub fn map_ref<U>(&self, _db: &'db dyn crate::Db, op: impl FnOnce(&T) -> U) -> Binder<'db, U>
115    where
116        U: BoundTerm<'db>,
117    {
118        Binder {
119            variables: self.variables.clone(),
120            bound_value: op(&self.bound_value),
121        }
122    }
123}
124
125impl<'db, T> std::ops::Index<usize> for Binder<'db, T>
126where
127    T: BoundTerm<'db>,
128{
129    type Output = SymVariable<'db>;
130
131    fn index(&self, index: usize) -> &Self::Output {
132        &self.variables[index]
133    }
134}
135
136/// A value that can appear in a binder
137pub trait BoundTerm<'db>: Update + Subst<'db, Output = Self> + Sized {
138    const BINDER_LEVELS: usize;
139    type BoundTerm: BoundTerm<'db, LeafTerm = Self::LeafTerm>;
140    type LeafTerm: Subst<'db, Output = Self::LeafTerm>;
141
142    fn bind(
143        db: &'db dyn crate::Db,
144        symbols_to_bind: &mut dyn Iterator<Item = Vec<SymVariable<'db>>>,
145        leaf_value: Self::LeafTerm,
146    ) -> Self;
147
148    fn as_binder(&self) -> Result<&Binder<'db, Self::BoundTerm>, &Self::LeafTerm>;
149}
150
151pub trait LeafBoundTerm<'db>: Update + Subst<'db, Output = Self> {}
152
153impl<'db, T> BoundTerm<'db> for T
154where
155    T: LeafBoundTerm<'db>,
156{
157    const BINDER_LEVELS: usize = 0;
158    type BoundTerm = NeverBinder<T>;
159    type LeafTerm = T;
160
161    fn bind(
162        _db: &'db dyn crate::Db,
163        symbols_to_bind: &mut dyn Iterator<Item = Vec<SymVariable<'db>>>,
164        value: T,
165    ) -> Self {
166        assert!(
167            symbols_to_bind.next().is_none(),
168            "incorrect number of binding levels in iterator"
169        );
170        value
171    }
172
173    fn as_binder(&self) -> Result<&Binder<'db, NeverBinder<Self>>, &Self> {
174        Err(self)
175    }
176}
177
178impl LeafBoundTerm<'_> for Never {}
179
180impl<'db, T> BoundTerm<'db> for Binder<'db, T>
181where
182    T: BoundTerm<'db>,
183{
184    const BINDER_LEVELS: usize = T::BINDER_LEVELS + 1;
185    type BoundTerm = T;
186    type LeafTerm = T::LeafTerm;
187
188    fn bind(
189        db: &'db dyn crate::Db,
190        symbols_to_bind: &mut dyn Iterator<Item = Vec<SymVariable<'db>>>,
191        leaf_value: T::LeafTerm,
192    ) -> Self {
193        // Extract next level of bound symbols for use in this binder;
194        // if this unwrap fails, user gave wrong number of `Binder<_>` types
195        // for the scope.
196        let variables = symbols_to_bind.next().unwrap();
197
198        // Introduce whatever binders are needed to go from the innermost
199        // value type `T` to `U`.
200        let u = T::bind(db, symbols_to_bind, leaf_value);
201        Binder {
202            variables,
203            bound_value: u,
204        }
205    }
206
207    fn as_binder(&self) -> Result<&Binder<'db, T>, &Self::LeafTerm> {
208        Ok(self)
209    }
210}
211
212#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)]
213pub struct NeverBinder<T> {
214    _data: Never,
215    _value: T,
216}
217
218unsafe impl<T> Update for NeverBinder<T> {
219    unsafe fn maybe_update(_old_pointer: *mut Self, _new_value: Self) -> bool {
220        unreachable!()
221    }
222}
223
224impl<'db, T: Subst<'db, Output = T>> BoundTerm<'db> for NeverBinder<T> {
225    const BINDER_LEVELS: usize = 0;
226
227    type BoundTerm = NeverBinder<T>;
228
229    type LeafTerm = T;
230
231    fn bind(
232        _db: &'db dyn crate::Db,
233        _symbols_to_bind: &mut dyn Iterator<Item = Vec<SymVariable<'db>>>,
234        _leaf_value: Self::LeafTerm,
235    ) -> Self {
236        unreachable!()
237    }
238
239    fn as_binder(&self) -> Result<&Binder<'db, Self::BoundTerm>, &Self::LeafTerm> {
240        unreachable!()
241    }
242}