Skip to content

Commit 209d171

Browse files
authored
Merge pull request #1107 from ethanhs/split_complex
Implement real/imag splitting of arrays
2 parents c5e3ba8 + 6937be5 commit 209d171

File tree

3 files changed

+225
-0
lines changed

3 files changed

+225
-0
lines changed

src/impl_raw_views.rs

+85
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use num_complex::Complex;
12
use std::mem;
23
use std::ptr::NonNull;
34

@@ -149,6 +150,73 @@ where
149150
}
150151
}
151152

153+
impl<T, D> RawArrayView<Complex<T>, D>
154+
where
155+
D: Dimension,
156+
{
157+
/// Splits the view into views of the real and imaginary components of the
158+
/// elements.
159+
pub fn split_re_im(self) -> Complex<RawArrayView<T, D>> {
160+
// Check that the size and alignment of `Complex<T>` are as expected.
161+
// These assertions should always pass, for arbitrary `T`.
162+
assert_eq!(
163+
mem::size_of::<Complex<T>>(),
164+
mem::size_of::<T>().checked_mul(2).unwrap()
165+
);
166+
assert_eq!(mem::align_of::<Complex<T>>(), mem::align_of::<T>());
167+
168+
let dim = self.dim.clone();
169+
170+
// Double the strides. In the zero-sized element case and for axes of
171+
// length <= 1, we leave the strides as-is to avoid possible overflow.
172+
let mut strides = self.strides.clone();
173+
if mem::size_of::<T>() != 0 {
174+
for ax in 0..strides.ndim() {
175+
if dim[ax] > 1 {
176+
strides[ax] = (strides[ax] as isize * 2) as usize;
177+
}
178+
}
179+
}
180+
181+
let ptr_re: *mut T = self.ptr.as_ptr().cast();
182+
let ptr_im: *mut T = if self.is_empty() {
183+
// In the empty case, we can just reuse the existing pointer since
184+
// it won't be dereferenced anyway. It is not safe to offset by
185+
// one, since the allocation may be empty.
186+
ptr_re
187+
} else {
188+
// In the nonempty case, we can safely offset into the first
189+
// (complex) element.
190+
unsafe { ptr_re.add(1) }
191+
};
192+
193+
// `Complex` is `repr(C)` with only fields `re: T` and `im: T`. So, the
194+
// real components of the elements start at the same pointer, and the
195+
// imaginary components start at the pointer offset by one, with
196+
// exactly double the strides. The new, doubled strides still meet the
197+
// overflow constraints:
198+
//
199+
// - For the zero-sized element case, the strides are unchanged in
200+
// units of bytes and in units of the element type.
201+
//
202+
// - For the nonzero-sized element case:
203+
//
204+
// - In units of bytes, the strides are unchanged. The only exception
205+
// is axes of length <= 1, but those strides are irrelevant anyway.
206+
//
207+
// - Since `Complex<T>` for nonzero `T` is always at least 2 bytes,
208+
// and the original strides did not overflow in units of bytes, we
209+
// know that the new, doubled strides will not overflow in units of
210+
// `T`.
211+
unsafe {
212+
Complex {
213+
re: RawArrayView::new_(ptr_re, dim.clone(), strides.clone()),
214+
im: RawArrayView::new_(ptr_im, dim, strides),
215+
}
216+
}
217+
}
218+
}
219+
152220
impl<A, D> RawArrayViewMut<A, D>
153221
where
154222
D: Dimension,
@@ -300,3 +368,20 @@ where
300368
unsafe { RawArrayViewMut::new(ptr, self.dim, self.strides) }
301369
}
302370
}
371+
372+
impl<T, D> RawArrayViewMut<Complex<T>, D>
373+
where
374+
D: Dimension,
375+
{
376+
/// Splits the view into views of the real and imaginary components of the
377+
/// elements.
378+
pub fn split_re_im(self) -> Complex<RawArrayViewMut<T, D>> {
379+
let Complex { re, im } = self.into_raw_view().split_re_im();
380+
unsafe {
381+
Complex {
382+
re: RawArrayViewMut::new(re.ptr, re.dim, re.strides),
383+
im: RawArrayViewMut::new(im.ptr, im.dim, im.strides),
384+
}
385+
}
386+
}
387+
}

