Skip to content

Commit 1a34ccc

Browse files
authored
Merge pull request #160 from nipy/enh/autoload-linear
ENH: Guess open linear transform formats
2 parents 2fa335d + bb8bf32 commit 1a34ccc

File tree

8 files changed

+144
-101
lines changed

8 files changed

+144
-101
lines changed

nitransforms/io/__init__.py

+24-1
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,34 @@
11
# emacs: -*- mode: python-mode; py-indent-offset: 4; indent-tabs-mode: nil -*-
22
# vi: set ft=python sts=4 ts=4 sw=4 et:
33
"""Read and write transforms."""
4-
from . import afni, fsl, itk, lta
4+
from nitransforms.io import afni, fsl, itk, lta
5+
from nitransforms.io.base import TransformIOError, TransformFileError
56

67
__all__ = [
78
"afni",
89
"fsl",
910
"itk",
1011
"lta",
12+
"get_linear_factory",
13+
"TransformFileError",
14+
"TransformIOError",
1115
]
16+
17+
_IO_TYPES = {
18+
"itk": (itk, "ITKLinearTransform"),
19+
"ants": (itk, "ITKLinearTransform"),
20+
"elastix": (itk, "ITKLinearTransform"),
21+
"lta": (lta, "FSLinearTransform"),
22+
"fs": (lta, "FSLinearTransform"),
23+
"fsl": (fsl, "FSLLinearTransform"),
24+
"afni": (afni, "AFNILinearTransform"),
25+
}
26+
27+
28+
def get_linear_factory(fmt, is_array=True):
29+
"""Return the type required by a given format."""
30+
if fmt.lower() not in _IO_TYPES:
31+
raise TypeError(f"Unsupported transform format <{fmt}>.")
32+
33+
module, classname = _IO_TYPES[fmt.lower()]
34+
return getattr(module, f"{classname}{'Array' * is_array}")

nitransforms/io/afni.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -95,12 +95,16 @@ def from_string(cls, string):
9595
if not lines:
9696
raise TransformFileError
9797

98-
parameters = np.vstack(
99-
(
100-
np.genfromtxt([lines[0].encode()], dtype="f8").reshape((3, 4)),
101-
(0.0, 0.0, 0.0, 1.0),
98+
try:
99+
parameters = np.vstack(
100+
(
101+
np.genfromtxt([lines[0].encode()], dtype="f8").reshape((3, 4)),
102+
(0.0, 0.0, 0.0, 1.0),
103+
)
102104
)
103-
)
105+
except ValueError as e:
106+
raise TransformFileError from e
107+
104108
sa["parameters"] = parameters
105109
return tf
106110

nitransforms/io/base.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,12 @@
66
from ..patched import LabeledWrapStruct
77

88

9-
class TransformFileError(Exception):
10-
"""A custom exception for transform files."""
9+
class TransformIOError(IOError):
10+
"""General I/O exception while reading/writing transforms."""
11+
12+
13+
class TransformFileError(TransformIOError):
14+
"""Specific I/O exception when a file does not meet the expected format."""
1115

1216

1317
class StringBasedStruct(LabeledWrapStruct):

nitransforms/io/fsl.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
BaseLinearTransformList,
1111
LinearParameters,
1212
DisplacementsField,
13+
TransformIOError,
1314
TransformFileError,
1415
_ensure_image,
1516
)
@@ -40,7 +41,7 @@ def from_ras(cls, ras, moving=None, reference=None):
4041
moving = reference
4142

4243
if reference is None:
43-
raise ValueError("Cannot build FSL linear transform without a reference")
44+
raise TransformIOError("Cannot build FSL linear transform without a reference")
4445

4546
reference = _ensure_image(reference)
4647
moving = _ensure_image(moving)
@@ -77,7 +78,7 @@ def from_string(cls, string):
7778
def to_ras(self, moving=None, reference=None):
7879
"""Return a nitransforms internal RAS+ matrix."""
7980
if reference is None:
80-
raise ValueError("Cannot build FSL linear transform without a reference")
81+
raise TransformIOError("Cannot build FSL linear transform without a reference")
8182

8283
if moving is None:
8384
warnings.warn(

nitransforms/io/itk.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
BaseLinearTransformList,
99
DisplacementsField,
1010
LinearParameters,
11+
TransformIOError,
1112
TransformFileError,
1213
)
1314

@@ -306,7 +307,7 @@ def from_filename(cls, filename):
306307
from h5py import File as H5File
307308

308309
if not str(filename).endswith(".h5"):
309-
raise RuntimeError("Extension is not .h5")
310+
raise TransformFileError("Extension is not .h5")
310311

311312
with H5File(str(filename)) as f:
312313
return cls.from_h5obj(f)
@@ -354,7 +355,7 @@ def from_h5obj(cls, fileobj, check=True):
354355
)
355356
continue
356357

