Allow conversions between matrices of the same element count

This commit is contained in:
gingerBill
2021-10-21 01:34:39 +01:00
parent e0b9475378
commit 48d277a3c4
4 changed files with 50 additions and 38 deletions

View File

@@ -1293,7 +1293,7 @@ i64 matrix_type_stride_in_elems(Type *t) {
}
i64 matrix_type_total_elems(Type *t) {
i64 matrix_type_total_internal_elems(Type *t) {
t = base_type(t);
GB_ASSERT(t->kind == Type_Matrix);
i64 size = type_size_of(t);
@@ -1301,30 +1301,6 @@ i64 matrix_type_total_elems(Type *t) {
return size/gb_max(elem_size, 1);
}
void matrix_indices_from_index(Type *t, i64 index, i64 *row_index_, i64 *column_index_) {
t = base_type(t);
GB_ASSERT(t->kind == Type_Matrix);
i64 row_count = t->Matrix.row_count;
i64 column_count = t->Matrix.column_count;
GB_ASSERT(0 <= index && index < row_count*column_count);
i64 row_index = index / column_count;
i64 column_index = index % column_count;
if (row_index_) *row_index_ = row_index;
if (column_index_) *column_index_ = column_index;
}
i64 matrix_index_to_offset(Type *t, i64 index) {
t = base_type(t);
GB_ASSERT(t->kind == Type_Matrix);
i64 row_index, column_index;
matrix_indices_from_index(t, index, &row_index, &column_index);
i64 stride_elems = matrix_type_stride_in_elems(t);
return stride_elems*column_index + row_index;
}
i64 matrix_indices_to_offset(Type *t, i64 row_index, i64 column_index) {
t = base_type(t);
GB_ASSERT(t->kind == Type_Matrix);
@@ -1333,6 +1309,22 @@ i64 matrix_indices_to_offset(Type *t, i64 row_index, i64 column_index) {
i64 stride_elems = matrix_type_stride_in_elems(t);
return stride_elems*column_index + row_index;
}
i64 matrix_index_to_offset(Type *t, i64 index) {
t = base_type(t);
GB_ASSERT(t->kind == Type_Matrix);
i64 row_index = index%t->Matrix.row_count;
i64 column_index = index/t->Matrix.row_count;
return matrix_indices_to_offset(t, row_index, column_index);
}
bool is_matrix_square(Type *t) {
t = base_type(t);
GB_ASSERT(t->kind == Type_Matrix);
return t->Matrix.row_count == t->Matrix.column_count;
}
bool is_type_valid_for_matrix_elems(Type *t) {
t = base_type(t);