From 7faca7066c30d6e663b268dc1e8ec66710ae3dd5 Mon Sep 17 00:00:00 2001 From: gingerBill Date: Wed, 20 Oct 2021 01:51:16 +0100 Subject: [PATCH] Add builtin `transpose` --- src/check_builtin.cpp | 34 ++++++++- src/check_expr.cpp | 36 +++++---- src/checker_builtin_procs.hpp | 4 + src/llvm_backend_expr.cpp | 135 +++++----------------------------- src/llvm_backend_proc.cpp | 6 ++ 5 files changed, 81 insertions(+), 134 deletions(-) diff --git a/src/check_builtin.cpp b/src/check_builtin.cpp index a04302d01..659a74ad7 100644 --- a/src/check_builtin.cpp +++ b/src/check_builtin.cpp @@ -1966,13 +1966,13 @@ bool check_builtin_procedure(CheckerContext *c, Operand *operand, Ast *call, i32 return false; } if (!is_operand_value(x)) { - error(call, "'soa_unzip' expects an #soa slice"); + error(call, "'%.*s' expects an #soa slice", LIT(builtin_name)); return false; } Type *t = base_type(x.type); if (!is_type_soa_struct(t) || t->Struct.soa_kind != StructSoa_Slice) { gbString s = type_to_string(x.type); - error(call, "'soa_unzip' expects an #soa slice, got %s", s); + error(call, "'%.*s' expects an #soa slice, got %s", LIT(builtin_name), s); gb_string_free(s); return false; } @@ -1987,6 +1987,36 @@ bool check_builtin_procedure(CheckerContext *c, Operand *operand, Ast *call, i32 operand->mode = Addressing_Value; break; } + + case BuiltinProc_transpose: { + Operand x = {}; + check_expr(c, &x, ce->args[0]); + if (x.mode == Addressing_Invalid) { + return false; + } + if (!is_operand_value(x)) { + error(call, "'%.*s' expects a matrix or array", LIT(builtin_name)); + return false; + } + Type *t = base_type(x.type); + if (!is_type_matrix(t) && !is_type_array(t)) { + gbString s = type_to_string(x.type); + error(call, "'%.*s' expects a matrix or array, got %s", LIT(builtin_name), s); + gb_string_free(s); + return false; + } + + operand->mode = Addressing_Value; + if (is_type_array(t)) { + // Do nothing + operand->type = x.type; + } else { + GB_ASSERT(t->kind == Type_Matrix); + operand->type = alloc_type_matrix(t->Matrix.elem, t->Matrix.column_count, t->Matrix.row_count); + } + operand->type = check_matrix_type_hint(operand->type, type_hint); + break; + } case BuiltinProc_simd_vector: { Operand x = {}; diff --git a/src/check_expr.cpp b/src/check_expr.cpp index 299810ce0..8a1e5fd86 100644 --- a/src/check_expr.cpp +++ b/src/check_expr.cpp @@ -2708,6 +2708,25 @@ bool can_use_other_type_as_type_hint(bool use_lhs_as_type_hint, Type *other_type return false; } +Type *check_matrix_type_hint(Type *matrix, Type *type_hint) { + Type *xt = base_type(matrix); + if (type_hint != nullptr) { + Type *th = base_type(type_hint); + if (are_types_identical(th, xt)) { + return type_hint; + } else if (xt->kind == Type_Matrix && th->kind == Type_Array) { + if (!are_types_identical(xt->Matrix.elem, th->Array.elem)) { + // ignore + } else if (xt->Matrix.row_count == 1 && xt->Matrix.column_count == th->Array.count) { + return type_hint; + } else if (xt->Matrix.column_count == 1 && xt->Matrix.row_count == th->Array.count) { + return type_hint; + } + } + } + return matrix; +} + void check_binary_matrix(CheckerContext *c, Token const &op, Operand *x, Operand *y, Type *type_hint, bool use_lhs_as_type_hint) { if (!check_binary_op(c, x, op)) { @@ -2791,21 +2810,8 @@ void check_binary_matrix(CheckerContext *c, Token const &op, Operand *x, Operand } matrix_success: - if (type_hint != nullptr) { - Type *th = base_type(type_hint); - if (are_types_identical(th, x->type)) { - x->type = type_hint; - } else if (x->type->kind == Type_Matrix && th->kind == Type_Array) { - Type *xt = x->type; - if (!are_types_identical(xt->Matrix.elem, th->Array.elem)) { - // ignore - } else if (xt->Matrix.row_count == 1 && xt->Matrix.column_count == th->Array.count) { - x->type = type_hint; - } else if (xt->Matrix.column_count == 1 && xt->Matrix.row_count == th->Array.count) { - x->type = type_hint; - } - } - } + x->type = check_matrix_type_hint(x->type, type_hint); + return; diff --git a/src/checker_builtin_procs.hpp b/src/checker_builtin_procs.hpp index 8991d2d5c..21a33bdd3 100644 --- a/src/checker_builtin_procs.hpp +++ b/src/checker_builtin_procs.hpp @@ -34,6 +34,8 @@ enum BuiltinProcId { BuiltinProc_soa_zip, BuiltinProc_soa_unzip, + + BuiltinProc_transpose, BuiltinProc_DIRECTIVE, // NOTE(bill): This is used for specialized hash-prefixed procedures @@ -274,6 +276,8 @@ gb_global BuiltinProc builtin_procs[BuiltinProc_COUNT] = { {STR_LIT("soa_zip"), 1, true, Expr_Expr, BuiltinProcPkg_builtin}, {STR_LIT("soa_unzip"), 1, false, Expr_Expr, BuiltinProcPkg_builtin}, + + {STR_LIT("transpose"), 1, false, Expr_Expr, BuiltinProcPkg_builtin}, {STR_LIT(""), 0, true, Expr_Expr, BuiltinProcPkg_builtin}, // DIRECTIVE diff --git a/src/llvm_backend_expr.cpp b/src/llvm_backend_expr.cpp index 518ce33af..d41a0a127 100644 --- a/src/llvm_backend_expr.cpp +++ b/src/llvm_backend_expr.cpp @@ -502,116 +502,29 @@ bool lb_matrix_elem_simple(Type *t) { return true; } -LLVMValueRef llvm_matrix_column_major_load(lbProcedure *p, lbValue lhs) { - lbModule *m = p->module; - - Type *mt = base_type(lhs.type); +lbValue lb_emit_matrix_tranpose(lbProcedure *p, lbValue m, Type *type) { + if (is_type_array(m.type)) { + m.type = type; + return m; + } + Type *mt = base_type(m.type); GB_ASSERT(mt->kind == Type_Matrix); - GB_ASSERT(lb_matrix_elem_simple(mt)); + lbAddr res = lb_add_local_generated(p, type, true); - i64 stride = matrix_type_stride_in_elems(mt); - i64 rows = mt->Matrix.row_count; - i64 columns = mt->Matrix.column_count; - unsigned elem_count = cast(unsigned)(rows*columns); - - Type *elem = mt->Matrix.elem; - LLVMTypeRef elem_type = lb_type(m, elem); - - LLVMTypeRef vector_type = LLVMVectorType(elem_type, elem_count); - LLVMTypeRef types[] = {vector_type}; - - char const *name = "llvm.matrix.column.major.load"; - unsigned id = LLVMLookupIntrinsicID(name, gb_strlen(name)); - GB_ASSERT_MSG(id != 0, "Unable to find %s", name); - LLVMValueRef ip = LLVMGetIntrinsicDeclaration(m->mod, id, types, gb_count_of(types)); - - lbValue ptr = lb_address_from_load_or_generate_local(p, lhs); - ptr = lb_emit_matrix_epi(p, ptr, 0, 0); - - LLVMValueRef values[5] = {}; - values[0] = ptr.value; - values[1] = lb_const_int(m, t_u64, stride).value; - values[2] = LLVMConstNull(lb_type(m, t_llvm_bool)); - values[3] = lb_const_int(m, t_u32, mt->Matrix.row_count).value; - values[4] = lb_const_int(m, t_u32, mt->Matrix.column_count).value; - - LLVMValueRef call = LLVMBuildCall(p->builder, ip, values, gb_count_of(values), ""); - gb_printf_err("%s\n", LLVMPrintValueToString(call)); - // LLVMAddAttributeAtIndex(call, 0, lb_create_enum_attribute(p->module->ctx, "align", cast(u64)type_align_of(mt))); - return call; + i64 row_count = mt->Matrix.row_count; + i64 column_count = mt->Matrix.column_count; + for (i64 j = 0; j < column_count; j++) { + for (i64 i = 0; i < row_count; i++) { + lbValue src = lb_emit_matrix_ev(p, m, i, j); + lbValue dst = lb_emit_matrix_epi(p, res.addr, j, i); + lb_emit_store(p, dst, src); + } + } + return lb_addr_load(p, res); + } -void llvm_matrix_column_major_store(lbProcedure *p, lbAddr addr, LLVMValueRef vector_value) { - lbModule *m = p->module; - - Type *mt = base_type(lb_addr_type(addr)); - GB_ASSERT(mt->kind == Type_Matrix); - GB_ASSERT(lb_matrix_elem_simple(mt)); - - LLVMTypeRef vector_type = LLVMTypeOf(vector_value); - LLVMTypeRef types[] = {vector_type}; - - char const *name = "llvm.matrix.column.major.store"; - unsigned id = LLVMLookupIntrinsicID(name, gb_strlen(name)); - GB_ASSERT_MSG(id != 0, "Unable to find %s", name); - LLVMValueRef ip = LLVMGetIntrinsicDeclaration(m->mod, id, types, gb_count_of(types)); - - lbValue ptr = lb_addr_get_ptr(p, addr); - ptr = lb_emit_matrix_epi(p, ptr, 0, 0); - - unsigned vector_size = LLVMGetVectorSize(vector_type); - GB_ASSERT((mt->Matrix.row_count*mt->Matrix.column_count) == cast(i64)vector_size); - - i64 stride = matrix_type_stride_in_elems(mt); - - LLVMValueRef values[6] = {}; - values[0] = vector_value; - values[1] = ptr.value; - values[2] = lb_const_int(m, t_u64, stride).value; - values[3] = LLVMConstNull(lb_type(m, t_llvm_bool)); - values[4] = lb_const_int(m, t_u32, mt->Matrix.row_count).value; - values[5] = lb_const_int(m, t_u32, mt->Matrix.column_count).value; - - LLVMValueRef call = LLVMBuildCall(p->builder, ip, values, gb_count_of(values), ""); - gb_printf_err("%s\n", LLVMPrintValueToString(call)); - // LLVMAddAttributeAtIndex(call, 1, lb_create_enum_attribute(p->module->ctx, "align", cast(u64)type_align_of(mt))); - gb_unused(call); -} - - -LLVMValueRef llvm_matrix_multiply(lbProcedure *p, LLVMValueRef a, LLVMValueRef b, i64 outer_rows, i64 inner, i64 outer_columns) { - lbModule *m = p->module; - - LLVMTypeRef a_type = LLVMTypeOf(a); - LLVMTypeRef b_type = LLVMTypeOf(b); - - GB_ASSERT(LLVMGetElementType(a_type) == LLVMGetElementType(b_type)); - - LLVMTypeRef elem_type = LLVMGetElementType(a_type); - - LLVMTypeRef res_vector_type = LLVMVectorType(elem_type, cast(unsigned)(outer_rows*outer_columns)); - - LLVMTypeRef types[] = {res_vector_type, a_type, b_type}; - - char const *name = "llvm.matrix.multiply"; - unsigned id = LLVMLookupIntrinsicID(name, gb_strlen(name)); - GB_ASSERT_MSG(id != 0, "Unable to find %s", name); - LLVMValueRef ip = LLVMGetIntrinsicDeclaration(m->mod, id, types, gb_count_of(types)); - - LLVMValueRef values[5] = {}; - values[0] = a; - values[1] = b; - values[2] = lb_const_int(m, t_u32, outer_rows).value; - values[3] = lb_const_int(m, t_u32, inner).value; - values[4] = lb_const_int(m, t_u32, outer_columns).value; - - LLVMValueRef call = LLVMBuildCall(p->builder, ip, values, gb_count_of(values), ""); - gb_printf_err("%s\n", LLVMPrintValueToString(call)); - return call; -} - - lbValue lb_emit_matrix_mul(lbProcedure *p, lbValue lhs, lbValue rhs, Type *type) { Type *xt = base_type(lhs.type); Type *yt = base_type(rhs.type); @@ -626,18 +539,6 @@ lbValue lb_emit_matrix_mul(lbProcedure *p, lbValue lhs, lbValue rhs, Type *type) goto slow_form; } - if (false) { - // TODO(bill): LLVM ERROR: Do not know how to split the result of this operator! - lbAddr res = lb_add_local_generated(p, type, true); - - LLVMValueRef a = llvm_matrix_column_major_load(p, lhs); gb_unused(a); - LLVMValueRef b = llvm_matrix_column_major_load(p, rhs); gb_unused(b); - LLVMValueRef c = llvm_matrix_multiply(p, a, b, xt->Matrix.row_count, xt->Matrix.column_count, yt->Matrix.column_count); gb_unused(c); - llvm_matrix_column_major_store(p, res, c); - - return lb_addr_load(p, res); - } - slow_form: { Type *elem = xt->Matrix.elem; diff --git a/src/llvm_backend_proc.cpp b/src/llvm_backend_proc.cpp index 222161164..1431fffaa 100644 --- a/src/llvm_backend_proc.cpp +++ b/src/llvm_backend_proc.cpp @@ -1257,6 +1257,12 @@ lbValue lb_build_builtin_proc(lbProcedure *p, Ast *expr, TypeAndValue const &tv, return lb_soa_zip(p, ce, tv); case BuiltinProc_soa_unzip: return lb_soa_unzip(p, ce, tv); + + case BuiltinProc_transpose: + { + lbValue m = lb_build_expr(p, ce->args[0]); + return lb_emit_matrix_tranpose(p, m, tv.type); + } // "Intrinsics"