Skip to content

Commit ad14515

Browse files
authored
Merge pull request #157 from asmeurer/2023-support
2023.12 support
2 parents b15a815 + b8a59da commit ad14515

File tree

6 files changed

+77
-8
lines changed

6 files changed

+77
-8
lines changed

array_api_compat/common/_aliases.py

+62-2
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from typing import NamedTuple
1313
import inspect
1414

15-
from ._helpers import _check_device
15+
from ._helpers import array_namespace, _check_device
1616

1717
# These functions are modified from the NumPy versions.
1818

@@ -264,6 +264,66 @@ def var(
264264
) -> ndarray:
265265
return xp.var(x, axis=axis, ddof=correction, keepdims=keepdims, **kwargs)
266266

267+
268+
# The min and max argument names in clip are different and not optional in numpy, and type
269+
# promotion behavior is different.
270+
def clip(
271+
x: ndarray,
272+
/,
273+
min: Optional[Union[int, float, ndarray]] = None,
274+
max: Optional[Union[int, float, ndarray]] = None,
275+
*,
276+
xp,
277+
# TODO: np.clip has other ufunc kwargs
278+
out: Optional[ndarray] = None,
279+
) -> ndarray:
280+
def _isscalar(a):
281+
return isinstance(a, (int, float, type(None)))
282+
min_shape = () if _isscalar(min) else min.shape
283+
max_shape = () if _isscalar(max) else max.shape
284+
result_shape = xp.broadcast_shapes(x.shape, min_shape, max_shape)
285+
286+
wrapped_xp = array_namespace(x)
287+
288+
# np.clip does type promotion but the array API clip requires that the
289+
# output have the same dtype as x. We do this instead of just downcasting
290+
# the result of xp.clip() to handle some corner cases better (e.g.,
291+
# avoiding uint64 -> float64 promotion).
292+
293+
# Note: cases where min or max overflow (integer) or round (float) in the
294+
# wrong direction when downcasting to x.dtype are unspecified. This code
295+
# just does whatever NumPy does when it downcasts in the assignment, but
296+
# other behavior could be preferred, especially for integers. For example,
297+
# this code produces:
298+
299+
# >>> clip(asarray(0, dtype=int8), asarray(128, dtype=int16), None)
300+
# -128
301+
302+
# but an answer of 0 might be preferred. See
303+
# https://github.com/numpy/numpy/issues/24976 for more discussion on this issue.
304+
305+
306+
# At least handle the case of Python integers correctly (see
307+
# https://github.com/numpy/numpy/pull/26892).
308+
if type(min) is int and min <= xp.iinfo(x.dtype).min:
309+
min = None
310+
if type(max) is int and max >= xp.iinfo(x.dtype).max:
311+
max = None
312+
313+
if out is None:
314+
out = wrapped_xp.asarray(xp.broadcast_to(x, result_shape), copy=True)
315+
if min is not None:
316+
a = xp.broadcast_to(xp.asarray(min), result_shape)
317+
ia = (out < a) | xp.isnan(a)
318+
# torch requires an explicit cast here
319+
out[ia] = wrapped_xp.astype(a[ia], out.dtype)
320+
if max is not None:
321+
b = xp.broadcast_to(xp.asarray(max), result_shape)
322+
ib = (out > b) | xp.isnan(b)
323+
out[ib] = wrapped_xp.astype(b[ib], out.dtype)
324+
# Return a scalar for 0-D
325+
return out[()]
326+
267327
# Unlike transpose(), the axes argument to permute_dims() is required.
268328
def permute_dims(x: ndarray, /, axes: Tuple[int, ...], xp) -> ndarray:
269329
return xp.transpose(x, axes)
@@ -465,6 +525,6 @@ def isdtype(
465525
'linspace', 'ones', 'ones_like', 'zeros', 'zeros_like',
466526
'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult',
467527
'unique_all', 'unique_counts', 'unique_inverse', 'unique_values',
468-
'astype', 'std', 'var', 'permute_dims', 'reshape', 'argsort',
528+
'astype', 'std', 'var', 'clip', 'permute_dims', 'reshape', 'argsort',
469529
'sort', 'nonzero', 'sum', 'prod', 'ceil', 'floor', 'trunc',
470530
'matmul', 'matrix_transpose', 'tensordot', 'vecdot', 'isdtype']

array_api_compat/cupy/_aliases.py

+1
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
astype = _aliases.astype
4848
std = get_xp(cp)(_aliases.std)
4949
var = get_xp(cp)(_aliases.var)
50+
clip = get_xp(cp)(_aliases.clip)
5051
permute_dims = get_xp(cp)(_aliases.permute_dims)
5152
reshape = get_xp(cp)(_aliases.reshape)
5253
argsort = get_xp(cp)(_aliases.argsort)

