diff --git a/core/fmt/fmt.odin b/core/fmt/fmt.odin index c0190a0b9..dc5b529ea 100644 --- a/core/fmt/fmt.odin +++ b/core/fmt/fmt.odin @@ -1967,7 +1967,7 @@ fmt_value :: proc(fi: ^Info, v: any, verb: rune) { for col in 0.. 0 { io.write_string(fi.writer, ", ") } - offset := row*info.elem_size + col*info.stride + offset := (row + col*info.elem_stride)*info.elem_size data := uintptr(v.data) + uintptr(offset) fmt_arg(fi, any{rawptr(data), info.elem.id}, verb) @@ -1980,7 +1980,7 @@ fmt_value :: proc(fi: ^Info, v: any, verb: rune) { for col in 0.. 0 { io.write_string(fi.writer, "; ") } - offset := row*info.elem_size + col*info.stride + offset := (row + col*info.elem_stride)*info.elem_size data := uintptr(v.data) + uintptr(offset) fmt_arg(fi, any{rawptr(data), info.elem.id}, verb) diff --git a/core/runtime/core.odin b/core/runtime/core.odin index 611b4002c..ba1e81da6 100644 --- a/core/runtime/core.odin +++ b/core/runtime/core.odin @@ -165,7 +165,7 @@ Type_Info_Relative_Slice :: struct { Type_Info_Matrix :: struct { elem: ^Type_Info, elem_size: int, - stride: int, // bytes + elem_stride: int, row_count: int, column_count: int, } diff --git a/src/check_expr.cpp b/src/check_expr.cpp index 73e1a7e51..eb6040320 100644 --- a/src/check_expr.cpp +++ b/src/check_expr.cpp @@ -7369,6 +7369,7 @@ ExprKind check_expr_base_internal(CheckerContext *c, Operand *o, Ast *node, Type case Type_Array: case Type_DynamicArray: case Type_SimdVector: + case Type_Matrix: { Type *elem_type = nullptr; String context_name = {}; @@ -7395,6 +7396,10 @@ ExprKind check_expr_base_internal(CheckerContext *c, Operand *o, Ast *node, Type elem_type = t->SimdVector.elem; context_name = str_lit("simd vector literal"); max_type_count = t->SimdVector.count; + } else if (t->kind == Type_Matrix) { + elem_type = t->Matrix.elem; + context_name = str_lit("matrix literal"); + max_type_count = t->Matrix.row_count*t->Matrix.column_count; } else { GB_PANIC("unreachable"); } diff --git a/src/llvm_backend.hpp b/src/llvm_backend.hpp index 9041e7621..d2abed354 100644 --- a/src/llvm_backend.hpp +++ b/src/llvm_backend.hpp @@ -393,6 +393,8 @@ lbValue lb_soa_struct_len(lbProcedure *p, lbValue value); void lb_emit_increment(lbProcedure *p, lbValue addr); lbValue lb_emit_select(lbProcedure *p, lbValue cond, lbValue x, lbValue y); +lbValue lb_emit_mul_add(lbProcedure *p, lbValue a, lbValue b, lbValue c, Type *t); + void lb_fill_slice(lbProcedure *p, lbAddr const &slice, lbValue base_elem, lbValue len); lbValue lb_type_info(lbModule *m, Type *type); diff --git a/src/llvm_backend_const.cpp b/src/llvm_backend_const.cpp index 4cfcecdc3..413fb365b 100644 --- a/src/llvm_backend_const.cpp +++ b/src/llvm_backend_const.cpp @@ -523,14 +523,11 @@ lbValue lb_const_value(lbModule *m, Type *type, ExactValue value, bool allow_loc lbValue single_elem = lb_const_value(m, elem, value, allow_local); single_elem.value = llvm_const_cast(single_elem.value, lb_type(m, elem)); - - i64 stride_bytes = matrix_type_stride(type); - i64 stride_elems = stride_bytes/type_size_of(elem); - + i64 total_elem_count = matrix_type_total_elems(type); LLVMValueRef *elems = gb_alloc_array(permanent_allocator(), LLVMValueRef, cast(isize)total_elem_count); for (i64 i = 0; i < row; i++) { - elems[i*stride_elems + i] = single_elem.value; + elems[matrix_index_to_offset(type, i)] = single_elem.value; } for (i64 i = 0; i < total_elem_count; i++) { if (elems[i] == nullptr) { @@ -984,6 +981,82 @@ lbValue lb_const_value(lbModule *m, Type *type, ExactValue value, bool allow_loc res.value = LLVMConstInt(lb_type(m, original_type), bits, false); return res; + } else if (is_type_matrix(type)) { + ast_node(cl, CompoundLit, value.value_compound); + Type *elem_type = type->Matrix.elem; + isize elem_count = cl->elems.count; + if (elem_count == 0 || !elem_type_can_be_constant(elem_type)) { + return lb_const_nil(m, original_type); + } + + i64 max_count = type->Matrix.row_count*type->Matrix.column_count; + i64 total_count = matrix_type_total_elems(type); + + LLVMValueRef *values = gb_alloc_array(temporary_allocator(), LLVMValueRef, cast(isize)total_count); + if (cl->elems[0]->kind == Ast_FieldValue) { + for_array(j, cl->elems) { + Ast *elem = cl->elems[j]; + ast_node(fv, FieldValue, elem); + if (is_ast_range(fv->field)) { + ast_node(ie, BinaryExpr, fv->field); + TypeAndValue lo_tav = ie->left->tav; + TypeAndValue hi_tav = ie->right->tav; + GB_ASSERT(lo_tav.mode == Addressing_Constant); + GB_ASSERT(hi_tav.mode == Addressing_Constant); + + TokenKind op = ie->op.kind; + i64 lo = exact_value_to_i64(lo_tav.value); + i64 hi = exact_value_to_i64(hi_tav.value); + if (op != Token_RangeHalf) { + hi += 1; + } + TypeAndValue tav = fv->value->tav; + LLVMValueRef val = lb_const_value(m, elem_type, tav.value, allow_local).value; + for (i64 k = lo; k < hi; k++) { + i64 offset = matrix_index_to_offset(type, k); + GB_ASSERT(values[offset] == nullptr); + values[offset] = val; + } + } else { + TypeAndValue index_tav = fv->field->tav; + GB_ASSERT(index_tav.mode == Addressing_Constant); + i64 index = exact_value_to_i64(index_tav.value); + TypeAndValue tav = fv->value->tav; + LLVMValueRef val = lb_const_value(m, elem_type, tav.value, allow_local).value; + i64 offset = matrix_index_to_offset(type, index); + GB_ASSERT(values[offset] == nullptr); + values[offset] = val; + } + } + + for (i64 i = 0; i < total_count; i++) { + if (values[i] == nullptr) { + values[i] = LLVMConstNull(lb_type(m, elem_type)); + } + } + + res.value = lb_build_constant_array_values(m, type, elem_type, cast(isize)total_count, values, allow_local); + return res; + } else { + GB_ASSERT_MSG(elem_count == max_count, "%td != %td", elem_count, max_count); + + LLVMValueRef *values = gb_alloc_array(temporary_allocator(), LLVMValueRef, cast(isize)total_count); + + for_array(i, cl->elems) { + TypeAndValue tav = cl->elems[i]->tav; + GB_ASSERT(tav.mode != Addressing_Invalid); + i64 offset = matrix_index_to_offset(type, i); + values[offset] = lb_const_value(m, elem_type, tav.value, allow_local).value; + } + for (isize i = 0; i < total_count; i++) { + if (values[i] == nullptr) { + values[i] = LLVMConstNull(lb_type(m, elem_type)); + } + } + + res.value = lb_build_constant_array_values(m, type, elem_type, cast(isize)total_count, values, allow_local); + return res; + } } else { return lb_const_nil(m, original_type); } diff --git a/src/llvm_backend_expr.cpp b/src/llvm_backend_expr.cpp index bcbb77355..518ce33af 100644 --- a/src/llvm_backend_expr.cpp +++ b/src/llvm_backend_expr.cpp @@ -648,18 +648,23 @@ slow_form: i64 inner = xt->Matrix.column_count; i64 outer_columns = yt->Matrix.column_count; + auto inners = slice_make(permanent_allocator(), inner); + for (i64 j = 0; j < outer_columns; j++) { for (i64 i = 0; i < outer_rows; i++) { + lbValue dst = lb_emit_matrix_epi(p, res.addr, i, j); for (i64 k = 0; k < inner; k++) { - lbValue dst = lb_emit_matrix_epi(p, res.addr, i, j); - lbValue d0 = lb_emit_load(p, dst); - - lbValue a = lb_emit_matrix_ev(p, lhs, i, k); - lbValue b = lb_emit_matrix_ev(p, rhs, k, j); - lbValue c = lb_emit_arith(p, Token_Mul, a, b, elem); - lbValue d = lb_emit_arith(p, Token_Add, d0, c, elem); - lb_emit_store(p, dst, d); + inners[k][0] = lb_emit_matrix_ev(p, lhs, i, k); + inners[k][1] = lb_emit_matrix_ev(p, rhs, k, j); } + + lbValue sum = lb_emit_load(p, dst); + for (i64 k = 0; k < inner; k++) { + lbValue a = inners[k][0]; + lbValue b = inners[k][1]; + sum = lb_emit_mul_add(p, a, b, sum, elem); + } + lb_emit_store(p, dst, sum); } } @@ -3626,6 +3631,7 @@ lbAddr lb_build_addr(lbProcedure *p, Ast *expr) { case Type_Slice: et = bt->Slice.elem; break; case Type_BitSet: et = bt->BitSet.elem; break; case Type_SimdVector: et = bt->SimdVector.elem; break; + case Type_Matrix: et = bt->Matrix.elem; break; } String proc_name = {}; @@ -4157,7 +4163,104 @@ lbAddr lb_build_addr(lbProcedure *p, Ast *expr) { } break; } + + case Type_Matrix: { + if (cl->elems.count > 0) { + lb_addr_store(p, v, lb_const_value(p->module, type, exact_value_compound(expr))); + auto temp_data = array_make(temporary_allocator(), 0, cl->elems.count); + + // NOTE(bill): Separate value, gep, store into their own chunks + for_array(i, cl->elems) { + Ast *elem = cl->elems[i]; + + if (elem->kind == Ast_FieldValue) { + ast_node(fv, FieldValue, elem); + if (lb_is_elem_const(fv->value, et)) { + continue; + } + if (is_ast_range(fv->field)) { + ast_node(ie, BinaryExpr, fv->field); + TypeAndValue lo_tav = ie->left->tav; + TypeAndValue hi_tav = ie->right->tav; + GB_ASSERT(lo_tav.mode == Addressing_Constant); + GB_ASSERT(hi_tav.mode == Addressing_Constant); + + TokenKind op = ie->op.kind; + i64 lo = exact_value_to_i64(lo_tav.value); + i64 hi = exact_value_to_i64(hi_tav.value); + if (op != Token_RangeHalf) { + hi += 1; + } + + lbValue value = lb_build_expr(p, fv->value); + + for (i64 k = lo; k < hi; k++) { + lbCompoundLitElemTempData data = {}; + data.value = value; + + data.elem_index = cast(i32)matrix_index_to_offset(bt, k); + array_add(&temp_data, data); + } + + } else { + auto tav = fv->field->tav; + GB_ASSERT(tav.mode == Addressing_Constant); + i64 index = exact_value_to_i64(tav.value); + + lbValue value = lb_build_expr(p, fv->value); + lbCompoundLitElemTempData data = {}; + data.value = lb_emit_conv(p, value, et); + data.expr = fv->value; + + data.elem_index = cast(i32)matrix_index_to_offset(bt, index); + array_add(&temp_data, data); + } + + } else { + if (lb_is_elem_const(elem, et)) { + continue; + } + lbCompoundLitElemTempData data = {}; + data.expr = elem; + data.elem_index = cast(i32)matrix_index_to_offset(bt, i); + array_add(&temp_data, data); + } + } + + for_array(i, temp_data) { + temp_data[i].gep = lb_emit_array_epi(p, lb_addr_get_ptr(p, v), temp_data[i].elem_index); + } + + for_array(i, temp_data) { + lbValue field_expr = temp_data[i].value; + Ast *expr = temp_data[i].expr; + + auto prev_hint = lb_set_copy_elision_hint(p, lb_addr(temp_data[i].gep), expr); + + if (field_expr.value == nullptr) { + field_expr = lb_build_expr(p, expr); + } + Type *t = field_expr.type; + GB_ASSERT(t->kind != Type_Tuple); + lbValue ev = lb_emit_conv(p, field_expr, et); + + if (!p->copy_elision_hint.used) { + temp_data[i].value = ev; + } + + lb_reset_copy_elision_hint(p, prev_hint); + } + + for_array(i, temp_data) { + if (temp_data[i].value.value != nullptr) { + lb_emit_store(p, temp_data[i].gep, temp_data[i].value); + } + } + } + break; + } + } return v; diff --git a/src/llvm_backend_type.cpp b/src/llvm_backend_type.cpp index 82e20bf60..decb57702 100644 --- a/src/llvm_backend_type.cpp +++ b/src/llvm_backend_type.cpp @@ -877,7 +877,7 @@ void lb_setup_type_info_data(lbProcedure *p) { // NOTE(bill): Setup type_info da LLVMValueRef vals[5] = { lb_get_type_info_ptr(m, t->Matrix.elem).value, lb_const_int(m, t_int, ez).value, - lb_const_int(m, t_int, matrix_type_stride(t)).value, + lb_const_int(m, t_int, matrix_type_stride_in_elems(t)).value, lb_const_int(m, t_int, t->Matrix.row_count).value, lb_const_int(m, t_int, t->Matrix.column_count).value, }; diff --git a/src/llvm_backend_utility.cpp b/src/llvm_backend_utility.cpp index c7e9e1742..fb9264661 100644 --- a/src/llvm_backend_utility.cpp +++ b/src/llvm_backend_utility.cpp @@ -1225,18 +1225,53 @@ lbValue lb_emit_matrix_epi(lbProcedure *p, lbValue s, isize row, isize column) { Type *t = s.type; GB_ASSERT(is_type_pointer(t)); Type *mt = base_type(type_deref(t)); - GB_ASSERT_MSG(is_type_matrix(mt), "%s", type_to_string(mt)); - + Type *ptr = base_array_type(mt); - i64 stride_elems = matrix_type_stride_in_elems(mt); + if (column == 0) { + GB_ASSERT_MSG(is_type_matrix(mt) || is_type_array_like(mt), "%s", type_to_string(mt)); + + LLVMValueRef indices[2] = { + LLVMConstInt(lb_type(p->module, t_int), 0, false), + LLVMConstInt(lb_type(p->module, t_int), cast(unsigned)row, false), + }; + + lbValue res = {}; + if (lb_is_const(s)) { + res.value = LLVMConstGEP(s.value, indices, gb_count_of(indices)); + } else { + res.value = LLVMBuildGEP(p->builder, s.value, indices, gb_count_of(indices), ""); + } + + Type *ptr = base_array_type(mt); + res.type = alloc_type_pointer(ptr); + return res; + } else if (row == 0 && is_type_array_like(mt)) { + LLVMValueRef indices[2] = { + LLVMConstInt(lb_type(p->module, t_int), 0, false), + LLVMConstInt(lb_type(p->module, t_int), cast(unsigned)column, false), + }; + + lbValue res = {}; + if (lb_is_const(s)) { + res.value = LLVMConstGEP(s.value, indices, gb_count_of(indices)); + } else { + res.value = LLVMBuildGEP(p->builder, s.value, indices, gb_count_of(indices), ""); + } + + Type *ptr = base_array_type(mt); + res.type = alloc_type_pointer(ptr); + return res; + } - isize index = row + column*stride_elems; - GB_ASSERT(0 <= index); + + GB_ASSERT_MSG(is_type_matrix(mt), "%s", type_to_string(mt)); + + isize offset = matrix_indices_to_offset(mt, row, column); LLVMValueRef indices[2] = { LLVMConstInt(lb_type(p->module, t_int), 0, false), - LLVMConstInt(lb_type(p->module, t_int), cast(unsigned)index, false), + LLVMConstInt(lb_type(p->module, t_int), cast(unsigned)offset, false), }; lbValue res = {}; @@ -1447,3 +1482,34 @@ lbValue lb_soa_struct_cap(lbProcedure *p, lbValue value) { } return lb_emit_struct_ev(p, value, cast(i32)n); } + + + +lbValue lb_emit_mul_add(lbProcedure *p, lbValue a, lbValue b, lbValue c, Type *t) { + lbModule *m = p->module; + + a = lb_emit_conv(p, a, t); + b = lb_emit_conv(p, b, t); + c = lb_emit_conv(p, c, t); + + if (!is_type_different_to_arch_endianness(t) && is_type_float(t)) { + char const *name = "llvm.fma"; + unsigned id = LLVMLookupIntrinsicID(name, gb_strlen(name)); + GB_ASSERT_MSG(id != 0, "Unable to find %s", name); + + LLVMTypeRef types[1] = {}; + types[0] = lb_type(m, t); + + LLVMValueRef ip = LLVMGetIntrinsicDeclaration(m->mod, id, types, gb_count_of(types)); + LLVMValueRef values[3] = {}; + values[0] = a.value; + values[1] = b.value; + values[2] = c.value; + LLVMValueRef call = LLVMBuildCall(p->builder, ip, values, gb_count_of(values), ""); + return {call, t}; + } else { + lbValue x = lb_emit_arith(p, Token_Mul, a, b, t); + lbValue y = lb_emit_arith(p, Token_Add, x, c, t); + return y; + } +} \ No newline at end of file diff --git a/src/parser.cpp b/src/parser.cpp index c29cf70d9..83da481d5 100644 --- a/src/parser.cpp +++ b/src/parser.cpp @@ -2569,6 +2569,7 @@ bool is_literal_type(Ast *node) { case Ast_DynamicArrayType: case Ast_MapType: case Ast_BitSetType: + case Ast_MatrixType: case Ast_CallExpr: return true; } diff --git a/src/types.cpp b/src/types.cpp index ec094b4ff..bbabdf732 100644 --- a/src/types.cpp +++ b/src/types.cpp @@ -1276,6 +1276,39 @@ i64 matrix_type_total_elems(Type *t) { return size/gb_max(elem_size, 1); } +void matrix_indices_from_index(Type *t, i64 index, i64 *row_index_, i64 *column_index_) { + t = base_type(t); + GB_ASSERT(t->kind == Type_Matrix); + i64 row_count = t->Matrix.row_count; + i64 column_count = t->Matrix.column_count; + GB_ASSERT(0 <= index && index < row_count*column_count); + + i64 row_index = index / column_count; + i64 column_index = index % column_count; + + if (row_index_) *row_index_ = row_index; + if (column_index_) *column_index_ = column_index; +} + +i64 matrix_index_to_offset(Type *t, i64 index) { + t = base_type(t); + GB_ASSERT(t->kind == Type_Matrix); + + i64 row_index, column_index; + matrix_indices_from_index(t, index, &row_index, &column_index); + i64 stride_elems = matrix_type_stride_in_elems(t); + return stride_elems*column_index + row_index; +} + +i64 matrix_indices_to_offset(Type *t, i64 row_index, i64 column_index) { + t = base_type(t); + GB_ASSERT(t->kind == Type_Matrix); + GB_ASSERT(0 <= row_index && row_index < t->Matrix.row_count); + GB_ASSERT(0 <= column_index && column_index < t->Matrix.column_count); + i64 stride_elems = matrix_type_stride_in_elems(t); + return stride_elems*column_index + row_index; +} + bool is_type_dynamic_array(Type *t) { t = base_type(t); return t->kind == Type_DynamicArray;