From 3dcc22fa6d0779e35e193ba4f5fae6b919d89080 Mon Sep 17 00:00:00 2001 From: gingerBill Date: Wed, 16 Apr 2025 10:52:35 +0100 Subject: [PATCH] Change hashing rules for float-like types to make `0 == -0` --- base/runtime/dynamic_map_internal.odin | 29 ++++++++++++++++ core/reflect/reflect.odin | 43 +++++++++++++++++++++-- src/check_type.cpp | 15 ++++++++ src/llvm_backend.cpp | 47 ++++++++++++++++++++++++++ src/types.cpp | 2 +- 5 files changed, 132 insertions(+), 4 deletions(-) diff --git a/base/runtime/dynamic_map_internal.odin b/base/runtime/dynamic_map_internal.odin index 96ae9c73c..7b65a2fa0 100644 --- a/base/runtime/dynamic_map_internal.odin +++ b/base/runtime/dynamic_map_internal.odin @@ -1029,3 +1029,32 @@ default_hasher_cstring :: proc "contextless" (data: rawptr, seed: uintptr) -> ui h &= HASH_MASK return uintptr(h) | uintptr(uintptr(h) == 0) } + +default_hasher_f64 :: proc "contextless" (f: f64, seed: uintptr) -> uintptr { + f := f + buf: [size_of(f)]u8 + if f == 0 { + return default_hasher(&buf, seed, size_of(buf)) + } + if f != f { + // TODO(bill): What should the logic be for NaNs? + return default_hasher(&f, seed, size_of(f)) + } + return default_hasher(&f, seed, size_of(f)) +} + +default_hasher_complex128 :: proc "contextless" (x, y: f64, seed: uintptr) -> uintptr { + seed := seed + seed = default_hasher_f64(x, seed) + seed = default_hasher_f64(y, seed) + return seed +} + +default_hasher_quaternion256 :: proc "contextless" (x, y, z, w: f64, seed: uintptr) -> uintptr { + seed := seed + seed = default_hasher_f64(x, seed) + seed = default_hasher_f64(y, seed) + seed = default_hasher_f64(z, seed) + seed = default_hasher_f64(w, seed) + return seed +} \ No newline at end of file diff --git a/core/reflect/reflect.odin b/core/reflect/reflect.odin index 115b19b64..b3315a0c3 100644 --- a/core/reflect/reflect.odin +++ b/core/reflect/reflect.odin @@ -1439,6 +1439,11 @@ as_f64 :: proc(a: any) -> (value: f64, valid: bool) { case Type_Info_Complex: switch v in a { + case complex32: + if imag(v) == 0 { + value = f64(real(v)) + valid = true + } case complex64: if imag(v) == 0 { value = f64(real(v)) @@ -1453,6 +1458,11 @@ as_f64 :: proc(a: any) -> (value: f64, valid: bool) { case Type_Info_Quaternion: switch v in a { + case quaternion64: + if imag(v) == 0 && jmag(v) == 0 && kmag(v) == 0 { + value = f64(real(v)) + valid = true + } case quaternion128: if imag(v) == 0 && jmag(v) == 0 && kmag(v) == 0 { value = f64(real(v)) @@ -1646,13 +1656,40 @@ equal :: proc(a, b: any, including_indirect_array_recursion := false, recursion_ return equal(va, vb, including_indirect_array_recursion, recursion_level+1) case Type_Info_Map: return false + case Type_Info_Float: + x, _ := as_f64(a) + y, _ := as_f64(b) + return x == y + case Type_Info_Complex: + switch x in a { + case complex32: + #no_type_assert y := b.(complex32) + return x == y + case complex64: + #no_type_assert y := b.(complex64) + return x == y + case complex128: + #no_type_assert y := b.(complex128) + return x == y + } + return false + case Type_Info_Quaternion: + switch x in a { + case quaternion64: + #no_type_assert y := b.(quaternion64) + return x == y + case quaternion128: + #no_type_assert y := b.(quaternion128) + return x == y + case quaternion256: + #no_type_assert y := b.(quaternion256) + return x == y + } + return false case Type_Info_Boolean, Type_Info_Integer, Type_Info_Rune, - Type_Info_Float, - Type_Info_Complex, - Type_Info_Quaternion, Type_Info_Type_Id, Type_Info_Pointer, Type_Info_Multi_Pointer, diff --git a/src/check_type.cpp b/src/check_type.cpp index 89dcacfc5..1549f477e 100644 --- a/src/check_type.cpp +++ b/src/check_type.cpp @@ -2774,6 +2774,21 @@ gb_internal void add_map_key_type_dependencies(CheckerContext *ctx, Type *key) { return; } + if (key->kind == Type_Basic) { + if (key->Basic.flags & BasicFlag_Quaternion) { + add_package_dependency(ctx, "runtime", "default_hasher_f64"); + add_package_dependency(ctx, "runtime", "default_hasher_quaternion256"); + return; + } else if (key->Basic.flags & BasicFlag_Complex) { + add_package_dependency(ctx, "runtime", "default_hasher_f64"); + add_package_dependency(ctx, "runtime", "default_hasher_complex128"); + return; + } else if (key->Basic.flags & BasicFlag_Float) { + add_package_dependency(ctx, "runtime", "default_hasher_f64"); + return; + } + } + if (key->kind == Type_Struct) { add_package_dependency(ctx, "runtime", "default_hasher"); for_array(i, key->Struct.fields) { diff --git a/src/llvm_backend.cpp b/src/llvm_backend.cpp index ee0ea7567..083a1d90e 100644 --- a/src/llvm_backend.cpp +++ b/src/llvm_backend.cpp @@ -563,6 +563,53 @@ gb_internal lbValue lb_hasher_proc_for_type(lbModule *m, Type *type) { lbValue res = lb_emit_runtime_call(p, "default_hasher_string", args); lb_add_callsite_force_inline(p, res); LLVMBuildRet(p->builder, res.value); + } else if (is_type_float(type)) { + lbValue ptr = lb_emit_conv(p, data, pt); + lbValue v = lb_emit_load(p, ptr); + v = lb_emit_conv(p, v, t_f64); + + auto args = array_make(temporary_allocator(), 2); + args[0] = v; + args[1] = seed; + lbValue res = lb_emit_runtime_call(p, "default_hasher_f64", args); + lb_add_callsite_force_inline(p, res); + LLVMBuildRet(p->builder, res.value); + } else if (is_type_complex(type)) { + lbValue ptr = lb_emit_conv(p, data, pt); + lbValue xp = lb_emit_struct_ep(p, ptr, 0); + lbValue yp = lb_emit_struct_ep(p, ptr, 1); + + lbValue x = lb_emit_conv(p, lb_emit_load(p, xp), t_f64); + lbValue y = lb_emit_conv(p, lb_emit_load(p, yp), t_f64); + + auto args = array_make(temporary_allocator(), 3); + args[0] = x; + args[1] = y; + args[2] = seed; + lbValue res = lb_emit_runtime_call(p, "default_hasher_complex128", args); + lb_add_callsite_force_inline(p, res); + LLVMBuildRet(p->builder, res.value); + } else if (is_type_quaternion(type)) { + lbValue ptr = lb_emit_conv(p, data, pt); + lbValue xp = lb_emit_struct_ep(p, ptr, 0); + lbValue yp = lb_emit_struct_ep(p, ptr, 1); + lbValue zp = lb_emit_struct_ep(p, ptr, 2); + lbValue wp = lb_emit_struct_ep(p, ptr, 3); + + lbValue x = lb_emit_conv(p, lb_emit_load(p, xp), t_f64); + lbValue y = lb_emit_conv(p, lb_emit_load(p, yp), t_f64); + lbValue z = lb_emit_conv(p, lb_emit_load(p, zp), t_f64); + lbValue w = lb_emit_conv(p, lb_emit_load(p, wp), t_f64); + + auto args = array_make(temporary_allocator(), 5); + args[0] = x; + args[1] = y; + args[2] = z; + args[3] = w; + args[4] = seed; + lbValue res = lb_emit_runtime_call(p, "default_hasher_quaternion256", args); + lb_add_callsite_force_inline(p, res); + LLVMBuildRet(p->builder, res.value); } else { GB_PANIC("Unhandled type for hasher: %s", type_to_string(type)); } diff --git a/src/types.cpp b/src/types.cpp index 48631a373..9c9472a28 100644 --- a/src/types.cpp +++ b/src/types.cpp @@ -111,7 +111,7 @@ enum BasicFlag { BasicFlag_Ordered = BasicFlag_Integer | BasicFlag_Float | BasicFlag_String | BasicFlag_Pointer | BasicFlag_Rune, BasicFlag_OrderedNumeric = BasicFlag_Integer | BasicFlag_Float | BasicFlag_Rune, BasicFlag_ConstantType = BasicFlag_Boolean | BasicFlag_Numeric | BasicFlag_String | BasicFlag_Pointer | BasicFlag_Rune, - BasicFlag_SimpleCompare = BasicFlag_Boolean | BasicFlag_Numeric | BasicFlag_Pointer | BasicFlag_Rune, + BasicFlag_SimpleCompare = BasicFlag_Boolean | BasicFlag_Integer | BasicFlag_Pointer | BasicFlag_Rune, }; struct BasicType {