Add optimization edge case for square mat * mat

This commit is contained in:
gingerBill
2026-05-11 11:16:27 +01:00
parent 1ba0ab8790
commit 65ff188c1c
2 changed files with 122 additions and 6 deletions

View File

@@ -976,15 +976,122 @@ gb_internal lbValue lb_emit_matrix_mul(lbProcedure *p, lbValue lhs, lbValue rhs,
unsigned outer_columns = cast(unsigned)yt->Matrix.column_count;
if (!xt->Matrix.is_row_major && lb_is_matrix_simdable(xt)) {
// if (LLVMIsALoadInst(lhs.value) && LLVMIsALoadInst(rhs.value)) {
// auto do_u32 = [](lbProcedure *p, u32 val) -> LLVMValueRef {
// return LLVMConstInt(lb_type(p->module, t_u32), val, false);
// };
// LLVMValueRef llvm_stride = do_u32(p, inner);
// LLVMValueRef llvm_false = LLVMConstInt(lb_type(p->module, t_llvm_bool), false, false);
// LLVMValueRef lhs_args[] = {LLVMGetOperand(lhs.value, 0), llvm_stride, llvm_false, do_u32(p, outer_rows), do_u32(p, inner)};
// LLVMValueRef rhs_args[] = {LLVMGetOperand(rhs.value, 0), llvm_stride, llvm_false, do_u32(p, inner), do_u32(p, outer_columns)};
// LLVMTypeRef types[] = {lb_type(p->module, elem)};
// LLVMValueRef lhs_loaded = lb_call_intrinsic(p, "llvm.matrix.column.major.load", lhs_args, gb_count_of(lhs_args), types, gb_count_of(types));
// LLVMValueRef rhs_loaded = lb_call_intrinsic(p, "llvm.matrix.column.major.load", rhs_args, gb_count_of(rhs_args), types, gb_count_of(types));
// LLVMValueRef mul_args[] = {lhs_loaded, rhs_loaded, do_u32(p, outer_rows), do_u32(p, inner), do_u32(p, outer_columns)};
// LLVMValueRef lhs_mul_rhs = lb_call_intrinsic(p, "llvm.matrix.multiply", mul_args, gb_count_of(mul_args), types, gb_count_of(types));
// lbAddr res = lb_add_local_generated(p, type, false);
// LLVMValueRef store_args[] = {res.addr.value, lhs_mul_rhs, llvm_stride, llvm_false, do_u32(p, inner), do_u32(p, outer_columns)};
// lb_call_intrinsic(p, "llvm.matrix.column.major.store", store_args, gb_count_of(store_args), types, gb_count_of(types));
// return lb_addr_load(p, res);
// }
unsigned x_stride = cast(unsigned)matrix_type_stride_in_elems(xt);
unsigned y_stride = cast(unsigned)matrix_type_stride_in_elems(yt);
auto x_rows = slice_make<LLVMValueRef>(permanent_allocator(), outer_rows);
auto y_columns = slice_make<LLVMValueRef>(permanent_allocator(), outer_columns);
LLVMValueRef x_vector = lb_matrix_to_vector(p, lhs);
LLVMValueRef y_vector = lb_matrix_to_vector(p, rhs);
if (outer_rows == outer_columns && outer_rows == inner && (inner & 1) == 0) {
// square matrix calculation
unsigned N = outer_columns;
auto x_columns = slice_make<LLVMValueRef>(permanent_allocator(), N);
auto y_columns = slice_make<LLVMValueRef>(permanent_allocator(), N);
for (unsigned i = 0; i < N; i++) {
LLVMValueRef mask = llvm_mask_iota(p->module, x_stride*i, inner);
LLVMValueRef column = llvm_basic_shuffle(p, x_vector, mask);
x_columns[i] = column;
}
for (unsigned i = 0; i < N; i++) {
LLVMValueRef mask = llvm_mask_iota(p->module, y_stride*i, inner);
LLVMValueRef column = llvm_basic_shuffle(p, y_vector, mask);
y_columns[i] = column;
}
auto z_columns = slice_make<LLVMValueRef>(permanent_allocator(), N);
auto mask_elems = slice_make<LLVMValueRef>(permanent_allocator(), N);
auto temp_muls = slice_make<LLVMValueRef>(permanent_allocator(), N);
for (unsigned i = 0; i < N; i++) {
for (unsigned j = 0; j < N; j++) {
LLVMValueRef mask = llvm_mask_same(p->module, j, N);
mask_elems[j] = llvm_basic_shuffle(p, y_columns[i], mask);
}
for (unsigned j = 0; j < N; j++) {
if (is_type_float(elem)) {
temp_muls[j] = LLVMBuildFMul(p->builder, mask_elems[j], x_columns[j], "");
// LLVMSetFastMathFlags(temp_muls[j], LLVMFastMathAll);
} else {
temp_muls[j] = LLVMBuildMul(p->builder, mask_elems[j], x_columns[j], "");
}
}
unsigned k = N;
while (k > 1) {
unsigned half = k/2;
for (unsigned j = 0; j < half; j++) {
if (is_type_float(elem)) {
temp_muls[j] = LLVMBuildFAdd(p->builder, temp_muls[2*j + 0], temp_muls[2*j + 1], "");
// LLVMSetFastMathFlags(temp_muls[j], LLVMFastMathAll);
} else {
temp_muls[j] = LLVMBuildAdd(p->builder, temp_muls[2*j + 0], temp_muls[2*j + 1], "");
}
}
if ((k&1) != 0) {
temp_muls[half] = temp_muls[k-1];
}
k = (k+1)/2;
}
z_columns[i] = temp_muls[0];
}
auto do_u32 = [](lbProcedure *p, u32 val) -> LLVMValueRef {
return LLVMConstInt(lb_type(p->module, t_u32), val, false);
};
lbAddr res = lb_add_local_generated(p, type, false);
LLVMValueRef dest_ptr = res.addr.value;
LLVMTypeRef dest_ptr_type = LLVMPointerType(LLVMTypeOf(z_columns[0]), 0);
dest_ptr = LLVMBuildPointerCast(p->builder, dest_ptr, dest_ptr_type, "");
for (unsigned i = 0; i < N; i++) {
LLVMValueRef indices[] = {do_u32(p, i)};
LLVMValueRef dst = LLVMBuildInBoundsGEP2(p->builder, LLVMTypeOf(z_columns[0]), dest_ptr, indices, 1, "");
LLVMBuildStore(p->builder, z_columns[i], dst);
}
return lb_addr_load(p, res);
}
auto x_rows = slice_make<LLVMValueRef>(permanent_allocator(), outer_rows);
auto y_columns = slice_make<LLVMValueRef>(permanent_allocator(), outer_columns);
auto mask_elems = slice_make<LLVMValueRef>(permanent_allocator(), inner);
for (unsigned i = 0; i < outer_rows; i++) {
for (unsigned j = 0; j < inner; j++) {
@@ -1004,7 +1111,7 @@ gb_internal lbValue lb_emit_matrix_mul(lbProcedure *p, lbValue lhs, lbValue rhs,
y_columns[i] = column;
}
lbAddr res = lb_add_local_generated(p, type, true);
lbAddr res = lb_add_local_generated(p, type, false);
for_array(i, x_rows) {
LLVMValueRef x_row = x_rows[i];
for_array(j, y_columns) {
@@ -1018,7 +1125,7 @@ gb_internal lbValue lb_emit_matrix_mul(lbProcedure *p, lbValue lhs, lbValue rhs,
}
if (!xt->Matrix.is_row_major) {
lbAddr res = lb_add_local_generated(p, type, true);
lbAddr res = lb_add_local_generated(p, type, false);
auto inners = slice_make<lbValue[2]>(permanent_allocator(), inner);
@@ -1042,7 +1149,7 @@ gb_internal lbValue lb_emit_matrix_mul(lbProcedure *p, lbValue lhs, lbValue rhs,
return lb_addr_load(p, res);
} else {
lbAddr res = lb_add_local_generated(p, type, true);
lbAddr res = lb_add_local_generated(p, type, false);
auto inners = slice_make<lbValue[2]>(permanent_allocator(), inner);

View File

@@ -2048,6 +2048,15 @@ gb_internal LLVMValueRef llvm_mask_zero(lbModule *m, unsigned count) {
return LLVMConstNull(LLVMVectorType(lb_type(m, t_u32), count));
}
gb_internal LLVMValueRef llvm_mask_same(lbModule *m, unsigned value, unsigned count) {
auto iota = slice_make<LLVMValueRef>(temporary_allocator(), count);
for (unsigned i = 0; i < count; i++) {
iota[i] = lb_const_int(m, t_u32, value).value;
}
return LLVMConstVector(iota.data, count);
}
#define LLVM_VECTOR_DUMMY_VALUE(type) LLVMGetUndef((type))
// #define LLVM_VECTOR_DUMMY_VALUE(type) LLVMConstNull((type))