dada_ir_sym/check/
resolve.rs1use std::collections::hash_map::Entry;
4
5use dada_ir_ast::diagnostic::{Diagnostic, Err, Reported};
6use dada_util::Map;
7
8use crate::ir::{
9 classes::SymAggregateStyle,
10 indices::InferVarIndex,
11 subst::Subst,
12 types::{SymGenericTerm, SymPerm, SymTy, SymTyKind},
13};
14
15use super::{
16 Env,
17 inference::{Direction, InferVarKind},
18 red::RedTy,
19};
20
21pub struct Resolver<'env, 'db> {
22 db: &'db dyn crate::Db,
23 env: &'env mut Env<'db>,
24
25 memoized_ty: Map<InferVarIndex, Result<SymTy<'db>, ResolverCycle>>,
36
37 memoized_perm: Map<InferVarIndex, SymPerm<'db>>,
46}
47
48impl<'env, 'db> Resolver<'env, 'db> {
49 pub fn new(env: &'env mut Env<'db>) -> Self {
50 assert!(
51 env.runtime().check_complete(),
52 "resolution is only possible once type constraints are known"
53 );
54
55 Self {
56 db: env.db(),
57 env,
58 memoized_ty: Default::default(),
59 memoized_perm: Default::default(),
60 }
61 }
62
63 pub fn resolve<T>(&mut self, term: T) -> T::Output
65 where
66 T: Subst<'db, GenericTerm = SymGenericTerm<'db>>,
67 {
68 let mut bound_vars = self.env.bound_vars();
69 term.resolve_infer_var(self.db, &mut bound_vars, |infer| {
70 match self.resolve_infer_var(infer) {
71 Ok(v) => Some(v),
72 Result::Err(error) => Some(SymGenericTerm::err(self.db, self.report(infer, error))),
73 }
74 })
75 }
76
77 fn resolve_infer_var(
79 &mut self,
80 infer: InferVarIndex,
81 ) -> Result<SymGenericTerm<'db>, ResolverCycle> {
82 match self.env.infer_var_kind(infer) {
83 InferVarKind::Type => Ok(self.resolve_ty_var(infer)?.into()),
84 InferVarKind::Perm => Ok(self.resolve_perm_var(infer).into()),
85 }
86 }
87
88 fn resolve_ty_var(&mut self, infer: InferVarIndex) -> Result<SymTy<'db>, ResolverCycle> {
89 match self.memoized_ty.entry(infer) {
90 Entry::Occupied(entry) => {
91 return *entry.get();
92 }
93 Entry::Vacant(entry) => {
94 entry.insert(Err(ResolverCycle));
95 }
96 }
97
98 let ty = if let Some(t) = self.bounding_ty(infer, Direction::FromBelow)? {
99 t
100 } else if let Some(t) = self.bounding_ty(infer, Direction::FromAbove)? {
101 t
102 } else {
103 panic!("found no inference bounds, odd")
105 };
106
107 self.memoized_ty.insert(infer, Ok(ty));
108 Ok(ty)
109 }
110
111 fn bounding_ty(
113 &mut self,
114 infer: InferVarIndex,
115 direction: Direction,
116 ) -> Result<Option<SymTy<'db>>, ResolverCycle> {
117 let db = self.env.db();
118
119 let bound = self.env.red_bound(infer, direction).peek_ty();
120
121 let Some((red_ty, _)) = bound else {
122 return Ok(None);
123 };
124
125 let apply_perm = |this: &mut Self, sym_ty: SymTy<'db>| {
126 let perm_infer = this.env.perm_infer(infer);
127 let sym_perm = this.resolve_perm_var(perm_infer);
128 sym_perm.apply_to(db, sym_ty)
129 };
130
131 Ok(Some(match red_ty {
132 RedTy::Error(reported) => SymTy::err(db, reported),
133 RedTy::Named(name, args) => {
134 let args = self.resolve(args);
135 let ty = SymTy::new(db, SymTyKind::Named(name, args));
136 match name.style(db) {
137 SymAggregateStyle::Struct => ty,
138 SymAggregateStyle::Class => apply_perm(self, ty),
139 }
140 }
141 RedTy::Never => SymTy::new(db, SymTyKind::Never),
142 RedTy::Infer(_) => panic!("infer bound cannot be another infer"),
143 RedTy::Var(var) => apply_perm(self, SymTy::new(db, SymTyKind::Var(var))),
144 RedTy::Perm => panic!("infer bound cannot be a perm"),
145 }))
146 }
147
148 fn resolve_perm_var(&mut self, infer: InferVarIndex) -> SymPerm<'db> {
149 if let Some(perm) = self.memoized_perm.get(&infer) {
150 return *perm;
151 }
152
153 let perm = if let Some(t) = self.bounding_perm(infer, Direction::FromBelow) {
154 t
155 } else if let Some(t) = self.bounding_perm(infer, Direction::FromAbove) {
156 t
157 } else {
158 panic!("found no inference bounds, odd")
160 };
161
162 self.memoized_perm.insert(infer, perm);
163 perm
164 }
165
166 fn bounding_perm(&self, infer: InferVarIndex, direction: Direction) -> Option<SymPerm<'db>> {
168 let runtime = self.env.runtime().clone();
169 runtime.with_inference_var_data(infer, |data| {
170 data.red_perm_bound(direction)
171 .map(|(bound, _)| bound.to_sym_perm(self.db))
172 })
173 }
174
175 fn report(&self, infer: InferVarIndex, _err: ResolverCycle) -> Reported {
176 let span = self.env.infer_var_span(infer);
177 Diagnostic::error(self.db, span, "cyclic bounds found for inference variable")
178 .report(self.db)
179 }
180}
181
182#[derive(Copy, Clone)]
183struct ResolverCycle;