mirror of
https://github.com/odin-lang/Odin.git
synced 2026-05-31 07:55:33 +00:00
Improve code generation for intrinsics.transpose
This commit is contained in:
@@ -822,35 +822,32 @@ gb_internal lbValue lb_emit_matrix_transpose(lbProcedure *p, lbValue m, Type *ty
|
||||
GB_PANIC("TODO: transpose with changing layout");
|
||||
}
|
||||
|
||||
if (lb_is_matrix_simdable(mt) && lb_is_matrix_simdable(type)) {
|
||||
if (lb_is_matrix_simdable(mt, true) && lb_is_matrix_simdable(type, true)) {
|
||||
auto const do_u32 = [](lbProcedure *p, u32 val) -> LLVMValueRef {
|
||||
return LLVMConstInt(lb_type(p->module, t_u32), val, false);
|
||||
};
|
||||
|
||||
unsigned stride = cast(unsigned)matrix_type_stride_in_elems(mt);
|
||||
unsigned row_count = cast(unsigned)mt->Matrix.row_count;
|
||||
unsigned column_count = cast(unsigned)mt->Matrix.column_count;
|
||||
|
||||
auto rows = slice_make<LLVMValueRef>(permanent_allocator(), row_count);
|
||||
auto mask_elems = slice_make<LLVMValueRef>(permanent_allocator(), column_count);
|
||||
unsigned other_stride = (row_count*column_count)/stride;
|
||||
|
||||
LLVMValueRef vector = lb_matrix_to_vector(p, m);
|
||||
auto mask_elems = slice_make<LLVMValueRef>(permanent_allocator(), row_count * column_count);
|
||||
for (unsigned i = 0; i < row_count; i++) {
|
||||
for (unsigned j = 0; j < column_count; j++) {
|
||||
unsigned offset = stride*j + i;
|
||||
mask_elems[j] = lb_const_int(p->module, t_u32, offset).value;
|
||||
mask_elems[other_stride*i + j] = do_u32(p, stride*j + i);
|
||||
}
|
||||
|
||||
// transpose mask
|
||||
LLVMValueRef mask = LLVMConstVector(mask_elems.data, column_count);
|
||||
LLVMValueRef row = llvm_basic_shuffle(p, vector, mask);
|
||||
rows[i] = row;
|
||||
}
|
||||
LLVMValueRef mask = LLVMConstVector(mask_elems.data, cast(unsigned)mask_elems.count);
|
||||
LLVMValueRef transposed_vector = llvm_basic_shuffle(p, vector, mask);
|
||||
lbAddr res = lb_add_local_generated(p, type, false);
|
||||
|
||||
lbAddr res = lb_add_local_generated(p, type, true);
|
||||
for_array(i, rows) {
|
||||
LLVMValueRef row = rows[i];
|
||||
lbValue dst_row_ptr = lb_emit_matrix_epi(p, res.addr, 0, i);
|
||||
LLVMValueRef ptr = dst_row_ptr.value;
|
||||
ptr = LLVMBuildPointerCast(p->builder, ptr, LLVMPointerType(LLVMTypeOf(row), 0), "");
|
||||
LLVMBuildStore(p->builder, row, ptr);
|
||||
}
|
||||
LLVMValueRef res_ptr = res.addr.value;
|
||||
res_ptr = LLVMBuildPointerCast(p->builder, res_ptr, LLVMPointerType(LLVMTypeOf(transposed_vector), 0), "");
|
||||
|
||||
LLVMValueRef store = LLVMBuildStore(p->builder, transposed_vector, res_ptr);
|
||||
LLVMSetAlignment(store, cast(unsigned)type_align_of(type));
|
||||
|
||||
return lb_addr_load(p, res);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user