diff --git a/src/llvm_backend_expr.cpp b/src/llvm_backend_expr.cpp index f6b9934ef..f20c52e88 100644 --- a/src/llvm_backend_expr.cpp +++ b/src/llvm_backend_expr.cpp @@ -705,31 +705,37 @@ gb_internal lbValue lb_emit_matrix_flatten(lbProcedure *p, lbValue m, Type *type lbAddr res = lb_add_local_generated(p, type, true); - i64 row_count = mt->Matrix.row_count; - i64 column_count = mt->Matrix.column_count; - TEMPORARY_ALLOCATOR_GUARD(); + GB_ASSERT(type_size_of(type) == type_size_of(m.type)); - auto srcs = array_make(temporary_allocator(), 0, row_count*column_count); - auto dsts = array_make(temporary_allocator(), 0, row_count*column_count); + lbValue m_ptr = lb_address_from_load_or_generate_local(p, m); + lbValue n = lb_const_int(p->module, t_int, type_size_of(type)); + lb_mem_copy_non_overlapping(p, res.addr, m_ptr, n); - for (i64 j = 0; j < column_count; j++) { - for (i64 i = 0; i < row_count; i++) { - lbValue src = lb_emit_matrix_ev(p, m, i, j); - array_add(&srcs, src); - } - } + // i64 row_count = mt->Matrix.row_count; + // i64 column_count = mt->Matrix.column_count; + // TEMPORARY_ALLOCATOR_GUARD(); - for (i64 j = 0; j < column_count; j++) { - for (i64 i = 0; i < row_count; i++) { - lbValue dst = lb_emit_array_epi(p, res.addr, i + j*row_count); - array_add(&dsts, dst); - } - } + // auto srcs = array_make(temporary_allocator(), 0, row_count*column_count); + // auto dsts = array_make(temporary_allocator(), 0, row_count*column_count); - GB_ASSERT(srcs.count == dsts.count); - for_array(i, srcs) { - lb_emit_store(p, dsts[i], srcs[i]); - } + // for (i64 j = 0; j < column_count; j++) { + // for (i64 i = 0; i < row_count; i++) { + // lbValue src = lb_emit_matrix_ev(p, m, i, j); + // array_add(&srcs, src); + // } + // } + + // for (i64 j = 0; j < column_count; j++) { + // for (i64 i = 0; i < row_count; i++) { + // lbValue dst = lb_emit_array_epi(p, res.addr, i + j*row_count); + // array_add(&dsts, dst); + // } + // } + + // GB_ASSERT(srcs.count == dsts.count); + // for_array(i, srcs) { + // lb_emit_store(p, dsts[i], srcs[i]); + // } return lb_addr_load(p, res); } diff --git a/src/types.cpp b/src/types.cpp index 63182f5c4..a9a7d6dda 100644 --- a/src/types.cpp +++ b/src/types.cpp @@ -1474,6 +1474,7 @@ gb_internal i64 matrix_align_of(Type *t, struct TypePath *tp) { Type *elem = t->Matrix.elem; i64 row_count = gb_max(t->Matrix.row_count, 1); + i64 column_count = gb_max(t->Matrix.column_count, 1); bool pop = type_path_push(tp, elem); if (tp->failure) { @@ -1491,7 +1492,7 @@ gb_internal i64 matrix_align_of(Type *t, struct TypePath *tp) { // could be maximally aligned but as a compromise, having no padding will be // beneficial to third libraries that assume no padding - i64 total_expected_size = row_count*t->Matrix.column_count*elem_size; + i64 total_expected_size = row_count*column_count*elem_size; // i64 min_alignment = prev_pow2(elem_align * row_count); i64 min_alignment = prev_pow2(total_expected_size); while (total_expected_size != 0 && (total_expected_size % min_alignment) != 0) { @@ -1523,12 +1524,15 @@ gb_internal i64 matrix_type_stride_in_bytes(Type *t, struct TypePath *tp) { i64 stride_in_bytes = 0; // NOTE(bill, 2021-10-25): The alignment strategy here is to have zero padding - // It would be better for performance to pad each column so that each column + // It would be better for performance to pad each column/row so that each column/row // could be maximally aligned but as a compromise, having no padding will be // beneficial to third libraries that assume no padding - i64 row_count = t->Matrix.row_count; - stride_in_bytes = elem_size*row_count; - + + if (t->Matrix.is_row_major) { + stride_in_bytes = elem_size*t->Matrix.column_count; + } else { + stride_in_bytes = elem_size*t->Matrix.row_count; + } t->Matrix.stride_in_bytes = stride_in_bytes; return stride_in_bytes; } @@ -4217,7 +4221,11 @@ gb_internal i64 type_size_of_internal(Type *t, TypePath *path) { case Type_Matrix: { i64 stride_in_bytes = matrix_type_stride_in_bytes(t, path); - return stride_in_bytes * t->Matrix.column_count; + if (t->Matrix.is_row_major) { + return stride_in_bytes * t->Matrix.row_count; + } else { + return stride_in_bytes * t->Matrix.column_count; + } } case Type_BitField: