Improve code generation for intrinsics.transpose

This commit is contained in:
gingerBill
2026-05-11 13:41:30 +01:00
parent b752ff4bdb
commit 7d0dba1b82

View File

@@ -822,35 +822,32 @@ gb_internal lbValue lb_emit_matrix_transpose(lbProcedure *p, lbValue m, Type *ty
GB_PANIC("TODO: transpose with changing layout");
}
if (lb_is_matrix_simdable(mt) && lb_is_matrix_simdable(type)) {
if (lb_is_matrix_simdable(mt, true) && lb_is_matrix_simdable(type, true)) {
auto const do_u32 = [](lbProcedure *p, u32 val) -> LLVMValueRef {
return LLVMConstInt(lb_type(p->module, t_u32), val, false);
};
unsigned stride = cast(unsigned)matrix_type_stride_in_elems(mt);
unsigned row_count = cast(unsigned)mt->Matrix.row_count;
unsigned column_count = cast(unsigned)mt->Matrix.column_count;
auto rows = slice_make<LLVMValueRef>(permanent_allocator(), row_count);
auto mask_elems = slice_make<LLVMValueRef>(permanent_allocator(), column_count);
unsigned other_stride = (row_count*column_count)/stride;
LLVMValueRef vector = lb_matrix_to_vector(p, m);
auto mask_elems = slice_make<LLVMValueRef>(permanent_allocator(), row_count * column_count);
for (unsigned i = 0; i < row_count; i++) {
for (unsigned j = 0; j < column_count; j++) {
unsigned offset = stride*j + i;
mask_elems[j] = lb_const_int(p->module, t_u32, offset).value;
mask_elems[other_stride*i + j] = do_u32(p, stride*j + i);
}
// transpose mask
LLVMValueRef mask = LLVMConstVector(mask_elems.data, column_count);
LLVMValueRef row = llvm_basic_shuffle(p, vector, mask);
rows[i] = row;
}
LLVMValueRef mask = LLVMConstVector(mask_elems.data, cast(unsigned)mask_elems.count);
LLVMValueRef transposed_vector = llvm_basic_shuffle(p, vector, mask);
lbAddr res = lb_add_local_generated(p, type, false);
lbAddr res = lb_add_local_generated(p, type, true);
for_array(i, rows) {
LLVMValueRef row = rows[i];
lbValue dst_row_ptr = lb_emit_matrix_epi(p, res.addr, 0, i);
LLVMValueRef ptr = dst_row_ptr.value;
ptr = LLVMBuildPointerCast(p->builder, ptr, LLVMPointerType(LLVMTypeOf(row), 0), "");
LLVMBuildStore(p->builder, row, ptr);
}
LLVMValueRef res_ptr = res.addr.value;
res_ptr = LLVMBuildPointerCast(p->builder, res_ptr, LLVMPointerType(LLVMTypeOf(transposed_vector), 0), "");
LLVMValueRef store = LLVMBuildStore(p->builder, transposed_vector, res_ptr);
LLVMSetAlignment(store, cast(unsigned)type_align_of(type));
return lb_addr_load(p, res);
}