5
5
See also ..testing for public testing utilities.
6
6
"""
7
7
8
+ from __future__ import annotations
9
+
8
10
import math
9
11
from types import ModuleType
10
- from typing import cast
12
+ from typing import Any , cast
11
13
14
+ import numpy as np
12
15
import pytest
13
16
14
17
from ._utils ._compat import (
15
18
array_namespace ,
16
19
is_array_api_strict_namespace ,
17
20
is_cupy_namespace ,
18
21
is_dask_namespace ,
22
+ is_jax_namespace ,
23
+ is_numpy_namespace ,
19
24
is_pydata_sparse_namespace ,
20
25
is_torch_namespace ,
26
+ to_device ,
21
27
)
22
- from ._utils ._typing import Array
28
+ from ._utils ._typing import Array , Device
23
29
24
- __all__ = ["xp_assert_close" , "xp_assert_equal" ]
30
+ __all__ = ["as_numpy_array" , " xp_assert_close" , "xp_assert_equal" , "xp_assert_less " ]
25
31
26
32
27
33
def _check_ns_shape_dtype (
28
- actual : Array , desired : Array
34
+ actual : Array ,
35
+ desired : Array ,
36
+ check_dtype : bool ,
37
+ check_shape : bool ,
38
+ check_scalar : bool ,
29
39
) -> ModuleType : # numpydoc ignore=RT03
30
40
"""
31
41
Assert that namespace, shape and dtype of the two arrays match.
@@ -36,6 +46,11 @@ def _check_ns_shape_dtype(
36
46
The array produced by the tested function.
37
47
desired : Array
38
48
The expected array (typically hardcoded).
49
+ check_dtype, check_shape : bool, default: True
50
+ Whether to check agreement between actual and desired dtypes and shapes
51
+ check_scalar : bool, default: False
52
+ NumPy only: whether to check agreement between actual and desired types -
53
+ 0d array vs scalar.
39
54
40
55
Returns
41
56
-------
@@ -47,43 +62,67 @@ def _check_ns_shape_dtype(
47
62
msg = f"namespaces do not match: { actual_xp } != f{ desired_xp } "
48
63
assert actual_xp == desired_xp , msg
49
64
50
- actual_shape = actual .shape
51
- desired_shape = desired .shape
52
- if is_dask_namespace (desired_xp ):
53
- # Dask uses nan instead of None for unknown shapes
54
- if any (math .isnan (i ) for i in cast (tuple [float , ...], actual_shape )):
55
- actual_shape = actual .compute ().shape # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
56
- if any (math .isnan (i ) for i in cast (tuple [float , ...], desired_shape )):
57
- desired_shape = desired .compute ().shape # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
58
-
59
- msg = f"shapes do not match: { actual_shape } != f{ desired_shape } "
60
- assert actual_shape == desired_shape , msg
61
-
62
- msg = f"dtypes do not match: { actual .dtype } != { desired .dtype } "
63
- assert actual .dtype == desired .dtype , msg
65
+ if check_shape :
66
+ actual_shape = actual .shape
67
+ desired_shape = desired .shape
68
+ if is_dask_namespace (desired_xp ):
69
+ # Dask uses nan instead of None for unknown shapes
70
+ if any (math .isnan (i ) for i in cast (tuple [float , ...], actual_shape )):
71
+ actual_shape = actual .compute ().shape # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
72
+ if any (math .isnan (i ) for i in cast (tuple [float , ...], desired_shape )):
73
+ desired_shape = desired .compute ().shape # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
74
+
75
+ msg = f"shapes do not match: { actual_shape } != f{ desired_shape } "
76
+ assert actual_shape == desired_shape , msg
77
+
78
+ if check_dtype :
79
+ msg = f"dtypes do not match: { actual .dtype } != { desired .dtype } "
80
+ assert actual .dtype == desired .dtype , msg
81
+
82
+ if is_numpy_namespace (actual_xp ) and check_scalar :
83
+ # only NumPy distinguishes between scalars and arrays; we do if check_scalar.
84
+ _msg = (
85
+ "array-ness does not match:\n Actual: "
86
+ f"{ type (actual )} \n Desired: { type (desired )} "
87
+ )
88
+ assert np .isscalar (actual ) == np .isscalar (desired ), _msg
64
89
65
90
return desired_xp
66
91
67
92
68
- def _prepare_for_test (array : Array , xp : ModuleType ) -> Array :
93
+ def as_numpy_array (array : Array , * , xp : ModuleType ) -> np . typing . NDArray [ Any ]: # type: ignore[explicit-any]
69
94
"""
70
- Ensure that the array can be compared with xp.testing or np.testing.
71
-
72
- This involves transferring it from GPU to CPU memory, densifying it, etc.
95
+ Convert array to NumPy, bypassing GPU-CPU transfer guards and densification guards.
73
96
"""
74
- if is_torch_namespace (xp ):
75
- return array . cpu () # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
97
+ if is_cupy_namespace (xp ):
98
+ return xp . asnumpy ( array )
76
99
if is_pydata_sparse_namespace (xp ):
77
100
return array .todense () # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
101
+
102
+ if is_torch_namespace (xp ):
103
+ array = to_device (array , "cpu" )
78
104
if is_array_api_strict_namespace (xp ):
79
- # Note: we deliberately did not add a `.to_device` method in _typing.pyi
80
- # even if it is required by the standard as many backends don't support it
81
- return array .to_device (xp .Device ("CPU_DEVICE" )) # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
82
- # Note: nothing to do for CuPy, because it uses a bespoke test function
83
- return array
105
+ cpu : Device = xp .Device ("CPU_DEVICE" )
106
+ array = to_device (array , cpu )
107
+ if is_jax_namespace (xp ):
108
+ import jax
84
109
110
+ # Note: only needed if the transfer guard is enabled
111
+ cpu = cast (Device , jax .devices ("cpu" )[0 ])
112
+ array = to_device (array , cpu )
85
113
86
- def xp_assert_equal (actual : Array , desired : Array , err_msg : str = "" ) -> None :
114
+ return np .asarray (array )
115
+
116
+
117
+ def xp_assert_equal (
118
+ actual : Array ,
119
+ desired : Array ,
120
+ * ,
121
+ err_msg : str = "" ,
122
+ check_dtype : bool = True ,
123
+ check_shape : bool = True ,
124
+ check_scalar : bool = False ,
125
+ ) -> None :
87
126
"""
88
127
Array-API compatible version of `np.testing.assert_array_equal`.
89
128
@@ -95,34 +134,56 @@ def xp_assert_equal(actual: Array, desired: Array, err_msg: str = "") -> None:
95
134
The expected array (typically hardcoded).
96
135
err_msg : str, optional
97
136
Error message to display on failure.
137
+ check_dtype, check_shape : bool, default: True
138
+ Whether to check agreement between actual and desired dtypes and shapes
139
+ check_scalar : bool, default: False
140
+ NumPy only: whether to check agreement between actual and desired types -
141
+ 0d array vs scalar.
98
142
99
143
See Also
100
144
--------
101
145
xp_assert_close : Similar function for inexact equality checks.
102
146
numpy.testing.assert_array_equal : Similar function for NumPy arrays.
103
147
"""
104
- xp = _check_ns_shape_dtype (actual , desired )
105
- actual = _prepare_for_test (actual , xp )
106
- desired = _prepare_for_test (desired , xp )
148
+ xp = _check_ns_shape_dtype (actual , desired , check_dtype , check_shape , check_scalar )
149
+ actual_np = as_numpy_array (actual , xp = xp )
150
+ desired_np = as_numpy_array (desired , xp = xp )
151
+ np .testing .assert_array_equal (actual_np , desired_np , err_msg = err_msg )
107
152
108
- if is_cupy_namespace (xp ):
109
- xp .testing .assert_array_equal (actual , desired , err_msg = err_msg )
110
- elif is_torch_namespace (xp ):
111
- # PyTorch recommends using `rtol=0, atol=0` like this
112
- # to test for exact equality
113
- xp .testing .assert_close (
114
- actual ,
115
- desired ,
116
- rtol = 0 ,
117
- atol = 0 ,
118
- equal_nan = True ,
119
- check_dtype = False ,
120
- msg = err_msg or None ,
121
- )
122
- else :
123
- import numpy as np # pylint: disable=import-outside-toplevel
124
153
125
- np .testing .assert_array_equal (actual , desired , err_msg = err_msg )
154
+ def xp_assert_less (
155
+ x : Array ,
156
+ y : Array ,
157
+ * ,
158
+ err_msg : str = "" ,
159
+ check_dtype : bool = True ,
160
+ check_shape : bool = True ,
161
+ check_scalar : bool = False ,
162
+ ) -> None :
163
+ """
164
+ Array-API compatible version of `np.testing.assert_array_less`.
165
+
166
+ Parameters
167
+ ----------
168
+ x, y : Array
169
+ The arrays to compare according to ``x < y`` (elementwise).
170
+ err_msg : str, optional
171
+ Error message to display on failure.
172
+ check_dtype, check_shape : bool, default: True
173
+ Whether to check agreement between actual and desired dtypes and shapes
174
+ check_scalar : bool, default: False
175
+ NumPy only: whether to check agreement between actual and desired types -
176
+ 0d array vs scalar.
177
+
178
+ See Also
179
+ --------
180
+ xp_assert_close : Similar function for inexact equality checks.
181
+ numpy.testing.assert_array_equal : Similar function for NumPy arrays.
182
+ """
183
+ xp = _check_ns_shape_dtype (x , y , check_dtype , check_shape , check_scalar )
184
+ x_np = as_numpy_array (x , xp = xp )
185
+ y_np = as_numpy_array (y , xp = xp )
186
+ np .testing .assert_array_less (x_np , y_np , err_msg = err_msg )
126
187
127
188
128
189
def xp_assert_close (
@@ -132,6 +193,9 @@ def xp_assert_close(
132
193
rtol : float | None = None ,
133
194
atol : float = 0 ,
134
195
err_msg : str = "" ,
196
+ check_dtype : bool = True ,
197
+ check_shape : bool = True ,
198
+ check_scalar : bool = False ,
135
199
) -> None :
136
200
"""
137
201
Array-API compatible version of `np.testing.assert_allclose`.
@@ -148,6 +212,11 @@ def xp_assert_close(
148
212
Absolute tolerance. Default: 0.
149
213
err_msg : str, optional
150
214
Error message to display on failure.
215
+ check_dtype, check_shape : bool, default: True
216
+ Whether to check agreement between actual and desired dtypes and shapes
217
+ check_scalar : bool, default: False
218
+ NumPy only: whether to check agreement between actual and desired types -
219
+ 0d array vs scalar.
151
220
152
221
See Also
153
222
--------
@@ -159,40 +228,26 @@ def xp_assert_close(
159
228
-----
160
229
The default `atol` and `rtol` differ from `xp.all(xpx.isclose(a, b))`.
161
230
"""
162
- xp = _check_ns_shape_dtype (actual , desired )
163
-
164
- floating = xp .isdtype (actual .dtype , ("real floating" , "complex floating" ))
165
- if rtol is None and floating :
166
- # multiplier of 4 is used as for `np.float64` this puts the default `rtol`
167
- # roughly half way between sqrt(eps) and the default for
168
- # `numpy.testing.assert_allclose`, 1e-7
169
- rtol = xp .finfo (actual .dtype ).eps ** 0.5 * 4
170
- elif rtol is None :
171
- rtol = 1e-7
172
-
173
- actual = _prepare_for_test (actual , xp )
174
- desired = _prepare_for_test (desired , xp )
175
-
176
- if is_cupy_namespace (xp ):
177
- xp .testing .assert_allclose (
178
- actual , desired , rtol = rtol , atol = atol , err_msg = err_msg
179
- )
180
- elif is_torch_namespace (xp ):
181
- xp .testing .assert_close (
182
- actual , desired , rtol = rtol , atol = atol , equal_nan = True , msg = err_msg or None
183
- )
184
- else :
185
- import numpy as np # pylint: disable=import-outside-toplevel
186
-
187
- # JAX/Dask arrays work directly with `np.testing`
188
- assert isinstance (rtol , float )
189
- np .testing .assert_allclose ( # type: ignore[call-overload] # pyright: ignore[reportCallIssue]
190
- actual , # pyright: ignore[reportArgumentType]
191
- desired , # pyright: ignore[reportArgumentType]
192
- rtol = rtol ,
193
- atol = atol ,
194
- err_msg = err_msg ,
195
- )
231
+ xp = _check_ns_shape_dtype (actual , desired , check_dtype , check_shape , check_scalar )
232
+
233
+ if rtol is None :
234
+ if xp .isdtype (actual .dtype , ("real floating" , "complex floating" )):
235
+ # multiplier of 4 is used as for `np.float64` this puts the default `rtol`
236
+ # roughly half way between sqrt(eps) and the default for
237
+ # `numpy.testing.assert_allclose`, 1e-7
238
+ rtol = xp .finfo (actual .dtype ).eps ** 0.5 * 4
239
+ else :
240
+ rtol = 1e-7
241
+
242
+ actual_np = as_numpy_array (actual , xp = xp )
243
+ desired_np = as_numpy_array (desired , xp = xp )
244
+ np .testing .assert_allclose ( # pyright: ignore[reportCallIssue]
245
+ actual_np ,
246
+ desired_np ,
247
+ rtol = rtol , # pyright: ignore[reportArgumentType]
248
+ atol = atol ,
249
+ err_msg = err_msg ,
250
+ )
196
251
197
252
198
253
def xfail (
0 commit comments