Skip to main content

dada_ir_sym/ir/
classes.rs

1use std::borrow::Cow;
2
3use dada_ir_ast::{
4    ast::{AstAggregate, AstAggregateKind, AstFieldDecl, AstMember, Identifier, SpannedIdentifier},
5    span::{SourceSpanned, Span, Spanned},
6};
7use dada_parser::prelude::*;
8use dada_util::{FromImpls, SalsaSerialize};
9use salsa::Update;
10
11use crate::{
12    check::scope::Scope,
13    check::scope_tree::{ScopeItem, ScopeTreeNode},
14    ir::functions::{SignatureSymbols, SymFunction, SymFunctionSource},
15    ir::populate::PopulateSignatureSymbols,
16    ir::types::{SymGenericKind, SymTy, SymTyKind},
17    ir::variables::SymVariable,
18    prelude::Symbol,
19};
20
21use super::types::Variance;
22
23#[derive(SalsaSerialize)]
24#[salsa::tracked(debug)]
25pub struct SymAggregate<'db> {
26    /// The scope in which this class is declared.
27    super_scope: ScopeItem<'db>,
28
29    /// The AST for this class.
30    source: AstAggregate<'db>,
31}
32
33#[salsa::tracked]
34impl<'db> SymAggregate<'db> {
35    /// Name of the class.
36    pub fn name(&self, db: &'db dyn salsa::Database) -> Identifier<'db> {
37        self.source(db).name(db)
38    }
39
40    /// Aggregate style (struct, etc)
41    pub fn style(self, db: &'db dyn crate::Db) -> SymAggregateStyle {
42        match self.source(db).kind(db) {
43            AstAggregateKind::Class => SymAggregateStyle::Class,
44            AstAggregateKind::Struct => SymAggregateStyle::Struct,
45        }
46    }
47
48    /// True if this is a struct
49    pub fn is_struct(self, db: &'db dyn crate::Db) -> bool {
50        self.style(db) == SymAggregateStyle::Struct
51    }
52
53    /// True if this is a class
54    pub fn is_class(self, db: &'db dyn crate::Db) -> bool {
55        self.style(db) == SymAggregateStyle::Class
56    }
57
58    /// Number of generic parameters
59    pub fn len_generics(&self, db: &'db dyn crate::Db) -> usize {
60        if let Some(generics) = self.source(db).generics(db) {
61            generics.len()
62        } else {
63            0
64        }
65    }
66
67    /// Variance of generic parameters
68    pub fn variances(&self, db: &'db dyn crate::Db) -> Vec<Variance> {
69        let len_generics = self.len_generics(db);
70        // FIXME
71        vec![Variance::covariant(); len_generics]
72    }
73
74    /// Kinds of generic parameters
75    pub fn generic_kinds(
76        &self,
77        db: &'db dyn crate::Db,
78    ) -> impl Iterator<Item = SymGenericKind> + 'db {
79        self.source(db)
80            .generics(db)
81            .iter()
82            .flatten()
83            .map(move |decl| decl.kind(db).symbol(db))
84    }
85
86    /// Span of the class name, typically used in diagnostics.
87    /// Also returned by the [`Spanned`][] impl.
88    pub fn name_span(&self, db: &'db dyn dada_ir_ast::Db) -> Span<'db> {
89        self.source(db).name_span(db)
90    }
91
92    /// Span where generics are declared (possibly the name span, if there are no generics)
93    pub fn generics_span(&self, db: &'db dyn crate::Db) -> Span<'db> {
94        if let Some(generics) = self.source(db).generics(db) {
95            generics.span
96        } else {
97            self.name_span(db)
98        }
99    }
100
101    /// Span where the `index`th generics are is (possibly the name span, if there are no generics)
102    ///
103    /// # Panics
104    ///
105    /// If `index` is not a valid generic index
106    pub fn generic_span(&self, db: &'db dyn crate::Db, index: usize) -> Span<'db> {
107        let Some(generic) = self.source(db).generics(db).iter().flatten().nth(index) else {
108            panic!(
109                "invalid generic index `{index}` for `{name}`",
110                name = self.name(db)
111            )
112        };
113        generic.span(db)
114    }
115
116    /// Returns the symbols for this class header (generic arguments).
117    #[salsa::tracked(return_ref)]
118    pub(crate) fn symbols(self, db: &'db dyn crate::Db) -> SignatureSymbols<'db> {
119        let mut signature_symbols = SignatureSymbols::new(self);
120        self.source(db)
121            .populate_signature_symbols(db, &mut signature_symbols);
122
123        // NB: classes have no default symbols
124
125        signature_symbols
126    }
127
128    /// Returns the base scope used to resolve the class members.
129    /// Typically this is created by invoke [`Scope::new`][].
130    pub(crate) fn class_scope(self, db: &'db dyn crate::Db) -> Scope<'db, 'db> {
131        let symbols = self.symbols(db);
132        assert!(symbols.input_variables.is_empty());
133        self.super_scope(db)
134            .into_scope(db)
135            .with_link(self)
136            .with_link(Cow::Borrowed(&symbols.generic_variables[..]))
137    }
138
139    /// Returns the type of this class, referencing the generics that appear in `scope`.
140    pub fn self_ty(self, db: &'db dyn crate::Db, scope: &Scope<'_, 'db>) -> SymTy<'db> {
141        SymTy::new(
142            db,
143            SymTyKind::Named(
144                self.into(),
145                self.source(db)
146                    .generics(db)
147                    .iter()
148                    .flatten()
149                    .map(|g| g.symbol(db))
150                    .map(|g| g.into_generic_term(db, scope))
151                    .collect(),
152            ),
153        )
154    }
155
156    /// Tracked list of class members.
157    #[salsa::tracked(return_ref)]
158    pub fn members(self, db: &'db dyn crate::Db) -> Vec<SymClassMember<'db>> {
159        // If the class is declared like `class Foo(x: u32, y: u32)` then we make a constructor `new`
160        // and a field for each of those members
161        let ctor_members = self.source(db).inputs(db).iter().flat_map(|inputs| {
162            let ctor = SymFunction::new(
163                db,
164                self.into(),
165                SymFunctionSource::Constructor(self, self.source(db)),
166            )
167            .into();
168
169            let fields = inputs.iter().map(|field_decl| {
170                SymField::new(
171                    db,
172                    self.into(),
173                    field_decl.variable(db).name(db).id,
174                    field_decl.variable(db).name(db).span,
175                    *field_decl,
176                )
177                .into()
178            });
179
180            std::iter::once(ctor).chain(fields)
181        });
182
183        // Also include anything the user explicitly wrote
184        let explicit_members = self.source(db).members(db).iter().map(|m| match *m {
185            AstMember::Field(ast_field_decl) => {
186                let SpannedIdentifier { span, id } = ast_field_decl.variable(db).name(db);
187                SymField::new(db, self.into(), id, span, ast_field_decl).into()
188            }
189            AstMember::Function(ast_function) => {
190                SymFunction::new(db, self.into(), ast_function.into()).into()
191            }
192        });
193
194        ctor_members.chain(explicit_members).collect()
195    }
196
197    /// Returns the member with the given name, if it exists.
198    #[salsa::tracked]
199    pub fn inherent_member(
200        self,
201        db: &'db dyn crate::Db,
202        id: Identifier<'db>,
203    ) -> Option<SymClassMember<'db>> {
204        self.members(db)
205            .iter()
206            .copied()
207            .find(|m| m.has_name(db, id))
208    }
209
210    /// Returns the member with the given name, if it exists.
211    pub fn inherent_member_str(
212        self,
213        db: &'db dyn crate::Db,
214        id: &str,
215    ) -> Option<SymClassMember<'db>> {
216        self.inherent_member(db, Identifier::new(db, id))
217    }
218
219    /// Returns iterator over all fields in this class.
220    pub fn fields(self, db: &'db dyn crate::Db) -> impl Iterator<Item = SymField<'db>> {
221        self.members(db).iter().filter_map(|&m| match m {
222            SymClassMember::SymField(f) => Some(f),
223            _ => None,
224        })
225    }
226
227    /// Returns iterator over all methods in this class.
228    pub fn methods(self, db: &'db dyn crate::Db) -> impl Iterator<Item = SymFunction<'db>> {
229        self.members(db).iter().filter_map(|&m| match m {
230            SymClassMember::SymFunction(f) => Some(f),
231            _ => None,
232        })
233    }
234}
235
236impl std::fmt::Display for SymAggregate<'_> {
237    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
238        salsa::with_attached_database(|db| write!(f, "{}", self.name(db)))
239            .unwrap_or_else(|| std::fmt::Debug::fmt(self, f))
240    }
241}
242
243impl<'db> ScopeTreeNode<'db> for SymAggregate<'db> {
244    fn direct_super_scope(self, db: &'db dyn crate::Db) -> Option<ScopeItem<'db>> {
245        Some(self.super_scope(db))
246    }
247
248    fn direct_generic_parameters(self, db: &'db dyn crate::Db) -> &'db Vec<SymVariable<'db>> {
249        &self.symbols(db).generic_variables
250    }
251
252    fn into_scope(self, db: &'db dyn crate::Db) -> Scope<'db, 'db> {
253        self.class_scope(db)
254    }
255
256    fn push_direct_ast_where_clauses(
257        self,
258        db: &'db dyn crate::Db,
259        out: &mut Vec<dada_ir_ast::ast::AstWhereClause<'db>>,
260    ) {
261        if let Some(wc) = self.source(db).where_clauses(db) {
262            out.extend(wc.clauses(db));
263        }
264    }
265}
266
267impl<'db> Spanned<'db> for SymAggregate<'db> {
268    fn span(&self, db: &'db dyn dada_ir_ast::Db) -> Span<'db> {
269        self.name_span(db)
270    }
271}
272
273impl<'db> SourceSpanned<'db> for SymAggregate<'db> {
274    fn source_span(&self, db: &'db dyn dada_ir_ast::Db) -> Span<'db> {
275        self.source(db).span(db)
276    }
277}
278
279#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Debug)]
280pub enum SymAggregateStyle {
281    Struct,
282    Class,
283}
284
285/// Symbol for a class member
286#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Debug, FromImpls, Update)]
287pub enum SymClassMember<'db> {
288    /// Class fields
289    SymField(SymField<'db>),
290
291    /// Class methods
292    SymFunction(SymFunction<'db>),
293}
294
295impl<'db> SymClassMember<'db> {
296    /// True if this class member has the given name.
297    pub fn has_name(self, db: &'db dyn crate::Db, id: Identifier<'db>) -> bool {
298        match self {
299            SymClassMember::SymField(f) => f.name(db) == id,
300            SymClassMember::SymFunction(f) => f.name(db) == id,
301        }
302    }
303}
304
305impl<'db> Spanned<'db> for SymClassMember<'db> {
306    fn span(&self, db: &'db dyn dada_ir_ast::Db) -> Span<'db> {
307        match self {
308            SymClassMember::SymField(f) => f.name_span(db),
309            SymClassMember::SymFunction(f) => f.name_span(db),
310        }
311    }
312}
313
314impl<'db> SourceSpanned<'db> for SymClassMember<'db> {
315    fn source_span(&self, db: &'db dyn dada_ir_ast::Db) -> Span<'db> {
316        match self {
317            SymClassMember::SymField(f) => f.source_span(db),
318            SymClassMember::SymFunction(f) => f.source_span(db),
319        }
320    }
321}
322
323/// Symbol for a field of a class, struct, or enum
324#[derive(SalsaSerialize)]
325#[salsa::tracked(debug)]
326pub struct SymField<'db> {
327    /// The item in which this field is declared.
328    pub scope_item: ScopeItem<'db>,
329
330    /// Field name
331    pub name: Identifier<'db>,
332
333    /// Span of field name. Also returned by [`Spanned`][] impl.
334    pub name_span: Span<'db>,
335
336    /// AST for field declaration
337    pub source: AstFieldDecl<'db>,
338}
339
340#[salsa::tracked]
341impl<'db> SymField<'db> {
342    /// The symbol for the `self` variable that appears in this field's type.
343    /// (Every field and class member has their own `self` symbol.)
344    #[salsa::tracked]
345    pub fn self_sym(self, db: &'db dyn crate::Db) -> SymVariable<'db> {
346        SymVariable::new(
347            db,
348            SymGenericKind::Place,
349            Some(Identifier::self_ident(db)),
350            self.name_span(db),
351        )
352    }
353
354    /// The scope for resolving the type of this field.
355    pub fn into_scope(self, db: &'db dyn crate::Db) -> Scope<'db, 'db> {
356        let self_sym = self.self_sym(db);
357        self.scope_item(db).into_scope(db).with_link(self_sym)
358    }
359}
360
361impl<'db> Spanned<'db> for SymField<'db> {
362    fn span(&self, db: &'db dyn dada_ir_ast::Db) -> dada_ir_ast::span::Span<'db> {
363        self.name_span(db)
364    }
365}
366
367impl<'db> SourceSpanned<'db> for SymField<'db> {
368    fn source_span(&self, db: &'db dyn dada_ir_ast::Db) -> Span<'db> {
369        self.source(db).span(db)
370    }
371}
372
373impl std::fmt::Display for SymField<'_> {
374    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
375        salsa::with_attached_database(|db| write!(f, "{}", self.name(db)))
376            .unwrap_or_else(|| std::fmt::Debug::fmt(self, f))
377    }
378}