Skip to content

Commit 1c685ef

Browse files
authored
Merge pull request #1105 from ethanhs/kron
Implement Kronecker product
2 parents 209d171 + 7d6fd72 commit 1c685ef

File tree

3 files changed

+98
-0
lines changed

3 files changed

+98
-0
lines changed

src/linalg/impl_linalg.rs

+34
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ use crate::dimension::offset_from_low_addr_ptr_to_logical_ptr;
1414
use crate::{LinalgScalar, Zip};
1515

1616
use std::any::TypeId;
17+
use std::mem::MaybeUninit;
1718
use alloc::vec::Vec;
1819

1920
#[cfg(feature = "blas")]
@@ -699,6 +700,39 @@ unsafe fn general_mat_vec_mul_impl<A, S1, S2>(
699700
}
700701
}
701702

703+
704+
/// Kronecker product of 2D matrices.
705+
///
706+
/// The kronecker product of a LxN matrix A and a MxR matrix B is a (L*M)x(N*R)
707+
/// matrix K formed by the block multiplication A_ij * B.
708+
pub fn kron<A, S1, S2>(a: &ArrayBase<S1, Ix2>, b: &ArrayBase<S2, Ix2>) -> Array<A, Ix2>
709+
where
710+
S1: Data<Elem = A>,
711+
S2: Data<Elem = A>,
712+
A: LinalgScalar,
713+
{
714+
let dimar = a.shape()[0];
715+
let dimac = a.shape()[1];
716+
let dimbr = b.shape()[0];
717+
let dimbc = b.shape()[1];
718+
let mut out: Array2<MaybeUninit<A>> = Array2::uninit((
719+
dimar
720+
.checked_mul(dimbr)
721+
.expect("Dimensions of kronecker product output array overflows usize."),
722+
dimac
723+
.checked_mul(dimbc)
724+
.expect("Dimensions of kronecker product output array overflows usize."),
725+
));
726+
Zip::from(out.exact_chunks_mut((dimbr, dimbc)))
727+
.and(a)
728+
.for_each(|out, &a| {
729+
Zip::from(out).and(b).for_each(|out, &b| {
730+
*out = MaybeUninit::new(a * b);
731+
})
732+
});
733+
unsafe { out.assume_init() }
734+
}
735+
702736
#[inline(always)]
703737
/// Return `true` if `A` and `B` are the same type
704738
fn same_type<A: 'static, B: 'static>() -> bool {

src/linalg/mod.rs

+1
Original file line numberDiff line numberDiff line change
@@ -11,5 +11,6 @@
1111
pub use self::impl_linalg::general_mat_mul;
1212
pub use self::impl_linalg::general_mat_vec_mul;
1313
pub use self::impl_linalg::Dot;
14+
pub use self::impl_linalg::kron;
1415

1516
mod impl_linalg;

tests/oper.rs

+63
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
)]
77
#![cfg(feature = "std")]
88
use ndarray::linalg::general_mat_mul;
9+
use ndarray::linalg::kron;
910
use ndarray::prelude::*;
1011
use ndarray::{rcarr1, rcarr2};
1112
use ndarray::{Data, LinalgScalar};
@@ -820,3 +821,65 @@ fn vec_mat_mul() {
820821
}
821822
}
822823
}
824+
825+
#[test]
826+
fn kron_square_f64() {
827+
let a = arr2(&[[1.0, 0.0], [0.0, 1.0]]);
828+
let b = arr2(&[[0.0, 1.0], [1.0, 0.0]]);
829+
830+
assert_eq!(
831+
kron(&a, &b),
832+
arr2(&[
833+
[0.0, 1.0, 0.0, 0.0],
834+
[1.0, 0.0, 0.0, 0.0],
835+
[0.0, 0.0, 0.0, 1.0],
836+
[0.0, 0.0, 1.0, 0.0]
837+
]),
838+
);
839+
840+
assert_eq!(
841+
kron(&b, &a),
842+
arr2(&[
843+
[0.0, 0.0, 1.0, 0.0],
844+
[0.0, 0.0, 0.0, 1.0],
845+
[1.0, 0.0, 0.0, 0.0],
846+
[0.0, 1.0, 0.0, 0.0]
847+
]),
848+
)
849+
}
850+
851+
#[test]
852+
fn kron_square_i64() {
853+
let a = arr2(&[[1, 0], [0, 1]]);
854+
let b = arr2(&[[0, 1], [1, 0]]);
855+
856+
assert_eq!(
857+
kron(&a, &b),
858+
arr2(&[[0, 1, 0, 0], [1, 0, 0, 0], [0, 0, 0, 1], [0, 0, 1, 0]]),
859+
);
860+
861+
assert_eq!(
862+
kron(&b, &a),
863+
arr2(&[[0, 0, 1, 0], [0, 0, 0, 1], [1, 0, 0, 0], [0, 1, 0, 0]]),
864+
)
865+
}
866+
867+
#[test]
868+
fn kron_i64() {
869+
let a = arr2(&[[1, 0]]);
870+
let b = arr2(&[[0, 1], [1, 0]]);
871+
let r = arr2(&[[0, 1, 0, 0], [1, 0, 0, 0]]);
872+
assert_eq!(kron(&a, &b), r);
873+
874+
let a = arr2(&[[1, 0], [0, 0], [0, 1]]);
875+
let b = arr2(&[[0, 1], [1, 0]]);
876+
let r = arr2(&[
877+
[0, 1, 0, 0],
878+
[1, 0, 0, 0],
879+
[0, 0, 0, 0],
880+
[0, 0, 0, 0],
881+
[0, 0, 0, 1],
882+
[0, 0, 1, 0],
883+
]);
884+
assert_eq!(kron(&a, &b), r);
885+
}

0 commit comments

Comments
 (0)