From 618d3bf62fbcfa6ca7f827ad4090143b8535b4a2 Mon Sep 17 00:00:00 2001 From: gingerBill Date: Sat, 28 May 2022 13:42:58 +0100 Subject: [PATCH] Improve vector comparison `==` `!=` for horizontal reduction --- src/llvm_backend_expr.cpp | 37 ++++++++++++++++++++++++++++--------- 1 file changed, 28 insertions(+), 9 deletions(-) diff --git a/src/llvm_backend_expr.cpp b/src/llvm_backend_expr.cpp index 55b76b93a..1894e85f6 100644 --- a/src/llvm_backend_expr.cpp +++ b/src/llvm_backend_expr.cpp @@ -2585,16 +2585,35 @@ lbValue lb_emit_comp(lbProcedure *p, TokenKind op_kind, lbValue left, lbValue ri } GB_ASSERT_MSG(mask != nullptr, "Unhandled comparison kind %s (%s) %.*s %s (%s)", type_to_string(left.type), type_to_string(base_type(left.type)), LIT(token_strings[op_kind]), type_to_string(right.type), type_to_string(base_type(right.type))); - // TODO(bill): is this a good approach to dealing with comparisons of vectors? - char const *name = "llvm.vector.reduce.umax"; - LLVMTypeRef types[1] = {LLVMTypeOf(mask)}; - unsigned id = LLVMLookupIntrinsicID(name, gb_strlen(name)); - GB_ASSERT_MSG(id != 0, "Unable to find %s.%s", name, LLVMPrintTypeToString(types[0])); - LLVMValueRef ip = LLVMGetIntrinsicDeclaration(p->module->mod, id, types, gb_count_of(types)); + /* NOTE(bill, 2022-05-28): + Thanks to Per Vognsen, sign extending to + a vector of the same width as the input vector, bit casting to an integer, + and then comparing against zero is the better option + See: https://lists.llvm.org/pipermail/llvm-dev/2012-September/053046.html - LLVMValueRef args[1] = {}; - args[0] = mask; - res.value = LLVMBuildCall(p->builder, ip, args, gb_count_of(args), ""); + // Example assuming 128-bit vector + + %1 = <4 x float> ... + %2 = <4 x float> ... + %3 = fcmp oeq <4 x float> %1, %2 + %4 = sext <4 x i1> %3 to <4 x i32> + %5 = bitcast <4 x i32> %4 to i128 + %6 = icmp ne i128 %5, 0 + br i1 %6, label %true1, label %false2 + + This will result in 1 cmpps + 1 ptest + 1 br + (even without SSE4.1, contrary to what the mail list states, because of pmovmskb) + + */ + + unsigned count = cast(unsigned)get_array_type_count(a); + unsigned elem_sz = cast(unsigned)(type_size_of(elem)*8); + LLVMTypeRef mask_type = LLVMVectorType(LLVMIntTypeInContext(p->module->ctx, elem_sz), count); + mask = LLVMBuildSExtOrBitCast(p->builder, mask, mask_type, ""); + + LLVMTypeRef mask_int_type = LLVMIntTypeInContext(p->module->ctx, cast(unsigned)(8*type_size_of(a))); + LLVMValueRef mask_int = LLVMBuildBitCast(p->builder, mask, mask_int_type, ""); + res.value = LLVMBuildICmp(p->builder, LLVMIntNE, mask_int, LLVMConstNull(LLVMTypeOf(mask_int)), ""); return res; } else {