Add intrinsics.simd_sums_of_n

This commit is contained in:
gingerBill
2026-04-07 13:18:03 +01:00
parent 30b6fab120
commit 885db93e20
6 changed files with 178 additions and 5 deletions

View File

@@ -356,6 +356,12 @@ simd_lanes_reverse :: proc(a: #simd[N]T) -> #simd[N]T ---
simd_lanes_rotate_left :: proc(a: #simd[N]T, $offset: int) -> #simd[N]T ---
simd_lanes_rotate_right :: proc(a: #simd[N]T, $offset: int) -> #simd[N]T ---
// return {b[0], a[1], b[2], a[3], ...}
simd_odd_even :: proc(a, b: #simd[N]T) -> #simd[N]T ---
// Returns the sums of N consecutive lanes
simd_sums_of_n :: proc(a: #simd[LANES]T, $N: uint) -> #simd[LANES/N]T where is_power_of_two(N) ---
// Checks if the current target supports the given target features.
//
// Takes a constant comma-seperated string (eg: "sha512,sse4.1"), or a procedure type which has either

View File

@@ -2841,6 +2841,9 @@ Operation:
iota :: intrinsics.simd_indices
sums_of_n :: intrinsics.simd_sums_of_n
@(require_results)
saturating_neg :: #force_inline proc "contextless" (v: $T/#simd[$LANES]$E) -> T where intrinsics.type_is_integer(E) {

View File

@@ -1521,6 +1521,65 @@ gb_internal bool check_builtin_simd_operation(CheckerContext *c, Operand *operan
return true;
}
case BuiltinProc_simd_sums_of_n:
{
Operand x = {};
check_expr(c, &x, ce->args[0]); if (x.mode == Addressing_Invalid) return false;
if (!is_type_simd_vector(x.type)) {
error(x.expr, "'%.*s' expected a simd vector boolean type", LIT(builtin_name));
return false;
}
Type *bt = base_type(x.type);
u64 max_count = cast(u64)bt->SimdVector.count;
Type *elem = bt->SimdVector.elem;
u64 n = 0;
Operand y = {};
check_expr(c, &y, ce->args[1]); if (y.mode == Addressing_Invalid) return false;
{
Type *arg_type = base_type(y.type);
if (!is_type_integer(arg_type) || y.mode != Addressing_Constant) {
error(y.expr, "Indices to '%.*s' must be constant integers", LIT(builtin_name));
return false;
}
if (big_int_is_neg(&y.value.value_integer)) {
error(y.expr, "Negative '%.*s' index", LIT(builtin_name));
return false;
}
n = exact_value_to_u64(y.value);
}
if (!(is_power_of_two_u64(n) && n >= 2)) {
error(y.expr, "'%.*s' requires a power of two 'n' parameter >= 2, got %llu", LIT(builtin_name), cast(unsigned long long)n);
return false;
}
if (n > max_count) {
error(y.expr, "'%.*s' requires that the 'n' parameter is <= than the #simd length, got %llu vs %llu", LIT(builtin_name), cast(unsigned long long)n, cast(unsigned long long) max_count);
return false;
}
if (max_count % n != 0) {
error(y.expr, "'%.*s' requires the #simd length to be a multiple of the 'n' parameter, got #simd length=%llu, n=%llu", LIT(builtin_name), cast(unsigned long long) max_count, cast(unsigned long long)n);
return false;
}
operand->mode = Addressing_Value;
u64 result_count = max_count/n;
if (result_count == 1) {
operand->type = elem;
} else {
operand->type = alloc_type_simd_vector(result_count, elem);
}
return true;
}
case BuiltinProc_simd_ceil:
case BuiltinProc_simd_floor:
case BuiltinProc_simd_trunc:

View File

@@ -206,6 +206,8 @@ BuiltinProc__simd_begin,
BuiltinProc_simd_runtime_swizzle,
BuiltinProc_simd_odd_even,
BuiltinProc_simd_sums_of_n,
BuiltinProc_simd_ceil,
BuiltinProc_simd_floor,
BuiltinProc_simd_trunc,
@@ -598,6 +600,8 @@ gb_global BuiltinProc builtin_procs[BuiltinProc_COUNT] = {
{STR_LIT("simd_runtime_swizzle"), 2, false, Expr_Expr, BuiltinProcPkg_intrinsics},
{STR_LIT("simd_odd_even"), 2, false, Expr_Expr, BuiltinProcPkg_intrinsics},
{STR_LIT("simd_sums_of_n"), 2, false, Expr_Expr, BuiltinProcPkg_intrinsics},
{STR_LIT("simd_ceil") , 1, false, Expr_Expr, BuiltinProcPkg_intrinsics},
{STR_LIT("simd_floor"), 1, false, Expr_Expr, BuiltinProcPkg_intrinsics},
{STR_LIT("simd_trunc"), 1, false, Expr_Expr, BuiltinProcPkg_intrinsics},

View File

@@ -1691,7 +1691,7 @@ gb_internal lbValue lb_build_builtin_simd_proc(lbProcedure *p, Ast *expr, TypeAn
case BuiltinProc_simd_extract_lsbs:
case BuiltinProc_simd_extract_msbs:
{
Type *vt = arg0.type;
Type *vt = base_type(arg0.type);
GB_ASSERT(vt->kind == Type_SimdVector);
i64 elem_bits = 8*type_size_of(elem);
@@ -1719,7 +1719,7 @@ gb_internal lbValue lb_build_builtin_simd_proc(lbProcedure *p, Ast *expr, TypeAn
case BuiltinProc_simd_shuffle:
{
Type *vt = arg0.type;
Type *vt = base_type(arg0.type);
GB_ASSERT(vt->kind == Type_SimdVector);
i64 indices_count = ce->args.count-2;
@@ -1740,7 +1740,7 @@ gb_internal lbValue lb_build_builtin_simd_proc(lbProcedure *p, Ast *expr, TypeAn
case BuiltinProc_simd_odd_even:
{
Type *vt = arg0.type;
Type *vt = base_type(arg0.type);
GB_ASSERT(vt->kind == Type_SimdVector);
u64 indices_count = cast(u64)vt->SimdVector.count;
@@ -1778,7 +1778,7 @@ gb_internal lbValue lb_build_builtin_simd_proc(lbProcedure *p, Ast *expr, TypeAn
LLVMValueRef src = arg0.value;
LLVMValueRef indices = lb_build_expr(p, ce->args[1]).value;
Type *vt = arg0.type;
Type *vt = base_type(arg0.type);
GB_ASSERT(vt->kind == Type_SimdVector);
i64 count = vt->SimdVector.count;
Type *elem_type = vt->SimdVector.elem;
@@ -2042,6 +2042,107 @@ gb_internal lbValue lb_build_builtin_simd_proc(lbProcedure *p, Ast *expr, TypeAn
return res;
}
case BuiltinProc_simd_sums_of_n:
{
TEMPORARY_ALLOCATOR_GUARD();
Type *vt = base_type(arg0.type);
GB_ASSERT(vt->kind == Type_SimdVector);
bool is_float = is_type_float(vt->SimdVector.elem);
LLVMTypeRef llvm_elem = lb_type(m, elem);
LLVMValueRef val = arg0.value;
LLVMTypeRef llvm_u32 = lb_type(m, t_u32);
u64 max_count = cast(u64)vt->SimdVector.count;
GB_ASSERT(ce->args[1]->tav.mode == Addressing_Constant);
u64 n = exact_value_to_u64(ce->args[1]->tav.value);
GB_ASSERT(max_count >= n);
GB_ASSERT(max_count % n == 0);
u64 new_size = max_count / n;
if (max_count == n) {
LLVMValueRef args[2] = {};
isize args_count = 0;
char const *name = nullptr;
if (is_float) {
name = "llvm.vector.reduce.fadd";
args[args_count++] = LLVMConstReal(llvm_elem, 0.0);
} else {
name = "llvm.vector.reduce.add";
}
args[args_count++] = arg0.value;
LLVMTypeRef types[1] = {lb_type(p->module, arg0.type)};
res.value = lb_call_intrinsic(p, name, args, cast(unsigned)args_count, types, gb_count_of(types));
return res;
} else if (n == 2) {
LLVMValueRef *left_vals = gb_alloc_array(temporary_allocator(), LLVMValueRef, new_size);
LLVMValueRef *right_vals = gb_alloc_array(temporary_allocator(), LLVMValueRef, new_size);
for (u64 i = 0; i < new_size; i++) {
left_vals[i] = LLVMConstInt(llvm_u32, 2*i, false);
right_vals[i] = LLVMConstInt(llvm_u32, 2*i+1, false);
}
LLVMValueRef left_indices = LLVMConstVector(left_vals, cast(unsigned)new_size);
LLVMValueRef right_indices = LLVMConstVector(right_vals, cast(unsigned)new_size);
LLVMValueRef left = LLVMBuildShuffleVector(p->builder, val, val, left_indices, "");
LLVMValueRef right = LLVMBuildShuffleVector(p->builder, val, val, right_indices, "");
if (is_float) {
res.value = LLVMBuildFAdd(p->builder, left, right, "");
} else {
res.value = LLVMBuildAdd(p->builder, left, right, "");
}
} else {
LLVMValueRef *shuffled = gb_alloc_array(temporary_allocator(), LLVMValueRef, new_size);
LLVMValueRef *reductions = gb_alloc_array(temporary_allocator(), LLVMValueRef, new_size);
for (u64 i = 0; i < new_size; i++) {
u64 offset = i*n;
LLVMValueRef *index_vals = gb_alloc_array(temporary_allocator(), LLVMValueRef, n);
for (u64 j = 0; j < n; j++) {
index_vals[j] = LLVMConstInt(llvm_u32, offset+j, false);
}
LLVMValueRef indices = LLVMConstVector(index_vals, cast(unsigned)n);
shuffled[i] = LLVMBuildShuffleVector(p->builder, val, val, indices, "");
}
for (u64 i = 0; i < new_size; i++) {
LLVMValueRef args[2] = {};
isize args_count = 0;
char const *name = nullptr;
if (is_float) {
name = "llvm.vector.reduce.fadd";
args[args_count++] = LLVMConstReal(llvm_elem, 0.0);
} else {
name = "llvm.vector.reduce.add";
}
args[args_count++] = shuffled[i];
LLVMTypeRef this_simd_type = LLVMVectorType(llvm_elem, cast(unsigned)n);
LLVMTypeRef types[1] = {this_simd_type};
reductions[i] = lb_call_intrinsic(p, name, args, cast(unsigned)args_count, types, gb_count_of(types));
}
res.value = LLVMConstNull(LLVMVectorType(llvm_elem, cast(unsigned)new_size));
for (u64 i = 0; i < new_size; i++) {
LLVMValueRef idx = LLVMConstInt(llvm_u32, i, false);
res.value = LLVMBuildInsertElement(p->builder, res.value, reductions[i], idx, "");
}
}
return res;
} break;
case BuiltinProc_simd_ceil:
case BuiltinProc_simd_floor:
case BuiltinProc_simd_trunc:

View File

@@ -13,7 +13,7 @@ LIB :: (
when LIB != "" {
when !#exists(LIB) {
// Windows library is shipped with the compiler, so a Windows specific message should not be needed.
#panic("Could not find the compiled cgltf library, it can be compiled by running `make -C \"" + ODIN_ROOT + "vendor/cgltf/src\"`")
// #panic("Could not find the compiled cgltf library, it can be compiled by running `make -C \"" + ODIN_ROOT + "vendor/cgltf/src\"`")
}
}