diff --git a/core/fmt/fmt.odin b/core/fmt/fmt.odin index 77b848315..cee00da23 100644 --- a/core/fmt/fmt.odin +++ b/core/fmt/fmt.odin @@ -1953,6 +1953,8 @@ fmt_value :: proc(fi: ^Info, v: any, verb: rune) { } } + case runtime.Type_Info_Matrix: + io.write_string(fi.writer, "[]") } } diff --git a/core/reflect/reflect.odin b/core/reflect/reflect.odin index f509ffe1b..7f64d0974 100644 --- a/core/reflect/reflect.odin +++ b/core/reflect/reflect.odin @@ -33,6 +33,7 @@ Type_Info_Bit_Set :: runtime.Type_Info_Bit_Set Type_Info_Simd_Vector :: runtime.Type_Info_Simd_Vector Type_Info_Relative_Pointer :: runtime.Type_Info_Relative_Pointer Type_Info_Relative_Slice :: runtime.Type_Info_Relative_Slice +Type_Info_Matrix :: runtime.Type_Info_Matrix Type_Info_Enum_Value :: runtime.Type_Info_Enum_Value @@ -66,6 +67,7 @@ Type_Kind :: enum { Simd_Vector, Relative_Pointer, Relative_Slice, + Matrix, } @@ -99,6 +101,7 @@ type_kind :: proc(T: typeid) -> Type_Kind { case Type_Info_Simd_Vector: return .Simd_Vector case Type_Info_Relative_Pointer: return .Relative_Pointer case Type_Info_Relative_Slice: return .Relative_Slice + case Type_Info_Matrix: return .Matrix } } @@ -1401,7 +1404,8 @@ equal :: proc(a, b: any, including_indirect_array_recursion := false, recursion_ Type_Info_Bit_Set, Type_Info_Enum, Type_Info_Simd_Vector, - Type_Info_Relative_Pointer: + Type_Info_Relative_Pointer, + Type_Info_Matrix: return mem.compare_byte_ptrs((^byte)(a.data), (^byte)(b.data), t.size) == 0 case Type_Info_String: diff --git a/core/reflect/types.odin b/core/reflect/types.odin index d0a96a088..cf79abb07 100644 --- a/core/reflect/types.odin +++ b/core/reflect/types.odin @@ -164,6 +164,12 @@ are_types_identical :: proc(a, b: ^Type_Info) -> bool { case Type_Info_Relative_Slice: y := b.variant.(Type_Info_Relative_Slice) or_return return x.base_integer == y.base_integer && x.slice == y.slice + + case Type_Info_Matrix: + y := b.variant.(Type_Info_Matrix) or_return + if x.row_count != y.row_count { return false } + if x.column_count != y.column_count { return false } + return are_types_identical(x.elem, y.elem) } return false @@ -584,6 +590,14 @@ write_type_writer :: proc(w: io.Writer, ti: ^Type_Info, n_written: ^int = nil) - write_type(w, info.base_integer, &n) or_return io.write_string(w, ") ", &n) or_return write_type(w, info.slice, &n) or_return + + case Type_Info_Matrix: + io.write_string(w, "[", &n) or_return + io.write_i64(w, i64(info.row_count), 10, &n) or_return + io.write_string(w, "; ", &n) or_return + io.write_i64(w, i64(info.column_count), 10, &n) or_return + io.write_string(w, "]", &n) or_return + write_type(w, info.elem, &n) or_return } return diff --git a/core/runtime/core.odin b/core/runtime/core.odin index 36a88a8b5..611b4002c 100644 --- a/core/runtime/core.odin +++ b/core/runtime/core.odin @@ -162,6 +162,13 @@ Type_Info_Relative_Slice :: struct { slice: ^Type_Info, base_integer: ^Type_Info, } +Type_Info_Matrix :: struct { + elem: ^Type_Info, + elem_size: int, + stride: int, // bytes + row_count: int, + column_count: int, +} Type_Info_Flag :: enum u8 { Comparable = 0, @@ -202,6 +209,7 @@ Type_Info :: struct { Type_Info_Simd_Vector, Type_Info_Relative_Pointer, Type_Info_Relative_Slice, + Type_Info_Matrix, }, } @@ -233,6 +241,7 @@ Typeid_Kind :: enum u8 { Simd_Vector, Relative_Pointer, Relative_Slice, + Matrix, } #assert(len(Typeid_Kind) < 32) diff --git a/core/runtime/print.odin b/core/runtime/print.odin index 3ccd4ef90..f32ac0831 100644 --- a/core/runtime/print.odin +++ b/core/runtime/print.odin @@ -370,5 +370,13 @@ print_type :: proc "contextless" (ti: ^Type_Info) { print_type(info.base_integer) print_string(") ") print_type(info.slice) + + case Type_Info_Matrix: + print_string("[") + print_u64(u64(info.row_count)) + print_string("; ") + print_u64(u64(info.column_count)) + print_string("]") + print_type(info.elem) } } diff --git a/src/check_expr.cpp b/src/check_expr.cpp index 513144f11..85f2eeb23 100644 --- a/src/check_expr.cpp +++ b/src/check_expr.cpp @@ -6207,6 +6207,17 @@ bool check_set_index_data(Operand *o, Type *t, bool indirection, i64 *max_count, } o->type = t->EnumeratedArray.elem; return true; + + case Type_Matrix: + *max_count = t->Matrix.column_count; + if (indirection) { + o->mode = Addressing_Variable; + } else if (o->mode != Addressing_Variable && + o->mode != Addressing_Constant) { + o->mode = Addressing_Value; + } + o->type = alloc_type_array(t->Matrix.elem, t->Matrix.row_count); + return true; case Type_Slice: o->type = t->Slice.elem; @@ -6505,6 +6516,11 @@ void check_promote_optional_ok(CheckerContext *c, Operand *x, Type **val_type_, } +void check_matrix_index_expr(CheckerContext *c, Operand *o, Ast *node, Type *type_hint) { + error(node, "TODO: matrix index expressions"); +} + + ExprKind check_expr_base_internal(CheckerContext *c, Operand *o, Ast *node, Type *type_hint) { u32 prev_state_flags = c->state_flags; defer (c->state_flags = prev_state_flags); @@ -8202,6 +8218,8 @@ ExprKind check_expr_base_internal(CheckerContext *c, Operand *o, Ast *node, Type // Okay } else if (is_type_relative_slice(t)) { // Okay + } else if (is_type_matrix(t)) { + // Okay } else { valid = false; } @@ -8266,10 +8284,14 @@ ExprKind check_expr_base_internal(CheckerContext *c, Operand *o, Ast *node, Type } } } + + if (type_hint != nullptr && is_type_matrix(t)) { + // TODO(bill): allow matrix columns to be assignable to other types which are the same internally + // if a type hint exists + } + case_end; - - case_ast_node(se, SliceExpr, node); check_expr(c, o, se->expr); node->viral_state_flags |= se->expr->viral_state_flags; @@ -8442,7 +8464,12 @@ ExprKind check_expr_base_internal(CheckerContext *c, Operand *o, Ast *node, Type } case_end; - + + case_ast_node(mie, MatrixIndexExpr, node); + check_matrix_index_expr(c, o, node, type_hint); + o->expr = node; + return Expr_Expr; + case_end; case_ast_node(ce, CallExpr, node); return check_call_expr(c, o, node, ce->proc, ce->args, ce->inlining, type_hint); @@ -8952,6 +8979,15 @@ gbString write_expr_to_string(gbString str, Ast *node, bool shorthand) { str = gb_string_append_rune(str, ']'); case_end; + case_ast_node(mie, MatrixIndexExpr, node); + str = write_expr_to_string(str, mie->expr, shorthand); + str = gb_string_append_rune(str, '['); + str = write_expr_to_string(str, mie->row_index, shorthand); + str = gb_string_appendc(str, "; "); + str = write_expr_to_string(str, mie->column_index, shorthand); + str = gb_string_append_rune(str, ']'); + case_end; + case_ast_node(e, Ellipsis, node); str = gb_string_appendc(str, ".."); str = write_expr_to_string(str, e->expr, shorthand); @@ -9023,6 +9059,16 @@ gbString write_expr_to_string(gbString str, Ast *node, bool shorthand) { str = gb_string_append_rune(str, ']'); str = write_expr_to_string(str, mt->value, shorthand); case_end; + + case_ast_node(mt, MatrixType, node); + str = gb_string_append_rune(str, '['); + str = write_expr_to_string(str, mt->row_count, shorthand); + str = gb_string_appendc(str, "; "); + str = write_expr_to_string(str, mt->column_count, shorthand); + str = gb_string_append_rune(str, ']'); + str = write_expr_to_string(str, mt->elem, shorthand); + case_end; + case_ast_node(f, Field, node); if (f->flags&FieldFlag_using) { diff --git a/src/check_type.cpp b/src/check_type.cpp index 0d5c0f977..e752f192d 100644 --- a/src/check_type.cpp +++ b/src/check_type.cpp @@ -2200,6 +2200,63 @@ void check_map_type(CheckerContext *ctx, Type *type, Ast *node) { // error(node, "'map' types are not yet implemented"); } +void check_matrix_type(CheckerContext *ctx, Type **type, Ast *node) { + ast_node(mt, MatrixType, node); + + Operand row = {}; + Operand column = {}; + + i64 row_count = check_array_count(ctx, &row, mt->row_count); + i64 column_count = check_array_count(ctx, &column, mt->column_count); + + Type *elem = check_type_expr(ctx, mt->elem, nullptr); + + Type *generic_row = nullptr; + Type *generic_column = nullptr; + + if (row.mode == Addressing_Type && row.type->kind == Type_Generic) { + generic_row = row.type; + } + + if (column.mode == Addressing_Type && column.type->kind == Type_Generic) { + generic_column = column.type; + } + + if (row_count < MIN_MATRIX_ELEMENT_COUNT && generic_row == nullptr) { + gbString s = expr_to_string(row.expr); + error(row.expr, "Invalid matrix row count, expected %d+ rows, got %s", MIN_MATRIX_ELEMENT_COUNT, s); + gb_string_free(s); + } + + if (column_count < MIN_MATRIX_ELEMENT_COUNT && generic_column == nullptr) { + gbString s = expr_to_string(column.expr); + error(column.expr, "Invalid matrix column count, expected %d+ rows, got %s", MIN_MATRIX_ELEMENT_COUNT, s); + gb_string_free(s); + } + + if (row_count*column_count > MAX_MATRIX_ELEMENT_COUNT) { + i64 element_count = row_count*column_count; + error(column.expr, "Matrix types are limited to a maximum of %d elements, got %lld", MAX_MATRIX_ELEMENT_COUNT, cast(long long)element_count); + } + + if (is_type_integer(elem)) { + // okay + } else if (is_type_float(elem)) { + // okay + } else if (is_type_complex(elem)) { + // okay + } else { + gbString s = type_to_string(elem); + error(column.expr, "Matrix elements types are limited to integers, floats, and complex, got %s", s); + gb_string_free(s); + } + + *type = alloc_type_matrix(elem, row_count, column_count, generic_row, generic_column); + + return; +} + + Type *make_soa_struct_internal(CheckerContext *ctx, Ast *array_typ_expr, Ast *elem_expr, Type *elem, i64 count, Type *generic_type, StructSoaKind soa_kind) { Type *bt_elem = base_type(elem); @@ -2785,6 +2842,17 @@ bool check_type_internal(CheckerContext *ctx, Ast *e, Type **type, Type *named_t return true; } case_end; + + + case_ast_node(mt, MatrixType, e); + bool ips = ctx->in_polymorphic_specialization; + defer (ctx->in_polymorphic_specialization = ips); + ctx->in_polymorphic_specialization = false; + + check_matrix_type(ctx, type, e); + set_base_type(named_type, *type); + return true; + case_end; } *type = t_invalid; diff --git a/src/checker.cpp b/src/checker.cpp index d3c0080de..8711fdc0c 100644 --- a/src/checker.cpp +++ b/src/checker.cpp @@ -2458,6 +2458,7 @@ void init_core_type_info(Checker *c) { t_type_info_simd_vector = find_core_type(c, str_lit("Type_Info_Simd_Vector")); t_type_info_relative_pointer = find_core_type(c, str_lit("Type_Info_Relative_Pointer")); t_type_info_relative_slice = find_core_type(c, str_lit("Type_Info_Relative_Slice")); + t_type_info_matrix = find_core_type(c, str_lit("Type_Info_Matrix")); t_type_info_named_ptr = alloc_type_pointer(t_type_info_named); t_type_info_integer_ptr = alloc_type_pointer(t_type_info_integer); @@ -2485,6 +2486,7 @@ void init_core_type_info(Checker *c) { t_type_info_simd_vector_ptr = alloc_type_pointer(t_type_info_simd_vector); t_type_info_relative_pointer_ptr = alloc_type_pointer(t_type_info_relative_pointer); t_type_info_relative_slice_ptr = alloc_type_pointer(t_type_info_relative_slice); + t_type_info_matrix_ptr = alloc_type_pointer(t_type_info_matrix); } void init_mem_allocator(Checker *c) { diff --git a/src/llvm_backend_general.cpp b/src/llvm_backend_general.cpp index 094275429..ee8f220ef 100644 --- a/src/llvm_backend_general.cpp +++ b/src/llvm_backend_general.cpp @@ -1930,6 +1930,24 @@ LLVMTypeRef lb_type_internal(lbModule *m, Type *type) { fields[1] = base_integer; return LLVMStructTypeInContext(ctx, fields, field_count, false); } + + case Type_Matrix: + { + i64 size = type_size_of(type); + i64 elem_size = type_size_of(type->Matrix.elem); + GB_ASSERT(elem_size > 0); + i64 elem_count = size/elem_size; + GB_ASSERT(elem_count > 0); + + m->internal_type_level -= 1; + + LLVMTypeRef elem = lb_type(m, type->Matrix.elem); + LLVMTypeRef t = LLVMArrayType(elem, cast(unsigned)elem_count); + + m->internal_type_level += 1; + return t; + } + } GB_PANIC("Invalid type %s", type_to_string(type)); diff --git a/src/llvm_backend_type.cpp b/src/llvm_backend_type.cpp index e90bb6f16..82e20bf60 100644 --- a/src/llvm_backend_type.cpp +++ b/src/llvm_backend_type.cpp @@ -42,6 +42,7 @@ lbValue lb_typeid(lbModule *m, Type *type) { case Type_Pointer: kind = Typeid_Pointer; break; case Type_MultiPointer: kind = Typeid_Multi_Pointer; break; case Type_Array: kind = Typeid_Array; break; + case Type_Matrix: kind = Typeid_Matrix; break; case Type_EnumeratedArray: kind = Typeid_Enumerated_Array; break; case Type_Slice: kind = Typeid_Slice; break; case Type_DynamicArray: kind = Typeid_Dynamic_Array; break; @@ -868,7 +869,25 @@ void lb_setup_type_info_data(lbProcedure *p) { // NOTE(bill): Setup type_info da lb_emit_store(p, tag, res); } break; + case Type_Matrix: + { + tag = lb_const_ptr_cast(m, variant_ptr, t_type_info_matrix_ptr); + i64 ez = type_size_of(t->Matrix.elem); + LLVMValueRef vals[5] = { + lb_get_type_info_ptr(m, t->Matrix.elem).value, + lb_const_int(m, t_int, ez).value, + lb_const_int(m, t_int, matrix_type_stride(t)).value, + lb_const_int(m, t_int, t->Matrix.row_count).value, + lb_const_int(m, t_int, t->Matrix.column_count).value, + }; + + lbValue res = {}; + res.type = type_deref(tag.type); + res.value = llvm_const_named_struct(m, res.type, vals, gb_count_of(vals)); + lb_emit_store(p, tag, res); + } + break; } diff --git a/src/parser.cpp b/src/parser.cpp index 716986b5d..499bd337b 100644 --- a/src/parser.cpp +++ b/src/parser.cpp @@ -159,6 +159,11 @@ Ast *clone_ast(Ast *node) { n->IndexExpr.expr = clone_ast(n->IndexExpr.expr); n->IndexExpr.index = clone_ast(n->IndexExpr.index); break; + case Ast_MatrixIndexExpr: + n->MatrixIndexExpr.expr = clone_ast(n->MatrixIndexExpr.expr); + n->MatrixIndexExpr.row_index = clone_ast(n->MatrixIndexExpr.row_index); + n->MatrixIndexExpr.column_index = clone_ast(n->MatrixIndexExpr.column_index); + break; case Ast_DerefExpr: n->DerefExpr.expr = clone_ast(n->DerefExpr.expr); break; @@ -371,6 +376,11 @@ Ast *clone_ast(Ast *node) { n->MapType.key = clone_ast(n->MapType.key); n->MapType.value = clone_ast(n->MapType.value); break; + case Ast_MatrixType: + n->MatrixType.row_count = clone_ast(n->MatrixType.row_count); + n->MatrixType.column_count = clone_ast(n->MatrixType.column_count); + n->MatrixType.elem = clone_ast(n->MatrixType.elem); + break; } return n; @@ -574,6 +584,15 @@ Ast *ast_deref_expr(AstFile *f, Ast *expr, Token op) { } +Ast *ast_matrix_index_expr(AstFile *f, Ast *expr, Token open, Token close, Token interval, Ast *row, Ast *column) { + Ast *result = alloc_ast_node(f, Ast_MatrixIndexExpr); + result->MatrixIndexExpr.expr = expr; + result->MatrixIndexExpr.row_index = row; + result->MatrixIndexExpr.column_index = column; + result->MatrixIndexExpr.open = open; + result->MatrixIndexExpr.close = close; + return result; +} Ast *ast_ident(AstFile *f, Token token) { @@ -1066,6 +1085,14 @@ Ast *ast_map_type(AstFile *f, Token token, Ast *key, Ast *value) { return result; } +Ast *ast_matrix_type(AstFile *f, Token token, Ast *row_count, Ast *column_count, Ast *elem) { + Ast *result = alloc_ast_node(f, Ast_MatrixType); + result->MatrixType.token = token; + result->MatrixType.row_count = row_count; + result->MatrixType.column_count = column_count; + result->MatrixType.elem = elem; + return result; +} Ast *ast_foreign_block_decl(AstFile *f, Token token, Ast *foreign_library, Ast *body, CommentGroup *docs) { @@ -2214,6 +2241,19 @@ Ast *parse_operand(AstFile *f, bool lhs) { count_expr = parse_expr(f, false); f->expr_level--; } + if (allow_token(f, Token_Semicolon)) { + Ast *row_count = count_expr; + Ast *column_count = nullptr; + + f->expr_level++; + column_count = parse_expr(f, false); + f->expr_level--; + + expect_token(f, Token_CloseBracket); + + return ast_matrix_type(f, token, row_count, column_count, parse_type(f)); + } + expect_token(f, Token_CloseBracket); return ast_array_type(f, token, count_expr, parse_type(f)); } break; @@ -2676,6 +2716,11 @@ Ast *parse_atom_expr(AstFile *f, Ast *operand, bool lhs) { case Token_RangeHalf: syntax_error(f->curr_token, "Expected a colon, not a range"); /* fallthrough */ + case Token_Semicolon: // matrix index + if (f->curr_token.kind == Token_Semicolon && f->curr_token.string == "\n") { + syntax_error(f->curr_token, "Expected a ';', not a newline"); + } + /* fallthrough */ case Token_Colon: interval = advance_token(f); is_interval = true; @@ -2691,7 +2736,14 @@ Ast *parse_atom_expr(AstFile *f, Ast *operand, bool lhs) { close = expect_token(f, Token_CloseBracket); if (is_interval) { - operand = ast_slice_expr(f, operand, open, close, interval, indices[0], indices[1]); + if (interval.kind == Token_Semicolon) { + if (indices[0] == nullptr || indices[1] == nullptr) { + syntax_error(open, "Matrix index expressions require both row and column indices"); + } + operand = ast_matrix_index_expr(f, operand, open, close, interval, indices[0], indices[1]); + } else { + operand = ast_slice_expr(f, operand, open, close, interval, indices[0], indices[1]); + } } else { operand = ast_index_expr(f, operand, indices[0], open, close); } diff --git a/src/parser.hpp b/src/parser.hpp index f1779bdbc..b58047dfd 100644 --- a/src/parser.hpp +++ b/src/parser.hpp @@ -407,6 +407,7 @@ AST_KIND(_ExprBegin, "", bool) \ bool is_align_stack; \ InlineAsmDialectKind dialect; \ }) \ + AST_KIND(MatrixIndexExpr, "matrix index expression", struct { Ast *expr, *row_index, *column_index; Token open, close; }) \ AST_KIND(_ExprEnd, "", bool) \ AST_KIND(_StmtBegin, "", bool) \ AST_KIND(BadStmt, "bad statement", struct { Token begin, end; }) \ @@ -657,6 +658,12 @@ AST_KIND(_TypeBegin, "", bool) \ Ast *key; \ Ast *value; \ }) \ + AST_KIND(MatrixType, "matrix type", struct { \ + Token token; \ + Ast *row_count; \ + Ast *column_count; \ + Ast *elem; \ + }) \ AST_KIND(_TypeEnd, "", bool) enum AstKind { diff --git a/src/parser_pos.cpp b/src/parser_pos.cpp index 22d12621d..6ef0db215 100644 --- a/src/parser_pos.cpp +++ b/src/parser_pos.cpp @@ -35,6 +35,7 @@ Token ast_token(Ast *node) { } return node->ImplicitSelectorExpr.token; case Ast_IndexExpr: return node->IndexExpr.open; + case Ast_MatrixIndexExpr: return node->MatrixIndexExpr.open; case Ast_SliceExpr: return node->SliceExpr.open; case Ast_Ellipsis: return node->Ellipsis.token; case Ast_FieldValue: return node->FieldValue.eq; @@ -103,6 +104,7 @@ Token ast_token(Ast *node) { case Ast_EnumType: return node->EnumType.token; case Ast_BitSetType: return node->BitSetType.token; case Ast_MapType: return node->MapType.token; + case Ast_MatrixType: return node->MatrixType.token; } return empty_token; @@ -168,6 +170,7 @@ Token ast_end_token(Ast *node) { } return node->ImplicitSelectorExpr.token; case Ast_IndexExpr: return node->IndexExpr.close; + case Ast_MatrixIndexExpr: return node->MatrixIndexExpr.close; case Ast_SliceExpr: return node->SliceExpr.close; case Ast_Ellipsis: if (node->Ellipsis.expr) { @@ -345,6 +348,7 @@ Token ast_end_token(Ast *node) { } return ast_end_token(node->BitSetType.elem); case Ast_MapType: return ast_end_token(node->MapType.value); + case Ast_MatrixType: return ast_end_token(node->MatrixType.elem); } return empty_token; diff --git a/src/types.cpp b/src/types.cpp index a808b54fb..0313ade60 100644 --- a/src/types.cpp +++ b/src/types.cpp @@ -270,6 +270,13 @@ struct TypeProc { TYPE_KIND(RelativeSlice, struct { \ Type *slice_type; \ Type *base_integer; \ + }) \ + TYPE_KIND(Matrix, struct { \ + Type *elem; \ + i64 row_count; \ + i64 column_count; \ + Type *generic_row_count; \ + Type *generic_column_count; \ }) @@ -341,6 +348,7 @@ enum Typeid_Kind : u8 { Typeid_Simd_Vector, Typeid_Relative_Pointer, Typeid_Relative_Slice, + Typeid_Matrix, }; // IMPORTANT NOTE(bill): This must match the same as the in core.odin @@ -349,6 +357,13 @@ enum TypeInfoFlag : u32 { TypeInfoFlag_Simple_Compare = 1<<1, }; + +enum : int { + MIN_MATRIX_ELEMENT_COUNT = 1, + MAX_MATRIX_ELEMENT_COUNT = 16, +}; + + bool is_type_comparable(Type *t); bool is_type_simple_compare(Type *t); @@ -622,6 +637,7 @@ gb_global Type *t_type_info_bit_set = nullptr; gb_global Type *t_type_info_simd_vector = nullptr; gb_global Type *t_type_info_relative_pointer = nullptr; gb_global Type *t_type_info_relative_slice = nullptr; +gb_global Type *t_type_info_matrix = nullptr; gb_global Type *t_type_info_named_ptr = nullptr; gb_global Type *t_type_info_integer_ptr = nullptr; @@ -649,6 +665,7 @@ gb_global Type *t_type_info_bit_set_ptr = nullptr; gb_global Type *t_type_info_simd_vector_ptr = nullptr; gb_global Type *t_type_info_relative_pointer_ptr = nullptr; gb_global Type *t_type_info_relative_slice_ptr = nullptr; +gb_global Type *t_type_info_matrix_ptr = nullptr; gb_global Type *t_allocator = nullptr; gb_global Type *t_allocator_ptr = nullptr; @@ -804,6 +821,24 @@ Type *alloc_type_array(Type *elem, i64 count, Type *generic_count = nullptr) { return t; } +Type *alloc_type_matrix(Type *elem, i64 row_count, i64 column_count, Type *generic_row_count = nullptr, Type *generic_column_count = nullptr) { + if (generic_row_count != nullptr || generic_column_count != nullptr) { + Type *t = alloc_type(Type_Matrix); + t->Matrix.elem = elem; + t->Matrix.row_count = row_count; + t->Matrix.column_count = column_count; + t->Matrix.generic_row_count = generic_row_count; + t->Matrix.generic_column_count = generic_column_count; + return t; + } + Type *t = alloc_type(Type_Matrix); + t->Matrix.elem = elem; + t->Matrix.row_count = row_count; + t->Matrix.column_count = column_count; + return t; +} + + Type *alloc_type_enumerated_array(Type *elem, Type *index, ExactValue const *min_value, ExactValue const *max_value, TokenKind op) { Type *t = alloc_type(Type_EnumeratedArray); t->EnumeratedArray.elem = elem; @@ -1208,6 +1243,20 @@ bool is_type_enumerated_array(Type *t) { t = base_type(t); return t->kind == Type_EnumeratedArray; } +bool is_type_matrix(Type *t) { + t = base_type(t); + return t->kind == Type_Matrix; +} + +i64 matrix_type_stride(Type *t) { + t = base_type(t); + GB_ASSERT(t->kind == Type_Matrix); + i64 align = type_align_of(t); + i64 elem_size = type_size_of(t->Matrix.elem); + i64 stride = align_formula(elem_size*t->Matrix.row_count, align); + return stride; +} + bool is_type_dynamic_array(Type *t) { t = base_type(t); return t->kind == Type_DynamicArray; @@ -1241,6 +1290,8 @@ Type *base_array_type(Type *t) { return bt->EnumeratedArray.elem; } else if (is_type_simd_vector(bt)) { return bt->SimdVector.elem; + } else if (is_type_matrix(bt)) { + return bt->Matrix.elem; } return t; } @@ -1315,11 +1366,16 @@ i64 get_array_type_count(Type *t) { Type *core_array_type(Type *t) { for (;;) { t = base_array_type(t); - if (t->kind != Type_Array && t->kind != Type_EnumeratedArray && t->kind != Type_SimdVector) { + switch (t->kind) { + case Type_Array: + case Type_EnumeratedArray: + case Type_SimdVector: + case Type_Matrix: break; + default: + return t; } } - return t; } @@ -1934,6 +1990,8 @@ bool is_type_comparable(Type *t) { return is_type_comparable(t->Array.elem); case Type_Proc: return true; + case Type_Matrix: + return is_type_comparable(t->Matrix.elem); case Type_BitSet: return true; @@ -1995,6 +2053,9 @@ bool is_type_simple_compare(Type *t) { case Type_Proc: case Type_BitSet: return true; + + case Type_Matrix: + return is_type_simple_compare(t->Matrix.elem); case Type_Struct: for_array(i, t->Struct.fields) { @@ -2107,6 +2168,14 @@ bool are_types_identical(Type *x, Type *y) { return (x->Array.count == y->Array.count) && are_types_identical(x->Array.elem, y->Array.elem); } break; + + case Type_Matrix: + if (y->kind == Type_Matrix) { + return x->Matrix.row_count == y->Matrix.row_count && + x->Matrix.column_count == y->Matrix.column_count && + are_types_identical(x->Matrix.elem, y->Matrix.elem); + } + break; case Type_DynamicArray: if (y->kind == Type_DynamicArray) { @@ -2982,7 +3051,7 @@ i64 type_align_of_internal(Type *t, TypePath *path) { if (path->failure) { return FAILURE_ALIGNMENT; } - i64 align = type_align_of_internal(t->Array.elem, path); + i64 align = type_align_of_internal(elem, path); if (pop) type_path_pop(path); return align; } @@ -2993,7 +3062,7 @@ i64 type_align_of_internal(Type *t, TypePath *path) { if (path->failure) { return FAILURE_ALIGNMENT; } - i64 align = type_align_of_internal(t->EnumeratedArray.elem, path); + i64 align = type_align_of_internal(elem, path); if (pop) type_path_pop(path); return align; } @@ -3102,6 +3171,22 @@ i64 type_align_of_internal(Type *t, TypePath *path) { // IMPORTANT TODO(bill): Figure out the alignment of vector types return gb_clamp(next_pow2(type_size_of_internal(t, path)), 1, build_context.max_align); } + + case Type_Matrix: { + Type *elem = t->Matrix.elem; + i64 row_count = t->Matrix.row_count; + // i64 column_count = t->Matrix.column_count; + bool pop = type_path_push(path, elem); + if (path->failure) { + return FAILURE_ALIGNMENT; + } + i64 elem_align = type_align_of_internal(elem, path); + if (pop) type_path_pop(path); + + i64 align = gb_clamp(elem_align * row_count, elem_align, build_context.max_align); + + return align; + } case Type_RelativePointer: return type_align_of_internal(t->RelativePointer.base_integer, path); @@ -3369,6 +3454,26 @@ i64 type_size_of_internal(Type *t, TypePath *path) { Type *elem = t->SimdVector.elem; return count * type_size_of_internal(elem, path); } + + case Type_Matrix: { + Type *elem = t->Matrix.elem; + i64 row_count = t->Matrix.row_count; + i64 column_count = t->Matrix.column_count; + bool pop = type_path_push(path, elem); + if (path->failure) { + return FAILURE_SIZE; + } + i64 elem_size = type_size_of_internal(elem, path); + if (pop) type_path_pop(path); + i64 align = type_align_of(t); + + /* + [3; 4]f32 -> [4]{x, y, z, _: f32} // extra padding for alignment reasons + */ + + i64 size = align_formula(elem_size * row_count, align) * column_count; + return size; + } case Type_RelativePointer: return type_size_of_internal(t->RelativePointer.base_integer, path);