1use dada_ir_ast::{
2 ast::AstFunctionInput,
3 diagnostic::{Diagnostic, Err, Errors},
4 span::Spanned,
5};
6
7use crate::{
8 check::{env::Env, runtime::Runtime},
9 ir::{
10 functions::{SymFunction, SymFunctionSignature, SymFunctionSource, SymInputOutput},
11 generics::SymWhereClause,
12 populate::self_arg_requires_default_perm,
13 types::{AnonymousPermSymbol, SymPerm, SymTy, SymTyName},
14 variables::SymVariable,
15 },
16 prelude::Symbol,
17};
18
19use super::{CheckTyInEnv, generics::symbolify_ast_where_clause, scope_tree::ScopeTreeNode};
20
21pub fn check_function_signature<'db>(
22 db: &'db dyn crate::Db,
23 function: SymFunction<'db>,
24) -> Errors<SymFunctionSignature<'db>> {
25 Runtime::execute(
26 db,
27 function.name_span(db),
28 "check_function_signature",
29 &[&function],
30 async move |runtime| -> Errors<SymFunctionSignature<'db>> {
31 let PreparedEnv {
32 env,
33 input_tys,
34 output_ty_caller,
35 where_clauses,
36 ..
37 } = prepare_env(db, runtime, function).await;
38
39 let scope = env.into_scope();
40 Ok(SymFunctionSignature::new(
41 db,
42 function.symbols(db).clone(),
43 scope.into_bound_value(
44 db,
45 SymInputOutput {
46 input_tys,
47 output_ty: output_ty_caller,
48 where_clauses,
49 },
50 ),
51 ))
52 },
53 |v| v,
54 )
55}
56
57#[derive(Debug)]
58pub struct PreparedEnv<'db> {
59 pub env: Env<'db>,
61
62 pub input_symbols: Vec<SymVariable<'db>>,
64
65 pub input_tys: Vec<SymTy<'db>>,
67
68 pub output_ty_body: SymTy<'db>,
72
73 pub output_ty_caller: SymTy<'db>,
76
77 pub where_clauses: Vec<SymWhereClause<'db>>,
79}
80
81pub async fn prepare_env<'db>(
82 db: &'db dyn crate::Db,
83 runtime: &Runtime<'db>,
84 function: SymFunction<'db>,
85) -> PreparedEnv<'db> {
86 let source = function.source(db);
87 let inputs = source.inputs(db);
88 let input_symbols = inputs
89 .iter()
90 .map(|input| input.symbol(db))
91 .collect::<Vec<_>>();
92
93 let mut env: Env<'db> = Env::new(runtime, function.scope(db));
94
95 for i in source.inputs(db).iter() {
97 set_variable_ty_from_input(&mut env, i).await;
98 }
99
100 let mut input_tys: Vec<SymTy<'db>> = vec![];
102 for i in source.inputs(db).iter() {
103 let ty = env.variable_ty(i.symbol(db)).await;
104 input_tys.push(ty);
105 }
106
107 let output_ty_body: SymTy<'db> = output_ty(&mut env, &function).await;
109 env.set_return_ty(output_ty_body);
110
111 let output_ty_caller = if function.effects(db).async_effect {
112 SymTy::named(db, SymTyName::Future, vec![output_ty_body.into()])
113 } else {
114 output_ty_body
115 };
116
117 let mut ast_where_clauses = vec![];
119 let mut where_clauses = vec![];
120 function.push_transitive_where_clauses(db, &mut ast_where_clauses);
121 for ast_where_clause in ast_where_clauses {
122 symbolify_ast_where_clause(&mut env, ast_where_clause, &mut where_clauses).await;
123 }
124
125 PreparedEnv {
126 env,
127 input_symbols,
128 input_tys,
129 output_ty_body,
130 output_ty_caller,
131 where_clauses,
132 }
133}
134
135async fn set_variable_ty_from_input<'db>(env: &mut Env<'db>, input: &AstFunctionInput<'db>) {
136 let db = env.db();
137 let lv = input.symbol(db);
138 match input {
139 AstFunctionInput::SelfArg(arg) => {
140 let self_ty = if let Some(aggregate) = env.scope.aggregate() {
141 let aggr_ty = aggregate.self_ty(db, &env.scope);
142 if let Some(ast_perm) = arg.perm(db) {
143 let sym_perm = ast_perm.check_in_env(env).await;
144 SymTy::perm(db, sym_perm, aggr_ty)
145 } else if self_arg_requires_default_perm(db, *arg, &env.scope) {
146 let sym_perm = SymPerm::var(db, arg.anonymous_perm_symbol(db));
147 SymTy::perm(db, sym_perm, aggr_ty)
148 } else {
149 aggr_ty
150 }
151 } else {
152 SymTy::err(
153 db,
154 Diagnostic::error(
155 db,
156 arg.span(db),
157 "self parameter is only permitted within a class definition",
158 )
159 .report(db),
160 )
161 };
162 env.set_variable_sym_ty(lv, self_ty);
163 }
164 AstFunctionInput::Variable(decl) => env.set_variable_ast_ty(lv, *decl),
165 }
166}
167
168async fn output_ty<'db>(env: &mut Env<'db>, function: &SymFunction<'db>) -> SymTy<'db> {
169 let db = env.db();
170 match function.source(db) {
171 SymFunctionSource::Function(ast_function) => match ast_function.output_ty(db) {
172 Some(ast_ty) => ast_ty.check_in_env(env).await,
173 None => SymTy::unit(db),
174 },
175 SymFunctionSource::MainFunction(_) => SymTy::unit(env.db()),
176 SymFunctionSource::Constructor(sym_aggregate, _ast_aggregate) => {
177 sym_aggregate.self_ty(db, &env.scope)
178 }
179 }
180}