|
12 | 12 | from typing import NamedTuple
|
13 | 13 | import inspect
|
14 | 14 |
|
15 |
| -from ._helpers import _check_device |
| 15 | +from ._helpers import array_namespace, _check_device |
16 | 16 |
|
17 | 17 | # These functions are modified from the NumPy versions.
|
18 | 18 |
|
@@ -264,6 +264,66 @@ def var(
|
264 | 264 | ) -> ndarray:
|
265 | 265 | return xp.var(x, axis=axis, ddof=correction, keepdims=keepdims, **kwargs)
|
266 | 266 |
|
| 267 | + |
| 268 | +# The min and max argument names in clip are different and not optional in numpy, and type |
| 269 | +# promotion behavior is different. |
| 270 | +def clip( |
| 271 | + x: ndarray, |
| 272 | + /, |
| 273 | + min: Optional[Union[int, float, ndarray]] = None, |
| 274 | + max: Optional[Union[int, float, ndarray]] = None, |
| 275 | + *, |
| 276 | + xp, |
| 277 | + # TODO: np.clip has other ufunc kwargs |
| 278 | + out: Optional[ndarray] = None, |
| 279 | +) -> ndarray: |
| 280 | + def _isscalar(a): |
| 281 | + return isinstance(a, (int, float, type(None))) |
| 282 | + min_shape = () if _isscalar(min) else min.shape |
| 283 | + max_shape = () if _isscalar(max) else max.shape |
| 284 | + result_shape = xp.broadcast_shapes(x.shape, min_shape, max_shape) |
| 285 | + |
| 286 | + wrapped_xp = array_namespace(x) |
| 287 | + |
| 288 | + # np.clip does type promotion but the array API clip requires that the |
| 289 | + # output have the same dtype as x. We do this instead of just downcasting |
| 290 | + # the result of xp.clip() to handle some corner cases better (e.g., |
| 291 | + # avoiding uint64 -> float64 promotion). |
| 292 | + |
| 293 | + # Note: cases where min or max overflow (integer) or round (float) in the |
| 294 | + # wrong direction when downcasting to x.dtype are unspecified. This code |
| 295 | + # just does whatever NumPy does when it downcasts in the assignment, but |
| 296 | + # other behavior could be preferred, especially for integers. For example, |
| 297 | + # this code produces: |
| 298 | + |
| 299 | + # >>> clip(asarray(0, dtype=int8), asarray(128, dtype=int16), None) |
| 300 | + # -128 |
| 301 | + |
| 302 | + # but an answer of 0 might be preferred. See |
| 303 | + # https://github.com/numpy/numpy/issues/24976 for more discussion on this issue. |
| 304 | + |
| 305 | + |
| 306 | + # At least handle the case of Python integers correctly (see |
| 307 | + # https://github.com/numpy/numpy/pull/26892). |
| 308 | + if type(min) is int and min <= xp.iinfo(x.dtype).min: |
| 309 | + min = None |
| 310 | + if type(max) is int and max >= xp.iinfo(x.dtype).max: |
| 311 | + max = None |
| 312 | + |
| 313 | + if out is None: |
| 314 | + out = wrapped_xp.asarray(xp.broadcast_to(x, result_shape), copy=True) |
| 315 | + if min is not None: |
| 316 | + a = xp.broadcast_to(xp.asarray(min), result_shape) |
| 317 | + ia = (out < a) | xp.isnan(a) |
| 318 | + # torch requires an explicit cast here |
| 319 | + out[ia] = wrapped_xp.astype(a[ia], out.dtype) |
| 320 | + if max is not None: |
| 321 | + b = xp.broadcast_to(xp.asarray(max), result_shape) |
| 322 | + ib = (out > b) | xp.isnan(b) |
| 323 | + out[ib] = wrapped_xp.astype(b[ib], out.dtype) |
| 324 | + # Return a scalar for 0-D |
| 325 | + return out[()] |
| 326 | + |
267 | 327 | # Unlike transpose(), the axes argument to permute_dims() is required.
|
268 | 328 | def permute_dims(x: ndarray, /, axes: Tuple[int, ...], xp) -> ndarray:
|
269 | 329 | return xp.transpose(x, axes)
|
@@ -465,6 +525,6 @@ def isdtype(
|
465 | 525 | 'linspace', 'ones', 'ones_like', 'zeros', 'zeros_like',
|
466 | 526 | 'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult',
|
467 | 527 | 'unique_all', 'unique_counts', 'unique_inverse', 'unique_values',
|
468 |
| - 'astype', 'std', 'var', 'permute_dims', 'reshape', 'argsort', |
| 528 | + 'astype', 'std', 'var', 'clip', 'permute_dims', 'reshape', 'argsort', |
469 | 529 | 'sort', 'nonzero', 'sum', 'prod', 'ceil', 'floor', 'trunc',
|
470 | 530 | 'matmul', 'matrix_transpose', 'tensordot', 'vecdot', 'isdtype']
|
0 commit comments