1use std::fmt::Debug;
2
3use dada_ir_ast::{diagnostic::Reported, span::Span};
4use dada_util::{Map, Never};
5
6use crate::{
7 ir::binder::{Binder, BoundTerm, NeverBinder},
8 ir::functions::SymInputOutput,
9 ir::types::{
10 AssertKind, HasKind, SymGenericKind, SymGenericTerm, SymPerm, SymPermKind, SymPlace,
11 SymPlaceKind, SymTy, SymTyKind, SymTyName,
12 },
13 ir::variables::{FromVar, SymVariable},
14};
15
16use super::{
17 classes::SymField,
18 functions::SymFunction,
19 generics::{SymWhereClause, SymWhereClauseKind},
20 indices::InferVarIndex,
21};
22
23pub struct SubstitutionFns<'s, 'db, Term> {
24 pub free_var: &'s mut dyn FnMut(SymVariable<'db>) -> Option<Term>,
28
29 pub infer_var: &'s mut dyn FnMut(InferVarIndex) -> Option<Term>,
33}
34
35pub fn default_free_var<Term>(_: SymVariable<'_>) -> Option<Term> {
36 None
37}
38
39pub trait Subst<'db>: SubstWith<'db, Self::GenericTerm> + Debug {
41 type GenericTerm: Copy + HasKind<'db> + Debug + FromVar<'db>;
44
45 fn subst_vars(
49 &self,
50 db: &'db dyn crate::Db,
51 map: &Map<SymVariable<'db>, Self::GenericTerm>,
52 ) -> Self::Output {
53 debug_assert!(
54 map.iter()
55 .all(|(&var, term)| term.has_kind(db, var.kind(db)))
56 );
57
58 self.subst_with(
59 db,
60 &mut Default::default(),
61 &mut SubstitutionFns {
62 free_var: &mut |var| map.get(&var).copied(),
63 infer_var: &mut |_| None,
64 },
65 )
66 }
67
68 fn subst_var(
70 &self,
71 db: &'db dyn crate::Db,
72 var: SymVariable<'db>,
73 term: Self::GenericTerm,
74 ) -> Self::Output {
75 debug_assert!(term.has_kind(db, var.kind(db)));
76
77 self.subst_with(
78 db,
79 &mut Default::default(),
80 &mut SubstitutionFns {
81 free_var: &mut |v| if v == var { Some(term) } else { None },
82 infer_var: &mut |_| None,
83 },
84 )
85 }
86
87 fn resolve_infer_var(
90 &self,
91 db: &'db dyn crate::Db,
92 bound_vars: &mut Vec<SymVariable<'db>>,
93 mut op: impl FnMut(InferVarIndex) -> Option<Self::GenericTerm>,
94 ) -> Self::Output {
95 self.subst_with(
96 db,
97 bound_vars,
98 &mut SubstitutionFns {
99 free_var: &mut |_| None,
100 infer_var: &mut op,
101 },
102 )
103 }
104}
105
106pub trait SubstWith<'db, Term> {
112 type Output;
114
115 fn identity(&self) -> Self::Output;
117
118 fn subst_with<'subst>(
128 &'subst self,
129 db: &'db dyn crate::Db,
130 bound_vars: &mut Vec<SymVariable<'db>>,
131 subst_fns: &mut SubstitutionFns<'_, 'db, Term>,
132 ) -> Self::Output;
133}
134
135impl<'db> Subst<'db> for Never {
136 type GenericTerm = SymGenericTerm<'db>;
137}
138
139impl<'db, Term> SubstWith<'db, Term> for Never {
140 type Output = Never;
141
142 fn identity(&self) -> Self::Output {
143 unreachable!()
144 }
145
146 fn subst_with<'subst>(
147 &'subst self,
148 _db: &'db dyn crate::Db,
149 _bound_vars: &mut Vec<SymVariable<'db>>,
150 _subst_fns: &mut SubstitutionFns<'_, 'db, Term>,
151 ) -> Self::Output {
152 unreachable!()
153 }
154}
155
156impl<'db, T> Subst<'db> for &T
157where
158 T: Subst<'db>,
159{
160 type GenericTerm = T::GenericTerm;
161}
162
163impl<'db, T, Term> SubstWith<'db, Term> for &T
164where
165 T: SubstWith<'db, Term>,
166{
167 type Output = T::Output;
168
169 fn subst_with<'subst>(
170 &'subst self,
171 db: &'db dyn crate::Db,
172 bound_vars: &mut Vec<SymVariable<'db>>,
173 subst_fns: &mut SubstitutionFns<'_, 'db, Term>,
174 ) -> Self::Output {
175 T::subst_with(self, db, bound_vars, subst_fns)
176 }
177
178 fn identity(&self) -> Self::Output {
179 T::identity(self)
180 }
181}
182
183impl<'db> Subst<'db> for SymGenericTerm<'db> {
184 type GenericTerm = SymGenericTerm<'db>;
185}
186
187impl<'db> SubstWith<'db, SymGenericTerm<'db>> for SymGenericTerm<'db> {
188 type Output = Self;
189
190 fn subst_with<'subst>(
191 &'subst self,
192 db: &'db dyn crate::Db,
193 bound_vars: &mut Vec<SymVariable<'db>>,
194 subst_fns: &mut SubstitutionFns<'_, 'db, SymGenericTerm<'db>>,
195 ) -> Self::Output {
196 match self {
197 SymGenericTerm::Type(ty) => {
198 SymGenericTerm::Type(ty.subst_with(db, bound_vars, subst_fns))
199 }
200 SymGenericTerm::Perm(perm) => {
201 SymGenericTerm::Perm(perm.subst_with(db, bound_vars, subst_fns))
202 }
203 SymGenericTerm::Place(place) => {
204 SymGenericTerm::Place(place.subst_with(db, bound_vars, subst_fns))
205 }
206 SymGenericTerm::Error(e) => {
207 SymGenericTerm::Error(e.subst_with(db, bound_vars, subst_fns))
208 }
209 }
210 }
211
212 fn identity(&self) -> Self::Output {
213 *self
214 }
215}
216
217impl<'db> Subst<'db> for SymTy<'db> {
218 type GenericTerm = SymGenericTerm<'db>;
219}
220
221impl<'db> SubstWith<'db, SymGenericTerm<'db>> for SymTy<'db> {
222 type Output = Self;
223
224 fn subst_with<'subst>(
225 &'subst self,
226 db: &'db dyn crate::Db,
227 bound_vars: &mut Vec<SymVariable<'db>>,
228 subst_fns: &mut SubstitutionFns<'_, 'db, SymGenericTerm<'db>>,
229 ) -> Self::Output {
230 match self.kind(db) {
231 SymTyKind::Var(var) => subst_var(db, bound_vars, subst_fns, *var),
233 SymTyKind::Infer(v) => {
234 if let Some(term) = (subst_fns.infer_var)(*v) {
235 term.assert_type(db)
236 } else {
237 self.identity()
238 }
239 }
240
241 SymTyKind::Perm(sym_perm, sym_ty) => SymTy::new(
243 db,
244 SymTyKind::Perm(
245 sym_perm.subst_with(db, bound_vars, subst_fns),
246 sym_ty.subst_with(db, bound_vars, subst_fns),
247 ),
248 ),
249 SymTyKind::Named(sym_ty_name, vec) => SymTy::new(
250 db,
251 SymTyKind::Named(
252 sym_ty_name.subst_with(db, bound_vars, subst_fns),
253 vec.iter()
254 .map(|g| g.subst_with(db, bound_vars, subst_fns))
255 .collect(),
256 ),
257 ),
258 SymTyKind::Never => self.identity(),
259 SymTyKind::Error(_) => self.identity(),
260 }
261 }
262
263 fn identity(&self) -> Self::Output {
264 *self
265 }
266}
267
268impl<'db> Subst<'db> for SymPerm<'db> {
269 type GenericTerm = SymGenericTerm<'db>;
270}
271
272impl<'db> SubstWith<'db, SymGenericTerm<'db>> for SymPerm<'db> {
273 type Output = Self;
274
275 fn identity(&self) -> Self::Output {
276 *self
277 }
278
279 fn subst_with<'subst>(
280 &self,
281 db: &'db dyn crate::Db,
282 bound_vars: &mut Vec<SymVariable<'db>>,
283 subst_fns: &mut SubstitutionFns<'_, 'db, SymGenericTerm<'db>>,
284 ) -> Self::Output {
285 match self.kind(db) {
286 SymPermKind::Var(var) => subst_var(db, bound_vars, subst_fns, *var),
287 SymPermKind::Infer(v) => {
288 if let Some(term) = (subst_fns.infer_var)(*v) {
289 term.assert_perm(db)
290 } else {
291 self.identity()
292 }
293 }
294 SymPermKind::Referenced(vec) => SymPerm::new(
295 db,
296 SymPermKind::Referenced(
297 vec.iter()
298 .map(|g| g.subst_with(db, bound_vars, subst_fns))
299 .collect(),
300 ),
301 ),
302 SymPermKind::Mutable(vec) => SymPerm::new(
303 db,
304 SymPermKind::Mutable(
305 vec.iter()
306 .map(|g| g.subst_with(db, bound_vars, subst_fns))
307 .collect(),
308 ),
309 ),
310 SymPermKind::Error(reported) => SymPerm::new(
311 db,
312 SymPermKind::Error(reported.subst_with(db, bound_vars, subst_fns)),
313 ),
314 SymPermKind::My => self.identity(),
315 SymPermKind::Our => self.identity(),
316 SymPermKind::Apply(left, right) => SymPerm::new(
317 db,
318 SymPermKind::Apply(
319 left.subst_with(db, bound_vars, subst_fns),
320 right.subst_with(db, bound_vars, subst_fns),
321 ),
322 ),
323 SymPermKind::Or(left, right) => SymPerm::new(
324 db,
325 SymPermKind::Or(
326 left.subst_with(db, bound_vars, subst_fns),
327 right.subst_with(db, bound_vars, subst_fns),
328 ),
329 ),
330 }
331 }
332}
333
334impl<'db> Subst<'db> for SymPlace<'db> {
335 type GenericTerm = SymGenericTerm<'db>;
336}
337
338impl<'db> SubstWith<'db, SymGenericTerm<'db>> for SymPlace<'db> {
339 type Output = Self;
340
341 fn identity(&self) -> Self::Output {
342 *self
343 }
344
345 fn subst_with<'subst>(
346 &'subst self,
347 db: &'db dyn crate::Db,
348 bound_vars: &mut Vec<SymVariable<'db>>,
349 subst_fns: &mut SubstitutionFns<'_, 'db, SymGenericTerm<'db>>,
350 ) -> Self::Output {
351 match self.kind(db) {
352 SymPlaceKind::Var(var) => subst_var(db, bound_vars, subst_fns, *var),
353 SymPlaceKind::Field(sym_place, identifier) => SymPlace::new(
354 db,
355 SymPlaceKind::Field(sym_place.subst_with(db, bound_vars, subst_fns), *identifier),
356 ),
357 SymPlaceKind::Index(sym_place) => SymPlace::new(
358 db,
359 SymPlaceKind::Index(sym_place.subst_with(db, bound_vars, subst_fns)),
360 ),
361 SymPlaceKind::Error(reported) => SymPlace::new(
362 db,
363 SymPlaceKind::Error(reported.subst_with(db, bound_vars, subst_fns)),
364 ),
365 SymPlaceKind::Erased => SymPlace::new(db, SymPlaceKind::Erased),
366 }
367 }
368}
369
370impl<'db> Subst<'db> for SymWhereClause<'db> {
371 type GenericTerm = SymGenericTerm<'db>;
372}
373
374impl<'db> SubstWith<'db, SymGenericTerm<'db>> for SymWhereClause<'db> {
375 type Output = SymWhereClause<'db>;
376
377 fn identity(&self) -> Self::Output {
378 *self
379 }
380
381 fn subst_with<'subst>(
382 &'subst self,
383 db: &'db dyn crate::Db,
384 bound_vars: &mut Vec<SymVariable<'db>>,
385 subst_fns: &mut SubstitutionFns<'_, 'db, SymGenericTerm<'db>>,
386 ) -> Self::Output {
387 SymWhereClause::new(
388 db,
389 self.subject(db).subst_with(db, bound_vars, subst_fns),
390 self.kind(db).subst_with(db, bound_vars, subst_fns),
391 )
392 }
393}
394
395impl<'db, T: BoundTerm<'db>> Subst<'db> for Binder<'db, T>
396where
397 T::Output: BoundTerm<'db>,
398{
399 type GenericTerm = T::GenericTerm;
400}
401
402impl<'db, T: BoundTerm<'db>> SubstWith<'db, T::GenericTerm> for Binder<'db, T> {
403 type Output = Binder<'db, T>;
404
405 fn identity(&self) -> Self::Output {
406 Binder {
407 variables: self.variables.clone(),
408 bound_value: T::identity(&self.bound_value),
409 }
410 }
411
412 fn subst_with<'subst>(
413 &'subst self,
414 db: &'db dyn crate::Db,
415 bound_vars: &mut Vec<SymVariable<'db>>,
416 subst_fns: &mut SubstitutionFns<'_, 'db, T::GenericTerm>,
417 ) -> Self::Output {
418 let len = bound_vars.len();
419 bound_vars.extend_from_slice(&self.variables);
420 let bound_value = self.bound_value.subst_with(db, bound_vars, subst_fns);
421 bound_vars.truncate(len);
422
423 Binder {
424 variables: self.variables.clone(),
425 bound_value,
426 }
427 }
428}
429
430impl<'db, T> Subst<'db> for NeverBinder<T>
431where
432 T: Debug,
433{
434 type GenericTerm = SymGenericTerm<'db>;
435}
436
437impl<'db, T, Term> SubstWith<'db, Term> for NeverBinder<T> {
438 type Output = Self;
439
440 fn identity(&self) -> Self::Output {
441 unreachable!()
442 }
443
444 fn subst_with<'subst>(
445 &'subst self,
446 _db: &'db dyn crate::Db,
447 _bound_vars: &mut Vec<SymVariable<'db>>,
448 _subst_fns: &mut SubstitutionFns<'_, 'db, Term>,
449 ) -> Self::Output {
450 unreachable!()
451 }
452}
453
454impl<'db> Subst<'db> for SymInputOutput<'db> {
455 type GenericTerm = SymGenericTerm<'db>;
456}
457
458impl<'db> SubstWith<'db, SymGenericTerm<'db>> for SymInputOutput<'db> {
459 type Output = Self;
460
461 fn identity(&self) -> Self::Output {
462 self.clone()
463 }
464
465 fn subst_with<'subst>(
466 &'subst self,
467 db: &'db dyn crate::Db,
468 bound_vars: &mut Vec<SymVariable<'db>>,
469 subst_fns: &mut SubstitutionFns<'_, 'db, SymGenericTerm<'db>>,
470 ) -> Self::Output {
471 SymInputOutput {
472 input_tys: self.input_tys.subst_with(db, bound_vars, subst_fns),
473 output_ty: self.output_ty.subst_with(db, bound_vars, subst_fns),
474 where_clauses: self.where_clauses.subst_with(db, bound_vars, subst_fns),
475 }
476 }
477}
478
479impl<'db, T> Subst<'db> for Vec<T>
480where
481 T: Subst<'db>,
482{
483 type GenericTerm = T::GenericTerm;
484}
485
486impl<'db, T: Subst<'db>> SubstWith<'db, T::GenericTerm> for Vec<T> {
487 type Output = Vec<T::Output>;
488
489 fn identity(&self) -> Self::Output {
490 self.iter().map(T::identity).collect()
491 }
492
493 fn subst_with<'subst>(
494 &'subst self,
495 db: &'db dyn crate::Db,
496 bound_vars: &mut Vec<SymVariable<'db>>,
497 subst_fns: &mut SubstitutionFns<'_, 'db, T::GenericTerm>,
498 ) -> Self::Output {
499 self.iter()
500 .map(|t| t.subst_with(db, bound_vars, subst_fns))
501 .collect()
502 }
503}
504
505pub fn subst_var<'db, Output, Term>(
506 db: &'db dyn crate::Db,
507 bound_vars: &mut Vec<SymVariable<'db>>,
508 subst_fns: &mut SubstitutionFns<'_, 'db, Term>,
509 var: SymVariable<'db>,
510) -> Output
511where
512 Term: AssertKind<'db, Output>,
513 Output: FromVar<'db>,
514{
515 let var_appears_free = !bound_vars.contains(&var);
516
517 if var_appears_free && let Some(term) = (subst_fns.free_var)(var) {
518 return term.assert_kind(db);
519 }
520
521 Output::var(db, var)
522}
523
524macro_rules! identity_subst {
530 (for $l:lifetime { $($t:ty,)* }) => {
531 $(
532 impl<$l> Subst<$l> for $t {
533 type GenericTerm = SymGenericTerm<$l>;
534 }
535
536 impl<$l, Term> SubstWith<$l, Term> for $t {
537 type Output = Self;
538
539 fn identity(&self) -> Self::Output {
540 *self
541 }
542
543 fn subst_with<'subst>(
544 &self,
545 _db: &$l dyn crate::Db,
546 _bound_vars: &mut Vec<SymVariable<'db>>,
547 _subst_fns: &mut SubstitutionFns<'_, $l, Term>,
548 ) -> Self::Output {
549 *self
550 }
551 }
552 )*
553 };
554}
555pub(crate) use identity_subst; identity_subst! {
558 for 'db {
559 (),
560 Reported,
561 SymGenericKind,
562 SymTyName<'db>,
563 Span<'db>,
564 SymFunction<'db>,
565 SymField<'db>,
566 SymWhereClauseKind,
567 }
568}
569
570impl<'db, Term, T> SubstWith<'db, Term> for Option<T>
571where
572 T: SubstWith<'db, Term>,
573{
574 type Output = Option<T::Output>;
575
576 fn identity(&self) -> Self::Output {
577 self.as_ref().map(|v| v.identity())
578 }
579
580 fn subst_with<'subst>(
581 &'subst self,
582 db: &'db dyn crate::Db,
583 bound_vars: &mut Vec<SymVariable<'db>>,
584 subst_fns: &mut SubstitutionFns<'_, 'db, Term>,
585 ) -> Self::Output {
586 self.as_ref()
587 .map(|v| v.subst_with(db, bound_vars, subst_fns))
588 }
589}
590
591impl<'db, O, E, Term> SubstWith<'db, Term> for Result<O, E>
592where
593 O: SubstWith<'db, Term>,
594 E: SubstWith<'db, Term>,
595{
596 type Output = Result<O::Output, E::Output>;
597
598 fn identity(&self) -> Self::Output {
599 match self {
600 Ok(v) => Ok(v.identity()),
601 Err(e) => Err(e.identity()),
602 }
603 }
604
605 fn subst_with<'subst>(
606 &'subst self,
607 db: &'db dyn crate::Db,
608 bound_vars: &mut Vec<SymVariable<'db>>,
609 subst_fns: &mut SubstitutionFns<'_, 'db, Term>,
610 ) -> Self::Output {
611 match self {
612 Ok(v) => Ok(v.subst_with(db, bound_vars, subst_fns)),
613 Err(e) => Err(e.subst_with(db, bound_vars, subst_fns)),
614 }
615 }
616}