mirror of
https://github.com/odin-lang/Odin.git
synced 2026-06-03 09:14:38 +00:00
Very basic matrix support in backend
This commit is contained in:
@@ -1954,7 +1954,40 @@ fmt_value :: proc(fi: ^Info, v: any, verb: rune) {
|
||||
}
|
||||
|
||||
case runtime.Type_Info_Matrix:
|
||||
io.write_string(fi.writer, "[]")
|
||||
reflect.write_type(fi.writer, type_info_of(v.id))
|
||||
io.write_byte(fi.writer, '{')
|
||||
defer io.write_byte(fi.writer, '}')
|
||||
|
||||
fi.indent += 1; defer fi.indent -= 1
|
||||
|
||||
if fi.hash {
|
||||
io.write_byte(fi.writer, '\n')
|
||||
// TODO(bill): Should this render it like in written form? e.g. tranposed
|
||||
for col in 0..<info.column_count {
|
||||
fmt_write_indent(fi)
|
||||
for row in 0..<info.row_count {
|
||||
if row > 0 { io.write_string(fi.writer, ", ") }
|
||||
|
||||
offset := row*info.elem_size + col*info.stride
|
||||
|
||||
data := uintptr(v.data) + uintptr(offset)
|
||||
fmt_arg(fi, any{rawptr(data), info.elem.id}, verb)
|
||||
}
|
||||
io.write_string(fi.writer, ";\n")
|
||||
}
|
||||
} else {
|
||||
for col in 0..<info.column_count {
|
||||
if col > 0 { io.write_string(fi.writer, "; ") }
|
||||
for row in 0..<info.row_count {
|
||||
if row > 0 { io.write_string(fi.writer, ", ") }
|
||||
|
||||
offset := row*info.elem_size + col*info.stride
|
||||
|
||||
data := uintptr(v.data) + uintptr(offset)
|
||||
fmt_arg(fi, any{rawptr(data), info.elem.id}, verb)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1400,8 +1400,9 @@ bool check_unary_op(CheckerContext *c, Operand *o, Token op) {
|
||||
}
|
||||
|
||||
bool check_binary_op(CheckerContext *c, Operand *o, Token op) {
|
||||
Type *main_type = o->type;
|
||||
// TODO(bill): Handle errors correctly
|
||||
Type *type = base_type(core_array_type(o->type));
|
||||
Type *type = base_type(core_array_type(main_type));
|
||||
Type *ct = core_type(type);
|
||||
switch (op.kind) {
|
||||
case Token_Sub:
|
||||
@@ -1414,10 +1415,15 @@ bool check_binary_op(CheckerContext *c, Operand *o, Token op) {
|
||||
}
|
||||
break;
|
||||
|
||||
case Token_Mul:
|
||||
case Token_Quo:
|
||||
case Token_MulEq:
|
||||
case Token_QuoEq:
|
||||
if (is_type_matrix(main_type)) {
|
||||
error(op, "Operator '%.*s' is only allowed with matrix types", LIT(op.string));
|
||||
return false;
|
||||
}
|
||||
/*fallthrough*/
|
||||
case Token_Mul:
|
||||
case Token_MulEq:
|
||||
case Token_AddEq:
|
||||
if (is_type_bit_set(type)) {
|
||||
return true;
|
||||
@@ -1458,6 +1464,10 @@ bool check_binary_op(CheckerContext *c, Operand *o, Token op) {
|
||||
case Token_ModMod:
|
||||
case Token_ModEq:
|
||||
case Token_ModModEq:
|
||||
if (is_type_matrix(main_type)) {
|
||||
error(op, "Operator '%.*s' is only allowed with matrix types", LIT(op.string));
|
||||
return false;
|
||||
}
|
||||
if (!is_type_integer(type)) {
|
||||
error(op, "Operator '%.*s' is only allowed with integers", LIT(op.string));
|
||||
return false;
|
||||
@@ -2671,6 +2681,114 @@ bool can_use_other_type_as_type_hint(bool use_lhs_as_type_hint, Type *other_type
|
||||
}
|
||||
|
||||
|
||||
void check_binary_matrix(CheckerContext *c, Token const &op, Operand *x, Operand *y, Type *type_hint, bool use_lhs_as_type_hint) {
|
||||
if (!check_binary_op(c, x, op)) {
|
||||
x->mode = Addressing_Invalid;
|
||||
return;
|
||||
}
|
||||
|
||||
if (is_type_matrix(x->type)) {
|
||||
Type *xt = base_type(x->type);
|
||||
Type *yt = base_type(y->type);
|
||||
GB_ASSERT(xt->kind == Type_Matrix);
|
||||
if (op.kind == Token_Mul) {
|
||||
if (yt->kind == Type_Matrix) {
|
||||
if (!are_types_identical(xt->Matrix.elem, yt->Matrix.elem)) {
|
||||
goto matrix_error;
|
||||
}
|
||||
|
||||
if (xt->Matrix.column_count != yt->Matrix.row_count) {
|
||||
goto matrix_error;
|
||||
}
|
||||
x->mode = Addressing_Value;
|
||||
x->type = alloc_type_matrix(xt->Matrix.elem, xt->Matrix.row_count, yt->Matrix.column_count);
|
||||
goto matrix_success;
|
||||
} else if (yt->kind == Type_Array) {
|
||||
if (!are_types_identical(xt->Matrix.elem, yt->Array.elem)) {
|
||||
goto matrix_error;
|
||||
}
|
||||
|
||||
if (xt->Matrix.column_count != yt->Array.count) {
|
||||
goto matrix_error;
|
||||
}
|
||||
|
||||
// Treat arrays as column vectors
|
||||
x->mode = Addressing_Value;
|
||||
x->type = alloc_type_matrix(xt->Matrix.elem, xt->Matrix.row_count, 1);
|
||||
goto matrix_success;
|
||||
}
|
||||
}
|
||||
if (!are_types_identical(xt, yt)) {
|
||||
goto matrix_error;
|
||||
}
|
||||
x->mode = Addressing_Value;
|
||||
x->type = xt;
|
||||
goto matrix_success;
|
||||
} else {
|
||||
Type *xt = base_type(x->type);
|
||||
Type *yt = base_type(y->type);
|
||||
GB_ASSERT(is_type_matrix(yt));
|
||||
GB_ASSERT(!is_type_matrix(xt));
|
||||
|
||||
if (op.kind == Token_Mul) {
|
||||
// NOTE(bill): no need to handle the matrix case here since it should be handled above
|
||||
if (xt->kind == Type_Array) {
|
||||
if (!are_types_identical(yt->Matrix.elem, xt->Array.elem)) {
|
||||
goto matrix_error;
|
||||
}
|
||||
|
||||
if (xt->Array.count != yt->Matrix.row_count) {
|
||||
goto matrix_error;
|
||||
}
|
||||
|
||||
// Treat arrays as row vectors
|
||||
x->mode = Addressing_Value;
|
||||
x->type = alloc_type_matrix(xt->Matrix.elem, 1, xt->Matrix.column_count);
|
||||
goto matrix_success;
|
||||
}
|
||||
}
|
||||
if (!are_types_identical(xt, yt)) {
|
||||
goto matrix_error;
|
||||
}
|
||||
x->mode = Addressing_Value;
|
||||
x->type = xt;
|
||||
goto matrix_success;
|
||||
}
|
||||
|
||||
matrix_success:
|
||||
if (type_hint != nullptr) {
|
||||
Type *th = base_type(type_hint);
|
||||
if (are_types_identical(th, x->type)) {
|
||||
x->type = type_hint;
|
||||
} else if (x->type->kind == Type_Matrix && th->kind == Type_Array) {
|
||||
Type *xt = x->type;
|
||||
if (!are_types_identical(xt->Matrix.elem, th->Array.elem)) {
|
||||
// ignore
|
||||
} else if (xt->Matrix.row_count == 1 && xt->Matrix.column_count == th->Array.count) {
|
||||
x->type = type_hint;
|
||||
} else if (xt->Matrix.column_count == 1 && xt->Matrix.row_count == th->Array.count) {
|
||||
x->type = type_hint;
|
||||
}
|
||||
}
|
||||
}
|
||||
return;
|
||||
|
||||
|
||||
matrix_error:
|
||||
gbString xt = type_to_string(x->type);
|
||||
gbString yt = type_to_string(y->type);
|
||||
gbString expr_str = expr_to_string(x->expr);
|
||||
error(op, "Mismatched types in binary matrix expression '%s' for operator '%.*s' : '%s' vs '%s'", expr_str, LIT(op.string), xt, yt);
|
||||
gb_string_free(expr_str);
|
||||
gb_string_free(yt);
|
||||
gb_string_free(xt);
|
||||
x->type = t_invalid;
|
||||
x->mode = Addressing_Invalid;
|
||||
return;
|
||||
|
||||
}
|
||||
|
||||
|
||||
void check_binary_expr(CheckerContext *c, Operand *x, Ast *node, Type *type_hint, bool use_lhs_as_type_hint=false) {
|
||||
GB_ASSERT(node->kind == Ast_BinaryExpr);
|
||||
Operand y_ = {}, *y = &y_;
|
||||
@@ -2874,6 +2992,12 @@ void check_binary_expr(CheckerContext *c, Operand *x, Ast *node, Type *type_hint
|
||||
x->type = y->type;
|
||||
return;
|
||||
}
|
||||
if (is_type_matrix(x->type) || is_type_matrix(y->type)) {
|
||||
check_binary_matrix(c, op, x, y, type_hint, use_lhs_as_type_hint);
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
if (!are_types_identical(x->type, y->type)) {
|
||||
if (x->type != t_invalid &&
|
||||
y->type != t_invalid) {
|
||||
@@ -3258,6 +3382,29 @@ void convert_to_typed(CheckerContext *c, Operand *operand, Type *target_type) {
|
||||
|
||||
break;
|
||||
}
|
||||
|
||||
case Type_Matrix: {
|
||||
Type *elem = base_array_type(t);
|
||||
if (check_is_assignable_to(c, operand, elem)) {
|
||||
if (t->Matrix.row_count != t->Matrix.column_count) {
|
||||
operand->mode = Addressing_Invalid;
|
||||
begin_error_block();
|
||||
defer (end_error_block());
|
||||
|
||||
convert_untyped_error(c, operand, target_type);
|
||||
error_line("\tNote: Only a square matrix types can be initialized with a scalar value\n");
|
||||
return;
|
||||
} else {
|
||||
operand->mode = Addressing_Value;
|
||||
}
|
||||
} else {
|
||||
operand->mode = Addressing_Invalid;
|
||||
convert_untyped_error(c, operand, target_type);
|
||||
return;
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
|
||||
case Type_Union:
|
||||
if (!is_operand_nil(*operand) && !is_operand_undef(*operand)) {
|
||||
|
||||
@@ -1659,6 +1659,10 @@ void add_type_info_type_internal(CheckerContext *c, Type *t) {
|
||||
add_type_info_type_internal(c, bt->RelativeSlice.slice_type);
|
||||
add_type_info_type_internal(c, bt->RelativeSlice.base_integer);
|
||||
break;
|
||||
|
||||
case Type_Matrix:
|
||||
add_type_info_type_internal(c, bt->Matrix.elem);
|
||||
break;
|
||||
|
||||
default:
|
||||
GB_PANIC("Unhandled type: %*.s %d", LIT(type_strings[bt->kind]), bt->kind);
|
||||
@@ -1870,6 +1874,10 @@ void add_min_dep_type_info(Checker *c, Type *t) {
|
||||
add_min_dep_type_info(c, bt->RelativeSlice.slice_type);
|
||||
add_min_dep_type_info(c, bt->RelativeSlice.base_integer);
|
||||
break;
|
||||
|
||||
case Type_Matrix:
|
||||
add_min_dep_type_info(c, bt->Matrix.elem);
|
||||
break;
|
||||
|
||||
default:
|
||||
GB_PANIC("Unhandled type: %*.s", LIT(type_strings[bt->kind]));
|
||||
|
||||
@@ -333,6 +333,10 @@ lbValue lb_emit_array_ep(lbProcedure *p, lbValue s, lbValue index);
|
||||
lbValue lb_emit_deep_field_gep(lbProcedure *p, lbValue e, Selection sel);
|
||||
lbValue lb_emit_deep_field_ev(lbProcedure *p, lbValue e, Selection sel);
|
||||
|
||||
lbValue lb_emit_matrix_epi(lbProcedure *p, lbValue s, isize row, isize column);
|
||||
lbValue lb_emit_matrix_ev(lbProcedure *p, lbValue s, isize row, isize column);
|
||||
|
||||
|
||||
lbValue lb_emit_arith(lbProcedure *p, TokenKind op, lbValue lhs, lbValue rhs, Type *type);
|
||||
lbValue lb_emit_byte_swap(lbProcedure *p, lbValue value, Type *end_type);
|
||||
void lb_emit_defer_stmts(lbProcedure *p, lbDeferExitKind kind, lbBlock *block);
|
||||
|
||||
@@ -512,6 +512,34 @@ lbValue lb_const_value(lbModule *m, Type *type, ExactValue value, bool allow_loc
|
||||
|
||||
res.value = llvm_const_array(lb_type(m, elem), elems, cast(unsigned)count);
|
||||
return res;
|
||||
} else if (is_type_matrix(type) &&
|
||||
value.kind != ExactValue_Invalid &&
|
||||
value.kind != ExactValue_Compound) {
|
||||
i64 row = type->Matrix.row_count;
|
||||
i64 column = type->Matrix.column_count;
|
||||
GB_ASSERT(row == column);
|
||||
|
||||
Type *elem = type->Matrix.elem;
|
||||
|
||||
lbValue single_elem = lb_const_value(m, elem, value, allow_local);
|
||||
single_elem.value = llvm_const_cast(single_elem.value, lb_type(m, elem));
|
||||
|
||||
i64 stride_bytes = matrix_type_stride(type);
|
||||
i64 stride_elems = stride_bytes/type_size_of(elem);
|
||||
|
||||
i64 total_elem_count = matrix_type_total_elems(type);
|
||||
LLVMValueRef *elems = gb_alloc_array(permanent_allocator(), LLVMValueRef, cast(isize)total_elem_count);
|
||||
for (i64 i = 0; i < row; i++) {
|
||||
elems[i*stride_elems + i] = single_elem.value;
|
||||
}
|
||||
for (i64 i = 0; i < total_elem_count; i++) {
|
||||
if (elems[i] == nullptr) {
|
||||
elems[i] = LLVMConstNull(lb_type(m, elem));
|
||||
}
|
||||
}
|
||||
|
||||
res.value = LLVMConstArray(lb_type(m, elem), elems, cast(unsigned)total_elem_count);
|
||||
return res;
|
||||
}
|
||||
|
||||
switch (value.kind) {
|
||||
|
||||
@@ -477,10 +477,72 @@ lbValue lb_emit_arith_array(lbProcedure *p, TokenKind op, lbValue lhs, lbValue r
|
||||
}
|
||||
|
||||
|
||||
lbValue lb_emit_arith_matrix(lbProcedure *p, TokenKind op, lbValue lhs, lbValue rhs, Type *type) {
|
||||
GB_ASSERT(is_type_matrix(lhs.type) || is_type_matrix(rhs.type));
|
||||
|
||||
Type *xt = base_type(lhs.type);
|
||||
Type *yt = base_type(rhs.type);
|
||||
|
||||
if (op == Token_Mul) {
|
||||
if (xt->kind == Type_Matrix) {
|
||||
if (yt->kind == Type_Matrix) {
|
||||
GB_ASSERT(is_type_matrix(type));
|
||||
GB_ASSERT(xt->Matrix.column_count == yt->Matrix.row_count);
|
||||
GB_ASSERT(are_types_identical(xt->Matrix.elem, yt->Matrix.elem));
|
||||
|
||||
Type *elem = xt->Matrix.elem;
|
||||
|
||||
lbAddr res = lb_add_local_generated(p, type, true);
|
||||
for (i64 i = 0; i < xt->Matrix.row_count; i++) {
|
||||
for (i64 j = 0; j < yt->Matrix.column_count; j++) {
|
||||
for (i64 k = 0; k < xt->Matrix.column_count; k++) {
|
||||
lbValue dst = lb_emit_matrix_epi(p, res.addr, i, j);
|
||||
|
||||
lbValue a = lb_emit_matrix_ev(p, lhs, i, k);
|
||||
lbValue b = lb_emit_matrix_ev(p, rhs, k, j);
|
||||
lbValue c = lb_emit_arith(p, op, a, b, elem);
|
||||
lbValue d = lb_emit_load(p, dst);
|
||||
lbValue e = lb_emit_arith(p, Token_Add, d, c, elem);
|
||||
lb_emit_store(p, dst, e);
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return lb_addr_load(p, res);
|
||||
}
|
||||
}
|
||||
|
||||
} else {
|
||||
GB_ASSERT(are_types_identical(xt, yt));
|
||||
GB_ASSERT(xt->kind == Type_Matrix);
|
||||
// element-wise arithmetic
|
||||
// pretend it is an array
|
||||
lbValue array_lhs = lhs;
|
||||
lbValue array_rhs = rhs;
|
||||
Type *array_type = alloc_type_array(xt->Matrix.elem, matrix_type_total_elems(xt));
|
||||
GB_ASSERT(type_size_of(array_type) == type_size_of(type));
|
||||
|
||||
array_lhs.type = array_type;
|
||||
array_rhs.type = array_type;
|
||||
|
||||
lbValue array = lb_emit_arith_array(p, op, array_lhs, array_rhs, type);
|
||||
array.type = type;
|
||||
return array;
|
||||
}
|
||||
|
||||
GB_PANIC("TODO: lb_emit_arith_matrix");
|
||||
|
||||
return {};
|
||||
}
|
||||
|
||||
|
||||
|
||||
lbValue lb_emit_arith(lbProcedure *p, TokenKind op, lbValue lhs, lbValue rhs, Type *type) {
|
||||
if (is_type_array_like(lhs.type) || is_type_array_like(rhs.type)) {
|
||||
return lb_emit_arith_array(p, op, lhs, rhs, type);
|
||||
} else if (is_type_matrix(lhs.type) || is_type_matrix(rhs.type)) {
|
||||
return lb_emit_arith_matrix(p, op, lhs, rhs, type);
|
||||
} else if (is_type_complex(type)) {
|
||||
lhs = lb_emit_conv(p, lhs, type);
|
||||
rhs = lb_emit_conv(p, rhs, type);
|
||||
@@ -1417,6 +1479,22 @@ lbValue lb_emit_conv(lbProcedure *p, lbValue value, Type *t) {
|
||||
}
|
||||
return lb_addr_load(p, v);
|
||||
}
|
||||
|
||||
if (is_type_matrix(dst) && !is_type_matrix(src)) {
|
||||
GB_ASSERT(dst->Matrix.row_count == dst->Matrix.column_count);
|
||||
|
||||
Type *elem = base_array_type(dst);
|
||||
lbValue e = lb_emit_conv(p, value, elem);
|
||||
lbAddr v = lb_add_local_generated(p, t, false);
|
||||
for (i64 i = 0; i < dst->Matrix.row_count; i++) {
|
||||
isize j = cast(isize)i;
|
||||
lbValue ptr = lb_emit_matrix_epi(p, v.addr, j, j);
|
||||
lb_emit_store(p, ptr, e);
|
||||
}
|
||||
|
||||
|
||||
return lb_addr_load(p, v);
|
||||
}
|
||||
|
||||
if (is_type_any(dst)) {
|
||||
if (is_type_untyped_nil(src)) {
|
||||
|
||||
@@ -1221,6 +1221,41 @@ lbValue lb_emit_ptr_offset(lbProcedure *p, lbValue ptr, lbValue index) {
|
||||
return res;
|
||||
}
|
||||
|
||||
lbValue lb_emit_matrix_epi(lbProcedure *p, lbValue s, isize row, isize column) {
|
||||
Type *t = s.type;
|
||||
GB_ASSERT(is_type_pointer(t));
|
||||
Type *st = base_type(type_deref(t));
|
||||
GB_ASSERT_MSG(is_type_matrix(st), "%s", type_to_string(st));
|
||||
|
||||
Type *ptr = base_array_type(st);
|
||||
|
||||
isize index = row*column;
|
||||
GB_ASSERT(0 <= index);
|
||||
|
||||
LLVMValueRef indices[2] = {
|
||||
LLVMConstInt(lb_type(p->module, t_int), 0, false),
|
||||
LLVMConstInt(lb_type(p->module, t_int), cast(unsigned)index, false),
|
||||
};
|
||||
|
||||
lbValue res = {};
|
||||
if (lb_is_const(s)) {
|
||||
res.value = LLVMConstGEP(s.value, indices, gb_count_of(indices));
|
||||
} else {
|
||||
res.value = LLVMBuildGEP(p->builder, s.value, indices, gb_count_of(indices), "");
|
||||
}
|
||||
res.type = alloc_type_pointer(ptr);
|
||||
return res;
|
||||
}
|
||||
|
||||
lbValue lb_emit_matrix_ev(lbProcedure *p, lbValue s, isize row, isize column) {
|
||||
Type *st = base_type(s.type);
|
||||
GB_ASSERT_MSG(is_type_matrix(st), "%s", type_to_string(st));
|
||||
|
||||
lbValue value = lb_address_from_load_or_generate_local(p, s);
|
||||
lbValue ptr = lb_emit_matrix_epi(p, value, row, column);
|
||||
return lb_emit_load(p, ptr);
|
||||
}
|
||||
|
||||
|
||||
void lb_fill_slice(lbProcedure *p, lbAddr const &slice, lbValue base_elem, lbValue len) {
|
||||
Type *t = lb_addr_type(slice);
|
||||
|
||||
@@ -1257,6 +1257,22 @@ i64 matrix_type_stride(Type *t) {
|
||||
return stride;
|
||||
}
|
||||
|
||||
i64 matrix_type_stride_in_elems(Type *t) {
|
||||
t = base_type(t);
|
||||
GB_ASSERT(t->kind == Type_Matrix);
|
||||
i64 stride = matrix_type_stride(t);
|
||||
return stride/gb_max(1, type_size_of(t->Matrix.elem));
|
||||
}
|
||||
|
||||
|
||||
i64 matrix_type_total_elems(Type *t) {
|
||||
t = base_type(t);
|
||||
GB_ASSERT(t->kind == Type_Matrix);
|
||||
i64 size = type_size_of(t);
|
||||
i64 elem_size = type_size_of(t->Matrix.elem);
|
||||
return size/gb_max(elem_size, 1);
|
||||
}
|
||||
|
||||
bool is_type_dynamic_array(Type *t) {
|
||||
t = base_type(t);
|
||||
return t->kind == Type_DynamicArray;
|
||||
@@ -3174,17 +3190,17 @@ i64 type_align_of_internal(Type *t, TypePath *path) {
|
||||
|
||||
case Type_Matrix: {
|
||||
Type *elem = t->Matrix.elem;
|
||||
i64 row_count = t->Matrix.row_count;
|
||||
// i64 column_count = t->Matrix.column_count;
|
||||
i64 row_count = gb_max(t->Matrix.row_count, 1);
|
||||
|
||||
bool pop = type_path_push(path, elem);
|
||||
if (path->failure) {
|
||||
return FAILURE_ALIGNMENT;
|
||||
}
|
||||
// elem align is used here rather than size as it make a little more sense
|
||||
i64 elem_align = type_align_of_internal(elem, path);
|
||||
if (pop) type_path_pop(path);
|
||||
|
||||
i64 align = gb_clamp(elem_align * row_count, elem_align, build_context.max_align);
|
||||
|
||||
i64 align = gb_min(next_pow2(elem_align * row_count), build_context.max_align);
|
||||
return align;
|
||||
}
|
||||
|
||||
@@ -3935,6 +3951,13 @@ gbString write_type_to_string(gbString str, Type *type) {
|
||||
str = gb_string_append_fmt(str, ") ");
|
||||
str = write_type_to_string(str, type->RelativeSlice.slice_type);
|
||||
break;
|
||||
|
||||
case Type_Matrix:
|
||||
str = gb_string_appendc(str, gb_bprintf("[%d", cast(int)type->Matrix.row_count));
|
||||
str = gb_string_appendc(str, "; ");
|
||||
str = gb_string_appendc(str, gb_bprintf("%d]", cast(int)type->Matrix.column_count));
|
||||
str = write_type_to_string(str, type->Matrix.elem);
|
||||
break;
|
||||
}
|
||||
|
||||
return str;
|
||||
|
||||
Reference in New Issue
Block a user