diff --git a/src/llvm_backend_expr.cpp b/src/llvm_backend_expr.cpp index ba2bea7cd..968d27a4e 100644 --- a/src/llvm_backend_expr.cpp +++ b/src/llvm_backend_expr.cpp @@ -822,35 +822,32 @@ gb_internal lbValue lb_emit_matrix_transpose(lbProcedure *p, lbValue m, Type *ty GB_PANIC("TODO: transpose with changing layout"); } - if (lb_is_matrix_simdable(mt) && lb_is_matrix_simdable(type)) { + if (lb_is_matrix_simdable(mt, true) && lb_is_matrix_simdable(type, true)) { + auto const do_u32 = [](lbProcedure *p, u32 val) -> LLVMValueRef { + return LLVMConstInt(lb_type(p->module, t_u32), val, false); + }; + unsigned stride = cast(unsigned)matrix_type_stride_in_elems(mt); unsigned row_count = cast(unsigned)mt->Matrix.row_count; unsigned column_count = cast(unsigned)mt->Matrix.column_count; - - auto rows = slice_make(permanent_allocator(), row_count); - auto mask_elems = slice_make(permanent_allocator(), column_count); + unsigned other_stride = (row_count*column_count)/stride; LLVMValueRef vector = lb_matrix_to_vector(p, m); + auto mask_elems = slice_make(permanent_allocator(), row_count * column_count); for (unsigned i = 0; i < row_count; i++) { for (unsigned j = 0; j < column_count; j++) { - unsigned offset = stride*j + i; - mask_elems[j] = lb_const_int(p->module, t_u32, offset).value; + mask_elems[other_stride*i + j] = do_u32(p, stride*j + i); } - - // transpose mask - LLVMValueRef mask = LLVMConstVector(mask_elems.data, column_count); - LLVMValueRef row = llvm_basic_shuffle(p, vector, mask); - rows[i] = row; } + LLVMValueRef mask = LLVMConstVector(mask_elems.data, cast(unsigned)mask_elems.count); + LLVMValueRef transposed_vector = llvm_basic_shuffle(p, vector, mask); + lbAddr res = lb_add_local_generated(p, type, false); - lbAddr res = lb_add_local_generated(p, type, true); - for_array(i, rows) { - LLVMValueRef row = rows[i]; - lbValue dst_row_ptr = lb_emit_matrix_epi(p, res.addr, 0, i); - LLVMValueRef ptr = dst_row_ptr.value; - ptr = LLVMBuildPointerCast(p->builder, ptr, LLVMPointerType(LLVMTypeOf(row), 0), ""); - LLVMBuildStore(p->builder, row, ptr); - } + LLVMValueRef res_ptr = res.addr.value; + res_ptr = LLVMBuildPointerCast(p->builder, res_ptr, LLVMPointerType(LLVMTypeOf(transposed_vector), 0), ""); + + LLVMValueRef store = LLVMBuildStore(p->builder, transposed_vector, res_ptr); + LLVMSetAlignment(store, cast(unsigned)type_align_of(type)); return lb_addr_load(p, res); }