Skip to content

Commit 548b321

Browse files
mlloreda9prady9
authored andcommitted
Added topk() support from v3.6.0.
1 parent 24ec3fb commit 548b321

File tree

3 files changed

+54
-0
lines changed

3 files changed

+54
-0
lines changed

arrayfire/library.py

+8
Original file line numberDiff line numberDiff line change
@@ -438,6 +438,14 @@ class DIFFUSION(_Enum):
438438
GRAD = _Enum_Type(1)
439439
MCDE = _Enum_Type(2)
440440

441+
class TOPK(_Enum):
442+
"""
443+
Top-K ordering
444+
"""
445+
DEFAULT = _Enum_Type(0)
446+
MIN = _Enum_Type(1)
447+
MAX = _Enum_Type(2)
448+
441449
_VER_MAJOR_PLACEHOLDER = "__VER_MAJOR__"
442450

443451
def _setup():

arrayfire/statistics.py

+37
Original file line numberDiff line numberDiff line change
@@ -108,3 +108,40 @@ def corrcoef(x, y):
108108
real = real.value
109109
imag = imag.value
110110
return real if imag == 0 else real + imag * 1j
111+
112+
def topk(data, k, dim=0, order=TOPK.DEFAULT):
113+
"""
114+
Return top k elements along a single dimension.
115+
116+
Parameters
117+
----------
118+
119+
data: af.Array
120+
Input array to return k elements from.
121+
122+
k: scalar. default: 0
123+
The number of elements to return from input array.
124+
125+
dim: optional: scalar. default: 0
126+
The dimension along which the top k elements are
127+
extracted. Note: at the moment, topk() only supports the
128+
extraction of values along the first dimension.
129+
130+
order: optional: af.TOPK. default: af.TOPK.DEFAULT
131+
The ordering of k extracted elements. Defaults to top k max values.
132+
133+
Returns
134+
-------
135+
136+
values: af.Array
137+
Top k elements from input array.
138+
indices: af.Array
139+
Corresponding index array to top k elements.
140+
"""
141+
142+
values = Array()
143+
indices = Array()
144+
145+
safe_call(backend.get().af_topk(c_pointer(values.arr), c_pointer(indices.arr), data.arr, k, c_int_t(dim), order.value))
146+
147+
return values,indices

arrayfire/tests/simple/statistics.py

+9
Original file line numberDiff line numberDiff line change
@@ -44,4 +44,13 @@ def simple_statistics(verbose=False):
4444

4545
print_func(af.corrcoef(a, b))
4646

47+
data = af.iota(5, 3)
48+
k = 3
49+
dim = 0
50+
order = af.TOPK.DEFAULT # defaults to af.TOPK.MAX
51+
assert(dim == 0) # topk currently supports first dim only
52+
values,indices = af.topk(data, k, dim, order)
53+
display_func(values)
54+
display_func(indices)
55+
4756
_util.tests['statistics'] = simple_statistics

0 commit comments

Comments
 (0)