#unroll(N) for

This commit is contained in:
gingerBill
2025-01-10 12:14:43 +00:00
parent 4a2b13f1c2
commit 328d893cb5
7 changed files with 339 additions and 77 deletions

View File

@@ -432,10 +432,13 @@ Range_Stmt :: struct {
reverse: bool,
}
Inline_Range_Stmt :: struct {
Inline_Range_Stmt :: Unroll_Range_Stmt
Unroll_Range_Stmt :: struct {
using node: Stmt,
label: ^Expr,
inline_pos: tokenizer.Pos,
unroll_pos: tokenizer.Pos,
args: []^Expr,
for_pos: tokenizer.Pos,
val0: ^Expr,
val1: ^Expr,

View File

@@ -242,8 +242,9 @@ clone_node :: proc(node: ^Node) -> ^Node {
r.vals = clone(r.vals)
r.expr = clone(r.expr)
r.body = clone(r.body)
case ^Inline_Range_Stmt:
case ^Unroll_Range_Stmt:
r.label = clone(r.label)
r.args = clone(r.args)
r.val0 = clone(r.val0)
r.val1 = clone(r.val1)
r.expr = clone(r.expr)

View File

@@ -1262,11 +1262,49 @@ parse_foreign_decl :: proc(p: ^Parser) -> ^ast.Decl {
parse_unrolled_for_loop :: proc(p: ^Parser, inline_tok: tokenizer.Token) -> ^ast.Stmt {
for_tok := expect_token(p, .For)
val0, val1: ^ast.Expr
in_tok: tokenizer.Token
expr: ^ast.Expr
body: ^ast.Stmt
args: [dynamic]^ast.Expr
if allow_token(p, .Open_Paren) {
p.expr_level += 1
if p.curr_tok.kind == .Close_Paren {
error(p, p.curr_tok.pos, "#unroll expected at least 1 argument, got 0")
} else {
args = make([dynamic]^ast.Expr)
for p.curr_tok.kind != .Close_Paren &&
p.curr_tok.kind != .EOF {
arg := parse_value(p)
if p.curr_tok.kind == .Eq {
eq := expect_token(p, .Eq)
if arg != nil {
if _, ok := arg.derived.(^ast.Ident); !ok {
error(p, arg.pos, "expected an identifier for 'key=value'")
}
}
value := parse_value(p)
fv := ast.new(ast.Field_Value, arg.pos, value)
fv.field = arg
fv.sep = eq.pos
fv.value = value
arg = fv
}
append(&args, arg)
allow_token(p, .Comma) or_break
}
}
p.expr_level -= 1
_ = expect_token_after(p, .Close_Paren, "#unroll")
}
for_tok := expect_token(p, .For)
bad_stmt := false
@@ -1309,7 +1347,8 @@ parse_unrolled_for_loop :: proc(p: ^Parser, inline_tok: tokenizer.Token) -> ^ast
}
range_stmt := ast.new(ast.Inline_Range_Stmt, inline_tok.pos, body)
range_stmt.inline_pos = inline_tok.pos
range_stmt.unroll_pos = inline_tok.pos
range_stmt.args = args[:]
range_stmt.for_pos = for_tok.pos
range_stmt.val0 = val0
range_stmt.val1 = val1

View File

@@ -894,15 +894,49 @@ gb_internal void error_var_decl_identifier(Ast *name) {
}
}
gb_internal void check_inline_range_stmt(CheckerContext *ctx, Ast *node, u32 mod_flags) {
gb_internal void check_unroll_range_stmt(CheckerContext *ctx, Ast *node, u32 mod_flags) {
ast_node(irs, UnrollRangeStmt, node);
check_open_scope(ctx, node);
defer (check_close_scope(ctx));
Type *val0 = nullptr;
Type *val1 = nullptr;
Entity *entities[2] = {};
isize entity_count = 0;
i64 unroll_count = -1;
if (irs->args.count > 0) {
if (irs->args.count > 1) {
error(irs->args[1], "#unroll only supports a single argument for the unroll per loop amount");
}
Ast *arg = irs->args[0];
if (arg->kind == Ast_FieldValue) {
error(arg, "#unroll does not yet support named arguments");
arg = arg->FieldValue.value;
}
Operand x = {};
check_expr(ctx, &x, arg);
if (x.mode != Addressing_Constant || !is_type_integer(x.type)) {
gbString s = expr_to_string(x.expr);
error(x.expr, "Expected a constant integer for #unroll, got '%s'", s);
gb_string_free(s);
} else {
ExactValue value = exact_value_to_integer(x.value);
i64 v = exact_value_to_i64(value);
if (v < 1) {
error(x.expr, "Expected a constant integer >= 1 for #unroll, got %lld", cast(long long)v);
} else {
unroll_count = v;
if (v > 1024) {
error(x.expr, "Too large of a value for #unroll, got %lld, expected <= 1024", cast(long long)v);
}
}
}
}
Ast *expr = unparen_expr(irs->expr);
ExactValue inline_for_depth = exact_value_i64(0);
@@ -946,18 +980,39 @@ gb_internal void check_inline_range_stmt(CheckerContext *ctx, Ast *node, u32 mod
val0 = t_rune;
val1 = t_int;
inline_for_depth = exact_value_i64(operand.value.value_string.len);
if (unroll_count > 0) {
error(node, "#unroll(%lld) does not support strings", cast(long long)unroll_count);
}
}
break;
case Type_Array:
val0 = t->Array.elem;
val1 = t_int;
inline_for_depth = exact_value_i64(t->Array.count);
inline_for_depth = unroll_count > 0 ? exact_value_i64(unroll_count) : exact_value_i64(t->Array.count);
break;
case Type_EnumeratedArray:
val0 = t->EnumeratedArray.elem;
val1 = t->EnumeratedArray.index;
if (unroll_count > 0) {
error(node, "#unroll(%lld) does not support enumerated arrays", cast(long long)unroll_count);
}
inline_for_depth = exact_value_i64(t->EnumeratedArray.count);
break;
case Type_Slice:
if (unroll_count > 0) {
val0 = t->Slice.elem;
val1 = t_int;
inline_for_depth = exact_value_i64(unroll_count);
}
break;
case Type_DynamicArray:
if (unroll_count > 0) {
val0 = t->DynamicArray.elem;
val1 = t_int;
inline_for_depth = exact_value_i64(unroll_count);
}
break;
}
}
@@ -967,7 +1022,7 @@ gb_internal void check_inline_range_stmt(CheckerContext *ctx, Ast *node, u32 mod
error(operand.expr, "Cannot iterate over '%s' of type '%s' in an '#unroll for' statement", s, t);
gb_string_free(t);
gb_string_free(s);
} else if (operand.mode != Addressing_Constant) {
} else if (operand.mode != Addressing_Constant && unroll_count <= 0) {
error(operand.expr, "An '#unroll for' expression must be known at compile time");
}
}
@@ -1050,8 +1105,6 @@ gb_internal void check_inline_range_stmt(CheckerContext *ctx, Ast *node, u32 mod
check_stmt(ctx, irs->body, mod_flags);
check_close_scope(ctx);
}
gb_internal void check_switch_stmt(CheckerContext *ctx, Ast *node, u32 mod_flags) {
@@ -2679,7 +2732,7 @@ gb_internal void check_stmt_internal(CheckerContext *ctx, Ast *node, u32 flags)
case_end;
case_ast_node(irs, UnrollRangeStmt, node);
check_inline_range_stmt(ctx, node, mod_flags);
check_unroll_range_stmt(ctx, node, mod_flags);
case_end;
case_ast_node(ss, SwitchStmt, node);

View File

@@ -256,7 +256,7 @@ gb_internal void lb_build_when_stmt(lbProcedure *p, AstWhenStmt *ws) {
gb_internal void lb_build_range_indexed(lbProcedure *p, lbValue expr, Type *val_type, lbValue count_ptr,
lbValue *val_, lbValue *idx_, lbBlock **loop_, lbBlock **done_,
bool is_reverse) {
bool is_reverse, i64 unroll_count=0) {
lbModule *m = p->module;
lbValue count = {};
@@ -1230,7 +1230,6 @@ gb_internal void lb_build_unroll_range_stmt(lbProcedure *p, AstUnrollRangeStmt *
TypeAndValue tav = type_and_value_of_expr(expr);
if (is_ast_range(expr)) {
lbAddr val0_addr = {};
lbAddr val1_addr = {};
if (val0_type) val0_addr = lb_build_addr(p, val0);
@@ -1268,7 +1267,6 @@ gb_internal void lb_build_unroll_range_stmt(lbProcedure *p, AstUnrollRangeStmt *
}
}
} else if (tav.mode == Addressing_Type) {
GB_ASSERT(is_type_enum(type_deref(tav.type)));
Type *et = type_deref(tav.type);
@@ -1293,72 +1291,203 @@ gb_internal void lb_build_unroll_range_stmt(lbProcedure *p, AstUnrollRangeStmt *
if (val0_type) val0_addr = lb_build_addr(p, val0);
if (val1_type) val1_addr = lb_build_addr(p, val1);
GB_ASSERT(expr->tav.mode == Addressing_Constant);
Type *t = base_type(expr->tav.type);
ExactValue unroll_count_ev = {};
if (rs->args.count != 0) {
unroll_count_ev = rs->args[0]->tav.value;
}
switch (t->kind) {
case Type_Basic:
GB_ASSERT(is_type_string(t));
{
ExactValue value = expr->tav.value;
GB_ASSERT(value.kind == ExactValue_String);
String str = value.value_string;
Rune codepoint = 0;
isize offset = 0;
do {
isize width = utf8_decode(str.text+offset, str.len-offset, &codepoint);
if (val0_type) lb_addr_store(p, val0_addr, lb_const_value(m, val0_type, exact_value_i64(codepoint)));
if (val1_type) lb_addr_store(p, val1_addr, lb_const_value(m, val1_type, exact_value_i64(offset)));
lb_build_stmt(p, rs->body);
if (unroll_count_ev.kind == ExactValue_Invalid) {
GB_ASSERT(expr->tav.mode == Addressing_Constant);
offset += width;
} while (offset < str.len);
}
break;
case Type_Array:
if (t->Array.count > 0) {
lbValue val = lb_build_expr(p, expr);
lbValue val_addr = lb_address_from_load_or_generate_local(p, val);
Type *t = base_type(expr->tav.type);
for (i64 i = 0; i < t->Array.count; i++) {
if (val0_type) {
// NOTE(bill): Due to weird legacy issues in LLVM, this needs to be an i32
lbValue elem = lb_emit_array_epi(p, val_addr, cast(i32)i);
lb_addr_store(p, val0_addr, lb_emit_load(p, elem));
switch (t->kind) {
case Type_Basic:
GB_ASSERT(is_type_string(t));
{
ExactValue value = expr->tav.value;
GB_ASSERT(value.kind == ExactValue_String);
String str = value.value_string;
Rune codepoint = 0;
isize offset = 0;
do {
isize width = utf8_decode(str.text+offset, str.len-offset, &codepoint);
if (val0_type) lb_addr_store(p, val0_addr, lb_const_value(m, val0_type, exact_value_i64(codepoint)));
if (val1_type) lb_addr_store(p, val1_addr, lb_const_value(m, val1_type, exact_value_i64(offset)));
lb_build_stmt(p, rs->body);
offset += width;
} while (offset < str.len);
}
break;
case Type_Array:
if (t->Array.count > 0) {
lbValue val = lb_build_expr(p, expr);
lbValue val_addr = lb_address_from_load_or_generate_local(p, val);
for (i64 i = 0; i < t->Array.count; i++) {
if (val0_type) {
// NOTE(bill): Due to weird legacy issues in LLVM, this needs to be an i32
lbValue elem = lb_emit_array_epi(p, val_addr, cast(i32)i);
lb_addr_store(p, val0_addr, lb_emit_load(p, elem));
}
if (val1_type) lb_addr_store(p, val1_addr, lb_const_value(m, val1_type, exact_value_i64(i)));
lb_build_stmt(p, rs->body);
}
if (val1_type) lb_addr_store(p, val1_addr, lb_const_value(m, val1_type, exact_value_i64(i)));
lb_build_stmt(p, rs->body);
}
break;
case Type_EnumeratedArray:
if (t->EnumeratedArray.count > 0) {
lbValue val = lb_build_expr(p, expr);
lbValue val_addr = lb_address_from_load_or_generate_local(p, val);
for (i64 i = 0; i < t->EnumeratedArray.count; i++) {
if (val0_type) {
// NOTE(bill): Due to weird legacy issues in LLVM, this needs to be an i32
lbValue elem = lb_emit_array_epi(p, val_addr, cast(i32)i);
lb_addr_store(p, val0_addr, lb_emit_load(p, elem));
}
if (val1_type) {
ExactValue idx = exact_value_add(exact_value_i64(i), *t->EnumeratedArray.min_value);
lb_addr_store(p, val1_addr, lb_const_value(m, val1_type, idx));
}
lb_build_stmt(p, rs->body);
}
}
break;
default:
GB_PANIC("Invalid '#unroll for' type");
break;
}
} else {
////////////////////////////////
// //
// #unroll(N) logic //
// //
////////////////////////////////
i64 unroll_count = exact_value_to_i64(unroll_count_ev);
gb_unused(unroll_count);
Type *t = base_type(expr->tav.type);
lbValue data_ptr = {};
lbValue count_ptr = {};
switch (t->kind) {
case Type_Slice:
case Type_DynamicArray: {
lbValue slice = lb_build_expr(p, expr);
if (is_type_pointer(slice.type)) {
count_ptr = lb_emit_struct_ep(p, slice, 1);
slice = lb_emit_load(p, slice);
} else {
count_ptr = lb_add_local_generated(p, t_int, false).addr;
lb_emit_store(p, count_ptr, lb_slice_len(p, slice));
}
data_ptr = lb_emit_struct_ev(p, slice, 0);
break;
}
case Type_Array: {
lbValue array = lb_build_expr(p, expr);
count_ptr = lb_add_local_generated(p, t_int, false).addr;
lb_emit_store(p, count_ptr, lb_const_int(p->module, t_int, t->Array.count));
if (!is_type_pointer(array.type)) {
array = lb_address_from_load_or_generate_local(p, array);
}
GB_ASSERT(is_type_pointer(array.type));
data_ptr = lb_emit_conv(p, array, alloc_type_pointer(t->Array.elem));
break;
}
break;
case Type_EnumeratedArray:
if (t->EnumeratedArray.count > 0) {
lbValue val = lb_build_expr(p, expr);
lbValue val_addr = lb_address_from_load_or_generate_local(p, val);
for (i64 i = 0; i < t->EnumeratedArray.count; i++) {
if (val0_type) {
// NOTE(bill): Due to weird legacy issues in LLVM, this needs to be an i32
lbValue elem = lb_emit_array_epi(p, val_addr, cast(i32)i);
lb_addr_store(p, val0_addr, lb_emit_load(p, elem));
}
if (val1_type) {
ExactValue idx = exact_value_add(exact_value_i64(i), *t->EnumeratedArray.min_value);
lb_addr_store(p, val1_addr, lb_const_value(m, val1_type, idx));
}
default:
GB_PANIC("Invalid '#unroll for' type");
break;
}
data_ptr.type = alloc_type_multi_pointer_to_pointer(data_ptr.type);
lbBlock *loop_top = lb_create_block(p, "for.unroll.loop.top");
lbBlock *body_top = lb_create_block(p, "for.unroll.body.top");
lbBlock *body_bot = lb_create_block(p, "for.unroll.body.bot");
lbBlock *done = lb_create_block(p, "for.unroll.done");
lbBlock *loop_bot = unroll_count > 1 ? lb_create_block(p, "for.unroll.loop.bot") : done;
/*
i := 0
for ; i+N <= len(array); i += N {
body
}
for ; i < len(array); i += 1 {
body
}
*/
Entity *val_entity = val0 ? entity_of_node(val0) : nullptr;
Entity *idx_entity = val1 ? entity_of_node(val1) : nullptr;
lbAddr val_addr = lb_add_local(p, type_deref(data_ptr.type, true), val_entity);
lbAddr idx_addr = lb_add_local(p, t_int, idx_entity);
lb_addr_store(p, idx_addr, lb_const_nil(p->module, t_int));
lb_emit_jump(p, loop_top);
lb_start_block(p, loop_top);
lbValue idx_add_n = lb_addr_load(p, idx_addr);
idx_add_n = lb_emit_arith(p, Token_Add, idx_add_n, lb_const_int(p->module, t_int, unroll_count), t_int);
lbValue cond_top = lb_emit_comp(p, Token_LtEq, idx_add_n, lb_emit_load(p, count_ptr));
lb_emit_if(p, cond_top, body_top, loop_bot);
lb_start_block(p, body_top);
for (i64 top = 0; top < unroll_count; top++) {
lbValue idx = lb_addr_load(p, idx_addr);
lbValue val = lb_emit_load(p, lb_emit_ptr_offset(p, data_ptr, idx));
lb_addr_store(p, val_addr, val);
lb_build_stmt(p, rs->body);
lb_emit_increment(p, lb_addr_get_ptr(p, idx_addr));
}
lb_emit_jump(p, loop_top);
if (unroll_count > 1) {
lb_start_block(p, loop_bot);
lbValue cond_bot = lb_emit_comp(p, Token_Lt, lb_addr_load(p, idx_addr), lb_emit_load(p, count_ptr));
lb_emit_if(p, cond_bot, body_bot, done);
lb_start_block(p, body_bot);
{
lbValue idx = lb_addr_load(p, idx_addr);
lbValue val = lb_emit_load(p, lb_emit_ptr_offset(p, data_ptr, idx));
lb_addr_store(p, val_addr, val);
lb_build_stmt(p, rs->body);
}
lb_emit_increment(p, lb_addr_get_ptr(p, idx_addr));
}
lb_emit_jump(p, loop_bot);
}
break;
default:
GB_PANIC("Invalid '#unroll for' type");
break;
lb_close_scope(p, lbDeferExit_Default, nullptr, rs->body);
lb_emit_jump(p, done);
lb_start_block(p, done);
return;
}
}

View File

@@ -348,10 +348,11 @@ gb_internal Ast *clone_ast(Ast *node, AstFile *f) {
n->RangeStmt.body = clone_ast(n->RangeStmt.body, f);
break;
case Ast_UnrollRangeStmt:
n->UnrollRangeStmt.val0 = clone_ast(n->UnrollRangeStmt.val0, f);
n->UnrollRangeStmt.val1 = clone_ast(n->UnrollRangeStmt.val1, f);
n->UnrollRangeStmt.expr = clone_ast(n->UnrollRangeStmt.expr, f);
n->UnrollRangeStmt.body = clone_ast(n->UnrollRangeStmt.body, f);
n->UnrollRangeStmt.args = clone_ast_array(n->UnrollRangeStmt.args, f);
n->UnrollRangeStmt.val0 = clone_ast(n->UnrollRangeStmt.val0, f);
n->UnrollRangeStmt.val1 = clone_ast(n->UnrollRangeStmt.val1, f);
n->UnrollRangeStmt.expr = clone_ast(n->UnrollRangeStmt.expr, f);
n->UnrollRangeStmt.body = clone_ast(n->UnrollRangeStmt.body, f);
break;
case Ast_CaseClause:
n->CaseClause.list = clone_ast_array(n->CaseClause.list, f);
@@ -1037,15 +1038,16 @@ gb_internal Ast *ast_range_stmt(AstFile *f, Token token, Slice<Ast *> vals, Toke
return result;
}
gb_internal Ast *ast_unroll_range_stmt(AstFile *f, Token unroll_token, Token for_token, Ast *val0, Ast *val1, Token in_token, Ast *expr, Ast *body) {
gb_internal Ast *ast_unroll_range_stmt(AstFile *f, Token unroll_token, Slice<Ast *> args, Token for_token, Ast *val0, Ast *val1, Token in_token, Ast *expr, Ast *body) {
Ast *result = alloc_ast_node(f, Ast_UnrollRangeStmt);
result->UnrollRangeStmt.unroll_token = unroll_token;
result->UnrollRangeStmt.args = args;
result->UnrollRangeStmt.for_token = for_token;
result->UnrollRangeStmt.val0 = val0;
result->UnrollRangeStmt.val1 = val1;
result->UnrollRangeStmt.in_token = in_token;
result->UnrollRangeStmt.expr = expr;
result->UnrollRangeStmt.body = body;
result->UnrollRangeStmt.val0 = val0;
result->UnrollRangeStmt.val1 = val1;
result->UnrollRangeStmt.in_token = in_token;
result->UnrollRangeStmt.expr = expr;
result->UnrollRangeStmt.body = body;
return result;
}
@@ -5137,6 +5139,40 @@ gb_internal Ast *parse_attribute(AstFile *f, Token token, TokenKind open_kind, T
gb_internal Ast *parse_unrolled_for_loop(AstFile *f, Token unroll_token) {
Array<Ast *> args = {};
if (allow_token(f, Token_OpenParen)) {
f->expr_level++;
if (f->curr_token.kind == Token_CloseParen) {
syntax_error(f->curr_token, "#unroll expected at least 1 argument, got 0");
} else {
args = array_make<Ast *>(ast_allocator(f));
while (f->curr_token.kind != Token_CloseParen &&
f->curr_token.kind != Token_EOF) {
Ast *arg = nullptr;
arg = parse_value(f);
if (f->curr_token.kind == Token_Eq) {
Token eq = expect_token(f, Token_Eq);
if (arg != nullptr && arg->kind != Ast_Ident) {
syntax_error(arg, "Expected an identifier for 'key=value'");
}
Ast *value = parse_value(f);
arg = ast_field_value(f, arg, value, eq);
}
array_add(&args, arg);
if (!allow_field_separator(f)) {
break;
}
}
}
f->expr_level--;
Token close = expect_closing(f, Token_CloseParen, str_lit("#unroll"));
gb_unused(close);
}
Token for_token = expect_token(f, Token_for);
Ast *val0 = nullptr;
Ast *val1 = nullptr;
@@ -5180,7 +5216,7 @@ gb_internal Ast *parse_unrolled_for_loop(AstFile *f, Token unroll_token) {
if (bad_stmt) {
return ast_bad_stmt(f, unroll_token, f->curr_token);
}
return ast_unroll_range_stmt(f, unroll_token, for_token, val0, val1, in_token, expr, body);
return ast_unroll_range_stmt(f, unroll_token, slice_from_array(args), for_token, val0, val1, in_token, expr, body);
}
gb_internal Ast *parse_stmt(AstFile *f) {

View File

@@ -563,6 +563,7 @@ AST_KIND(_ComplexStmtBegin, "", bool) \
AST_KIND(UnrollRangeStmt, "#unroll range statement", struct { \
Scope *scope; \
Token unroll_token; \
Slice<Ast *> args; \
Token for_token; \
Ast *val0; \
Ast *val1; \