Improve matrix * vector code gen

This commit is contained in:
gingerBill
2026-03-15 21:03:31 +00:00
parent 12b06887a3
commit 6b2853d9f1
2 changed files with 32 additions and 11 deletions

View File

@@ -967,7 +967,7 @@ gb_internal lbValue lb_emit_matrix_mul_vector(lbProcedure *p, lbValue lhs, lbVal
}
for (unsigned row_index = 0; row_index < column_count; row_index++) {
LLVMValueRef value = lb_emit_struct_ev(p, rhs, row_index).value;
LLVMValueRef value = LLVMBuildExtractValue(p->builder, rhs.value, row_index, "");
LLVMValueRef row = llvm_vector_broadcast(p, value, row_count);
v_rows[row_index] = row;
}
@@ -988,13 +988,19 @@ gb_internal lbValue lb_emit_matrix_mul_vector(lbProcedure *p, lbValue lhs, lbVal
lbAddr res = lb_add_local_generated(p, type, true);
Type *vector_elem_type = base_array_type(rhs.type);
for (i64 i = 0; i < mt->Matrix.row_count; i++) {
for (i64 j = 0; j < mt->Matrix.column_count; j++) {
lbValue dst = lb_emit_matrix_epi(p, res.addr, i, 0);
lbValue d0 = lb_emit_load(p, dst);
lbValue a = lb_emit_matrix_ev(p, lhs, i, j);
lbValue b = lb_emit_struct_ev(p, rhs, cast(i32)j);
LLVMValueRef b_value = LLVMBuildExtractValue(p->builder, rhs.value, cast(unsigned)j, "");
lbValue b = {b_value, vector_elem_type};
lbValue c = lb_emit_mul_add(p, a, b, d0, elem);
lb_emit_store(p, dst, c);
}
@@ -1043,7 +1049,7 @@ gb_internal lbValue lb_emit_vector_mul_matrix(lbProcedure *p, lbValue lhs, lbVal
}
for (unsigned column_index = 0; column_index < row_count; column_index++) {
LLVMValueRef value = lb_emit_struct_ev(p, lhs, column_index).value;
LLVMValueRef value = LLVMBuildExtractValue(p->builder, lhs.value, column_index, "");
LLVMValueRef row = llvm_vector_broadcast(p, value, column_count);
v_rows[column_index] = row;
}
@@ -1072,12 +1078,16 @@ gb_internal lbValue lb_emit_vector_mul_matrix(lbProcedure *p, lbValue lhs, lbVal
lbAddr res = lb_add_local_generated(p, type, true);
Type *vector_elem_type = base_array_type(rhs.type);
for (i64 j = 0; j < mt->Matrix.column_count; j++) {
for (i64 k = 0; k < mt->Matrix.row_count; k++) {
lbValue dst = lb_emit_matrix_epi(p, res.addr, 0, j);
lbValue d0 = lb_emit_load(p, dst);
lbValue a = lb_emit_struct_ev(p, lhs, cast(i32)k);
LLVMValueRef a_value = LLVMBuildExtractValue(p->builder, lhs.value, cast(unsigned)k, "");
lbValue a = {a_value, vector_elem_type};
lbValue b = lb_emit_matrix_ev(p, rhs, k, j);
lbValue c = lb_emit_mul_add(p, a, b, d0, elem);
lb_emit_store(p, dst, c);

View File

@@ -1742,14 +1742,25 @@ gb_internal lbValue lb_emit_matrix_ep(lbProcedure *p, lbValue s, lbValue row, lb
return res;
}
gb_internal lbValue lb_emit_matrix_ev(lbProcedure *p, lbValue s, isize row, isize column) {
Type *st = base_type(s.type);
GB_ASSERT_MSG(is_type_matrix(st), "%s", type_to_string(st));
lbValue value = lb_address_from_load_or_generate_local(p, s);
lbValue ptr = lb_emit_matrix_epi(p, value, row, column);
return lb_emit_load(p, ptr);
Type *t = s.type;
Type *mt = base_type(t);
GB_ASSERT_MSG(is_type_matrix(mt), "%s", type_to_string(mt));
isize stride_elems = matrix_type_stride_in_elems(mt);
isize index = -1;
if (mt->Matrix.is_row_major) {
index = column + (row * stride_elems);
} else {
index = row + (column * stride_elems);
}
lbValue res = {};
res.value = LLVMBuildExtractValue(p->builder, s.value, cast(unsigned)index, "");
res.type = base_array_type(mt);
return res;
}
gb_internal void lb_fill_slice(lbProcedure *p, lbAddr const &slice, lbValue base_elem, lbValue len) {