Skip to content

Commit b692390

Browse files
committed
enh: guess open linear transform formats
EAFP implementation of loading linear transforms without specifying the format of the file. Resolves: #86. Resolves: #87. Resolves: #107.
1 parent 3b1f125 commit b692390

File tree

3 files changed

+70
-71
lines changed

3 files changed

+70
-71
lines changed

nitransforms/io/__init__.py

+23-1
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,33 @@
11
# emacs: -*- mode: python-mode; py-indent-offset: 4; indent-tabs-mode: nil -*-
22
# vi: set ft=python sts=4 ts=4 sw=4 et:
33
"""Read and write transforms."""
4-
from . import afni, fsl, itk, lta
4+
from nitransforms.io import afni, fsl, itk, lta
5+
from nitransforms.io.base import TransformFileError
56

67
__all__ = [
78
"afni",
89
"fsl",
910
"itk",
1011
"lta",
12+
"get_linear_factory",
13+
"TransformFileError",
1114
]
15+
16+
_IO_TYPES = {
17+
"itk": (itk, "ITKLinearTransform"),
18+
"ants": (itk, "ITKLinearTransform"),
19+
"elastix": (itk, "ITKLinearTransform"),
20+
"lta": (lta, "FSLinearTransform"),
21+
"fs": (lta, "FSLinearTransform"),
22+
"fsl": (fsl, "FSLLinearTransform"),
23+
"afni": (afni, "AFNILinearTransform"),
24+
}
25+
26+
27+
def get_linear_factory(fmt, is_array=True):
28+
"""Return the type required by a given format."""
29+
if fmt.lower() not in _IO_TYPES:
30+
raise TypeError(f"Unsupported transform format <{fmt}>.")
31+
32+
module, classname = _IO_TYPES[fmt.lower()]
33+
return getattr(module, f"{classname}{'Array' * is_array}")

nitransforms/linear.py

+43-67
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,14 @@
1414

1515
from nibabel.loadsave import load as _nbload
1616

