diff --git a/src/check_decl.cpp b/src/check_decl.cpp index 134dbb35b..42f68203c 100644 --- a/src/check_decl.cpp +++ b/src/check_decl.cpp @@ -1302,9 +1302,6 @@ void check_proc_body(CheckerContext *ctx_, Token token, DeclInfo *decl, Type *ty Type *t = base_type(type_deref(e->type)); if (t->kind == Type_Struct) { Scope *scope = t->Struct.scope; - if (scope == nullptr) { - scope = scope_of_node(t->Struct.node); - } GB_ASSERT(scope != nullptr); for_array(i, scope->elements.entries) { Entity *f = scope->elements.entries[i].value; diff --git a/src/check_stmt.cpp b/src/check_stmt.cpp index b5e3d8c88..1a424240c 100644 --- a/src/check_stmt.cpp +++ b/src/check_stmt.cpp @@ -635,10 +635,7 @@ bool check_using_stmt_entity(CheckerContext *ctx, AstUsingStmt *us, Ast *expr, b bool is_ptr = is_type_pointer(e->type); Type *t = base_type(type_deref(e->type)); if (t->kind == Type_Struct) { - Scope *found = scope_of_node(t->Struct.node); - if (found == nullptr) { - found = t->Struct.scope; - } + Scope *found = t->Struct.scope; GB_ASSERT(found != nullptr); for_array(i, found->elements.entries) { Entity *f = found->elements.entries[i].value; @@ -1399,9 +1396,9 @@ void check_block_stmt_for_errors(CheckerContext *ctx, Ast *body) { ast_node(bs, BlockStmt, body); // NOTE(bill, 2020-09-23): This logic is prevent common erros with block statements // e.g. if cond { x := 123; } // this is an error - if (body->scope != nullptr && body->scope->elements.entries.count > 0) { - if (body->scope->parent->node != nullptr) { - switch (body->scope->parent->node->kind) { + if (bs->scope != nullptr && bs->scope->elements.entries.count > 0) { + if (bs->scope->parent->node != nullptr) { + switch (bs->scope->parent->node->kind) { case Ast_IfStmt: case Ast_ForStmt: case Ast_RangeStmt: @@ -2339,7 +2336,8 @@ void check_stmt_internal(CheckerContext *ctx, Ast *node, u32 flags) { } else if (is_type_struct(t) || is_type_raw_union(t)) { ERROR_BLOCK(); - Scope *scope = scope_of_node(t->Struct.node); + Scope *scope = t->Struct.scope; + GB_ASSERT(scope != nullptr); for_array(i, scope->elements.entries) { Entity *f = scope->elements.entries[i].value; if (f->kind == Entity_Variable) { diff --git a/src/checker.cpp b/src/checker.cpp index 4ae8fd456..3caed256a 100644 --- a/src/checker.cpp +++ b/src/checker.cpp @@ -318,7 +318,43 @@ void add_scope(CheckerContext *c, Ast *node, Scope *scope) { GB_ASSERT(node != nullptr); GB_ASSERT(scope != nullptr); scope->node = node; - node->scope = scope; + switch (node->kind) { + case Ast_BlockStmt: node->BlockStmt.scope = scope; break; + case Ast_IfStmt: node->IfStmt.scope = scope; break; + case Ast_ForStmt: node->ForStmt.scope = scope; break; + case Ast_RangeStmt: node->RangeStmt.scope = scope; break; + case Ast_UnrollRangeStmt: node->UnrollRangeStmt.scope = scope; break; + case Ast_CaseClause: node->CaseClause.scope = scope; break; + case Ast_SwitchStmt: node->SwitchStmt.scope = scope; break; + case Ast_TypeSwitchStmt: node->TypeSwitchStmt.scope = scope; break; + case Ast_ProcType: node->ProcType.scope = scope; break; + case Ast_StructType: node->StructType.scope = scope; break; + case Ast_UnionType: node->UnionType.scope = scope; break; + case Ast_EnumType: node->EnumType.scope = scope; break; + default: GB_PANIC("Invalid node for add_scope: %.*s", LIT(ast_strings[node->kind])); + } +} + +Scope *scope_of_node(Ast *node) { + if (node == nullptr) { + return nullptr; + } + switch (node->kind) { + case Ast_BlockStmt: return node->BlockStmt.scope; + case Ast_IfStmt: return node->IfStmt.scope; + case Ast_ForStmt: return node->ForStmt.scope; + case Ast_RangeStmt: return node->RangeStmt.scope; + case Ast_UnrollRangeStmt: return node->UnrollRangeStmt.scope; + case Ast_CaseClause: return node->CaseClause.scope; + case Ast_SwitchStmt: return node->SwitchStmt.scope; + case Ast_TypeSwitchStmt: return node->TypeSwitchStmt.scope; + case Ast_ProcType: return node->ProcType.scope; + case Ast_StructType: return node->StructType.scope; + case Ast_UnionType: return node->UnionType.scope; + case Ast_EnumType: return node->EnumType.scope; + } + GB_PANIC("Invalid node for add_scope: %.*s", LIT(ast_strings[node->kind])); + return nullptr; } @@ -1081,9 +1117,6 @@ AstFile *ast_file_of_filename(CheckerInfo *i, String filename) { } return nullptr; } -Scope *scope_of_node(Ast *node) { - return node->scope; -} ExprInfo *check_get_expr_info(CheckerContext *c, Ast *expr) { if (c->untyped != nullptr) { ExprInfo **found = map_get(c->untyped, expr); diff --git a/src/checker.hpp b/src/checker.hpp index 6511dad32..74435c1d4 100644 --- a/src/checker.hpp +++ b/src/checker.hpp @@ -410,7 +410,6 @@ gb_global AstPackage *config_pkg = nullptr; TypeAndValue type_and_value_of_expr (Ast *expr); Type * type_of_expr (Ast *expr); Entity * implicit_entity_of_node(Ast *clause); -Scope * scope_of_node (Ast *node); DeclInfo * decl_info_of_ident (Ast *ident); DeclInfo * decl_info_of_entity (Entity * e); AstFile * ast_file_of_filename (CheckerInfo *i, String filename); diff --git a/src/llvm_backend_stmt.cpp b/src/llvm_backend_stmt.cpp index ec8ded7fa..c2ff0dfe1 100644 --- a/src/llvm_backend_stmt.cpp +++ b/src/llvm_backend_stmt.cpp @@ -1625,7 +1625,7 @@ void lb_build_return_stmt(lbProcedure *p, Slice const &return_results) { void lb_build_if_stmt(lbProcedure *p, Ast *node) { ast_node(is, IfStmt, node); - lb_open_scope(p, node->scope); // Scope #1 + lb_open_scope(p, is->scope); // Scope #1 defer (lb_close_scope(p, lbDeferExit_Default, nullptr)); if (is->init != nullptr) { @@ -1675,7 +1675,7 @@ void lb_build_if_stmt(lbProcedure *p, Ast *node) { lb_emit_jump(p, else_); lb_start_block(p, else_); - lb_open_scope(p, is->else_stmt->scope); + lb_open_scope(p, scope_of_node(is->else_stmt)); lb_build_stmt(p, is->else_stmt); lb_close_scope(p, lbDeferExit_Default, nullptr); } @@ -1692,7 +1692,7 @@ void lb_build_if_stmt(lbProcedure *p, Ast *node) { if (is->else_stmt != nullptr) { lb_start_block(p, else_); - lb_open_scope(p, is->else_stmt->scope); + lb_open_scope(p, scope_of_node(is->else_stmt)); lb_build_stmt(p, is->else_stmt); lb_close_scope(p, lbDeferExit_Default, nullptr); @@ -1710,7 +1710,7 @@ void lb_build_if_stmt(lbProcedure *p, Ast *node) { void lb_build_for_stmt(lbProcedure *p, Ast *node) { ast_node(fs, ForStmt, node); - lb_open_scope(p, node->scope); // Open Scope here + lb_open_scope(p, fs->scope); // Open Scope here if (fs->init != nullptr) { #if 1 @@ -2056,7 +2056,7 @@ void lb_build_stmt(lbProcedure *p, Ast *node) { tl->is_block = true; } - lb_open_scope(p, node->scope); + lb_open_scope(p, bs->scope); lb_build_stmt_list(p, bs->stmts); lb_close_scope(p, lbDeferExit_Default, nullptr); @@ -2137,15 +2137,15 @@ void lb_build_stmt(lbProcedure *p, Ast *node) { case_end; case_ast_node(rs, RangeStmt, node); - lb_build_range_stmt(p, rs, node->scope); + lb_build_range_stmt(p, rs, rs->scope); case_end; case_ast_node(rs, UnrollRangeStmt, node); - lb_build_unroll_range_stmt(p, rs, node->scope); + lb_build_unroll_range_stmt(p, rs, rs->scope); case_end; case_ast_node(ss, SwitchStmt, node); - lb_build_switch_stmt(p, ss, node->scope); + lb_build_switch_stmt(p, ss, ss->scope); case_end; case_ast_node(ss, TypeSwitchStmt, node); diff --git a/src/parser.hpp b/src/parser.hpp index 76ae33b21..b1518533e 100644 --- a/src/parser.hpp +++ b/src/parser.hpp @@ -424,11 +424,13 @@ AST_KIND(_StmtBegin, "", bool) \ }) \ AST_KIND(_ComplexStmtBegin, "", bool) \ AST_KIND(BlockStmt, "block statement", struct { \ + Scope *scope; \ Slice stmts; \ Ast *label; \ Token open, close; \ }) \ AST_KIND(IfStmt, "if statement", struct { \ + Scope *scope; \ Token token; \ Ast *label; \ Ast * init; \ @@ -449,6 +451,7 @@ AST_KIND(_ComplexStmtBegin, "", bool) \ Slice results; \ }) \ AST_KIND(ForStmt, "for statement", struct { \ + Scope *scope; \ Token token; \ Ast *label; \ Ast *init; \ @@ -457,6 +460,7 @@ AST_KIND(_ComplexStmtBegin, "", bool) \ Ast *body; \ }) \ AST_KIND(RangeStmt, "range statement", struct { \ + Scope *scope; \ Token token; \ Ast *label; \ Slice vals; \ @@ -465,6 +469,7 @@ AST_KIND(_ComplexStmtBegin, "", bool) \ Ast *body; \ }) \ AST_KIND(UnrollRangeStmt, "#unroll range statement", struct { \ + Scope *scope; \ Token unroll_token; \ Token for_token; \ Ast *val0; \ @@ -474,12 +479,14 @@ AST_KIND(_ComplexStmtBegin, "", bool) \ Ast *body; \ }) \ AST_KIND(CaseClause, "case clause", struct { \ + Scope *scope; \ Token token; \ Slice list; \ Slice stmts; \ Entity *implicit_entity; \ }) \ AST_KIND(SwitchStmt, "switch statement", struct { \ + Scope *scope; \ Token token; \ Ast *label; \ Ast *init; \ @@ -488,6 +495,7 @@ AST_KIND(_ComplexStmtBegin, "", bool) \ bool partial; \ }) \ AST_KIND(TypeSwitchStmt, "type switch statement", struct { \ + Scope *scope; \ Token token; \ Ast *label; \ Ast *tag; \ @@ -589,6 +597,7 @@ AST_KIND(_TypeBegin, "", bool) \ Ast * specialization; \ }) \ AST_KIND(ProcType, "procedure type", struct { \ + Scope *scope; \ Token token; \ Ast *params; \ Ast *results; \ @@ -621,6 +630,7 @@ AST_KIND(_TypeBegin, "", bool) \ Ast *tag; \ }) \ AST_KIND(StructType, "struct type", struct { \ + Scope *scope; \ Token token; \ Slice fields; \ isize field_count; \ @@ -632,6 +642,7 @@ AST_KIND(_TypeBegin, "", bool) \ bool is_raw_union; \ }) \ AST_KIND(UnionType, "union type", struct { \ + Scope *scope; \ Token token; \ Slice variants; \ Ast *polymorphic_params; \ @@ -642,6 +653,7 @@ AST_KIND(_TypeBegin, "", bool) \ Slice where_clauses; \ }) \ AST_KIND(EnumType, "enum type", struct { \ + Scope *scope; \ Token token; \ Ast * base_type; \ Slice fields; /* FieldValue */ \ @@ -695,21 +707,19 @@ isize const ast_variant_sizes[] = { }; struct AstCommonStuff { - AstKind kind; + AstKind kind; // u16 u8 state_flags; u8 viral_state_flags; i32 file_id; - Scope * scope; - TypeAndValue tav; // TODO(bill): Make this a pointer to minimize pointer size + TypeAndValue tav; // TODO(bill): Make this a pointer to minimize 'Ast' size }; struct Ast { - AstKind kind; + AstKind kind; // u16 u8 state_flags; u8 viral_state_flags; i32 file_id; - Scope * scope; - TypeAndValue tav; // TODO(bill): Make this a pointer to minimize pointer size + TypeAndValue tav; // TODO(bill): Make this a pointer to minimize 'Ast' size // IMPORTANT NOTE(bill): This must be at the end since the AST is allocated to be size of the variant union {