diff --git a/src/check_expr.cpp b/src/check_expr.cpp index 959bbb078..f115dd6b2 100644 --- a/src/check_expr.cpp +++ b/src/check_expr.cpp @@ -3039,8 +3039,8 @@ void check_binary_matrix(CheckerContext *c, Token const &op, Operand *x, Operand x->type = xt; goto matrix_success; } else { - GB_ASSERT(is_type_matrix(yt)); GB_ASSERT(!is_type_matrix(xt)); + GB_ASSERT(is_type_matrix(yt)); if (op.kind == Token_Mul) { // NOTE(bill): no need to handle the matrix case here since it should be handled above @@ -3061,6 +3061,9 @@ void check_binary_matrix(CheckerContext *c, Token const &op, Operand *x, Operand x->type = alloc_type_matrix(yt->Matrix.elem, 1, yt->Matrix.column_count); } goto matrix_success; + } else if (are_types_identical(yt->Matrix.elem, xt)) { + x->type = check_matrix_type_hint(y->type, type_hint); + return; } } if (!are_types_identical(xt, yt)) { diff --git a/src/llvm_backend_expr.cpp b/src/llvm_backend_expr.cpp index 023e7b363..7c92c517c 100644 --- a/src/llvm_backend_expr.cpp +++ b/src/llvm_backend_expr.cpp @@ -1,4 +1,4 @@ -lbValue lb_emit_arith_matrix(lbProcedure *p, TokenKind op, lbValue lhs, lbValue rhs, Type *type, bool component_wise=false); +lbValue lb_emit_arith_matrix(lbProcedure *p, TokenKind op, lbValue lhs, lbValue rhs, Type *type, bool component_wise); lbValue lb_emit_logical_binary_expr(lbProcedure *p, TokenKind op, Ast *left, Ast *right, Type *type) { lbModule *m = p->module; @@ -987,7 +987,6 @@ lbValue lb_emit_vector_mul_matrix(lbProcedure *p, lbValue lhs, lbValue rhs, Type lbValue lb_emit_arith_matrix(lbProcedure *p, TokenKind op, lbValue lhs, lbValue rhs, Type *type, bool component_wise) { GB_ASSERT(is_type_matrix(lhs.type) || is_type_matrix(rhs.type)); - if (op == Token_Mul && !component_wise) { Type *xt = base_type(lhs.type); Type *yt = base_type(rhs.type); @@ -1001,8 +1000,22 @@ lbValue lb_emit_arith_matrix(lbProcedure *p, TokenKind op, lbValue lhs, lbValue } else if (is_type_array_like(xt)) { GB_ASSERT(yt->kind == Type_Matrix); return lb_emit_vector_mul_matrix(p, lhs, rhs, type); - } + } else { + GB_ASSERT(xt->kind == Type_Basic); + GB_ASSERT(yt->kind == Type_Matrix); + GB_ASSERT(is_type_matrix(type)); + Type *array_type = alloc_type_array(yt->Matrix.elem, matrix_type_total_internal_elems(yt)); + GB_ASSERT(type_size_of(array_type) == type_size_of(yt)); + + lbValue array_lhs = lb_emit_conv(p, lhs, array_type); + lbValue array_rhs = rhs; + array_rhs.type = array_type; + + lbValue array = lb_emit_arith(p, op, array_lhs, array_rhs, array_type); + array.type = type; + return array; + } } else { if (is_type_matrix(lhs.type)) { rhs = lb_emit_conv(p, rhs, lhs.type); @@ -1047,7 +1060,7 @@ lbValue lb_emit_arith(lbProcedure *p, TokenKind op, lbValue lhs, lbValue rhs, Ty if (is_type_array_like(lhs.type) || is_type_array_like(rhs.type)) { return lb_emit_arith_array(p, op, lhs, rhs, type); } else if (is_type_matrix(lhs.type) || is_type_matrix(rhs.type)) { - return lb_emit_arith_matrix(p, op, lhs, rhs, type); + return lb_emit_arith_matrix(p, op, lhs, rhs, type, false); } else if (is_type_complex(type)) { lhs = lb_emit_conv(p, lhs, type); rhs = lb_emit_conv(p, rhs, type); @@ -1320,7 +1333,7 @@ lbValue lb_build_binary_expr(lbProcedure *p, Ast *expr) { if (is_type_matrix(be->left->tav.type) || is_type_matrix(be->right->tav.type)) { lbValue left = lb_build_expr(p, be->left); lbValue right = lb_build_expr(p, be->right); - return lb_emit_arith_matrix(p, be->op.kind, left, right, default_type(tv.type)); + return lb_emit_arith_matrix(p, be->op.kind, left, right, default_type(tv.type), false); }