Skip to content

Commit b356a64

Browse files
authored
Merge pull request #774 from MichielCottaar/check_cifti_shape
ENH: Check CIFTI-2 data shape matches shape described by header
2 parents 0282bb9 + c531421 commit b356a64

File tree

4 files changed

+95
-18
lines changed

4 files changed

+95
-18
lines changed

nibabel/batteryrunners.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ def check_only(self, obj):
141141
-------
142142
reports : sequence
143143
sequence of report objects reporting on result of running
144-
checks (withou fixes) on `obj`
144+
checks (without fixes) on `obj`
145145
'''
146146
reports = []
147147
for check in self._checks:

nibabel/cifti2/cifti2.py

+46-3
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from ..nifti2 import Nifti2Image, Nifti2Header
3131
from ..arrayproxy import reshape_dataobj
3232
from ..keywordonly import kw_only_meth
33+
from warnings import warn
3334

3435

3536
def _float_01(val):
@@ -1209,6 +1210,38 @@ def _to_xml_element(self):
12091210
mat.append(mim._to_xml_element())
12101211
return mat
12111212

1213+
def get_axis(self, index):
1214+
'''
1215+
Generates the Cifti2 axis for a given dimension
1216+
1217+
Parameters
1218+
----------
1219+
index : int
1220+
Dimension for which we want to obtain the mapping.
1221+
1222+
Returns
1223+
-------
1224+
axis : :class:`.cifti2_axes.Axis`
1225+
'''
1226+
from . import cifti2_axes
1227+
return cifti2_axes.from_index_mapping(self.get_index_map(index))
1228+
1229+
def get_data_shape(self):
1230+
"""
1231+
Returns data shape expected based on the CIFTI-2 header
1232+
1233+
Any dimensions omitted in the CIFTI-2 header will be given a default size of None.
1234+
"""
1235+
from . import cifti2_axes
1236+
if len(self.mapped_indices) == 0:
1237+
return ()
1238+
base_shape = [None] * (max(self.mapped_indices) + 1)
1239+
for mim in self:
1240+
size = len(cifti2_axes.from_index_mapping(mim))
1241+
for idx in mim.applies_to_matrix_dimension:
1242+
base_shape[idx] = size
1243+
return tuple(base_shape)
1244+
12121245

12131246
class Cifti2Header(FileBasedHeader, xml.XmlSerializable):
12141247
''' Class for CIFTI-2 header extension '''
@@ -1279,8 +1312,7 @@ def get_axis(self, index):
12791312
-------
12801313
axis : :class:`.cifti2_axes.Axis`
12811314
'''
1282-
from . import cifti2_axes
1283-
return cifti2_axes.from_index_mapping(self.matrix.get_index_map(index))
1315+
return self.matrix.get_axis(index)
12841316

