Skip to content

Commit 93201f1

Browse files
committed
Fix passing of keyword arguments in __dlpack__ for NumPy 2.1
1 parent 67d9667 commit 93201f1

File tree

2 files changed

+33
-16
lines changed

2 files changed

+33
-16
lines changed

array_api_strict/_array_object.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -595,7 +595,15 @@ def __dlpack__(
595595
raise NotImplementedError("The copy argument to __dlpack__ is not yet implemented")
596596

597597
return self._array.__dlpack__(stream=stream)
598-
return self._array.__dlpack__(stream=stream, max_version=max_version, dl_device=dl_device, copy=copy)
598+
else:
599+
kwargs = {'stream': stream}
600+
if max_version is not _default:
601+
kwargs['max_version'] = max_version
602+
if dl_device is not _default:
603+
kwargs['dl_device'] = dl_device
604+
if copy is not _default:
605+
kwargs['copy'] = copy
606+
return self._array.__dlpack__(**kwargs)
599607

600608
def __dlpack_device__(self: Array, /) -> Tuple[IntEnum, int]:
601609
"""

array_api_strict/tests/test_array_object.py

+24-15
Original file line numberDiff line numberDiff line change
@@ -460,18 +460,27 @@ def dlpack_2023_12(api_version):
460460
a.__dlpack__()
461461

462462

463-
exception = NotImplementedError if api_version >= '2023.12' and np.__version__ < '2.1' else ValueError
464-
pytest.raises(exception, lambda:
465-
a.__dlpack__(dl_device=CPU_DEVICE))
466-
pytest.raises(exception, lambda:
467-
a.__dlpack__(dl_device=None))
468-
pytest.raises(exception, lambda:
469-
a.__dlpack__(max_version=(1, 0)))
470-
pytest.raises(exception, lambda:
471-
a.__dlpack__(max_version=None))
472-
pytest.raises(exception, lambda:
473-
a.__dlpack__(copy=False))
474-
pytest.raises(exception, lambda:
475-
a.__dlpack__(copy=True))
476-
pytest.raises(exception, lambda:
477-
a.__dlpack__(copy=None))
463+
if np.__version__ < '2.1':
464+
exception = NotImplementedError if api_version >= '2023.12' else ValueError
465+
pytest.raises(exception, lambda:
466+
a.__dlpack__(dl_device=CPU_DEVICE))
467+
pytest.raises(exception, lambda:
468+
a.__dlpack__(dl_device=None))
469+
pytest.raises(exception, lambda:
470+
a.__dlpack__(max_version=(1, 0)))
471+
pytest.raises(exception, lambda:
472+
a.__dlpack__(max_version=None))
473+
pytest.raises(exception, lambda:
474+
a.__dlpack__(copy=False))
475+
pytest.raises(exception, lambda:
476+
a.__dlpack__(copy=True))
477+
pytest.raises(exception, lambda:
478+
a.__dlpack__(copy=None))
479+
else:
480+
a.__dlpack__(dl_device=CPU_DEVICE)
481+
a.__dlpack__(dl_device=None)
482+
a.__dlpack__(max_version=(1, 0))
483+
a.__dlpack__(max_version=None)
484+
a.__dlpack__(copy=False)
485+
a.__dlpack__(copy=True)
486+
a.__dlpack__(copy=None)

0 commit comments

Comments
 (0)