Merge pull request #5442 from jon-lipstate/table_lookup

table lookup simd intrinsic
This commit is contained in:
gingerBill
2025-07-22 11:14:54 +01:00
committed by GitHub
5 changed files with 375 additions and 0 deletions

View File

@@ -314,6 +314,7 @@ simd_indices :: proc($T: typeid/#simd[$N]$E) -> T where type_is_numeric(T) ---
simd_shuffle :: proc(a, b: #simd[N]T, indices: ..int) -> #simd[len(indices)]T ---
simd_select :: proc(cond: #simd[N]boolean_or_integer, true, false: #simd[N]T) -> #simd[N]T ---
simd_runtime_swizzle :: proc(table: #simd[N]T, indices: #simd[N]T) -> #simd[N]T where type_is_integer(T) ---
// Lane-wise operations
simd_ceil :: proc(a: #simd[N]any_float) -> #simd[N]any_float ---

View File

@@ -2440,6 +2440,57 @@ Graphically, the operation looks as follows. The `t` and `f` represent the
*/
select :: intrinsics.simd_select
/*
Runtime Equivalent to Shuffle.
Performs element-wise table lookups using runtime indices.
Each element in the indices vector selects an element from the table vector.
The indices are automatically masked to prevent out-of-bounds access.
This operation is hardware-accelerated on most platforms when using 8-bit
integer vectors. For other element types or unsupported vector sizes, it
falls back to software emulation.
Inputs:
- `table`: The lookup table vector (should be power-of-2 size for correct masking).
- `indices`: The indices vector (automatically masked to valid range).
Returns:
- A vector where `result[i] = table[indices[i] & (table_size-1)]`.
Operation:
for i in 0 ..< len(indices) {
masked_index := indices[i] & (len(table) - 1)
result[i] = table[masked_index]
}
return result
Implementation:
| Platform | Lane Size | Implementation |
|-------------|-------------------------------------------|---------------------|
| x86-64 | pshufb (16B), vpshufb (32B), AVX512 (64B) | Single vector |
| ARM64 | tbl1 (16B), tbl2 (32B), tbl4 (64B) | Automatic splitting |
| ARM32 | vtbl1 (8B), vtbl2 (16B), vtbl4 (32B) | Automatic splitting |
| WebAssembly | i8x16.swizzle (16B), Emulation (>16B) | Mixed |
| Other | Emulation | Software |
Example:
import "core:simd"
import "core:fmt"
runtime_swizzle_example :: proc() {
table := simd.u8x16{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}
indices := simd.u8x16{15, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14}
result := simd.runtime_swizzle(table, indices)
fmt.println(result) // Expected: {15, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14}
}
*/
runtime_swizzle :: intrinsics.simd_runtime_swizzle
/*
Compute the square root of each lane in a SIMD vector.
*/

View File

@@ -1159,6 +1159,58 @@ gb_internal bool check_builtin_simd_operation(CheckerContext *c, Operand *operan
return true;
}
case BuiltinProc_simd_runtime_swizzle:
{
if (ce->args.count != 2) {
error(call, "'%.*s' expected 2 arguments, got %td", LIT(builtin_name), ce->args.count);
return false;
}
Operand src = {};
Operand indices = {};
check_expr(c, &src, ce->args[0]); if (src.mode == Addressing_Invalid) return false;
check_expr_with_type_hint(c, &indices, ce->args[1], src.type); if (indices.mode == Addressing_Invalid) return false;
if (!is_type_simd_vector(src.type)) {
error(src.expr, "'%.*s' expected first argument to be a simd vector", LIT(builtin_name));
return false;
}
if (!is_type_simd_vector(indices.type)) {
error(indices.expr, "'%.*s' expected second argument (indices) to be a simd vector", LIT(builtin_name));
return false;
}
Type *src_elem = base_array_type(src.type);
Type *indices_elem = base_array_type(indices.type);
if (!is_type_integer(src_elem)) {
gbString src_str = type_to_string(src.type);
error(src.expr, "'%.*s' expected first argument to be a simd vector of integers, got '%s'", LIT(builtin_name), src_str);
gb_string_free(src_str);
return false;
}
if (!is_type_integer(indices_elem)) {
gbString indices_str = type_to_string(indices.type);
error(indices.expr, "'%.*s' expected indices to be a simd vector of integers, got '%s'", LIT(builtin_name), indices_str);
gb_string_free(indices_str);
return false;
}
if (!are_types_identical(src.type, indices.type)) {
gbString src_str = type_to_string(src.type);
gbString indices_str = type_to_string(indices.type);
error(indices.expr, "'%.*s' expected both arguments to have the same type, got '%s' vs '%s'", LIT(builtin_name), src_str, indices_str);
gb_string_free(indices_str);
gb_string_free(src_str);
return false;
}
operand->mode = Addressing_Value;
operand->type = src.type;
return true;
}
case BuiltinProc_simd_ceil:
case BuiltinProc_simd_floor:
case BuiltinProc_simd_trunc:

View File

@@ -191,6 +191,7 @@ BuiltinProc__simd_begin,
BuiltinProc_simd_shuffle,
BuiltinProc_simd_select,
BuiltinProc_simd_runtime_swizzle,
BuiltinProc_simd_ceil,
BuiltinProc_simd_floor,
@@ -552,6 +553,7 @@ gb_global BuiltinProc builtin_procs[BuiltinProc_COUNT] = {
{STR_LIT("simd_shuffle"), 2, true, Expr_Expr, BuiltinProcPkg_intrinsics},
{STR_LIT("simd_select"), 3, false, Expr_Expr, BuiltinProcPkg_intrinsics},
{STR_LIT("simd_runtime_swizzle"), 2, false, Expr_Expr, BuiltinProcPkg_intrinsics},
{STR_LIT("simd_ceil") , 1, false, Expr_Expr, BuiltinProcPkg_intrinsics},
{STR_LIT("simd_floor"), 1, false, Expr_Expr, BuiltinProcPkg_intrinsics},

View File

@@ -1721,6 +1721,275 @@ gb_internal lbValue lb_build_builtin_simd_proc(lbProcedure *p, Ast *expr, TypeAn
return res;
}
case BuiltinProc_simd_runtime_swizzle:
{
LLVMValueRef src = arg0.value;
LLVMValueRef indices = lb_build_expr(p, ce->args[1]).value;
Type *vt = arg0.type;
GB_ASSERT(vt->kind == Type_SimdVector);
i64 count = vt->SimdVector.count;
Type *elem_type = vt->SimdVector.elem;
i64 elem_size = type_size_of(elem_type);
// Determine strategy based on element size and target architecture
char const *intrinsic_name = nullptr;
bool use_hardware_runtime_swizzle = false;
// 8-bit elements: Use dedicated table lookup instructions
if (elem_size == 1) {
use_hardware_runtime_swizzle = true;
if (build_context.metrics.arch == TargetArch_amd64 || build_context.metrics.arch == TargetArch_i386) {
// x86/x86-64: Use pshufb intrinsics
switch (count) {
case 16:
intrinsic_name = "llvm.x86.ssse3.pshuf.b.128";
break;
case 32:
intrinsic_name = "llvm.x86.avx2.pshuf.b";
break;
case 64:
intrinsic_name = "llvm.x86.avx512.pshuf.b.512";
break;
default:
use_hardware_runtime_swizzle = false;
break;
}
} else if (build_context.metrics.arch == TargetArch_arm64) {
// ARM64: Use NEON tbl intrinsics with automatic table splitting
switch (count) {
case 16:
intrinsic_name = "llvm.aarch64.neon.tbl1";
break;
case 32:
intrinsic_name = "llvm.aarch64.neon.tbl2";
break;
case 48:
intrinsic_name = "llvm.aarch64.neon.tbl3";
break;
case 64:
intrinsic_name = "llvm.aarch64.neon.tbl4";
break;
default:
use_hardware_runtime_swizzle = false;
break;
}
} else if (build_context.metrics.arch == TargetArch_arm32) {
// ARM32: Use NEON vtbl intrinsics with automatic table splitting
switch (count) {
case 8:
intrinsic_name = "llvm.arm.neon.vtbl1";
break;
case 16:
intrinsic_name = "llvm.arm.neon.vtbl2";
break;
case 24:
intrinsic_name = "llvm.arm.neon.vtbl3";
break;
case 32:
intrinsic_name = "llvm.arm.neon.vtbl4";
break;
default:
use_hardware_runtime_swizzle = false;
break;
}
} else if (build_context.metrics.arch == TargetArch_wasm32 || build_context.metrics.arch == TargetArch_wasm64p32) {
// WebAssembly: Use swizzle (only supports 16-byte vectors)
if (count == 16) {
intrinsic_name = "llvm.wasm.swizzle";
} else {
use_hardware_runtime_swizzle = false;
}
} else {
use_hardware_runtime_swizzle = false;
}
}
if (use_hardware_runtime_swizzle && intrinsic_name != nullptr) {
// Use dedicated hardware swizzle instruction
// Check if required target features are enabled
bool features_enabled = true;
if (build_context.metrics.arch == TargetArch_amd64 || build_context.metrics.arch == TargetArch_i386) {
// x86/x86-64 feature checking
if (count == 16) {
// SSE/SSSE3 for 128-bit vectors
if (!check_target_feature_is_enabled(str_lit("ssse3"), nullptr)) {
features_enabled = false;
}
} else if (count == 32) {
// AVX2 requires ssse3 + avx2 features
if (!check_target_feature_is_enabled(str_lit("ssse3"), nullptr) ||
!check_target_feature_is_enabled(str_lit("avx2"), nullptr)) {
features_enabled = false;
}
} else if (count == 64) {
// AVX512 requires ssse3 + avx2 + avx512f + avx512bw features
if (!check_target_feature_is_enabled(str_lit("ssse3"), nullptr) ||
!check_target_feature_is_enabled(str_lit("avx2"), nullptr) ||
!check_target_feature_is_enabled(str_lit("avx512f"), nullptr) ||
!check_target_feature_is_enabled(str_lit("avx512bw"), nullptr)) {
features_enabled = false;
}
}
} else if (build_context.metrics.arch == TargetArch_arm64 || build_context.metrics.arch == TargetArch_arm32) {
// ARM/ARM64 feature checking - NEON is required for all table/swizzle ops
if (!check_target_feature_is_enabled(str_lit("neon"), nullptr)) {
features_enabled = false;
}
}
if (features_enabled) {
// Add target features to function attributes for LLVM instruction selection
if (build_context.metrics.arch == TargetArch_amd64 || build_context.metrics.arch == TargetArch_i386) {
// x86/x86-64 function attributes
if (count == 16) {
// SSE/SSSE3 for 128-bit vectors
lb_add_attribute_to_proc_with_string(p->module, p->value, str_lit("target-features"), str_lit("+ssse3"));
lb_add_attribute_to_proc_with_string(p->module, p->value, str_lit("min-legal-vector-width"), str_lit("128"));
} else if (count == 32) {
lb_add_attribute_to_proc_with_string(p->module, p->value, str_lit("target-features"), str_lit("+avx,+avx2,+ssse3"));
lb_add_attribute_to_proc_with_string(p->module, p->value, str_lit("min-legal-vector-width"), str_lit("256"));
} else if (count == 64) {
lb_add_attribute_to_proc_with_string(p->module, p->value, str_lit("target-features"), str_lit("+avx,+avx2,+avx512f,+avx512bw,+ssse3"));
lb_add_attribute_to_proc_with_string(p->module, p->value, str_lit("min-legal-vector-width"), str_lit("512"));
}
} else if (build_context.metrics.arch == TargetArch_arm64) {
// ARM64 function attributes - enable NEON for swizzle instructions
lb_add_attribute_to_proc_with_string(p->module, p->value, str_lit("target-features"), str_lit("+neon"));
// Set appropriate vector width for multi-swizzle operations
if (count >= 32) {
lb_add_attribute_to_proc_with_string(p->module, p->value, str_lit("min-legal-vector-width"), str_lit("256"));
}
} else if (build_context.metrics.arch == TargetArch_arm32) {
// ARM32 function attributes - enable NEON for swizzle instructions
lb_add_attribute_to_proc_with_string(p->module, p->value, str_lit("target-features"), str_lit("+neon"));
}
// Handle ARM's multi-swizzle intrinsics by splitting the src vector
if (build_context.metrics.arch == TargetArch_arm64 && count > 16) {
// ARM64 TBL2/TBL3/TBL4: Split src into multiple 16-byte vectors
int num_tables = cast(int)(count / 16);
GB_ASSERT_MSG(count % 16 == 0, "ARM64 src size must be multiple of 16 bytes, got %lld bytes", count);
GB_ASSERT_MSG(num_tables <= 4, "ARM64 NEON supports maximum 4 tables (tbl4), got %d tables for %lld-byte vector", num_tables, count);
LLVMValueRef src_parts[4]; // Max 4 tables for tbl4
for (int i = 0; i < num_tables; i++) {
// Extract 16-byte slice from the larger src
LLVMValueRef indices_for_extract[16];
for (int j = 0; j < 16; j++) {
indices_for_extract[j] = LLVMConstInt(LLVMInt32TypeInContext(p->module->ctx), i * 16 + j, false);
}
LLVMValueRef extract_mask = LLVMConstVector(indices_for_extract, 16);
src_parts[i] = LLVMBuildShuffleVector(p->builder, src, LLVMGetUndef(LLVMTypeOf(src)), extract_mask, "");
}
// Call appropriate ARM64 tbl intrinsic
if (count == 32) {
LLVMValueRef args[3] = { src_parts[0], src_parts[1], indices };
res.value = lb_call_intrinsic(p, intrinsic_name, args, 3, nullptr, 0);
} else if (count == 48) {
LLVMValueRef args[4] = { src_parts[0], src_parts[1], src_parts[2], indices };
res.value = lb_call_intrinsic(p, intrinsic_name, args, 4, nullptr, 0);
} else if (count == 64) {
LLVMValueRef args[5] = { src_parts[0], src_parts[1], src_parts[2], src_parts[3], indices };
res.value = lb_call_intrinsic(p, intrinsic_name, args, 5, nullptr, 0);
}
} else if (build_context.metrics.arch == TargetArch_arm32 && count > 8) {
// ARM32 VTBL2/VTBL3/VTBL4: Split src into multiple 8-byte vectors
int num_tables = cast(int)count / 8;
GB_ASSERT_MSG(count % 8 == 0, "ARM32 src size must be multiple of 8 bytes, got %lld bytes", count);
GB_ASSERT_MSG(num_tables <= 4, "ARM32 NEON supports maximum 4 tables (vtbl4), got %d tables for %lld-byte vector", num_tables, count);
LLVMValueRef src_parts[4]; // Max 4 tables for vtbl4
for (int i = 0; i < num_tables; i++) {
// Extract 8-byte slice from the larger src
LLVMValueRef indices_for_extract[8];
for (int j = 0; j < 8; j++) {
indices_for_extract[j] = LLVMConstInt(LLVMInt32TypeInContext(p->module->ctx), i * 8 + j, false);
}
LLVMValueRef extract_mask = LLVMConstVector(indices_for_extract, 8);
src_parts[i] = LLVMBuildShuffleVector(p->builder, src, LLVMGetUndef(LLVMTypeOf(src)), extract_mask, "");
}
// Call appropriate ARM32 vtbl intrinsic
if (count == 16) {
LLVMValueRef args[3] = { src_parts[0], src_parts[1], indices };
res.value = lb_call_intrinsic(p, intrinsic_name, args, 3, nullptr, 0);
} else if (count == 24) {
LLVMValueRef args[4] = { src_parts[0], src_parts[1], src_parts[2], indices };
res.value = lb_call_intrinsic(p, intrinsic_name, args, 4, nullptr, 0);
} else if (count == 32) {
LLVMValueRef args[5] = { src_parts[0], src_parts[1], src_parts[2], src_parts[3], indices };
res.value = lb_call_intrinsic(p, intrinsic_name, args, 5, nullptr, 0);
}
} else {
// Single runtime swizzle case (x86, WebAssembly, ARM single-table)
LLVMValueRef args[2] = { src, indices };
res.value = lb_call_intrinsic(p, intrinsic_name, args, gb_count_of(args), nullptr, 0);
}
return res;
} else {
// Features not enabled, fall back to emulation
use_hardware_runtime_swizzle = false;
}
}
// Fallback: Emulate with extracts and inserts for all element sizes
GB_ASSERT(count > 0 && count <= 64); // Sanity check
LLVMValueRef *values = gb_alloc_array(temporary_allocator(), LLVMValueRef, count);
LLVMTypeRef i32_type = LLVMInt32TypeInContext(p->module->ctx);
LLVMTypeRef elem_llvm_type = lb_type(p->module, elem_type);
// Calculate mask based on element size and vector count
i64 max_index = count - 1;
LLVMValueRef index_mask;
if (elem_size == 1) {
// 8-bit: mask to src size (like pshufb behavior)
index_mask = LLVMConstInt(elem_llvm_type, max_index, false);
} else if (elem_size == 2) {
// 16-bit: mask to src size
index_mask = LLVMConstInt(elem_llvm_type, max_index, false);
} else if (elem_size == 4) {
// 32-bit: mask to src size
index_mask = LLVMConstInt(elem_llvm_type, max_index, false);
} else {
// 64-bit: mask to src size
index_mask = LLVMConstInt(elem_llvm_type, max_index, false);
}
for (i64 i = 0; i < count; i++) {
LLVMValueRef idx_i = LLVMConstInt(i32_type, cast(unsigned)i, false);
LLVMValueRef index_elem = LLVMBuildExtractElement(p->builder, indices, idx_i, "");
// Mask index to valid range
LLVMValueRef masked_index = LLVMBuildAnd(p->builder, index_elem, index_mask, "");
// Convert to i32 for extractelement
LLVMValueRef index_i32;
if (LLVMGetIntTypeWidth(LLVMTypeOf(masked_index)) < 32) {
index_i32 = LLVMBuildZExt(p->builder, masked_index, i32_type, "");
} else if (LLVMGetIntTypeWidth(LLVMTypeOf(masked_index)) > 32) {
index_i32 = LLVMBuildTrunc(p->builder, masked_index, i32_type, "");
} else {
index_i32 = masked_index;
}
values[i] = LLVMBuildExtractElement(p->builder, src, index_i32, "");
}
// Build result vector
res.value = LLVMGetUndef(LLVMTypeOf(src));
for (i64 i = 0; i < count; i++) {
LLVMValueRef idx_i = LLVMConstInt(i32_type, cast(unsigned)i, false);
res.value = LLVMBuildInsertElement(p->builder, res.value, values[i], idx_i, "");
}
return res;
}
case BuiltinProc_simd_ceil:
case BuiltinProc_simd_floor:
case BuiltinProc_simd_trunc: