diff --git a/nibabel/cifti2/cifti2.py b/nibabel/cifti2/cifti2.py index 0b2b7f2a9a..a0cd22fa58 100644 --- a/nibabel/cifti2/cifti2.py +++ b/nibabel/cifti2/cifti2.py @@ -25,6 +25,7 @@ from ..nifti1 import Nifti1Extensions from ..nifti2 import Nifti2Image, Nifti2Header from ..arrayproxy import reshape_dataobj +from ..volumeutils import Recoder from warnings import warn @@ -89,6 +90,53 @@ class Cifti2HeaderError(Exception): 'CIFTI_STRUCTURE_THALAMUS_LEFT', 'CIFTI_STRUCTURE_THALAMUS_RIGHT') +# "Standard CIFTI Mapping Combinations" within CIFTI-2 spec +# https://www.nitrc.org/forum/attachment.php?attachid=341&group_id=454&forum_id=1955 +CIFTI_CODES = Recoder(( + ('.dconn.nii', 'NIFTI_INTENT_CONNECTIVITY_DENSE', ( + 'CIFTI_INDEX_TYPE_BRAIN_MODELS', 'CIFTI_INDEX_TYPE_BRAIN_MODELS', + )), + ('.dtseries.nii', 'NIFTI_INTENT_CONNECTIVITY_DENSE_SERIES', ( + 'CIFTI_INDEX_TYPE_SERIES', 'CIFTI_INDEX_TYPE_BRAIN_MODELS', + )), + ('.pconn.nii', 'NIFTI_INTENT_CONNECTIVITY_PARCELLATED', ( + 'CIFTI_INDEX_TYPE_PARCELS', 'CIFTI_INDEX_TYPE_PARCELS', + )), + ('.ptseries.nii', 'NIFTI_INTENT_CONNECTIVITY_PARCELLATED_SERIES', ( + 'CIFTI_INDEX_TYPE_SERIES', 'CIFTI_INDEX_TYPE_PARCELS', + )), + ('.dscalar.nii', 'NIFTI_INTENT_CONNECTIVITY_DENSE_SCALARS', ( + 'CIFTI_INDEX_TYPE_SCALARS', 'CIFTI_INDEX_TYPE_BRAIN_MODELS', + )), + ('.dlabel.nii', 'NIFTI_INTENT_CONNECTIVITY_DENSE_LABELS', ( + 'CIFTI_INDEX_TYPE_LABELS', 'CIFTI_INDEX_TYPE_BRAIN_MODELS', + )), + ('.pscalar.nii', 'NIFTI_INTENT_CONNECTIVITY_PARCELLATED_SCALAR', ( + 'CIFTI_INDEX_TYPE_SCALARS', 'CIFTI_INDEX_TYPE_PARCELS', + )), + ('.pdconn.nii', 'NIFTI_INTENT_CONNECTIVITY_PARCELLATED_DENSE', ( + 'CIFTI_INDEX_TYPE_BRAIN_MODELS', 'CIFTI_INDEX_TYPE_PARCELS', + )), + ('.dpconn.nii', 'NIFTI_INTENT_CONNECTIVITY_DENSE_PARCELLATED', ( + 'CIFTI_INDEX_TYPE_PARCELS', 'CIFTI_INDEX_TYPE_BRAIN_MODELS', + )), + ('.pconnseries.nii', 'NIFTI_INTENT_CONNECTIVITY_PARCELLATED_PARCELLATED_SERIES', ( + 'CIFTI_INDEX_TYPE_PARCELS', 'CIFTI_INDEX_TYPE_PARCELS', 'CIFTI_INDEX_TYPE_SERIES', + )), + ('.pconnscalar.nii', 'NIFTI_INTENT_CONNECTIVITY_PARCELLATED_PARCELLATED_SCALAR', ( + 'CIFTI_INDEX_TYPE_PARCELS', 'CIFTI_INDEX_TYPE_PARCELS', 'CIFTI_INDEX_TYPE_SCALARS', + )), + ('.dfan.nii', 'NIFTI_INTENT_CONNECTIVITY_DENSE_SERIES', ( + 'CIFTI_INDEX_TYPE_SCALARS', 'CIFTI_INDEX_TYPE_BRAIN_MODELS', + )), + ('.dfibersamp.nii', 'NIFTI_INTENT_CONNECTIVITY_UNKNOWN', ( + 'CIFTI_INDEX_TYPE_SCALARS', 'CIFTI_INDEX_TYPE_SCALARS', 'CIFTI_INDEX_TYPE_BRAIN_MODELS', + )), + ('.dfansamp.nii', 'NIFTI_INTENT_CONNECTIVITY_UNKNOWN', ( + 'CIFTI_INDEX_TYPE_SCALARS', 'CIFTI_INDEX_TYPE_SCALARS', 'CIFTI_INDEX_TYPE_BRAIN_MODELS', + )), +), fields=('extension', 'niistring', 'map_types')) + def _value_if_klass(val, klass): if val is None or isinstance(val, klass): @@ -1466,11 +1514,7 @@ def to_file_map(self, file_map=None): raise ValueError( f"Dataobj shape {self._dataobj.shape} does not match shape " f"expected from CIFTI-2 header {self.header.matrix.get_data_shape()}") - # if intent code is not set, default to unknown CIFTI - if header.get_intent()[0] == 'none': - header.set_intent('NIFTI_INTENT_CONNECTIVITY_UNKNOWN') - data = reshape_dataobj(self.dataobj, - (1, 1, 1, 1) + self.dataobj.shape) + data = reshape_dataobj(self.dataobj, (1, 1, 1, 1) + self.dataobj.shape) # If qform not set, reset pixdim values so Nifti2 does not complain if header['qform_code'] == 0: header['pixdim'][:4] = 1 @@ -1501,7 +1545,11 @@ def update_headers(self): >>> img.shape == (2, 3, 4) True """ - self._nifti_header.set_data_shape((1, 1, 1, 1) + self._dataobj.shape) + header = self._nifti_header + header.set_data_shape((1, 1, 1, 1) + self._dataobj.shape) + # if intent code is not set, default to unknown + if header.get_intent()[0] == 'none': + header.set_intent('NIFTI_INTENT_CONNECTIVITY_UNKNOWN') def get_data_dtype(self): return self._nifti_header.get_data_dtype() @@ -1509,6 +1557,47 @@ def get_data_dtype(self): def set_data_dtype(self, dtype): self._nifti_header.set_data_dtype(dtype) + def to_filename(self, filename, validate=True): + """ + Ensures NIfTI header intent code is set prior to saving. + + Parameters + ---------- + validate : boolean, optional + If ``True``, infer and validate CIFTI type based on MatrixIndicesMap values. + This includes the setting of the relevant intent code within the NIfTI header. + If validation fails, a UserWarning is issued and saving continues. + """ + if validate: + # Determine CIFTI type via index maps + from .parse_cifti2 import intent_codes + + matrix = self.header.matrix + map_types = tuple( + matrix.get_index_map(idx).indices_map_to_data_type for idx + in sorted(matrix.mapped_indices) + ) + try: + expected_intent = CIFTI_CODES.niistring[map_types] + expected_ext = CIFTI_CODES.extension[map_types] + except KeyError: # unknown + expected_intent = "NIFTI_INTENT_CONNECTIVITY_UNKNOWN" + expected_ext = None + warn( + "No information found for matrix containing the following index maps:" + f"{map_types}, defaulting to unknown." + ) + + orig_intent = self._nifti_header.get_intent()[0] + if expected_intent != intent_codes.niistring[orig_intent]: + warn( + f"Expected NIfTI intent: {expected_intent} has been automatically set." + ) + self._nifti_header.set_intent(expected_intent) + if expected_ext is not None and not filename.endswith(expected_ext): + warn(f"Filename does not end with expected extension: {expected_ext}") + super().to_filename(filename) + load = Cifti2Image.from_filename save = Cifti2Image.instance_to_filename diff --git a/nibabel/cifti2/tests/test_cifti2.py b/nibabel/cifti2/tests/test_cifti2.py index ea571065de..f64e11c5df 100644 --- a/nibabel/cifti2/tests/test_cifti2.py +++ b/nibabel/cifti2/tests/test_cifti2.py @@ -427,3 +427,18 @@ def make_imaker(self, arr, header=None, ni_header=None): ) header.matrix.append(mim) return lambda: self.image_maker(arr.copy(), header, ni_header) + + def validate_filenames(self, imaker, params, validate=False): + super().validate_filenames(imaker, params, validate=validate) + + def validate_mmap_parameter(self, imaker, params, validate=False): + super().validate_mmap_parameter(imaker, params, validate=validate) + + def validate_to_bytes(self, imaker, params, validate=False): + super().validate_to_bytes(imaker, params, validate=validate) + + def validate_from_bytes(self, imaker, params, validate=False): + super().validate_from_bytes(imaker, params, validate=validate) + + def validate_to_from_bytes(self, imaker, params, validate=False): + super().validate_to_from_bytes(imaker, params, validate=validate) diff --git a/nibabel/cifti2/tests/test_cifti2io_axes.py b/nibabel/cifti2/tests/test_cifti2io_axes.py index c237e3c61a..1abd129ae5 100644 --- a/nibabel/cifti2/tests/test_cifti2io_axes.py +++ b/nibabel/cifti2/tests/test_cifti2io_axes.py @@ -91,7 +91,7 @@ def check_rewrite(arr, axes, extension='.nii'): custom extension to use """ (fd, name) = tempfile.mkstemp(extension) - cifti2.Cifti2Image(arr, header=axes).to_filename(name) + cifti2.Cifti2Image(arr, header=axes).to_filename(name, validate=False) img = nib.load(name) arr2 = img.get_fdata() assert np.allclose(arr, arr2) diff --git a/nibabel/cifti2/tests/test_cifti2io_header.py b/nibabel/cifti2/tests/test_cifti2io_header.py index df4fe10fcd..f1f1f6c811 100644 --- a/nibabel/cifti2/tests/test_cifti2io_header.py +++ b/nibabel/cifti2/tests/test_cifti2io_header.py @@ -83,7 +83,7 @@ def test_readwritedata(): with InTemporaryDirectory(): for name in datafiles: img = ci.load(name) - ci.save(img, 'test.nii') + ci.save(img, 'test.nii', validate=False) img2 = ci.load('test.nii') assert len(img.header.matrix) == len(img2.header.matrix) # Order should be preserved in load/save @@ -109,7 +109,7 @@ def test_nibabel_readwritedata(): with InTemporaryDirectory(): for name in datafiles: img = nib.load(name) - nib.save(img, 'test.nii') + nib.save(img, 'test.nii', validate=False) img2 = nib.load('test.nii') assert len(img.header.matrix) == len(img2.header.matrix) # Order should be preserved in load/save diff --git a/nibabel/cifti2/tests/test_new_cifti2.py b/nibabel/cifti2/tests/test_new_cifti2.py index 65ef95c316..026fc920f0 100644 --- a/nibabel/cifti2/tests/test_new_cifti2.py +++ b/nibabel/cifti2/tests/test_new_cifti2.py @@ -7,12 +7,11 @@ scratch. """ import numpy as np - import nibabel as nib from nibabel import cifti2 as ci from nibabel.tmpdirs import InTemporaryDirectory - import pytest + from ...testing import ( clear_and_catch_warnings, error_warnings, suppress_warnings, assert_array_equal) @@ -237,7 +236,6 @@ def test_dtseries(): hdr = ci.Cifti2Header(matrix) data = np.random.randn(13, 10) img = ci.Cifti2Image(data, hdr) - img.nifti_header.set_intent('NIFTI_INTENT_CONNECTIVITY_DENSE_SERIES') with InTemporaryDirectory(): ci.save(img, 'test.dtseries.nii') @@ -281,7 +279,6 @@ def test_dlabel(): hdr = ci.Cifti2Header(matrix) data = np.random.randn(2, 10) img = ci.Cifti2Image(data, hdr) - img.nifti_header.set_intent('NIFTI_INTENT_CONNECTIVITY_DENSE_LABELS') with InTemporaryDirectory(): ci.save(img, 'test.dlabel.nii') @@ -301,7 +298,6 @@ def test_dconn(): hdr = ci.Cifti2Header(matrix) data = np.random.randn(10, 10) img = ci.Cifti2Image(data, hdr) - img.nifti_header.set_intent('NIFTI_INTENT_CONNECTIVITY_DENSE') with InTemporaryDirectory(): ci.save(img, 'test.dconn.nii') @@ -323,7 +319,6 @@ def test_ptseries(): hdr = ci.Cifti2Header(matrix) data = np.random.randn(13, 4) img = ci.Cifti2Image(data, hdr) - img.nifti_header.set_intent('NIFTI_INTENT_CONNECTIVITY_PARCELLATED_SERIES') with InTemporaryDirectory(): ci.save(img, 'test.ptseries.nii') @@ -345,7 +340,6 @@ def test_pscalar(): hdr = ci.Cifti2Header(matrix) data = np.random.randn(2, 4) img = ci.Cifti2Image(data, hdr) - img.nifti_header.set_intent('NIFTI_INTENT_CONNECTIVITY_PARCELLATED_SCALAR') with InTemporaryDirectory(): ci.save(img, 'test.pscalar.nii') @@ -367,7 +361,6 @@ def test_pdconn(): hdr = ci.Cifti2Header(matrix) data = np.random.randn(10, 4) img = ci.Cifti2Image(data, hdr) - img.nifti_header.set_intent('NIFTI_INTENT_CONNECTIVITY_PARCELLATED_DENSE') with InTemporaryDirectory(): ci.save(img, 'test.pdconn.nii') @@ -389,7 +382,6 @@ def test_dpconn(): hdr = ci.Cifti2Header(matrix) data = np.random.randn(4, 10) img = ci.Cifti2Image(data, hdr) - img.nifti_header.set_intent('NIFTI_INTENT_CONNECTIVITY_DENSE_PARCELLATED') with InTemporaryDirectory(): ci.save(img, 'test.dpconn.nii') @@ -413,7 +405,7 @@ def test_plabel(): img = ci.Cifti2Image(data, hdr) with InTemporaryDirectory(): - ci.save(img, 'test.plabel.nii') + ci.save(img, 'test.plabel.nii', validate=False) img2 = ci.load('test.plabel.nii') assert img.nifti_header.get_intent()[0] == 'ConnUnknown' assert isinstance(img2, ci.Cifti2Image) @@ -430,7 +422,6 @@ def test_pconn(): hdr = ci.Cifti2Header(matrix) data = np.random.randn(4, 4) img = ci.Cifti2Image(data, hdr) - img.nifti_header.set_intent('NIFTI_INTENT_CONNECTIVITY_PARCELLATED') with InTemporaryDirectory(): ci.save(img, 'test.pconn.nii') @@ -453,8 +444,6 @@ def test_pconnseries(): hdr = ci.Cifti2Header(matrix) data = np.random.randn(4, 4, 13) img = ci.Cifti2Image(data, hdr) - img.nifti_header.set_intent('NIFTI_INTENT_CONNECTIVITY_PARCELLATED_' - 'PARCELLATED_SERIES') with InTemporaryDirectory(): ci.save(img, 'test.pconnseries.nii') @@ -478,8 +467,6 @@ def test_pconnscalar(): hdr = ci.Cifti2Header(matrix) data = np.random.randn(4, 4, 2) img = ci.Cifti2Image(data, hdr) - img.nifti_header.set_intent('NIFTI_INTENT_CONNECTIVITY_PARCELLATED_' - 'PARCELLATED_SCALAR') with InTemporaryDirectory(): ci.save(img, 'test.pconnscalar.nii') @@ -517,7 +504,45 @@ def test_wrong_shape(): ci.Cifti2Image(data, hdr) with suppress_warnings(): img = ci.Cifti2Image(data, hdr) - + with pytest.raises(ValueError): img.to_file_map() + +def test_cifti_validation(): + # flip label / brain_model index maps + geometry_map = create_geometry_map((0, )) + label_map = create_label_map((1, )) + matrix = ci.Cifti2Matrix() + matrix.append(geometry_map) + matrix.append(label_map) + hdr = ci.Cifti2Header(matrix) + data = np.random.randn(10, 2) + img = ci.Cifti2Image(data, hdr) + # flipped index maps will warn + with InTemporaryDirectory(), pytest.warns(UserWarning): + ci.save(img, 'test.dlabel.nii') + + label_map = create_label_map((0, )) + geometry_map = create_geometry_map((1, )) + matrix = ci.Cifti2Matrix() + matrix.append(label_map) + matrix.append(geometry_map) + hdr = ci.Cifti2Header(matrix) + data = np.random.randn(2, 10) + img = ci.Cifti2Image(data, hdr) + + with InTemporaryDirectory(): + ci.save(img, 'test.validate.nii', validate=False) + ci.save(img, 'test.dlabel.nii') + + img2 = nib.load('test.dlabel.nii') + img3 = nib.load('test.validate.nii') + assert img2.nifti_header.get_intent()[0] == 'ConnDenseLabel' + assert img3.nifti_header.get_intent()[0] == 'ConnUnknown' + assert isinstance(img2, ci.Cifti2Image) + assert isinstance(img3, ci.Cifti2Image) + assert_array_equal(img2.get_fdata(), data) + check_label_map(img2.header.matrix.get_index_map(0)) + check_geometry_map(img2.header.matrix.get_index_map(1)) + del img2, img3 \ No newline at end of file diff --git a/nibabel/filebasedimages.py b/nibabel/filebasedimages.py index 006b70d615..f610e87543 100644 --- a/nibabel/filebasedimages.py +++ b/nibabel/filebasedimages.py @@ -315,7 +315,7 @@ def filespec_to_file_map(klass, filespec): def filespec_to_files(klass, filespec): return klass.filespec_to_file_map(filespec) - def to_filename(self, filename): + def to_filename(self, filename, **kwargs): """ Write image to files implied by filename string Parameters @@ -381,7 +381,7 @@ def make_file_map(klass, mapping=None): load = from_filename @classmethod - def instance_to_filename(klass, img, filename): + def instance_to_filename(klass, img, filename, **kwargs): """ Save `img` in our own format, to name implied by `filename` This is a class method @@ -394,7 +394,7 @@ def instance_to_filename(klass, img, filename): Filename, implying name to which to save image. """ img = klass.from_image(img) - img.to_filename(filename) + img.to_filename(filename, **kwargs) @classmethod def from_image(klass, img): diff --git a/nibabel/loadsave.py b/nibabel/loadsave.py index f3a2b5876b..eb112b158c 100644 --- a/nibabel/loadsave.py +++ b/nibabel/loadsave.py @@ -78,7 +78,7 @@ def guessed_image_type(filename): raise ImageFileError(f'Cannot work out file type of "{filename}"') -def save(img, filename): +def save(img, filename, **kwargs): """ Save an image to file adapting format to `filename` Parameters @@ -96,7 +96,7 @@ def save(img, filename): # Save the type as expected try: - img.to_filename(filename) + img.to_filename(filename, **kwargs) except ImageFileError: pass else: @@ -144,7 +144,7 @@ def save(img, filename): # Here, we either have a klass or a converted image. if converted is None: converted = klass.from_image(img) - converted.to_filename(filename) + converted.to_filename(filename, **kwargs) @deprecate_with_version('read_img_data deprecated. ' diff --git a/nibabel/tests/test_image_api.py b/nibabel/tests/test_image_api.py index 392a493e53..9a9ae040c4 100644 --- a/nibabel/tests/test_image_api.py +++ b/nibabel/tests/test_image_api.py @@ -123,7 +123,7 @@ def validate_header_deprecated(self, imaker, params): hdr = img.get_header() assert hdr is img.header - def validate_filenames(self, imaker, params): + def validate_filenames(self, imaker, params, **kwargs): # Validate the filename, file_map interface if not self.can_save: @@ -160,7 +160,7 @@ def validate_filenames(self, imaker, params): warnings.filterwarnings('error', category=DeprecationWarning, module=r"nibabel.*") - img.to_filename(path) + img.to_filename(path, **kwargs) rt_img = img.__class__.from_filename(path) assert_array_equal(img.shape, rt_img.shape) assert_almost_equal(img.get_fdata(), rt_img.get_fdata()) @@ -456,7 +456,7 @@ def validate_shape_deprecated(self, imaker, params): with pytest.raises(ExpiredDeprecationError): img.get_shape() - def validate_mmap_parameter(self, imaker, params): + def validate_mmap_parameter(self, imaker, params, **kwargs): img = imaker() fname = img.get_filename() with InTemporaryDirectory(): @@ -468,7 +468,7 @@ def validate_mmap_parameter(self, imaker, params): if not img.rw or not img.valid_exts: return fname = 'image' + img.valid_exts[0] - img.to_filename(fname) + img.to_filename(fname, **kwargs) rt_img = img.__class__.from_filename(fname, mmap=True) assert_almost_equal(img.get_fdata(), rt_img.get_fdata()) rt_img = img.__class__.from_filename(fname, mmap=False) @@ -533,22 +533,22 @@ def validate_affine_deprecated(self, imaker, params): class SerializeMixin(object): - def validate_to_bytes(self, imaker, params): + def validate_to_bytes(self, imaker, params, **kwargs): img = imaker() serialized = img.to_bytes() with InTemporaryDirectory(): fname = 'img' + self.standard_extension - img.to_filename(fname) + img.to_filename(fname, **kwargs) with open(fname, 'rb') as fobj: file_contents = fobj.read() assert serialized == file_contents - def validate_from_bytes(self, imaker, params): + def validate_from_bytes(self, imaker, params, **kwargs): img = imaker() klass = getattr(self, 'klass', img.__class__) with InTemporaryDirectory(): fname = 'img' + self.standard_extension - img.to_filename(fname) + img.to_filename(fname, **kwargs) all_images = list(getattr(self, 'example_images', [])) + [{'fname': fname}] for img_params in all_images: @@ -561,12 +561,12 @@ def validate_from_bytes(self, imaker, params): del img_a del img_b - def validate_to_from_bytes(self, imaker, params): + def validate_to_from_bytes(self, imaker, params, **kwargs): img = imaker() klass = getattr(self, 'klass', img.__class__) with InTemporaryDirectory(): fname = 'img' + self.standard_extension - img.to_filename(fname) + img.to_filename(fname, **kwargs) all_images = list(getattr(self, 'example_images', [])) + [{'fname': fname}] for img_params in all_images: