mirror of
https://github.com/odin-lang/Odin.git
synced 2026-06-14 14:23:43 +00:00
Improve matrix * vector code gen
This commit is contained in:
@@ -967,7 +967,7 @@ gb_internal lbValue lb_emit_matrix_mul_vector(lbProcedure *p, lbValue lhs, lbVal
|
||||
}
|
||||
|
||||
for (unsigned row_index = 0; row_index < column_count; row_index++) {
|
||||
LLVMValueRef value = lb_emit_struct_ev(p, rhs, row_index).value;
|
||||
LLVMValueRef value = LLVMBuildExtractValue(p->builder, rhs.value, row_index, "");
|
||||
LLVMValueRef row = llvm_vector_broadcast(p, value, row_count);
|
||||
v_rows[row_index] = row;
|
||||
}
|
||||
@@ -988,13 +988,19 @@ gb_internal lbValue lb_emit_matrix_mul_vector(lbProcedure *p, lbValue lhs, lbVal
|
||||
|
||||
lbAddr res = lb_add_local_generated(p, type, true);
|
||||
|
||||
Type *vector_elem_type = base_array_type(rhs.type);
|
||||
|
||||
for (i64 i = 0; i < mt->Matrix.row_count; i++) {
|
||||
for (i64 j = 0; j < mt->Matrix.column_count; j++) {
|
||||
lbValue dst = lb_emit_matrix_epi(p, res.addr, i, 0);
|
||||
lbValue d0 = lb_emit_load(p, dst);
|
||||
|
||||
lbValue a = lb_emit_matrix_ev(p, lhs, i, j);
|
||||
lbValue b = lb_emit_struct_ev(p, rhs, cast(i32)j);
|
||||
|
||||
|
||||
LLVMValueRef b_value = LLVMBuildExtractValue(p->builder, rhs.value, cast(unsigned)j, "");
|
||||
lbValue b = {b_value, vector_elem_type};
|
||||
|
||||
lbValue c = lb_emit_mul_add(p, a, b, d0, elem);
|
||||
lb_emit_store(p, dst, c);
|
||||
}
|
||||
@@ -1043,7 +1049,7 @@ gb_internal lbValue lb_emit_vector_mul_matrix(lbProcedure *p, lbValue lhs, lbVal
|
||||
}
|
||||
|
||||
for (unsigned column_index = 0; column_index < row_count; column_index++) {
|
||||
LLVMValueRef value = lb_emit_struct_ev(p, lhs, column_index).value;
|
||||
LLVMValueRef value = LLVMBuildExtractValue(p->builder, lhs.value, column_index, "");
|
||||
LLVMValueRef row = llvm_vector_broadcast(p, value, column_count);
|
||||
v_rows[column_index] = row;
|
||||
}
|
||||
@@ -1072,12 +1078,16 @@ gb_internal lbValue lb_emit_vector_mul_matrix(lbProcedure *p, lbValue lhs, lbVal
|
||||
|
||||
lbAddr res = lb_add_local_generated(p, type, true);
|
||||
|
||||
Type *vector_elem_type = base_array_type(rhs.type);
|
||||
|
||||
for (i64 j = 0; j < mt->Matrix.column_count; j++) {
|
||||
for (i64 k = 0; k < mt->Matrix.row_count; k++) {
|
||||
lbValue dst = lb_emit_matrix_epi(p, res.addr, 0, j);
|
||||
lbValue d0 = lb_emit_load(p, dst);
|
||||
|
||||
lbValue a = lb_emit_struct_ev(p, lhs, cast(i32)k);
|
||||
LLVMValueRef a_value = LLVMBuildExtractValue(p->builder, lhs.value, cast(unsigned)k, "");
|
||||
lbValue a = {a_value, vector_elem_type};
|
||||
|
||||
lbValue b = lb_emit_matrix_ev(p, rhs, k, j);
|
||||
lbValue c = lb_emit_mul_add(p, a, b, d0, elem);
|
||||
lb_emit_store(p, dst, c);
|
||||
|
||||
@@ -1742,14 +1742,25 @@ gb_internal lbValue lb_emit_matrix_ep(lbProcedure *p, lbValue s, lbValue row, lb
|
||||
return res;
|
||||
}
|
||||
|
||||
|
||||
gb_internal lbValue lb_emit_matrix_ev(lbProcedure *p, lbValue s, isize row, isize column) {
|
||||
Type *st = base_type(s.type);
|
||||
GB_ASSERT_MSG(is_type_matrix(st), "%s", type_to_string(st));
|
||||
|
||||
lbValue value = lb_address_from_load_or_generate_local(p, s);
|
||||
lbValue ptr = lb_emit_matrix_epi(p, value, row, column);
|
||||
return lb_emit_load(p, ptr);
|
||||
Type *t = s.type;
|
||||
Type *mt = base_type(t);
|
||||
GB_ASSERT_MSG(is_type_matrix(mt), "%s", type_to_string(mt));
|
||||
|
||||
isize stride_elems = matrix_type_stride_in_elems(mt);
|
||||
|
||||
isize index = -1;
|
||||
|
||||
if (mt->Matrix.is_row_major) {
|
||||
index = column + (row * stride_elems);
|
||||
} else {
|
||||
index = row + (column * stride_elems);
|
||||
}
|
||||
|
||||
lbValue res = {};
|
||||
res.value = LLVMBuildExtractValue(p->builder, s.value, cast(unsigned)index, "");
|
||||
res.type = base_array_type(mt);
|
||||
return res;
|
||||
}
|
||||
|
||||
gb_internal void lb_fill_slice(lbProcedure *p, lbAddr const &slice, lbValue base_elem, lbValue len) {
|
||||
|
||||
Reference in New Issue
Block a user