Skip to content

Commit 6dc1a67

Browse files
committed
New API, more clear and more efficient.
1 parent 566909e commit 6dc1a67

File tree

1 file changed

+20
-22
lines changed

1 file changed

+20
-22
lines changed

src/linalg/impl_linalg.rs

+20-22
Original file line numberDiff line numberDiff line change
@@ -825,17 +825,16 @@ mod blas_tests {
825825
}
826826

827827
#[allow(dead_code)]
828-
fn general_outer_to_dyn<Sa, Sb, I, F, T>(
828+
fn general_outer_to_dyn<Sa, Sb, F, T>(
829829
a: &ArrayBase<Sa, IxDyn>,
830-
b: &ArrayBase<Sb, I>,
830+
b: &ArrayBase<Sb, IxDyn>,
831831
f: F,
832832
) -> ArrayD<T>
833833
where
834834
T: Copy,
835835
Sa: Data<Elem = T>,
836836
Sb: Data<Elem = T>,
837-
I: Dimension,
838-
F: Fn(ArrayViewMut<T, IxDyn>, T, &ArrayBase<Sb, I>) -> (),
837+
F: Fn(T, T) -> T,
839838
{
840839
//Iterators on the shapes, compelted by 1s
841840
let a_shape_iter = a.shape().iter().chain([1].iter().cycle());
@@ -851,25 +850,24 @@ where
851850
unsafe {
852851
let mut res: ArrayD<T> = ArrayBase::uninitialized(res_dim);
853852
let res_chunks = res.exact_chunks_mut(b.shape());
854-
Zip::from(res_chunks).and(a).apply(|res_chunk, &a_elem| f(res_chunk, a_elem, b));
853+
Zip::from(res_chunks).and(a).apply(|res_chunk, &a_elem| {
854+
Zip::from(res_chunk)
855+
.and(b)
856+
.apply(|res_elem, &b_elem| *res_elem = f(a_elem, b_elem))
857+
});
855858
res
856859
}
857860
}
858861

859862
#[allow(dead_code, clippy::type_repetition_in_bounds)]
860-
fn kron_to_dyn<Sa, I, Sb, T>(a: &ArrayBase<Sa, IxDyn>, b: &ArrayBase<Sb, I>) -> Array<T, IxDyn>
863+
fn kron_to_dyn<Sa, Sb, T>(a: &ArrayBase<Sa, IxDyn>, b: &ArrayBase<Sb, IxDyn>) -> Array<T, IxDyn>
861864
where
862865
T: Copy,
863866
Sa: Data<Elem = T>,
864867
Sb: Data<Elem = T>,
865-
I: Dimension,
866-
T: crate::ScalarOperand + std::ops::MulAssign,
867-
for<'a> &'a ArrayBase<Sb, I>: std::ops::Mul<T, Output = Array<T, I>>,
868+
T: crate::ScalarOperand + std::ops::Mul<Output = T>,
868869
{
869-
general_outer_to_dyn(a, b, |mut res, x, a| {
870-
res.assign(a);
871-
res *= x
872-
})
870+
general_outer_to_dyn(a, b, std::ops::Mul::mul)
873871
}
874872

875873
#[allow(dead_code)]
@@ -883,7 +881,7 @@ where
883881
Sa: Data<Elem = T>,
884882
Sb: Data<Elem = T>,
885883
I: Dimension,
886-
F: Fn(ArrayViewMut<T, I>, T, &ArrayBase<Sb, I>) -> (),
884+
F: Fn(T, T) -> T,
887885
{
888886
let mut res_dim = a.raw_dim();
889887
let mut res_dim_view = res_dim.as_array_view_mut();
@@ -892,7 +890,11 @@ where
892890
unsafe {
893891
let mut res: Array<T, I> = ArrayBase::uninitialized(res_dim);
894892
let res_chunks = res.exact_chunks_mut(b.raw_dim());
895-
Zip::from(res_chunks).and(a).apply(|r_c, &x| f(r_c, x, b));
893+
Zip::from(res_chunks).and(a).apply(|res_chunk, &a_elem| {
894+
Zip::from(res_chunk)
895+
.and(b)
896+
.apply(|r_elem, &b_elem| *r_elem = f(a_elem, b_elem))
897+
});
896898
res
897899
}
898900
}
@@ -904,13 +906,9 @@ where
904906
Sa: Data<Elem = T>,
905907
Sb: Data<Elem = T>,
906908
I: Dimension,
907-
T: crate::ScalarOperand + std::ops::MulAssign,
908-
for<'a> &'a ArrayBase<Sb, I>: std::ops::Mul<T, Output = Array<T, I>>,
909+
T: crate::ScalarOperand + std::ops::Mul<Output = T>,
909910
{
910-
general_outer_same_size(a, b, |mut res, x, a| {
911-
res.assign(&a);
912-
res *= x
913-
})
911+
general_outer_same_size(a, b, std::ops::Mul::mul)
914912
}
915913

916914
#[cfg(test)]
@@ -930,7 +928,7 @@ mod kron_test {
930928
[[110, 0, 7], [523, 21, -12]]
931929
];
932930
let res1 = kron_same_size(&a, &b);
933-
let res2 = kron_to_dyn(&a.clone().into_dyn(), &b);
931+
let res2 = kron_to_dyn(&a.clone().into_dyn(), &b.clone().into_dyn());
934932
assert_eq!(res1.clone().into_dyn(), res2);
935933
for a0 in 0..a.len_of(Axis(0)) {
936934
for a1 in 0..a.len_of(Axis(1)) {

0 commit comments

Comments
 (0)