Skip to content

Commit 3f97d53

Browse files
committed
test(types): Test SpatialImage API type inference
1 parent d310807 commit 3f97d53

File tree

4 files changed

+95
-0
lines changed

4 files changed

+95
-0
lines changed

pyproject.toml

+1
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ test = [
7676
"pytest-httpserver >=1.0.7",
7777
"pytest-xdist >=3.5",
7878
"coverage[toml]>=7.2",
79+
"pytest-mypy-testing>=0.1.3",
7980
]
8081
# Remaining: Simpler to centralize in tox
8182
dev = ["tox"]

tests/ruff.toml

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
line-length = 200

tests/typing/test_spatialimage_api.py

+80
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
import typing as ty
2+
3+
import numpy as np
4+
import pytest
5+
6+
from nibabel import AnalyzeImage, Spm99AnalyzeImage, Spm2AnalyzeImage, Nifti1Image, Nifti2Image, MGHImage
7+
from nibabel.spatialimages import SpatialImage
8+
9+
if ty.TYPE_CHECKING:
10+
from typing import reveal_type
11+
else:
12+
13+
def reveal_type(x: ty.Any) -> None:
14+
pass
15+
16+
17+
@pytest.mark.mypy_testing
18+
def test_affine_tracking() -> None:
19+
img_with_affine = SpatialImage(np.empty((5, 5, 5)), np.eye(4))
20+
img_without_affine = SpatialImage(np.empty((5, 5, 5)), None)
21+
22+
reveal_type(img_with_affine) # R: nibabel.spatialimages.SpatialImage[numpy.ndarray[builtins.tuple[builtins.int, ...], numpy.dtype[numpy.float64]]]
23+
reveal_type(img_without_affine) # R: nibabel.spatialimages.SpatialImage[None]
24+
25+
26+
@pytest.mark.mypy_testing
27+
def test_SpatialImageAPI() -> None:
28+
img = SpatialImage(np.empty((5, 5, 5)), np.eye(4))
29+
30+
# Affine
31+
reveal_type(img.affine) # R: numpy.ndarray[builtins.tuple[builtins.int, ...], numpy.dtype[numpy.float64]]
32+
reveal_type(SpatialImage(np.empty((5, 5, 5)), None).affine) # R: None
33+
34+
# Data
35+
reveal_type(img.dataobj) # R: nibabel.arrayproxy.ArrayLike
36+
reveal_type(img.get_fdata()) # R: numpy.ndarray[builtins.tuple[builtins.int, ...], numpy.dtype[numpy.float64]]
37+
reveal_type(img.get_fdata(dtype=np.float32)) # R: numpy.ndarray[builtins.tuple[builtins.int, ...], numpy.dtype[numpy.floating[numpy._typing._nbit_base._32Bit]]]
38+
reveal_type(img.get_fdata(dtype=np.float64)) # R: numpy.ndarray[builtins.tuple[builtins.int, ...], numpy.dtype[numpy.float64]]
39+
reveal_type(img.get_fdata(dtype=np.dtype(np.float32))) # R: numpy.ndarray[builtins.tuple[builtins.int, ...], numpy.dtype[numpy.floating[numpy._typing._nbit_base._32Bit]]]
40+
reveal_type(img.get_fdata(dtype=np.dtype(np.float64))) # R: numpy.ndarray[builtins.tuple[builtins.int, ...], numpy.dtype[numpy.float64]]
41+
reveal_type(img.get_fdata(dtype=np.dtype("f4"))) # R: numpy.ndarray[builtins.tuple[builtins.int, ...], numpy.dtype[numpy.floating[numpy._typing._nbit_base._32Bit]]]
42+
reveal_type(img.get_fdata(dtype=np.dtype("f8"))) # R: numpy.ndarray[builtins.tuple[builtins.int, ...], numpy.dtype[numpy.float64]]
43+
reveal_type(img.get_fdata(dtype="f4")) # R: numpy.ndarray[builtins.tuple[builtins.int, ...], numpy.dtype[numpy.floating[numpy._typing._nbit_base._32Bit]]]
44+
reveal_type(img.get_fdata(dtype="f8")) # R: numpy.ndarray[builtins.tuple[builtins.int, ...], numpy.dtype[numpy.float64]]
45+
46+
# Indirect header
47+
reveal_type(img.shape) # R: builtins.tuple[builtins.int, ...]
48+
reveal_type(img.ndim) # R: builtins.int
49+
50+
# SpatialHeader fields
51+
reveal_type(img.header.get_data_dtype()) # R: numpy.dtype[Any]
52+
reveal_type(img.header.get_data_shape()) # R: builtins.tuple[builtins.int, ...]
53+
reveal_type(img.header.get_zooms()) # R: builtins.tuple[builtins.float, ...]
54+
55+
56+
@pytest.mark.mypy_testing
57+
def test_image_and_header_types() -> None:
58+
analyze_img = AnalyzeImage(np.empty((5, 5, 5)), np.eye(4))
59+
reveal_type(analyze_img) # R: nibabel.analyze.AnalyzeImage[numpy.ndarray[builtins.tuple[builtins.int, ...], numpy.dtype[numpy.float64]]]
60+
reveal_type(analyze_img.header) # R: nibabel.analyze.AnalyzeHeader
61+
62+
spm99_img = Spm99AnalyzeImage(np.empty((5, 5, 5)), np.eye(4))
63+
reveal_type(spm99_img) # R: nibabel.spm99analyze.Spm99AnalyzeImage[numpy.ndarray[builtins.tuple[builtins.int, ...], numpy.dtype[numpy.float64]]]
64+
reveal_type(spm99_img.header) # R: nibabel.spm99analyze.Spm99AnalyzeHeader
65+
66+
spm2_img = Spm2AnalyzeImage(np.empty((5, 5, 5)), np.eye(4))
67+
reveal_type(spm2_img) # R: nibabel.spm2analyze.Spm2AnalyzeImage[numpy.ndarray[builtins.tuple[builtins.int, ...], numpy.dtype[numpy.float64]]]
68+
reveal_type(spm2_img.header) # R: nibabel.spm2analyze.Spm2AnalyzeHeader
69+
70+
ni1_img = Nifti1Image(np.empty((5, 5, 5)), np.eye(4))
71+
reveal_type(ni1_img) # R: nibabel.nifti1.Nifti1Image[numpy.ndarray[builtins.tuple[builtins.int, ...], numpy.dtype[numpy.float64]]]
72+
reveal_type(ni1_img.header) # R: nibabel.nifti1.Nifti1Header
73+
74+
ni2_img = Nifti2Image(np.empty((5, 5, 5)), np.eye(4))
75+
reveal_type(ni2_img) # R: nibabel.nifti2.Nifti2Image[numpy.ndarray[builtins.tuple[builtins.int, ...], numpy.dtype[numpy.float64]]]
76+
reveal_type(ni2_img.header) # R: nibabel.nifti2.Nifti2Header
77+
78+
mgh_img = MGHImage(np.empty((5, 5, 5), dtype=np.float32), np.eye(4))
79+
reveal_type(mgh_img) # R: nibabel.freesurfer.mghformat.MGHImage
80+
reveal_type(mgh_img.header) # R: nibabel.freesurfer.mghformat.MGHHeader

tox.ini

+13
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ envlist =
1717
doctest
1818
style
1919
typecheck
20+
type-inference
2021
skip_missing_interpreters = true
2122

2223
# Configuration that allows us to split tests across GitHub runners effectively
@@ -179,6 +180,18 @@ skip_install = true
179180
commands =
180181
mypy nibabel
181182

183+
[testenv:type-inference]
184+
description = Check type inference
185+
labels = test
186+
deps =
187+
pytest-mypy-testing @ git+https://github.com/effigies/pytest-mypy-testing@rf/global-mypy-run
188+
commands =
189+
python -m pytest \
190+
--cov tests --cov nibabel --cov-report xml:cov.xml \
191+
--junitxml test-results.xml \
192+
--durations=20 --durations-min=1.0 \
193+
tests/ {posargs:-n auto}
194+
182195
[testenv:build{,-strict}]
183196
labels =
184197
check

0 commit comments

Comments
 (0)