diff --git a/src/linalg/impl_linalg.rs b/src/linalg/impl_linalg.rs index fd8d77d85..626c6d1ec 100644 --- a/src/linalg/impl_linalg.rs +++ b/src/linalg/impl_linalg.rs @@ -9,9 +9,10 @@ use crate::imp_prelude::*; use crate::numeric_util; -use crate::{LinalgScalar, Zip}; +use crate::{ErrorKind, LinalgScalar, Zip}; use std::any::TypeId; +use std::mem::MaybeUninit; #[cfg(feature = "blas")] use std::cmp; @@ -823,3 +824,190 @@ mod blas_tests { assert!(blas_column_major_2d::(&m)); } } + +#[allow(dead_code)] +fn general_kron( + a: &ArrayBase, + b: &ArrayBase, + mut f: F, +) -> Result, ErrorKind> +where + T: Copy, + Sa: Data, + Sb: Data, + D: Dimension, + F: FnMut(T, T) -> T, +{ + let res_ndim = std::cmp::max(a.ndim(), b.ndim()); + + //Creates shapes completed by 1s to have res_ndim dims for each input array, + let a_shape_completed = { + let a_shape_iter = a.shape().iter().cloned().chain(std::iter::repeat(1)); + let mut a_shape_completed = D::zeros(res_ndim); + for (a_shape_completed_elem, a_shape_completed_value) in a_shape_completed + .as_array_view_mut() + .iter_mut() + .zip(a_shape_iter) + { + *a_shape_completed_elem = a_shape_completed_value; + } + a_shape_completed + }; + + let b_shape_completed = { + let b_shape_iter = b.shape().iter().cloned().chain(std::iter::repeat(1)); + let mut b_shape_completed = D::zeros(res_ndim); + for (b_shape_completed_elem, b_shape_completed_value) in b_shape_completed + .as_array_view_mut() + .iter_mut() + .zip(b_shape_iter) + { + *b_shape_completed_elem = b_shape_completed_value; + } + b_shape_completed + }; + + // Create result shape, checking that the multiplication doesn't overflow to guarantee safety below + let res_dim = { + let mut res_dim: D = D::zeros(res_ndim); + for ((res_dim_elem, &a_shape_value), &b_shape_value) in res_dim + .as_array_view_mut() + .iter_mut() + .zip(a_shape_completed.as_array_view()) + .zip(b_shape_completed.as_array_view()) + { + match a_shape_value.checked_mul(b_shape_value) { + Some(n) => *res_dim_elem = n, + None => return Err(ErrorKind::Overflow), + } + } + res_dim + }; + + // Reshape input arrays to compatible shapes + let a_reshape = a.view().into_shape(a_shape_completed).unwrap(); + let b_reshape = b.view().into_shape(b_shape_completed.clone()).unwrap(); + + //Create and fill the result array + let mut res: Array, D> = ArrayBase::maybe_uninit(res_dim); + let res_chunks = res.exact_chunks_mut(b_shape_completed); + Zip::from(res_chunks) + .and(a_reshape) + .apply(|res_chunk, &a_elem| { + Zip::from(&b_reshape) + .apply_assign_into(res_chunk, |&b_elem| MaybeUninit::new(f(a_elem, b_elem))) + }); + // This is safe because the exact chunks covered exactly the res + let res = unsafe { res.assume_init() }; + Ok(res) +} + +#[allow(dead_code, clippy::type_repetition_in_bounds)] +fn kron(a: &ArrayBase, b: &ArrayBase) -> Result, ErrorKind> +where + T: Copy, + Sa: Data, + Sb: Data, + D: Dimension, + T: crate::ScalarOperand + std::ops::Mul, +{ + general_kron(a, b, std::ops::Mul::mul) +} + +#[cfg(test)] +mod kron_test { + use super::*; + + #[test] + fn same_dim() { + let a = array![ + [[1, 2, 3], [4, 5, 6]], + [[17, 42, 69], [0, -1, 1]], + [[1337, 1, 0], [-1337, -1, 0]] + ]; + let b = array![ + [[55, 66, 77], [88, 99, 1010]], + [[42, 42, 0], [1, -3, 10]], + [[110, 0, 7], [523, 21, -12]] + ]; + let res = kron(&a.view(), &b.view()).unwrap(); + for a0 in 0..a.len_of(Axis(0)) { + for a1 in 0..a.len_of(Axis(1)) { + for a2 in 0..a.len_of(Axis(2)) { + for b0 in 0..b.len_of(Axis(0)) { + for b1 in 0..b.len_of(Axis(1)) { + for b2 in 0..b.len_of(Axis(2)) { + assert_eq!( + res[[ + b.shape()[0] * a0 + b0, + b.shape()[1] * a1 + b1, + b.shape()[2] * a2 + b2 + ]], + a[[a0, a1, a2]] * b[[b0, b1, b2]] + ) + } + } + } + } + } + } + } + + #[test] + fn different_dim() { + let a = array![ + [1, 2, 3, 4, 5, 6], + [17, 42, 69, 0, -1, 1], + [1337, 1, 0, -1337, -1, 0] + ]; + let b = array![ + [[55, 66, 77], [88, 99, 1010]], + [[42, 42, 0], [1, -3, 10]], + [[110, 0, 7], [523, 21, -12]] + ]; + let res = kron(&a.view().into_dyn(), &b.view().into_dyn()).unwrap(); + for a0 in 0..a.len_of(Axis(0)) { + for a1 in 0..a.len_of(Axis(1)) { + for b0 in 0..b.len_of(Axis(0)) { + for b1 in 0..b.len_of(Axis(1)) { + for b2 in 0..b.len_of(Axis(2)) { + assert_eq!( + res[[b.shape()[0] * a0 + b0, b.shape()[1] * a1 + b1, b2]], + a[[a0, a1]] * b[[b0, b1, b2]] + ) + } + } + } + } + } + } + + #[test] + fn different_dim2() { + let a = array![ + [1, 2, 3, 4, 5, 6], + [17, 42, 69, 0, -1, 1], + [1337, 1, 0, -1337, -1, 0] + ]; + let b = array![ + [[55, 66, 77], [88, 99, 1010]], + [[42, 42, 0], [1, -3, 10]], + [[110, 0, 7], [523, 21, -12]] + ]; + let res = kron(&b.view().into_dyn(), &a.view().into_dyn()).unwrap(); + for a0 in 0..a.len_of(Axis(0)) { + for a1 in 0..a.len_of(Axis(1)) { + for b0 in 0..b.len_of(Axis(0)) { + for b1 in 0..b.len_of(Axis(1)) { + for b2 in 0..b.len_of(Axis(2)) { + assert_eq!( + res[[a.shape()[0] * b0 + a0, a.shape()[1] * b1 + a1, b2]], + a[[a0, a1]] * b[[b0, b1, b2]] + ) + } + } + } + } + } + } +}