From ba331024af2f5074125442e91dda6c8e63324c8f Mon Sep 17 00:00:00 2001 From: gingerBill Date: Mon, 18 Oct 2021 18:16:52 +0100 Subject: [PATCH] Very basic matrix support in backend --- core/fmt/fmt.odin | 35 +++++++- src/check_expr.cpp | 153 ++++++++++++++++++++++++++++++++++- src/checker.cpp | 8 ++ src/llvm_backend.hpp | 4 + src/llvm_backend_const.cpp | 28 +++++++ src/llvm_backend_expr.cpp | 78 ++++++++++++++++++ src/llvm_backend_utility.cpp | 35 ++++++++ src/types.cpp | 31 ++++++- 8 files changed, 364 insertions(+), 8 deletions(-) diff --git a/core/fmt/fmt.odin b/core/fmt/fmt.odin index cee00da23..804a29cab 100644 --- a/core/fmt/fmt.odin +++ b/core/fmt/fmt.odin @@ -1954,7 +1954,40 @@ fmt_value :: proc(fi: ^Info, v: any, verb: rune) { } case runtime.Type_Info_Matrix: - io.write_string(fi.writer, "[]") + reflect.write_type(fi.writer, type_info_of(v.id)) + io.write_byte(fi.writer, '{') + defer io.write_byte(fi.writer, '}') + + fi.indent += 1; defer fi.indent -= 1 + + if fi.hash { + io.write_byte(fi.writer, '\n') + // TODO(bill): Should this render it like in written form? e.g. tranposed + for col in 0.. 0 { io.write_string(fi.writer, ", ") } + + offset := row*info.elem_size + col*info.stride + + data := uintptr(v.data) + uintptr(offset) + fmt_arg(fi, any{rawptr(data), info.elem.id}, verb) + } + io.write_string(fi.writer, ";\n") + } + } else { + for col in 0.. 0 { io.write_string(fi.writer, "; ") } + for row in 0.. 0 { io.write_string(fi.writer, ", ") } + + offset := row*info.elem_size + col*info.stride + + data := uintptr(v.data) + uintptr(offset) + fmt_arg(fi, any{rawptr(data), info.elem.id}, verb) + } + } + } } } diff --git a/src/check_expr.cpp b/src/check_expr.cpp index 85f2eeb23..9c12802d7 100644 --- a/src/check_expr.cpp +++ b/src/check_expr.cpp @@ -1400,8 +1400,9 @@ bool check_unary_op(CheckerContext *c, Operand *o, Token op) { } bool check_binary_op(CheckerContext *c, Operand *o, Token op) { + Type *main_type = o->type; // TODO(bill): Handle errors correctly - Type *type = base_type(core_array_type(o->type)); + Type *type = base_type(core_array_type(main_type)); Type *ct = core_type(type); switch (op.kind) { case Token_Sub: @@ -1414,10 +1415,15 @@ bool check_binary_op(CheckerContext *c, Operand *o, Token op) { } break; - case Token_Mul: case Token_Quo: - case Token_MulEq: case Token_QuoEq: + if (is_type_matrix(main_type)) { + error(op, "Operator '%.*s' is only allowed with matrix types", LIT(op.string)); + return false; + } + /*fallthrough*/ + case Token_Mul: + case Token_MulEq: case Token_AddEq: if (is_type_bit_set(type)) { return true; @@ -1458,6 +1464,10 @@ bool check_binary_op(CheckerContext *c, Operand *o, Token op) { case Token_ModMod: case Token_ModEq: case Token_ModModEq: + if (is_type_matrix(main_type)) { + error(op, "Operator '%.*s' is only allowed with matrix types", LIT(op.string)); + return false; + } if (!is_type_integer(type)) { error(op, "Operator '%.*s' is only allowed with integers", LIT(op.string)); return false; @@ -2671,6 +2681,114 @@ bool can_use_other_type_as_type_hint(bool use_lhs_as_type_hint, Type *other_type } +void check_binary_matrix(CheckerContext *c, Token const &op, Operand *x, Operand *y, Type *type_hint, bool use_lhs_as_type_hint) { + if (!check_binary_op(c, x, op)) { + x->mode = Addressing_Invalid; + return; + } + + if (is_type_matrix(x->type)) { + Type *xt = base_type(x->type); + Type *yt = base_type(y->type); + GB_ASSERT(xt->kind == Type_Matrix); + if (op.kind == Token_Mul) { + if (yt->kind == Type_Matrix) { + if (!are_types_identical(xt->Matrix.elem, yt->Matrix.elem)) { + goto matrix_error; + } + + if (xt->Matrix.column_count != yt->Matrix.row_count) { + goto matrix_error; + } + x->mode = Addressing_Value; + x->type = alloc_type_matrix(xt->Matrix.elem, xt->Matrix.row_count, yt->Matrix.column_count); + goto matrix_success; + } else if (yt->kind == Type_Array) { + if (!are_types_identical(xt->Matrix.elem, yt->Array.elem)) { + goto matrix_error; + } + + if (xt->Matrix.column_count != yt->Array.count) { + goto matrix_error; + } + + // Treat arrays as column vectors + x->mode = Addressing_Value; + x->type = alloc_type_matrix(xt->Matrix.elem, xt->Matrix.row_count, 1); + goto matrix_success; + } + } + if (!are_types_identical(xt, yt)) { + goto matrix_error; + } + x->mode = Addressing_Value; + x->type = xt; + goto matrix_success; + } else { + Type *xt = base_type(x->type); + Type *yt = base_type(y->type); + GB_ASSERT(is_type_matrix(yt)); + GB_ASSERT(!is_type_matrix(xt)); + + if (op.kind == Token_Mul) { + // NOTE(bill): no need to handle the matrix case here since it should be handled above + if (xt->kind == Type_Array) { + if (!are_types_identical(yt->Matrix.elem, xt->Array.elem)) { + goto matrix_error; + } + + if (xt->Array.count != yt->Matrix.row_count) { + goto matrix_error; + } + + // Treat arrays as row vectors + x->mode = Addressing_Value; + x->type = alloc_type_matrix(xt->Matrix.elem, 1, xt->Matrix.column_count); + goto matrix_success; + } + } + if (!are_types_identical(xt, yt)) { + goto matrix_error; + } + x->mode = Addressing_Value; + x->type = xt; + goto matrix_success; + } + +matrix_success: + if (type_hint != nullptr) { + Type *th = base_type(type_hint); + if (are_types_identical(th, x->type)) { + x->type = type_hint; + } else if (x->type->kind == Type_Matrix && th->kind == Type_Array) { + Type *xt = x->type; + if (!are_types_identical(xt->Matrix.elem, th->Array.elem)) { + // ignore + } else if (xt->Matrix.row_count == 1 && xt->Matrix.column_count == th->Array.count) { + x->type = type_hint; + } else if (xt->Matrix.column_count == 1 && xt->Matrix.row_count == th->Array.count) { + x->type = type_hint; + } + } + } + return; + + +matrix_error: + gbString xt = type_to_string(x->type); + gbString yt = type_to_string(y->type); + gbString expr_str = expr_to_string(x->expr); + error(op, "Mismatched types in binary matrix expression '%s' for operator '%.*s' : '%s' vs '%s'", expr_str, LIT(op.string), xt, yt); + gb_string_free(expr_str); + gb_string_free(yt); + gb_string_free(xt); + x->type = t_invalid; + x->mode = Addressing_Invalid; + return; + +} + + void check_binary_expr(CheckerContext *c, Operand *x, Ast *node, Type *type_hint, bool use_lhs_as_type_hint=false) { GB_ASSERT(node->kind == Ast_BinaryExpr); Operand y_ = {}, *y = &y_; @@ -2874,6 +2992,12 @@ void check_binary_expr(CheckerContext *c, Operand *x, Ast *node, Type *type_hint x->type = y->type; return; } + if (is_type_matrix(x->type) || is_type_matrix(y->type)) { + check_binary_matrix(c, op, x, y, type_hint, use_lhs_as_type_hint); + return; + } + + if (!are_types_identical(x->type, y->type)) { if (x->type != t_invalid && y->type != t_invalid) { @@ -3258,6 +3382,29 @@ void convert_to_typed(CheckerContext *c, Operand *operand, Type *target_type) { break; } + + case Type_Matrix: { + Type *elem = base_array_type(t); + if (check_is_assignable_to(c, operand, elem)) { + if (t->Matrix.row_count != t->Matrix.column_count) { + operand->mode = Addressing_Invalid; + begin_error_block(); + defer (end_error_block()); + + convert_untyped_error(c, operand, target_type); + error_line("\tNote: Only a square matrix types can be initialized with a scalar value\n"); + return; + } else { + operand->mode = Addressing_Value; + } + } else { + operand->mode = Addressing_Invalid; + convert_untyped_error(c, operand, target_type); + return; + } + break; + } + case Type_Union: if (!is_operand_nil(*operand) && !is_operand_undef(*operand)) { diff --git a/src/checker.cpp b/src/checker.cpp index 8711fdc0c..c0e6d47c0 100644 --- a/src/checker.cpp +++ b/src/checker.cpp @@ -1659,6 +1659,10 @@ void add_type_info_type_internal(CheckerContext *c, Type *t) { add_type_info_type_internal(c, bt->RelativeSlice.slice_type); add_type_info_type_internal(c, bt->RelativeSlice.base_integer); break; + + case Type_Matrix: + add_type_info_type_internal(c, bt->Matrix.elem); + break; default: GB_PANIC("Unhandled type: %*.s %d", LIT(type_strings[bt->kind]), bt->kind); @@ -1870,6 +1874,10 @@ void add_min_dep_type_info(Checker *c, Type *t) { add_min_dep_type_info(c, bt->RelativeSlice.slice_type); add_min_dep_type_info(c, bt->RelativeSlice.base_integer); break; + + case Type_Matrix: + add_min_dep_type_info(c, bt->Matrix.elem); + break; default: GB_PANIC("Unhandled type: %*.s", LIT(type_strings[bt->kind])); diff --git a/src/llvm_backend.hpp b/src/llvm_backend.hpp index ffb81f0e4..73ddad797 100644 --- a/src/llvm_backend.hpp +++ b/src/llvm_backend.hpp @@ -333,6 +333,10 @@ lbValue lb_emit_array_ep(lbProcedure *p, lbValue s, lbValue index); lbValue lb_emit_deep_field_gep(lbProcedure *p, lbValue e, Selection sel); lbValue lb_emit_deep_field_ev(lbProcedure *p, lbValue e, Selection sel); +lbValue lb_emit_matrix_epi(lbProcedure *p, lbValue s, isize row, isize column); +lbValue lb_emit_matrix_ev(lbProcedure *p, lbValue s, isize row, isize column); + + lbValue lb_emit_arith(lbProcedure *p, TokenKind op, lbValue lhs, lbValue rhs, Type *type); lbValue lb_emit_byte_swap(lbProcedure *p, lbValue value, Type *end_type); void lb_emit_defer_stmts(lbProcedure *p, lbDeferExitKind kind, lbBlock *block); diff --git a/src/llvm_backend_const.cpp b/src/llvm_backend_const.cpp index 68050e0ce..4cfcecdc3 100644 --- a/src/llvm_backend_const.cpp +++ b/src/llvm_backend_const.cpp @@ -512,6 +512,34 @@ lbValue lb_const_value(lbModule *m, Type *type, ExactValue value, bool allow_loc res.value = llvm_const_array(lb_type(m, elem), elems, cast(unsigned)count); return res; + } else if (is_type_matrix(type) && + value.kind != ExactValue_Invalid && + value.kind != ExactValue_Compound) { + i64 row = type->Matrix.row_count; + i64 column = type->Matrix.column_count; + GB_ASSERT(row == column); + + Type *elem = type->Matrix.elem; + + lbValue single_elem = lb_const_value(m, elem, value, allow_local); + single_elem.value = llvm_const_cast(single_elem.value, lb_type(m, elem)); + + i64 stride_bytes = matrix_type_stride(type); + i64 stride_elems = stride_bytes/type_size_of(elem); + + i64 total_elem_count = matrix_type_total_elems(type); + LLVMValueRef *elems = gb_alloc_array(permanent_allocator(), LLVMValueRef, cast(isize)total_elem_count); + for (i64 i = 0; i < row; i++) { + elems[i*stride_elems + i] = single_elem.value; + } + for (i64 i = 0; i < total_elem_count; i++) { + if (elems[i] == nullptr) { + elems[i] = LLVMConstNull(lb_type(m, elem)); + } + } + + res.value = LLVMConstArray(lb_type(m, elem), elems, cast(unsigned)total_elem_count); + return res; } switch (value.kind) { diff --git a/src/llvm_backend_expr.cpp b/src/llvm_backend_expr.cpp index 3056952f6..6b7d90ec0 100644 --- a/src/llvm_backend_expr.cpp +++ b/src/llvm_backend_expr.cpp @@ -477,10 +477,72 @@ lbValue lb_emit_arith_array(lbProcedure *p, TokenKind op, lbValue lhs, lbValue r } +lbValue lb_emit_arith_matrix(lbProcedure *p, TokenKind op, lbValue lhs, lbValue rhs, Type *type) { + GB_ASSERT(is_type_matrix(lhs.type) || is_type_matrix(rhs.type)); + + Type *xt = base_type(lhs.type); + Type *yt = base_type(rhs.type); + + if (op == Token_Mul) { + if (xt->kind == Type_Matrix) { + if (yt->kind == Type_Matrix) { + GB_ASSERT(is_type_matrix(type)); + GB_ASSERT(xt->Matrix.column_count == yt->Matrix.row_count); + GB_ASSERT(are_types_identical(xt->Matrix.elem, yt->Matrix.elem)); + + Type *elem = xt->Matrix.elem; + + lbAddr res = lb_add_local_generated(p, type, true); + for (i64 i = 0; i < xt->Matrix.row_count; i++) { + for (i64 j = 0; j < yt->Matrix.column_count; j++) { + for (i64 k = 0; k < xt->Matrix.column_count; k++) { + lbValue dst = lb_emit_matrix_epi(p, res.addr, i, j); + + lbValue a = lb_emit_matrix_ev(p, lhs, i, k); + lbValue b = lb_emit_matrix_ev(p, rhs, k, j); + lbValue c = lb_emit_arith(p, op, a, b, elem); + lbValue d = lb_emit_load(p, dst); + lbValue e = lb_emit_arith(p, Token_Add, d, c, elem); + lb_emit_store(p, dst, e); + + } + } + } + + return lb_addr_load(p, res); + } + } + + } else { + GB_ASSERT(are_types_identical(xt, yt)); + GB_ASSERT(xt->kind == Type_Matrix); + // element-wise arithmetic + // pretend it is an array + lbValue array_lhs = lhs; + lbValue array_rhs = rhs; + Type *array_type = alloc_type_array(xt->Matrix.elem, matrix_type_total_elems(xt)); + GB_ASSERT(type_size_of(array_type) == type_size_of(type)); + + array_lhs.type = array_type; + array_rhs.type = array_type; + + lbValue array = lb_emit_arith_array(p, op, array_lhs, array_rhs, type); + array.type = type; + return array; + } + + GB_PANIC("TODO: lb_emit_arith_matrix"); + + return {}; +} + + lbValue lb_emit_arith(lbProcedure *p, TokenKind op, lbValue lhs, lbValue rhs, Type *type) { if (is_type_array_like(lhs.type) || is_type_array_like(rhs.type)) { return lb_emit_arith_array(p, op, lhs, rhs, type); + } else if (is_type_matrix(lhs.type) || is_type_matrix(rhs.type)) { + return lb_emit_arith_matrix(p, op, lhs, rhs, type); } else if (is_type_complex(type)) { lhs = lb_emit_conv(p, lhs, type); rhs = lb_emit_conv(p, rhs, type); @@ -1417,6 +1479,22 @@ lbValue lb_emit_conv(lbProcedure *p, lbValue value, Type *t) { } return lb_addr_load(p, v); } + + if (is_type_matrix(dst) && !is_type_matrix(src)) { + GB_ASSERT(dst->Matrix.row_count == dst->Matrix.column_count); + + Type *elem = base_array_type(dst); + lbValue e = lb_emit_conv(p, value, elem); + lbAddr v = lb_add_local_generated(p, t, false); + for (i64 i = 0; i < dst->Matrix.row_count; i++) { + isize j = cast(isize)i; + lbValue ptr = lb_emit_matrix_epi(p, v.addr, j, j); + lb_emit_store(p, ptr, e); + } + + + return lb_addr_load(p, v); + } if (is_type_any(dst)) { if (is_type_untyped_nil(src)) { diff --git a/src/llvm_backend_utility.cpp b/src/llvm_backend_utility.cpp index 0531c62bb..1b41be2a3 100644 --- a/src/llvm_backend_utility.cpp +++ b/src/llvm_backend_utility.cpp @@ -1221,6 +1221,41 @@ lbValue lb_emit_ptr_offset(lbProcedure *p, lbValue ptr, lbValue index) { return res; } +lbValue lb_emit_matrix_epi(lbProcedure *p, lbValue s, isize row, isize column) { + Type *t = s.type; + GB_ASSERT(is_type_pointer(t)); + Type *st = base_type(type_deref(t)); + GB_ASSERT_MSG(is_type_matrix(st), "%s", type_to_string(st)); + + Type *ptr = base_array_type(st); + + isize index = row*column; + GB_ASSERT(0 <= index); + + LLVMValueRef indices[2] = { + LLVMConstInt(lb_type(p->module, t_int), 0, false), + LLVMConstInt(lb_type(p->module, t_int), cast(unsigned)index, false), + }; + + lbValue res = {}; + if (lb_is_const(s)) { + res.value = LLVMConstGEP(s.value, indices, gb_count_of(indices)); + } else { + res.value = LLVMBuildGEP(p->builder, s.value, indices, gb_count_of(indices), ""); + } + res.type = alloc_type_pointer(ptr); + return res; +} + +lbValue lb_emit_matrix_ev(lbProcedure *p, lbValue s, isize row, isize column) { + Type *st = base_type(s.type); + GB_ASSERT_MSG(is_type_matrix(st), "%s", type_to_string(st)); + + lbValue value = lb_address_from_load_or_generate_local(p, s); + lbValue ptr = lb_emit_matrix_epi(p, value, row, column); + return lb_emit_load(p, ptr); +} + void lb_fill_slice(lbProcedure *p, lbAddr const &slice, lbValue base_elem, lbValue len) { Type *t = lb_addr_type(slice); diff --git a/src/types.cpp b/src/types.cpp index 0313ade60..fd9b20c91 100644 --- a/src/types.cpp +++ b/src/types.cpp @@ -1257,6 +1257,22 @@ i64 matrix_type_stride(Type *t) { return stride; } +i64 matrix_type_stride_in_elems(Type *t) { + t = base_type(t); + GB_ASSERT(t->kind == Type_Matrix); + i64 stride = matrix_type_stride(t); + return stride/gb_max(1, type_size_of(t->Matrix.elem)); +} + + +i64 matrix_type_total_elems(Type *t) { + t = base_type(t); + GB_ASSERT(t->kind == Type_Matrix); + i64 size = type_size_of(t); + i64 elem_size = type_size_of(t->Matrix.elem); + return size/gb_max(elem_size, 1); +} + bool is_type_dynamic_array(Type *t) { t = base_type(t); return t->kind == Type_DynamicArray; @@ -3174,17 +3190,17 @@ i64 type_align_of_internal(Type *t, TypePath *path) { case Type_Matrix: { Type *elem = t->Matrix.elem; - i64 row_count = t->Matrix.row_count; - // i64 column_count = t->Matrix.column_count; + i64 row_count = gb_max(t->Matrix.row_count, 1); + bool pop = type_path_push(path, elem); if (path->failure) { return FAILURE_ALIGNMENT; } + // elem align is used here rather than size as it make a little more sense i64 elem_align = type_align_of_internal(elem, path); if (pop) type_path_pop(path); - i64 align = gb_clamp(elem_align * row_count, elem_align, build_context.max_align); - + i64 align = gb_min(next_pow2(elem_align * row_count), build_context.max_align); return align; } @@ -3935,6 +3951,13 @@ gbString write_type_to_string(gbString str, Type *type) { str = gb_string_append_fmt(str, ") "); str = write_type_to_string(str, type->RelativeSlice.slice_type); break; + + case Type_Matrix: + str = gb_string_appendc(str, gb_bprintf("[%d", cast(int)type->Matrix.row_count)); + str = gb_string_appendc(str, "; "); + str = gb_string_appendc(str, gb_bprintf("%d]", cast(int)type->Matrix.column_count)); + str = write_type_to_string(str, type->Matrix.elem); + break; } return str;