diff --git a/src/llvm_backend_expr.cpp b/src/llvm_backend_expr.cpp index 22530831b..598ab6d21 100644 --- a/src/llvm_backend_expr.cpp +++ b/src/llvm_backend_expr.cpp @@ -976,15 +976,122 @@ gb_internal lbValue lb_emit_matrix_mul(lbProcedure *p, lbValue lhs, lbValue rhs, unsigned outer_columns = cast(unsigned)yt->Matrix.column_count; if (!xt->Matrix.is_row_major && lb_is_matrix_simdable(xt)) { + + // if (LLVMIsALoadInst(lhs.value) && LLVMIsALoadInst(rhs.value)) { + // auto do_u32 = [](lbProcedure *p, u32 val) -> LLVMValueRef { + // return LLVMConstInt(lb_type(p->module, t_u32), val, false); + // }; + + // LLVMValueRef llvm_stride = do_u32(p, inner); + // LLVMValueRef llvm_false = LLVMConstInt(lb_type(p->module, t_llvm_bool), false, false); + + // LLVMValueRef lhs_args[] = {LLVMGetOperand(lhs.value, 0), llvm_stride, llvm_false, do_u32(p, outer_rows), do_u32(p, inner)}; + // LLVMValueRef rhs_args[] = {LLVMGetOperand(rhs.value, 0), llvm_stride, llvm_false, do_u32(p, inner), do_u32(p, outer_columns)}; + // LLVMTypeRef types[] = {lb_type(p->module, elem)}; + + // LLVMValueRef lhs_loaded = lb_call_intrinsic(p, "llvm.matrix.column.major.load", lhs_args, gb_count_of(lhs_args), types, gb_count_of(types)); + // LLVMValueRef rhs_loaded = lb_call_intrinsic(p, "llvm.matrix.column.major.load", rhs_args, gb_count_of(rhs_args), types, gb_count_of(types)); + + // LLVMValueRef mul_args[] = {lhs_loaded, rhs_loaded, do_u32(p, outer_rows), do_u32(p, inner), do_u32(p, outer_columns)}; + // LLVMValueRef lhs_mul_rhs = lb_call_intrinsic(p, "llvm.matrix.multiply", mul_args, gb_count_of(mul_args), types, gb_count_of(types)); + + // lbAddr res = lb_add_local_generated(p, type, false); + + // LLVMValueRef store_args[] = {res.addr.value, lhs_mul_rhs, llvm_stride, llvm_false, do_u32(p, inner), do_u32(p, outer_columns)}; + // lb_call_intrinsic(p, "llvm.matrix.column.major.store", store_args, gb_count_of(store_args), types, gb_count_of(types)); + + // return lb_addr_load(p, res); + // } + + + unsigned x_stride = cast(unsigned)matrix_type_stride_in_elems(xt); unsigned y_stride = cast(unsigned)matrix_type_stride_in_elems(yt); - auto x_rows = slice_make(permanent_allocator(), outer_rows); - auto y_columns = slice_make(permanent_allocator(), outer_columns); - LLVMValueRef x_vector = lb_matrix_to_vector(p, lhs); LLVMValueRef y_vector = lb_matrix_to_vector(p, rhs); + if (outer_rows == outer_columns && outer_rows == inner && (inner & 1) == 0) { + // square matrix calculation + unsigned N = outer_columns; + + auto x_columns = slice_make(permanent_allocator(), N); + auto y_columns = slice_make(permanent_allocator(), N); + + for (unsigned i = 0; i < N; i++) { + LLVMValueRef mask = llvm_mask_iota(p->module, x_stride*i, inner); + LLVMValueRef column = llvm_basic_shuffle(p, x_vector, mask); + x_columns[i] = column; + } + + for (unsigned i = 0; i < N; i++) { + LLVMValueRef mask = llvm_mask_iota(p->module, y_stride*i, inner); + LLVMValueRef column = llvm_basic_shuffle(p, y_vector, mask); + y_columns[i] = column; + } + + + auto z_columns = slice_make(permanent_allocator(), N); + + auto mask_elems = slice_make(permanent_allocator(), N); + auto temp_muls = slice_make(permanent_allocator(), N); + + for (unsigned i = 0; i < N; i++) { + for (unsigned j = 0; j < N; j++) { + LLVMValueRef mask = llvm_mask_same(p->module, j, N); + mask_elems[j] = llvm_basic_shuffle(p, y_columns[i], mask); + } + for (unsigned j = 0; j < N; j++) { + if (is_type_float(elem)) { + temp_muls[j] = LLVMBuildFMul(p->builder, mask_elems[j], x_columns[j], ""); + // LLVMSetFastMathFlags(temp_muls[j], LLVMFastMathAll); + } else { + temp_muls[j] = LLVMBuildMul(p->builder, mask_elems[j], x_columns[j], ""); + } + } + unsigned k = N; + while (k > 1) { + unsigned half = k/2; + for (unsigned j = 0; j < half; j++) { + if (is_type_float(elem)) { + temp_muls[j] = LLVMBuildFAdd(p->builder, temp_muls[2*j + 0], temp_muls[2*j + 1], ""); + // LLVMSetFastMathFlags(temp_muls[j], LLVMFastMathAll); + } else { + temp_muls[j] = LLVMBuildAdd(p->builder, temp_muls[2*j + 0], temp_muls[2*j + 1], ""); + } + } + + if ((k&1) != 0) { + temp_muls[half] = temp_muls[k-1]; + } + k = (k+1)/2; + } + + z_columns[i] = temp_muls[0]; + } + + auto do_u32 = [](lbProcedure *p, u32 val) -> LLVMValueRef { + return LLVMConstInt(lb_type(p->module, t_u32), val, false); + }; + + lbAddr res = lb_add_local_generated(p, type, false); + LLVMValueRef dest_ptr = res.addr.value; + + LLVMTypeRef dest_ptr_type = LLVMPointerType(LLVMTypeOf(z_columns[0]), 0); + dest_ptr = LLVMBuildPointerCast(p->builder, dest_ptr, dest_ptr_type, ""); + for (unsigned i = 0; i < N; i++) { + LLVMValueRef indices[] = {do_u32(p, i)}; + LLVMValueRef dst = LLVMBuildInBoundsGEP2(p->builder, LLVMTypeOf(z_columns[0]), dest_ptr, indices, 1, ""); + LLVMBuildStore(p->builder, z_columns[i], dst); + } + + return lb_addr_load(p, res); + } + + + auto x_rows = slice_make(permanent_allocator(), outer_rows); + auto y_columns = slice_make(permanent_allocator(), outer_columns); + auto mask_elems = slice_make(permanent_allocator(), inner); for (unsigned i = 0; i < outer_rows; i++) { for (unsigned j = 0; j < inner; j++) { @@ -1004,7 +1111,7 @@ gb_internal lbValue lb_emit_matrix_mul(lbProcedure *p, lbValue lhs, lbValue rhs, y_columns[i] = column; } - lbAddr res = lb_add_local_generated(p, type, true); + lbAddr res = lb_add_local_generated(p, type, false); for_array(i, x_rows) { LLVMValueRef x_row = x_rows[i]; for_array(j, y_columns) { @@ -1018,7 +1125,7 @@ gb_internal lbValue lb_emit_matrix_mul(lbProcedure *p, lbValue lhs, lbValue rhs, } if (!xt->Matrix.is_row_major) { - lbAddr res = lb_add_local_generated(p, type, true); + lbAddr res = lb_add_local_generated(p, type, false); auto inners = slice_make(permanent_allocator(), inner); @@ -1042,7 +1149,7 @@ gb_internal lbValue lb_emit_matrix_mul(lbProcedure *p, lbValue lhs, lbValue rhs, return lb_addr_load(p, res); } else { - lbAddr res = lb_add_local_generated(p, type, true); + lbAddr res = lb_add_local_generated(p, type, false); auto inners = slice_make(permanent_allocator(), inner); diff --git a/src/llvm_backend_utility.cpp b/src/llvm_backend_utility.cpp index a04f91fbd..d101d28c2 100644 --- a/src/llvm_backend_utility.cpp +++ b/src/llvm_backend_utility.cpp @@ -2048,6 +2048,15 @@ gb_internal LLVMValueRef llvm_mask_zero(lbModule *m, unsigned count) { return LLVMConstNull(LLVMVectorType(lb_type(m, t_u32), count)); } +gb_internal LLVMValueRef llvm_mask_same(lbModule *m, unsigned value, unsigned count) { + auto iota = slice_make(temporary_allocator(), count); + for (unsigned i = 0; i < count; i++) { + iota[i] = lb_const_int(m, t_u32, value).value; + } + return LLVMConstVector(iota.data, count); +} + + #define LLVM_VECTOR_DUMMY_VALUE(type) LLVMGetUndef((type)) // #define LLVM_VECTOR_DUMMY_VALUE(type) LLVMConstNull((type))