1use std::borrow::Cow;
2
3use dada_ir_ast::{
4 ast::{
5 AstAggregate, AstFunction, AstFunctionEffects, AstFunctionInput, AstMainFunction,
6 Identifier, SpannedIdentifier,
7 },
8 span::{SourceSpanned, Span, Spanned},
9};
10use dada_util::{FromImpls, SalsaSerialize};
11use salsa::Update;
12use serde::Serialize;
13
14use crate::{
15 check::{
16 scope::Scope,
17 scope_tree::{ScopeItem, ScopeTreeNode},
18 },
19 ir::{
20 binder::{Binder, LeafBoundTerm},
21 classes::SymAggregate,
22 populate::{PopulateDefaultSymbols, PopulateSignatureSymbols},
23 types::SymTy,
24 variables::SymVariable,
25 },
26};
27
28use super::{
29 classes::SymAggregateStyle,
30 generics::SymWhereClause,
31 types::{HasKind, SymGenericKind},
32};
33
34#[derive(SalsaSerialize)]
35#[salsa::tracked(debug)]
36pub struct SymFunction<'db> {
37 pub super_scope_item: ScopeItem<'db>,
38
39 #[tracked]
40 pub source: SymFunctionSource<'db>,
41}
42
43#[salsa::tracked]
44impl<'db> SymFunction<'db> {
45 #[salsa::tracked]
46 pub fn effects(self, db: &'db dyn crate::Db) -> SymFunctionEffects {
47 let source = self.source(db).effects(db);
48 SymFunctionEffects {
49 async_effect: source.async_effect.is_some(),
50 }
51 }
52}
53
54impl<'db> ScopeTreeNode<'db> for SymFunction<'db> {
55 fn into_scope(self, db: &'db dyn crate::Db) -> Scope<'db, 'db> {
56 self.scope(db)
57 }
58
59 fn direct_super_scope(self, db: &'db dyn crate::Db) -> Option<ScopeItem<'db>> {
60 Some(self.super_scope_item(db))
61 }
62
63 fn direct_generic_parameters(self, db: &'db dyn crate::Db) -> &'db Vec<SymVariable<'db>> {
64 &self.symbols(db).generic_variables
65 }
66
67 fn push_direct_ast_where_clauses(
68 self,
69 db: &'db dyn crate::Db,
70 out: &mut Vec<dada_ir_ast::ast::AstWhereClause<'db>>,
71 ) {
72 let wc = match self.source(db) {
73 SymFunctionSource::Function(ast) => ast.where_clauses(db),
74 SymFunctionSource::Constructor(_, ast) => ast.where_clauses(db),
75 SymFunctionSource::MainFunction(_) => &None,
76 };
77
78 if let Some(wc) = wc {
79 out.extend(wc.clauses(db));
80 }
81 }
82}
83
84impl<'db> Spanned<'db> for SymFunction<'db> {
85 fn span(&self, db: &'db dyn dada_ir_ast::Db) -> Span<'db> {
86 self.source(db).name(db).span
87 }
88}
89
90impl<'db> SourceSpanned<'db> for SymFunction<'db> {
91 fn source_span(&self, db: &'db dyn dada_ir_ast::Db) -> Span<'db> {
92 self.source(db).source_span(db)
93 }
94}
95
96#[salsa::tracked]
97impl<'db> SymFunction<'db> {
98 pub fn name(self, db: &'db dyn crate::Db) -> Identifier<'db> {
100 self.source(db).name(db).id
101 }
102
103 pub fn name_span(self, db: &'db dyn crate::Db) -> Span<'db> {
105 self.source(db).name(db).span
106 }
107
108 fn scope_from_symbols<'sym>(
109 self,
110 db: &'db dyn crate::Db,
111 symbols: &'sym SignatureSymbols<'db>,
112 ) -> Scope<'sym, 'db> {
113 self.super_scope_item(db)
114 .into_scope(db)
115 .with_link(Cow::Borrowed(&symbols.generic_variables[..]))
116 .with_link(Cow::Borrowed(&symbols.input_variables[..]))
117 }
118
119 #[salsa::tracked(return_ref)]
120 pub fn symbols(self, db: &'db dyn crate::Db) -> SignatureSymbols<'db> {
121 let source = self.source(db);
122
123 let mut just_explicit_symbols = SignatureSymbols::new(self);
127 source.populate_signature_symbols(db, &mut just_explicit_symbols);
128 let scope = self.scope_from_symbols(db, &just_explicit_symbols);
129
130 let mut with_default_symbols = just_explicit_symbols.clone();
132 source.populate_default_symbols(db, &scope, &mut with_default_symbols);
133
134 with_default_symbols
135 }
136
137 pub fn scope(self, db: &'db dyn crate::Db) -> Scope<'db, 'db> {
140 let symbols = self.symbols(db);
141 self.scope_from_symbols(db, symbols)
142 }
143}
144
145#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Debug, Update, FromImpls, Serialize)]
146pub enum SymFunctionSource<'db> {
147 Function(AstFunction<'db>),
148
149 MainFunction(AstMainFunction<'db>),
151
152 #[no_from_impl] Constructor(SymAggregate<'db>, AstAggregate<'db>),
155}
156
157impl<'db> SymFunctionSource<'db> {
158 fn effects(self, db: &'db dyn crate::Db) -> AstFunctionEffects<'db> {
159 match self {
160 Self::Function(ast_function) => ast_function.effects(db),
161 Self::MainFunction(_) | Self::Constructor(..) => AstFunctionEffects::default(),
162 }
163 }
164
165 fn name(self, db: &'db dyn dada_ir_ast::Db) -> SpannedIdentifier<'db> {
166 match self {
167 Self::Function(ast_function) => ast_function.name(db),
168 Self::Constructor(class, _) => SpannedIdentifier {
169 span: class.name_span(db),
170 id: Identifier::new_ident(db),
171 },
172 Self::MainFunction(mfunc) => SpannedIdentifier {
173 span: mfunc.statements(db).span,
174 id: Identifier::main(db),
175 },
176 }
177 }
178
179 pub fn inputs(self, db: &'db dyn crate::Db) -> Cow<'db, [AstFunctionInput<'db>]> {
180 match self {
181 Self::Function(ast_function) => Cow::Borrowed(&ast_function.inputs(db).values),
182 Self::Constructor(_, class) => Cow::Owned(
183 class
184 .inputs(db)
185 .as_ref()
186 .unwrap()
187 .iter()
188 .map(|i| i.variable(db).into())
189 .collect::<Vec<_>>(),
190 ),
191 Self::MainFunction(_) => Cow::Borrowed(&[]),
192 }
193 }
194}
195
196impl<'db> SourceSpanned<'db> for SymFunctionSource<'db> {
197 fn source_span(&self, db: &'db dyn dada_ir_ast::Db) -> Span<'db> {
198 match self {
199 SymFunctionSource::Function(ast_function) => ast_function.span(db),
200 SymFunctionSource::Constructor(_, ast_aggregate) => ast_aggregate.span(db),
201 SymFunctionSource::MainFunction(mfunc) => mfunc.span(db),
202 }
203 }
204}
205
206#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Debug)]
208pub struct SymFunctionEffects {
209 pub async_effect: bool,
210}
211
212#[derive(SalsaSerialize)]
213#[salsa::tracked(debug)]
214pub struct SymFunctionSignature<'db> {
215 #[return_ref]
216 pub symbols: SignatureSymbols<'db>,
217
218 #[return_ref]
223 pub input_output: Binder<'db, Binder<'db, SymInputOutput<'db>>>,
224}
225
226#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Update, Debug, Serialize)]
227pub struct SymInputOutput<'db> {
228 pub input_tys: Vec<SymTy<'db>>,
229 pub output_ty: SymTy<'db>,
230 pub where_clauses: Vec<SymWhereClause<'db>>,
231}
232
233impl<'db> LeafBoundTerm<'db> for SymInputOutput<'db> {}
234
235#[derive(Clone, Debug, PartialEq, Eq, Hash, Update, Serialize)]
236pub struct SignatureSymbols<'db> {
237 pub source: SignatureSource<'db>,
239
240 pub generic_variables: Vec<SymVariable<'db>>,
242
243 pub input_variables: Vec<SymVariable<'db>>,
245}
246
247impl<'db> SignatureSymbols<'db> {
248 pub fn has_generics_of_kind(&self, db: &'db dyn crate::Db, kinds: &[SymGenericKind]) -> bool {
249 if self.generic_variables.len() != kinds.len() {
250 return false;
251 }
252 self.generic_variables
253 .iter()
254 .zip(kinds)
255 .all(|(&v, &k)| v.has_kind(db, k))
256 }
257}
258
259impl<'db> SignatureSymbols<'db> {
260 pub fn new(source: impl Into<SignatureSource<'db>>) -> Self {
264 Self {
265 source: source.into(),
266 generic_variables: Vec::new(),
267 input_variables: Vec::new(),
268 }
269 }
270}
271
272#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, Update, FromImpls, Serialize)]
273pub enum SignatureSource<'db> {
274 Class(SymAggregate<'db>),
275 Function(SymFunction<'db>),
276}
277
278impl<'db> SignatureSource<'db> {
279 pub fn aggr_style(self, db: &'db dyn crate::Db) -> Option<SymAggregateStyle> {
280 match self {
281 SignatureSource::Class(aggr) => Some(aggr.style(db)),
282 SignatureSource::Function(_) => None,
283 }
284 }
285}