17-
from .base import (
17+
from nitransforms.base import (
1818
ImageGrid,
1919
TransformBase,
2020
SpatialReference,
2121
_as_homogeneous,
2222
EQUALITY_TOL,
2323
)
24-
from . import io
24+
from nitransforms.io import get_linear_factory, TransformFileError
2525

2626

2727
class Affine(TransformBase):
@@ -183,51 +183,40 @@ def _to_hdf5(self, x5_root):
183183
self.reference._to_hdf5(x5_root.create_group("Reference"))
184184

185185
def to_filename(self, filename, fmt="X5", moving=None):
186-
"""Store the transform in BIDS-Transforms HDF5 file format (.x5)."""
187-
if fmt.lower() in ["itk", "ants", "elastix"]:
188-
itkobj = io.itk.ITKLinearTransform.from_ras(self.matrix)
189-
itkobj.to_filename(filename)
190-
return filename
191-
192-
# Rest of the formats peek into moving and reference image grids
193-
moving = ImageGrid(moving) if moving is not None else self.reference
194-
195-
_factory = {
196-
"afni": io.afni.AFNILinearTransform,
197-
"fsl": io.fsl.FSLLinearTransform,
198-
"lta": io.lta.FSLinearTransform,
199-
"fs": io.lta.FSLinearTransform,
200-
}
201-
202-
if fmt not in _factory:
203-
raise NotImplementedError(f"Unsupported format <{fmt}>")
204-
205-
_factory[fmt].from_ras(
206-
self.matrix, moving=moving, reference=self.reference
207-
).to_filename(filename)
208-
return filename
186+
"""Store the transform in the requested output format."""
187+
writer = get_linear_factory(fmt, is_array=False)
209188

210-
@classmethod
211-
def from_filename(cls, filename, fmt="X5", reference=None, moving=None):
212-
"""Create an affine from a transform file."""
213189
if fmt.lower() in ("itk", "ants", "elastix"):
214-
_factory = io.itk.ITKLinearTransformArray
215-
elif fmt.lower() in ("lta", "fs"):
216-
_factory = io.lta.FSLinearTransformArray
217-
elif fmt.lower() == "fsl":
218-
_factory = io.fsl.FSLLinearTransformArray
219-
elif fmt.lower() == "afni":
220-
_factory = io.afni.AFNILinearTransformArray
190+
writer.from_ras(self.matrix).to_filename(filename)
221191
else:
222-
raise NotImplementedError
192+
# Rest of the formats peek into moving and reference image grids
193+
writer.from_ras(
194+
self.matrix,
195+
reference=self.reference,
196+
moving=ImageGrid(moving) if moving is not None else self.reference,
197+
).to_filename(filename)
198+
return filename
223199

224-
struct = _factory.from_filename(filename)
225-
matrix = struct.to_ras(reference=reference, moving=moving)
226-
if cls == Affine:
227-
if np.shape(matrix)[0] != 1:
228-
raise TypeError("Cannot load transform array '%s'" % filename)
229-
matrix = matrix[0]
230-
return cls(matrix, reference=reference)
200+
@classmethod
201+
def from_filename(cls, filename, fmt=None, reference=None, moving=None):
202+
"""Create an affine from a transform file."""
203+
fmtlist = [fmt] if fmt is not None else ("itk", "lta", "afni", "fsl")
204+
205+
for potential_fmt in fmtlist:
206+
try:
207+
struct = get_linear_factory(potential_fmt).from_filename(filename)
208+
matrix = struct.to_ras(reference=reference, moving=moving)
209+
if cls == Affine:
210+
if np.shape(matrix)[0] != 1:
211+
raise TypeError("Cannot load transform array '%s'" % filename)
212+
matrix = matrix[0]
213+
return cls(matrix, reference=reference)
214+
except TransformFileError:
215+
continue
216+
217+
raise TransformFileError(
218+
f"Could not open <{filename}> (formats tried: {', '.join(fmtlist)}."
219+
)
231220

232221
def __repr__(self):
233222
"""
@@ -353,31 +342,18 @@ def map(self, x, inverse=False):
353342
return np.swapaxes(affine.dot(coords), 1, 2)
354343

355344
def to_filename(self, filename, fmt="X5", moving=None):
356-
"""Store the transform in BIDS-Transforms HDF5 file format (.x5)."""
357-
if fmt.lower() in ("itk", "ants", "elastix"):
358-
itkobj = io.itk.ITKLinearTransformArray.from_ras(self.matrix)
359-
itkobj.to_filename(filename)
360-
return filename
345+
"""Store the transform in the requested output format."""
346+
writer = get_linear_factory(fmt, is_array=True)
361347

362-
# Rest of the formats peek into moving and reference image grids
363-
if moving is not None:
364-
moving = ImageGrid(moving)
348+
if fmt.lower() in ("itk", "ants", "elastix"):
349+
writer.from_ras(self.matrix).to_filename(filename)
365350
else:
366-
moving = self.reference
367-
368-
_factory = {
369-
"afni": io.afni.AFNILinearTransformArray,
370-
"fsl": io.fsl.FSLLinearTransformArray,
371-
"lta": io.lta.FSLinearTransformArray,
372-
"fs": io.lta.FSLinearTransformArray,
373-
}
374-
375-
if fmt not in _factory:
376-
raise NotImplementedError(f"Unsupported format <{fmt}>")
377-
378-
_factory[fmt].from_ras(
379-
self.matrix, moving=moving, reference=self.reference
380-
).to_filename(filename)
351+
# Rest of the formats peek into moving and reference image grids
352+
writer.from_ras(
353+
self.matrix,
354+
reference=self.reference,
355+
moving=ImageGrid(moving) if moving is not None else self.reference,
356+
).to_filename(filename)
381357
return filename
382358

383359
def apply(
@@ -486,7 +462,7 @@ def apply(
486462
return resampled
487463

488464

489-
def load(filename, fmt="X5", reference=None, moving=None):
465+
def load(filename, fmt=None, reference=None, moving=None):
490466
"""
491467
Load a linear transform file.
492468

nitransforms/tests/test_linear.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
import nibabel as nb
1212
from nibabel.eulerangles import euler2mat
1313
from nibabel.affines import from_matvec
14-
from .. import linear as nitl
14+
from nitransforms import linear as nitl
15+
from nitransforms import io
1516
from .utils import assert_affines_by_filename
1617

1718
RMSE_TOL = 0.1
@@ -152,7 +153,7 @@ def test_linear_save(tmpdir, data_path, get_testdata, image_orientation, sw_tool
152153
xfm = (
153154
nitl.Affine(T) if (sw_tool, image_orientation) != ("afni", "oblique") else
154155
# AFNI is special when moving or reference are oblique - let io do the magic
155-
nitl.Affine(nitl.io.afni.AFNILinearTransform.from_ras(T).to_ras(
156+
nitl.Affine(io.afni.AFNILinearTransform.from_ras(T).to_ras(
156157
reference=img,
157158
moving=img,
158159
))
@@ -199,7 +200,7 @@ def test_apply_linear_transform(tmpdir, get_testdata, get_testmask, image_orient
199200
xfm_fname = "M.%s%s" % (sw_tool, ext)
200201
# Change reference dataset for AFNI & oblique
201202
if (sw_tool, image_orientation) == ("afni", "oblique"):
202-
nitl.io.afni.AFNILinearTransform.from_ras(
203+
io.afni.AFNILinearTransform.from_ras(
203204
T,
204205
moving=img,
205206
reference=img,

0 commit comments

Comments
 (0)