diff --git a/src/llvm_abi.cpp b/src/llvm_abi.cpp index d4dde2a3d..1c7ca9f78 100644 --- a/src/llvm_abi.cpp +++ b/src/llvm_abi.cpp @@ -30,6 +30,11 @@ gb_internal lbArgType lb_arg_type_direct(LLVMTypeRef type) { return lb_arg_type_direct(type, nullptr, nullptr, nullptr); } +gb_internal lbArgType lb_arg_type_direct_inreg(LLVMContextRef c, LLVMTypeRef type) { + LLVMAttributeRef inreg_attr = lb_create_enum_attribute_with_type(c, "inreg", type); + return lb_arg_type_direct(type, nullptr, nullptr, inreg_attr); +} + gb_internal lbArgType lb_arg_type_indirect(LLVMTypeRef type, LLVMAttributeRef attr) { return lbArgType{lbArg_Indirect, type, nullptr, nullptr, attr, nullptr, 0, false}; } @@ -144,7 +149,7 @@ gb_internal void lb_add_function_calling_convention(LLVMValueRef fn, ProcCalling case ProcCC_Odin: case ProcCC_Contextless: if (ALLOW_WIN64_VECTORCALL_ABI) { - cc_kind = lbCallingConvention_X86_VectorCall; + // cc_kind = lbCallingConvention_X86_VectorCall; } break; } @@ -479,9 +484,9 @@ namespace lbAbiAmd64Win64 { LLVMContextRef c = m->ctx; lbFunctionType *ft = permanent_alloc_item(); ft->ctx = c; + ft->calling_convention = calling_convention; ft->args = compute_arg_types(c, arg_types, arg_count, calling_convention); ft->ret = compute_return_type(ft, c, return_type, return_is_defined, return_is_tuple); - ft->calling_convention = calling_convention; return ft; } @@ -496,6 +501,114 @@ namespace lbAbiAmd64Win64 { return false; } + enum RegClass { + RegClass_Other, + RegClass_Memory, + RegClass_Int, + RegClass_Half, + RegClass_Float, + RegClass_Double, + }; + + gb_internal void vectorcall_classify_with(LLVMTypeRef t, Array *cls, i64 ix, i64 off) { + i64 t_align = lb_alignof(t); + + i64 misalign = off % t_align; + if (misalign != 0) { + array_add(cls, RegClass_Memory); + return; + } + + switch (LLVMGetTypeKind(t)) { + case LLVMIntegerTypeKind: { + array_add(cls, RegClass_Int); + break; + } + case LLVMPointerTypeKind: + array_add(cls, RegClass_Int); + break; + case LLVMHalfTypeKind: + array_add(cls, RegClass_Half); + break; + case LLVMFloatTypeKind: + array_add(cls, RegClass_Float); + break; + case LLVMDoubleTypeKind: + array_add(cls, RegClass_Double); + break; + case LLVMStructTypeKind: + { + LLVMBool packed = LLVMIsPackedStruct(t); + unsigned field_count = LLVMCountStructElementTypes(t); + + i64 field_off = off; + for (unsigned field_index = 0; field_index < field_count; field_index++) { + LLVMTypeRef field_type = LLVMStructGetTypeAtIndex(t, field_index); + if (!packed) { + field_off = llvm_align_formula(field_off, lb_alignof(field_type)); + } + vectorcall_classify_with(field_type, cls, ix, field_off); + field_off += lb_sizeof(field_type); + } + } + break; + case LLVMArrayTypeKind: + { + i64 len = LLVMGetArrayLength(t); + LLVMTypeRef elem = OdinLLVMGetArrayElementType(t); + i64 elem_sz = lb_sizeof(elem); + for (i64 i = 0; i < len; i++) { + vectorcall_classify_with(elem, cls, ix, off + i*elem_sz); + } + } + break; + case LLVMVectorTypeKind: + { + i64 len = LLVMGetVectorSize(t); + LLVMTypeRef elem = OdinLLVMGetVectorElementType(t); + i64 elem_sz = lb_sizeof(elem); + LLVMTypeKind elem_kind = LLVMGetTypeKind(elem); + switch (elem_kind) { + case LLVMIntegerTypeKind: { + unsigned elem_width = LLVMGetIntTypeWidth(elem); + if (elem_width > 64) { + for (i64 i = 0; i < len; i++) { + vectorcall_classify_with(elem, cls, ix, off + i*elem_sz); + } + break; + } else { + array_add(cls, RegClass_Int); + } + break; + }; + case LLVMHalfTypeKind: + array_add(cls, RegClass_Half); + break; + case LLVMFloatTypeKind: + array_add(cls, RegClass_Float); + break; + case LLVMDoubleTypeKind: + array_add(cls, RegClass_Double); + break; + default: + GB_PANIC("Unhandled vector element type"); + } + } + break; + default: + GB_PANIC("Unhandled type"); + break; + } + } + + gb_internal Array vectorcall_classify(LLVMTypeRef t) { + i64 sz = lb_sizeof(t); + i64 words = (sz + 7)/8; + auto reg_classes = array_make(heap_allocator(), 0, cast(isize)words); + vectorcall_classify_with(t, ®_classes, 0, 0); + return reg_classes; + } + gb_internal Array compute_arg_types(LLVMContextRef c, LLVMTypeRef *arg_types, unsigned arg_count, ProcCallingConvention calling_convention) { auto args = array_make(lb_function_type_args_allocator(), arg_count); @@ -504,16 +617,34 @@ namespace lbAbiAmd64Win64 { LLVMTypeKind kind = LLVMGetTypeKind(t); if (is_vectorcall(calling_convention)) { - if (kind == LLVMStructTypeKind || kind == LLVMArrayTypeKind) { - #if 0 + if (kind == LLVMStructTypeKind || kind == LLVMArrayTypeKind || kind == LLVMVectorTypeKind) { i64 sz = lb_sizeof(t); - if (sz <= 8) { - args[i] = lb_arg_type_direct(t, LLVMIntTypeInContext(c, 8*cast(unsigned)sz), nullptr, nullptr); - } else { - args[i] = lb_arg_type_indirect(t, nullptr); + auto cls = vectorcall_classify(t); + defer (array_free(&cls)); + + if (sz <= 32 && + cls.count > 0 && + cls.data[0] >= RegClass_Float) { + bool is_inreg = true; + auto first = cls.data[0]; + for (isize i = 1; i < cls.count; i++) { + if (cls.data[i] != first) { + is_inreg = false; + break; + } + } + if (is_inreg) { + if (first == RegClass_Float && sz <= 16) { + args[i] = lb_arg_type_direct_inreg(c, t); + continue; + } else if (first == RegClass_Double && sz <= 32) { + args[i] = lb_arg_type_direct_inreg(c, t); + continue; + } + + } } - #else - i64 sz = lb_sizeof(t); + switch (sz) { case 1: case 2: @@ -525,21 +656,12 @@ namespace lbAbiAmd64Win64 { args[i] = lb_arg_type_indirect(t, nullptr); break; } - #endif } - if (kind == LLVMVectorTypeKind) { - i64 sz = lb_sizeof(t); - if (sz <= 32) { - args[i] = lb_arg_type_direct(t, t, nullptr, nullptr); - } else { - args[i] = lbAbi386::non_struct(c, t, false); - } - } else { - args[i] = lbAbi386::non_struct(c, t, false); - } + + args[i] = lbAbi386::non_struct(c, t, false); } else { - if (kind == LLVMStructTypeKind || kind == LLVMArrayTypeKind) { + if (kind == LLVMStructTypeKind || kind == LLVMArrayTypeKind || kind == LLVMVectorTypeKind) { i64 sz = lb_sizeof(t); switch (sz) { case 1: @@ -563,13 +685,42 @@ namespace lbAbiAmd64Win64 { gb_internal LB_ABI_COMPUTE_RETURN_TYPE(compute_return_type) { if (!return_is_defined) { return lb_arg_type_direct(LLVMVoidTypeInContext(c)); - } else if (lb_is_type_kind(return_type, LLVMStructTypeKind) || lb_is_type_kind(return_type, LLVMArrayTypeKind)) { - i64 sz = lb_sizeof(return_type); + } + LLVMTypeKind kind = LLVMGetTypeKind(return_type); + + i64 sz = lb_sizeof(return_type); + + if (is_vectorcall(ft->calling_convention)) { + if (kind == LLVMStructTypeKind || kind == LLVMArrayTypeKind || kind == LLVMVectorTypeKind) { + auto cls = vectorcall_classify(return_type); + defer (array_free(&cls)); + + if (sz <= 32 && + cls.count > 0 && + cls.data[0] >= RegClass_Float) { + bool is_inreg = true; + auto first = cls.data[0]; + for (isize i = 1; i < cls.count; i++) { + if (cls.data[i] != first) { + is_inreg = false; + break; + } + } + if (is_inreg) { + if (first == RegClass_Float && sz <= 16) { + return lb_arg_type_direct(return_type); + } else if (first == RegClass_Double && sz <= 32) { + return lb_arg_type_direct(return_type); + } + } + } + } + } + + if (kind == LLVMStructTypeKind || kind == LLVMArrayTypeKind || kind == LLVMVectorTypeKind) { switch (sz) { - case 1: return lb_arg_type_direct(return_type, LLVMIntTypeInContext(c, 8), nullptr, nullptr); - case 2: return lb_arg_type_direct(return_type, LLVMIntTypeInContext(c, 16), nullptr, nullptr); - case 4: return lb_arg_type_direct(return_type, LLVMIntTypeInContext(c, 32), nullptr, nullptr); - case 8: return lb_arg_type_direct(return_type, LLVMIntTypeInContext(c, 64), nullptr, nullptr); + case 1: case 2: case 4: case 8: + return lb_arg_type_direct(return_type, LLVMIntTypeInContext(c, 8*cast(unsigned)sz), nullptr, nullptr); } LB_ABI_MODIFY_RETURN_IF_TUPLE_MACRO(); @@ -578,17 +729,6 @@ namespace lbAbiAmd64Win64 { return lb_arg_type_indirect(return_type, attr); } - if (is_vectorcall(ft->calling_convention)) { - if (lb_is_type_kind(return_type, LLVMVectorTypeKind)) { - i64 sz = lb_sizeof(return_type); - if (sz <= 32) { - return lb_arg_type_direct(return_type, return_type, nullptr, nullptr); - } - - return lb_arg_type_indirect(return_type, nullptr); - } - } - return lbAbi386::non_struct(c, return_type, true); } };