@@ -285,30 +285,12 @@ def _copy_array(x, complex_input):
285
285
dtype = map_dtype_to_device (dpnp .float64 , x .sycl_device )
286
286
287
287
if copy_flag :
288
- x = _copy_kernel ( x , dtype )
288
+ x = x . astype ( dtype , copy = True )
289
289
290
290
# if copying is done, FFT can be in-place (copy_flag = in_place flag)
291
291
return x , copy_flag
292
292
293
293
294
- def _copy_kernel (x , dtype ):
295
- x_copy = dpnp .empty_like (x , dtype = dtype , order = "C" )
296
-
297
- exec_q = x .sycl_queue
298
- _manager = dpu .SequentialOrderManager [exec_q ]
299
- dep_evs = _manager .submitted_events
300
-
301
- ht_copy_ev , copy_ev = ti ._copy_usm_ndarray_into_usm_ndarray (
302
- src = dpnp .get_usm_ndarray (x ),
303
- dst = x_copy .get_array (),
304
- sycl_queue = exec_q ,
305
- depends = dep_evs ,
306
- )
307
- _manager .add_event_pair (ht_copy_ev , copy_ev )
308
-
309
- return x_copy
310
-
311
-
312
294
def _extract_axes_chunk (a , s , chunk_size = 3 ):
313
295
"""
314
296
Classify the first input into a list of lists with each list containing
@@ -438,9 +420,9 @@ def _fft(a, norm, out, forward, in_place, c2c, axes, batch_fft=True):
438
420
return result
439
421
440
422
441
- def _make_array_hermitian (a , n , copy_needed ):
423
+ def _make_array_hermitian (a , n , axis , copy_needed ):
442
424
"""
443
- For `dpnp.fft.irfft` , the input array should be Hermitian. If it is not,
425
+ For complex-to-real FFT , the input array should be Hermitian. If it is not,
444
426
the behavior is undefined. This function makes necessary changes to make
445
427
sure the given array is Hermitian.
446
428
@@ -449,6 +431,7 @@ def _make_array_hermitian(a, n, copy_needed):
449
431
`_truncate_or_pad`, so the array has enough length.
450
432
"""
451
433
434
+ a = dpnp .moveaxis (a , axis , 0 )
452
435
length_is_even = n % 2 == 0
453
436
hermitian = dpnp .all (a [0 ].imag == 0 )
454
437
assert n is not None
@@ -463,14 +446,14 @@ def _make_array_hermitian(a, n, copy_needed):
463
446
464
447
if not hermitian :
465
448
if copy_needed :
466
- a = _copy_kernel ( a , a .dtype )
449
+ a = a . astype ( a .dtype , copy = True )
467
450
468
451
a [0 ].imag = 0
469
452
if length_is_even :
470
453
f_ny = n // 2
471
454
a [f_ny ].imag = 0
472
455
473
- return a
456
+ return dpnp . moveaxis ( a , 0 , axis )
474
457
475
458
476
459
def _scale_result (res , a_shape , norm , forward , index ):
@@ -634,11 +617,12 @@ def dpnp_fft(a, forward, real, n=None, axis=-1, norm=None, out=None):
634
617
635
618
if c2r :
636
619
# input array should be Hermitian for c2r FFT
637
- a = dpnp .moveaxis (a , axis , 0 )
638
620
a = _make_array_hermitian (
639
- a , a .shape [0 ], dpnp .are_same_logical_tensors (a , a_orig )
621
+ a ,
622
+ n = a .shape [0 ],
623
+ axis = axis ,
624
+ copy_needed = dpnp .are_same_logical_tensors (a , a_orig ),
640
625
)
641
- a = dpnp .moveaxis (a , 0 , axis )
642
626
643
627
return _fft (
644
628
a ,
@@ -687,11 +671,12 @@ def dpnp_fftn(a, forward, real, s=None, axes=None, norm=None, out=None):
687
671
if len_axes == 1 :
688
672
a = _truncate_or_pad (a , (s [- 1 ],), (axes [- 1 ],))
689
673
if c2r :
690
- a = dpnp .moveaxis (a , axes [- 1 ], 0 )
691
674
a = _make_array_hermitian (
692
- a , a .shape [0 ], dpnp .are_same_logical_tensors (a , a_orig )
675
+ a ,
676
+ n = a .shape [0 ],
677
+ axis = axes [- 1 ],
678
+ copy_needed = dpnp .are_same_logical_tensors (a , a_orig ),
693
679
)
694
- a = dpnp .moveaxis (a , 0 , axes [- 1 ])
695
680
return _fft (
696
681
a , norm , out , forward , in_place and c2c , c2c , axes [- 1 ], a .ndim != 1
697
682
)
@@ -743,11 +728,12 @@ def dpnp_fftn(a, forward, real, s=None, axes=None, norm=None, out=None):
743
728
)
744
729
a = _truncate_or_pad (a , (s [- 1 ],), (axes [- 1 ],))
745
730
if c2r :
746
- a = dpnp .moveaxis (a , axes [- 1 ], 0 )
747
731
a = _make_array_hermitian (
748
- a , a .shape [0 ], dpnp .are_same_logical_tensors (a , a_orig )
732
+ a ,
733
+ n = a .shape [0 ],
734
+ axis = axes [- 1 ],
735
+ copy_needed = dpnp .are_same_logical_tensors (a , a_orig ),
749
736
)
750
- a = dpnp .moveaxis (a , 0 , axes [- 1 ])
751
737
return _fft (
752
738
a , norm , out , forward , in_place and c2c , c2c , axes [- 1 ], a .ndim != 1
753
739
)
0 commit comments