Fixed matrix multiplication panicking in some cases

This commit is contained in:
bakk 2022-01-18 19:53:32 +01:00 committed by PaddiM8
parent b8b7a0e257
commit 777d54cd40
2 changed files with 22 additions and 12 deletions

View File

@ -747,30 +747,28 @@ impl KalkValue {
KalkValue::Vector(new_values) KalkValue::Vector(new_values)
} }
(KalkValue::Matrix(rows), KalkValue::Matrix(rows_rhs)) => { (KalkValue::Matrix(rows), KalkValue::Matrix(rows_rhs)) => {
if rows.first().unwrap().len() != rows_rhs.len() { let lhs_columns = rows.first().unwrap();
if lhs_columns.len() != rows_rhs.len() {
return KalkValue::nan(); return KalkValue::nan();
} }
let mut result = Vec::new(); let rhs_columns = rows_rhs.first().unwrap();
let mut result = vec![vec![KalkValue::from(0f64); rhs_columns.len()]; rows.len()];
// For every row in lhs // For every row in lhs
for i in 0..rows.len() { for i in 0..rows.len() {
let mut dot_products = Vec::new();
// For every column in rhs // For every column in rhs
for j in 0..rows.len() { for j in 0..rhs_columns.len() {
let mut dot_product = KalkValue::from(0f64); let mut sum = KalkValue::from(0f64);
// For every value in the current lhs row // For every value in the current lhs row
for (k, value) in rows[i].iter().enumerate() { for (k, value) in rows[i].iter().enumerate() {
let value_rhs = &rows_rhs[k][j]; let value_rhs = &rows_rhs[k][j];
dot_product = dot_product sum = sum.add_without_unit(&value.clone().mul_without_unit(value_rhs));
.add_without_unit(&value.clone().mul_without_unit(value_rhs));
} }
dot_products.push(dot_product); result[i][j] = sum;
} }
result.push(dot_products);
} }
KalkValue::Matrix(result) KalkValue::Matrix(result)

View File

@ -6,6 +6,14 @@ m_2 = [2, 3, 4
5, 6, 7 5, 6, 7
8, 9, 10] 8, 9, 10]
m_3 = [1, 2, 1
0, 1, 0
2, 3, 4]
m_4 = [2, 5
6, 7
1, 8]
v = (1, 2, 3) v = (1, 2, 3)
m_1 + m_2 = [3, 5, 7 m_1 + m_2 = [3, 5, 7
@ -35,4 +43,8 @@ m_1 * 2 = [2, 4, 6
14, 16, 18] and 14, 16, 18] and
m_1 / 2 = [1/2, 1, 3/2 m_1 / 2 = [1/2, 1, 3/2
2, 5/2, 3 2, 5/2, 3
7/2, 4, 9/2] 7/2, 4, 9/2] and
m_3 * m_4 = [15, 27
6, 7
26, 63]