diff --git a/benches/bench1.rs b/benches/bench1.rs index 35a1d6e7e..291a25e97 100644 --- a/benches/bench1.rs +++ b/benches/bench1.rs @@ -431,6 +431,22 @@ fn scalar_add_2(bench: &mut test::Bencher) { bench.iter(|| n + &a); } +#[bench] +fn scalar_add_strided_1(bench: &mut test::Bencher) { + let a = + Array::from_shape_fn((64, 64 * 2), |(i, j)| (i * 64 + j) as f32).slice_move(s![.., ..;2]); + let n = 1.; + bench.iter(|| &a + n); +} + +#[bench] +fn scalar_add_strided_2(bench: &mut test::Bencher) { + let a = + Array::from_shape_fn((64, 64 * 2), |(i, j)| (i * 64 + j) as f32).slice_move(s![.., ..;2]); + let n = 1.; + bench.iter(|| n + &a); +} + #[bench] fn scalar_sub_1(bench: &mut test::Bencher) { let a = Array::::zeros((64, 64)); diff --git a/src/impl_ops.rs b/src/impl_ops.rs index 4804356e8..8999999f1 100644 --- a/src/impl_ops.rs +++ b/src/impl_ops.rs @@ -152,70 +152,56 @@ impl $trt for ArrayBase #[doc=$doc] /// between the reference `self` and the scalar `x`, /// and return the result as a new `Array`. -impl<'a, A, S, D, B> $trt for &'a ArrayBase - where A: Clone + $trt, +impl<'a, A, S, D, B, C> $trt for &'a ArrayBase + where A: Clone + $trt, S: Data, D: Dimension, B: ScalarOperand, { - type Output = Array; - fn $mth(self, x: B) -> Array { - self.to_owned().$mth(x) + type Output = Array; + fn $mth(self, x: B) -> Self::Output { + self.map(move |elt| elt.clone() $operator x.clone()) } } ); ); -// Pick the expression $a for commutative and $b for ordered binop -macro_rules! if_commutative { - (Commute { $a:expr } or { $b:expr }) => { - $a - }; - (Ordered { $a:expr } or { $b:expr }) => { - $b - }; -} - macro_rules! impl_scalar_lhs_op { - // $commutative flag. Reuse the self + scalar impl if we can. - // We can do this safely since these are the primitive numeric types - ($scalar:ty, $commutative:ident, $operator:tt, $trt:ident, $mth:ident, $doc:expr) => ( -// these have no doc -- they are not visible in rustdoc -// Perform elementwise -// between the scalar `self` and array `rhs`, -// and return the result (based on `self`). -impl $trt> for $scalar - where S: DataOwned + DataMut, - D: Dimension, + ($scalar:ty, $operator:tt, $trt:ident, $mth:ident, $doc:expr) => ( +/// Perform elementwise +#[doc=$doc] +/// between the scalar `self` and array `rhs`, +/// and return the result (based on `self`). +impl $trt> for $scalar +where + $scalar: Clone + $trt, + A: Clone, + S: DataOwned + DataMut, + D: Dimension, { type Output = ArrayBase; - fn $mth(self, rhs: ArrayBase) -> ArrayBase { - if_commutative!($commutative { - rhs.$mth(self) - } or {{ - let mut rhs = rhs; - rhs.unordered_foreach_mut(move |elt| { - *elt = self $operator *elt; - }); - rhs - }}) + fn $mth(self, mut rhs: ArrayBase) -> ArrayBase { + rhs.unordered_foreach_mut(move |elt| { + *elt = self.clone() $operator elt.clone(); + }); + rhs } } -// Perform elementwise -// between the scalar `self` and array `rhs`, -// and return the result as a new `Array`. -impl<'a, S, D> $trt<&'a ArrayBase> for $scalar - where S: Data, - D: Dimension, +/// Perform elementwise +#[doc=$doc] +/// between the scalar `self` and array `rhs`, +/// and return the result as a new `Array`. +impl<'a, A, S, D, B> $trt<&'a ArrayBase> for $scalar +where + $scalar: Clone + $trt, + A: Clone, + S: Data, + D: Dimension, { - type Output = Array<$scalar, D>; - fn $mth(self, rhs: &ArrayBase) -> Array<$scalar, D> { - if_commutative!($commutative { - rhs.$mth(self) - } or { - self.$mth(rhs.to_owned()) - }) + type Output = Array; + fn $mth(self, rhs: &ArrayBase) -> Array { + rhs.map(move |elt| self.clone() $operator elt.clone()) } } ); @@ -241,16 +227,16 @@ mod arithmetic_ops { macro_rules! all_scalar_ops { ($int_scalar:ty) => ( - impl_scalar_lhs_op!($int_scalar, Commute, +, Add, add, "addition"); - impl_scalar_lhs_op!($int_scalar, Ordered, -, Sub, sub, "subtraction"); - impl_scalar_lhs_op!($int_scalar, Commute, *, Mul, mul, "multiplication"); - impl_scalar_lhs_op!($int_scalar, Ordered, /, Div, div, "division"); - impl_scalar_lhs_op!($int_scalar, Ordered, %, Rem, rem, "remainder"); - impl_scalar_lhs_op!($int_scalar, Commute, &, BitAnd, bitand, "bit and"); - impl_scalar_lhs_op!($int_scalar, Commute, |, BitOr, bitor, "bit or"); - impl_scalar_lhs_op!($int_scalar, Commute, ^, BitXor, bitxor, "bit xor"); - impl_scalar_lhs_op!($int_scalar, Ordered, <<, Shl, shl, "left shift"); - impl_scalar_lhs_op!($int_scalar, Ordered, >>, Shr, shr, "right shift"); + impl_scalar_lhs_op!($int_scalar, +, Add, add, "addition"); + impl_scalar_lhs_op!($int_scalar, -, Sub, sub, "subtraction"); + impl_scalar_lhs_op!($int_scalar, *, Mul, mul, "multiplication"); + impl_scalar_lhs_op!($int_scalar, /, Div, div, "division"); + impl_scalar_lhs_op!($int_scalar, %, Rem, rem, "remainder"); + impl_scalar_lhs_op!($int_scalar, &, BitAnd, bitand, "bit and"); + impl_scalar_lhs_op!($int_scalar, |, BitOr, bitor, "bit or"); + impl_scalar_lhs_op!($int_scalar, ^, BitXor, bitxor, "bit xor"); + impl_scalar_lhs_op!($int_scalar, <<, Shl, shl, "left shift"); + impl_scalar_lhs_op!($int_scalar, >>, Shr, shr, "right shift"); ); } all_scalar_ops!(i8); @@ -264,31 +250,31 @@ mod arithmetic_ops { all_scalar_ops!(i128); all_scalar_ops!(u128); - impl_scalar_lhs_op!(bool, Commute, &, BitAnd, bitand, "bit and"); - impl_scalar_lhs_op!(bool, Commute, |, BitOr, bitor, "bit or"); - impl_scalar_lhs_op!(bool, Commute, ^, BitXor, bitxor, "bit xor"); + impl_scalar_lhs_op!(bool, &, BitAnd, bitand, "bit and"); + impl_scalar_lhs_op!(bool, |, BitOr, bitor, "bit or"); + impl_scalar_lhs_op!(bool, ^, BitXor, bitxor, "bit xor"); - impl_scalar_lhs_op!(f32, Commute, +, Add, add, "addition"); - impl_scalar_lhs_op!(f32, Ordered, -, Sub, sub, "subtraction"); - impl_scalar_lhs_op!(f32, Commute, *, Mul, mul, "multiplication"); - impl_scalar_lhs_op!(f32, Ordered, /, Div, div, "division"); - impl_scalar_lhs_op!(f32, Ordered, %, Rem, rem, "remainder"); + impl_scalar_lhs_op!(f32, +, Add, add, "addition"); + impl_scalar_lhs_op!(f32, -, Sub, sub, "subtraction"); + impl_scalar_lhs_op!(f32, *, Mul, mul, "multiplication"); + impl_scalar_lhs_op!(f32, /, Div, div, "division"); + impl_scalar_lhs_op!(f32, %, Rem, rem, "remainder"); - impl_scalar_lhs_op!(f64, Commute, +, Add, add, "addition"); - impl_scalar_lhs_op!(f64, Ordered, -, Sub, sub, "subtraction"); - impl_scalar_lhs_op!(f64, Commute, *, Mul, mul, "multiplication"); - impl_scalar_lhs_op!(f64, Ordered, /, Div, div, "division"); - impl_scalar_lhs_op!(f64, Ordered, %, Rem, rem, "remainder"); + impl_scalar_lhs_op!(f64, +, Add, add, "addition"); + impl_scalar_lhs_op!(f64, -, Sub, sub, "subtraction"); + impl_scalar_lhs_op!(f64, *, Mul, mul, "multiplication"); + impl_scalar_lhs_op!(f64, /, Div, div, "division"); + impl_scalar_lhs_op!(f64, %, Rem, rem, "remainder"); - impl_scalar_lhs_op!(Complex, Commute, +, Add, add, "addition"); - impl_scalar_lhs_op!(Complex, Ordered, -, Sub, sub, "subtraction"); - impl_scalar_lhs_op!(Complex, Commute, *, Mul, mul, "multiplication"); - impl_scalar_lhs_op!(Complex, Ordered, /, Div, div, "division"); + impl_scalar_lhs_op!(Complex, +, Add, add, "addition"); + impl_scalar_lhs_op!(Complex, -, Sub, sub, "subtraction"); + impl_scalar_lhs_op!(Complex, *, Mul, mul, "multiplication"); + impl_scalar_lhs_op!(Complex, /, Div, div, "division"); - impl_scalar_lhs_op!(Complex, Commute, +, Add, add, "addition"); - impl_scalar_lhs_op!(Complex, Ordered, -, Sub, sub, "subtraction"); - impl_scalar_lhs_op!(Complex, Commute, *, Mul, mul, "multiplication"); - impl_scalar_lhs_op!(Complex, Ordered, /, Div, div, "division"); + impl_scalar_lhs_op!(Complex, +, Add, add, "addition"); + impl_scalar_lhs_op!(Complex, -, Sub, sub, "subtraction"); + impl_scalar_lhs_op!(Complex, *, Mul, mul, "multiplication"); + impl_scalar_lhs_op!(Complex, /, Div, div, "division"); impl Neg for ArrayBase where