From e0b9475378f4d69ebaf3e141ed941674b2c0d3f3 Mon Sep 17 00:00:00 2001 From: gingerBill Date: Thu, 21 Oct 2021 01:14:44 +0100 Subject: [PATCH] Allow casting between square matrices of the same element type --- src/check_expr.cpp | 19 +++++++++++++++++ src/check_type.cpp | 10 ++++----- src/llvm_backend_expr.cpp | 44 ++++++++++++++++++++++++++++----------- 3 files changed, 56 insertions(+), 17 deletions(-) diff --git a/src/check_expr.cpp b/src/check_expr.cpp index 498bf78c7..ad12e00c8 100644 --- a/src/check_expr.cpp +++ b/src/check_expr.cpp @@ -2460,6 +2460,24 @@ bool check_is_castable_to(CheckerContext *c, Operand *operand, Type *y) { if (is_type_quaternion(src) && is_type_quaternion(dst)) { return true; } + + if (is_type_matrix(src) && is_type_matrix(dst)) { + GB_ASSERT(src->kind == Type_Matrix); + GB_ASSERT(dst->kind == Type_Matrix); + if (!are_types_identical(src->Matrix.elem, dst->Matrix.elem)) { + return false; + } + + if (src->Matrix.row_count != src->Matrix.column_count) { + return false; + } + + if (dst->Matrix.row_count != dst->Matrix.column_count) { + return false; + } + + return true; + } // Cast between pointers @@ -8838,6 +8856,7 @@ ExprKind check_expr_base_internal(CheckerContext *c, Operand *o, Ast *node, Type case Ast_EnumType: case Ast_MapType: case Ast_BitSetType: + case Ast_MatrixType: o->mode = Addressing_Type; o->type = check_type(c, node); break; diff --git a/src/check_type.cpp b/src/check_type.cpp index d9302c65a..21c8a9f19 100644 --- a/src/check_type.cpp +++ b/src/check_type.cpp @@ -1154,7 +1154,11 @@ Type *determine_type_from_polymorphic(CheckerContext *ctx, Type *poly_type, Oper bool show_error = modify_type && !ctx->hide_polymorphic_errors; if (!is_operand_value(operand)) { if (show_error) { - error(operand.expr, "Cannot determine polymorphic type from parameter"); + gbString pts = type_to_string(poly_type); + gbString ots = type_to_string(operand.type); + defer (gb_string_free(pts)); + defer (gb_string_free(ots)); + error(operand.expr, "Cannot determine polymorphic type from parameter: '%s' to '%s'", ots, pts); } return t_invalid; } @@ -2839,10 +2843,6 @@ bool check_type_internal(CheckerContext *ctx, Ast *e, Type **type, Type *named_t case_ast_node(mt, MatrixType, e); - bool ips = ctx->in_polymorphic_specialization; - defer (ctx->in_polymorphic_specialization = ips); - ctx->in_polymorphic_specialization = false; - check_matrix_type(ctx, type, e); set_base_type(named_type, *type); return true; diff --git a/src/llvm_backend_expr.cpp b/src/llvm_backend_expr.cpp index cdc1deea1..9582be93c 100644 --- a/src/llvm_backend_expr.cpp +++ b/src/llvm_backend_expr.cpp @@ -476,7 +476,7 @@ lbValue lb_emit_arith_array(lbProcedure *p, TokenKind op, lbValue lhs, lbValue r } } -bool lb_matrix_elem_simple(Type *t) { +bool lb_is_matrix_simdable(Type *t) { Type *mt = base_type(t); GB_ASSERT(mt->kind == Type_Matrix); @@ -555,7 +555,7 @@ lbValue lb_emit_matrix_tranpose(lbProcedure *p, lbValue m, Type *type) { Type *mt = base_type(m.type); GB_ASSERT(mt->kind == Type_Matrix); - if (lb_matrix_elem_simple(mt)) { + if (lb_is_matrix_simdable(mt)) { unsigned stride = cast(unsigned)matrix_type_stride_in_elems(mt); unsigned row_count = cast(unsigned)mt->Matrix.row_count; unsigned column_count = cast(unsigned)mt->Matrix.column_count; @@ -623,7 +623,7 @@ lbValue lb_emit_matrix_flatten(lbProcedure *p, lbValue m, Type *type) { Type *mt = base_type(m.type); GB_ASSERT(mt->kind == Type_Matrix); - if (lb_matrix_elem_simple(mt)) { + if (lb_is_matrix_simdable(mt)) { LLVMValueRef vector = lb_matrix_to_trimmed_vector(p, m); return lb_matrix_cast_vector_to_type(p, vector, type); } @@ -690,7 +690,7 @@ lbValue lb_emit_matrix_mul(lbProcedure *p, lbValue lhs, lbValue rhs, Type *type) unsigned inner = cast(unsigned)xt->Matrix.column_count; unsigned outer_columns = cast(unsigned)yt->Matrix.column_count; - if (lb_matrix_elem_simple(xt)) { + if (lb_is_matrix_simdable(xt)) { unsigned x_stride = cast(unsigned)matrix_type_stride_in_elems(xt); unsigned y_stride = cast(unsigned)matrix_type_stride_in_elems(yt); @@ -773,7 +773,7 @@ lbValue lb_emit_matrix_mul_vector(lbProcedure *p, lbValue lhs, lbValue rhs, Type Type *elem = mt->Matrix.elem; - if (lb_matrix_elem_simple(mt)) { + if (lb_is_matrix_simdable(mt)) { unsigned stride = cast(unsigned)matrix_type_stride_in_elems(mt); unsigned row_count = cast(unsigned)mt->Matrix.row_count; @@ -819,9 +819,8 @@ lbValue lb_emit_matrix_mul_vector(lbProcedure *p, lbValue lhs, lbValue rhs, Type lbValue a = lb_emit_matrix_ev(p, lhs, i, j); lbValue b = lb_emit_struct_ev(p, rhs, cast(i32)j); - lbValue c = lb_emit_arith(p, Token_Mul, a, b, elem); - lbValue d = lb_emit_arith(p, Token_Add, d0, c, elem); - lb_emit_store(p, dst, d); + lbValue c = lb_emit_mul_add(p, a, b, d0, elem); + lb_emit_store(p, dst, c); } } @@ -842,7 +841,7 @@ lbValue lb_emit_vector_mul_matrix(lbProcedure *p, lbValue lhs, lbValue rhs, Type Type *elem = mt->Matrix.elem; - if (lb_matrix_elem_simple(mt)) { + if (lb_is_matrix_simdable(mt)) { unsigned stride = cast(unsigned)matrix_type_stride_in_elems(mt); unsigned row_count = cast(unsigned)mt->Matrix.row_count; @@ -903,9 +902,8 @@ lbValue lb_emit_vector_mul_matrix(lbProcedure *p, lbValue lhs, lbValue rhs, Type lbValue a = lb_emit_struct_ev(p, lhs, cast(i32)k); lbValue b = lb_emit_matrix_ev(p, rhs, k, j); - lbValue c = lb_emit_arith(p, Token_Mul, a, b, elem); - lbValue d = lb_emit_arith(p, Token_Add, d0, c, elem); - lb_emit_store(p, dst, d); + lbValue c = lb_emit_mul_add(p, a, b, d0, elem); + lb_emit_store(p, dst, c); } } @@ -1938,6 +1936,28 @@ lbValue lb_emit_conv(lbProcedure *p, lbValue value, Type *t) { return lb_addr_load(p, v); } + + if (is_type_matrix(dst) && is_type_matrix(src)) { + GB_ASSERT(dst->kind == Type_Matrix); + GB_ASSERT(src->kind == Type_Matrix); + lbAddr v = lb_add_local_generated(p, t, true); + for (i64 j = 0; j < dst->Matrix.column_count; j++) { + for (i64 i = 0; i < dst->Matrix.row_count; i++) { + if (i < src->Matrix.row_count && j < src->Matrix.column_count) { + lbValue d = lb_emit_matrix_epi(p, v.addr, i, j); + lbValue s = lb_emit_matrix_ev(p, value, i, j); + lb_emit_store(p, d, s); + } else if (i == j) { + lbValue d = lb_emit_matrix_epi(p, v.addr, i, j); + lbValue s = lb_const_value(p->module, dst->Matrix.elem, exact_value_i64(1), true); + lb_emit_store(p, d, s); + } + } + } + return lb_addr_load(p, v); + } + + if (is_type_any(dst)) { if (is_type_untyped_nil(src)) {