Skip to content

Commit b8a59da

Browse files
committed
Wrap clip() for torch
1 parent 7da965c commit b8a59da

File tree

2 files changed

+8
-5
lines changed

2 files changed

+8
-5
lines changed

array_api_compat/common/_aliases.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -315,11 +315,12 @@ def _isscalar(a):
315315
if min is not None:
316316
a = xp.broadcast_to(xp.asarray(min), result_shape)
317317
ia = (out < a) | xp.isnan(a)
318-
out[ia] = a[ia]
318+
# torch requires an explicit cast here
319+
out[ia] = wrapped_xp.astype(a[ia], out.dtype)
319320
if max is not None:
320321
b = xp.broadcast_to(xp.asarray(max), result_shape)
321322
ib = (out > b) | xp.isnan(b)
322-
out[ib] = b[ib]
323+
out[ib] = wrapped_xp.astype(b[ib], out.dtype)
323324
# Return a scalar for 0-D
324325
return out[()]
325326

array_api_compat/torch/_aliases.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from builtins import all as _builtin_all, any as _builtin_any
55

66
from ..common._aliases import (matrix_transpose as _aliases_matrix_transpose,
7-
vecdot as _aliases_vecdot)
7+
vecdot as _aliases_vecdot, clip as _aliases_clip)
88
from .._internal import get_xp
99

1010
import torch
@@ -189,6 +189,8 @@ def min(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keep
189189
return torch.clone(x)
190190
return torch.amin(x, axis, keepdims=keepdims)
191191

192+
clip = get_xp(torch)(_aliases_clip)
193+
192194
# torch.sort also returns a tuple
193195
# https://github.com/pytorch/pytorch/issues/70921
194196
def sort(x: array, /, *, axis: int = -1, descending: bool = False, stable: bool = True, **kwargs) -> array:
@@ -706,8 +708,8 @@ def take(x: array, indices: array, /, *, axis: Optional[int] = None, **kwargs) -
706708
'bitwise_or', 'bitwise_right_shift', 'bitwise_xor', 'copysign',
707709
'divide', 'equal', 'floor_divide', 'greater', 'greater_equal',
708710
'less', 'less_equal', 'logaddexp', 'multiply', 'not_equal', 'pow',
709-
'remainder', 'subtract', 'max', 'min', 'sort', 'prod', 'sum',
710-
'any', 'all', 'mean', 'std', 'var', 'concat', 'squeeze',
711+
'remainder', 'subtract', 'max', 'min', 'clip', 'sort', 'prod',
712+
'sum', 'any', 'all', 'mean', 'std', 'var', 'concat', 'squeeze',
711713
'broadcast_to', 'flip', 'roll', 'nonzero', 'where', 'reshape',
712714
'arange', 'eye', 'linspace', 'full', 'ones', 'zeros', 'empty',
713715
'tril', 'triu', 'expand_dims', 'astype', 'broadcast_arrays',

0 commit comments

Comments
 (0)