diff --git a/core/runtime/core.odin b/core/runtime/core.odin index 5f016579f..cb526ed2d 100644 --- a/core/runtime/core.odin +++ b/core/runtime/core.odin @@ -121,6 +121,9 @@ Type_Info_Union :: struct { variants: []^Type_Info, tag_offset: uintptr, tag_type: ^Type_Info, + + equal: Equal_Proc, // set only when the struct has .Comparable set but does not have .Simple_Compare set + custom_align: bool, no_nil: bool, maybe: bool, diff --git a/src/llvm_backend.cpp b/src/llvm_backend.cpp index de202658a..29a181f94 100644 --- a/src/llvm_backend.cpp +++ b/src/llvm_backend.cpp @@ -10259,7 +10259,7 @@ lbValue lb_get_equal_proc_for_type(lbModule *m, Type *type) { lb_start_block(p, block_diff_ptr); - if (type->kind == Type_Struct) { + if (type->kind == Type_Struct) { type_set_offsets(type); lbBlock *block_false = lb_create_block(p, "bfalse"); @@ -10285,6 +10285,56 @@ lbValue lb_get_equal_proc_for_type(lbModule *m, Type *type) { lb_start_block(p, block_false); LLVMBuildRet(p->builder, LLVMConstInt(lb_type(m, t_bool), 0, false)); + } else if (type->kind == Type_Union) { + if (is_type_union_maybe_pointer(type)) { + Type *v = type->Union.variants[0]; + Type *pv = alloc_type_pointer(v); + + lbValue left = lb_emit_load(p, lb_emit_conv(p, lhs, pv)); + lbValue right = lb_emit_load(p, lb_emit_conv(p, rhs, pv)); + + lbValue ok = lb_emit_comp(p, Token_CmpEq, left, right); + ok = lb_emit_conv(p, ok, t_bool); + LLVMBuildRet(p->builder, ok.value); + } else { + lbBlock *block_false = lb_create_block(p, "bfalse"); + lbBlock *block_switch = lb_create_block(p, "bswitch"); + + lbValue left_tag = lb_emit_load(p, lb_emit_union_tag_ptr(p, lhs)); + lbValue right_tag = lb_emit_load(p, lb_emit_union_tag_ptr(p, rhs)); + + lbValue tag_eq = lb_emit_comp(p, Token_CmpEq, left_tag, right_tag); + lb_emit_if(p, tag_eq, block_switch, block_false); + + lb_start_block(p, block_switch); + LLVMValueRef v_switch = LLVMBuildSwitch(p->builder, left_tag.value, block_false->block, cast(unsigned)type->Union.variants.count); + + + for_array(i, type->Union.variants) { + lbBlock *case_block = lb_create_block(p, "bcase"); + lb_start_block(p, case_block); + + Type *v = type->Union.variants[i]; + lbValue tag = lb_const_union_tag(p->module, type, v); + + Type *vp = alloc_type_pointer(v); + + lbValue left = lb_emit_load(p, lb_emit_conv(p, lhs, vp)); + lbValue right = lb_emit_load(p, lb_emit_conv(p, rhs, vp)); + lbValue ok = lb_emit_comp(p, Token_CmpEq, left, right); + ok = lb_emit_conv(p, ok, t_bool); + + LLVMBuildRet(p->builder, ok.value); + + + LLVMAddCase(v_switch, tag.value, case_block->block); + } + + lb_start_block(p, block_false); + + LLVMBuildRet(p->builder, LLVMConstInt(lb_type(m, t_bool), 0, false)); + } + } else { lbValue left = lb_emit_load(p, lhs); lbValue right = lb_emit_load(p, rhs); @@ -10565,7 +10615,7 @@ lbValue lb_emit_comp(lbProcedure *p, TokenKind op_kind, lbValue left, lbValue ri } - if (is_type_struct(a) && is_type_comparable(a)) { + if ((is_type_struct(a) || is_type_union(b)) && is_type_comparable(a)) { lbValue left_ptr = lb_address_from_load_or_generate_local(p, left); lbValue right_ptr = lb_address_from_load_or_generate_local(p, right); lbValue res = {}; @@ -13467,7 +13517,7 @@ void lb_setup_type_info_data(lbProcedure *p) { // NOTE(bill): Setup type_info da tag = lb_const_ptr_cast(m, variant_ptr, t_type_info_union_ptr); { - LLVMValueRef vals[6] = {}; + LLVMValueRef vals[7] = {}; isize variant_count = gb_max(0, t->Union.variants.count); lbValue memory_types = lb_type_info_member_types_offset(p, variant_count); @@ -13496,10 +13546,19 @@ void lb_setup_type_info_data(lbProcedure *p) { // NOTE(bill): Setup type_info da vals[2] = LLVMConstNull(lb_type(m, t_type_info_ptr)); } - vals[3] = lb_const_bool(m, t_bool, t->Union.custom_align != 0).value; - vals[4] = lb_const_bool(m, t_bool, t->Union.no_nil).value; - vals[5] = lb_const_bool(m, t_bool, t->Union.maybe).value; + if (is_type_comparable(t) && !is_type_simple_compare(t)) { + vals[3] = lb_get_equal_proc_for_type(m, t).value; + } + vals[4] = lb_const_bool(m, t_bool, t->Union.custom_align != 0).value; + vals[5] = lb_const_bool(m, t_bool, t->Union.no_nil).value; + vals[6] = lb_const_bool(m, t_bool, t->Union.maybe).value; + + for (isize i = 0; i < gb_count_of(vals); i++) { + if (vals[i] == nullptr) { + vals[i] = LLVMConstNull(lb_type(m, get_struct_field_type(tag.type, i))); + } + } lbValue res = {}; res.type = type_deref(tag.type); diff --git a/src/types.cpp b/src/types.cpp index 7b5b81c43..5e4370f8a 100644 --- a/src/types.cpp +++ b/src/types.cpp @@ -1544,6 +1544,9 @@ bool is_type_valid_for_keys(Type *t) { if (is_type_untyped(t)) { return false; } + if (t->kind == Type_Union) { + return false; + } return is_type_comparable(t); } @@ -1915,6 +1918,18 @@ bool is_type_comparable(Type *t) { } } return true; + + case Type_Union: + if (type_size_of(t) == 0) { + return false; + } + for_array(i, t->Union.variants) { + Type *v = t->Union.variants[i]; + if (!is_type_comparable(v)) { + return false; + } + } + return true; } return false; } @@ -1959,7 +1974,8 @@ bool is_type_simple_compare(Type *t) { return false; } } - return true; + // make it dumb on purpose + return t->Union.variants.count == 1; case Type_SimdVector: return is_type_simple_compare(t->SimdVector.elem);