From a750fc0ba63c9f1461bba4cc0446b1b4c2d2b3a9 Mon Sep 17 00:00:00 2001 From: gingerBill Date: Tue, 19 Mar 2024 21:05:23 +0000 Subject: [PATCH] Add `#row_major matrix[R, C]T` As well as `#column_major matrix[R, C]T` as an alias for just `matrix[R, C]T`. This is because some libraries require a row_major internal layout but still want to be used with row or major oriented vectors. --- base/runtime/core.odin | 4 ++++ core/fmt/fmt.odin | 12 +++++++++-- core/reflect/types.odin | 4 ++++ src/check_expr.cpp | 12 ++++++++--- src/check_type.cpp | 2 +- src/llvm_backend_const.cpp | 4 ++-- src/llvm_backend_expr.cpp | 39 +++++++++++++++++++++++++++--------- src/llvm_backend_type.cpp | 3 ++- src/llvm_backend_utility.cpp | 22 +++++++++++++------- src/parser.cpp | 13 ++++++++++++ src/parser.hpp | 1 + src/types.cpp | 19 ++++++++++++++---- 12 files changed, 105 insertions(+), 30 deletions(-) diff --git a/base/runtime/core.odin b/base/runtime/core.odin index 8f27ca674..7ad3ef1d6 100644 --- a/base/runtime/core.odin +++ b/base/runtime/core.odin @@ -177,6 +177,10 @@ Type_Info_Matrix :: struct { row_count: int, column_count: int, // Total element count = column_count * elem_stride + layout: enum u8 { + Column_Major, // array of column vectors + Row_Major, // array of row vectors + }, } Type_Info_Soa_Pointer :: struct { elem: ^Type_Info, diff --git a/core/fmt/fmt.odin b/core/fmt/fmt.odin index 02803f882..6f9801bc8 100644 --- a/core/fmt/fmt.odin +++ b/core/fmt/fmt.odin @@ -2396,7 +2396,11 @@ fmt_matrix :: proc(fi: ^Info, v: any, verb: rune, info: runtime.Type_Info_Matrix for col in 0.. 0 { io.write_string(fi.writer, ", ", &fi.n) } - offset := (row + col*info.elem_stride)*info.elem_size + offset: int + switch info.layout { + case .Column_Major: offset = (row + col*info.elem_stride)*info.elem_size + case .Row_Major: offset = (col + row*info.elem_stride)*info.elem_size + } data := uintptr(v.data) + uintptr(offset) fmt_arg(fi, any{rawptr(data), info.elem.id}, verb) @@ -2410,7 +2414,11 @@ fmt_matrix :: proc(fi: ^Info, v: any, verb: rune, info: runtime.Type_Info_Matrix for col in 0.. 0 { io.write_string(fi.writer, ", ", &fi.n) } - offset := (row + col*info.elem_stride)*info.elem_size + offset: int + switch info.layout { + case .Column_Major: offset = (row + col*info.elem_stride)*info.elem_size + case .Row_Major: offset = (col + row*info.elem_stride)*info.elem_size + } data := uintptr(v.data) + uintptr(offset) fmt_arg(fi, any{rawptr(data), info.elem.id}, verb) diff --git a/core/reflect/types.odin b/core/reflect/types.odin index 2b96dd4fb..9cff46a00 100644 --- a/core/reflect/types.odin +++ b/core/reflect/types.odin @@ -173,6 +173,7 @@ are_types_identical :: proc(a, b: ^Type_Info) -> bool { y := b.variant.(Type_Info_Matrix) or_return if x.row_count != y.row_count { return false } if x.column_count != y.column_count { return false } + if x.layout != y.layout { return false } return are_types_identical(x.elem, y.elem) case Type_Info_Bit_Field: @@ -689,6 +690,9 @@ write_type_writer :: proc(w: io.Writer, ti: ^Type_Info, n_written: ^int = nil) - write_type(w, info.pointer, &n) or_return case Type_Info_Matrix: + if info.layout == .Row_Major { + io.write_string(w, "#row_major ", &n) or_return + } io.write_string(w, "matrix[", &n) or_return io.write_i64(w, i64(info.row_count), 10, &n) or_return io.write_string(w, ", ", &n) or_return diff --git a/src/check_expr.cpp b/src/check_expr.cpp index f359d5a54..67c7f1a04 100644 --- a/src/check_expr.cpp +++ b/src/check_expr.cpp @@ -3431,6 +3431,11 @@ gb_internal void check_binary_matrix(CheckerContext *c, Token const &op, Operand if (xt->Matrix.column_count != yt->Matrix.row_count) { goto matrix_error; } + + if (xt->Matrix.is_row_major != yt->Matrix.is_row_major) { + goto matrix_error; + } + x->mode = Addressing_Value; if (are_types_identical(xt, yt)) { if (!is_type_named(x->type) && is_type_named(y->type)) { @@ -3438,7 +3443,8 @@ gb_internal void check_binary_matrix(CheckerContext *c, Token const &op, Operand x->type = y->type; } } else { - x->type = alloc_type_matrix(xt->Matrix.elem, xt->Matrix.row_count, yt->Matrix.column_count); + bool is_row_major = xt->Matrix.is_row_major && yt->Matrix.is_row_major; + x->type = alloc_type_matrix(xt->Matrix.elem, xt->Matrix.row_count, yt->Matrix.column_count, nullptr, nullptr, is_row_major); } goto matrix_success; } else if (yt->kind == Type_Array) { @@ -3452,7 +3458,7 @@ gb_internal void check_binary_matrix(CheckerContext *c, Token const &op, Operand // Treat arrays as column vectors x->mode = Addressing_Value; - if (type_hint == nullptr && xt->Matrix.row_count == yt->Array.count) { + if (xt->Matrix.row_count == yt->Array.count) { x->type = y->type; } else { x->type = alloc_type_matrix(xt->Matrix.elem, xt->Matrix.row_count, 1); @@ -3483,7 +3489,7 @@ gb_internal void check_binary_matrix(CheckerContext *c, Token const &op, Operand // Treat arrays as row vectors x->mode = Addressing_Value; - if (type_hint == nullptr && yt->Matrix.column_count == xt->Array.count) { + if (yt->Matrix.column_count == xt->Array.count) { x->type = x->type; } else { x->type = alloc_type_matrix(yt->Matrix.elem, 1, yt->Matrix.column_count); diff --git a/src/check_type.cpp b/src/check_type.cpp index d5cf187a4..3fe633892 100644 --- a/src/check_type.cpp +++ b/src/check_type.cpp @@ -2658,7 +2658,7 @@ gb_internal void check_matrix_type(CheckerContext *ctx, Type **type, Ast *node) } type_assign:; - *type = alloc_type_matrix(elem, row_count, column_count, generic_row, generic_column); + *type = alloc_type_matrix(elem, row_count, column_count, generic_row, generic_column, mt->is_row_major); return; } diff --git a/src/llvm_backend_const.cpp b/src/llvm_backend_const.cpp index 2291f24ac..bbb0b8387 100644 --- a/src/llvm_backend_const.cpp +++ b/src/llvm_backend_const.cpp @@ -1302,11 +1302,11 @@ gb_internal lbValue lb_const_value(lbModule *m, Type *type, ExactValue value, bo GB_ASSERT_MSG(elem_count == max_count, "%td != %td", elem_count, max_count); LLVMValueRef *values = gb_alloc_array(temporary_allocator(), LLVMValueRef, cast(isize)total_count); - for_array(i, cl->elems) { TypeAndValue tav = cl->elems[i]->tav; GB_ASSERT(tav.mode != Addressing_Invalid); - i64 offset = matrix_row_major_index_to_offset(type, i); + i64 offset = 0; + offset = matrix_row_major_index_to_offset(type, i); values[offset] = lb_const_value(m, elem_type, tav.value, allow_local).value; } for (isize i = 0; i < total_count; i++) { diff --git a/src/llvm_backend_expr.cpp b/src/llvm_backend_expr.cpp index 98618798b..6eb8fdcc6 100644 --- a/src/llvm_backend_expr.cpp +++ b/src/llvm_backend_expr.cpp @@ -684,12 +684,6 @@ gb_internal lbValue lb_emit_matrix_flatten(lbProcedure *p, lbValue m, Type *type Type *mt = base_type(m.type); GB_ASSERT(mt->kind == Type_Matrix); - // TODO(bill): Determine why this fails on Windows sometimes - if (false && lb_is_matrix_simdable(mt)) { - LLVMValueRef vector = lb_matrix_to_trimmed_vector(p, m); - return lb_matrix_cast_vector_to_type(p, vector, type); - } - lbAddr res = lb_add_local_generated(p, type, true); i64 row_count = mt->Matrix.row_count; @@ -763,6 +757,7 @@ gb_internal lbValue lb_emit_matrix_mul(lbProcedure *p, lbValue lhs, lbValue rhs, GB_ASSERT(is_type_matrix(yt)); GB_ASSERT(xt->Matrix.column_count == yt->Matrix.row_count); GB_ASSERT(are_types_identical(xt->Matrix.elem, yt->Matrix.elem)); + GB_ASSERT(xt->Matrix.is_row_major == yt->Matrix.is_row_major); Type *elem = xt->Matrix.elem; @@ -770,7 +765,7 @@ gb_internal lbValue lb_emit_matrix_mul(lbProcedure *p, lbValue lhs, lbValue rhs, unsigned inner = cast(unsigned)xt->Matrix.column_count; unsigned outer_columns = cast(unsigned)yt->Matrix.column_count; - if (lb_is_matrix_simdable(xt)) { + if (!xt->Matrix.is_row_major && lb_is_matrix_simdable(xt)) { unsigned x_stride = cast(unsigned)matrix_type_stride_in_elems(xt); unsigned y_stride = cast(unsigned)matrix_type_stride_in_elems(yt); @@ -812,7 +807,7 @@ gb_internal lbValue lb_emit_matrix_mul(lbProcedure *p, lbValue lhs, lbValue rhs, return lb_addr_load(p, res); } - { + if (!xt->Matrix.is_row_major) { lbAddr res = lb_add_local_generated(p, type, true); auto inners = slice_make(permanent_allocator(), inner); @@ -835,6 +830,30 @@ gb_internal lbValue lb_emit_matrix_mul(lbProcedure *p, lbValue lhs, lbValue rhs, } } + return lb_addr_load(p, res); + } else { + lbAddr res = lb_add_local_generated(p, type, true); + + auto inners = slice_make(permanent_allocator(), inner); + + for (unsigned i = 0; i < outer_rows; i++) { + for (unsigned j = 0; j < outer_columns; j++) { + lbValue dst = lb_emit_matrix_epi(p, res.addr, i, j); + for (unsigned k = 0; k < inner; k++) { + inners[k][0] = lb_emit_matrix_ev(p, lhs, i, k); + inners[k][1] = lb_emit_matrix_ev(p, rhs, k, j); + } + + lbValue sum = lb_const_nil(p->module, elem); + for (unsigned k = 0; k < inner; k++) { + lbValue a = inners[k][0]; + lbValue b = inners[k][1]; + sum = lb_emit_mul_add(p, a, b, sum, elem); + } + lb_emit_store(p, dst, sum); + } + } + return lb_addr_load(p, res); } } @@ -855,7 +874,7 @@ gb_internal lbValue lb_emit_matrix_mul_vector(lbProcedure *p, lbValue lhs, lbVal Type *elem = mt->Matrix.elem; - if (lb_is_matrix_simdable(mt)) { + if (!mt->Matrix.is_row_major && lb_is_matrix_simdable(mt)) { unsigned stride = cast(unsigned)matrix_type_stride_in_elems(mt); unsigned row_count = cast(unsigned)mt->Matrix.row_count; @@ -924,7 +943,7 @@ gb_internal lbValue lb_emit_vector_mul_matrix(lbProcedure *p, lbValue lhs, lbVal Type *elem = mt->Matrix.elem; - if (lb_is_matrix_simdable(mt)) { + if (!mt->Matrix.is_row_major && lb_is_matrix_simdable(mt)) { unsigned stride = cast(unsigned)matrix_type_stride_in_elems(mt); unsigned row_count = cast(unsigned)mt->Matrix.row_count; diff --git a/src/llvm_backend_type.cpp b/src/llvm_backend_type.cpp index aec1fb201..de83f5058 100644 --- a/src/llvm_backend_type.cpp +++ b/src/llvm_backend_type.cpp @@ -979,12 +979,13 @@ gb_internal void lb_setup_type_info_data_giant_array(lbModule *m, i64 global_typ tag_type = t_type_info_matrix; i64 ez = type_size_of(t->Matrix.elem); - LLVMValueRef vals[5] = { + LLVMValueRef vals[6] = { get_type_info_ptr(m, t->Matrix.elem), lb_const_int(m, t_int, ez).value, lb_const_int(m, t_int, matrix_type_stride_in_elems(t)).value, lb_const_int(m, t_int, t->Matrix.row_count).value, lb_const_int(m, t_int, t->Matrix.column_count).value, + lb_const_int(m, t_u8, cast(u8)t->Matrix.is_row_major).value, }; variant_value = llvm_const_named_struct(m, tag_type, vals, gb_count_of(vals)); diff --git a/src/llvm_backend_utility.cpp b/src/llvm_backend_utility.cpp index f18aa5521..fb61c41c3 100644 --- a/src/llvm_backend_utility.cpp +++ b/src/llvm_backend_utility.cpp @@ -1464,14 +1464,16 @@ gb_internal lbValue lb_emit_matrix_epi(lbProcedure *p, lbValue s, isize row, isi Type *t = s.type; GB_ASSERT(is_type_pointer(t)); Type *mt = base_type(type_deref(t)); - if (column == 0) { - GB_ASSERT_MSG(is_type_matrix(mt) || is_type_array_like(mt), "%s", type_to_string(mt)); - return lb_emit_epi(p, s, row); - } else if (row == 0 && is_type_array_like(mt)) { - return lb_emit_epi(p, s, column); + + if (!mt->Matrix.is_row_major) { + if (column == 0) { + GB_ASSERT_MSG(is_type_matrix(mt) || is_type_array_like(mt), "%s", type_to_string(mt)); + return lb_emit_epi(p, s, row); + } else if (row == 0 && is_type_array_like(mt)) { + return lb_emit_epi(p, s, column); + } } - GB_ASSERT_MSG(is_type_matrix(mt), "%s", type_to_string(mt)); isize offset = matrix_indices_to_offset(mt, row, column); @@ -1491,7 +1493,13 @@ gb_internal lbValue lb_emit_matrix_ep(lbProcedure *p, lbValue s, lbValue row, lb row = lb_emit_conv(p, row, t_int); column = lb_emit_conv(p, column, t_int); - LLVMValueRef index = LLVMBuildAdd(p->builder, row.value, LLVMBuildMul(p->builder, column.value, stride_elems, ""), ""); + LLVMValueRef index = nullptr; + + if (mt->Matrix.is_row_major) { + index = LLVMBuildAdd(p->builder, column.value, LLVMBuildMul(p->builder, row.value, stride_elems, ""), ""); + } else { + index = LLVMBuildAdd(p->builder, row.value, LLVMBuildMul(p->builder, column.value, stride_elems, ""), ""); + } LLVMValueRef indices[2] = { LLVMConstInt(lb_type(p->module, t_int), 0, false), diff --git a/src/parser.cpp b/src/parser.cpp index 1aa40ccbf..eb9e73342 100644 --- a/src/parser.cpp +++ b/src/parser.cpp @@ -2329,6 +2329,19 @@ gb_internal Ast *parse_operand(AstFile *f, bool lhs) { break; } return original_type; + } else if (name.string == "row_major" || + name.string == "column_major") { + Ast *original_type = parse_type(f); + Ast *type = unparen_expr(original_type); + switch (type->kind) { + case Ast_MatrixType: + type->MatrixType.is_row_major = (name.string == "row_major"); + break; + default: + syntax_error(type, "Expected a matrix type after #%.*s, got %.*s", LIT(name.string), LIT(ast_strings[type->kind])); + break; + } + return original_type; } else if (name.string == "partial") { Ast *tag = ast_basic_directive(f, token, name); Ast *original_expr = parse_expr(f, lhs); diff --git a/src/parser.hpp b/src/parser.hpp index f5997c4bd..ff3c5eb34 100644 --- a/src/parser.hpp +++ b/src/parser.hpp @@ -772,6 +772,7 @@ AST_KIND(_TypeBegin, "", bool) \ Ast *row_count; \ Ast *column_count; \ Ast *elem; \ + bool is_row_major; \ }) \ AST_KIND(_TypeEnd, "", bool) diff --git a/src/types.cpp b/src/types.cpp index 5a3ad5d6b..ab5e4de03 100644 --- a/src/types.cpp +++ b/src/types.cpp @@ -281,6 +281,7 @@ struct TypeProc { Type *generic_row_count; \ Type *generic_column_count; \ i64 stride_in_bytes; \ + bool is_row_major; \ }) \ TYPE_KIND(BitField, struct { \ Scope * scope; \ @@ -1002,7 +1003,7 @@ gb_internal Type *alloc_type_array(Type *elem, i64 count, Type *generic_count = return t; } -gb_internal Type *alloc_type_matrix(Type *elem, i64 row_count, i64 column_count, Type *generic_row_count = nullptr, Type *generic_column_count = nullptr) { +gb_internal Type *alloc_type_matrix(Type *elem, i64 row_count, i64 column_count, Type *generic_row_count = nullptr, Type *generic_column_count = nullptr, bool is_row_major = false) { if (generic_row_count != nullptr || generic_column_count != nullptr) { Type *t = alloc_type(Type_Matrix); t->Matrix.elem = elem; @@ -1010,12 +1011,14 @@ gb_internal Type *alloc_type_matrix(Type *elem, i64 row_count, i64 column_count, t->Matrix.column_count = column_count; t->Matrix.generic_row_count = generic_row_count; t->Matrix.generic_column_count = generic_column_count; + t->Matrix.is_row_major = is_row_major; return t; } Type *t = alloc_type(Type_Matrix); t->Matrix.elem = elem; t->Matrix.row_count = row_count; t->Matrix.column_count = column_count; + t->Matrix.is_row_major = is_row_major; return t; } @@ -1512,14 +1515,18 @@ gb_internal i64 matrix_indices_to_offset(Type *t, i64 row_index, i64 column_inde GB_ASSERT(0 <= row_index && row_index < t->Matrix.row_count); GB_ASSERT(0 <= column_index && column_index < t->Matrix.column_count); i64 stride_elems = matrix_type_stride_in_elems(t); - // NOTE(bill): Column-major layout internally - return row_index + stride_elems*column_index; + if (t->Matrix.is_row_major) { + return column_index + stride_elems*row_index; + } else { + // NOTE(bill): Column-major layout internally + return row_index + stride_elems*column_index; + } } gb_internal i64 matrix_row_major_index_to_offset(Type *t, i64 index) { t = base_type(t); GB_ASSERT(t->kind == Type_Matrix); - + i64 row_index = index/t->Matrix.column_count; i64 column_index = index%t->Matrix.column_count; return matrix_indices_to_offset(t, row_index, column_index); @@ -2690,6 +2697,7 @@ gb_internal bool are_types_identical_internal(Type *x, Type *y, bool check_tuple case Type_Matrix: return x->Matrix.row_count == y->Matrix.row_count && x->Matrix.column_count == y->Matrix.column_count && + x->Matrix.is_row_major == y->Matrix.is_row_major && are_types_identical(x->Matrix.elem, y->Matrix.elem); case Type_DynamicArray: @@ -4735,6 +4743,9 @@ gb_internal gbString write_type_to_string(gbString str, Type *type, bool shortha break; case Type_Matrix: + if (type->Matrix.is_row_major) { + str = gb_string_appendc(str, "#row_major "); + } str = gb_string_appendc(str, gb_bprintf("matrix[%d, %d]", cast(int)type->Matrix.row_count, cast(int)type->Matrix.column_count)); str = write_type_to_string(str, type->Matrix.elem); break;