Skip to main content

dada_ir_sym/check/
subst_impls.rs

1use dada_ir_ast::ast::PermissionOp;
2
3use crate::ir::{
4    exprs::{
5        SymBinaryOp, SymByteLiteral, SymExpr, SymExprKind, SymLiteral, SymMatchArm, SymPlaceExpr,
6        SymPlaceExprKind,
7    },
8    functions::SymFunctionSignature,
9    subst::{Subst, SubstWith, SubstitutionFns, identity_subst},
10    types::SymGenericTerm,
11    variables::SymVariable,
12};
13
14impl<'db> Subst<'db> for SymExpr<'db> {
15    type GenericTerm = SymGenericTerm<'db>;
16}
17
18impl<'db> SubstWith<'db, SymGenericTerm<'db>> for SymExpr<'db> {
19    type Output = SymExpr<'db>;
20
21    fn identity(&self) -> Self::Output {
22        *self
23    }
24
25    fn subst_with<'subst>(
26        &'subst self,
27        db: &'db dyn crate::Db,
28        bound_vars: &mut Vec<SymVariable<'db>>,
29        subst_fns: &mut SubstitutionFns<'_, 'db, SymGenericTerm<'db>>,
30    ) -> Self::Output {
31        let span = self.span(db);
32        let ty = self.ty(db).subst_with(db, bound_vars, subst_fns);
33        let kind = self.kind(db).subst_with(db, bound_vars, subst_fns);
34        SymExpr::new(db, span, ty, kind)
35    }
36}
37
38impl<'db> Subst<'db> for SymExprKind<'db> {
39    type GenericTerm = SymGenericTerm<'db>;
40}
41
42impl<'db> SubstWith<'db, SymGenericTerm<'db>> for SymExprKind<'db> {
43    type Output = SymExprKind<'db>;
44
45    fn identity(&self) -> Self::Output {
46        self.clone()
47    }
48
49    fn subst_with<'subst>(
50        &'subst self,
51        db: &'db dyn crate::Db,
52        bound_vars: &mut Vec<SymVariable<'db>>,
53        subst_fns: &mut SubstitutionFns<'_, 'db, SymGenericTerm<'db>>,
54    ) -> Self::Output {
55        match self {
56            SymExprKind::Semi(sym_expr, sym_expr1) => SymExprKind::Semi(
57                sym_expr.subst_with(db, bound_vars, subst_fns),
58                sym_expr1.subst_with(db, bound_vars, subst_fns),
59            ),
60            SymExprKind::Tuple(vec) => {
61                SymExprKind::Tuple(vec.subst_with(db, bound_vars, subst_fns))
62            }
63            SymExprKind::Primitive(sym_literal) => {
64                SymExprKind::Primitive(sym_literal.subst_with(db, bound_vars, subst_fns))
65            }
66            SymExprKind::ByteLiteral(sym_byte_literal) => {
67                SymExprKind::ByteLiteral(sym_byte_literal.subst_with(db, bound_vars, subst_fns))
68            }
69            SymExprKind::LetIn {
70                lv,
71                ty,
72                initializer,
73                body,
74            } => SymExprKind::LetIn {
75                lv: *lv,
76                ty: ty.subst_with(db, bound_vars, subst_fns),
77                initializer: initializer.subst_with(db, bound_vars, subst_fns),
78                body: bind_variable(*lv, bound_vars, |bound_vars| {
79                    body.subst_with(db, bound_vars, subst_fns)
80                }),
81            },
82            SymExprKind::Await {
83                future,
84                await_keyword,
85            } => SymExprKind::Await {
86                future: future.subst_with(db, bound_vars, subst_fns),
87                await_keyword: await_keyword.subst_with(db, bound_vars, subst_fns),
88            },
89            SymExprKind::Assign { place, value } => SymExprKind::Assign {
90                place: place.subst_with(db, bound_vars, subst_fns),
91                value: value.subst_with(db, bound_vars, subst_fns),
92            },
93            SymExprKind::PermissionOp(permission_op, sym_place_expr) => SymExprKind::PermissionOp(
94                permission_op.subst_with(db, bound_vars, subst_fns),
95                sym_place_expr.subst_with(db, bound_vars, subst_fns),
96            ),
97            SymExprKind::Call {
98                function,
99                substitution,
100                arg_temps,
101            } => SymExprKind::Call {
102                function: function.subst_with(db, bound_vars, subst_fns),
103                substitution: substitution.subst_with(db, bound_vars, subst_fns),
104                arg_temps: arg_temps
105                    .iter()
106                    .map(|&t| assert_bound_variable(db, t, bound_vars))
107                    .collect(),
108            },
109            SymExprKind::Return(sym_expr) => {
110                SymExprKind::Return(sym_expr.subst_with(db, bound_vars, subst_fns))
111            }
112            SymExprKind::Not { operand, op_span } => SymExprKind::Not {
113                operand: operand.subst_with(db, bound_vars, subst_fns),
114                op_span: op_span.subst_with(db, bound_vars, subst_fns),
115            },
116            SymExprKind::BinaryOp(sym_binary_op, sym_expr, sym_expr1) => SymExprKind::BinaryOp(
117                sym_binary_op.subst_with(db, bound_vars, subst_fns),
118                sym_expr.subst_with(db, bound_vars, subst_fns),
119                sym_expr1.subst_with(db, bound_vars, subst_fns),
120            ),
121            SymExprKind::Aggregate { ty, fields } => SymExprKind::Aggregate {
122                ty: ty.subst_with(db, bound_vars, subst_fns),
123                fields: fields.subst_with(db, bound_vars, subst_fns),
124            },
125            SymExprKind::Match { arms } => SymExprKind::Match {
126                arms: arms.subst_with(db, bound_vars, subst_fns),
127            },
128            SymExprKind::Error(reported) => {
129                SymExprKind::Error(reported.subst_with(db, bound_vars, subst_fns))
130            }
131        }
132    }
133}
134
135impl<'db> Subst<'db> for SymPlaceExpr<'db> {
136    type GenericTerm = SymGenericTerm<'db>;
137}
138
139impl<'db> SubstWith<'db, SymGenericTerm<'db>> for SymPlaceExpr<'db> {
140    type Output = SymPlaceExpr<'db>;
141
142    fn identity(&self) -> Self::Output {
143        *self
144    }
145
146    fn subst_with<'subst>(
147        &'subst self,
148        db: &'db dyn crate::Db,
149        bound_vars: &mut Vec<SymVariable<'db>>,
150        subst_fns: &mut SubstitutionFns<'_, 'db, SymGenericTerm<'db>>,
151    ) -> Self::Output {
152        SymPlaceExpr::new(
153            db,
154            self.span(db).subst_with(db, bound_vars, subst_fns),
155            self.ty(db).subst_with(db, bound_vars, subst_fns),
156            self.kind(db).subst_with(db, bound_vars, subst_fns),
157        )
158    }
159}
160
161impl<'db> SubstWith<'db, SymGenericTerm<'db>> for SymPlaceExprKind<'db> {
162    type Output = SymPlaceExprKind<'db>;
163
164    fn identity(&self) -> Self::Output {
165        self.clone()
166    }
167
168    fn subst_with<'subst>(
169        &'subst self,
170        db: &'db dyn crate::Db,
171        bound_vars: &mut Vec<SymVariable<'db>>,
172        subst_fns: &mut SubstitutionFns<'_, 'db, SymGenericTerm<'db>>,
173    ) -> Self::Output {
174        match *self {
175            SymPlaceExprKind::Var(sym_variable) => {
176                SymPlaceExprKind::Var(assert_bound_variable(db, sym_variable, bound_vars))
177            }
178            SymPlaceExprKind::Field(sym_place_expr, sym_field) => SymPlaceExprKind::Field(
179                sym_place_expr.subst_with(db, bound_vars, subst_fns),
180                sym_field.subst_with(db, bound_vars, subst_fns),
181            ),
182            SymPlaceExprKind::Error(reported) => {
183                SymPlaceExprKind::Error(reported.subst_with(db, bound_vars, subst_fns))
184            }
185        }
186    }
187}
188
189impl<'db> Subst<'db> for SymMatchArm<'db> {
190    type GenericTerm = SymGenericTerm<'db>;
191}
192
193impl<'db> SubstWith<'db, SymGenericTerm<'db>> for SymMatchArm<'db> {
194    type Output = SymMatchArm<'db>;
195
196    fn identity(&self) -> Self::Output {
197        self.clone()
198    }
199
200    fn subst_with<'subst>(
201        &'subst self,
202        db: &'db dyn crate::Db,
203        bound_vars: &mut Vec<SymVariable<'db>>,
204        subst_fns: &mut SubstitutionFns<'_, 'db, SymGenericTerm<'db>>,
205    ) -> Self::Output {
206        let SymMatchArm { condition, body } = self;
207        SymMatchArm {
208            condition: condition.subst_with(db, bound_vars, subst_fns),
209            body: body.subst_with(db, bound_vars, subst_fns),
210        }
211    }
212}
213
214impl<'db> Subst<'db> for SymFunctionSignature<'db> {
215    type GenericTerm = SymGenericTerm<'db>;
216}
217
218impl<'db> SubstWith<'db, SymGenericTerm<'db>> for SymFunctionSignature<'db> {
219    type Output = SymFunctionSignature<'db>;
220
221    fn identity(&self) -> Self::Output {
222        *self
223    }
224
225    fn subst_with<'subst>(
226        &'subst self,
227        db: &'db dyn crate::Db,
228        bound_vars: &mut Vec<SymVariable<'db>>,
229        subst_fns: &mut SubstitutionFns<'_, 'db, SymGenericTerm<'db>>,
230    ) -> Self::Output {
231        let symbols = self.symbols(db);
232        let len = bound_vars.len();
233        bound_vars.extend_from_slice(&symbols.generic_variables);
234        bound_vars.extend_from_slice(&symbols.input_variables);
235        let input_output = self.input_output(db).subst_with(db, bound_vars, subst_fns);
236        bound_vars.truncate(len);
237        SymFunctionSignature::new(db, symbols.clone(), input_output)
238    }
239}
240
241identity_subst! {
242    for 'db {
243        SymBinaryOp,
244        PermissionOp,
245        SymLiteral,
246        SymByteLiteral<'db>,
247    }
248}
249
250fn bind_variable<'db, T>(
251    sym_variable: SymVariable<'db>,
252    bound_vars: &mut Vec<SymVariable<'db>>,
253    op: impl FnOnce(&mut Vec<SymVariable<'db>>) -> T,
254) -> T {
255    bound_vars.push(sym_variable);
256    let result = op(bound_vars);
257    bound_vars.pop().unwrap();
258    result
259}
260
261fn assert_bound_variable<'db>(
262    _db: &'db dyn crate::Db,
263    sym_variable: SymVariable<'db>,
264    bound_vars: &mut Vec<SymVariable<'db>>,
265) -> SymVariable<'db> {
266    // Program variables should always appear bound, never free, and hence are never substituted.
267    assert!(bound_vars.contains(&sym_variable));
268    sym_variable
269}