From b752ff4bdbda17d7f7e48b002e2aedd51b11f2f5 Mon Sep 17 00:00:00 2001 From: gingerBill Date: Mon, 11 May 2026 13:28:54 +0100 Subject: [PATCH] Add a minor optimization for `row_major * row_major` --- src/llvm_backend_expr.cpp | 270 +++++++++++++++++++++-------------- src/llvm_backend_utility.cpp | 24 ++++ 2 files changed, 184 insertions(+), 110 deletions(-) diff --git a/src/llvm_backend_expr.cpp b/src/llvm_backend_expr.cpp index 7f45b89dd..ba2bea7cd 100644 --- a/src/llvm_backend_expr.cpp +++ b/src/llvm_backend_expr.cpp @@ -672,7 +672,7 @@ gb_internal lbValue lb_emit_arith_array(lbProcedure *p, TokenKind op, lbValue lh } } -gb_internal bool lb_is_matrix_simdable(Type *t) { +gb_internal bool lb_is_matrix_simdable(Type *t, bool ignore_layout=false) { Type *mt = base_type(t); GB_ASSERT(mt->kind == Type_Matrix); @@ -701,8 +701,10 @@ gb_internal bool lb_is_matrix_simdable(Type *t) { return false; } if (mt->Matrix.is_row_major) { - // TODO(bill): make #row_major matrices work with SIMD - return false; + if (!ignore_layout) { + // TODO(bill): make #row_major matrices work with SIMD + return false; + } } if (elem->kind == Type_Basic) { @@ -959,6 +961,10 @@ gb_internal lbValue lb_emit_outer_product(lbProcedure *p, lbValue a, lbValue b, gb_internal lbValue lb_emit_matrix_mul(lbProcedure *p, lbValue lhs, lbValue rhs, Type *type) { // TODO(bill): Handle edge case for f16 types on x86(-64) platforms + auto const do_u32 = [](lbProcedure *p, u32 val) -> LLVMValueRef { + return LLVMConstInt(lb_type(p->module, t_u32), val, false); + }; + Type *xt = base_type(lhs.type); Type *yt = base_type(rhs.type); @@ -975,114 +981,179 @@ gb_internal lbValue lb_emit_matrix_mul(lbProcedure *p, lbValue lhs, lbValue rhs, unsigned inner = cast(unsigned)xt->Matrix.column_count; unsigned outer_columns = cast(unsigned)yt->Matrix.column_count; - if (!xt->Matrix.is_row_major && lb_is_matrix_simdable(xt)) { - unsigned x_stride = cast(unsigned)matrix_type_stride_in_elems(xt); - unsigned y_stride = cast(unsigned)matrix_type_stride_in_elems(yt); + if (lb_is_matrix_simdable(xt, true)) { + if (!xt->Matrix.is_row_major) { // #column_major + unsigned x_stride = cast(unsigned)matrix_type_stride_in_elems(xt); + unsigned y_stride = cast(unsigned)matrix_type_stride_in_elems(yt); - LLVMValueRef x_vector = lb_matrix_to_vector(p, lhs); - LLVMValueRef y_vector = lb_matrix_to_vector(p, rhs); + LLVMValueRef x_vector = lb_matrix_to_vector(p, lhs); + LLVMValueRef y_vector = lb_matrix_to_vector(p, rhs); - if (outer_rows == outer_columns && outer_rows == inner && (inner & 1) == 0) { - // square matrix calculation - unsigned N = outer_columns; + if (outer_rows == outer_columns && outer_rows == inner && (inner & 1) == 0) { + // square matrix calculation + unsigned N = outer_columns; - auto x_columns = slice_make(permanent_allocator(), N); - auto y_columns = slice_make(permanent_allocator(), N); + auto x_columns = slice_make(permanent_allocator(), N); + auto y_columns = slice_make(permanent_allocator(), N); - for (unsigned i = 0; i < N; i++) { - LLVMValueRef mask = llvm_mask_iota(p->module, x_stride*i, inner); - LLVMValueRef column = llvm_basic_shuffle(p, x_vector, mask); - x_columns[i] = column; + for (unsigned i = 0; i < N; i++) { + LLVMValueRef mask = llvm_mask_iota(p->module, x_stride*i, N); + LLVMValueRef column = llvm_basic_shuffle(p, x_vector, mask); + x_columns[i] = column; + } + + for (unsigned i = 0; i < N; i++) { + LLVMValueRef mask = llvm_mask_iota(p->module, y_stride*i, N); + LLVMValueRef column = llvm_basic_shuffle(p, y_vector, mask); + y_columns[i] = column; + } + + + auto z_columns = slice_make(permanent_allocator(), N); + auto mask_elems = slice_make(permanent_allocator(), N); + + for (unsigned i = 0; i < N; i++) { + for (unsigned j = 0; j < N; j++) { + LLVMValueRef mask = llvm_mask_same(p->module, j, N); + mask_elems[j] = llvm_basic_shuffle(p, y_columns[i], mask); + } + z_columns[i] = llvm_vector_mul_pairwise_reduce_add(p, mask_elems, x_columns); + } + + lbAddr res = lb_add_local_generated(p, type, false); + LLVMValueRef dest_ptr = res.addr.value; + + LLVMTypeRef dest_ptr_type = LLVMPointerType(LLVMTypeOf(z_columns[0]), 0); + dest_ptr = LLVMBuildPointerCast(p->builder, dest_ptr, dest_ptr_type, ""); + for (unsigned i = 0; i < N; i++) { + LLVMValueRef indices[] = {do_u32(p, i)}; + LLVMValueRef dst = LLVMBuildInBoundsGEP2(p->builder, LLVMTypeOf(z_columns[0]), dest_ptr, indices, 1, ""); + LLVMBuildStore(p->builder, z_columns[i], dst); + } + + return lb_addr_load(p, res); } - for (unsigned i = 0; i < N; i++) { + + auto x_rows = slice_make(permanent_allocator(), outer_rows); + auto y_columns = slice_make(permanent_allocator(), outer_columns); + + auto mask_elems = slice_make(permanent_allocator(), inner); + for (unsigned i = 0; i < outer_rows; i++) { + for (unsigned j = 0; j < inner; j++) { + unsigned offset = x_stride*j + i; + mask_elems[j] = do_u32(p, offset); + } + + // transpose mask + LLVMValueRef mask = LLVMConstVector(mask_elems.data, inner); + LLVMValueRef row = llvm_basic_shuffle(p, x_vector, mask); + x_rows[i] = row; + } + + for (unsigned i = 0; i < outer_columns; i++) { LLVMValueRef mask = llvm_mask_iota(p->module, y_stride*i, inner); LLVMValueRef column = llvm_basic_shuffle(p, y_vector, mask); y_columns[i] = column; } - - auto z_columns = slice_make(permanent_allocator(), N); - - auto mask_elems = slice_make(permanent_allocator(), N); - auto temp_muls = slice_make(permanent_allocator(), N); - - for (unsigned i = 0; i < N; i++) { - for (unsigned j = 0; j < N; j++) { - LLVMValueRef mask = llvm_mask_same(p->module, j, N); - mask_elems[j] = llvm_basic_shuffle(p, y_columns[i], mask); + lbAddr res = lb_add_local_generated(p, type, false); + for_array(i, x_rows) { + LLVMValueRef x_row = x_rows[i]; + for_array(j, y_columns) { + LLVMValueRef y_column = y_columns[j]; + LLVMValueRef elem = llvm_vector_dot(p, x_row, y_column); + lbValue dst = lb_emit_matrix_epi(p, res.addr, i, j); + LLVMBuildStore(p->builder, elem, dst.value); } - for (unsigned j = 0; j < N; j++) { - temp_muls[j] = llvm_vector_mul(p, mask_elems[j], x_columns[j]); + } + return lb_addr_load(p, res); + } else { // #row_major + unsigned x_stride = cast(unsigned)matrix_type_stride_in_elems(xt); + unsigned y_stride = cast(unsigned)matrix_type_stride_in_elems(yt); + + LLVMValueRef x_vector = lb_matrix_to_vector(p, lhs); + LLVMValueRef y_vector = lb_matrix_to_vector(p, rhs); + + if (outer_rows == outer_columns && outer_rows == inner && (inner & 1) == 0) { + // square matrix calculation + unsigned N = outer_columns; + + auto x_rows = slice_make(permanent_allocator(), N); + auto y_rows = slice_make(permanent_allocator(), N); + + for (unsigned i = 0; i < N; i++) { + LLVMValueRef mask = llvm_mask_iota(p->module, x_stride*i, N); + LLVMValueRef column = llvm_basic_shuffle(p, x_vector, mask); + x_rows[i] = column; } - unsigned k = N; - while (k > 1) { - unsigned half = k/2; - for (unsigned j = 0; j < half; j++) { - temp_muls[j] = llvm_vector_add(p, temp_muls[2*j + 0], temp_muls[2*j + 1]); + + for (unsigned i = 0; i < N; i++) { + LLVMValueRef mask = llvm_mask_iota(p->module, y_stride*i, N); + LLVMValueRef column = llvm_basic_shuffle(p, y_vector, mask); + y_rows[i] = column; + } + + + auto z_rows = slice_make(permanent_allocator(), N); + auto mask_elems = slice_make(permanent_allocator(), N); + + for (unsigned i = 0; i < N; i++) { + for (unsigned j = 0; j < N; j++) { + LLVMValueRef mask = llvm_mask_same(p->module, j, N); + mask_elems[j] = llvm_basic_shuffle(p, x_rows[i], mask); } - - if ((k&1) != 0) { - temp_muls[half] = temp_muls[k-1]; - } - k = (k+1)/2; + z_rows[i] = llvm_vector_mul_pairwise_reduce_add(p, mask_elems, y_rows); } - z_columns[i] = temp_muls[0]; + lbAddr res = lb_add_local_generated(p, type, false); + LLVMValueRef dest_ptr = res.addr.value; + + LLVMTypeRef dest_ptr_type = LLVMPointerType(LLVMTypeOf(z_rows[0]), 0); + dest_ptr = LLVMBuildPointerCast(p->builder, dest_ptr, dest_ptr_type, ""); + for (unsigned i = 0; i < N; i++) { + LLVMValueRef indices[] = {do_u32(p, i)}; + LLVMValueRef dst = LLVMBuildInBoundsGEP2(p->builder, LLVMTypeOf(z_rows[0]), dest_ptr, indices, 1, ""); + LLVMBuildStore(p->builder, z_rows[i], dst); + } + + return lb_addr_load(p, res); } - auto do_u32 = [](lbProcedure *p, u32 val) -> LLVMValueRef { - return LLVMConstInt(lb_type(p->module, t_u32), val, false); - }; + auto x_rows = slice_make(permanent_allocator(), outer_rows); + auto y_columns = slice_make(permanent_allocator(), outer_columns); + + for (unsigned i = 0; i < outer_rows; i++) { + LLVMValueRef mask = llvm_mask_iota(p->module, x_stride*i, inner); + LLVMValueRef row = llvm_basic_shuffle(p, x_vector, mask); + x_rows[i] = row; + } + + auto mask_elems = slice_make(permanent_allocator(), inner); + for (unsigned i = 0; i < outer_columns; i++) { + for (unsigned j = 0; j < inner; j++) { + unsigned offset = x_stride*j + i; + mask_elems[j] = do_u32(p, offset); + } + + // transpose mask + LLVMValueRef mask = LLVMConstVector(mask_elems.data, inner); + LLVMValueRef column = llvm_basic_shuffle(p, y_vector, mask); + y_columns[i] = column; + } lbAddr res = lb_add_local_generated(p, type, false); - LLVMValueRef dest_ptr = res.addr.value; - - LLVMTypeRef dest_ptr_type = LLVMPointerType(LLVMTypeOf(z_columns[0]), 0); - dest_ptr = LLVMBuildPointerCast(p->builder, dest_ptr, dest_ptr_type, ""); - for (unsigned i = 0; i < N; i++) { - LLVMValueRef indices[] = {do_u32(p, i)}; - LLVMValueRef dst = LLVMBuildInBoundsGEP2(p->builder, LLVMTypeOf(z_columns[0]), dest_ptr, indices, 1, ""); - LLVMBuildStore(p->builder, z_columns[i], dst); + for_array(i, x_rows) { + LLVMValueRef x_row = x_rows[i]; + for_array(j, y_columns) { + LLVMValueRef y_column = y_columns[j]; + LLVMValueRef elem = llvm_vector_dot(p, x_row, y_column); + lbValue dst = lb_emit_matrix_epi(p, res.addr, i, j); + LLVMBuildStore(p->builder, elem, dst.value); + } } - return lb_addr_load(p, res); } - - - auto x_rows = slice_make(permanent_allocator(), outer_rows); - auto y_columns = slice_make(permanent_allocator(), outer_columns); - - auto mask_elems = slice_make(permanent_allocator(), inner); - for (unsigned i = 0; i < outer_rows; i++) { - for (unsigned j = 0; j < inner; j++) { - unsigned offset = x_stride*j + i; - mask_elems[j] = lb_const_int(p->module, t_u32, offset).value; - } - - // transpose mask - LLVMValueRef mask = LLVMConstVector(mask_elems.data, inner); - LLVMValueRef row = llvm_basic_shuffle(p, x_vector, mask); - x_rows[i] = row; - } - - for (unsigned i = 0; i < outer_columns; i++) { - LLVMValueRef mask = llvm_mask_iota(p->module, y_stride*i, inner); - LLVMValueRef column = llvm_basic_shuffle(p, y_vector, mask); - y_columns[i] = column; - } - - lbAddr res = lb_add_local_generated(p, type, false); - for_array(i, x_rows) { - LLVMValueRef x_row = x_rows[i]; - for_array(j, y_columns) { - LLVMValueRef y_column = y_columns[j]; - LLVMValueRef elem = llvm_vector_dot(p, x_row, y_column); - lbValue dst = lb_emit_matrix_epi(p, res.addr, i, j); - LLVMBuildStore(p->builder, elem, dst.value); - } - } - return lb_addr_load(p, res); } if (!xt->Matrix.is_row_major) { @@ -1186,28 +1257,7 @@ gb_internal lbValue lb_emit_matrix_mul_vector(lbProcedure *p, lbValue lhs, lbVal } } - auto temps = slice_make(permanent_allocator(), column_count); - for (unsigned i = 0; i < column_count; i++) { - temps[i] = llvm_vector_mul(p, m_columns[i], v_rows[i]); - } - - GB_ASSERT(column_count > 0); - - unsigned k = column_count; - while (k > 1) { - unsigned half = k/2; - for (unsigned j = 0; j < half; j++) { - temps[j] = llvm_vector_add(p, temps[2*j + 0], temps[2*j + 1]); - } - - if ((k&1) != 0) { - temps[half] = temps[k-1]; - } - k = (k+1)/2; - } - - LLVMValueRef vector = temps[0]; - + LLVMValueRef vector = llvm_vector_mul_pairwise_reduce_add(p, m_columns, v_rows); return lb_matrix_cast_vector_to_type(p, vector, type); } diff --git a/src/llvm_backend_utility.cpp b/src/llvm_backend_utility.cpp index d101d28c2..25481b2ed 100644 --- a/src/llvm_backend_utility.cpp +++ b/src/llvm_backend_utility.cpp @@ -2230,6 +2230,30 @@ gb_internal LLVMValueRef llvm_vector_mul(lbProcedure *p, LLVMValueRef a, LLVMVal return LLVMBuildFMul(p->builder, a, b, ""); } +gb_internal LLVMValueRef llvm_vector_mul_pairwise_reduce_add(lbProcedure *p, Slice const &a, Slice const &b) { + GB_ASSERT(a.count == b.count); + + auto temps = slice_make(temporary_allocator(), a.count); + for (unsigned i = 0; i < a.count; i++) { + temps[i] = llvm_vector_mul(p, a[i], b[i]); + } + + unsigned k = cast(unsigned)a.count; + while (k > 1) { + unsigned half = k/2; + for (unsigned j = 0; j < half; j++) { + temps[j] = llvm_vector_add(p, temps[2*j + 0], temps[2*j + 1]); + } + + if ((k&1) != 0) { + temps[half] = temps[k-1]; + } + k = (k+1)/2; + } + + return temps[0]; +} + gb_internal LLVMValueRef llvm_vector_dot(lbProcedure *p, LLVMValueRef a, LLVMValueRef b) { return llvm_vector_reduce_add(p, llvm_vector_mul(p, a, b));