Add builtin outer_product

This commit is contained in:
gingerBill
2021-10-20 02:06:56 +01:00
parent 7faca7066c
commit 68afbb37f4
4 changed files with 102 additions and 0 deletions

View File

@@ -522,9 +522,41 @@ lbValue lb_emit_matrix_tranpose(lbProcedure *p, lbValue m, Type *type) {
}
}
return lb_addr_load(p, res);
}
lbValue lb_emit_outer_product(lbProcedure *p, lbValue a, lbValue b, Type *type) {
Type *mt = base_type(type);
Type *at = base_type(a.type);
Type *bt = base_type(b.type);
GB_ASSERT(mt->kind == Type_Matrix);
GB_ASSERT(at->kind == Type_Array);
GB_ASSERT(bt->kind == Type_Array);
i64 row_count = mt->Matrix.row_count;
i64 column_count = mt->Matrix.column_count;
GB_ASSERT(row_count == at->Array.count);
GB_ASSERT(column_count == bt->Array.count);
lbAddr res = lb_add_local_generated(p, type, true);
for (i64 j = 0; j < column_count; j++) {
for (i64 i = 0; i < row_count; i++) {
lbValue x = lb_emit_struct_ev(p, a, cast(i32)i);
lbValue y = lb_emit_struct_ev(p, b, cast(i32)j);
lbValue src = lb_emit_arith(p, Token_Mul, x, y, mt->Matrix.elem);
lbValue dst = lb_emit_matrix_epi(p, res.addr, i, j);
lb_emit_store(p, dst, src);
}
}
return lb_addr_load(p, res);
}
lbValue lb_emit_matrix_mul(lbProcedure *p, lbValue lhs, lbValue rhs, Type *type) {
Type *xt = base_type(lhs.type);
Type *yt = base_type(rhs.type);