Skip to main content

dada_codegen/cx/
generate_fn.rs

1use dada_ir_ast::diagnostic::Err;
2use dada_ir_sym::{
3    ir::{
4        functions::{SymFunction, SymInputOutput},
5        types::{SymGenericTerm, SymPlace, SymTy},
6        variables::SymVariable,
7    },
8    prelude::{CheckedBody, CheckedSignature},
9};
10use dada_util::Map;
11use wasm_encoder::ValType;
12
13use super::{Cx, FnIndex, FnKey, generate_expr::ExprCodegen, wasm_repr::WasmReprCx};
14
15impl<'db> Cx<'db> {
16    /// Declares an instantiation of a function with a given set of arguments and returns its index.
17    /// If the function is already declared, nothing happens.
18    /// If the function is not already declared, it is enqueued for code-generation.
19    pub(crate) fn declare_fn(
20        &mut self,
21        function: SymFunction<'db>,
22        generics: Vec<SymGenericTerm<'db>>,
23    ) -> FnIndex {
24        let key = FnKey(function, generics);
25        let generics: &Vec<SymGenericTerm<'_>> = &key.1;
26
27        // Check if we already declared this function and return the result if so
28        if let Some(index) = self.functions.get(&key).copied() {
29            return index;
30        }
31
32        // Extract function signature
33        let CodegenSignature {
34            inputs: _,
35            ref generics,
36            input_output:
37                SymInputOutput {
38                    input_tys,
39                    output_ty,
40                    where_clauses: _,
41                },
42        } = self.codegen_signature(function, generics);
43
44        // Create the type for this function
45        let ty_index = {
46            let mut wrcx = WasmReprCx::new(self.db, generics);
47            // The first input is the stack pointer.
48            // The remainder are the values given by the user.
49            let input_val_types = std::iter::once(ValType::I32)
50                .chain(
51                    input_tys
52                        .iter()
53                        .flat_map(|&t| wrcx.wasm_repr_of_type(t).flatten()),
54                )
55                .collect::<Vec<_>>();
56            let output_val_types = wrcx.wasm_repr_of_type(output_ty).flatten();
57            self.declare_fn_type(input_val_types, output_val_types)
58        };
59
60        // Add to the WASM function section
61        let fn_index = FnIndex(self.function_section.len());
62        self.function_section.function(u32::from(ty_index));
63
64        // Record on the queue to generate code
65        self.codegen_queue.push(key.clone().into());
66
67        // Memoize the result for later
68        self.functions.insert(key, fn_index);
69
70        fn_index
71    }
72
73    pub(crate) fn codegen_fn(&mut self, FnKey(function, generics): FnKey<'db>) {
74        let db = self.db;
75
76        let object_check_body = match function.checked_body(self.db) {
77            Some(body) => body,
78            None => panic!("asked to codegen function with no body: {function:?}"),
79        };
80
81        let CodegenSignature {
82            inputs,
83            generics,
84            input_output,
85        } = self.codegen_signature(function, &generics);
86
87        // Generate the function body.
88        let function = {
89            let mut ecx = ExprCodegen::new(self, generics);
90            ecx.pop_arguments(inputs, &input_output.input_tys);
91            ecx.push_expr(object_check_body);
92            ecx.pop_and_return(object_check_body.ty(db));
93            ecx.into_function()
94        };
95
96        self.code_section.function(&function);
97    }
98
99    fn codegen_signature(
100        &self,
101        function: SymFunction<'db>,
102        generics: &[SymGenericTerm<'db>],
103    ) -> CodegenSignature<'db> {
104        match function.checked_signature(self.db) {
105            Ok(signature) => {
106                let symbols = signature.symbols(self.db);
107
108                let input_output = signature
109                    .input_output(self.db)
110                    .substitute(self.db, generics);
111                let dummy_places = symbols
112                    .input_variables
113                    .iter()
114                    .map(|_| SymGenericTerm::Place(SymPlace::erased(self.db)))
115                    .collect::<Vec<_>>();
116                let input_output = input_output.substitute(self.db, &dummy_places);
117
118                CodegenSignature {
119                    inputs: &symbols.input_variables,
120                    generics: symbols
121                        .generic_variables
122                        .iter()
123                        .copied()
124                        .zip(generics.iter().copied())
125                        .collect(),
126                    input_output,
127                }
128            }
129
130            Err(reported) => CodegenSignature {
131                inputs: &[],
132                generics: Default::default(),
133                input_output: SymInputOutput {
134                    input_tys: vec![],
135                    output_ty: SymTy::err(self.db, reported),
136                    where_clauses: vec![],
137                },
138            },
139        }
140    }
141}
142
143struct CodegenSignature<'db> {
144    inputs: &'db [SymVariable<'db>],
145    generics: Map<SymVariable<'db>, SymGenericTerm<'db>>,
146    input_output: SymInputOutput<'db>,
147}