|
6 | 6 | )]
|
7 | 7 | #![cfg(feature = "std")]
|
8 | 8 | use ndarray::linalg::general_mat_mul;
|
| 9 | +use ndarray::linalg::kron; |
9 | 10 | use ndarray::prelude::*;
|
10 | 11 | use ndarray::{rcarr1, rcarr2};
|
11 | 12 | use ndarray::{Data, LinalgScalar};
|
@@ -820,3 +821,65 @@ fn vec_mat_mul() {
|
820 | 821 | }
|
821 | 822 | }
|
822 | 823 | }
|
| 824 | + |
| 825 | +#[test] |
| 826 | +fn kron_square_f64() { |
| 827 | + let a = arr2(&[[1.0, 0.0], [0.0, 1.0]]); |
| 828 | + let b = arr2(&[[0.0, 1.0], [1.0, 0.0]]); |
| 829 | + |
| 830 | + assert_eq!( |
| 831 | + kron(&a, &b), |
| 832 | + arr2(&[ |
| 833 | + [0.0, 1.0, 0.0, 0.0], |
| 834 | + [1.0, 0.0, 0.0, 0.0], |
| 835 | + [0.0, 0.0, 0.0, 1.0], |
| 836 | + [0.0, 0.0, 1.0, 0.0] |
| 837 | + ]), |
| 838 | + ); |
| 839 | + |
| 840 | + assert_eq!( |
| 841 | + kron(&b, &a), |
| 842 | + arr2(&[ |
| 843 | + [0.0, 0.0, 1.0, 0.0], |
| 844 | + [0.0, 0.0, 0.0, 1.0], |
| 845 | + [1.0, 0.0, 0.0, 0.0], |
| 846 | + [0.0, 1.0, 0.0, 0.0] |
| 847 | + ]), |
| 848 | + ) |
| 849 | +} |
| 850 | + |
| 851 | +#[test] |
| 852 | +fn kron_square_i64() { |
| 853 | + let a = arr2(&[[1, 0], [0, 1]]); |
| 854 | + let b = arr2(&[[0, 1], [1, 0]]); |
| 855 | + |
| 856 | + assert_eq!( |
| 857 | + kron(&a, &b), |
| 858 | + arr2(&[[0, 1, 0, 0], [1, 0, 0, 0], [0, 0, 0, 1], [0, 0, 1, 0]]), |
| 859 | + ); |
| 860 | + |
| 861 | + assert_eq!( |
| 862 | + kron(&b, &a), |
| 863 | + arr2(&[[0, 0, 1, 0], [0, 0, 0, 1], [1, 0, 0, 0], [0, 1, 0, 0]]), |
| 864 | + ) |
| 865 | +} |
| 866 | + |
| 867 | +#[test] |
| 868 | +fn kron_i64() { |
| 869 | + let a = arr2(&[[1, 0]]); |
| 870 | + let b = arr2(&[[0, 1], [1, 0]]); |
| 871 | + let r = arr2(&[[0, 1, 0, 0], [1, 0, 0, 0]]); |
| 872 | + assert_eq!(kron(&a, &b), r); |
| 873 | + |
| 874 | + let a = arr2(&[[1, 0], [0, 0], [0, 1]]); |
| 875 | + let b = arr2(&[[0, 1], [1, 0]]); |
| 876 | + let r = arr2(&[ |
| 877 | + [0, 1, 0, 0], |
| 878 | + [1, 0, 0, 0], |
| 879 | + [0, 0, 0, 0], |
| 880 | + [0, 0, 0, 0], |
| 881 | + [0, 0, 0, 1], |
| 882 | + [0, 0, 1, 0], |
| 883 | + ]); |
| 884 | + assert_eq!(kron(&a, &b), r); |
| 885 | +} |
0 commit comments