diff --git a/array_api_compat/cupy/_aliases.py b/array_api_compat/cupy/_aliases.py index fd1460ae..48f68f01 100644 --- a/array_api_compat/cupy/_aliases.py +++ b/array_api_compat/cupy/_aliases.py @@ -138,6 +138,11 @@ def count_nonzero( return result +# take_along_axis: axis defaults to -1 but in cupy (and numpy) axis is a required arg +def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1): + return cp.take_along_axis(x, indices, axis=axis) + + # These functions are completely new here. If the library already has them # (i.e., numpy 2.0), use the library version instead of our wrapper. if hasattr(cp, 'vecdot'): @@ -159,6 +164,7 @@ def count_nonzero( 'acos', 'acosh', 'asin', 'asinh', 'atan', 'atan2', 'atanh', 'bitwise_left_shift', 'bitwise_invert', 'bitwise_right_shift', - 'bool', 'concat', 'count_nonzero', 'pow', 'sign'] + 'bool', 'concat', 'count_nonzero', 'pow', 'sign', + 'take_along_axis'] _all_ignore = ['cp', 'get_xp'] diff --git a/array_api_compat/numpy/_aliases.py b/array_api_compat/numpy/_aliases.py index d8792611..a1aee5c0 100644 --- a/array_api_compat/numpy/_aliases.py +++ b/array_api_compat/numpy/_aliases.py @@ -140,6 +140,11 @@ def count_nonzero( return result +# take_along_axis: axis defaults to -1 but in numpy axis is a required arg +def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1): + return np.take_along_axis(x, indices, axis=axis) + + # These functions are completely new here. If the library already has them # (i.e., numpy 2.0), use the library version instead of our wrapper. if hasattr(np, "vecdot"): @@ -175,6 +180,7 @@ def count_nonzero( "concat", "count_nonzero", "pow", + "take_along_axis" ] __all__ += _aliases.__all__ _all_ignore = ["np", "get_xp"] diff --git a/dask-xfails.txt b/dask-xfails.txt index 932aeada..3efb4f96 100644 --- a/dask-xfails.txt +++ b/dask-xfails.txt @@ -24,6 +24,9 @@ array_api_tests/test_creation_functions.py::test_linspace # Shape mismatch array_api_tests/test_indexing_functions.py::test_take +# missing `take_along_axis`, https://github.com/dask/dask/issues/3663 +array_api_tests/test_indexing_functions.py::test_take_along_axis + # Array methods and attributes not already on da.Array cannot be wrapped array_api_tests/test_has_names.py::test_has_names[array_method-__array_namespace__] array_api_tests/test_has_names.py::test_has_names[array_method-to_device]