diff --git a/core/odin/ast/ast.odin b/core/odin/ast/ast.odin index f62feec8c..3b8998b31 100644 --- a/core/odin/ast/ast.odin +++ b/core/odin/ast/ast.odin @@ -432,10 +432,13 @@ Range_Stmt :: struct { reverse: bool, } -Inline_Range_Stmt :: struct { +Inline_Range_Stmt :: Unroll_Range_Stmt + +Unroll_Range_Stmt :: struct { using node: Stmt, label: ^Expr, - inline_pos: tokenizer.Pos, + unroll_pos: tokenizer.Pos, + args: []^Expr, for_pos: tokenizer.Pos, val0: ^Expr, val1: ^Expr, diff --git a/core/odin/ast/clone.odin b/core/odin/ast/clone.odin index 67f7ffa95..b7501e6ca 100644 --- a/core/odin/ast/clone.odin +++ b/core/odin/ast/clone.odin @@ -242,8 +242,9 @@ clone_node :: proc(node: ^Node) -> ^Node { r.vals = clone(r.vals) r.expr = clone(r.expr) r.body = clone(r.body) - case ^Inline_Range_Stmt: + case ^Unroll_Range_Stmt: r.label = clone(r.label) + r.args = clone(r.args) r.val0 = clone(r.val0) r.val1 = clone(r.val1) r.expr = clone(r.expr) diff --git a/core/odin/parser/parser.odin b/core/odin/parser/parser.odin index 5a7440339..63c7e388f 100644 --- a/core/odin/parser/parser.odin +++ b/core/odin/parser/parser.odin @@ -1262,11 +1262,49 @@ parse_foreign_decl :: proc(p: ^Parser) -> ^ast.Decl { parse_unrolled_for_loop :: proc(p: ^Parser, inline_tok: tokenizer.Token) -> ^ast.Stmt { - for_tok := expect_token(p, .For) val0, val1: ^ast.Expr in_tok: tokenizer.Token expr: ^ast.Expr body: ^ast.Stmt + args: [dynamic]^ast.Expr + + if allow_token(p, .Open_Paren) { + p.expr_level += 1 + if p.curr_tok.kind == .Close_Paren { + error(p, p.curr_tok.pos, "#unroll expected at least 1 argument, got 0") + } else { + args = make([dynamic]^ast.Expr) + for p.curr_tok.kind != .Close_Paren && + p.curr_tok.kind != .EOF { + arg := parse_value(p) + + if p.curr_tok.kind == .Eq { + eq := expect_token(p, .Eq) + if arg != nil { + if _, ok := arg.derived.(^ast.Ident); !ok { + error(p, arg.pos, "expected an identifier for 'key=value'") + } + } + value := parse_value(p) + fv := ast.new(ast.Field_Value, arg.pos, value) + fv.field = arg + fv.sep = eq.pos + fv.value = value + + arg = fv + } + + append(&args, arg) + + allow_token(p, .Comma) or_break + } + } + + p.expr_level -= 1 + _ = expect_token_after(p, .Close_Paren, "#unroll") + } + + for_tok := expect_token(p, .For) bad_stmt := false @@ -1309,7 +1347,8 @@ parse_unrolled_for_loop :: proc(p: ^Parser, inline_tok: tokenizer.Token) -> ^ast } range_stmt := ast.new(ast.Inline_Range_Stmt, inline_tok.pos, body) - range_stmt.inline_pos = inline_tok.pos + range_stmt.unroll_pos = inline_tok.pos + range_stmt.args = args[:] range_stmt.for_pos = for_tok.pos range_stmt.val0 = val0 range_stmt.val1 = val1 diff --git a/src/check_stmt.cpp b/src/check_stmt.cpp index 02ad72388..1708f7c81 100644 --- a/src/check_stmt.cpp +++ b/src/check_stmt.cpp @@ -894,15 +894,49 @@ gb_internal void error_var_decl_identifier(Ast *name) { } } -gb_internal void check_inline_range_stmt(CheckerContext *ctx, Ast *node, u32 mod_flags) { +gb_internal void check_unroll_range_stmt(CheckerContext *ctx, Ast *node, u32 mod_flags) { ast_node(irs, UnrollRangeStmt, node); check_open_scope(ctx, node); + defer (check_close_scope(ctx)); Type *val0 = nullptr; Type *val1 = nullptr; Entity *entities[2] = {}; isize entity_count = 0; + i64 unroll_count = -1; + + if (irs->args.count > 0) { + if (irs->args.count > 1) { + error(irs->args[1], "#unroll only supports a single argument for the unroll per loop amount"); + } + Ast *arg = irs->args[0]; + if (arg->kind == Ast_FieldValue) { + error(arg, "#unroll does not yet support named arguments"); + arg = arg->FieldValue.value; + } + + Operand x = {}; + check_expr(ctx, &x, arg); + if (x.mode != Addressing_Constant || !is_type_integer(x.type)) { + gbString s = expr_to_string(x.expr); + error(x.expr, "Expected a constant integer for #unroll, got '%s'", s); + gb_string_free(s); + } else { + ExactValue value = exact_value_to_integer(x.value); + i64 v = exact_value_to_i64(value); + if (v < 1) { + error(x.expr, "Expected a constant integer >= 1 for #unroll, got %lld", cast(long long)v); + } else { + unroll_count = v; + if (v > 1024) { + error(x.expr, "Too large of a value for #unroll, got %lld, expected <= 1024", cast(long long)v); + } + } + + } + } + Ast *expr = unparen_expr(irs->expr); ExactValue inline_for_depth = exact_value_i64(0); @@ -946,18 +980,39 @@ gb_internal void check_inline_range_stmt(CheckerContext *ctx, Ast *node, u32 mod val0 = t_rune; val1 = t_int; inline_for_depth = exact_value_i64(operand.value.value_string.len); + if (unroll_count > 0) { + error(node, "#unroll(%lld) does not support strings", cast(long long)unroll_count); + } } break; case Type_Array: val0 = t->Array.elem; val1 = t_int; - inline_for_depth = exact_value_i64(t->Array.count); + inline_for_depth = unroll_count > 0 ? exact_value_i64(unroll_count) : exact_value_i64(t->Array.count); break; case Type_EnumeratedArray: val0 = t->EnumeratedArray.elem; val1 = t->EnumeratedArray.index; + if (unroll_count > 0) { + error(node, "#unroll(%lld) does not support enumerated arrays", cast(long long)unroll_count); + } inline_for_depth = exact_value_i64(t->EnumeratedArray.count); break; + + case Type_Slice: + if (unroll_count > 0) { + val0 = t->Slice.elem; + val1 = t_int; + inline_for_depth = exact_value_i64(unroll_count); + } + break; + case Type_DynamicArray: + if (unroll_count > 0) { + val0 = t->DynamicArray.elem; + val1 = t_int; + inline_for_depth = exact_value_i64(unroll_count); + } + break; } } @@ -967,7 +1022,7 @@ gb_internal void check_inline_range_stmt(CheckerContext *ctx, Ast *node, u32 mod error(operand.expr, "Cannot iterate over '%s' of type '%s' in an '#unroll for' statement", s, t); gb_string_free(t); gb_string_free(s); - } else if (operand.mode != Addressing_Constant) { + } else if (operand.mode != Addressing_Constant && unroll_count <= 0) { error(operand.expr, "An '#unroll for' expression must be known at compile time"); } } @@ -1050,8 +1105,6 @@ gb_internal void check_inline_range_stmt(CheckerContext *ctx, Ast *node, u32 mod check_stmt(ctx, irs->body, mod_flags); - - check_close_scope(ctx); } gb_internal void check_switch_stmt(CheckerContext *ctx, Ast *node, u32 mod_flags) { @@ -2679,7 +2732,7 @@ gb_internal void check_stmt_internal(CheckerContext *ctx, Ast *node, u32 flags) case_end; case_ast_node(irs, UnrollRangeStmt, node); - check_inline_range_stmt(ctx, node, mod_flags); + check_unroll_range_stmt(ctx, node, mod_flags); case_end; case_ast_node(ss, SwitchStmt, node); diff --git a/src/llvm_backend_stmt.cpp b/src/llvm_backend_stmt.cpp index a2f0d2f4a..b05df0b46 100644 --- a/src/llvm_backend_stmt.cpp +++ b/src/llvm_backend_stmt.cpp @@ -256,7 +256,7 @@ gb_internal void lb_build_when_stmt(lbProcedure *p, AstWhenStmt *ws) { gb_internal void lb_build_range_indexed(lbProcedure *p, lbValue expr, Type *val_type, lbValue count_ptr, lbValue *val_, lbValue *idx_, lbBlock **loop_, lbBlock **done_, - bool is_reverse) { + bool is_reverse, i64 unroll_count=0) { lbModule *m = p->module; lbValue count = {}; @@ -1230,7 +1230,6 @@ gb_internal void lb_build_unroll_range_stmt(lbProcedure *p, AstUnrollRangeStmt * TypeAndValue tav = type_and_value_of_expr(expr); if (is_ast_range(expr)) { - lbAddr val0_addr = {}; lbAddr val1_addr = {}; if (val0_type) val0_addr = lb_build_addr(p, val0); @@ -1268,7 +1267,6 @@ gb_internal void lb_build_unroll_range_stmt(lbProcedure *p, AstUnrollRangeStmt * } } - } else if (tav.mode == Addressing_Type) { GB_ASSERT(is_type_enum(type_deref(tav.type))); Type *et = type_deref(tav.type); @@ -1293,72 +1291,203 @@ gb_internal void lb_build_unroll_range_stmt(lbProcedure *p, AstUnrollRangeStmt * if (val0_type) val0_addr = lb_build_addr(p, val0); if (val1_type) val1_addr = lb_build_addr(p, val1); - GB_ASSERT(expr->tav.mode == Addressing_Constant); - - Type *t = base_type(expr->tav.type); + ExactValue unroll_count_ev = {}; + if (rs->args.count != 0) { + unroll_count_ev = rs->args[0]->tav.value; + } - switch (t->kind) { - case Type_Basic: - GB_ASSERT(is_type_string(t)); - { - ExactValue value = expr->tav.value; - GB_ASSERT(value.kind == ExactValue_String); - String str = value.value_string; - Rune codepoint = 0; - isize offset = 0; - do { - isize width = utf8_decode(str.text+offset, str.len-offset, &codepoint); - if (val0_type) lb_addr_store(p, val0_addr, lb_const_value(m, val0_type, exact_value_i64(codepoint))); - if (val1_type) lb_addr_store(p, val1_addr, lb_const_value(m, val1_type, exact_value_i64(offset))); - lb_build_stmt(p, rs->body); + if (unroll_count_ev.kind == ExactValue_Invalid) { + GB_ASSERT(expr->tav.mode == Addressing_Constant); - offset += width; - } while (offset < str.len); - } - break; - case Type_Array: - if (t->Array.count > 0) { - lbValue val = lb_build_expr(p, expr); - lbValue val_addr = lb_address_from_load_or_generate_local(p, val); + Type *t = base_type(expr->tav.type); - for (i64 i = 0; i < t->Array.count; i++) { - if (val0_type) { - // NOTE(bill): Due to weird legacy issues in LLVM, this needs to be an i32 - lbValue elem = lb_emit_array_epi(p, val_addr, cast(i32)i); - lb_addr_store(p, val0_addr, lb_emit_load(p, elem)); + switch (t->kind) { + case Type_Basic: + GB_ASSERT(is_type_string(t)); + { + ExactValue value = expr->tav.value; + GB_ASSERT(value.kind == ExactValue_String); + String str = value.value_string; + Rune codepoint = 0; + isize offset = 0; + do { + isize width = utf8_decode(str.text+offset, str.len-offset, &codepoint); + if (val0_type) lb_addr_store(p, val0_addr, lb_const_value(m, val0_type, exact_value_i64(codepoint))); + if (val1_type) lb_addr_store(p, val1_addr, lb_const_value(m, val1_type, exact_value_i64(offset))); + lb_build_stmt(p, rs->body); + + offset += width; + } while (offset < str.len); + } + break; + case Type_Array: + if (t->Array.count > 0) { + lbValue val = lb_build_expr(p, expr); + lbValue val_addr = lb_address_from_load_or_generate_local(p, val); + + for (i64 i = 0; i < t->Array.count; i++) { + if (val0_type) { + // NOTE(bill): Due to weird legacy issues in LLVM, this needs to be an i32 + lbValue elem = lb_emit_array_epi(p, val_addr, cast(i32)i); + lb_addr_store(p, val0_addr, lb_emit_load(p, elem)); + } + if (val1_type) lb_addr_store(p, val1_addr, lb_const_value(m, val1_type, exact_value_i64(i))); + + lb_build_stmt(p, rs->body); } - if (val1_type) lb_addr_store(p, val1_addr, lb_const_value(m, val1_type, exact_value_i64(i))); - lb_build_stmt(p, rs->body); + } + break; + case Type_EnumeratedArray: + if (t->EnumeratedArray.count > 0) { + lbValue val = lb_build_expr(p, expr); + lbValue val_addr = lb_address_from_load_or_generate_local(p, val); + + for (i64 i = 0; i < t->EnumeratedArray.count; i++) { + if (val0_type) { + // NOTE(bill): Due to weird legacy issues in LLVM, this needs to be an i32 + lbValue elem = lb_emit_array_epi(p, val_addr, cast(i32)i); + lb_addr_store(p, val0_addr, lb_emit_load(p, elem)); + } + if (val1_type) { + ExactValue idx = exact_value_add(exact_value_i64(i), *t->EnumeratedArray.min_value); + lb_addr_store(p, val1_addr, lb_const_value(m, val1_type, idx)); + } + + lb_build_stmt(p, rs->body); + } + + } + break; + default: + GB_PANIC("Invalid '#unroll for' type"); + break; + } + } else { + + //////////////////////////////// + // // + // #unroll(N) logic // + // // + //////////////////////////////// + + + i64 unroll_count = exact_value_to_i64(unroll_count_ev); + gb_unused(unroll_count); + + Type *t = base_type(expr->tav.type); + + lbValue data_ptr = {}; + lbValue count_ptr = {}; + + switch (t->kind) { + case Type_Slice: + case Type_DynamicArray: { + lbValue slice = lb_build_expr(p, expr); + if (is_type_pointer(slice.type)) { + count_ptr = lb_emit_struct_ep(p, slice, 1); + slice = lb_emit_load(p, slice); + } else { + count_ptr = lb_add_local_generated(p, t_int, false).addr; + lb_emit_store(p, count_ptr, lb_slice_len(p, slice)); + } + data_ptr = lb_emit_struct_ev(p, slice, 0); + break; + } + + case Type_Array: { + lbValue array = lb_build_expr(p, expr); + count_ptr = lb_add_local_generated(p, t_int, false).addr; + lb_emit_store(p, count_ptr, lb_const_int(p->module, t_int, t->Array.count)); + + if (!is_type_pointer(array.type)) { + array = lb_address_from_load_or_generate_local(p, array); } + GB_ASSERT(is_type_pointer(array.type)); + data_ptr = lb_emit_conv(p, array, alloc_type_pointer(t->Array.elem)); + break; } - break; - case Type_EnumeratedArray: - if (t->EnumeratedArray.count > 0) { - lbValue val = lb_build_expr(p, expr); - lbValue val_addr = lb_address_from_load_or_generate_local(p, val); - for (i64 i = 0; i < t->EnumeratedArray.count; i++) { - if (val0_type) { - // NOTE(bill): Due to weird legacy issues in LLVM, this needs to be an i32 - lbValue elem = lb_emit_array_epi(p, val_addr, cast(i32)i); - lb_addr_store(p, val0_addr, lb_emit_load(p, elem)); - } - if (val1_type) { - ExactValue idx = exact_value_add(exact_value_i64(i), *t->EnumeratedArray.min_value); - lb_addr_store(p, val1_addr, lb_const_value(m, val1_type, idx)); - } + default: + GB_PANIC("Invalid '#unroll for' type"); + break; + } + + data_ptr.type = alloc_type_multi_pointer_to_pointer(data_ptr.type); + + lbBlock *loop_top = lb_create_block(p, "for.unroll.loop.top"); + + lbBlock *body_top = lb_create_block(p, "for.unroll.body.top"); + lbBlock *body_bot = lb_create_block(p, "for.unroll.body.bot"); + + lbBlock *done = lb_create_block(p, "for.unroll.done"); + + lbBlock *loop_bot = unroll_count > 1 ? lb_create_block(p, "for.unroll.loop.bot") : done; + + /* + i := 0 + for ; i+N <= len(array); i += N { + body + } + for ; i < len(array); i += 1 { + body + } + */ + + Entity *val_entity = val0 ? entity_of_node(val0) : nullptr; + Entity *idx_entity = val1 ? entity_of_node(val1) : nullptr; + + lbAddr val_addr = lb_add_local(p, type_deref(data_ptr.type, true), val_entity); + lbAddr idx_addr = lb_add_local(p, t_int, idx_entity); + lb_addr_store(p, idx_addr, lb_const_nil(p->module, t_int)); + + lb_emit_jump(p, loop_top); + lb_start_block(p, loop_top); + + lbValue idx_add_n = lb_addr_load(p, idx_addr); + idx_add_n = lb_emit_arith(p, Token_Add, idx_add_n, lb_const_int(p->module, t_int, unroll_count), t_int); + + lbValue cond_top = lb_emit_comp(p, Token_LtEq, idx_add_n, lb_emit_load(p, count_ptr)); + lb_emit_if(p, cond_top, body_top, loop_bot); + + lb_start_block(p, body_top); + for (i64 top = 0; top < unroll_count; top++) { + lbValue idx = lb_addr_load(p, idx_addr); + lbValue val = lb_emit_load(p, lb_emit_ptr_offset(p, data_ptr, idx)); + lb_addr_store(p, val_addr, val); + + lb_build_stmt(p, rs->body); + + lb_emit_increment(p, lb_addr_get_ptr(p, idx_addr)); + } + lb_emit_jump(p, loop_top); + + if (unroll_count > 1) { + lb_start_block(p, loop_bot); + + lbValue cond_bot = lb_emit_comp(p, Token_Lt, lb_addr_load(p, idx_addr), lb_emit_load(p, count_ptr)); + lb_emit_if(p, cond_bot, body_bot, done); + + lb_start_block(p, body_bot); + { + lbValue idx = lb_addr_load(p, idx_addr); + lbValue val = lb_emit_load(p, lb_emit_ptr_offset(p, data_ptr, idx)); + lb_addr_store(p, val_addr, val); lb_build_stmt(p, rs->body); - } + lb_emit_increment(p, lb_addr_get_ptr(p, idx_addr)); + } + lb_emit_jump(p, loop_bot); } - break; - default: - GB_PANIC("Invalid '#unroll for' type"); - break; + + lb_close_scope(p, lbDeferExit_Default, nullptr, rs->body); + lb_emit_jump(p, done); + lb_start_block(p, done); + + return; } } diff --git a/src/parser.cpp b/src/parser.cpp index 03c5a5962..94f8fd42c 100644 --- a/src/parser.cpp +++ b/src/parser.cpp @@ -348,10 +348,11 @@ gb_internal Ast *clone_ast(Ast *node, AstFile *f) { n->RangeStmt.body = clone_ast(n->RangeStmt.body, f); break; case Ast_UnrollRangeStmt: - n->UnrollRangeStmt.val0 = clone_ast(n->UnrollRangeStmt.val0, f); - n->UnrollRangeStmt.val1 = clone_ast(n->UnrollRangeStmt.val1, f); - n->UnrollRangeStmt.expr = clone_ast(n->UnrollRangeStmt.expr, f); - n->UnrollRangeStmt.body = clone_ast(n->UnrollRangeStmt.body, f); + n->UnrollRangeStmt.args = clone_ast_array(n->UnrollRangeStmt.args, f); + n->UnrollRangeStmt.val0 = clone_ast(n->UnrollRangeStmt.val0, f); + n->UnrollRangeStmt.val1 = clone_ast(n->UnrollRangeStmt.val1, f); + n->UnrollRangeStmt.expr = clone_ast(n->UnrollRangeStmt.expr, f); + n->UnrollRangeStmt.body = clone_ast(n->UnrollRangeStmt.body, f); break; case Ast_CaseClause: n->CaseClause.list = clone_ast_array(n->CaseClause.list, f); @@ -1037,15 +1038,16 @@ gb_internal Ast *ast_range_stmt(AstFile *f, Token token, Slice vals, Toke return result; } -gb_internal Ast *ast_unroll_range_stmt(AstFile *f, Token unroll_token, Token for_token, Ast *val0, Ast *val1, Token in_token, Ast *expr, Ast *body) { +gb_internal Ast *ast_unroll_range_stmt(AstFile *f, Token unroll_token, Slice args, Token for_token, Ast *val0, Ast *val1, Token in_token, Ast *expr, Ast *body) { Ast *result = alloc_ast_node(f, Ast_UnrollRangeStmt); result->UnrollRangeStmt.unroll_token = unroll_token; + result->UnrollRangeStmt.args = args; result->UnrollRangeStmt.for_token = for_token; - result->UnrollRangeStmt.val0 = val0; - result->UnrollRangeStmt.val1 = val1; - result->UnrollRangeStmt.in_token = in_token; - result->UnrollRangeStmt.expr = expr; - result->UnrollRangeStmt.body = body; + result->UnrollRangeStmt.val0 = val0; + result->UnrollRangeStmt.val1 = val1; + result->UnrollRangeStmt.in_token = in_token; + result->UnrollRangeStmt.expr = expr; + result->UnrollRangeStmt.body = body; return result; } @@ -5137,6 +5139,40 @@ gb_internal Ast *parse_attribute(AstFile *f, Token token, TokenKind open_kind, T gb_internal Ast *parse_unrolled_for_loop(AstFile *f, Token unroll_token) { + Array args = {}; + + if (allow_token(f, Token_OpenParen)) { + f->expr_level++; + if (f->curr_token.kind == Token_CloseParen) { + syntax_error(f->curr_token, "#unroll expected at least 1 argument, got 0"); + } else { + args = array_make(ast_allocator(f)); + while (f->curr_token.kind != Token_CloseParen && + f->curr_token.kind != Token_EOF) { + Ast *arg = nullptr; + arg = parse_value(f); + + if (f->curr_token.kind == Token_Eq) { + Token eq = expect_token(f, Token_Eq); + if (arg != nullptr && arg->kind != Ast_Ident) { + syntax_error(arg, "Expected an identifier for 'key=value'"); + } + Ast *value = parse_value(f); + arg = ast_field_value(f, arg, value, eq); + } + + array_add(&args, arg); + + if (!allow_field_separator(f)) { + break; + } + } + } + f->expr_level--; + Token close = expect_closing(f, Token_CloseParen, str_lit("#unroll")); + gb_unused(close); + } + Token for_token = expect_token(f, Token_for); Ast *val0 = nullptr; Ast *val1 = nullptr; @@ -5180,7 +5216,7 @@ gb_internal Ast *parse_unrolled_for_loop(AstFile *f, Token unroll_token) { if (bad_stmt) { return ast_bad_stmt(f, unroll_token, f->curr_token); } - return ast_unroll_range_stmt(f, unroll_token, for_token, val0, val1, in_token, expr, body); + return ast_unroll_range_stmt(f, unroll_token, slice_from_array(args), for_token, val0, val1, in_token, expr, body); } gb_internal Ast *parse_stmt(AstFile *f) { diff --git a/src/parser.hpp b/src/parser.hpp index bbf70d03e..d2dd22667 100644 --- a/src/parser.hpp +++ b/src/parser.hpp @@ -563,6 +563,7 @@ AST_KIND(_ComplexStmtBegin, "", bool) \ AST_KIND(UnrollRangeStmt, "#unroll range statement", struct { \ Scope *scope; \ Token unroll_token; \ + Slice args; \ Token for_token; \ Ast *val0; \ Ast *val1; \