From cee45c1b155fcc917c2b0f9cfdbfa060304255e1 Mon Sep 17 00:00:00 2001 From: gingerBill Date: Wed, 20 Oct 2021 02:18:30 +0100 Subject: [PATCH] Add `hadamard_product` --- src/check_builtin.cpp | 56 ++++++++++++++++++++++++++++++++++- src/check_type.cpp | 12 ++------ src/checker_builtin_procs.hpp | 2 ++ src/llvm_backend_expr.cpp | 6 ++-- src/llvm_backend_proc.cpp | 10 +++++++ src/types.cpp | 11 +++++++ 6 files changed, 84 insertions(+), 13 deletions(-) diff --git a/src/check_builtin.cpp b/src/check_builtin.cpp index 1d033932f..a9427d4e0 100644 --- a/src/check_builtin.cpp +++ b/src/check_builtin.cpp @@ -2056,6 +2056,14 @@ bool check_builtin_procedure(CheckerContext *c, Operand *operand, Ast *call, i32 return false; } + Type *elem = xt->Array.elem; + + if (!is_type_valid_for_matrix_elems(elem)) { + gbString s = type_to_string(elem); + error(call, "Matrix elements types are limited to integers, floats, and complex, got %s", s); + gb_string_free(s); + } + if (xt->Array.count == 0 || yt->Array.count == 0) { gbString s1 = type_to_string(x.type); gbString s2 = type_to_string(y.type); @@ -2072,7 +2080,53 @@ bool check_builtin_procedure(CheckerContext *c, Operand *operand, Ast *call, i32 } operand->mode = Addressing_Value; - operand->type = alloc_type_matrix(xt->Array.elem, xt->Array.count, yt->Array.count); + operand->type = alloc_type_matrix(elem, xt->Array.count, yt->Array.count); + operand->type = check_matrix_type_hint(operand->type, type_hint); + break; + } + + case BuiltinProc_hadamard_product: { + Operand x = {}; + Operand y = {}; + check_expr(c, &x, ce->args[0]); + if (x.mode == Addressing_Invalid) { + return false; + } + check_expr(c, &y, ce->args[1]); + if (y.mode == Addressing_Invalid) { + return false; + } + if (!is_operand_value(x) || !is_operand_value(y)) { + error(call, "'%.*s' expects a matrix or array types", LIT(builtin_name)); + return false; + } + if (!is_type_matrix(x.type) && !is_type_array(y.type)) { + gbString s1 = type_to_string(x.type); + gbString s2 = type_to_string(y.type); + error(call, "'%.*s' expects matrix or array values, got %s and %s", LIT(builtin_name), s1, s2); + gb_string_free(s2); + gb_string_free(s1); + return false; + } + + if (!are_types_identical(x.type, y.type)) { + gbString s1 = type_to_string(x.type); + gbString s2 = type_to_string(y.type); + error(call, "'%.*s' values of the same type, got %s and %s", LIT(builtin_name), s1, s2); + gb_string_free(s2); + gb_string_free(s1); + return false; + } + + Type *elem = core_array_type(x.type); + if (!is_type_valid_for_matrix_elems(elem)) { + gbString s = type_to_string(elem); + error(call, "'%.*s' expects elements to be types are limited to integers, floats, and complex, got %s", LIT(builtin_name), s); + gb_string_free(s); + } + + operand->mode = Addressing_Value; + operand->type = x.type; operand->type = check_matrix_type_hint(operand->type, type_hint); break; } diff --git a/src/check_type.cpp b/src/check_type.cpp index e752f192d..d9302c65a 100644 --- a/src/check_type.cpp +++ b/src/check_type.cpp @@ -997,8 +997,8 @@ void check_bit_set_type(CheckerContext *c, Type *type, Type *named_type, Ast *no GB_ASSERT(lower <= upper); - i64 bits = MAX_BITS; - if (bs->underlying != nullptr) { + i64 bits = MAX_BITS +; if (bs->underlying != nullptr) { Type *u = check_type(c, bs->underlying); if (!is_type_integer(u)) { gbString ts = type_to_string(u); @@ -2239,13 +2239,7 @@ void check_matrix_type(CheckerContext *ctx, Type **type, Ast *node) { error(column.expr, "Matrix types are limited to a maximum of %d elements, got %lld", MAX_MATRIX_ELEMENT_COUNT, cast(long long)element_count); } - if (is_type_integer(elem)) { - // okay - } else if (is_type_float(elem)) { - // okay - } else if (is_type_complex(elem)) { - // okay - } else { + if (!is_type_valid_for_matrix_elems(elem)) { gbString s = type_to_string(elem); error(column.expr, "Matrix elements types are limited to integers, floats, and complex, got %s", s); gb_string_free(s); diff --git a/src/checker_builtin_procs.hpp b/src/checker_builtin_procs.hpp index 2c7392b09..de4e99d14 100644 --- a/src/checker_builtin_procs.hpp +++ b/src/checker_builtin_procs.hpp @@ -37,6 +37,7 @@ enum BuiltinProcId { BuiltinProc_transpose, BuiltinProc_outer_product, + BuiltinProc_hadamard_product, BuiltinProc_DIRECTIVE, // NOTE(bill): This is used for specialized hash-prefixed procedures @@ -280,6 +281,7 @@ gb_global BuiltinProc builtin_procs[BuiltinProc_COUNT] = { {STR_LIT("transpose"), 1, false, Expr_Expr, BuiltinProcPkg_builtin}, {STR_LIT("outer_product"), 2, false, Expr_Expr, BuiltinProcPkg_builtin}, + {STR_LIT("hadamard_product"), 2, false, Expr_Expr, BuiltinProcPkg_builtin}, {STR_LIT(""), 0, true, Expr_Expr, BuiltinProcPkg_builtin}, // DIRECTIVE diff --git a/src/llvm_backend_expr.cpp b/src/llvm_backend_expr.cpp index 27f12a829..b894bc7b8 100644 --- a/src/llvm_backend_expr.cpp +++ b/src/llvm_backend_expr.cpp @@ -672,13 +672,13 @@ lbValue lb_emit_vector_mul_matrix(lbProcedure *p, lbValue lhs, lbValue rhs, Type -lbValue lb_emit_arith_matrix(lbProcedure *p, TokenKind op, lbValue lhs, lbValue rhs, Type *type) { +lbValue lb_emit_arith_matrix(lbProcedure *p, TokenKind op, lbValue lhs, lbValue rhs, Type *type, bool component_wise=false) { GB_ASSERT(is_type_matrix(lhs.type) || is_type_matrix(rhs.type)); Type *xt = base_type(lhs.type); Type *yt = base_type(rhs.type); - if (op == Token_Mul) { + if (op == Token_Mul && !component_wise) { if (xt->kind == Type_Matrix) { if (yt->kind == Type_Matrix) { return lb_emit_matrix_mul(p, lhs, rhs, type); @@ -703,7 +703,7 @@ lbValue lb_emit_arith_matrix(lbProcedure *p, TokenKind op, lbValue lhs, lbValue array_lhs.type = array_type; array_rhs.type = array_type; - lbValue array = lb_emit_arith_array(p, op, array_lhs, array_rhs, type); + lbValue array = lb_emit_arith_array(p, op, array_lhs, array_rhs, array_type); array.type = type; return array; } diff --git a/src/llvm_backend_proc.cpp b/src/llvm_backend_proc.cpp index 5a7fc1626..da4e4ad28 100644 --- a/src/llvm_backend_proc.cpp +++ b/src/llvm_backend_proc.cpp @@ -1270,6 +1270,16 @@ lbValue lb_build_builtin_proc(lbProcedure *p, Ast *expr, TypeAndValue const &tv, lbValue b = lb_build_expr(p, ce->args[1]); return lb_emit_outer_product(p, a, b, tv.type); } + case BuiltinProc_hadamard_product: + { + lbValue a = lb_build_expr(p, ce->args[0]); + lbValue b = lb_build_expr(p, ce->args[1]); + if (is_type_array(tv.type)) { + return lb_emit_arith(p, Token_Mul, a, b, tv.type); + } + GB_ASSERT(is_type_matrix(tv.type)); + return lb_emit_arith_matrix(p, Token_Mul, a, b, tv.type, true); + } // "Intrinsics" diff --git a/src/types.cpp b/src/types.cpp index eaf1bac74..32e26bcc6 100644 --- a/src/types.cpp +++ b/src/types.cpp @@ -1333,6 +1333,17 @@ i64 matrix_indices_to_offset(Type *t, i64 row_index, i64 column_index) { return stride_elems*column_index + row_index; } +bool is_type_valid_for_matrix_elems(Type *t) { + if (is_type_integer(t)) { + return true; + } else if (is_type_float(t)) { + return true; + } else if (is_type_complex(t)) { + return true; + } + return false; +} + bool is_type_dynamic_array(Type *t) { t = base_type(t); return t->kind == Type_DynamicArray;