12851317
@classmethod
12861318
def from_axes(cls, axes):
@@ -1345,12 +1377,18 @@ def __init__(self,
13451377
super(Cifti2Image, self).__init__(dataobj, header=header,
13461378
extra=extra, file_map=file_map)
13471379
self._nifti_header = Nifti2Header.from_header(nifti_header)
1380+
13481381
# if NIfTI header not specified, get data type from input array
13491382
if nifti_header is None:
13501383
if hasattr(dataobj, 'dtype'):
13511384
self._nifti_header.set_data_dtype(dataobj.dtype)
13521385
self.update_headers()
13531386

1387+
if self._dataobj.shape != self.header.matrix.get_data_shape():
1388+
warn("Dataobj shape {} does not match shape expected from CIFTI-2 header {}".format(
1389+
self._dataobj.shape, self.header.matrix.get_data_shape()
1390+
))
1391+
13541392
@property
13551393
def nifti_header(self):
13561394
return self._nifti_header
@@ -1426,6 +1464,11 @@ def to_file_map(self, file_map=None):
14261464
header = self._nifti_header
14271465
extension = Cifti2Extension(content=self.header.to_xml())
14281466
header.extensions.append(extension)
1467+
if self._dataobj.shape != self.header.matrix.get_data_shape():
1468+
raise ValueError(
1469+
"Dataobj shape {} does not match shape expected from CIFTI-2 header {}".format(
1470+
self._dataobj.shape, self.header.matrix.get_data_shape()
1471+
))
14291472
# if intent code is not set, default to unknown CIFTI
14301473
if header.get_intent()[0] == 'none':
14311474
header.set_intent('NIFTI_INTENT_CONNECTIVITY_UNKNOWN')
@@ -1438,7 +1481,7 @@ def to_file_map(self, file_map=None):
14381481
img.to_file_map(file_map or self.file_map)
14391482

14401483
def update_headers(self):
1441-
''' Harmonize CIFTI-2 and NIfTI headers with image data
1484+
''' Harmonize NIfTI headers with image data
14421485
14431486
>>> import numpy as np
14441487
>>> data = np.zeros((2,3,4))

nibabel/cifti2/tests/test_cifti2.py

+6
Original file line numberDiff line numberDiff line change
@@ -358,4 +358,10 @@ class TestCifti2ImageAPI(_TDA):
358358
standard_extension = '.nii'
359359

360360
def make_imaker(self, arr, header=None, ni_header=None):
361+
for idx, sz in enumerate(arr.shape):
362+
maps = [ci.Cifti2NamedMap(str(value)) for value in range(sz)]
363+
mim = ci.Cifti2MatrixIndicesMap(
364+
(idx, ), 'CIFTI_INDEX_TYPE_SCALARS', maps=maps
365+
)
366+
header.matrix.append(mim)
361367
return lambda: self.image_maker(arr.copy(), header, ni_header)

nibabel/cifti2/tests/test_new_cifti2.py

+42-14
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,13 @@
1212
from nibabel import cifti2 as ci
1313
from nibabel.tmpdirs import InTemporaryDirectory
1414

15-
from nose.tools import assert_true, assert_equal
15+
from nose.tools import assert_true, assert_equal, assert_raises
16+
from nibabel.testing import clear_and_catch_warnings, error_warnings, suppress_warnings
1617

1718
affine = [[-1.5, 0, 0, 90],
1819
[0, 1.5, 0, -85],
19-
[0, 0, 1.5, -71]]
20+
[0, 0, 1.5, -71],
21+
[0, 0, 0, 1.]]
2022

2123
dimensions = (120, 83, 78)
2224

@@ -234,7 +236,7 @@ def test_dtseries():
234236
matrix.append(series_map)
235237
matrix.append(geometry_map)
236238
hdr = ci.Cifti2Header(matrix)
237-
data = np.random.randn(13, 9)
239+
data = np.random.randn(13, 10)
238240
img = ci.Cifti2Image(data, hdr)
239241
img.nifti_header.set_intent('NIFTI_INTENT_CONNECTIVITY_DENSE_SERIES')
240242

@@ -257,7 +259,7 @@ def test_dscalar():
257259
matrix.append(scalar_map)
258260
matrix.append(geometry_map)
259261
hdr = ci.Cifti2Header(matrix)
260-
data = np.random.randn(2, 9)
262+
data = np.random.randn(2, 10)
261263
img = ci.Cifti2Image(data, hdr)
262264
img.nifti_header.set_intent('NIFTI_INTENT_CONNECTIVITY_DENSE_SCALARS')
263265

@@ -279,7 +281,7 @@ def test_dlabel():
279281
matrix.append(label_map)
280282
matrix.append(geometry_map)
281283
hdr = ci.Cifti2Header(matrix)
282-
data = np.random.randn(2, 9)
284+
data = np.random.randn(2, 10)
283285
img = ci.Cifti2Image(data, hdr)
284286
img.nifti_header.set_intent('NIFTI_INTENT_CONNECTIVITY_DENSE_LABELS')
285287

@@ -299,7 +301,7 @@ def test_dconn():
299301
matrix = ci.Cifti2Matrix()
300302
matrix.append(mapping)
301303
hdr = ci.Cifti2Header(matrix)
302-
data = np.random.randn(9, 9)
304+
data = np.random.randn(10, 10)
303305
img = ci.Cifti2Image(data, hdr)
304306
img.nifti_header.set_intent('NIFTI_INTENT_CONNECTIVITY_DENSE')
305307

@@ -322,7 +324,7 @@ def test_ptseries():
322324
matrix.append(series_map)
323325
matrix.append(parcel_map)
324326
hdr = ci.Cifti2Header(matrix)
325-
data = np.random.randn(13, 3)
327+
data = np.random.randn(13, 4)
326328
img = ci.Cifti2Image(data, hdr)
327329
img.nifti_header.set_intent('NIFTI_INTENT_CONNECTIVITY_PARCELLATED_SERIES')
328330

