From 1bfbed0e02b4cd947acf9693f09016ec609356e1 Mon Sep 17 00:00:00 2001 From: gingerBill Date: Wed, 20 Oct 2021 12:48:48 +0100 Subject: [PATCH] Add `llvm_vector_reduce_add` --- src/llvm_backend_expr.cpp | 3 ++- src/llvm_backend_utility.cpp | 41 ++++++++++++++++++++++++++++++++++++ 2 files changed, 43 insertions(+), 1 deletion(-) diff --git a/src/llvm_backend_expr.cpp b/src/llvm_backend_expr.cpp index 6cb221a94..18d5e267b 100644 --- a/src/llvm_backend_expr.cpp +++ b/src/llvm_backend_expr.cpp @@ -619,9 +619,10 @@ lbValue lb_emit_matrix_mul_vector(lbProcedure *p, lbValue lhs, lbValue rhs, Type Type *elem = mt->Matrix.elem; LLVMTypeRef elem_type = lb_type(p->module, elem); - unsigned stride = cast(unsigned)matrix_type_stride_in_elems(mt); if (lb_matrix_elem_simple(mt)) { + unsigned stride = cast(unsigned)matrix_type_stride_in_elems(mt); + unsigned row_count = cast(unsigned)mt->Matrix.row_count; gb_unused(row_count); unsigned column_count = cast(unsigned)mt->Matrix.column_count; auto m_columns = slice_make(permanent_allocator(), column_count); diff --git a/src/llvm_backend_utility.cpp b/src/llvm_backend_utility.cpp index 56637e907..b07dc3459 100644 --- a/src/llvm_backend_utility.cpp +++ b/src/llvm_backend_utility.cpp @@ -1544,4 +1544,45 @@ LLVMValueRef llvm_splat(lbProcedure *p, LLVMValueRef value, unsigned count) { } LLVMValueRef mask = llvm_mask_zero(p->module, count); return LLVMBuildShuffleVector(p->builder, single, LLVMGetUndef(LLVMTypeOf(single)), mask, ""); +} + +LLVMValueRef llvm_vector_reduce_add(lbProcedure *p, LLVMValueRef value) { + LLVMTypeRef type = LLVMTypeOf(value); + GB_ASSERT(LLVMGetTypeKind(type) == LLVMVectorTypeKind); + LLVMTypeRef elem = LLVMGetElementType(type); + + char const *name = nullptr; + i32 value_offset = 0; + i32 value_count = 0; + + switch (LLVMGetTypeKind(elem)) { + case LLVMHalfTypeKind: + case LLVMFloatTypeKind: + case LLVMDoubleTypeKind: + name = "llvm.vector.reduce.fadd"; + value_offset = 0; + value_count = 2; + break; + case LLVMIntegerTypeKind: + name = "llvm.vector.reduce.add"; + value_offset = 1; + value_count = 1; + break; + default: + GB_PANIC("invalid vector type %s", LLVMPrintTypeToString(type)); + break; + } + + unsigned id = LLVMLookupIntrinsicID(name, gb_strlen(name)); + GB_ASSERT_MSG(id != 0, "Unable to find %s", name); + + LLVMTypeRef types[1] = {}; + types[0] = elem; + + LLVMValueRef ip = LLVMGetIntrinsicDeclaration(p->module->mod, id, types, gb_count_of(types)); + LLVMValueRef values[2] = {}; + values[0] = LLVMConstNull(elem); + values[1] = value; + LLVMValueRef call = LLVMBuildCall(p->builder, ip, values+value_offset, value_count, ""); + return call; } \ No newline at end of file