From d3abc1a2b4fe024fed5f2b9f5371fc2b7fb029be Mon Sep 17 00:00:00 2001 From: gingerBill Date: Wed, 20 Oct 2021 15:33:23 +0100 Subject: [PATCH] Add `matrix_flatten` - `matrix[R, C]T` -> `[R*C]T` --- src/check_builtin.cpp | 30 ++++++++++++++ src/checker_builtin_procs.hpp | 2 + src/llvm_backend_expr.cpp | 77 +++++++++++++++++++++++++++++++---- src/llvm_backend_proc.cpp | 6 +++ 4 files changed, 106 insertions(+), 9 deletions(-) diff --git a/src/check_builtin.cpp b/src/check_builtin.cpp index a9427d4e0..b60509c03 100644 --- a/src/check_builtin.cpp +++ b/src/check_builtin.cpp @@ -2131,6 +2131,36 @@ bool check_builtin_procedure(CheckerContext *c, Operand *operand, Ast *call, i32 break; } + case BuiltinProc_matrix_flatten: { + Operand x = {}; + check_expr(c, &x, ce->args[0]); + if (x.mode == Addressing_Invalid) { + return false; + } + if (!is_operand_value(x)) { + error(call, "'%.*s' expects a matrix or array", LIT(builtin_name)); + return false; + } + Type *t = base_type(x.type); + if (!is_type_matrix(t) && !is_type_array(t)) { + gbString s = type_to_string(x.type); + error(call, "'%.*s' expects a matrix or array, got %s", LIT(builtin_name), s); + gb_string_free(s); + return false; + } + + operand->mode = Addressing_Value; + if (is_type_array(t)) { + // Do nothing + operand->type = x.type; + } else { + GB_ASSERT(t->kind == Type_Matrix); + operand->type = alloc_type_array(t->Matrix.elem, t->Matrix.row_count*t->Matrix.column_count); + } + operand->type = check_matrix_type_hint(operand->type, type_hint); + break; + } + case BuiltinProc_simd_vector: { Operand x = {}; diff --git a/src/checker_builtin_procs.hpp b/src/checker_builtin_procs.hpp index de4e99d14..5594c1a1a 100644 --- a/src/checker_builtin_procs.hpp +++ b/src/checker_builtin_procs.hpp @@ -38,6 +38,7 @@ enum BuiltinProcId { BuiltinProc_transpose, BuiltinProc_outer_product, BuiltinProc_hadamard_product, + BuiltinProc_matrix_flatten, BuiltinProc_DIRECTIVE, // NOTE(bill): This is used for specialized hash-prefixed procedures @@ -282,6 +283,7 @@ gb_global BuiltinProc builtin_procs[BuiltinProc_COUNT] = { {STR_LIT("transpose"), 1, false, Expr_Expr, BuiltinProcPkg_builtin}, {STR_LIT("outer_product"), 2, false, Expr_Expr, BuiltinProcPkg_builtin}, {STR_LIT("hadamard_product"), 2, false, Expr_Expr, BuiltinProcPkg_builtin}, + {STR_LIT("matrix_flatten"), 1, false, Expr_Expr, BuiltinProcPkg_builtin}, {STR_LIT(""), 0, true, Expr_Expr, BuiltinProcPkg_builtin}, // DIRECTIVE diff --git a/src/llvm_backend_expr.cpp b/src/llvm_backend_expr.cpp index c1bdceba6..7d1c8e3db 100644 --- a/src/llvm_backend_expr.cpp +++ b/src/llvm_backend_expr.cpp @@ -517,6 +517,33 @@ LLVMValueRef lb_matrix_to_vector(lbProcedure *p, lbValue matrix) { return matrix_vector; } +LLVMValueRef lb_matrix_to_trimmed_vector(lbProcedure *p, lbValue m) { + Type *mt = base_type(m.type); + GB_ASSERT(mt->kind == Type_Matrix); + + unsigned stride = cast(unsigned)matrix_type_stride_in_elems(mt); + unsigned row_count = cast(unsigned)mt->Matrix.row_count; + unsigned column_count = cast(unsigned)mt->Matrix.column_count; + + auto columns = slice_make(permanent_allocator(), column_count); + + LLVMValueRef vector = lb_matrix_to_vector(p, m); + + unsigned mask_elems_index = 0; + auto mask_elems = slice_make(permanent_allocator(), row_count*column_count); + for (unsigned j = 0; j < column_count; j++) { + for (unsigned i = 0; i < row_count; i++) { + unsigned offset = stride*j + i; + mask_elems[mask_elems_index++] = lb_const_int(p->module, t_u32, offset).value; + } + } + + LLVMValueRef mask = LLVMConstVector(mask_elems.data, cast(unsigned)mask_elems.count); + LLVMValueRef trimmed_vector = LLVMBuildShuffleVector(p->builder, vector, LLVMGetUndef(LLVMTypeOf(vector)), mask, ""); + return trimmed_vector; +} + + lbValue lb_emit_matrix_tranpose(lbProcedure *p, lbValue m, Type *type) { if (is_type_array(m.type)) { // no-op @@ -573,6 +600,46 @@ lbValue lb_emit_matrix_tranpose(lbProcedure *p, lbValue m, Type *type) { return lb_addr_load(p, res); } +lbValue lb_matrix_cast_vector_to_type(lbProcedure *p, LLVMValueRef vector, Type *type) { + lbAddr res = lb_add_local_generated(p, type, true); + LLVMValueRef res_ptr = res.addr.value; + unsigned alignment = cast(unsigned)gb_max(type_align_of(type), lb_alignof(LLVMTypeOf(vector))); + LLVMSetAlignment(res_ptr, alignment); + + res_ptr = LLVMBuildPointerCast(p->builder, res_ptr, LLVMPointerType(LLVMTypeOf(vector), 0), ""); + LLVMBuildStore(p->builder, vector, res_ptr); + + return lb_addr_load(p, res); +} + +lbValue lb_emit_matrix_flatten(lbProcedure *p, lbValue m, Type *type) { + if (is_type_array(m.type)) { + // no-op + m.type = type; + return m; + } + Type *mt = base_type(m.type); + GB_ASSERT(mt->kind == Type_Matrix); + + if (lb_matrix_elem_simple(mt)) { + LLVMValueRef vector = lb_matrix_to_trimmed_vector(p, m); + return lb_matrix_cast_vector_to_type(p, vector, type); + } + + lbAddr res = lb_add_local_generated(p, type, true); + + i64 row_count = mt->Matrix.row_count; + i64 column_count = mt->Matrix.column_count; + for (i64 j = 0; j < column_count; j++) { + for (i64 i = 0; i < row_count; i++) { + lbValue src = lb_emit_matrix_ev(p, m, i, j); + lbValue dst = lb_emit_matrix_epi(p, res.addr, i, j); + lb_emit_store(p, dst, src); + } + } + return lb_addr_load(p, res); +} + lbValue lb_emit_outer_product(lbProcedure *p, lbValue a, lbValue b, Type *type) { Type *mt = base_type(type); @@ -737,16 +804,8 @@ lbValue lb_emit_matrix_mul_vector(lbProcedure *p, lbValue lhs, lbValue rhs, Type vector = llvm_vector_add(p, vector, product); } } - - lbAddr res = lb_add_local_generated(p, type, true); - LLVMValueRef res_ptr = res.addr.value; - unsigned alignment = cast(unsigned)gb_max(type_align_of(type), lb_alignof(LLVMTypeOf(vector))); - LLVMSetAlignment(res_ptr, alignment); - res_ptr = LLVMBuildPointerCast(p->builder, res_ptr, LLVMPointerType(LLVMTypeOf(vector), 0), ""); - LLVMBuildStore(p->builder, vector, res_ptr); - - return lb_addr_load(p, res); + return lb_matrix_cast_vector_to_type(p, vector, type); } lbAddr res = lb_add_local_generated(p, type, true); diff --git a/src/llvm_backend_proc.cpp b/src/llvm_backend_proc.cpp index da4e4ad28..8686b3262 100644 --- a/src/llvm_backend_proc.cpp +++ b/src/llvm_backend_proc.cpp @@ -1280,6 +1280,12 @@ lbValue lb_build_builtin_proc(lbProcedure *p, Ast *expr, TypeAndValue const &tv, GB_ASSERT(is_type_matrix(tv.type)); return lb_emit_arith_matrix(p, Token_Mul, a, b, tv.type, true); } + + case BuiltinProc_matrix_flatten: + { + lbValue m = lb_build_expr(p, ce->args[0]); + return lb_emit_matrix_flatten(p, m, tv.type); + } // "Intrinsics"