We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 621494b commit fdf9489Copy full SHA for fdf9489
array_api_compat/numpy/_aliases.py
@@ -134,6 +134,13 @@ def count_nonzero(x: Array, axis=None, keepdims=False) -> Array:
134
return result
135
136
137
+# "axis=-1" is an optional argument of `take_along_axis` but numpy has no default
138
+def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1):
139
+ if axis is None:
140
+ axis = -1
141
+ return np.take_along_axis(x, indices, axis=axis)
142
+
143
144
# These functions are completely new here. If the library already has them
145
# (i.e., numpy 2.0), use the library version instead of our wrapper.
146
if hasattr(np, 'vecdot'):
@@ -155,6 +162,7 @@ def count_nonzero(x: Array, axis=None, keepdims=False) -> Array:
155
162
'acos', 'acosh', 'asin', 'asinh', 'atan',
156
163
'atan2', 'atanh', 'bitwise_left_shift',
157
164
'bitwise_invert', 'bitwise_right_shift',
158
- 'bool', 'concat', 'count_nonzero', 'pow']
165
+ 'bool', 'concat', 'count_nonzero', 'pow',
166
+ 'take_along_axis']
159
167
160
168
_all_ignore = ['np', 'get_xp']
0 commit comments