src/impl_views/splitting.rs

+70
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
use crate::imp_prelude::*;
1010
use crate::slice::MultiSliceArg;
11+
use num_complex::Complex;
1112

1213
/// Methods for read-only array views.
1314
impl<'a, A, D> ArrayView<'a, A, D>
@@ -95,6 +96,37 @@ where
9596
}
9697
}
9798

99+
impl<'a, T, D> ArrayView<'a, Complex<T>, D>
100+
where
101+
D: Dimension,
102+
{
103+
/// Splits the view into views of the real and imaginary components of the
104+
/// elements.
105+
///
106+
/// ```
107+
/// use ndarray::prelude::*;
108+
/// use num_complex::{Complex, Complex64};
109+
///
110+
/// let arr = array![
111+
/// [Complex64::new(1., 2.), Complex64::new(3., 4.)],
112+
/// [Complex64::new(5., 6.), Complex64::new(7., 8.)],
113+
/// [Complex64::new(9., 10.), Complex64::new(11., 12.)],
114+
/// ];
115+
/// let Complex { re, im } = arr.view().split_re_im();
116+
/// assert_eq!(re, array![[1., 3.], [5., 7.], [9., 11.]]);
117+
/// assert_eq!(im, array![[2., 4.], [6., 8.], [10., 12.]]);
118+
/// ```
119+
pub fn split_re_im(self) -> Complex<ArrayView<'a, T, D>> {
120+
unsafe {
121+
let Complex { re, im } = self.into_raw_view().split_re_im();
122+
Complex {
123+
re: re.deref_into_view(),
124+
im: im.deref_into_view(),
125+
}
126+
}
127+
}
128+
}
129+
98130
/// Methods for read-write array views.
99131
impl<'a, A, D> ArrayViewMut<'a, A, D>
100132
where
@@ -135,3 +167,41 @@ where
135167
info.multi_slice_move(self)
136168
}
137169
}
170+
171+
impl<'a, T, D> ArrayViewMut<'a, Complex<T>, D>
172+
where
173+
D: Dimension,
174+
{
175+
/// Splits the view into views of the real and imaginary components of the
176+
/// elements.
177+
///
178+
/// ```
179+
/// use ndarray::prelude::*;
180+
/// use num_complex::{Complex, Complex64};
181+
///
182+
/// let mut arr = array![
183+
/// [Complex64::new(1., 2.), Complex64::new(3., 4.)],
184+
/// [Complex64::new(5., 6.), Complex64::new(7., 8.)],
185+
/// [Complex64::new(9., 10.), Complex64::new(11., 12.)],
186+
/// ];
187+
///
188+
/// let Complex { mut re, mut im } = arr.view_mut().split_re_im();
189+
/// assert_eq!(re, array![[1., 3.], [5., 7.], [9., 11.]]);
190+
/// assert_eq!(im, array![[2., 4.], [6., 8.], [10., 12.]]);
191+
///
192+
/// re[[0, 1]] = 13.;
193+
/// im[[2, 0]] = 14.;
194+
///
195+
/// assert_eq!(arr[[0, 1]], Complex64::new(13., 4.));
196+
/// assert_eq!(arr[[2, 0]], Complex64::new(9., 14.));
197+
/// ```
198+
pub fn split_re_im(self) -> Complex<ArrayViewMut<'a, T, D>> {
199+
unsafe {
200+
let Complex { re, im } = self.into_raw_view_mut().split_re_im();
201+
Complex {
202+
re: re.deref_into_view_mut(),
203+
im: im.deref_into_view_mut(),
204+
}
205+
}
206+
}
207+
}

tests/array.rs

+70
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,14 @@
77
clippy::float_cmp
88
)]
99

