Improve matrix * vector code gen

This commit is contained in:
gingerBill
2026-03-15 21:03:31 +00:00
parent 12b06887a3
commit 6b2853d9f1
2 changed files with 32 additions and 11 deletions

View File

@@ -967,7 +967,7 @@ gb_internal lbValue lb_emit_matrix_mul_vector(lbProcedure *p, lbValue lhs, lbVal
}
for (unsigned row_index = 0; row_index < column_count; row_index++) {
LLVMValueRef value = lb_emit_struct_ev(p, rhs, row_index).value;
LLVMValueRef value = LLVMBuildExtractValue(p->builder, rhs.value, row_index, "");
LLVMValueRef row = llvm_vector_broadcast(p, value, row_count);
v_rows[row_index] = row;
}
@@ -988,13 +988,19 @@ gb_internal lbValue lb_emit_matrix_mul_vector(lbProcedure *p, lbValue lhs, lbVal
lbAddr res = lb_add_local_generated(p, type, true);
Type *vector_elem_type = base_array_type(rhs.type);
for (i64 i = 0; i < mt->Matrix.row_count; i++) {
for (i64 j = 0; j < mt->Matrix.column_count; j++) {
lbValue dst = lb_emit_matrix_epi(p, res.addr, i, 0);
lbValue d0 = lb_emit_load(p, dst);
lbValue a = lb_emit_matrix_ev(p, lhs, i, j);
lbValue b = lb_emit_struct_ev(p, rhs, cast(i32)j);
LLVMValueRef b_value = LLVMBuildExtractValue(p->builder, rhs.value, cast(unsigned)j, "");
lbValue b = {b_value, vector_elem_type};
lbValue c = lb_emit_mul_add(p, a, b, d0, elem);
lb_emit_store(p, dst, c);
}
@@ -1043,7 +1049,7 @@ gb_internal lbValue lb_emit_vector_mul_matrix(lbProcedure *p, lbValue lhs, lbVal
}
for (unsigned column_index = 0; column_index < row_count; column_index++) {
LLVMValueRef value = lb_emit_struct_ev(p, lhs, column_index).value;
LLVMValueRef value = LLVMBuildExtractValue(p->builder, lhs.value, column_index, "");
LLVMValueRef row = llvm_vector_broadcast(p, value, column_count);
v_rows[column_index] = row;
}
@@ -1072,12 +1078,16 @@ gb_internal lbValue lb_emit_vector_mul_matrix(lbProcedure *p, lbValue lhs, lbVal
lbAddr res = lb_add_local_generated(p, type, true);
Type *vector_elem_type = base_array_type(rhs.type);
for (i64 j = 0; j < mt->Matrix.column_count; j++) {
for (i64 k = 0; k < mt->Matrix.row_count; k++) {
lbValue dst = lb_emit_matrix_epi(p, res.addr, 0, j);
lbValue d0 = lb_emit_load(p, dst);
lbValue a = lb_emit_struct_ev(p, lhs, cast(i32)k);
LLVMValueRef a_value = LLVMBuildExtractValue(p->builder, lhs.value, cast(unsigned)k, "");
lbValue a = {a_value, vector_elem_type};
lbValue b = lb_emit_matrix_ev(p, rhs, k, j);
lbValue c = lb_emit_mul_add(p, a, b, d0, elem);
lb_emit_store(p, dst, c);