1use std::fmt::Debug;
2
3use dada_util::Never;
4use salsa::Update;
5use serde::Serialize;
6
7use crate::{
8 ir::subst::{Subst, SubstitutionFns},
9 ir::types::{HasKind, SymGenericKind},
10 ir::variables::SymVariable,
11};
12
13#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Update, Debug, Serialize)]
15pub struct Binder<'db, T: BoundTerm<'db>> {
16 pub variables: Vec<SymVariable<'db>>,
17 pub bound_value: T,
18}
19
20impl<'db, T: BoundTerm<'db>> Binder<'db, T> {
21 pub fn len(&self) -> usize {
22 self.variables.len()
23 }
24
25 pub fn is_empty(&self) -> bool {
26 self.variables.is_empty()
27 }
28
29 pub fn kind(&self, db: &'db dyn crate::Db, index: usize) -> SymGenericKind {
30 self.variables[index].kind(db)
31 }
32
33 pub fn open(
42 &self,
43 db: &'db dyn crate::Db,
44 mut func: impl FnMut(usize) -> T::GenericTerm,
45 ) -> T::Output
46 where
47 T: Subst<'db>,
48 {
49 let mut cache = vec![None; self.len()];
50
51 self.bound_value.subst_with(
52 db,
53 &mut Default::default(),
54 &mut SubstitutionFns {
55 free_var: &mut |var| {
56 self.variables
57 .iter()
58 .position(|v| *v == var)
59 .map(|index| *cache[index].get_or_insert_with(|| func(index)))
60 },
61 infer_var: &mut |_| None,
62 },
63 )
64 }
65
66 pub fn substitute(
72 &self,
73 db: &'db dyn crate::Db,
74 substitution: &[impl Into<T::GenericTerm> + Copy],
75 ) -> T::Output {
76 assert_eq!(self.len(), substitution.len());
77 self.open(db, |index| {
78 let term = substitution[index].into();
79 assert!(term.has_kind(db, self.kind(db, index)));
80 term
81 })
82 }
83
84 pub fn map<U>(self, _db: &'db dyn crate::Db, op: impl FnOnce(T) -> U) -> Binder<'db, U>
95 where
96 U: BoundTerm<'db>,
97 {
98 Binder {
99 variables: self.variables,
100 bound_value: op(self.bound_value),
101 }
102 }
103
104 pub fn map_ref<U>(&self, _db: &'db dyn crate::Db, op: impl FnOnce(&T) -> U) -> Binder<'db, U>
115 where
116 U: BoundTerm<'db>,
117 {
118 Binder {
119 variables: self.variables.clone(),
120 bound_value: op(&self.bound_value),
121 }
122 }
123}
124
125impl<'db, T> std::ops::Index<usize> for Binder<'db, T>
126where
127 T: BoundTerm<'db>,
128{
129 type Output = SymVariable<'db>;
130
131 fn index(&self, index: usize) -> &Self::Output {
132 &self.variables[index]
133 }
134}
135
136pub trait BoundTerm<'db>: Update + Subst<'db, Output = Self> + Sized {
138 const BINDER_LEVELS: usize;
139 type BoundTerm: BoundTerm<'db, LeafTerm = Self::LeafTerm>;
140 type LeafTerm: Subst<'db, Output = Self::LeafTerm>;
141
142 fn bind(
143 db: &'db dyn crate::Db,
144 symbols_to_bind: &mut dyn Iterator<Item = Vec<SymVariable<'db>>>,
145 leaf_value: Self::LeafTerm,
146 ) -> Self;
147
148 fn as_binder(&self) -> Result<&Binder<'db, Self::BoundTerm>, &Self::LeafTerm>;
149}
150
151pub trait LeafBoundTerm<'db>: Update + Subst<'db, Output = Self> {}
152
153impl<'db, T> BoundTerm<'db> for T
154where
155 T: LeafBoundTerm<'db>,
156{
157 const BINDER_LEVELS: usize = 0;
158 type BoundTerm = NeverBinder<T>;
159 type LeafTerm = T;
160
161 fn bind(
162 _db: &'db dyn crate::Db,
163 symbols_to_bind: &mut dyn Iterator<Item = Vec<SymVariable<'db>>>,
164 value: T,
165 ) -> Self {
166 assert!(
167 symbols_to_bind.next().is_none(),
168 "incorrect number of binding levels in iterator"
169 );
170 value
171 }
172
173 fn as_binder(&self) -> Result<&Binder<'db, NeverBinder<Self>>, &Self> {
174 Err(self)
175 }
176}
177
178impl LeafBoundTerm<'_> for Never {}
179
180impl<'db, T> BoundTerm<'db> for Binder<'db, T>
181where
182 T: BoundTerm<'db>,
183{
184 const BINDER_LEVELS: usize = T::BINDER_LEVELS + 1;
185 type BoundTerm = T;
186 type LeafTerm = T::LeafTerm;
187
188 fn bind(
189 db: &'db dyn crate::Db,
190 symbols_to_bind: &mut dyn Iterator<Item = Vec<SymVariable<'db>>>,
191 leaf_value: T::LeafTerm,
192 ) -> Self {
193 let variables = symbols_to_bind.next().unwrap();
197
198 let u = T::bind(db, symbols_to_bind, leaf_value);
201 Binder {
202 variables,
203 bound_value: u,
204 }
205 }
206
207 fn as_binder(&self) -> Result<&Binder<'db, T>, &Self::LeafTerm> {
208 Ok(self)
209 }
210}
211
212#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)]
213pub struct NeverBinder<T> {
214 _data: Never,
215 _value: T,
216}
217
218unsafe impl<T> Update for NeverBinder<T> {
219 unsafe fn maybe_update(_old_pointer: *mut Self, _new_value: Self) -> bool {
220 unreachable!()
221 }
222}
223
224impl<'db, T: Subst<'db, Output = T>> BoundTerm<'db> for NeverBinder<T> {
225 const BINDER_LEVELS: usize = 0;
226
227 type BoundTerm = NeverBinder<T>;
228
229 type LeafTerm = T;
230
231 fn bind(
232 _db: &'db dyn crate::Db,
233 _symbols_to_bind: &mut dyn Iterator<Item = Vec<SymVariable<'db>>>,
234 _leaf_value: Self::LeafTerm,
235 ) -> Self {
236 unreachable!()
237 }
238
239 fn as_binder(&self) -> Result<&Binder<'db, Self::BoundTerm>, &Self::LeafTerm> {
240 unreachable!()
241 }
242}