10+
use approx::assert_relative_eq;
1011
use defmac::defmac;
1112
use itertools::{zip, Itertools};
1213
use ndarray::prelude::*;
1314
use ndarray::{arr3, rcarr2};
1415
use ndarray::indices;
1516
use ndarray::{Slice, SliceInfo, SliceInfoElem};
17+
use num_complex::Complex;
1618
use std::convert::TryFrom;
1719

1820
macro_rules! assert_panics {
@@ -2501,3 +2503,71 @@ fn test_remove_index_oob3() {
25012503
let mut a = array![[10], [4], [1]];
25022504
a.remove_index(Axis(2), 0);
25032505
}
2506+
2507+
#[test]
2508+
fn test_split_re_im_view() {
2509+
let a = Array3::from_shape_fn((3, 4, 5), |(i, j, k)| {
2510+
Complex::<f32>::new(i as f32 * j as f32, k as f32)
2511+
});
2512+
let Complex { re, im } = a.view().split_re_im();
2513+
assert_relative_eq!(re.sum(), 90.);
2514+
assert_relative_eq!(im.sum(), 120.);
2515+
}
2516+
2517+
#[test]
2518+
fn test_split_re_im_view_roundtrip() {
2519+
let a_re = Array3::from_shape_fn((3,1,5), |(i, j, _k)| {
2520+
i * j
2521+
});
2522+
let a_im = Array3::from_shape_fn((3,1,5), |(_i, _j, k)| {
2523+
k
2524+
});
2525+
let a = Array3::from_shape_fn((3,1,5), |(i,j,k)| {
2526+
Complex::new(a_re[[i,j,k]], a_im[[i,j,k]])
2527+
});
2528+
let Complex { re, im } = a.view().split_re_im();
2529+
assert_eq!(a_re, re);
2530+
assert_eq!(a_im, im);
2531+
}
2532+
2533+
#[test]
2534+
fn test_split_re_im_view_mut() {
2535+
let eye_scalar = Array2::<u32>::eye(4);
2536+
let eye_complex = Array2::<Complex<u32>>::eye(4);
2537+
let mut a = Array2::<Complex<u32>>::zeros((4, 4));
2538+
let Complex { mut re, im } = a.view_mut().split_re_im();
2539+
re.assign(&eye_scalar);
2540+
assert_eq!(im.sum(), 0);
2541+
assert_eq!(a, eye_complex);
2542+
}
2543+
2544+
#[test]
2545+
fn test_split_re_im_zerod() {
2546+
let mut a = Array0::from_elem((), Complex::new(42, 32));
2547+
let Complex { re, im } = a.view().split_re_im();
2548+
assert_eq!(re.get(()), Some(&42));
2549+
assert_eq!(im.get(()), Some(&32));
2550+
let cmplx = a.view_mut().split_re_im();
2551+
cmplx.re.assign_to(cmplx.im);
2552+
assert_eq!(a.get(()).unwrap().im, 42);
2553+
}
2554+
2555+
#[test]
2556+
fn test_split_re_im_permuted() {
2557+
let a = Array3::from_shape_fn((3, 4, 5), |(i, j, k)| {
2558+
Complex::new(i * k + j, k)
2559+
});
2560+
let permuted = a.view().permuted_axes([1,0,2]);
2561+
let Complex { re, im } = permuted.split_re_im();
2562+
assert_eq!(re.get((3,2,4)).unwrap(), &11);
2563+
assert_eq!(im.get((3,2,4)).unwrap(), &4);
2564+
}
2565+
2566+
#[test]
2567+
fn test_split_re_im_invert_axis() {
2568+
let mut a = Array::from_shape_fn((2, 3, 2), |(i, j, k)| Complex::new(i as f64 + j as f64, i as f64 + k as f64));
2569+
a.invert_axis(Axis(1));
2570+
let cmplx = a.view().split_re_im();
2571+
assert_eq!(cmplx.re, a.mapv(|z| z.re));
2572+
assert_eq!(cmplx.im, a.mapv(|z| z.im));
2573+
}

0 commit comments

Comments
 (0)