Add intrinsics.simd_sums_of_n

This commit is contained in:
gingerBill
2026-04-07 13:18:03 +01:00
parent 30b6fab120
commit 885db93e20
6 changed files with 178 additions and 5 deletions

View File

@@ -1691,7 +1691,7 @@ gb_internal lbValue lb_build_builtin_simd_proc(lbProcedure *p, Ast *expr, TypeAn
case BuiltinProc_simd_extract_lsbs:
case BuiltinProc_simd_extract_msbs:
{
Type *vt = arg0.type;
Type *vt = base_type(arg0.type);
GB_ASSERT(vt->kind == Type_SimdVector);
i64 elem_bits = 8*type_size_of(elem);
@@ -1719,7 +1719,7 @@ gb_internal lbValue lb_build_builtin_simd_proc(lbProcedure *p, Ast *expr, TypeAn
case BuiltinProc_simd_shuffle:
{
Type *vt = arg0.type;
Type *vt = base_type(arg0.type);
GB_ASSERT(vt->kind == Type_SimdVector);
i64 indices_count = ce->args.count-2;
@@ -1740,7 +1740,7 @@ gb_internal lbValue lb_build_builtin_simd_proc(lbProcedure *p, Ast *expr, TypeAn
case BuiltinProc_simd_odd_even:
{
Type *vt = arg0.type;
Type *vt = base_type(arg0.type);
GB_ASSERT(vt->kind == Type_SimdVector);
u64 indices_count = cast(u64)vt->SimdVector.count;
@@ -1778,7 +1778,7 @@ gb_internal lbValue lb_build_builtin_simd_proc(lbProcedure *p, Ast *expr, TypeAn
LLVMValueRef src = arg0.value;
LLVMValueRef indices = lb_build_expr(p, ce->args[1]).value;
Type *vt = arg0.type;
Type *vt = base_type(arg0.type);
GB_ASSERT(vt->kind == Type_SimdVector);
i64 count = vt->SimdVector.count;
Type *elem_type = vt->SimdVector.elem;
@@ -2042,6 +2042,107 @@ gb_internal lbValue lb_build_builtin_simd_proc(lbProcedure *p, Ast *expr, TypeAn
return res;
}
case BuiltinProc_simd_sums_of_n:
{
TEMPORARY_ALLOCATOR_GUARD();
Type *vt = base_type(arg0.type);
GB_ASSERT(vt->kind == Type_SimdVector);
bool is_float = is_type_float(vt->SimdVector.elem);
LLVMTypeRef llvm_elem = lb_type(m, elem);
LLVMValueRef val = arg0.value;
LLVMTypeRef llvm_u32 = lb_type(m, t_u32);
u64 max_count = cast(u64)vt->SimdVector.count;
GB_ASSERT(ce->args[1]->tav.mode == Addressing_Constant);
u64 n = exact_value_to_u64(ce->args[1]->tav.value);
GB_ASSERT(max_count >= n);
GB_ASSERT(max_count % n == 0);
u64 new_size = max_count / n;
if (max_count == n) {
LLVMValueRef args[2] = {};
isize args_count = 0;
char const *name = nullptr;
if (is_float) {
name = "llvm.vector.reduce.fadd";
args[args_count++] = LLVMConstReal(llvm_elem, 0.0);
} else {
name = "llvm.vector.reduce.add";
}
args[args_count++] = arg0.value;
LLVMTypeRef types[1] = {lb_type(p->module, arg0.type)};
res.value = lb_call_intrinsic(p, name, args, cast(unsigned)args_count, types, gb_count_of(types));
return res;
} else if (n == 2) {
LLVMValueRef *left_vals = gb_alloc_array(temporary_allocator(), LLVMValueRef, new_size);
LLVMValueRef *right_vals = gb_alloc_array(temporary_allocator(), LLVMValueRef, new_size);
for (u64 i = 0; i < new_size; i++) {
left_vals[i] = LLVMConstInt(llvm_u32, 2*i, false);
right_vals[i] = LLVMConstInt(llvm_u32, 2*i+1, false);
}
LLVMValueRef left_indices = LLVMConstVector(left_vals, cast(unsigned)new_size);
LLVMValueRef right_indices = LLVMConstVector(right_vals, cast(unsigned)new_size);
LLVMValueRef left = LLVMBuildShuffleVector(p->builder, val, val, left_indices, "");
LLVMValueRef right = LLVMBuildShuffleVector(p->builder, val, val, right_indices, "");
if (is_float) {
res.value = LLVMBuildFAdd(p->builder, left, right, "");
} else {
res.value = LLVMBuildAdd(p->builder, left, right, "");
}
} else {
LLVMValueRef *shuffled = gb_alloc_array(temporary_allocator(), LLVMValueRef, new_size);
LLVMValueRef *reductions = gb_alloc_array(temporary_allocator(), LLVMValueRef, new_size);
for (u64 i = 0; i < new_size; i++) {
u64 offset = i*n;
LLVMValueRef *index_vals = gb_alloc_array(temporary_allocator(), LLVMValueRef, n);
for (u64 j = 0; j < n; j++) {
index_vals[j] = LLVMConstInt(llvm_u32, offset+j, false);
}
LLVMValueRef indices = LLVMConstVector(index_vals, cast(unsigned)n);
shuffled[i] = LLVMBuildShuffleVector(p->builder, val, val, indices, "");
}
for (u64 i = 0; i < new_size; i++) {
LLVMValueRef args[2] = {};
isize args_count = 0;
char const *name = nullptr;
if (is_float) {
name = "llvm.vector.reduce.fadd";
args[args_count++] = LLVMConstReal(llvm_elem, 0.0);
} else {
name = "llvm.vector.reduce.add";
}
args[args_count++] = shuffled[i];
LLVMTypeRef this_simd_type = LLVMVectorType(llvm_elem, cast(unsigned)n);
LLVMTypeRef types[1] = {this_simd_type};
reductions[i] = lb_call_intrinsic(p, name, args, cast(unsigned)args_count, types, gb_count_of(types));
}
res.value = LLVMConstNull(LLVMVectorType(llvm_elem, cast(unsigned)new_size));
for (u64 i = 0; i < new_size; i++) {
LLVMValueRef idx = LLVMConstInt(llvm_u32, i, false);
res.value = LLVMBuildInsertElement(p->builder, res.value, reductions[i], idx, "");
}
}
return res;
} break;
case BuiltinProc_simd_ceil:
case BuiltinProc_simd_floor:
case BuiltinProc_simd_trunc: