diff --git a/src/check_expr.cpp b/src/check_expr.cpp index ad12e00c8..ee7493553 100644 --- a/src/check_expr.cpp +++ b/src/check_expr.cpp @@ -2469,7 +2469,9 @@ bool check_is_castable_to(CheckerContext *c, Operand *operand, Type *y) { } if (src->Matrix.row_count != src->Matrix.column_count) { - return false; + i64 src_count = src->Matrix.row_count*src->Matrix.column_count; + i64 dst_count = dst->Matrix.row_count*dst->Matrix.column_count; + return src_count == dst_count; } if (dst->Matrix.row_count != dst->Matrix.column_count) { diff --git a/src/llvm_backend_const.cpp b/src/llvm_backend_const.cpp index 554255f47..b543089e5 100644 --- a/src/llvm_backend_const.cpp +++ b/src/llvm_backend_const.cpp @@ -524,7 +524,7 @@ lbValue lb_const_value(lbModule *m, Type *type, ExactValue value, bool allow_loc lbValue single_elem = lb_const_value(m, elem, value, allow_local); single_elem.value = llvm_const_cast(single_elem.value, lb_type(m, elem)); - i64 total_elem_count = matrix_type_total_elems(type); + i64 total_elem_count = matrix_type_total_internal_elems(type); LLVMValueRef *elems = gb_alloc_array(permanent_allocator(), LLVMValueRef, cast(isize)total_elem_count); for (i64 i = 0; i < row; i++) { elems[matrix_indices_to_offset(type, i, i)] = single_elem.value; @@ -990,7 +990,7 @@ lbValue lb_const_value(lbModule *m, Type *type, ExactValue value, bool allow_loc } i64 max_count = type->Matrix.row_count*type->Matrix.column_count; - i64 total_count = matrix_type_total_elems(type); + i64 total_count = matrix_type_total_internal_elems(type); LLVMValueRef *values = gb_alloc_array(temporary_allocator(), LLVMValueRef, cast(isize)total_count); if (cl->elems[0]->kind == Ast_FieldValue) { diff --git a/src/llvm_backend_expr.cpp b/src/llvm_backend_expr.cpp index 9582be93c..eb88bbde0 100644 --- a/src/llvm_backend_expr.cpp +++ b/src/llvm_backend_expr.cpp @@ -508,7 +508,7 @@ LLVMValueRef lb_matrix_to_vector(lbProcedure *p, lbValue matrix) { GB_ASSERT(mt->kind == Type_Matrix); LLVMTypeRef elem_type = lb_type(p->module, mt->Matrix.elem); - unsigned total_count = cast(unsigned)matrix_type_total_elems(mt); + unsigned total_count = cast(unsigned)matrix_type_total_internal_elems(mt); LLVMTypeRef total_matrix_type = LLVMVectorType(elem_type, total_count); LLVMValueRef ptr = lb_address_from_load_or_generate_local(p, matrix).value; @@ -948,7 +948,7 @@ lbValue lb_emit_arith_matrix(lbProcedure *p, TokenKind op, lbValue lhs, lbValue // pretend it is an array lbValue array_lhs = lhs; lbValue array_rhs = rhs; - Type *array_type = alloc_type_array(xt->Matrix.elem, matrix_type_total_elems(xt)); + Type *array_type = alloc_type_array(xt->Matrix.elem, matrix_type_total_internal_elems(xt)); GB_ASSERT(type_size_of(array_type) == type_size_of(xt)); array_lhs.type = array_type; @@ -1941,15 +1941,33 @@ lbValue lb_emit_conv(lbProcedure *p, lbValue value, Type *t) { GB_ASSERT(dst->kind == Type_Matrix); GB_ASSERT(src->kind == Type_Matrix); lbAddr v = lb_add_local_generated(p, t, true); - for (i64 j = 0; j < dst->Matrix.column_count; j++) { - for (i64 i = 0; i < dst->Matrix.row_count; i++) { - if (i < src->Matrix.row_count && j < src->Matrix.column_count) { - lbValue d = lb_emit_matrix_epi(p, v.addr, i, j); + + if (is_matrix_square(dst) && is_matrix_square(dst)) { + for (i64 j = 0; j < dst->Matrix.column_count; j++) { + for (i64 i = 0; i < dst->Matrix.row_count; i++) { + if (i < src->Matrix.row_count && j < src->Matrix.column_count) { + lbValue d = lb_emit_matrix_epi(p, v.addr, i, j); + lbValue s = lb_emit_matrix_ev(p, value, i, j); + lb_emit_store(p, d, s); + } else if (i == j) { + lbValue d = lb_emit_matrix_epi(p, v.addr, i, j); + lbValue s = lb_const_value(p->module, dst->Matrix.elem, exact_value_i64(1), true); + lb_emit_store(p, d, s); + } + } + } + } else { + i64 dst_count = dst->Matrix.row_count*dst->Matrix.column_count; + i64 src_count = src->Matrix.row_count*src->Matrix.column_count; + GB_ASSERT(dst_count == src_count); + + for (i64 j = 0; j < src->Matrix.column_count; j++) { + for (i64 i = 0; i < src->Matrix.row_count; i++) { lbValue s = lb_emit_matrix_ev(p, value, i, j); - lb_emit_store(p, d, s); - } else if (i == j) { - lbValue d = lb_emit_matrix_epi(p, v.addr, i, j); - lbValue s = lb_const_value(p->module, dst->Matrix.elem, exact_value_i64(1), true); + i64 index = i + j*src->Matrix.row_count; + i64 dst_i = index%dst->Matrix.row_count; + i64 dst_j = index/dst->Matrix.row_count; + lbValue d = lb_emit_matrix_epi(p, v.addr, dst_i, dst_j); lb_emit_store(p, d, s); } } diff --git a/src/types.cpp b/src/types.cpp index d3fa363c2..3abcebdfb 100644 --- a/src/types.cpp +++ b/src/types.cpp @@ -1293,7 +1293,7 @@ i64 matrix_type_stride_in_elems(Type *t) { } -i64 matrix_type_total_elems(Type *t) { +i64 matrix_type_total_internal_elems(Type *t) { t = base_type(t); GB_ASSERT(t->kind == Type_Matrix); i64 size = type_size_of(t); @@ -1301,30 +1301,6 @@ i64 matrix_type_total_elems(Type *t) { return size/gb_max(elem_size, 1); } -void matrix_indices_from_index(Type *t, i64 index, i64 *row_index_, i64 *column_index_) { - t = base_type(t); - GB_ASSERT(t->kind == Type_Matrix); - i64 row_count = t->Matrix.row_count; - i64 column_count = t->Matrix.column_count; - GB_ASSERT(0 <= index && index < row_count*column_count); - - i64 row_index = index / column_count; - i64 column_index = index % column_count; - - if (row_index_) *row_index_ = row_index; - if (column_index_) *column_index_ = column_index; -} - -i64 matrix_index_to_offset(Type *t, i64 index) { - t = base_type(t); - GB_ASSERT(t->kind == Type_Matrix); - - i64 row_index, column_index; - matrix_indices_from_index(t, index, &row_index, &column_index); - i64 stride_elems = matrix_type_stride_in_elems(t); - return stride_elems*column_index + row_index; -} - i64 matrix_indices_to_offset(Type *t, i64 row_index, i64 column_index) { t = base_type(t); GB_ASSERT(t->kind == Type_Matrix); @@ -1333,6 +1309,22 @@ i64 matrix_indices_to_offset(Type *t, i64 row_index, i64 column_index) { i64 stride_elems = matrix_type_stride_in_elems(t); return stride_elems*column_index + row_index; } +i64 matrix_index_to_offset(Type *t, i64 index) { + t = base_type(t); + GB_ASSERT(t->kind == Type_Matrix); + + i64 row_index = index%t->Matrix.row_count; + i64 column_index = index/t->Matrix.row_count; + return matrix_indices_to_offset(t, row_index, column_index); +} + + + +bool is_matrix_square(Type *t) { + t = base_type(t); + GB_ASSERT(t->kind == Type_Matrix); + return t->Matrix.row_count == t->Matrix.column_count; +} bool is_type_valid_for_matrix_elems(Type *t) { t = base_type(t);