357-
raise NotImplementedError(
358+
raise TransformIOError(
358359
f"Unsupported transform type {xfm['TransformType'][0]}"
359360
)
360361

nitransforms/linear.py

+45-69
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,14 @@
1414

1515
from nibabel.loadsave import load as _nbload
1616

17-
from .base import (
17+
from nitransforms.base import (
1818
ImageGrid,
1919
TransformBase,
2020
SpatialReference,
2121
_as_homogeneous,
2222
EQUALITY_TOL,
2323
)
24-
from . import io
24+
from nitransforms.io import get_linear_factory, TransformFileError
2525

2626

2727
class Affine(TransformBase):
@@ -183,51 +183,40 @@ def _to_hdf5(self, x5_root):
183183
self.reference._to_hdf5(x5_root.create_group("Reference"))
184184

185185
def to_filename(self, filename, fmt="X5", moving=None):
186-
"""Store the transform in BIDS-Transforms HDF5 file format (.x5)."""
187-
if fmt.lower() in ["itk", "ants", "elastix"]:
188-
itkobj = io.itk.ITKLinearTransform.from_ras(self.matrix)
189-
itkobj.to_filename(filename)
190-
return filename
191-
192-
# Rest of the formats peek into moving and reference image grids
193-
moving = ImageGrid(moving) if moving is not None else self.reference
194-
195-
_factory = {
196-
"afni": io.afni.AFNILinearTransform,
197-
"fsl": io.fsl.FSLLinearTransform,
198-
"lta": io.lta.FSLinearTransform,
199-
"fs": io.lta.FSLinearTransform,
200-
}
201-
202-
if fmt not in _factory:
203-
raise NotImplementedError(f"Unsupported format <{fmt}>")
204-
205-
_factory[fmt].from_ras(
206-
self.matrix, moving=moving, reference=self.reference
207-
).to_filename(filename)
208-
return filename
186+
"""Store the transform in the requested output format."""
187+
writer = get_linear_factory(fmt, is_array=False)
209188

210-
@classmethod
211-
def from_filename(cls, filename, fmt="X5", reference=None, moving=None):
212-
"""Create an affine from a transform file."""
213189
if fmt.lower() in ("itk", "ants", "elastix"):
214-
_factory = io.itk.ITKLinearTransformArray
215-
elif fmt.lower() in ("lta", "fs"):
216-
_factory = io.lta.FSLinearTransformArray
217-
elif fmt.lower() == "fsl":
218-
_factory = io.fsl.FSLLinearTransformArray
219-
elif fmt.lower() == "afni":
220-
_factory = io.afni.AFNILinearTransformArray
190+
writer.from_ras(self.matrix).to_filename(filename)
221191
else:
222-
raise NotImplementedError
192+
# Rest of the formats peek into moving and reference image grids
193+
writer.from_ras(
194+
self.matrix,
195+
reference=self.reference,
196+
moving=ImageGrid(moving) if moving is not None else self.reference,
197+
).to_filename(filename)
198+
return filename
223199

224-
struct = _factory.from_filename(filename)
225-
matrix = struct.to_ras(reference=reference, moving=moving)
226-
if cls == Affine:
227-
if np.shape(matrix)[0] != 1:
228-
raise TypeError("Cannot load transform array '%s'" % filename)
229-
matrix = matrix[0]
230-
return cls(matrix, reference=reference)
200+
@classmethod
201+
def from_filename(cls, filename, fmt=None, reference=None, moving=None):
202+
"""Create an affine from a transform file."""
203+
fmtlist = [fmt] if fmt is not None else ("itk", "lta", "afni", "fsl")
204+
205+
for potential_fmt in fmtlist:
206+
try:
207+
struct = get_linear_factory(potential_fmt).from_filename(filename)
208+
matrix = struct.to_ras(reference=reference, moving=moving)
209+
if cls == Affine:
210+
if np.shape(matrix)[0] != 1:
211+
raise TypeError("Cannot load transform array '%s'" % filename)
212+
matrix = matrix[0]
213+
return cls(matrix, reference=reference)
214+
except (TransformFileError, FileNotFoundError):
215+
continue
216+
217+
raise TransformFileError(
218+
f"Could not open <{filename}> (formats tried: {', '.join(fmtlist)})."
219+
)
231220

232221
def __repr__(self):
233222
"""
@@ -353,31 +342,18 @@ def map(self, x, inverse=False):
353342
return np.swapaxes(affine.dot(coords), 1, 2)
354343

355344
def to_filename(self, filename, fmt="X5", moving=None):
356-
"""Store the transform in BIDS-Transforms HDF5 file format (.x5)."""
357-
if fmt.lower() in ("itk", "ants", "elastix"):
358-
itkobj = io.itk.ITKLinearTransformArray.from_ras(self.matrix)
359-
itkobj.to_filename(filename)
360-
return filename
345+
"""Store the transform in the requested output format."""
346+
writer = get_linear_factory(fmt, is_array=True)
361347

362-
# Rest of the formats peek into moving and reference image grids
363-
if moving is not None:
364-
moving = ImageGrid(moving)
348+
if fmt.lower() in ("itk", "ants", "elastix"):
349+
writer.from_ras(self.matrix).to_filename(filename)
365350
else:
366-
moving = self.reference
367-
368-
_factory = {
369-
"afni": io.afni.AFNILinearTransformArray,
370-
"fsl": io.fsl.FSLLinearTransformArray,
371-
"lta": io.lta.FSLinearTransformArray,
372-
"fs": io.lta.FSLinearTransformArray,
373-
}
374-
375-
if fmt not in _factory:
376-
raise NotImplementedError(f"Unsupported format <{fmt}>")
377-
378-
_factory[fmt].from_ras(
379-
self.matrix, moving=moving, reference=self.reference
380-
).to_filename(filename)
351+
# Rest of the formats peek into moving and reference image grids
352+
writer.from_ras(
353+
self.matrix,
354+
reference=self.reference,
355+
moving=ImageGrid(moving) if moving is not None else self.reference,
356+
).to_filename(filename)
381357
return filename
382358

383359
def apply(
@@ -486,17 +462,17 @@ def apply(
486462
return resampled
487463

488464

489-
def load(filename, fmt="X5", reference=None, moving=None):
465+
def load(filename, fmt=None, reference=None, moving=None):
490466
"""
491467
Load a linear transform file.
492468
493469
Examples
494470
--------
495-
>>> xfm = load(regress_dir / "affine-LAS.itk.tfm", fmt="itk")
471+
>>> xfm = load(regress_dir / "affine-LAS.itk.tfm")
496472
>>> isinstance(xfm, Affine)
497473
True
498474
499-
>>> xfm = load(regress_dir / "itktflist.tfm", fmt="itk")
475+
>>> xfm = load(regress_dir / "itktflist.tfm")
500476
>>> isinstance(xfm, LinearTransformsMapping)
501477
True
502478

nitransforms/tests/test_io.py

+21-4
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
# emacs: -*- mode: python-mode; py-indent-offset: 4; indent-tabs-mode: nil -*-
22
# vi: set ft=python sts=4 ts=4 sw=4 et:
33
"""I/O test cases."""
4+
import os
45
from subprocess import check_call
56
from io import StringIO
67
import filecmp
78
import shutil
89
import numpy as np
910
import pytest
11+
from h5py import File as H5File
1012

1113
import nibabel as nb
1214
from nibabel.eulerangles import euler2mat
@@ -24,7 +26,7 @@
2426
FSLinearTransform as LT,
2527
FSLinearTransformArray as LTA,
2628
)
27-
from ..io.base import LinearParameters, TransformFileError
29+
from ..io.base import LinearParameters, TransformIOError, TransformFileError
2830

2931
LPS = np.diag([-1, -1, 1, 1])
3032
ITK_MAT = LPS.dot(np.ones((4, 4)).dot(LPS))
@@ -224,7 +226,7 @@ def test_Linear_common(tmpdir, data_path, sw, image_orientation, get_testdata):
224226

225227
# Test without images
226228
if sw == "fsl":
227-
with pytest.raises(ValueError):
229+
with pytest.raises(TransformIOError):
228230
factory.from_ras(RAS)
229231
else:
230232
xfm = factory.from_ras(RAS)
@@ -408,7 +410,7 @@ def test_afni_Displacements():
408410
afni.AFNIDisplacementsField.from_image(field)
409411

410412

411-
def test_itk_h5(testdata_path):
413+
def test_itk_h5(tmpdir, testdata_path):
412414
"""Test displacements fields."""
413415
assert (
414416
len(
@@ -422,14 +424,29 @@ def test_itk_h5(testdata_path):
422424
== 2
423425
)
424426

425-
with pytest.raises(RuntimeError):
427+
with pytest.raises(TransformFileError):
426428
list(
427429
itk.ITKCompositeH5.from_filename(
428430
testdata_path
429431
/ "ds-005_sub-01_from-T1w_to-MNI152NLin2009cAsym_mode-image_xfm.x5"
430432
)
431433
)
432434

435+
tmpdir.chdir()
436+
shutil.copy(
437+
testdata_path / "ds-005_sub-01_from-T1w_to-MNI152NLin2009cAsym_mode-image_xfm.h5",
438+
"test.h5",
439+
)
440+
os.chmod("test.h5", 0o666)
441+
442+
with H5File("test.h5", "r+") as h5file:
443+
h5group = h5file["TransformGroup"]
444+
xfm = h5group[list(h5group.keys())[1]]
445+
xfm["TransformType"][0] = b"InventTransform"
446+
447+
with pytest.raises(TransformIOError):
448+
itk.ITKCompositeH5.from_filename("test.h5")
449+
433450

434451
@pytest.mark.parametrize(
435452
"file_type, test_file", [(LTA, "from-fsnative_to-scanner_mode-image.lta")]

0 commit comments

Comments
 (0)