Skip to content

Commit e51a69d

Browse files
committed
address comments
1 parent 9ad7238 commit e51a69d

File tree

2 files changed

+19
-33
lines changed

2 files changed

+19
-33
lines changed

CHANGELOG.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ This release achieves 100% compliance with Python Array API specification (revis
2929
* Removed `einsum_call` keyword from `dpnp.einsum_path` signature [#2421](https://github.com/IntelPython/dpnp/pull/2421)
3030
* Changed `"max dimensions"` to `None` in array API capabilities [#2432](https://github.com/IntelPython/dpnp/pull/2432)
3131
* Updated kernel header `i0.hpp` to expose `cyl_bessel_i0` function depending on build target [#2440](https://github.com/IntelPython/dpnp/pull/2440)
32-
* Updated FFT module to make input array Hermitian before calling complex-to-real FFT [#2444](https://github.com/IntelPython/dpnp/pull/2444)
32+
* Updated FFT module to ensure an input array is Hermitian before calling complex-to-real FFT [#2444](https://github.com/IntelPython/dpnp/pull/2444)
3333

3434
### Fixed
3535

dpnp/fft/dpnp_utils_fft.py

+18-32
Original file line numberDiff line numberDiff line change
@@ -285,30 +285,12 @@ def _copy_array(x, complex_input):
285285
dtype = map_dtype_to_device(dpnp.float64, x.sycl_device)
286286

287287
if copy_flag:
288-
x = _copy_kernel(x, dtype)
288+
x = x.astype(dtype, copy=True)
289289

290290
# if copying is done, FFT can be in-place (copy_flag = in_place flag)
291291
return x, copy_flag
292292

293293

294-
def _copy_kernel(x, dtype):
295-
x_copy = dpnp.empty_like(x, dtype=dtype, order="C")
296-
297-
exec_q = x.sycl_queue
298-
_manager = dpu.SequentialOrderManager[exec_q]
299-
dep_evs = _manager.submitted_events
300-
301-
ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
302-
src=dpnp.get_usm_ndarray(x),
303-
dst=x_copy.get_array(),
304-
sycl_queue=exec_q,
305-
depends=dep_evs,
306-
)
307-
_manager.add_event_pair(ht_copy_ev, copy_ev)
308-
309-
return x_copy
310-
311-
312294
def _extract_axes_chunk(a, s, chunk_size=3):
313295
"""
314296
Classify the first input into a list of lists with each list containing
@@ -438,9 +420,9 @@ def _fft(a, norm, out, forward, in_place, c2c, axes, batch_fft=True):
438420
return result
439421

440422

441-
def _make_array_hermitian(a, n, copy_needed):
423+
def _make_array_hermitian(a, n, axis, copy_needed):
442424
"""
443-
For `dpnp.fft.irfft`, the input array should be Hermitian. If it is not,
425+
For complex-to-real FFT, the input array should be Hermitian. If it is not,
444426
the behavior is undefined. This function makes necessary changes to make
445427
sure the given array is Hermitian.
446428
@@ -449,6 +431,7 @@ def _make_array_hermitian(a, n, copy_needed):
449431
`_truncate_or_pad`, so the array has enough length.
450432
"""
451433

434+
a = dpnp.moveaxis(a, axis, 0)
452435
length_is_even = n % 2 == 0
453436
hermitian = dpnp.all(a[0].imag == 0)
454437
assert n is not None
@@ -463,14 +446,14 @@ def _make_array_hermitian(a, n, copy_needed):
463446

464447
if not hermitian:
465448
if copy_needed:
466-
a = _copy_kernel(a, a.dtype)
449+
a = a.astype(a.dtype, copy=True)
467450

468451
a[0].imag = 0
469452
if length_is_even:
470453
f_ny = n // 2
471454
a[f_ny].imag = 0
472455

473-
return a
456+
return dpnp.moveaxis(a, 0, axis)
474457

475458

476459
def _scale_result(res, a_shape, norm, forward, index):
@@ -634,11 +617,12 @@ def dpnp_fft(a, forward, real, n=None, axis=-1, norm=None, out=None):
634617

635618
if c2r:
636619
# input array should be Hermitian for c2r FFT
637-
a = dpnp.moveaxis(a, axis, 0)
638620
a = _make_array_hermitian(
639-
a, a.shape[0], dpnp.are_same_logical_tensors(a, a_orig)
621+
a,
622+
n=a.shape[0],
623+
axis=axis,
624+
copy_needed=dpnp.are_same_logical_tensors(a, a_orig),
640625
)
641-
a = dpnp.moveaxis(a, 0, axis)
642626

643627
return _fft(
644628
a,
@@ -687,11 +671,12 @@ def dpnp_fftn(a, forward, real, s=None, axes=None, norm=None, out=None):
687671
if len_axes == 1:
688672
a = _truncate_or_pad(a, (s[-1],), (axes[-1],))
689673
if c2r:
690-
a = dpnp.moveaxis(a, axes[-1], 0)
691674
a = _make_array_hermitian(
692-
a, a.shape[0], dpnp.are_same_logical_tensors(a, a_orig)
675+
a,
676+
n=a.shape[0],
677+
axis=axes[-1],
678+
copy_needed=dpnp.are_same_logical_tensors(a, a_orig),
693679
)
694-
a = dpnp.moveaxis(a, 0, axes[-1])
695680
return _fft(
696681
a, norm, out, forward, in_place and c2c, c2c, axes[-1], a.ndim != 1
697682
)
@@ -743,11 +728,12 @@ def dpnp_fftn(a, forward, real, s=None, axes=None, norm=None, out=None):
743728
)
744729
a = _truncate_or_pad(a, (s[-1],), (axes[-1],))
745730
if c2r:
746-
a = dpnp.moveaxis(a, axes[-1], 0)
747731
a = _make_array_hermitian(
748-
a, a.shape[0], dpnp.are_same_logical_tensors(a, a_orig)
732+
a,
733+
n=a.shape[0],
734+
axis=axes[-1],
735+
copy_needed=dpnp.are_same_logical_tensors(a, a_orig),
749736
)
750-
a = dpnp.moveaxis(a, 0, axes[-1])
751737
return _fft(
752738
a, norm, out, forward, in_place and c2c, c2c, axes[-1], a.ndim != 1
753739
)

0 commit comments

Comments
 (0)