@@ -825,17 +825,16 @@ mod blas_tests {
825
825
}
826
826
827
827
#[ allow( dead_code) ]
828
- fn general_outer_to_dyn < Sa , Sb , I , F , T > (
828
+ fn general_outer_to_dyn < Sa , Sb , F , T > (
829
829
a : & ArrayBase < Sa , IxDyn > ,
830
- b : & ArrayBase < Sb , I > ,
830
+ b : & ArrayBase < Sb , IxDyn > ,
831
831
f : F ,
832
832
) -> ArrayD < T >
833
833
where
834
834
T : Copy ,
835
835
Sa : Data < Elem = T > ,
836
836
Sb : Data < Elem = T > ,
837
- I : Dimension ,
838
- F : Fn ( ArrayViewMut < T , IxDyn > , T , & ArrayBase < Sb , I > ) -> ( ) ,
837
+ F : Fn ( T , T ) -> T ,
839
838
{
840
839
//Iterators on the shapes, compelted by 1s
841
840
let a_shape_iter = a. shape ( ) . iter ( ) . chain ( [ 1 ] . iter ( ) . cycle ( ) ) ;
@@ -851,25 +850,24 @@ where
851
850
unsafe {
852
851
let mut res: ArrayD < T > = ArrayBase :: uninitialized ( res_dim) ;
853
852
let res_chunks = res. exact_chunks_mut ( b. shape ( ) ) ;
854
- Zip :: from ( res_chunks) . and ( a) . apply ( |res_chunk, & a_elem| f ( res_chunk, a_elem, b) ) ;
853
+ Zip :: from ( res_chunks) . and ( a) . apply ( |res_chunk, & a_elem| {
854
+ Zip :: from ( res_chunk)
855
+ . and ( b)
856
+ . apply ( |res_elem, & b_elem| * res_elem = f ( a_elem, b_elem) )
857
+ } ) ;
855
858
res
856
859
}
857
860
}
858
861
859
862
#[ allow( dead_code, clippy:: type_repetition_in_bounds) ]
860
- fn kron_to_dyn < Sa , I , Sb , T > ( a : & ArrayBase < Sa , IxDyn > , b : & ArrayBase < Sb , I > ) -> Array < T , IxDyn >
863
+ fn kron_to_dyn < Sa , Sb , T > ( a : & ArrayBase < Sa , IxDyn > , b : & ArrayBase < Sb , IxDyn > ) -> Array < T , IxDyn >
861
864
where
862
865
T : Copy ,
863
866
Sa : Data < Elem = T > ,
864
867
Sb : Data < Elem = T > ,
865
- I : Dimension ,
866
- T : crate :: ScalarOperand + std:: ops:: MulAssign ,
867
- for < ' a > & ' a ArrayBase < Sb , I > : std:: ops:: Mul < T , Output = Array < T , I > > ,
868
+ T : crate :: ScalarOperand + std:: ops:: Mul < Output = T > ,
868
869
{
869
- general_outer_to_dyn ( a, b, |mut res, x, a| {
870
- res. assign ( a) ;
871
- res *= x
872
- } )
870
+ general_outer_to_dyn ( a, b, std:: ops:: Mul :: mul)
873
871
}
874
872
875
873
#[ allow( dead_code) ]
@@ -883,7 +881,7 @@ where
883
881
Sa : Data < Elem = T > ,
884
882
Sb : Data < Elem = T > ,
885
883
I : Dimension ,
886
- F : Fn ( ArrayViewMut < T , I > , T , & ArrayBase < Sb , I > ) -> ( ) ,
884
+ F : Fn ( T , T ) -> T ,
887
885
{
888
886
let mut res_dim = a. raw_dim ( ) ;
889
887
let mut res_dim_view = res_dim. as_array_view_mut ( ) ;
@@ -892,7 +890,11 @@ where
892
890
unsafe {
893
891
let mut res: Array < T , I > = ArrayBase :: uninitialized ( res_dim) ;
894
892
let res_chunks = res. exact_chunks_mut ( b. raw_dim ( ) ) ;
895
- Zip :: from ( res_chunks) . and ( a) . apply ( |r_c, & x| f ( r_c, x, b) ) ;
893
+ Zip :: from ( res_chunks) . and ( a) . apply ( |res_chunk, & a_elem| {
894
+ Zip :: from ( res_chunk)
895
+ . and ( b)
896
+ . apply ( |r_elem, & b_elem| * r_elem = f ( a_elem, b_elem) )
897
+ } ) ;
896
898
res
897
899
}
898
900
}
@@ -904,13 +906,9 @@ where
904
906
Sa : Data < Elem = T > ,
905
907
Sb : Data < Elem = T > ,
906
908
I : Dimension ,
907
- T : crate :: ScalarOperand + std:: ops:: MulAssign ,
908
- for < ' a > & ' a ArrayBase < Sb , I > : std:: ops:: Mul < T , Output = Array < T , I > > ,
909
+ T : crate :: ScalarOperand + std:: ops:: Mul < Output = T > ,
909
910
{
910
- general_outer_same_size ( a, b, |mut res, x, a| {
911
- res. assign ( & a) ;
912
- res *= x
913
- } )
911
+ general_outer_same_size ( a, b, std:: ops:: Mul :: mul)
914
912
}
915
913
916
914
#[ cfg( test) ]
@@ -930,7 +928,7 @@ mod kron_test {
930
928
[ [ 110 , 0 , 7 ] , [ 523 , 21 , -12 ] ]
931
929
] ;
932
930
let res1 = kron_same_size ( & a, & b) ;
933
- let res2 = kron_to_dyn ( & a. clone ( ) . into_dyn ( ) , & b) ;
931
+ let res2 = kron_to_dyn ( & a. clone ( ) . into_dyn ( ) , & b. clone ( ) . into_dyn ( ) ) ;
934
932
assert_eq ! ( res1. clone( ) . into_dyn( ) , res2) ;
935
933
for a0 in 0 ..a. len_of ( Axis ( 0 ) ) {
936
934
for a1 in 0 ..a. len_of ( Axis ( 1 ) ) {
0 commit comments