|
14 | 14 |
|
15 | 15 | from nibabel.loadsave import load as _nbload
|
16 | 16 |
|
17 |
| -from .base import ( |
| 17 | +from nitransforms.base import ( |
18 | 18 | ImageGrid,
|
19 | 19 | TransformBase,
|
20 | 20 | SpatialReference,
|
21 | 21 | _as_homogeneous,
|
22 | 22 | EQUALITY_TOL,
|
23 | 23 | )
|
24 |
| -from . import io |
| 24 | +from nitransforms.io import get_linear_factory, TransformFileError |
25 | 25 |
|
26 | 26 |
|
27 | 27 | class Affine(TransformBase):
|
@@ -183,51 +183,40 @@ def _to_hdf5(self, x5_root):
|
183 | 183 | self.reference._to_hdf5(x5_root.create_group("Reference"))
|
184 | 184 |
|
185 | 185 | def to_filename(self, filename, fmt="X5", moving=None):
|
186 |
| - """Store the transform in BIDS-Transforms HDF5 file format (.x5).""" |
187 |
| - if fmt.lower() in ["itk", "ants", "elastix"]: |
188 |
| - itkobj = io.itk.ITKLinearTransform.from_ras(self.matrix) |
189 |
| - itkobj.to_filename(filename) |
190 |
| - return filename |
191 |
| - |
192 |
| - # Rest of the formats peek into moving and reference image grids |
193 |
| - moving = ImageGrid(moving) if moving is not None else self.reference |
194 |
| - |
195 |
| - _factory = { |
196 |
| - "afni": io.afni.AFNILinearTransform, |
197 |
| - "fsl": io.fsl.FSLLinearTransform, |
198 |
| - "lta": io.lta.FSLinearTransform, |
199 |
| - "fs": io.lta.FSLinearTransform, |
200 |
| - } |
201 |
| - |
202 |
| - if fmt not in _factory: |
203 |
| - raise NotImplementedError(f"Unsupported format <{fmt}>") |
204 |
| - |
205 |
| - _factory[fmt].from_ras( |
206 |
| - self.matrix, moving=moving, reference=self.reference |
207 |
| - ).to_filename(filename) |
208 |
| - return filename |
| 186 | + """Store the transform in the requested output format.""" |
| 187 | + writer = get_linear_factory(fmt, is_array=False) |
209 | 188 |
|
210 |
| - @classmethod |
211 |
| - def from_filename(cls, filename, fmt="X5", reference=None, moving=None): |
212 |
| - """Create an affine from a transform file.""" |
213 | 189 | if fmt.lower() in ("itk", "ants", "elastix"):
|
214 |
| - _factory = io.itk.ITKLinearTransformArray |
215 |
| - elif fmt.lower() in ("lta", "fs"): |
216 |
| - _factory = io.lta.FSLinearTransformArray |
217 |
| - elif fmt.lower() == "fsl": |
218 |
| - _factory = io.fsl.FSLLinearTransformArray |
219 |
| - elif fmt.lower() == "afni": |
220 |
| - _factory = io.afni.AFNILinearTransformArray |
| 190 | + writer.from_ras(self.matrix).to_filename(filename) |
221 | 191 | else:
|
222 |
| - raise NotImplementedError |
| 192 | + # Rest of the formats peek into moving and reference image grids |
| 193 | + writer.from_ras( |
| 194 | + self.matrix, |
| 195 | + reference=self.reference, |
| 196 | + moving=ImageGrid(moving) if moving is not None else self.reference, |
| 197 | + ).to_filename(filename) |
| 198 | + return filename |
223 | 199 |
|
224 |
| - struct = _factory.from_filename(filename) |
225 |
| - matrix = struct.to_ras(reference=reference, moving=moving) |
226 |
| - if cls == Affine: |
227 |
| - if np.shape(matrix)[0] != 1: |
228 |
| - raise TypeError("Cannot load transform array '%s'" % filename) |
229 |
| - matrix = matrix[0] |
230 |
| - return cls(matrix, reference=reference) |
| 200 | + @classmethod |
| 201 | + def from_filename(cls, filename, fmt=None, reference=None, moving=None): |
| 202 | + """Create an affine from a transform file.""" |
| 203 | + fmtlist = [fmt] if fmt is not None else ("itk", "lta", "afni", "fsl") |
| 204 | + |
| 205 | + for potential_fmt in fmtlist: |
| 206 | + try: |
| 207 | + struct = get_linear_factory(potential_fmt).from_filename(filename) |
| 208 | + matrix = struct.to_ras(reference=reference, moving=moving) |
| 209 | + if cls == Affine: |
| 210 | + if np.shape(matrix)[0] != 1: |
| 211 | + raise TypeError("Cannot load transform array '%s'" % filename) |
| 212 | + matrix = matrix[0] |
| 213 | + return cls(matrix, reference=reference) |
| 214 | + except (TransformFileError, FileNotFoundError): |
| 215 | + continue |
| 216 | + |
| 217 | + raise TransformFileError( |
| 218 | + f"Could not open <{filename}> (formats tried: {', '.join(fmtlist)})." |
| 219 | + ) |
231 | 220 |
|
232 | 221 | def __repr__(self):
|
233 | 222 | """
|
@@ -353,31 +342,18 @@ def map(self, x, inverse=False):
|
353 | 342 | return np.swapaxes(affine.dot(coords), 1, 2)
|
354 | 343 |
|
355 | 344 | def to_filename(self, filename, fmt="X5", moving=None):
|
356 |
| - """Store the transform in BIDS-Transforms HDF5 file format (.x5).""" |
357 |
| - if fmt.lower() in ("itk", "ants", "elastix"): |
358 |
| - itkobj = io.itk.ITKLinearTransformArray.from_ras(self.matrix) |
359 |
| - itkobj.to_filename(filename) |
360 |
| - return filename |
| 345 | + """Store the transform in the requested output format.""" |
| 346 | + writer = get_linear_factory(fmt, is_array=True) |
361 | 347 |
|
362 |
| - # Rest of the formats peek into moving and reference image grids |
363 |
| - if moving is not None: |
364 |
| - moving = ImageGrid(moving) |
| 348 | + if fmt.lower() in ("itk", "ants", "elastix"): |
| 349 | + writer.from_ras(self.matrix).to_filename(filename) |
365 | 350 | else:
|
366 |
| - moving = self.reference |
367 |
| - |
368 |
| - _factory = { |
369 |
| - "afni": io.afni.AFNILinearTransformArray, |
370 |
| - "fsl": io.fsl.FSLLinearTransformArray, |
371 |
| - "lta": io.lta.FSLinearTransformArray, |
372 |
| - "fs": io.lta.FSLinearTransformArray, |
373 |
| - } |
374 |
| - |
375 |
| - if fmt not in _factory: |
376 |
| - raise NotImplementedError(f"Unsupported format <{fmt}>") |
377 |
| - |
378 |
| - _factory[fmt].from_ras( |
379 |
| - self.matrix, moving=moving, reference=self.reference |
380 |
| - ).to_filename(filename) |
| 351 | + # Rest of the formats peek into moving and reference image grids |
| 352 | + writer.from_ras( |
| 353 | + self.matrix, |
| 354 | + reference=self.reference, |
| 355 | + moving=ImageGrid(moving) if moving is not None else self.reference, |
| 356 | + ).to_filename(filename) |
381 | 357 | return filename
|
382 | 358 |
|
383 | 359 | def apply(
|
@@ -486,17 +462,17 @@ def apply(
|
486 | 462 | return resampled
|
487 | 463 |
|
488 | 464 |
|
489 |
| -def load(filename, fmt="X5", reference=None, moving=None): |
| 465 | +def load(filename, fmt=None, reference=None, moving=None): |
490 | 466 | """
|
491 | 467 | Load a linear transform file.
|
492 | 468 |
|
493 | 469 | Examples
|
494 | 470 | --------
|
495 |
| - >>> xfm = load(regress_dir / "affine-LAS.itk.tfm", fmt="itk") |
| 471 | + >>> xfm = load(regress_dir / "affine-LAS.itk.tfm") |
496 | 472 | >>> isinstance(xfm, Affine)
|
497 | 473 | True
|
498 | 474 |
|
499 |
| - >>> xfm = load(regress_dir / "itktflist.tfm", fmt="itk") |
| 475 | + >>> xfm = load(regress_dir / "itktflist.tfm") |
500 | 476 | >>> isinstance(xfm, LinearTransformsMapping)
|
501 | 477 | True
|
502 | 478 |
|
|
0 commit comments