array_api_compat/dask/array/_aliases.py

+1
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ def _dask_arange(
8888
permute_dims = get_xp(da)(_aliases.permute_dims)
8989
std = get_xp(da)(_aliases.std)
9090
var = get_xp(da)(_aliases.var)
91+
clip = get_xp(da)(_aliases.clip)
9192
empty = get_xp(da)(_aliases.empty)
9293
empty_like = get_xp(da)(_aliases.empty_like)
9394
full = get_xp(da)(_aliases.full)

array_api_compat/numpy/_aliases.py

+1
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
astype = _aliases.astype
4848
std = get_xp(np)(_aliases.std)
4949
var = get_xp(np)(_aliases.var)
50+
clip = get_xp(np)(_aliases.clip)
5051
permute_dims = get_xp(np)(_aliases.permute_dims)
5152
reshape = get_xp(np)(_aliases.reshape)
5253
argsort = get_xp(np)(_aliases.argsort)

array_api_compat/torch/_aliases.py

+9-6
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from builtins import all as _builtin_all, any as _builtin_any
55

66
from ..common._aliases import (matrix_transpose as _aliases_matrix_transpose,
7-
vecdot as _aliases_vecdot)
7+
vecdot as _aliases_vecdot, clip as _aliases_clip)
88
from .._internal import get_xp
99

1010
import torch
@@ -155,6 +155,7 @@ def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool:
155155
bitwise_or = _two_arg(torch.bitwise_or)
156156
bitwise_right_shift = _two_arg(torch.bitwise_right_shift)
157157
bitwise_xor = _two_arg(torch.bitwise_xor)
158+
copysign = _two_arg(torch.copysign)
158159
divide = _two_arg(torch.divide)
159160
# Also a rename. torch.equal does not broadcast
160161
equal = _two_arg(torch.eq)
@@ -188,6 +189,8 @@ def min(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keep
188189
return torch.clone(x)
189190
return torch.amin(x, axis, keepdims=keepdims)
190191

192+
clip = get_xp(torch)(_aliases_clip)
193+
191194
# torch.sort also returns a tuple
192195
# https://github.com/pytorch/pytorch/issues/70921
193196
def sort(x: array, /, *, axis: int = -1, descending: bool = False, stable: bool = True, **kwargs) -> array:
@@ -702,11 +705,11 @@ def take(x: array, indices: array, /, *, axis: Optional[int] = None, **kwargs) -
702705

703706
__all__ = ['result_type', 'can_cast', 'permute_dims', 'bitwise_invert',
704707
'newaxis', 'add', 'atan2', 'bitwise_and', 'bitwise_left_shift',
705-
'bitwise_or', 'bitwise_right_shift', 'bitwise_xor', 'divide',
706-
'equal', 'floor_divide', 'greater', 'greater_equal', 'less',
707-
'less_equal', 'logaddexp', 'multiply', 'not_equal', 'pow',
708-
'remainder', 'subtract', 'max', 'min', 'sort', 'prod', 'sum',
709-
'any', 'all', 'mean', 'std', 'var', 'concat', 'squeeze',
708+
'bitwise_or', 'bitwise_right_shift', 'bitwise_xor', 'copysign',
709+
'divide', 'equal', 'floor_divide', 'greater', 'greater_equal',
710+
'less', 'less_equal', 'logaddexp', 'multiply', 'not_equal', 'pow',
711+
'remainder', 'subtract', 'max', 'min', 'clip', 'sort', 'prod',
712+
'sum', 'any', 'all', 'mean', 'std', 'var', 'concat', 'squeeze',
710713
'broadcast_to', 'flip', 'roll', 'nonzero', 'where', 'reshape',
711714
'arange', 'eye', 'linspace', 'full', 'ones', 'zeros', 'empty',
712715
'tril', 'triu', 'expand_dims', 'astype', 'broadcast_arrays',

dask-xfails.txt

+3
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,9 @@ array_api_tests/test_special_cases.py::test_iop[__ipow__(x1_i is -0 and x2_i > 0
4848
array_api_tests/test_special_cases.py::test_iop[__ipow__(x1_i is -infinity and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +infinity]
4949
array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is -infinity and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +infinity]
5050

51+
# The clip helper uses boolean indexing
52+
array_api_tests/test_operators_and_elementwise_functions.py::test_clip
53+
5154
# No sorting in dask
5255
array_api_tests/test_has_names.py::test_has_names[sorting-argsort]
5356
array_api_tests/test_has_names.py::test_has_names[sorting-sort]

0 commit comments

Comments
 (0)