From 8c5738de3199e7a8d03baa4ee1941459f57ddfd4 Mon Sep 17 00:00:00 2001 From: CARCANO Arthur Date: Tue, 20 Aug 2019 16:44:39 +0200 Subject: [PATCH 1/9] First try for various outter product implementations --- src/linalg/impl_linalg.rs | 122 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 122 insertions(+) diff --git a/src/linalg/impl_linalg.rs b/src/linalg/impl_linalg.rs index fd8d77d85..668efa377 100644 --- a/src/linalg/impl_linalg.rs +++ b/src/linalg/impl_linalg.rs @@ -823,3 +823,125 @@ mod blas_tests { assert!(blas_column_major_2d::(&m)); } } + +fn general_outer_to_dyn( + a: &ArrayBase, + b: &ArrayBase, + f: F, +) -> ArrayD +where + T: Copy, + Sa: Data, + Sb: Data, + I: Dimension, + F: Fn(ArrayViewMut, T, &ArrayBase) -> (), +{ + //Iterators on the shapes, compelted by 1s + let a_shape_iter = a.shape().iter().chain([1].iter().cycle()); + let b_shape_iter = b.shape().iter().chain([1].iter().cycle()); + + let res_ndim = std::cmp::max(a.ndim(), b.ndim()); + let res_dim: Vec = a_shape_iter + .zip(b_shape_iter) + .take(res_ndim) + .map(|(x, y)| x * y) + .collect(); + + unsafe { + let mut res: ArrayD = ArrayBase::uninitialized(res_dim); + let res_chunks = res.exact_chunks_mut(b.shape()); + //azip!(mut r_c (res_chunks), a in {f(r_c, a, b)}); + Zip::from(res_chunks).and(a).apply(|r_c, &x| f(r_c, x, b)); + res + } +} + +fn kron_to_dyn(a: &ArrayBase, b: &ArrayBase) -> Array +where + T: Copy, + Sa: Data, + Sb: Data, + I: Dimension, + T: crate::ScalarOperand, + for<'a> &'a ArrayBase: std::ops::Mul>, +{ + general_outer_to_dyn(a, b, |mut res, x, a| res.assign(&(a * x))) +} + +fn general_outer_same_size( + a: &ArrayBase, + b: &ArrayBase, + f: F, +) -> Array +where + T: Copy, + Sa: Data, + Sb: Data, + I: Dimension, + F: Fn(ArrayViewMut, T, &ArrayBase) -> (), +{ + let mut res_dim = a.raw_dim().clone(); + let mut res_dim_view = res_dim.as_array_view_mut(); + res_dim_view *= &b.raw_dim().as_array_view(); + + unsafe { + let mut res: Array = ArrayBase::uninitialized(res_dim); + let res_chunks = res.exact_chunks_mut(b.raw_dim()); + Zip::from(res_chunks).and(a).apply(|r_c, &x| f(r_c, x, b)); + res + } +} + +fn kron_same_size(a: &ArrayBase, b: &ArrayBase) -> Array +where + T: Copy, + Sa: Data, + Sb: Data, + I: Dimension, + T: crate::ScalarOperand, + for<'a> &'a ArrayBase: std::ops::Mul>, +{ + general_outer_same_size(a, b, |mut res, x, a| res.assign(&(a * x))) +} + +#[cfg(test)] +mod kron_test { + use super::*; + + #[test] + fn test_same_size() { + let a = array![ + [[1, 2, 3], [4, 5, 6]], + [[17, 42, 69], [0, -1, 1]], + [[1337, 1, 0], [-1337, -1, 0]] + ]; + let b = array![ + [[55, 66, 77], [88, 99, 1010]], + [[42, 42, 0], [1, -3, 10]], + [[110, 0, 7], [523, 21, -12]] + ]; + let res1 = kron_same_size(&a, &b); + let res2 = kron_to_dyn(&a.clone().into_dyn(), &b); + assert_eq!(res1.clone().into_dyn(), res2); + for a0 in 0..a.len_of(Axis(0)) { + for a1 in 0..a.len_of(Axis(1)) { + for a2 in 0..a.len_of(Axis(2)) { + for b0 in 0..b.len_of(Axis(0)) { + for b1 in 0..b.len_of(Axis(1)) { + for b2 in 0..b.len_of(Axis(2)) { + assert_eq!( + res2[[ + b.shape()[0] * a0 + b0, + b.shape()[1] * a1 + b1, + b.shape()[2] * a2 + b2 + ]], + a[[a0, a1, a2]] * b[[b0, b1, b2]] + ) + } + } + } + } + } + } + } +} From 71129cb1ca8fb96dc911c1b6f2743d7ab5c9722d Mon Sep 17 00:00:00 2001 From: CARCANO Arthur Date: Tue, 20 Aug 2019 19:26:59 +0200 Subject: [PATCH 2/9] Added allow dead code while WIP to pass CI. --- src/linalg/impl_linalg.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/linalg/impl_linalg.rs b/src/linalg/impl_linalg.rs index 668efa377..aadc8fa0c 100644 --- a/src/linalg/impl_linalg.rs +++ b/src/linalg/impl_linalg.rs @@ -824,6 +824,7 @@ mod blas_tests { } } +#[allow(dead_code)] fn general_outer_to_dyn( a: &ArrayBase, b: &ArrayBase, @@ -856,6 +857,7 @@ where } } +#[allow(dead_code)] fn kron_to_dyn(a: &ArrayBase, b: &ArrayBase) -> Array where T: Copy, @@ -868,6 +870,7 @@ where general_outer_to_dyn(a, b, |mut res, x, a| res.assign(&(a * x))) } +#[allow(dead_code)] fn general_outer_same_size( a: &ArrayBase, b: &ArrayBase, @@ -892,6 +895,7 @@ where } } +#[allow(dead_code)] fn kron_same_size(a: &ArrayBase, b: &ArrayBase) -> Array where T: Copy, From 26b10875b11d924de725fa6f938ea09068cd8a27 Mon Sep 17 00:00:00 2001 From: Arthur Carcano Date: Tue, 20 Aug 2019 23:55:39 +0200 Subject: [PATCH 3/9] Allow beta specific clippy:type_repetition_in_bounds --- src/linalg/impl_linalg.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/linalg/impl_linalg.rs b/src/linalg/impl_linalg.rs index aadc8fa0c..b66b37187 100644 --- a/src/linalg/impl_linalg.rs +++ b/src/linalg/impl_linalg.rs @@ -857,7 +857,7 @@ where } } -#[allow(dead_code)] +#[allow(dead_code, clippy::type_repetition_in_bounds)] fn kron_to_dyn(a: &ArrayBase, b: &ArrayBase) -> Array where T: Copy, @@ -895,7 +895,7 @@ where } } -#[allow(dead_code)] +#[allow(dead_code, clippy::type_repetition_in_bounds)] fn kron_same_size(a: &ArrayBase, b: &ArrayBase) -> Array where T: Copy, From 3d9933e489a38ed44bfe4c89164e5a9aa9ecffad Mon Sep 17 00:00:00 2001 From: Arthur Carcano Date: Tue, 3 Mar 2020 10:05:14 +0100 Subject: [PATCH 4/9] One allocation less. --- src/linalg/impl_linalg.rs | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/src/linalg/impl_linalg.rs b/src/linalg/impl_linalg.rs index b66b37187..5a943ddf5 100644 --- a/src/linalg/impl_linalg.rs +++ b/src/linalg/impl_linalg.rs @@ -864,10 +864,13 @@ where Sa: Data, Sb: Data, I: Dimension, - T: crate::ScalarOperand, + T: crate::ScalarOperand + std::ops::MulAssign, for<'a> &'a ArrayBase: std::ops::Mul>, { - general_outer_to_dyn(a, b, |mut res, x, a| res.assign(&(a * x))) + general_outer_to_dyn(a, b, |mut res, x, a| { + res.assign(a); + res *= x + }) } #[allow(dead_code)] @@ -902,10 +905,13 @@ where Sa: Data, Sb: Data, I: Dimension, - T: crate::ScalarOperand, + T: crate::ScalarOperand + std::ops::MulAssign, for<'a> &'a ArrayBase: std::ops::Mul>, { - general_outer_same_size(a, b, |mut res, x, a| res.assign(&(a * x))) + general_outer_same_size(a, b, |mut res, x, a| { + res.assign(&a); + res *= x + }) } #[cfg(test)] From 566909ee635ec8c014a5a8cb1aaede15a97eba67 Mon Sep 17 00:00:00 2001 From: CARCANO Arthur Date: Tue, 3 Mar 2020 18:36:15 +0100 Subject: [PATCH 5/9] Tydied-up code a bit --- src/linalg/impl_linalg.rs | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/linalg/impl_linalg.rs b/src/linalg/impl_linalg.rs index 5a943ddf5..bd8862fe2 100644 --- a/src/linalg/impl_linalg.rs +++ b/src/linalg/impl_linalg.rs @@ -851,8 +851,7 @@ where unsafe { let mut res: ArrayD = ArrayBase::uninitialized(res_dim); let res_chunks = res.exact_chunks_mut(b.shape()); - //azip!(mut r_c (res_chunks), a in {f(r_c, a, b)}); - Zip::from(res_chunks).and(a).apply(|r_c, &x| f(r_c, x, b)); + Zip::from(res_chunks).and(a).apply(|res_chunk, &a_elem| f(res_chunk, a_elem, b)); res } } @@ -886,7 +885,7 @@ where I: Dimension, F: Fn(ArrayViewMut, T, &ArrayBase) -> (), { - let mut res_dim = a.raw_dim().clone(); + let mut res_dim = a.raw_dim(); let mut res_dim_view = res_dim.as_array_view_mut(); res_dim_view *= &b.raw_dim().as_array_view(); From 6dc1a673b9dc594a353cfc74dd182ea2cc4bc455 Mon Sep 17 00:00:00 2001 From: CARCANO Arthur Date: Tue, 3 Mar 2020 19:38:54 +0100 Subject: [PATCH 6/9] New API, more clear and more efficient. --- src/linalg/impl_linalg.rs | 42 +++++++++++++++++++-------------------- 1 file changed, 20 insertions(+), 22 deletions(-) diff --git a/src/linalg/impl_linalg.rs b/src/linalg/impl_linalg.rs index bd8862fe2..af8751f59 100644 --- a/src/linalg/impl_linalg.rs +++ b/src/linalg/impl_linalg.rs @@ -825,17 +825,16 @@ mod blas_tests { } #[allow(dead_code)] -fn general_outer_to_dyn( +fn general_outer_to_dyn( a: &ArrayBase, - b: &ArrayBase, + b: &ArrayBase, f: F, ) -> ArrayD where T: Copy, Sa: Data, Sb: Data, - I: Dimension, - F: Fn(ArrayViewMut, T, &ArrayBase) -> (), + F: Fn(T, T) -> T, { //Iterators on the shapes, compelted by 1s let a_shape_iter = a.shape().iter().chain([1].iter().cycle()); @@ -851,25 +850,24 @@ where unsafe { let mut res: ArrayD = ArrayBase::uninitialized(res_dim); let res_chunks = res.exact_chunks_mut(b.shape()); - Zip::from(res_chunks).and(a).apply(|res_chunk, &a_elem| f(res_chunk, a_elem, b)); + Zip::from(res_chunks).and(a).apply(|res_chunk, &a_elem| { + Zip::from(res_chunk) + .and(b) + .apply(|res_elem, &b_elem| *res_elem = f(a_elem, b_elem)) + }); res } } #[allow(dead_code, clippy::type_repetition_in_bounds)] -fn kron_to_dyn(a: &ArrayBase, b: &ArrayBase) -> Array +fn kron_to_dyn(a: &ArrayBase, b: &ArrayBase) -> Array where T: Copy, Sa: Data, Sb: Data, - I: Dimension, - T: crate::ScalarOperand + std::ops::MulAssign, - for<'a> &'a ArrayBase: std::ops::Mul>, + T: crate::ScalarOperand + std::ops::Mul, { - general_outer_to_dyn(a, b, |mut res, x, a| { - res.assign(a); - res *= x - }) + general_outer_to_dyn(a, b, std::ops::Mul::mul) } #[allow(dead_code)] @@ -883,7 +881,7 @@ where Sa: Data, Sb: Data, I: Dimension, - F: Fn(ArrayViewMut, T, &ArrayBase) -> (), + F: Fn(T, T) -> T, { let mut res_dim = a.raw_dim(); let mut res_dim_view = res_dim.as_array_view_mut(); @@ -892,7 +890,11 @@ where unsafe { let mut res: Array = ArrayBase::uninitialized(res_dim); let res_chunks = res.exact_chunks_mut(b.raw_dim()); - Zip::from(res_chunks).and(a).apply(|r_c, &x| f(r_c, x, b)); + Zip::from(res_chunks).and(a).apply(|res_chunk, &a_elem| { + Zip::from(res_chunk) + .and(b) + .apply(|r_elem, &b_elem| *r_elem = f(a_elem, b_elem)) + }); res } } @@ -904,13 +906,9 @@ where Sa: Data, Sb: Data, I: Dimension, - T: crate::ScalarOperand + std::ops::MulAssign, - for<'a> &'a ArrayBase: std::ops::Mul>, + T: crate::ScalarOperand + std::ops::Mul, { - general_outer_same_size(a, b, |mut res, x, a| { - res.assign(&a); - res *= x - }) + general_outer_same_size(a, b, std::ops::Mul::mul) } #[cfg(test)] @@ -930,7 +928,7 @@ mod kron_test { [[110, 0, 7], [523, 21, -12]] ]; let res1 = kron_same_size(&a, &b); - let res2 = kron_to_dyn(&a.clone().into_dyn(), &b); + let res2 = kron_to_dyn(&a.clone().into_dyn(), &b.clone().into_dyn()); assert_eq!(res1.clone().into_dyn(), res2); for a0 in 0..a.len_of(Axis(0)) { for a1 in 0..a.len_of(Axis(1)) { From 551deeb8efcfde0e8120c29576e19f0262e1dbec Mon Sep 17 00:00:00 2001 From: CARCANO Arthur Date: Mon, 20 Apr 2020 18:39:27 +0200 Subject: [PATCH 7/9] Changed to FnMut + MaybeUninit --- src/linalg/impl_linalg.rs | 45 ++++++++++++++++++--------------------- 1 file changed, 21 insertions(+), 24 deletions(-) diff --git a/src/linalg/impl_linalg.rs b/src/linalg/impl_linalg.rs index af8751f59..cc28ce1f6 100644 --- a/src/linalg/impl_linalg.rs +++ b/src/linalg/impl_linalg.rs @@ -12,6 +12,7 @@ use crate::numeric_util; use crate::{LinalgScalar, Zip}; use std::any::TypeId; +use std::mem::MaybeUninit; #[cfg(feature = "blas")] use std::cmp; @@ -828,13 +829,13 @@ mod blas_tests { fn general_outer_to_dyn( a: &ArrayBase, b: &ArrayBase, - f: F, + mut f: F, ) -> ArrayD where T: Copy, Sa: Data, Sb: Data, - F: Fn(T, T) -> T, + F: FnMut(T, T) -> T, { //Iterators on the shapes, compelted by 1s let a_shape_iter = a.shape().iter().chain([1].iter().cycle()); @@ -847,16 +848,14 @@ where .map(|(x, y)| x * y) .collect(); - unsafe { - let mut res: ArrayD = ArrayBase::uninitialized(res_dim); - let res_chunks = res.exact_chunks_mut(b.shape()); - Zip::from(res_chunks).and(a).apply(|res_chunk, &a_elem| { - Zip::from(res_chunk) - .and(b) - .apply(|res_elem, &b_elem| *res_elem = f(a_elem, b_elem)) - }); - res - } + let mut res: ArrayD> = ArrayBase::maybe_uninit(res_dim); + let res_chunks = res.exact_chunks_mut(b.shape()); + Zip::from(res_chunks).and(a).apply(|res_chunk, &a_elem| { + Zip::from(res_chunk) + .and(b) + .apply(|res_elem, &b_elem| *res_elem = MaybeUninit::new(f(a_elem, b_elem))) + }); + unsafe { res.assume_init() } } #[allow(dead_code, clippy::type_repetition_in_bounds)] @@ -874,29 +873,27 @@ where fn general_outer_same_size( a: &ArrayBase, b: &ArrayBase, - f: F, + mut f: F, ) -> Array where T: Copy, Sa: Data, Sb: Data, I: Dimension, - F: Fn(T, T) -> T, + F: FnMut(T, T) -> T, { let mut res_dim = a.raw_dim(); let mut res_dim_view = res_dim.as_array_view_mut(); res_dim_view *= &b.raw_dim().as_array_view(); - unsafe { - let mut res: Array = ArrayBase::uninitialized(res_dim); - let res_chunks = res.exact_chunks_mut(b.raw_dim()); - Zip::from(res_chunks).and(a).apply(|res_chunk, &a_elem| { - Zip::from(res_chunk) - .and(b) - .apply(|r_elem, &b_elem| *r_elem = f(a_elem, b_elem)) - }); - res - } + let mut res: Array, I> = ArrayBase::maybe_uninit(res_dim); + let res_chunks = res.exact_chunks_mut(b.raw_dim()); + Zip::from(res_chunks).and(a).apply(|res_chunk, &a_elem| { + Zip::from(res_chunk) + .and(b) + .apply(|r_elem, &b_elem| *r_elem = MaybeUninit::new(f(a_elem, b_elem))) + }); + unsafe { res.assume_init() } } #[allow(dead_code, clippy::type_repetition_in_bounds)] From cc24d1ec812ba878a4afbf9a04681692b8ea545e Mon Sep 17 00:00:00 2001 From: CARCANO Arthur Date: Mon, 20 Apr 2020 18:46:23 +0200 Subject: [PATCH 8/9] Use apply_assign_into --- src/linalg/impl_linalg.rs | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/linalg/impl_linalg.rs b/src/linalg/impl_linalg.rs index cc28ce1f6..764980ba4 100644 --- a/src/linalg/impl_linalg.rs +++ b/src/linalg/impl_linalg.rs @@ -851,9 +851,8 @@ where let mut res: ArrayD> = ArrayBase::maybe_uninit(res_dim); let res_chunks = res.exact_chunks_mut(b.shape()); Zip::from(res_chunks).and(a).apply(|res_chunk, &a_elem| { - Zip::from(res_chunk) - .and(b) - .apply(|res_elem, &b_elem| *res_elem = MaybeUninit::new(f(a_elem, b_elem))) + Zip::from(b) + .apply_assign_into(res_chunk, |&b_elem| MaybeUninit::new(f(a_elem, b_elem))) }); unsafe { res.assume_init() } } @@ -889,9 +888,8 @@ where let mut res: Array, I> = ArrayBase::maybe_uninit(res_dim); let res_chunks = res.exact_chunks_mut(b.raw_dim()); Zip::from(res_chunks).and(a).apply(|res_chunk, &a_elem| { - Zip::from(res_chunk) - .and(b) - .apply(|r_elem, &b_elem| *r_elem = MaybeUninit::new(f(a_elem, b_elem))) + Zip::from(b) + .apply_assign_into(res_chunk, |&b_elem| MaybeUninit::new(f(a_elem, b_elem))) }); unsafe { res.assume_init() } } From d0098f45db89c30f1d288960a03dd46207fb7089 Mon Sep 17 00:00:00 2001 From: CARCANO Arthur Date: Tue, 21 Apr 2020 14:39:52 +0200 Subject: [PATCH 9/9] Change interface + fix bugs in different ndim case --- src/linalg/impl_linalg.rs | 194 +++++++++++++++++++++++++------------- 1 file changed, 129 insertions(+), 65 deletions(-) diff --git a/src/linalg/impl_linalg.rs b/src/linalg/impl_linalg.rs index 764980ba4..626c6d1ec 100644 --- a/src/linalg/impl_linalg.rs +++ b/src/linalg/impl_linalg.rs @@ -9,7 +9,7 @@ use crate::imp_prelude::*; use crate::numeric_util; -use crate::{LinalgScalar, Zip}; +use crate::{ErrorKind, LinalgScalar, Zip}; use std::any::TypeId; use std::mem::MaybeUninit; @@ -826,84 +826,92 @@ mod blas_tests { } #[allow(dead_code)] -fn general_outer_to_dyn( - a: &ArrayBase, - b: &ArrayBase, +fn general_kron( + a: &ArrayBase, + b: &ArrayBase, mut f: F, -) -> ArrayD +) -> Result, ErrorKind> where T: Copy, Sa: Data, Sb: Data, + D: Dimension, F: FnMut(T, T) -> T, { - //Iterators on the shapes, compelted by 1s - let a_shape_iter = a.shape().iter().chain([1].iter().cycle()); - let b_shape_iter = b.shape().iter().chain([1].iter().cycle()); - let res_ndim = std::cmp::max(a.ndim(), b.ndim()); - let res_dim: Vec = a_shape_iter - .zip(b_shape_iter) - .take(res_ndim) - .map(|(x, y)| x * y) - .collect(); - - let mut res: ArrayD> = ArrayBase::maybe_uninit(res_dim); - let res_chunks = res.exact_chunks_mut(b.shape()); - Zip::from(res_chunks).and(a).apply(|res_chunk, &a_elem| { - Zip::from(b) - .apply_assign_into(res_chunk, |&b_elem| MaybeUninit::new(f(a_elem, b_elem))) - }); - unsafe { res.assume_init() } -} -#[allow(dead_code, clippy::type_repetition_in_bounds)] -fn kron_to_dyn(a: &ArrayBase, b: &ArrayBase) -> Array -where - T: Copy, - Sa: Data, - Sb: Data, - T: crate::ScalarOperand + std::ops::Mul, -{ - general_outer_to_dyn(a, b, std::ops::Mul::mul) -} + //Creates shapes completed by 1s to have res_ndim dims for each input array, + let a_shape_completed = { + let a_shape_iter = a.shape().iter().cloned().chain(std::iter::repeat(1)); + let mut a_shape_completed = D::zeros(res_ndim); + for (a_shape_completed_elem, a_shape_completed_value) in a_shape_completed + .as_array_view_mut() + .iter_mut() + .zip(a_shape_iter) + { + *a_shape_completed_elem = a_shape_completed_value; + } + a_shape_completed + }; -#[allow(dead_code)] -fn general_outer_same_size( - a: &ArrayBase, - b: &ArrayBase, - mut f: F, -) -> Array -where - T: Copy, - Sa: Data, - Sb: Data, - I: Dimension, - F: FnMut(T, T) -> T, -{ - let mut res_dim = a.raw_dim(); - let mut res_dim_view = res_dim.as_array_view_mut(); - res_dim_view *= &b.raw_dim().as_array_view(); - - let mut res: Array, I> = ArrayBase::maybe_uninit(res_dim); - let res_chunks = res.exact_chunks_mut(b.raw_dim()); - Zip::from(res_chunks).and(a).apply(|res_chunk, &a_elem| { - Zip::from(b) - .apply_assign_into(res_chunk, |&b_elem| MaybeUninit::new(f(a_elem, b_elem))) - }); - unsafe { res.assume_init() } + let b_shape_completed = { + let b_shape_iter = b.shape().iter().cloned().chain(std::iter::repeat(1)); + let mut b_shape_completed = D::zeros(res_ndim); + for (b_shape_completed_elem, b_shape_completed_value) in b_shape_completed + .as_array_view_mut() + .iter_mut() + .zip(b_shape_iter) + { + *b_shape_completed_elem = b_shape_completed_value; + } + b_shape_completed + }; + + // Create result shape, checking that the multiplication doesn't overflow to guarantee safety below + let res_dim = { + let mut res_dim: D = D::zeros(res_ndim); + for ((res_dim_elem, &a_shape_value), &b_shape_value) in res_dim + .as_array_view_mut() + .iter_mut() + .zip(a_shape_completed.as_array_view()) + .zip(b_shape_completed.as_array_view()) + { + match a_shape_value.checked_mul(b_shape_value) { + Some(n) => *res_dim_elem = n, + None => return Err(ErrorKind::Overflow), + } + } + res_dim + }; + + // Reshape input arrays to compatible shapes + let a_reshape = a.view().into_shape(a_shape_completed).unwrap(); + let b_reshape = b.view().into_shape(b_shape_completed.clone()).unwrap(); + + //Create and fill the result array + let mut res: Array, D> = ArrayBase::maybe_uninit(res_dim); + let res_chunks = res.exact_chunks_mut(b_shape_completed); + Zip::from(res_chunks) + .and(a_reshape) + .apply(|res_chunk, &a_elem| { + Zip::from(&b_reshape) + .apply_assign_into(res_chunk, |&b_elem| MaybeUninit::new(f(a_elem, b_elem))) + }); + // This is safe because the exact chunks covered exactly the res + let res = unsafe { res.assume_init() }; + Ok(res) } #[allow(dead_code, clippy::type_repetition_in_bounds)] -fn kron_same_size(a: &ArrayBase, b: &ArrayBase) -> Array +fn kron(a: &ArrayBase, b: &ArrayBase) -> Result, ErrorKind> where T: Copy, Sa: Data, Sb: Data, - I: Dimension, + D: Dimension, T: crate::ScalarOperand + std::ops::Mul, { - general_outer_same_size(a, b, std::ops::Mul::mul) + general_kron(a, b, std::ops::Mul::mul) } #[cfg(test)] @@ -911,7 +919,7 @@ mod kron_test { use super::*; #[test] - fn test_same_size() { + fn same_dim() { let a = array![ [[1, 2, 3], [4, 5, 6]], [[17, 42, 69], [0, -1, 1]], @@ -922,9 +930,7 @@ mod kron_test { [[42, 42, 0], [1, -3, 10]], [[110, 0, 7], [523, 21, -12]] ]; - let res1 = kron_same_size(&a, &b); - let res2 = kron_to_dyn(&a.clone().into_dyn(), &b.clone().into_dyn()); - assert_eq!(res1.clone().into_dyn(), res2); + let res = kron(&a.view(), &b.view()).unwrap(); for a0 in 0..a.len_of(Axis(0)) { for a1 in 0..a.len_of(Axis(1)) { for a2 in 0..a.len_of(Axis(2)) { @@ -932,7 +938,7 @@ mod kron_test { for b1 in 0..b.len_of(Axis(1)) { for b2 in 0..b.len_of(Axis(2)) { assert_eq!( - res2[[ + res[[ b.shape()[0] * a0 + b0, b.shape()[1] * a1 + b1, b.shape()[2] * a2 + b2 @@ -946,4 +952,62 @@ mod kron_test { } } } + + #[test] + fn different_dim() { + let a = array![ + [1, 2, 3, 4, 5, 6], + [17, 42, 69, 0, -1, 1], + [1337, 1, 0, -1337, -1, 0] + ]; + let b = array![ + [[55, 66, 77], [88, 99, 1010]], + [[42, 42, 0], [1, -3, 10]], + [[110, 0, 7], [523, 21, -12]] + ]; + let res = kron(&a.view().into_dyn(), &b.view().into_dyn()).unwrap(); + for a0 in 0..a.len_of(Axis(0)) { + for a1 in 0..a.len_of(Axis(1)) { + for b0 in 0..b.len_of(Axis(0)) { + for b1 in 0..b.len_of(Axis(1)) { + for b2 in 0..b.len_of(Axis(2)) { + assert_eq!( + res[[b.shape()[0] * a0 + b0, b.shape()[1] * a1 + b1, b2]], + a[[a0, a1]] * b[[b0, b1, b2]] + ) + } + } + } + } + } + } + + #[test] + fn different_dim2() { + let a = array![ + [1, 2, 3, 4, 5, 6], + [17, 42, 69, 0, -1, 1], + [1337, 1, 0, -1337, -1, 0] + ]; + let b = array![ + [[55, 66, 77], [88, 99, 1010]], + [[42, 42, 0], [1, -3, 10]], + [[110, 0, 7], [523, 21, -12]] + ]; + let res = kron(&b.view().into_dyn(), &a.view().into_dyn()).unwrap(); + for a0 in 0..a.len_of(Axis(0)) { + for a1 in 0..a.len_of(Axis(1)) { + for b0 in 0..b.len_of(Axis(0)) { + for b1 in 0..b.len_of(Axis(1)) { + for b2 in 0..b.len_of(Axis(2)) { + assert_eq!( + res[[a.shape()[0] * b0 + a0, a.shape()[1] * b1 + a1, b2]], + a[[a0, a1]] * b[[b0, b1, b2]] + ) + } + } + } + } + } + } }