Skip to main content

dada_ir_sym/check/
signature.rs

1use dada_ir_ast::{
2    ast::AstFunctionInput,
3    diagnostic::{Diagnostic, Err, Errors},
4    span::Spanned,
5};
6
7use crate::{
8    check::{env::Env, runtime::Runtime},
9    ir::{
10        functions::{SymFunction, SymFunctionSignature, SymFunctionSource, SymInputOutput},
11        generics::SymWhereClause,
12        populate::self_arg_requires_default_perm,
13        types::{AnonymousPermSymbol, SymPerm, SymTy, SymTyName},
14        variables::SymVariable,
15    },
16    prelude::Symbol,
17};
18
19use super::{CheckTyInEnv, generics::symbolify_ast_where_clause, scope_tree::ScopeTreeNode};
20
21pub fn check_function_signature<'db>(
22    db: &'db dyn crate::Db,
23    function: SymFunction<'db>,
24) -> Errors<SymFunctionSignature<'db>> {
25    Runtime::execute(
26        db,
27        function.name_span(db),
28        "check_function_signature",
29        &[&function],
30        async move |runtime| -> Errors<SymFunctionSignature<'db>> {
31            let PreparedEnv {
32                env,
33                input_tys,
34                output_ty_caller,
35                where_clauses,
36                ..
37            } = prepare_env(db, runtime, function).await;
38
39            let scope = env.into_scope();
40            Ok(SymFunctionSignature::new(
41                db,
42                function.symbols(db).clone(),
43                scope.into_bound_value(
44                    db,
45                    SymInputOutput {
46                        input_tys,
47                        output_ty: output_ty_caller,
48                        where_clauses,
49                    },
50                ),
51            ))
52        },
53        |v| v,
54    )
55}
56
57#[derive(Debug)]
58pub struct PreparedEnv<'db> {
59    /// The env that should be used to type check the body
60    pub env: Env<'db>,
61
62    /// The generic variables declared on the fn
63    pub input_symbols: Vec<SymVariable<'db>>,
64
65    /// The types of the fn inputs
66    pub input_tys: Vec<SymTy<'db>>,
67
68    /// The return type the block should generate.
69    /// This is the type that the user wrote.
70    /// In the case of an async fn, this is not a future.
71    pub output_ty_body: SymTy<'db>,
72
73    /// The return type of the fn from the perspective of the caller.
74    /// For an async fn, this is a future.
75    pub output_ty_caller: SymTy<'db>,
76
77    /// Where clauses in scope
78    pub where_clauses: Vec<SymWhereClause<'db>>,
79}
80
81pub async fn prepare_env<'db>(
82    db: &'db dyn crate::Db,
83    runtime: &Runtime<'db>,
84    function: SymFunction<'db>,
85) -> PreparedEnv<'db> {
86    let source = function.source(db);
87    let inputs = source.inputs(db);
88    let input_symbols = inputs
89        .iter()
90        .map(|input| input.symbol(db))
91        .collect::<Vec<_>>();
92
93    let mut env: Env<'db> = Env::new(runtime, function.scope(db));
94
95    // Set the AST types for the inputs.
96    for i in source.inputs(db).iter() {
97        set_variable_ty_from_input(&mut env, i).await;
98    }
99
100    // Now that all input types are available, symbolify and create `input_tys` vector.
101    let mut input_tys: Vec<SymTy<'db>> = vec![];
102    for i in source.inputs(db).iter() {
103        let ty = env.variable_ty(i.symbol(db)).await;
104        input_tys.push(ty);
105    }
106
107    // Symbolify the output type.
108    let output_ty_body: SymTy<'db> = output_ty(&mut env, &function).await;
109    env.set_return_ty(output_ty_body);
110
111    let output_ty_caller = if function.effects(db).async_effect {
112        SymTy::named(db, SymTyName::Future, vec![output_ty_body.into()])
113    } else {
114        output_ty_body
115    };
116
117    // Symbolify the where-clauses
118    let mut ast_where_clauses = vec![];
119    let mut where_clauses = vec![];
120    function.push_transitive_where_clauses(db, &mut ast_where_clauses);
121    for ast_where_clause in ast_where_clauses {
122        symbolify_ast_where_clause(&mut env, ast_where_clause, &mut where_clauses).await;
123    }
124
125    PreparedEnv {
126        env,
127        input_symbols,
128        input_tys,
129        output_ty_body,
130        output_ty_caller,
131        where_clauses,
132    }
133}
134
135async fn set_variable_ty_from_input<'db>(env: &mut Env<'db>, input: &AstFunctionInput<'db>) {
136    let db = env.db();
137    let lv = input.symbol(db);
138    match input {
139        AstFunctionInput::SelfArg(arg) => {
140            let self_ty = if let Some(aggregate) = env.scope.aggregate() {
141                let aggr_ty = aggregate.self_ty(db, &env.scope);
142                if let Some(ast_perm) = arg.perm(db) {
143                    let sym_perm = ast_perm.check_in_env(env).await;
144                    SymTy::perm(db, sym_perm, aggr_ty)
145                } else if self_arg_requires_default_perm(db, *arg, &env.scope) {
146                    let sym_perm = SymPerm::var(db, arg.anonymous_perm_symbol(db));
147                    SymTy::perm(db, sym_perm, aggr_ty)
148                } else {
149                    aggr_ty
150                }
151            } else {
152                SymTy::err(
153                    db,
154                    Diagnostic::error(
155                        db,
156                        arg.span(db),
157                        "self parameter is only permitted within a class definition",
158                    )
159                    .report(db),
160                )
161            };
162            env.set_variable_sym_ty(lv, self_ty);
163        }
164        AstFunctionInput::Variable(decl) => env.set_variable_ast_ty(lv, *decl),
165    }
166}
167
168async fn output_ty<'db>(env: &mut Env<'db>, function: &SymFunction<'db>) -> SymTy<'db> {
169    let db = env.db();
170    match function.source(db) {
171        SymFunctionSource::Function(ast_function) => match ast_function.output_ty(db) {
172            Some(ast_ty) => ast_ty.check_in_env(env).await,
173            None => SymTy::unit(db),
174        },
175        SymFunctionSource::MainFunction(_) => SymTy::unit(env.db()),
176        SymFunctionSource::Constructor(sym_aggregate, _ast_aggregate) => {
177            sym_aggregate.self_ty(db, &env.scope)
178        }
179    }
180}