diff --git a/src/llvm_backend_expr.cpp b/src/llvm_backend_expr.cpp index 968d27a4e..53e0d32de 100644 --- a/src/llvm_backend_expr.cpp +++ b/src/llvm_backend_expr.cpp @@ -866,8 +866,10 @@ gb_internal lbValue lb_emit_matrix_transpose(lbProcedure *p, lbValue m, Type *ty return lb_addr_load(p, res); } -gb_internal lbValue lb_matrix_cast_vector_to_type(lbProcedure *p, LLVMValueRef vector, Type *type) { - lbAddr res = lb_add_local_generated(p, type, true); +gb_internal lbAddr llvm_add_local_generated_from_vector(lbProcedure *p, Type *type, LLVMValueRef vector) { + GB_ASSERT(LLVMGetTypeKind(LLVMTypeOf(vector)) == LLVMVectorTypeKind); + + lbAddr res = lb_add_local_generated(p, type, false); LLVMValueRef res_ptr = res.addr.value; unsigned alignment = cast(unsigned)gb_max(type_align_of(type), lb_alignof(LLVMTypeOf(vector))); LLVMSetAlignment(res_ptr, alignment); @@ -875,9 +877,16 @@ gb_internal lbValue lb_matrix_cast_vector_to_type(lbProcedure *p, LLVMValueRef v res_ptr = LLVMBuildPointerCast(p->builder, res_ptr, LLVMPointerType(LLVMTypeOf(vector), 0), ""); LLVMBuildStore(p->builder, vector, res_ptr); + return res; +} + +gb_internal lbValue lb_matrix_cast_vector_to_type(lbProcedure *p, LLVMValueRef vector, Type *type) { + lbAddr res = llvm_add_local_generated_from_vector(p, type, vector); return lb_addr_load(p, res); } + + gb_internal lbValue lb_emit_matrix_flatten(lbProcedure *p, lbValue m, Type *type) { if (is_type_array(m.type)) { // no-op @@ -895,31 +904,6 @@ gb_internal lbValue lb_emit_matrix_flatten(lbProcedure *p, lbValue m, Type *type lbValue n = lb_const_int(p->module, t_int, type_size_of(type)); lb_mem_copy_non_overlapping(p, res.addr, m_ptr, n); - // i64 row_count = mt->Matrix.row_count; - // i64 column_count = mt->Matrix.column_count; - // TEMPORARY_ALLOCATOR_GUARD(); - - // auto srcs = array_make(temporary_allocator(), 0, row_count*column_count); - // auto dsts = array_make(temporary_allocator(), 0, row_count*column_count); - - // for (i64 j = 0; j < column_count; j++) { - // for (i64 i = 0; i < row_count; i++) { - // lbValue src = lb_emit_matrix_ev(p, m, i, j); - // array_add(&srcs, src); - // } - // } - - // for (i64 j = 0; j < column_count; j++) { - // for (i64 i = 0; i < row_count; i++) { - // lbValue dst = lb_emit_array_epi(p, res.addr, i + j*row_count); - // array_add(&dsts, dst); - // } - // } - - // GB_ASSERT(srcs.count == dsts.count); - // for_array(i, srcs) { - // lb_emit_store(p, dsts[i], srcs[i]); - // } return lb_addr_load(p, res); } @@ -1328,27 +1312,13 @@ gb_internal lbValue lb_emit_vector_mul_matrix(lbProcedure *p, lbValue lhs, lbVal GB_ASSERT(row_count > 0); - LLVMValueRef vector = nullptr; - for (i64 i = 0; i < row_count; i++) { - if (i == 0) { - vector = llvm_vector_mul(p, v_rows[i], m_columns[i]); - } else { - vector = llvm_vector_mul_add(p, v_rows[i], m_columns[i], vector); - } - } - - lbAddr res = lb_add_local_generated(p, type, true); - LLVMValueRef res_ptr = res.addr.value; - unsigned alignment = cast(unsigned)gb_max(type_align_of(type), lb_alignof(LLVMTypeOf(vector))); - LLVMSetAlignment(res_ptr, alignment); - - res_ptr = LLVMBuildPointerCast(p->builder, res_ptr, LLVMPointerType(LLVMTypeOf(vector), 0), ""); - LLVMBuildStore(p->builder, vector, res_ptr); + LLVMValueRef vector = llvm_vector_mul_pairwise_reduce_add(p, v_rows, m_columns); + lbAddr res = llvm_add_local_generated_from_vector(p, type, vector); return lb_addr_load(p, res); } - lbAddr res = lb_add_local_generated(p, type, true); + lbAddr res = lb_add_local_generated(p, type, false); Type *vector_elem_type = base_array_type(rhs.type); diff --git a/src/llvm_backend_utility.cpp b/src/llvm_backend_utility.cpp index 25481b2ed..bb3b4dadb 100644 --- a/src/llvm_backend_utility.cpp +++ b/src/llvm_backend_utility.cpp @@ -2293,6 +2293,7 @@ gb_internal LLVMValueRef llvm_vector_mul_add(lbProcedure *p, LLVMValueRef a, LLV } } + gb_internal LLVMValueRef llvm_get_inline_asm(LLVMTypeRef func_type, String const &str, String const &clobbers, bool has_side_effects=true, bool is_align_stack=false, LLVMInlineAsmDialect dialect=LLVMInlineAsmDialectATT) { return LLVMGetInlineAsm(func_type, cast(char *)str.text, cast(size_t)str.len,