Skip to content

Commit 82db3bf

Browse files
mdhabercrusaderkylucascolley
authored
ENH/TST: xp_assert_ enhancements (#267)
* WIP: xp_assert enhancements * ENH: add xp_assert_less * Rework prepare_for_test (#2) * Fix failures in #267 * Update tests/test_testing.py [skip ci] Co-authored-by: Lucas Colley <[email protected]> * Update _testing.py --------- Co-authored-by: Guido Imperiale <[email protected]> Co-authored-by: Lucas Colley <[email protected]>
1 parent c9204ea commit 82db3bf

File tree

3 files changed

+224
-92
lines changed

3 files changed

+224
-92
lines changed

src/array_api_extra/_lib/_testing.py

+139-84
Original file line numberDiff line numberDiff line change
@@ -5,27 +5,37 @@
55
See also ..testing for public testing utilities.
66
"""
77

8+
from __future__ import annotations
9+
810
import math
911
from types import ModuleType
10-
from typing import cast
12+
from typing import Any, cast
1113

14+
import numpy as np
1215
import pytest
1316

1417
from ._utils._compat import (
1518
array_namespace,
1619
is_array_api_strict_namespace,
1720
is_cupy_namespace,
1821
is_dask_namespace,
22+
is_jax_namespace,
23+
is_numpy_namespace,
1924
is_pydata_sparse_namespace,
2025
is_torch_namespace,
26+
to_device,
2127
)
22-
from ._utils._typing import Array
28+
from ._utils._typing import Array, Device
2329

24-
__all__ = ["xp_assert_close", "xp_assert_equal"]
30+
__all__ = ["as_numpy_array", "xp_assert_close", "xp_assert_equal", "xp_assert_less"]
2531

2632

2733
def _check_ns_shape_dtype(
28-
actual: Array, desired: Array
34+
actual: Array,
35+
desired: Array,
36+
check_dtype: bool,
37+
check_shape: bool,
38+
check_scalar: bool,
2939
) -> ModuleType: # numpydoc ignore=RT03
3040
"""
3141
Assert that namespace, shape and dtype of the two arrays match.
@@ -36,6 +46,11 @@ def _check_ns_shape_dtype(
3646
The array produced by the tested function.
3747
desired : Array
3848
The expected array (typically hardcoded).
49+
check_dtype, check_shape : bool, default: True
50+
Whether to check agreement between actual and desired dtypes and shapes
51+
check_scalar : bool, default: False
52+
NumPy only: whether to check agreement between actual and desired types -
53+
0d array vs scalar.
3954
4055
Returns
4156
-------
@@ -47,43 +62,67 @@ def _check_ns_shape_dtype(
4762
msg = f"namespaces do not match: {actual_xp} != f{desired_xp}"
4863
assert actual_xp == desired_xp, msg
4964

50-
actual_shape = actual.shape
51-
desired_shape = desired.shape
52-
if is_dask_namespace(desired_xp):
53-
# Dask uses nan instead of None for unknown shapes
54-
if any(math.isnan(i) for i in cast(tuple[float, ...], actual_shape)):
55-
actual_shape = actual.compute().shape # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
56-
if any(math.isnan(i) for i in cast(tuple[float, ...], desired_shape)):
57-
desired_shape = desired.compute().shape # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
58-
59-
msg = f"shapes do not match: {actual_shape} != f{desired_shape}"
60-
assert actual_shape == desired_shape, msg
61-
62-
msg = f"dtypes do not match: {actual.dtype} != {desired.dtype}"
63-
assert actual.dtype == desired.dtype, msg
65+
if check_shape:
66+
actual_shape = actual.shape
67+
desired_shape = desired.shape
68+
if is_dask_namespace(desired_xp):
69+
# Dask uses nan instead of None for unknown shapes
70+
if any(math.isnan(i) for i in cast(tuple[float, ...], actual_shape)):
71+
actual_shape = actual.compute().shape # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
72+
if any(math.isnan(i) for i in cast(tuple[float, ...], desired_shape)):
73+
desired_shape = desired.compute().shape # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
74+
75+
msg = f"shapes do not match: {actual_shape} != f{desired_shape}"
76+
assert actual_shape == desired_shape, msg
77+
78+
if check_dtype:
79+
msg = f"dtypes do not match: {actual.dtype} != {desired.dtype}"
80+
assert actual.dtype == desired.dtype, msg
81+
82+
if is_numpy_namespace(actual_xp) and check_scalar:
83+
# only NumPy distinguishes between scalars and arrays; we do if check_scalar.
84+
_msg = (
85+
"array-ness does not match:\n Actual: "
86+
f"{type(actual)}\n Desired: {type(desired)}"
87+
)
88+
assert np.isscalar(actual) == np.isscalar(desired), _msg
6489

6590
return desired_xp
6691

6792

68-
def _prepare_for_test(array: Array, xp: ModuleType) -> Array:
93+
def as_numpy_array(array: Array, *, xp: ModuleType) -> np.typing.NDArray[Any]: # type: ignore[explicit-any]
6994
"""
70-
Ensure that the array can be compared with xp.testing or np.testing.
71-
72-
This involves transferring it from GPU to CPU memory, densifying it, etc.
95+
Convert array to NumPy, bypassing GPU-CPU transfer guards and densification guards.
7396
"""
74-
if is_torch_namespace(xp):
75-
return array.cpu() # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
97+
if is_cupy_namespace(xp):
98+
return xp.asnumpy(array)
7699
if is_pydata_sparse_namespace(xp):
77100
return array.todense() # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
101+
102+
if is_torch_namespace(xp):
103+
array = to_device(array, "cpu")
78104
if is_array_api_strict_namespace(xp):
79-
# Note: we deliberately did not add a `.to_device` method in _typing.pyi
80-
# even if it is required by the standard as many backends don't support it
81-
return array.to_device(xp.Device("CPU_DEVICE")) # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
82-
# Note: nothing to do for CuPy, because it uses a bespoke test function
83-
return array
105+
cpu: Device = xp.Device("CPU_DEVICE")
106+
array = to_device(array, cpu)
107+
if is_jax_namespace(xp):
108+
import jax
84109

110+
# Note: only needed if the transfer guard is enabled
111+
cpu = cast(Device, jax.devices("cpu")[0])
112+
array = to_device(array, cpu)
85113

86-
def xp_assert_equal(actual: Array, desired: Array, err_msg: str = "") -> None:
114+
return np.asarray(array)
115+
116+
117+
def xp_assert_equal(
118+
actual: Array,
119+
desired: Array,
120+
*,
121+
err_msg: str = "",
122+
check_dtype: bool = True,
123+
check_shape: bool = True,
124+
check_scalar: bool = False,
125+
) -> None:
87126
"""
88127
Array-API compatible version of `np.testing.assert_array_equal`.
89128
@@ -95,34 +134,56 @@ def xp_assert_equal(actual: Array, desired: Array, err_msg: str = "") -> None:
95134
The expected array (typically hardcoded).
96135
err_msg : str, optional
97136
Error message to display on failure.
137+
check_dtype, check_shape : bool, default: True
138+
Whether to check agreement between actual and desired dtypes and shapes
139+
check_scalar : bool, default: False
140+
NumPy only: whether to check agreement between actual and desired types -
141+
0d array vs scalar.
98142
99143
See Also
100144
--------
101145
xp_assert_close : Similar function for inexact equality checks.
102146
numpy.testing.assert_array_equal : Similar function for NumPy arrays.
103147
"""
104-
xp = _check_ns_shape_dtype(actual, desired)
105-
actual = _prepare_for_test(actual, xp)
106-
desired = _prepare_for_test(desired, xp)
148+
xp = _check_ns_shape_dtype(actual, desired, check_dtype, check_shape, check_scalar)
149+
actual_np = as_numpy_array(actual, xp=xp)
150+
desired_np = as_numpy_array(desired, xp=xp)
151+
np.testing.assert_array_equal(actual_np, desired_np, err_msg=err_msg)
107152

108-
if is_cupy_namespace(xp):
109-
xp.testing.assert_array_equal(actual, desired, err_msg=err_msg)
110-
elif is_torch_namespace(xp):
111-
# PyTorch recommends using `rtol=0, atol=0` like this
112-
# to test for exact equality
113-
xp.testing.assert_close(
114-
actual,
115-
desired,
116-
rtol=0,
117-
atol=0,
118-
equal_nan=True,
119-
check_dtype=False,
120-
msg=err_msg or None,
121-
)
122-
else:
123-
import numpy as np # pylint: disable=import-outside-toplevel
124153

125-
np.testing.assert_array_equal(actual, desired, err_msg=err_msg)
154+
def xp_assert_less(
155+
x: Array,
156+
y: Array,
157+
*,
158+
err_msg: str = "",
159+
check_dtype: bool = True,
160+
check_shape: bool = True,
161+
check_scalar: bool = False,
162+
) -> None:
163+
"""
164+
Array-API compatible version of `np.testing.assert_array_less`.
165+
166+
Parameters
167+
----------
168+
x, y : Array
169+
The arrays to compare according to ``x < y`` (elementwise).
170+
err_msg : str, optional
171+
Error message to display on failure.
172+
check_dtype, check_shape : bool, default: True
173+
Whether to check agreement between actual and desired dtypes and shapes
174+
check_scalar : bool, default: False
175+
NumPy only: whether to check agreement between actual and desired types -
176+
0d array vs scalar.
177+
178+
See Also
179+
--------
180+
xp_assert_close : Similar function for inexact equality checks.
181+
numpy.testing.assert_array_equal : Similar function for NumPy arrays.
182+
"""
183+
xp = _check_ns_shape_dtype(x, y, check_dtype, check_shape, check_scalar)
184+
x_np = as_numpy_array(x, xp=xp)
185+
y_np = as_numpy_array(y, xp=xp)
186+
np.testing.assert_array_less(x_np, y_np, err_msg=err_msg)
126187

127188

128189
def xp_assert_close(
@@ -132,6 +193,9 @@ def xp_assert_close(
132193
rtol: float | None = None,
133194
atol: float = 0,
134195
err_msg: str = "",
196+
check_dtype: bool = True,
197+
check_shape: bool = True,
198+
check_scalar: bool = False,
135199
) -> None:
136200
"""
137201
Array-API compatible version of `np.testing.assert_allclose`.
@@ -148,6 +212,11 @@ def xp_assert_close(
148212
Absolute tolerance. Default: 0.
149213
err_msg : str, optional
150214
Error message to display on failure.
215+
check_dtype, check_shape : bool, default: True
216+
Whether to check agreement between actual and desired dtypes and shapes
217+
check_scalar : bool, default: False
218+
NumPy only: whether to check agreement between actual and desired types -
219+
0d array vs scalar.
151220
152221
See Also
153222
--------
@@ -159,40 +228,26 @@ def xp_assert_close(
159228
-----
160229
The default `atol` and `rtol` differ from `xp.all(xpx.isclose(a, b))`.
161230
"""
162-
xp = _check_ns_shape_dtype(actual, desired)
163-
164-
floating = xp.isdtype(actual.dtype, ("real floating", "complex floating"))
165-
if rtol is None and floating:
166-
# multiplier of 4 is used as for `np.float64` this puts the default `rtol`
167-
# roughly half way between sqrt(eps) and the default for
168-
# `numpy.testing.assert_allclose`, 1e-7
169-
rtol = xp.finfo(actual.dtype).eps ** 0.5 * 4
170-
elif rtol is None:
171-
rtol = 1e-7
172-
173-
actual = _prepare_for_test(actual, xp)
174-
desired = _prepare_for_test(desired, xp)
175-
176-
if is_cupy_namespace(xp):
177-
xp.testing.assert_allclose(
178-
actual, desired, rtol=rtol, atol=atol, err_msg=err_msg
179-
)
180-
elif is_torch_namespace(xp):
181-
xp.testing.assert_close(
182-
actual, desired, rtol=rtol, atol=atol, equal_nan=True, msg=err_msg or None
183-
)
184-
else:
185-
import numpy as np # pylint: disable=import-outside-toplevel
186-
187-
# JAX/Dask arrays work directly with `np.testing`
188-
assert isinstance(rtol, float)
189-
np.testing.assert_allclose( # type: ignore[call-overload] # pyright: ignore[reportCallIssue]
190-
actual, # pyright: ignore[reportArgumentType]
191-
desired, # pyright: ignore[reportArgumentType]
192-
rtol=rtol,
193-
atol=atol,
194-
err_msg=err_msg,
195-
)
231+
xp = _check_ns_shape_dtype(actual, desired, check_dtype, check_shape, check_scalar)
232+
233+
if rtol is None:
234+
if xp.isdtype(actual.dtype, ("real floating", "complex floating")):
235+
# multiplier of 4 is used as for `np.float64` this puts the default `rtol`
236+
# roughly half way between sqrt(eps) and the default for
237+
# `numpy.testing.assert_allclose`, 1e-7
238+
rtol = xp.finfo(actual.dtype).eps ** 0.5 * 4
239+
else:
240+
rtol = 1e-7
241+
242+
actual_np = as_numpy_array(actual, xp=xp)
243+
desired_np = as_numpy_array(desired, xp=xp)
244+
np.testing.assert_allclose( # pyright: ignore[reportCallIssue]
245+
actual_np,
246+
desired_np,
247+
rtol=rtol, # pyright: ignore[reportArgumentType]
248+
atol=atol,
249+
err_msg=err_msg,
250+
)
196251

197252

198253
def xfail(

tests/test_funcs.py

-1
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,6 @@ def test_device(self, xp: ModuleType, device: Device):
196196
y = apply_where(x % 2 == 0, x, self.f1, fill_value=x)
197197
assert get_device(y) == device
198198

199-
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no isdtype")
200199
@pytest.mark.filterwarnings("ignore::RuntimeWarning") # overflows, etc.
201200
@hypothesis.settings(
202201
# The xp and library fixtures are not regenerated between hypothesis iterations

0 commit comments

Comments
 (0)