diff --git a/array_api_tests/test_inspection_functions.py b/array_api_tests/test_inspection_functions.py index cff25f99..c710c41d 100644 --- a/array_api_tests/test_inspection_functions.py +++ b/array_api_tests/test_inspection_functions.py @@ -1,28 +1,51 @@ import pytest from hypothesis import given, strategies as st -from array_api_tests.dtype_helpers import available_kinds +from array_api_tests.dtype_helpers import available_kinds, dtype_names from . import xp pytestmark = pytest.mark.min_version("2023.12") -def test_array_namespace_info(): - out = xp.__array_namespace_info__() +class TestInspection: + def test_capabilities(self): + out = xp.__array_namespace_info__() - capabilities = out.capabilities() - assert isinstance(capabilities, dict) + capabilities = out.capabilities() + assert isinstance(capabilities, dict) - out.default_device() + expected_attr = {"boolean indexing": bool, "data-dependent shapes": bool} + if xp.__array_api_version__ >= "2024.12": + expected_attr.update(**{"max dimensions": int}) - default_dtypes = out.default_dtypes() - assert isinstance(default_dtypes, dict) - expected_subset = {"real floating", "complex floating", "integral"} & available_kinds() | {"indexing"} - assert expected_subset.issubset(set(default_dtypes.keys())) + for attr, typ in expected_attr.items(): + assert attr in capabilities, f'capabilites is missing "{attr}".' + assert isinstance(capabilities[attr], typ) + + assert capabilities.get("max dimensions", 100500) > 0 + + def test_devices(self): + out = xp.__array_namespace_info__() + + assert hasattr(out, "devices") + assert hasattr(out, "default_device") + + assert isinstance(out.devices(), list) + assert out.default_device() in out.devices() + + def test_default_dtypes(self): + out = xp.__array_namespace_info__() + + for device in xp.__array_namespace_info__().devices(): + default_dtypes = out.default_dtypes(device=device) + assert isinstance(default_dtypes, dict) + expected_subset = ( + {"real floating", "complex floating", "integral"} + & available_kinds() + | {"indexing"} + ) + assert expected_subset.issubset(set(default_dtypes.keys())) - devices = out.devices() - assert isinstance(devices, list) - atomic_kinds = [ "bool", @@ -34,12 +57,21 @@ def test_array_namespace_info(): @given( - st.one_of( + kind=st.one_of( st.none(), st.sampled_from(atomic_kinds + ["integral", "numeric"]), st.lists(st.sampled_from(atomic_kinds), unique=True, min_size=1).map(tuple), + ), + device=st.one_of( + st.none(), + st.sampled_from(xp.__array_namespace_info__().devices()) ) ) -def test_array_namespace_info_dtypes(kind): - out = xp.__array_namespace_info__().dtypes(kind=kind) +def test_array_namespace_info_dtypes(kind, device): + out = xp.__array_namespace_info__().dtypes(kind=kind, device=device) assert isinstance(out, dict) + + for name, dtyp in out.items(): + assert name in dtype_names + xp.empty(1, dtype=dtyp, device=device) # check `dtyp` is a valid dtype +