diff --git a/src/check_type.cpp b/src/check_type.cpp index aef1ddc7a..f70230682 100644 --- a/src/check_type.cpp +++ b/src/check_type.cpp @@ -2109,6 +2109,12 @@ void add_map_key_type_dependencies(CheckerContext *ctx, Type *key) { Entity *field = key->Struct.fields[i]; add_map_key_type_dependencies(ctx, field->type); } + } else if (key->kind == Type_Union) { + add_package_dependency(ctx, "runtime", "default_hasher_n"); + for_array(i, key->Union.variants) { + Type *v = key->Union.variants[i]; + add_map_key_type_dependencies(ctx, v); + } } else if (key->kind == Type_EnumeratedArray) { add_package_dependency(ctx, "runtime", "default_hasher_n"); add_map_key_type_dependencies(ctx, key->EnumeratedArray.elem); diff --git a/src/llvm_backend.cpp b/src/llvm_backend.cpp index 29a181f94..e931d1ce9 100644 --- a/src/llvm_backend.cpp +++ b/src/llvm_backend.cpp @@ -10315,7 +10315,7 @@ lbValue lb_get_equal_proc_for_type(lbModule *m, Type *type) { lb_start_block(p, case_block); Type *v = type->Union.variants[i]; - lbValue tag = lb_const_union_tag(p->module, type, v); + lbValue case_tag = lb_const_union_tag(p->module, type, v); Type *vp = alloc_type_pointer(v); @@ -10327,7 +10327,7 @@ lbValue lb_get_equal_proc_for_type(lbModule *m, Type *type) { LLVMBuildRet(p->builder, ok.value); - LLVMAddCase(v_switch, tag.value, case_block->block); + LLVMAddCase(v_switch, case_tag.value, case_block->block); } lb_start_block(p, block_false); @@ -10403,6 +10403,9 @@ lbValue lb_get_hasher_proc_for_type(lbModule *m, Type *type) { lbValue data = {x, t_rawptr}; lbValue seed = {y, t_uintptr}; + LLVMAttributeRef nonnull_attr = lb_create_enum_attribute(m->ctx, "nonnull"); + LLVMAddAttributeAtIndex(p->value, 1+0, nonnull_attr); + if (is_type_simple_compare(type)) { lbValue res = lb_simple_compare_hash(p, type, data, seed); LLVMBuildRet(p->builder, res.value); @@ -10425,6 +10428,38 @@ lbValue lb_get_hasher_proc_for_type(lbModule *m, Type *type) { seed = lb_emit_call(p, field_hasher, args); } LLVMBuildRet(p->builder, seed.value); + } else if (type->kind == Type_Union) { + lbBlock *end_block = lb_create_block(p, "bend"); + + data = lb_emit_conv(p, data, pt); + + lbValue tag_ptr = lb_emit_union_tag_ptr(p, data); + lbValue tag = lb_emit_load(p, tag_ptr); + + LLVMValueRef v_switch = LLVMBuildSwitch(p->builder, tag.value, end_block->block, cast(unsigned)type->Union.variants.count); + + auto args = array_make(permanent_allocator(), 2); + for_array(i, type->Union.variants) { + lbBlock *case_block = lb_create_block(p, "bcase"); + lb_start_block(p, case_block); + + Type *v = type->Union.variants[i]; + Type *vp = alloc_type_pointer(v); + lbValue case_tag = lb_const_union_tag(p->module, type, v); + + lbValue variant_hasher = lb_get_hasher_proc_for_type(m, v); + + args[0] = data; + args[1] = seed; + lbValue res = lb_emit_call(p, variant_hasher, args); + LLVMBuildRet(p->builder, res.value); + + LLVMAddCase(v_switch, case_tag.value, case_block->block); + } + + lb_start_block(p, end_block); + LLVMBuildRet(p->builder, seed.value); + } else if (type->kind == Type_Array) { lbAddr pres = lb_add_local_generated(p, t_uintptr, false); lb_addr_store(p, pres, seed); diff --git a/src/types.cpp b/src/types.cpp index 5e4370f8a..3fafa00d5 100644 --- a/src/types.cpp +++ b/src/types.cpp @@ -1544,9 +1544,6 @@ bool is_type_valid_for_keys(Type *t) { if (is_type_untyped(t)) { return false; } - if (t->kind == Type_Union) { - return false; - } return is_type_comparable(t); }