From bbf346c1b166e77f7da39fbcbc3e348b9e2598c8 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Mon, 21 Apr 2025 14:59:13 +0200 Subject: [PATCH 1/6] BUG: take_along_axis: numpy requires an axis --- array_api_compat/numpy/_aliases.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/array_api_compat/numpy/_aliases.py b/array_api_compat/numpy/_aliases.py index d8792611..0b2bf3e5 100644 --- a/array_api_compat/numpy/_aliases.py +++ b/array_api_compat/numpy/_aliases.py @@ -140,6 +140,13 @@ def count_nonzero( return result +# "axis=-1" is an optional argument of `take_along_axis` but numpy has no default +def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1): + if axis is None: + axis = -1 + return np.take_along_axis(x, indices, axis=axis) + + # These functions are completely new here. If the library already has them # (i.e., numpy 2.0), use the library version instead of our wrapper. if hasattr(np, "vecdot"): @@ -157,6 +164,7 @@ def count_nonzero( else: unstack = get_xp(np)(_aliases.unstack) + __all__ = [ "__array_namespace_info__", "asarray", @@ -175,6 +183,7 @@ def count_nonzero( "concat", "count_nonzero", "pow", + "take_along_axis" ] __all__ += _aliases.__all__ _all_ignore = ["np", "get_xp"] @@ -182,3 +191,4 @@ def count_nonzero( def __dir__() -> list[str]: return __all__ + From d8fa04afa589c70edb96ab47ce5ee9e94554a49a Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Mon, 21 Apr 2025 19:45:18 +0200 Subject: [PATCH 2/6] MAINT: move take_along_axis to common/_aliases --- array_api_compat/common/_aliases.py | 7 +++++++ array_api_compat/cupy/_aliases.py | 1 + array_api_compat/numpy/_aliases.py | 8 +------- 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/array_api_compat/common/_aliases.py b/array_api_compat/common/_aliases.py index 8ea9162a..3f24e65c 100644 --- a/array_api_compat/common/_aliases.py +++ b/array_api_compat/common/_aliases.py @@ -517,6 +517,12 @@ def sort( return res +# take_along_axis: axis defaults to -1; numpy, cupy do not have a default value; +# pytorch defaults to None, which ravels. +def take_along_axis(x: Array, indices: Array, /, *, xp: Namespace, axis: int = -1): + return xp.take_along_axis(x, indices, axis=axis) + + # nonzero should error for zero-dimensional arrays def nonzero(x: Array, /, xp: Namespace, **kwargs: object) -> tuple[Array, ...]: if x.ndim == 0: @@ -713,6 +719,7 @@ def iinfo(type_: DType | Array, /, xp: Namespace) -> Any: "matmul", "matrix_transpose", "tensordot", + "take_along_axis", "vecdot", "isdtype", "unstack", diff --git a/array_api_compat/cupy/_aliases.py b/array_api_compat/cupy/_aliases.py index fd1460ae..3467b52b 100644 --- a/array_api_compat/cupy/_aliases.py +++ b/array_api_compat/cupy/_aliases.py @@ -63,6 +63,7 @@ sign = get_xp(cp)(_aliases.sign) finfo = get_xp(cp)(_aliases.finfo) iinfo = get_xp(cp)(_aliases.iinfo) +take_along_axis = get_xp(cp)(_aliases.take_along_axis) _copy_default = object() diff --git a/array_api_compat/numpy/_aliases.py b/array_api_compat/numpy/_aliases.py index 0b2bf3e5..7071d3f9 100644 --- a/array_api_compat/numpy/_aliases.py +++ b/array_api_compat/numpy/_aliases.py @@ -72,6 +72,7 @@ sign = get_xp(np)(_aliases.sign) finfo = get_xp(np)(_aliases.finfo) iinfo = get_xp(np)(_aliases.iinfo) +take_along_axis = get_xp(np)(_aliases.take_along_axis) def _supports_buffer_protocol(obj: object) -> TypeIs[Buffer]: # pyright: ignore[reportUnusedFunction] @@ -140,13 +141,6 @@ def count_nonzero( return result -# "axis=-1" is an optional argument of `take_along_axis` but numpy has no default -def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1): - if axis is None: - axis = -1 - return np.take_along_axis(x, indices, axis=axis) - - # These functions are completely new here. If the library already has them # (i.e., numpy 2.0), use the library version instead of our wrapper. if hasattr(np, "vecdot"): From 0aaefd6239a4db667fcfdae4b01164cf16e157fc Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Mon, 21 Apr 2025 19:57:39 +0200 Subject: [PATCH 3/6] BUG: cannot add take_along_axis to common/_aliases because of dask --- array_api_compat/common/_aliases.py | 7 ------- array_api_compat/cupy/_aliases.py | 9 +++++++-- array_api_compat/numpy/_aliases.py | 8 +++++--- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/array_api_compat/common/_aliases.py b/array_api_compat/common/_aliases.py index 3f24e65c..8ea9162a 100644 --- a/array_api_compat/common/_aliases.py +++ b/array_api_compat/common/_aliases.py @@ -517,12 +517,6 @@ def sort( return res -# take_along_axis: axis defaults to -1; numpy, cupy do not have a default value; -# pytorch defaults to None, which ravels. -def take_along_axis(x: Array, indices: Array, /, *, xp: Namespace, axis: int = -1): - return xp.take_along_axis(x, indices, axis=axis) - - # nonzero should error for zero-dimensional arrays def nonzero(x: Array, /, xp: Namespace, **kwargs: object) -> tuple[Array, ...]: if x.ndim == 0: @@ -719,7 +713,6 @@ def iinfo(type_: DType | Array, /, xp: Namespace) -> Any: "matmul", "matrix_transpose", "tensordot", - "take_along_axis", "vecdot", "isdtype", "unstack", diff --git a/array_api_compat/cupy/_aliases.py b/array_api_compat/cupy/_aliases.py index 3467b52b..37e2b8b8 100644 --- a/array_api_compat/cupy/_aliases.py +++ b/array_api_compat/cupy/_aliases.py @@ -63,7 +63,6 @@ sign = get_xp(cp)(_aliases.sign) finfo = get_xp(cp)(_aliases.finfo) iinfo = get_xp(cp)(_aliases.iinfo) -take_along_axis = get_xp(cp)(_aliases.take_along_axis) _copy_default = object() @@ -139,6 +138,11 @@ def count_nonzero( return result +# take_along_axis: axis defaults to -1 but in cupy (and numpy) axis is a required arg +def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1): + return xp.take_along_axis(x, indices, axis=axis) + + # These functions are completely new here. If the library already has them # (i.e., numpy 2.0), use the library version instead of our wrapper. if hasattr(cp, 'vecdot'): @@ -160,6 +164,7 @@ def count_nonzero( 'acos', 'acosh', 'asin', 'asinh', 'atan', 'atan2', 'atanh', 'bitwise_left_shift', 'bitwise_invert', 'bitwise_right_shift', - 'bool', 'concat', 'count_nonzero', 'pow', 'sign'] + 'bool', 'concat', 'count_nonzero', 'pow', 'sign', + 'take_along_axis'] _all_ignore = ['cp', 'get_xp'] diff --git a/array_api_compat/numpy/_aliases.py b/array_api_compat/numpy/_aliases.py index 7071d3f9..97c75a6d 100644 --- a/array_api_compat/numpy/_aliases.py +++ b/array_api_compat/numpy/_aliases.py @@ -72,7 +72,6 @@ sign = get_xp(np)(_aliases.sign) finfo = get_xp(np)(_aliases.finfo) iinfo = get_xp(np)(_aliases.iinfo) -take_along_axis = get_xp(np)(_aliases.take_along_axis) def _supports_buffer_protocol(obj: object) -> TypeIs[Buffer]: # pyright: ignore[reportUnusedFunction] @@ -141,6 +140,11 @@ def count_nonzero( return result +# take_along_axis: axis defaults to -1 but in numpy axis is a required arg +def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1): + return xp.take_along_axis(x, indices, axis=axis) + + # These functions are completely new here. If the library already has them # (i.e., numpy 2.0), use the library version instead of our wrapper. if hasattr(np, "vecdot"): @@ -158,7 +162,6 @@ def count_nonzero( else: unstack = get_xp(np)(_aliases.unstack) - __all__ = [ "__array_namespace_info__", "asarray", @@ -185,4 +188,3 @@ def count_nonzero( def __dir__() -> list[str]: return __all__ - From 411ee89f4ded7d0c7204f05d210a8cb5bf554c8f Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Mon, 21 Apr 2025 20:00:52 +0200 Subject: [PATCH 4/6] . --- array_api_compat/cupy/_aliases.py | 2 +- array_api_compat/numpy/_aliases.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/array_api_compat/cupy/_aliases.py b/array_api_compat/cupy/_aliases.py index 37e2b8b8..b70f0fc1 100644 --- a/array_api_compat/cupy/_aliases.py +++ b/array_api_compat/cupy/_aliases.py @@ -140,7 +140,7 @@ def count_nonzero( # take_along_axis: axis defaults to -1 but in cupy (and numpy) axis is a required arg def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1): - return xp.take_along_axis(x, indices, axis=axis) + return np.take_along_axis(x, indices, axis=axis) # These functions are completely new here. If the library already has them diff --git a/array_api_compat/numpy/_aliases.py b/array_api_compat/numpy/_aliases.py index 97c75a6d..a1aee5c0 100644 --- a/array_api_compat/numpy/_aliases.py +++ b/array_api_compat/numpy/_aliases.py @@ -142,7 +142,7 @@ def count_nonzero( # take_along_axis: axis defaults to -1 but in numpy axis is a required arg def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1): - return xp.take_along_axis(x, indices, axis=axis) + return np.take_along_axis(x, indices, axis=axis) # These functions are completely new here. If the library already has them From e0359ffb3a057162bd0316d160e1479c423cb82e Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Mon, 21 Apr 2025 20:03:02 +0200 Subject: [PATCH 5/6] .. --- array_api_compat/cupy/_aliases.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/array_api_compat/cupy/_aliases.py b/array_api_compat/cupy/_aliases.py index b70f0fc1..48f68f01 100644 --- a/array_api_compat/cupy/_aliases.py +++ b/array_api_compat/cupy/_aliases.py @@ -140,7 +140,7 @@ def count_nonzero( # take_along_axis: axis defaults to -1 but in cupy (and numpy) axis is a required arg def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1): - return np.take_along_axis(x, indices, axis=axis) + return cp.take_along_axis(x, indices, axis=axis) # These functions are completely new here. If the library already has them From 7596ce62f4e7d3f6b62a7d514fa9ce1ebc109c83 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Mon, 21 Apr 2025 20:29:23 +0200 Subject: [PATCH 6/6] add a skip for dask.array --- dask-xfails.txt | 3 +++ 1 file changed, 3 insertions(+) diff --git a/dask-xfails.txt b/dask-xfails.txt index 932aeada..3efb4f96 100644 --- a/dask-xfails.txt +++ b/dask-xfails.txt @@ -24,6 +24,9 @@ array_api_tests/test_creation_functions.py::test_linspace # Shape mismatch array_api_tests/test_indexing_functions.py::test_take +# missing `take_along_axis`, https://github.com/dask/dask/issues/3663 +array_api_tests/test_indexing_functions.py::test_take_along_axis + # Array methods and attributes not already on da.Array cannot be wrapped array_api_tests/test_has_names.py::test_has_names[array_method-__array_namespace__] array_api_tests/test_has_names.py::test_has_names[array_method-to_device]