Skip to main content

dada_ir_sym/ir/
functions.rs

1use std::borrow::Cow;
2
3use dada_ir_ast::{
4    ast::{
5        AstAggregate, AstFunction, AstFunctionEffects, AstFunctionInput, AstMainFunction,
6        Identifier, SpannedIdentifier,
7    },
8    span::{SourceSpanned, Span, Spanned},
9};
10use dada_util::{FromImpls, SalsaSerialize};
11use salsa::Update;
12use serde::Serialize;
13
14use crate::{
15    check::{
16        scope::Scope,
17        scope_tree::{ScopeItem, ScopeTreeNode},
18    },
19    ir::{
20        binder::{Binder, LeafBoundTerm},
21        classes::SymAggregate,
22        populate::{PopulateDefaultSymbols, PopulateSignatureSymbols},
23        types::SymTy,
24        variables::SymVariable,
25    },
26};
27
28use super::{
29    classes::SymAggregateStyle,
30    generics::SymWhereClause,
31    types::{HasKind, SymGenericKind},
32};
33
34#[derive(SalsaSerialize)]
35#[salsa::tracked(debug)]
36pub struct SymFunction<'db> {
37    pub super_scope_item: ScopeItem<'db>,
38
39    #[tracked]
40    pub source: SymFunctionSource<'db>,
41}
42
43#[salsa::tracked]
44impl<'db> SymFunction<'db> {
45    #[salsa::tracked]
46    pub fn effects(self, db: &'db dyn crate::Db) -> SymFunctionEffects {
47        let source = self.source(db).effects(db);
48        SymFunctionEffects {
49            async_effect: source.async_effect.is_some(),
50        }
51    }
52}
53
54impl<'db> ScopeTreeNode<'db> for SymFunction<'db> {
55    fn into_scope(self, db: &'db dyn crate::Db) -> Scope<'db, 'db> {
56        self.scope(db)
57    }
58
59    fn direct_super_scope(self, db: &'db dyn crate::Db) -> Option<ScopeItem<'db>> {
60        Some(self.super_scope_item(db))
61    }
62
63    fn direct_generic_parameters(self, db: &'db dyn crate::Db) -> &'db Vec<SymVariable<'db>> {
64        &self.symbols(db).generic_variables
65    }
66
67    fn push_direct_ast_where_clauses(
68        self,
69        db: &'db dyn crate::Db,
70        out: &mut Vec<dada_ir_ast::ast::AstWhereClause<'db>>,
71    ) {
72        let wc = match self.source(db) {
73            SymFunctionSource::Function(ast) => ast.where_clauses(db),
74            SymFunctionSource::Constructor(_, ast) => ast.where_clauses(db),
75            SymFunctionSource::MainFunction(_) => &None,
76        };
77
78        if let Some(wc) = wc {
79            out.extend(wc.clauses(db));
80        }
81    }
82}
83
84impl<'db> Spanned<'db> for SymFunction<'db> {
85    fn span(&self, db: &'db dyn dada_ir_ast::Db) -> Span<'db> {
86        self.source(db).name(db).span
87    }
88}
89
90impl<'db> SourceSpanned<'db> for SymFunction<'db> {
91    fn source_span(&self, db: &'db dyn dada_ir_ast::Db) -> Span<'db> {
92        self.source(db).source_span(db)
93    }
94}
95
96#[salsa::tracked]
97impl<'db> SymFunction<'db> {
98    /// Name of the function.
99    pub fn name(self, db: &'db dyn crate::Db) -> Identifier<'db> {
100        self.source(db).name(db).id
101    }
102
103    /// Span for the function name.
104    pub fn name_span(self, db: &'db dyn crate::Db) -> Span<'db> {
105        self.source(db).name(db).span
106    }
107
108    fn scope_from_symbols<'sym>(
109        self,
110        db: &'db dyn crate::Db,
111        symbols: &'sym SignatureSymbols<'db>,
112    ) -> Scope<'sym, 'db> {
113        self.super_scope_item(db)
114            .into_scope(db)
115            .with_link(Cow::Borrowed(&symbols.generic_variables[..]))
116            .with_link(Cow::Borrowed(&symbols.input_variables[..]))
117    }
118
119    #[salsa::tracked(return_ref)]
120    pub fn symbols(self, db: &'db dyn crate::Db) -> SignatureSymbols<'db> {
121        let source = self.source(db);
122
123        // Before we can populate the default symbols,
124        // we need to create a temporary scope with *just* the explicit symbols.
125        // This allows us to do name resolution on the names of types and things.
126        let mut just_explicit_symbols = SignatureSymbols::new(self);
127        source.populate_signature_symbols(db, &mut just_explicit_symbols);
128        let scope = self.scope_from_symbols(db, &just_explicit_symbols);
129
130        // Now add in any default symbols.
131        let mut with_default_symbols = just_explicit_symbols.clone();
132        source.populate_default_symbols(db, &scope, &mut with_default_symbols);
133
134        with_default_symbols
135    }
136
137    /// Returns the scope for this function; this has the function generics
138    /// and parameters in scope.
139    pub fn scope(self, db: &'db dyn crate::Db) -> Scope<'db, 'db> {
140        let symbols = self.symbols(db);
141        self.scope_from_symbols(db, symbols)
142    }
143}
144
145#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Debug, Update, FromImpls, Serialize)]
146pub enum SymFunctionSource<'db> {
147    Function(AstFunction<'db>),
148
149    /// Generated `fn main()` from statements appearing at the top of a module
150    MainFunction(AstMainFunction<'db>),
151
152    /// Generated constructor from an aggregate like `struct Foo(x: u32)`
153    #[no_from_impl] // I'd prefer to be explicit
154    Constructor(SymAggregate<'db>, AstAggregate<'db>),
155}
156
157impl<'db> SymFunctionSource<'db> {
158    fn effects(self, db: &'db dyn crate::Db) -> AstFunctionEffects<'db> {
159        match self {
160            Self::Function(ast_function) => ast_function.effects(db),
161            Self::MainFunction(_) | Self::Constructor(..) => AstFunctionEffects::default(),
162        }
163    }
164
165    fn name(self, db: &'db dyn dada_ir_ast::Db) -> SpannedIdentifier<'db> {
166        match self {
167            Self::Function(ast_function) => ast_function.name(db),
168            Self::Constructor(class, _) => SpannedIdentifier {
169                span: class.name_span(db),
170                id: Identifier::new_ident(db),
171            },
172            Self::MainFunction(mfunc) => SpannedIdentifier {
173                span: mfunc.statements(db).span,
174                id: Identifier::main(db),
175            },
176        }
177    }
178
179    pub fn inputs(self, db: &'db dyn crate::Db) -> Cow<'db, [AstFunctionInput<'db>]> {
180        match self {
181            Self::Function(ast_function) => Cow::Borrowed(&ast_function.inputs(db).values),
182            Self::Constructor(_, class) => Cow::Owned(
183                class
184                    .inputs(db)
185                    .as_ref()
186                    .unwrap()
187                    .iter()
188                    .map(|i| i.variable(db).into())
189                    .collect::<Vec<_>>(),
190            ),
191            Self::MainFunction(_) => Cow::Borrowed(&[]),
192        }
193    }
194}
195
196impl<'db> SourceSpanned<'db> for SymFunctionSource<'db> {
197    fn source_span(&self, db: &'db dyn dada_ir_ast::Db) -> Span<'db> {
198        match self {
199            SymFunctionSource::Function(ast_function) => ast_function.span(db),
200            SymFunctionSource::Constructor(_, ast_aggregate) => ast_aggregate.span(db),
201            SymFunctionSource::MainFunction(mfunc) => mfunc.span(db),
202        }
203    }
204}
205
206/// Set of effects that can be declared on the function.
207#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Debug)]
208pub struct SymFunctionEffects {
209    pub async_effect: bool,
210}
211
212#[derive(SalsaSerialize)]
213#[salsa::tracked(debug)]
214pub struct SymFunctionSignature<'db> {
215    #[return_ref]
216    pub symbols: SignatureSymbols<'db>,
217
218    /// Input/output types:
219    ///
220    /// * Outer binder is for generic symbols from the function and its surrounding scopes
221    /// * Inner binder is the function local variables.
222    #[return_ref]
223    pub input_output: Binder<'db, Binder<'db, SymInputOutput<'db>>>,
224}
225
226#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Update, Debug, Serialize)]
227pub struct SymInputOutput<'db> {
228    pub input_tys: Vec<SymTy<'db>>,
229    pub output_ty: SymTy<'db>,
230    pub where_clauses: Vec<SymWhereClause<'db>>,
231}
232
233impl<'db> LeafBoundTerm<'db> for SymInputOutput<'db> {}
234
235#[derive(Clone, Debug, PartialEq, Eq, Hash, Update, Serialize)]
236pub struct SignatureSymbols<'db> {
237    /// Source of these symbols
238    pub source: SignatureSource<'db>,
239
240    /// Generic parmaeters on the class or function (concatenated)
241    pub generic_variables: Vec<SymVariable<'db>>,
242
243    /// Symbols for the function input variables (if any)
244    pub input_variables: Vec<SymVariable<'db>>,
245}
246
247impl<'db> SignatureSymbols<'db> {
248    pub fn has_generics_of_kind(&self, db: &'db dyn crate::Db, kinds: &[SymGenericKind]) -> bool {
249        if self.generic_variables.len() != kinds.len() {
250            return false;
251        }
252        self.generic_variables
253            .iter()
254            .zip(kinds)
255            .all(|(&v, &k)| v.has_kind(db, k))
256    }
257}
258
259impl<'db> SignatureSymbols<'db> {
260    /// Create an empty set of signature symbols from a given source.
261    /// The actual symbols themselves are populated via the trait
262    /// [`PopulateSignatureSymbols`][].
263    pub fn new(source: impl Into<SignatureSource<'db>>) -> Self {
264        Self {
265            source: source.into(),
266            generic_variables: Vec::new(),
267            input_variables: Vec::new(),
268        }
269    }
270}
271
272#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, Update, FromImpls, Serialize)]
273pub enum SignatureSource<'db> {
274    Class(SymAggregate<'db>),
275    Function(SymFunction<'db>),
276}
277
278impl<'db> SignatureSource<'db> {
279    pub fn aggr_style(self, db: &'db dyn crate::Db) -> Option<SymAggregateStyle> {
280        match self {
281            SignatureSource::Class(aggr) => Some(aggr.style(db)),
282            SignatureSource::Function(_) => None,
283        }
284    }
285}