From d0d9a3a4f4f3b4bc528c73ffcecb31d3eb4162a7 Mon Sep 17 00:00:00 2001 From: gingerBill Date: Wed, 20 Oct 2021 14:49:20 +0100 Subject: [PATCH] Make `lb_emit_matrix_mul` SIMD if possible --- src/llvm_backend_expr.cpp | 144 ++++++++++++++++++++--------------- src/llvm_backend_utility.cpp | 29 ++++++- 2 files changed, 110 insertions(+), 63 deletions(-) diff --git a/src/llvm_backend_expr.cpp b/src/llvm_backend_expr.cpp index c0a7a9edf..22e66c147 100644 --- a/src/llvm_backend_expr.cpp +++ b/src/llvm_backend_expr.cpp @@ -557,6 +557,20 @@ lbValue lb_emit_outer_product(lbProcedure *p, lbValue a, lbValue b, Type *type) } +LLVMValueRef lb_matrix_to_vector(lbProcedure *p, lbValue matrix) { + Type *mt = base_type(matrix.type); + GB_ASSERT(mt->kind == Type_Matrix); + LLVMTypeRef elem_type = lb_type(p->module, mt->Matrix.elem); + + unsigned total_count = cast(unsigned)matrix_type_total_elems(mt); + LLVMTypeRef total_matrix_type = LLVMVectorType(elem_type, total_count); + + LLVMValueRef ptr = lb_address_from_load_or_generate_local(p, matrix).value; + LLVMValueRef matrix_vector_ptr = LLVMBuildPointerCast(p->builder, ptr, LLVMPointerType(total_matrix_type, 0), ""); + LLVMValueRef matrix_vector = LLVMBuildLoad(p->builder, matrix_vector_ptr, ""); + return matrix_vector; +} + lbValue lb_emit_matrix_mul(lbProcedure *p, lbValue lhs, lbValue rhs, Type *type) { Type *xt = base_type(lhs.type); Type *yt = base_type(rhs.type); @@ -567,31 +581,72 @@ lbValue lb_emit_matrix_mul(lbProcedure *p, lbValue lhs, lbValue rhs, Type *type) GB_ASSERT(xt->Matrix.column_count == yt->Matrix.row_count); GB_ASSERT(are_types_identical(xt->Matrix.elem, yt->Matrix.elem)); + Type *elem = xt->Matrix.elem; + + unsigned outer_rows = cast(unsigned)xt->Matrix.row_count; + unsigned inner = cast(unsigned)xt->Matrix.column_count; + unsigned outer_columns = cast(unsigned)yt->Matrix.column_count; + if (lb_matrix_elem_simple(xt)) { - // TODO(bill): SIMD version + unsigned x_stride = cast(unsigned)matrix_type_stride_in_elems(xt); + unsigned y_stride = cast(unsigned)matrix_type_stride_in_elems(yt); + + auto x_rows = slice_make(permanent_allocator(), outer_rows); + auto y_columns = slice_make(permanent_allocator(), outer_columns); + + + LLVMValueRef x_vector = lb_matrix_to_vector(p, lhs); + LLVMValueRef y_vector = lb_matrix_to_vector(p, rhs); + + for (unsigned i = 0; i < outer_rows; i++) { + auto mask_elems = slice_make(temporary_allocator(), inner); + for (unsigned j = 0; j < inner; j++) { + unsigned offset = x_stride*j + i; + mask_elems[j] = lb_const_int(p->module, t_u32, offset).value; + } + + // transpose mask + LLVMValueRef mask = LLVMConstVector(mask_elems.data, inner); + LLVMValueRef row = LLVMBuildShuffleVector(p->builder, x_vector, LLVMGetUndef(LLVMTypeOf(x_vector)), mask, ""); + x_rows[i] = row; + } + + for (unsigned i = 0; i < outer_columns; i++) { + LLVMValueRef mask = llvm_mask_iota(p->module, y_stride*i, inner); + LLVMValueRef column = LLVMBuildShuffleVector(p->builder, y_vector, LLVMGetUndef(LLVMTypeOf(y_vector)), mask, ""); + y_columns[i] = column; + } + + + + lbAddr res = lb_add_local_generated(p, type, true); + for_array(i, x_rows) { + LLVMValueRef x_row = x_rows[i]; + for_array(j, y_columns) { + LLVMValueRef y_column = y_columns[j]; + LLVMValueRef elem = llvm_vector_dot(p, x_row, y_column); + lbValue dst = lb_emit_matrix_epi(p, res.addr, i, j); + LLVMBuildStore(p->builder, elem, dst.value); + } + } + return lb_addr_load(p, res); } { - Type *elem = xt->Matrix.elem; - lbAddr res = lb_add_local_generated(p, type, true); - i64 outer_rows = xt->Matrix.row_count; - i64 inner = xt->Matrix.column_count; - i64 outer_columns = yt->Matrix.column_count; - auto inners = slice_make(permanent_allocator(), inner); - for (i64 j = 0; j < outer_columns; j++) { - for (i64 i = 0; i < outer_rows; i++) { + for (unsigned j = 0; j < outer_columns; j++) { + for (unsigned i = 0; i < outer_rows; i++) { lbValue dst = lb_emit_matrix_epi(p, res.addr, i, j); - for (i64 k = 0; k < inner; k++) { + for (unsigned k = 0; k < inner; k++) { inners[k][0] = lb_emit_matrix_ev(p, lhs, i, k); inners[k][1] = lb_emit_matrix_ev(p, rhs, k, j); } - lbValue sum = lb_emit_load(p, dst); - for (i64 k = 0; k < inner; k++) { + lbValue sum = lb_const_nil(p->module, elem); + for (unsigned k = 0; k < inner; k++) { lbValue a = inners[k][0]; lbValue b = inners[k][1]; sum = lb_emit_mul_add(p, a, b, sum, elem); @@ -617,7 +672,6 @@ lbValue lb_emit_matrix_mul_vector(lbProcedure *p, lbValue lhs, lbValue rhs, Type GB_ASSERT(are_types_identical(mt->Matrix.elem, base_array_type(vt))); Type *elem = mt->Matrix.elem; - LLVMTypeRef elem_type = lb_type(p->module, elem); if (lb_matrix_elem_simple(mt)) { unsigned stride = cast(unsigned)matrix_type_stride_in_elems(mt); @@ -627,13 +681,7 @@ lbValue lb_emit_matrix_mul_vector(lbProcedure *p, lbValue lhs, lbValue rhs, Type auto m_columns = slice_make(permanent_allocator(), column_count); auto v_rows = slice_make(permanent_allocator(), column_count); - unsigned total_count = cast(unsigned)matrix_type_total_elems(mt); - LLVMTypeRef total_matrix_type = LLVMVectorType(elem_type, total_count); - - LLVMValueRef lhs_ptr = lb_address_from_load_or_generate_local(p, lhs).value; - LLVMValueRef matrix_vector_ptr = LLVMBuildPointerCast(p->builder, lhs_ptr, LLVMPointerType(total_matrix_type, 0), ""); - LLVMValueRef matrix_vector = LLVMBuildLoad(p->builder, matrix_vector_ptr, ""); - + LLVMValueRef matrix_vector = lb_matrix_to_vector(p, lhs); for (unsigned column_index = 0; column_index < column_count; column_index++) { LLVMValueRef mask = llvm_mask_iota(p->module, stride*column_index, row_count); @@ -650,23 +698,12 @@ lbValue lb_emit_matrix_mul_vector(lbProcedure *p, lbValue lhs, lbValue rhs, Type GB_ASSERT(column_count > 0); LLVMValueRef vector = nullptr; - if (is_type_float(elem)) { - for (i64 i = 0; i < column_count; i++) { - LLVMValueRef product = LLVMBuildFMul(p->builder, m_columns[i], v_rows[i], ""); - if (i == 0) { - vector = product; - } else { - vector = LLVMBuildFAdd(p->builder, vector, product, ""); - } - } - } else { - for (i64 i = 0; i < column_count; i++) { - LLVMValueRef product = LLVMBuildMul(p->builder, m_columns[i], v_rows[i], ""); - if (i == 0) { - vector = product; - } else { - vector = LLVMBuildAdd(p->builder, vector, product, ""); - } + for (i64 i = 0; i < column_count; i++) { + LLVMValueRef product = llvm_vector_mul(p, m_columns[i], v_rows[i]); + if (i == 0) { + vector = product; + } else { + vector = llvm_vector_add(p, vector, product); } } @@ -712,7 +749,6 @@ lbValue lb_emit_vector_mul_matrix(lbProcedure *p, lbValue lhs, lbValue rhs, Type GB_ASSERT(are_types_identical(mt->Matrix.elem, base_array_type(vt))); Type *elem = mt->Matrix.elem; - LLVMTypeRef elem_type = lb_type(p->module, elem); if (lb_matrix_elem_simple(mt)) { unsigned stride = cast(unsigned)matrix_type_stride_in_elems(mt); @@ -722,13 +758,8 @@ lbValue lb_emit_vector_mul_matrix(lbProcedure *p, lbValue lhs, lbValue rhs, Type auto m_columns = slice_make(permanent_allocator(), row_count); auto v_rows = slice_make(permanent_allocator(), row_count); - unsigned total_count = cast(unsigned)matrix_type_total_elems(mt); - LLVMTypeRef total_matrix_type = LLVMVectorType(elem_type, total_count); - - LLVMValueRef matrix_ptr = lb_address_from_load_or_generate_local(p, rhs).value; - LLVMValueRef matrix_vector_ptr = LLVMBuildPointerCast(p->builder, matrix_ptr, LLVMPointerType(total_matrix_type, 0), ""); - LLVMValueRef matrix_vector = LLVMBuildLoad(p->builder, matrix_vector_ptr, ""); - + LLVMValueRef matrix_vector = lb_matrix_to_vector(p, rhs); + for (unsigned row_index = 0; row_index < row_count; row_index++) { auto mask_elems = slice_make(temporary_allocator(), column_count); for (unsigned column_index = 0; column_index < column_count; column_index++) { @@ -751,23 +782,12 @@ lbValue lb_emit_vector_mul_matrix(lbProcedure *p, lbValue lhs, lbValue rhs, Type GB_ASSERT(row_count > 0); LLVMValueRef vector = nullptr; - if (is_type_float(elem)) { - for (i64 i = 0; i < row_count; i++) { - LLVMValueRef product = LLVMBuildFMul(p->builder, v_rows[i], m_columns[i], ""); - if (i == 0) { - vector = product; - } else { - vector = LLVMBuildFAdd(p->builder, vector, product, ""); - } - } - } else { - for (i64 i = 0; i < row_count; i++) { - LLVMValueRef product = LLVMBuildMul(p->builder, v_rows[i], m_columns[i], ""); - if (i == 0) { - vector = product; - } else { - vector = LLVMBuildAdd(p->builder, vector, product, ""); - } + for (i64 i = 0; i < row_count; i++) { + LLVMValueRef product = llvm_vector_mul(p, v_rows[i], m_columns[i]); + if (i == 0) { + vector = product; + } else { + vector = llvm_vector_add(p, vector, product); } } diff --git a/src/llvm_backend_utility.cpp b/src/llvm_backend_utility.cpp index b07dc3459..6754ce798 100644 --- a/src/llvm_backend_utility.cpp +++ b/src/llvm_backend_utility.cpp @@ -1577,7 +1577,7 @@ LLVMValueRef llvm_vector_reduce_add(lbProcedure *p, LLVMValueRef value) { GB_ASSERT_MSG(id != 0, "Unable to find %s", name); LLVMTypeRef types[1] = {}; - types[0] = elem; + types[0] = type; LLVMValueRef ip = LLVMGetIntrinsicDeclaration(p->module->mod, id, types, gb_count_of(types)); LLVMValueRef values[2] = {}; @@ -1585,4 +1585,31 @@ LLVMValueRef llvm_vector_reduce_add(lbProcedure *p, LLVMValueRef value) { values[1] = value; LLVMValueRef call = LLVMBuildCall(p->builder, ip, values+value_offset, value_count, ""); return call; +} + +LLVMValueRef llvm_vector_add(lbProcedure *p, LLVMValueRef a, LLVMValueRef b) { + GB_ASSERT(LLVMTypeOf(a) == LLVMTypeOf(b)); + + LLVMTypeRef elem = LLVMGetElementType(LLVMTypeOf(a)); + + if (LLVMGetTypeKind(elem) == LLVMIntegerTypeKind) { + return LLVMBuildAdd(p->builder, a, b, ""); + } + return LLVMBuildFAdd(p->builder, a, b, ""); +} + +LLVMValueRef llvm_vector_mul(lbProcedure *p, LLVMValueRef a, LLVMValueRef b) { + GB_ASSERT(LLVMTypeOf(a) == LLVMTypeOf(b)); + + LLVMTypeRef elem = LLVMGetElementType(LLVMTypeOf(a)); + + if (LLVMGetTypeKind(elem) == LLVMIntegerTypeKind) { + return LLVMBuildMul(p->builder, a, b, ""); + } + return LLVMBuildFMul(p->builder, a, b, ""); +} + + +LLVMValueRef llvm_vector_dot(lbProcedure *p, LLVMValueRef a, LLVMValueRef b) { + return llvm_vector_reduce_add(p, llvm_vector_mul(p, a, b)); } \ No newline at end of file