mirror of
https://github.com/odin-lang/Odin.git
synced 2026-02-28 05:44:57 +00:00
Make lb_emit_matrix_mul_vector use SIMD if possible
This commit is contained in:
@@ -567,11 +567,10 @@ lbValue lb_emit_matrix_mul(lbProcedure *p, lbValue lhs, lbValue rhs, Type *type)
|
||||
GB_ASSERT(xt->Matrix.column_count == yt->Matrix.row_count);
|
||||
GB_ASSERT(are_types_identical(xt->Matrix.elem, yt->Matrix.elem));
|
||||
|
||||
if (!lb_matrix_elem_simple(xt)) {
|
||||
goto slow_form;
|
||||
if (lb_matrix_elem_simple(xt)) {
|
||||
// TODO(bill): SIMD version
|
||||
}
|
||||
|
||||
slow_form:
|
||||
{
|
||||
Type *elem = xt->Matrix.elem;
|
||||
|
||||
@@ -618,6 +617,69 @@ lbValue lb_emit_matrix_mul_vector(lbProcedure *p, lbValue lhs, lbValue rhs, Type
|
||||
GB_ASSERT(are_types_identical(mt->Matrix.elem, base_array_type(vt)));
|
||||
|
||||
Type *elem = mt->Matrix.elem;
|
||||
LLVMTypeRef elem_type = lb_type(p->module, elem);
|
||||
|
||||
unsigned stride = cast(unsigned)matrix_type_stride_in_elems(mt);
|
||||
|
||||
if (lb_matrix_elem_simple(mt)) {
|
||||
unsigned row_count = cast(unsigned)mt->Matrix.row_count; gb_unused(row_count);
|
||||
unsigned column_count = cast(unsigned)mt->Matrix.column_count;
|
||||
auto m_columns = slice_make<LLVMValueRef>(permanent_allocator(), column_count);
|
||||
auto v_rows = slice_make<LLVMValueRef>(permanent_allocator(), column_count);
|
||||
|
||||
unsigned total_count = cast(unsigned)matrix_type_total_elems(mt);
|
||||
LLVMTypeRef total_matrix_type = LLVMVectorType(elem_type, total_count);
|
||||
|
||||
LLVMValueRef lhs_ptr = lb_address_from_load_or_generate_local(p, lhs).value;
|
||||
LLVMValueRef matrix_vector_ptr = LLVMBuildPointerCast(p->builder, lhs_ptr, LLVMPointerType(total_matrix_type, 0), "");
|
||||
LLVMValueRef matrix_vector = LLVMBuildLoad(p->builder, matrix_vector_ptr, "");
|
||||
|
||||
|
||||
for (unsigned column_index = 0; column_index < column_count; column_index++) {
|
||||
LLVMValueRef mask = llvm_mask_iota(p->module, stride*column_index, row_count);
|
||||
LLVMValueRef column = LLVMBuildShuffleVector(p->builder, matrix_vector, LLVMGetUndef(LLVMTypeOf(matrix_vector)), mask, "");
|
||||
m_columns[column_index] = column;
|
||||
}
|
||||
|
||||
for (unsigned row_index = 0; row_index < column_count; row_index++) {
|
||||
LLVMValueRef value = lb_emit_struct_ev(p, rhs, row_index).value;
|
||||
LLVMValueRef row = llvm_splat(p, value, row_count);
|
||||
v_rows[row_index] = row;
|
||||
}
|
||||
|
||||
GB_ASSERT(column_count > 0);
|
||||
|
||||
LLVMValueRef vector = nullptr;
|
||||
if (is_type_float(elem)) {
|
||||
for (i64 i = 0; i < column_count; i++) {
|
||||
LLVMValueRef product = LLVMBuildFMul(p->builder, m_columns[i], v_rows[i], "");
|
||||
if (i == 0) {
|
||||
vector = product;
|
||||
} else {
|
||||
vector = LLVMBuildFAdd(p->builder, vector, product, "");
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (i64 i = 0; i < column_count; i++) {
|
||||
LLVMValueRef product = LLVMBuildMul(p->builder, m_columns[i], v_rows[i], "");
|
||||
if (i == 0) {
|
||||
vector = product;
|
||||
} else {
|
||||
vector = LLVMBuildAdd(p->builder, vector, product, "");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
lbAddr res = lb_add_local_generated(p, type, true);
|
||||
LLVMValueRef res_ptr = res.addr.value;
|
||||
unsigned alignment = cast(unsigned)gb_max(type_align_of(type), lb_alignof(LLVMTypeOf(vector)));
|
||||
LLVMSetAlignment(res_ptr, alignment);
|
||||
|
||||
res_ptr = LLVMBuildPointerCast(p->builder, res_ptr, LLVMPointerType(LLVMTypeOf(vector), 0), "");
|
||||
LLVMBuildStore(p->builder, vector, res_ptr);
|
||||
|
||||
return lb_addr_load(p, res);
|
||||
}
|
||||
|
||||
lbAddr res = lb_add_local_generated(p, type, true);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user