|
4 | 4 | from builtins import all as _builtin_all, any as _builtin_any
|
5 | 5 |
|
6 | 6 | from ..common._aliases import (matrix_transpose as _aliases_matrix_transpose,
|
7 |
| - vecdot as _aliases_vecdot) |
| 7 | + vecdot as _aliases_vecdot, clip as _aliases_clip) |
8 | 8 | from .._internal import get_xp
|
9 | 9 |
|
10 | 10 | import torch
|
@@ -189,6 +189,8 @@ def min(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keep
|
189 | 189 | return torch.clone(x)
|
190 | 190 | return torch.amin(x, axis, keepdims=keepdims)
|
191 | 191 |
|
| 192 | +clip = get_xp(torch)(_aliases_clip) |
| 193 | + |
192 | 194 | # torch.sort also returns a tuple
|
193 | 195 | # https://github.com/pytorch/pytorch/issues/70921
|
194 | 196 | def sort(x: array, /, *, axis: int = -1, descending: bool = False, stable: bool = True, **kwargs) -> array:
|
@@ -706,8 +708,8 @@ def take(x: array, indices: array, /, *, axis: Optional[int] = None, **kwargs) -
|
706 | 708 | 'bitwise_or', 'bitwise_right_shift', 'bitwise_xor', 'copysign',
|
707 | 709 | 'divide', 'equal', 'floor_divide', 'greater', 'greater_equal',
|
708 | 710 | 'less', 'less_equal', 'logaddexp', 'multiply', 'not_equal', 'pow',
|
709 |
| - 'remainder', 'subtract', 'max', 'min', 'sort', 'prod', 'sum', |
710 |
| - 'any', 'all', 'mean', 'std', 'var', 'concat', 'squeeze', |
| 711 | + 'remainder', 'subtract', 'max', 'min', 'clip', 'sort', 'prod', |
| 712 | + 'sum', 'any', 'all', 'mean', 'std', 'var', 'concat', 'squeeze', |
711 | 713 | 'broadcast_to', 'flip', 'roll', 'nonzero', 'where', 'reshape',
|
712 | 714 | 'arange', 'eye', 'linspace', 'full', 'ones', 'zeros', 'empty',
|
713 | 715 | 'tril', 'triu', 'expand_dims', 'astype', 'broadcast_arrays',
|
|
0 commit comments