@@ -344,7 +346,7 @@ def test_pscalar():
344346
matrix.append(scalar_map)
345347
matrix.append(parcel_map)
346348
hdr = ci.Cifti2Header(matrix)
347-
data = np.random.randn(2, 3)
349+
data = np.random.randn(2, 4)
348350
img = ci.Cifti2Image(data, hdr)
349351
img.nifti_header.set_intent('NIFTI_INTENT_CONNECTIVITY_PARCELLATED_SCALAR')
350352

@@ -366,7 +368,7 @@ def test_pdconn():
366368
matrix.append(geometry_map)
367369
matrix.append(parcel_map)
368370
hdr = ci.Cifti2Header(matrix)
369-
data = np.random.randn(2, 3)
371+
data = np.random.randn(10, 4)
370372
img = ci.Cifti2Image(data, hdr)
371373
img.nifti_header.set_intent('NIFTI_INTENT_CONNECTIVITY_PARCELLATED_DENSE')
372374

@@ -388,7 +390,7 @@ def test_dpconn():
388390
matrix.append(parcel_map)
389391
matrix.append(geometry_map)
390392
hdr = ci.Cifti2Header(matrix)
391-
data = np.random.randn(2, 3)
393+
data = np.random.randn(4, 10)
392394
img = ci.Cifti2Image(data, hdr)
393395
img.nifti_header.set_intent('NIFTI_INTENT_CONNECTIVITY_DENSE_PARCELLATED')
394396

@@ -410,7 +412,7 @@ def test_plabel():
410412
matrix.append(label_map)
411413
matrix.append(parcel_map)
412414
hdr = ci.Cifti2Header(matrix)
413-
data = np.random.randn(2, 3)
415+
data = np.random.randn(2, 4)
414416
img = ci.Cifti2Image(data, hdr)
415417

416418
with InTemporaryDirectory():
@@ -429,7 +431,7 @@ def test_pconn():
429431
matrix = ci.Cifti2Matrix()
430432
matrix.append(mapping)
431433
hdr = ci.Cifti2Header(matrix)
432-
data = np.random.randn(3, 3)
434+
data = np.random.randn(4, 4)
433435
img = ci.Cifti2Image(data, hdr)
434436
img.nifti_header.set_intent('NIFTI_INTENT_CONNECTIVITY_PARCELLATED')
435437

@@ -453,7 +455,7 @@ def test_pconnseries():
453455
matrix.append(parcel_map)
454456
matrix.append(series_map)
455457
hdr = ci.Cifti2Header(matrix)
456-
data = np.random.randn(3, 3, 13)
458+
data = np.random.randn(4, 4, 13)
457459
img = ci.Cifti2Image(data, hdr)
458460
img.nifti_header.set_intent('NIFTI_INTENT_CONNECTIVITY_PARCELLATED_'
459461
'PARCELLATED_SERIES')
@@ -479,7 +481,7 @@ def test_pconnscalar():
479481
matrix.append(parcel_map)
480482
matrix.append(scalar_map)
481483
hdr = ci.Cifti2Header(matrix)
482-
data = np.random.randn(3, 3, 13)
484+
data = np.random.randn(4, 4, 2)
483485
img = ci.Cifti2Image(data, hdr)
484486
img.nifti_header.set_intent('NIFTI_INTENT_CONNECTIVITY_PARCELLATED_'
485487
'PARCELLATED_SCALAR')
@@ -496,3 +498,29 @@ def test_pconnscalar():
496498
check_parcel_map(img2.header.matrix.get_index_map(0))
497499
check_scalar_map(img2.header.matrix.get_index_map(2))
498500
del img2
501+
502+
503+
def test_wrong_shape():
504+
scalar_map = create_scalar_map((0, ))
505+
brain_model_map = create_geometry_map((1, ))
506+
507+
matrix = ci.Cifti2Matrix()
508+
matrix.append(scalar_map)
509+
matrix.append(brain_model_map)
510+
hdr = ci.Cifti2Header(matrix)
511+
512+
# correct shape is (2, 10)
513+
for data in (
514+
np.random.randn(1, 11),
515+
np.random.randn(2, 10, 1),
516+
np.random.randn(1, 2, 10),
517+
np.random.randn(3, 10),
518+
np.random.randn(2, 9),
519+
):
520+
with clear_and_catch_warnings():
521+
with error_warnings():
522+
assert_raises(UserWarning, ci.Cifti2Image, data, hdr)
523+
with suppress_warnings():
524+
img = ci.Cifti2Image(data, hdr)
525+
assert_raises(ValueError, img.to_file_map)
526+

0 commit comments

Comments
 (0)