From 73c1295982c475d3168e6285294c66e0eab7410f Mon Sep 17 00:00:00 2001 From: chaoliu Date: Sun, 10 Dec 2023 20:02:48 -0800 Subject: [PATCH 01/32] add sure loss, its test functions and its documents Signed-off-by: chaoliu --- docs/source/losses.rst | 5 + monai/losses/__init__.py | 1 + monai/losses/sure_loss.py | 211 ++++++++++++++++++++++++++++++++++++++ tests/test_sure_loss.py | 66 ++++++++++++ 4 files changed, 283 insertions(+) create mode 100644 monai/losses/sure_loss.py create mode 100644 tests/test_sure_loss.py diff --git a/docs/source/losses.rst b/docs/source/losses.rst index e929e9d605..91fa495100 100644 --- a/docs/source/losses.rst +++ b/docs/source/losses.rst @@ -134,6 +134,11 @@ Reconstruction Losses .. autoclass:: JukeboxLoss :members: +`SURELoss` +~~~~~~~~~~~~~~ +.. autoclass:: SURELoss + :members: + Loss Wrappers ------------- diff --git a/monai/losses/__init__.py b/monai/losses/__init__.py index 92898c81ca..3345868d4b 100644 --- a/monai/losses/__init__.py +++ b/monai/losses/__init__.py @@ -42,3 +42,4 @@ from .ssim_loss import SSIMLoss from .tversky import TverskyLoss from .unified_focal_loss import AsymmetricUnifiedFocalLoss +from .sure_loss import SURELoss diff --git a/monai/losses/sure_loss.py b/monai/losses/sure_loss.py new file mode 100644 index 0000000000..6f6bb371ee --- /dev/null +++ b/monai/losses/sure_loss.py @@ -0,0 +1,211 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import torch +import torch.nn as nn +from torch.nn.modules.loss import _Loss + + +def complex_diff_abs_loss(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + """ + First compute the difference in the complex domain, + then get the absolute value and take the mse + + Args: + x, y - B, 2, H, W real valued tensors representing complex numbers + or B,1,H,W complex valued tensors + Returns: + l2_loss - scalar + """ + if not x.is_complex(): + x_ = torch.view_as_complex(x.permute(0, 2, 3, 1).contiguous()) + else: + x_ = x + if not y.is_complex(): + y_ = torch.view_as_complex(y.permute(0, 2, 3, 1).contiguous()) + else: + y_ = y + diff = x_ - y_ + l2_loss = nn.functional.mse_loss(torch.abs(diff), torch.zeros_like(torch.abs(diff)), reduction="mean") + return l2_loss + + +def sure_loss_function( + operator: callable, + x: torch.Tensor, + y_pseudo_gt: torch.Tensor, + y_ref: torch.Tensor = None, + eps: float = None, + perturb_noise: torch.Tensor = None, + complex_input: bool = False, +) -> torch.Tensor: + """ + + Args: + operator (function): The operator function that takes in an input + tensor x and returns an output tensor y. We will use this to compute + the divergence. More specifically, we will perturb the input x by a + small amount and compute the divergence between the perturbed output + and the reference output + + x (torch.Tensor): The input tensor of shape (B, C, H, W) to the + operator. For complex input, the shape is (B, 2, H, W) aka C=2 real. + For real input, the shape is (B, 1, H, W) real. + + y_pseudo_gt (torch.Tensor): The pseudo ground truth tensor of shape (B, + C, H, W) used to compute the L2 loss. For complex input, the shape is + (B, 2, H, W) aka C=2 real. For real input, the shape is (B, 1, H, W) + real. + + y_ref (torch.Tensor, optional): The reference output tensor of shape + (B, C, H, W) used to compute the divergence. Defaults to None. For + complex input, the shape is (B, 2, H, W) aka C=2 real. For real input, + the shape is (B, 1, H, W) real. + + eps (float, optional): The perturbation scalar. Defaults to None. + + perturb_noise (torch.Tensor, optional): The noise vector of shape (B, + C, H, W). Defaults to None. For complex input, the shape is (B, 2, H, + W) aka C=2 real. For real input, the shape is (B, 1, H, W) real. + + complex_input(bool, optional): Whether the input is complex or not. + Defaults to False. + + Returns: + sure_loss (torch.Tensor): The SURE loss scalar. + """ + # perturb input + if perturb_noise is None: + perturb_noise = torch.randn_like(x) + if eps is None: + eps = torch.abs(y_pseudo_gt.max()) / 1000 + # get y_ref if not provided + if y_ref is None: + y_ref = operator(x) + + # get perturbed output + x_perturbed = x + eps * perturb_noise + y_perturbed = operator(x_perturbed) + # divergence + divergence = torch.sum(1 / eps * torch.matmul(perturb_noise.permute(0, 1, 3, 2), y_perturbed - y_ref)) + # l2 loss between y_ref, y_pseudo_gt + if complex_input: + l2_loss = complex_diff_abs_loss(y_ref, y_pseudo_gt) + else: + # real input + l2_loss = nn.functional.mse_loss(y_ref, y_pseudo_gt, reduction="mean") + + # sure loss + sure_loss = l2_loss * divergence / (x.shape[0] * x.shape[2] * x.shape[3]) + return sure_loss + + +class SURELoss(_Loss): + """ + Calculate the Stein's Unbiased Risk Estimator (SURE) loss for a given operator. + + This is a differentiable loss function that can be used to train/giude an + operator (e.g. neural network), where the pseudo ground truth is available + but the reference ground truth is not. For example, in the MRI + reconstruction, the pseudo ground truth is the zero-filled reconstruction + and the reference ground truth is the fully sampled reconstruction. Often, + the reference ground truth is not available due to the lack of fully sampled + data. + + The original SURE loss is proposed in [1]. The SURE loss used for guiding + the diffusion model based MRI reconstruction is proposed in [2]. + + Reference + [1] Stein, C.M.: Estimation of the mean of a multivariate normal distribution. Annals of Statistics + [2] B. Ozturkler et al. SMRD: SURE-based Robust MRI Reconstruction with Diffusion Models. + (https://arxiv.org/pdf/2310.01799.pdf) + """ + + def __init__(self, perturb_noise: torch.Tensor = None, eps: float = None) -> None: + """ + Args: + + perturb_noise (torch.Tensor, optional): The noise vector of shape (B, + C, H, W). Defaults to None. For complex input, the shape is (B, 2, H, + W) aka C=2 real. For real input, the shape is (B, 1, H, W) real. + + eps (float, optional): The perturbation scalar. Defaults to None. + + """ + + super().__init__() + self.perturb_noise = perturb_noise + self.eps = eps + + def forward( + self, + operator: callable, + x: torch.Tensor, + y_pseudo_gt: torch.Tensor, + y_ref: torch.Tensor = None, + complex_input: bool = False, + ) -> torch.Tensor: + """ + Args: + operator (function): The operator function that takes in an input tensor + x and returns an output tensor y. We will use this to compute the + divergence. More specifically, we will perturb the input x by a small + amount and compute the divergence between the perturbed output and the + reference output + + x (torch.Tensor): The input tensor of shape (B, C, H, W) to the operator. C=1 or 2: + For complex input, the shape is (B, 2, H, W) aka C=2 real. For real input, the shape is (B, 1, H, W) real. + + y_pseudo_gt (torch.Tensor): The pseudo ground truth tensor of shape + (B, C, H, W) used to compute the L2 loss. C=1 or 2: + For complex input, the shape is (B, 2, H, W) aka C=2 real. For real input, the shape is (B, 1, H, W) real. + + y_ref (torch.Tensor, optional): The reference output tensor of the same shape as y_pseudo_gt + + Returns: + sure_loss (torch.Tensor): The SURE loss scalar. + """ + # TODO: support for C>2 real valued input + + # check inputs + # dim check: + assert x.dim() == 4, "Input tensor x should be 4D." + assert y_pseudo_gt.dim() == 4, "Input tensor y_pseudo_gt should be 4D." + if y_ref is not None: + assert y_ref.dim() == 4, "Input tensor y_ref should be 4D." + + # complex/real check: + if complex_input: + assert ( + x.shape[1] == 2 and y_pseudo_gt.shape[1] == 2 and not x.is_complex() and not y_pseudo_gt.is_complex() + ), "For complex input, the shape is (B, 2, H, W) aka C=2 real or (B, 1, H, W) aka C=1 complex" + if y_ref is not None: + assert ( + y_ref.shape[1] == 2 and not y_ref.is_complex() + ), "For complex input, the shape is (B, 2, H, W) aka C=2 real or (B, 1, H, W) aka C=1 complex" + else: # real input + assert ( + x.shape[1] == 1 and y_pseudo_gt.shape[1] == 1 and not x.is_complex() and not y_pseudo_gt.is_complex() + ), "For real input, the shape is (B, 1, H, W) real." + if y_ref is not None: + assert y_ref.shape[1] == 1 and not y_ref.is_complex(), "For real input, the shape is (B, 1, H, W) real." + + # shape check + assert x.shape == y_pseudo_gt.shape, "Input tensor x and y_pseudo_gt should have the same shape." + if y_ref is not None: + assert y_pseudo_gt.shape == y_ref.shape, "Input tensor y_pseudo_gt and y_ref should have the same shape." + + # compute loss + loss = sure_loss_function(operator, x, y_pseudo_gt, y_ref, self.eps, self.perturb_noise, complex_input) + + return loss \ No newline at end of file diff --git a/tests/test_sure_loss.py b/tests/test_sure_loss.py new file mode 100644 index 0000000000..7162c7d785 --- /dev/null +++ b/tests/test_sure_loss.py @@ -0,0 +1,66 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest +import torch +from monai.losses import SURELoss + + +class TestSURELoss(unittest.TestCase): + def test_real_value(self): + """Test SURELoss with real-valued input: when the input is real value, the loss should be 0.0.""" + sure_loss_real = SURELoss(perturb_noise=torch.zeros(2, 1, 128, 128), eps=0.1) + def operator(x): + return x + y_pseudo_gt = torch.randn(2, 1, 128, 128) + x = torch.randn(2, 1, 128, 128) + loss = sure_loss_real(operator, x, y_pseudo_gt, complex_input=False) + self.assertAlmostEquals(loss.item(), 0.0) + print('real value test passed') + + def test_complex_value(self): + """Test SURELoss with complex-valued input: when the input is complex value, the loss should be 0.0.""" + def operator(x): + return x + sure_loss_complex = SURELoss(perturb_noise=torch.zeros(2,2,128,128), eps=0.1) + y_pseudo_gt = torch.randn(2, 2, 128, 128) + x = torch.randn(2, 2, 128, 128) + loss = sure_loss_complex(operator, x, y_pseudo_gt, complex_input=True) + self.assertAlmostEquals(loss.item(), 0.0) + print('complex value test passed') + + def test_complex_general_input(self): + """Test SURELoss with complex-valued input: when the input is general complex value, the loss should be 0.0.""" + def operator(x): + return x + perturb_noise_real = torch.randn(2,1,128,128) + perturb_noise_complex = torch.zeros(2,2,128,128) + perturb_noise_complex[:,0,:,:] = perturb_noise_real.squeeze() + y_pseudo_gt_real = torch.randn(2, 1, 128, 128) + y_pseudo_gt_complex = torch.zeros(2, 2, 128, 128) + y_pseudo_gt_complex[:,0,:,:] = y_pseudo_gt_real.squeeze() + x_real = torch.randn(2, 1, 128, 128) + x_complex = torch.zeros(2, 2, 128, 128) + x_complex[:,0,:,:] = x_real.squeeze() + + sure_loss_real = SURELoss(perturb_noise=perturb_noise_real, eps=0.1) + sure_loss_complex = SURELoss(perturb_noise=perturb_noise_complex, eps=0.1) + + loss_real = sure_loss_real(operator, x_real, y_pseudo_gt_real, complex_input=False) + loss_complex = sure_loss_complex(operator, x_complex, y_pseudo_gt_complex, complex_input=True) + self.assertAlmostEquals(loss_real.item(), loss_complex.abs().item()) + print('complex general input test passed') + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file From 74205dc0de9340c09215afdf38503e427d2b8890 Mon Sep 17 00:00:00 2001 From: chaoliu Date: Sun, 10 Dec 2023 20:34:37 -0800 Subject: [PATCH 02/32] modified docs Signed-off-by: chaoliu --- monai/losses/__init__.py | 2 +- monai/losses/sure_loss.py | 29 ++++++++--------------------- tests/test_sure_loss.py | 30 +++++++++++++++++++----------- 3 files changed, 28 insertions(+), 33 deletions(-) diff --git a/monai/losses/__init__.py b/monai/losses/__init__.py index 3345868d4b..489454707b 100644 --- a/monai/losses/__init__.py +++ b/monai/losses/__init__.py @@ -40,6 +40,6 @@ from .spatial_mask import MaskedLoss from .spectral_loss import JukeboxLoss from .ssim_loss import SSIMLoss +from .sure_loss import SURELoss from .tversky import TverskyLoss from .unified_focal_loss import AsymmetricUnifiedFocalLoss -from .sure_loss import SURELoss diff --git a/monai/losses/sure_loss.py b/monai/losses/sure_loss.py index 6f6bb371ee..c0754bf743 100644 --- a/monai/losses/sure_loss.py +++ b/monai/losses/sure_loss.py @@ -126,7 +126,9 @@ class SURELoss(_Loss): the diffusion model based MRI reconstruction is proposed in [2]. Reference + [1] Stein, C.M.: Estimation of the mean of a multivariate normal distribution. Annals of Statistics + [2] B. Ozturkler et al. SMRD: SURE-based Robust MRI Reconstruction with Diffusion Models. (https://arxiv.org/pdf/2310.01799.pdf) """ @@ -134,13 +136,8 @@ class SURELoss(_Loss): def __init__(self, perturb_noise: torch.Tensor = None, eps: float = None) -> None: """ Args: - - perturb_noise (torch.Tensor, optional): The noise vector of shape (B, - C, H, W). Defaults to None. For complex input, the shape is (B, 2, H, - W) aka C=2 real. For real input, the shape is (B, 1, H, W) real. - - eps (float, optional): The perturbation scalar. Defaults to None. - + perturb_noise (torch.Tensor, optional): The noise vector of shape (B, C, H, W). Defaults to None. For complex input, the shape is (B, 2, H, W) aka C=2 real. For real input, the shape is (B, 1, H, W) real. + eps (float, optional): The perturbation scalar. Defaults to None. """ super().__init__() @@ -157,19 +154,9 @@ def forward( ) -> torch.Tensor: """ Args: - operator (function): The operator function that takes in an input tensor - x and returns an output tensor y. We will use this to compute the - divergence. More specifically, we will perturb the input x by a small - amount and compute the divergence between the perturbed output and the - reference output - - x (torch.Tensor): The input tensor of shape (B, C, H, W) to the operator. C=1 or 2: - For complex input, the shape is (B, 2, H, W) aka C=2 real. For real input, the shape is (B, 1, H, W) real. - - y_pseudo_gt (torch.Tensor): The pseudo ground truth tensor of shape - (B, C, H, W) used to compute the L2 loss. C=1 or 2: - For complex input, the shape is (B, 2, H, W) aka C=2 real. For real input, the shape is (B, 1, H, W) real. - + operator (function): The operator function that takes in an input tensor x and returns an output tensor y. We will use this to compute the divergence. More specifically, we will perturb the input x by a small amount and compute the divergence between the perturbed output and the reference output + x (torch.Tensor): The input tensor of shape (B, C, H, W) to the operator. C=1 or 2: For complex input, the shape is (B, 2, H, W) aka C=2 real. For real input, the shape is (B, 1, H, W) real. + y_pseudo_gt (torch.Tensor): The pseudo ground truth tensor of shape (B, C, H, W) used to compute the L2 loss. C=1 or 2: For complex input, the shape is (B, 2, H, W) aka C=2 real. For real input, the shape is (B, 1, H, W) real. y_ref (torch.Tensor, optional): The reference output tensor of the same shape as y_pseudo_gt Returns: @@ -208,4 +195,4 @@ def forward( # compute loss loss = sure_loss_function(operator, x, y_pseudo_gt, y_ref, self.eps, self.perturb_noise, complex_input) - return loss \ No newline at end of file + return loss diff --git a/tests/test_sure_loss.py b/tests/test_sure_loss.py index 7162c7d785..a8aea56dc7 100644 --- a/tests/test_sure_loss.py +++ b/tests/test_sure_loss.py @@ -12,7 +12,9 @@ from __future__ import annotations import unittest + import torch + from monai.losses import SURELoss @@ -20,38 +22,44 @@ class TestSURELoss(unittest.TestCase): def test_real_value(self): """Test SURELoss with real-valued input: when the input is real value, the loss should be 0.0.""" sure_loss_real = SURELoss(perturb_noise=torch.zeros(2, 1, 128, 128), eps=0.1) + def operator(x): return x + y_pseudo_gt = torch.randn(2, 1, 128, 128) x = torch.randn(2, 1, 128, 128) loss = sure_loss_real(operator, x, y_pseudo_gt, complex_input=False) self.assertAlmostEquals(loss.item(), 0.0) - print('real value test passed') - + print("real value test passed") + def test_complex_value(self): """Test SURELoss with complex-valued input: when the input is complex value, the loss should be 0.0.""" + def operator(x): return x - sure_loss_complex = SURELoss(perturb_noise=torch.zeros(2,2,128,128), eps=0.1) + + sure_loss_complex = SURELoss(perturb_noise=torch.zeros(2, 2, 128, 128), eps=0.1) y_pseudo_gt = torch.randn(2, 2, 128, 128) x = torch.randn(2, 2, 128, 128) loss = sure_loss_complex(operator, x, y_pseudo_gt, complex_input=True) self.assertAlmostEquals(loss.item(), 0.0) - print('complex value test passed') + print("complex value test passed") def test_complex_general_input(self): """Test SURELoss with complex-valued input: when the input is general complex value, the loss should be 0.0.""" + def operator(x): return x - perturb_noise_real = torch.randn(2,1,128,128) - perturb_noise_complex = torch.zeros(2,2,128,128) - perturb_noise_complex[:,0,:,:] = perturb_noise_real.squeeze() + + perturb_noise_real = torch.randn(2, 1, 128, 128) + perturb_noise_complex = torch.zeros(2, 2, 128, 128) + perturb_noise_complex[:, 0, :, :] = perturb_noise_real.squeeze() y_pseudo_gt_real = torch.randn(2, 1, 128, 128) y_pseudo_gt_complex = torch.zeros(2, 2, 128, 128) - y_pseudo_gt_complex[:,0,:,:] = y_pseudo_gt_real.squeeze() + y_pseudo_gt_complex[:, 0, :, :] = y_pseudo_gt_real.squeeze() x_real = torch.randn(2, 1, 128, 128) x_complex = torch.zeros(2, 2, 128, 128) - x_complex[:,0,:,:] = x_real.squeeze() + x_complex[:, 0, :, :] = x_real.squeeze() sure_loss_real = SURELoss(perturb_noise=perturb_noise_real, eps=0.1) sure_loss_complex = SURELoss(perturb_noise=perturb_noise_complex, eps=0.1) @@ -59,8 +67,8 @@ def operator(x): loss_real = sure_loss_real(operator, x_real, y_pseudo_gt_real, complex_input=False) loss_complex = sure_loss_complex(operator, x_complex, y_pseudo_gt_complex, complex_input=True) self.assertAlmostEquals(loss_real.item(), loss_complex.abs().item()) - print('complex general input test passed') + print("complex general input test passed") if __name__ == "__main__": - unittest.main() \ No newline at end of file + unittest.main() From ad597e554b496d0cc932c89dde273f1e983b171c Mon Sep 17 00:00:00 2001 From: chaoliu Date: Sun, 10 Dec 2023 21:59:32 -0800 Subject: [PATCH 03/32] add conjugate gradient: class, unit test and doc Signed-off-by: chaoliu --- docs/source/networks.rst | 5 + monai/networks/layers/__init__.py | 1 + monai/networks/layers/conjugate_gradient.py | 114 ++++++++++++++++++++ tests/test_conjugate_gradient.py | 55 ++++++++++ 4 files changed, 175 insertions(+) create mode 100644 monai/networks/layers/conjugate_gradient.py create mode 100644 tests/test_conjugate_gradient.py diff --git a/docs/source/networks.rst b/docs/source/networks.rst index 8eada7933f..7a546e0302 100644 --- a/docs/source/networks.rst +++ b/docs/source/networks.rst @@ -407,6 +407,11 @@ Layers ~~~~~~ .. autoclass:: LLTM :members: + +`ConjugateGradient` +~~~~~~~~~~~~~~~~~~~ +.. autoclass:: ConjugateGradient + :members: `Utilities` ~~~~~~~~~~~ diff --git a/monai/networks/layers/__init__.py b/monai/networks/layers/__init__.py index d61ed57f7f..3a6e4aa554 100644 --- a/monai/networks/layers/__init__.py +++ b/monai/networks/layers/__init__.py @@ -11,6 +11,7 @@ from __future__ import annotations +from .conjugate_gradient import ConjugateGradient from .convutils import calculate_out_shape, gaussian_1d, polyval, same_padding, stride_minus_kernel_padding from .drop_path import DropPath from .factories import Act, Conv, Dropout, LayerFactory, Norm, Pad, Pool, split_args diff --git a/monai/networks/layers/conjugate_gradient.py b/monai/networks/layers/conjugate_gradient.py new file mode 100644 index 0000000000..8ad83305d0 --- /dev/null +++ b/monai/networks/layers/conjugate_gradient.py @@ -0,0 +1,114 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import torch +from torch import nn + + +class ConjugateGradient(nn.Module): + """ + Congugate Gradient (CG) solver for linear systems Ax = y. + + For A (linear_op) that is positive definite and self-adjoint, CG is + guaranteed to converge CG is often used to solve linear systems of the form + Ax = y, where A is too large to store explicitly, but can be computed via a + linear operator. + + As a result, here we won't set A explicitly as a matrix, but rather as a + linear operator. For example, A could be a FFT/IFFT operation + """ + + def __init__(self, linear_op: callable, num_iter: int, dbprint: bool = False): + """ + Args: + linear_op: Linear operator + num_iter: Number of iterations to run CG + dbprint [False]: Print residual at each iteration + """ + super(ConjugateGradient, self).__init__() + + self.A = linear_op + self.num_iter = num_iter + self.dbprint = dbprint + + def _zdot(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: + """ + Complex dot product between tensors x1 and x2. + """ + if torch.is_complex(x1): + assert torch.is_complex(x2), "x1 and x2 must both be complex" + return torch.sum(x1.conj() * x2) + else: + return torch.sum(x1 * x2) + + def _zdot_single(self, x: torch.Tensor) -> torch.Tensor: + """ + Complex dot product between tensor x and itself + """ + res = self._zdot(x, x) + if torch.is_complex(res): + return res.real + else: + return res + + def _update(self, iter: int) -> callable: + """ + perform one iteration of the CG method. It takes the current solution x, + the current search direction p, the current residual r, and the old + residual norm rsold as inputs. Then it computes the new solution, search + direction, residual, and residual norm, and returns them. + """ + + def update_fn( + x: torch.Tensor, p: torch.Tensor, r: torch.Tensor, rsold: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + dy = self.A(p) + p_dot_dy = self._zdot(p, dy) + alpha = rsold / p_dot_dy + x = x + alpha * p + r = r - alpha * dy + rsnew = self._zdot_single(r) + beta = rsnew / rsold + rsold = rsnew + p = beta * p + r + + # print residual + if self.dbprint: + print(f"CG Iteration {iter}: {rsnew}") + + return x, p, r, rsold + + return update_fn + + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + """ + run conjugate gradient for num_iter iterations to solve Ax = y + + Args: + x: B H W tensor (real or complex); Initial guess for linear system Ax = y + y: B H W tensor (real or complex); Measurement + + Returns: + x: Solution to Ax = y + """ + # Compute residual + r = y - self.A(x) + rsold = self._zdot_single(r) + p = r + + # Update + for i in range(self.num_iter): + x, p, r, rsold = self._update(i)(x, p, r, rsold) + if rsold < 1e-10: + break + return x diff --git a/tests/test_conjugate_gradient.py b/tests/test_conjugate_gradient.py new file mode 100644 index 0000000000..f068dd6a4c --- /dev/null +++ b/tests/test_conjugate_gradient.py @@ -0,0 +1,55 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch + +from monai.networks.layers import ConjugateGradient + + +class TestConjugateGradient(unittest.TestCase): + def test_real_valued_inverse(self): + """Test ConjugateGradient with real-valued input: when the input is real value, the output should be the inverse of the matrix.""" + A_dim = 3 + A_mat = torch.tensor([[1, 2, 3], [2, 1, 2], [3, 2, 1]], dtype=torch.float) + + def A_op(x): + return A_mat @ x + + cg_solver = ConjugateGradient(A_op, num_iter=100, dbprint=False) + # define the measurement + y = torch.tensor([1, 2, 3], dtype=torch.float) + # solve for x + x = cg_solver(torch.zeros(A_dim), y) + x_ref = torch.linalg.solve(A_mat, y) + # assert torch.allclose(x, x_ref, atol=1e-6), 'CG solver failed to converge to reference solution' + self.assertTrue(torch.allclose(x, x_ref, atol=1e-6)) + print("real value test passed") + + def test_complex_valued_inverse(self): + A_dim = 3 + A_mat = torch.tensor([[1, 2, 3], [2, 1, 2], [3, 2, 1]], dtype=torch.complex64) + + def A_op(x): + return A_mat @ x + + cg_solver = ConjugateGradient(A_op, num_iter=100, dbprint=False) + y = torch.tensor([1, 2, 3], dtype=torch.complex64) + x = cg_solver(torch.zeros(A_dim, dtype=torch.complex64), y) + x_ref = torch.linalg.solve(A_mat, y) + self.assertTrue(torch.allclose(x, x_ref, atol=1e-6)) + + +if __name__ == "__main__": + unittest.main() From daa5889dfa43cd49258220eccfcdceeaaef9958b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 11 Dec 2023 06:23:27 +0000 Subject: [PATCH 04/32] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- docs/source/networks.rst | 2 +- monai/networks/layers/conjugate_gradient.py | 2 +- tests/test_sure_loss.py | 6 +++--- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/source/networks.rst b/docs/source/networks.rst index 7a546e0302..b59c8af5fc 100644 --- a/docs/source/networks.rst +++ b/docs/source/networks.rst @@ -407,7 +407,7 @@ Layers ~~~~~~ .. autoclass:: LLTM :members: - + `ConjugateGradient` ~~~~~~~~~~~~~~~~~~~ .. autoclass:: ConjugateGradient diff --git a/monai/networks/layers/conjugate_gradient.py b/monai/networks/layers/conjugate_gradient.py index 8ad83305d0..327b6bd3b2 100644 --- a/monai/networks/layers/conjugate_gradient.py +++ b/monai/networks/layers/conjugate_gradient.py @@ -35,7 +35,7 @@ def __init__(self, linear_op: callable, num_iter: int, dbprint: bool = False): num_iter: Number of iterations to run CG dbprint [False]: Print residual at each iteration """ - super(ConjugateGradient, self).__init__() + super().__init__() self.A = linear_op self.num_iter = num_iter diff --git a/tests/test_sure_loss.py b/tests/test_sure_loss.py index a8aea56dc7..e4dc972fe4 100644 --- a/tests/test_sure_loss.py +++ b/tests/test_sure_loss.py @@ -29,7 +29,7 @@ def operator(x): y_pseudo_gt = torch.randn(2, 1, 128, 128) x = torch.randn(2, 1, 128, 128) loss = sure_loss_real(operator, x, y_pseudo_gt, complex_input=False) - self.assertAlmostEquals(loss.item(), 0.0) + self.assertAlmostEqual(loss.item(), 0.0) print("real value test passed") def test_complex_value(self): @@ -42,7 +42,7 @@ def operator(x): y_pseudo_gt = torch.randn(2, 2, 128, 128) x = torch.randn(2, 2, 128, 128) loss = sure_loss_complex(operator, x, y_pseudo_gt, complex_input=True) - self.assertAlmostEquals(loss.item(), 0.0) + self.assertAlmostEqual(loss.item(), 0.0) print("complex value test passed") def test_complex_general_input(self): @@ -66,7 +66,7 @@ def operator(x): loss_real = sure_loss_real(operator, x_real, y_pseudo_gt_real, complex_input=False) loss_complex = sure_loss_complex(operator, x_complex, y_pseudo_gt_complex, complex_input=True) - self.assertAlmostEquals(loss_real.item(), loss_complex.abs().item()) + self.assertAlmostEqual(loss_real.item(), loss_complex.abs().item()) print("complex general input test passed") From 6cc6a7a79038a34c770db63e8c005acbf7797acb Mon Sep 17 00:00:00 2001 From: cxlcl Date: Mon, 11 Dec 2023 08:51:15 -0800 Subject: [PATCH 05/32] change the doc conjugate_gradient Signed-off-by: cxlcl --- monai/networks/layers/conjugate_gradient.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/monai/networks/layers/conjugate_gradient.py b/monai/networks/layers/conjugate_gradient.py index 327b6bd3b2..bb1a68517e 100644 --- a/monai/networks/layers/conjugate_gradient.py +++ b/monai/networks/layers/conjugate_gradient.py @@ -43,7 +43,7 @@ def __init__(self, linear_op: callable, num_iter: int, dbprint: bool = False): def _zdot(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: """ - Complex dot product between tensors x1 and x2. + Complex dot product between tensors x1 and x2: sum(x1.*x2) """ if torch.is_complex(x1): assert torch.is_complex(x2), "x1 and x2 must both be complex" @@ -95,8 +95,8 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: run conjugate gradient for num_iter iterations to solve Ax = y Args: - x: B H W tensor (real or complex); Initial guess for linear system Ax = y - y: B H W tensor (real or complex); Measurement + x: tensor (real or complex); Initial guess for linear system Ax = y. The size of x should be applicable to the linear operator. For example, if the linear operator is FFT, then x is HCHW; if the linear operator is a matrix multiplication, then x is a vector + y: tensor (real or complex); Measurement. Same size as x Returns: x: Solution to Ax = y From 254637741eec2907b9374deb11245d7f3bf0f586 Mon Sep 17 00:00:00 2001 From: chaoliu Date: Sun, 28 Jan 2024 18:02:55 -0800 Subject: [PATCH 06/32] fix CI error Signed-off-by: chaoliu --- monai/apps/detection/utils/anchor_utils.py | 8 ++++++-- monai/data/decathlon_datalist.py | 6 ++---- monai/losses/image_dissimilarity.py | 4 +--- monai/losses/sure_loss.py | 20 +++++++++++--------- monai/networks/layers/conjugate_gradient.py | 6 ++++-- monai/transforms/utility/dictionary.py | 6 +++--- monai/utils/dist.py | 9 +++------ monai/utils/misc.py | 6 ++---- tests/test_hilbert_transform.py | 20 +++++++++++--------- tests/test_spacing.py | 8 +++++--- tests/test_sure_loss.py | 2 +- 11 files changed, 49 insertions(+), 46 deletions(-) diff --git a/monai/apps/detection/utils/anchor_utils.py b/monai/apps/detection/utils/anchor_utils.py index baaa7ce874..283169b653 100644 --- a/monai/apps/detection/utils/anchor_utils.py +++ b/monai/apps/detection/utils/anchor_utils.py @@ -369,8 +369,12 @@ class AnchorGeneratorWithAnchorShape(AnchorGenerator): def __init__( self, feature_map_scales: Sequence[int] | Sequence[float] = (1, 2, 4, 8), - base_anchor_shapes: Sequence[Sequence[int]] - | Sequence[Sequence[float]] = ((32, 32, 32), (48, 20, 20), (20, 48, 20), (20, 20, 48)), + base_anchor_shapes: Sequence[Sequence[int]] | Sequence[Sequence[float]] = ( + (32, 32, 32), + (48, 20, 20), + (20, 48, 20), + (20, 20, 48), + ), indexing: str = "ij", ) -> None: nn.Module.__init__(self) diff --git a/monai/data/decathlon_datalist.py b/monai/data/decathlon_datalist.py index 6f163f972e..14765dcfaa 100644 --- a/monai/data/decathlon_datalist.py +++ b/monai/data/decathlon_datalist.py @@ -24,13 +24,11 @@ @overload -def _compute_path(base_dir: PathLike, element: PathLike, check_path: bool = False) -> str: - ... +def _compute_path(base_dir: PathLike, element: PathLike, check_path: bool = False) -> str: ... @overload -def _compute_path(base_dir: PathLike, element: list[PathLike], check_path: bool = False) -> list[str]: - ... +def _compute_path(base_dir: PathLike, element: list[PathLike], check_path: bool = False) -> list[str]: ... def _compute_path(base_dir, element, check_path=False): diff --git a/monai/losses/image_dissimilarity.py b/monai/losses/image_dissimilarity.py index 39219e059a..dd132770ec 100644 --- a/monai/losses/image_dissimilarity.py +++ b/monai/losses/image_dissimilarity.py @@ -277,9 +277,7 @@ def parzen_windowing_b_spline(self, img: torch.Tensor, order: int) -> tuple[torc if order == 0: weight = weight + (sample_bin_matrix < 0.5) + (sample_bin_matrix == 0.5) * 0.5 elif order == 3: - weight = ( - weight + (4 - 6 * sample_bin_matrix**2 + 3 * sample_bin_matrix**3) * (sample_bin_matrix < 1) / 6 - ) + weight = weight + (4 - 6 * sample_bin_matrix**2 + 3 * sample_bin_matrix**3) * (sample_bin_matrix < 1) / 6 weight = weight + (2 - sample_bin_matrix) ** 3 * (sample_bin_matrix >= 1) * (sample_bin_matrix < 2) / 6 else: raise ValueError(f"Do not support b-spline {order}-order parzen windowing") diff --git a/monai/losses/sure_loss.py b/monai/losses/sure_loss.py index c0754bf743..866af1968b 100644 --- a/monai/losses/sure_loss.py +++ b/monai/losses/sure_loss.py @@ -11,6 +11,8 @@ from __future__ import annotations +from typing import Callable, Optional + import torch import torch.nn as nn from torch.nn.modules.loss import _Loss @@ -41,13 +43,13 @@ def complex_diff_abs_loss(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: def sure_loss_function( - operator: callable, + operator: Callable, x: torch.Tensor, y_pseudo_gt: torch.Tensor, - y_ref: torch.Tensor = None, - eps: float = None, - perturb_noise: torch.Tensor = None, - complex_input: bool = False, + y_ref: Optional[torch.Tensor] = None, + eps: Optional[float] = None, + perturb_noise: Optional[torch.Tensor] = None, + complex_input: Optional[bool] = False, ) -> torch.Tensor: """ @@ -133,7 +135,7 @@ class SURELoss(_Loss): (https://arxiv.org/pdf/2310.01799.pdf) """ - def __init__(self, perturb_noise: torch.Tensor = None, eps: float = None) -> None: + def __init__(self, perturb_noise: Optional[torch.Tensor] = None, eps: Optional[float] = None) -> None: """ Args: perturb_noise (torch.Tensor, optional): The noise vector of shape (B, C, H, W). Defaults to None. For complex input, the shape is (B, 2, H, W) aka C=2 real. For real input, the shape is (B, 1, H, W) real. @@ -146,11 +148,11 @@ def __init__(self, perturb_noise: torch.Tensor = None, eps: float = None) -> Non def forward( self, - operator: callable, + operator: Callable, x: torch.Tensor, y_pseudo_gt: torch.Tensor, - y_ref: torch.Tensor = None, - complex_input: bool = False, + y_ref: Optional[torch.Tensor] = None, + complex_input: Optional[bool] = False, ) -> torch.Tensor: """ Args: diff --git a/monai/networks/layers/conjugate_gradient.py b/monai/networks/layers/conjugate_gradient.py index bb1a68517e..de578e95e5 100644 --- a/monai/networks/layers/conjugate_gradient.py +++ b/monai/networks/layers/conjugate_gradient.py @@ -11,6 +11,8 @@ from __future__ import annotations +from typing import Callable, Optional + import torch from torch import nn @@ -28,7 +30,7 @@ class ConjugateGradient(nn.Module): linear operator. For example, A could be a FFT/IFFT operation """ - def __init__(self, linear_op: callable, num_iter: int, dbprint: bool = False): + def __init__(self, linear_op: Callable, num_iter: int, dbprint: Optional[bool] = False): """ Args: linear_op: Linear operator @@ -61,7 +63,7 @@ def _zdot_single(self, x: torch.Tensor) -> torch.Tensor: else: return res - def _update(self, iter: int) -> callable: + def _update(self, iter: int) -> Callable: """ perform one iteration of the CG method. It takes the current solution x, the current search direction p, the current residual r, and the old diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index ec10bd8537..1cd9ff6323 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -1765,9 +1765,9 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N LabelToMaskD = LabelToMaskDict = LabelToMaskd FgBgToIndicesD = FgBgToIndicesDict = FgBgToIndicesd ClassesToIndicesD = ClassesToIndicesDict = ClassesToIndicesd -ConvertToMultiChannelBasedOnBratsClassesD = ( - ConvertToMultiChannelBasedOnBratsClassesDict -) = ConvertToMultiChannelBasedOnBratsClassesd +ConvertToMultiChannelBasedOnBratsClassesD = ConvertToMultiChannelBasedOnBratsClassesDict = ( + ConvertToMultiChannelBasedOnBratsClassesd +) AddExtremePointsChannelD = AddExtremePointsChannelDict = AddExtremePointsChanneld TorchVisionD = TorchVisionDict = TorchVisiond RandTorchVisionD = RandTorchVisionDict = RandTorchVisiond diff --git a/monai/utils/dist.py b/monai/utils/dist.py index 20f09628ac..2418b43591 100644 --- a/monai/utils/dist.py +++ b/monai/utils/dist.py @@ -50,18 +50,15 @@ def get_dist_device(): @overload -def evenly_divisible_all_gather(data: torch.Tensor, concat: Literal[True]) -> torch.Tensor: - ... +def evenly_divisible_all_gather(data: torch.Tensor, concat: Literal[True]) -> torch.Tensor: ... @overload -def evenly_divisible_all_gather(data: torch.Tensor, concat: Literal[False]) -> list[torch.Tensor]: - ... +def evenly_divisible_all_gather(data: torch.Tensor, concat: Literal[False]) -> list[torch.Tensor]: ... @overload -def evenly_divisible_all_gather(data: torch.Tensor, concat: bool) -> torch.Tensor | list[torch.Tensor]: - ... +def evenly_divisible_all_gather(data: torch.Tensor, concat: bool) -> torch.Tensor | list[torch.Tensor]: ... def evenly_divisible_all_gather(data: torch.Tensor, concat: bool = True) -> torch.Tensor | list[torch.Tensor]: diff --git a/monai/utils/misc.py b/monai/utils/misc.py index d6ff370f69..2a5c5da136 100644 --- a/monai/utils/misc.py +++ b/monai/utils/misc.py @@ -103,13 +103,11 @@ def star_zip_with(op, *vals): @overload -def first(iterable: Iterable[T], default: T) -> T: - ... +def first(iterable: Iterable[T], default: T) -> T: ... @overload -def first(iterable: Iterable[T]) -> T | None: - ... +def first(iterable: Iterable[T]) -> T | None: ... def first(iterable: Iterable[T], default: T | None = None) -> T | None: diff --git a/tests/test_hilbert_transform.py b/tests/test_hilbert_transform.py index 4c49aecd8b..68fa0b1192 100644 --- a/tests/test_hilbert_transform.py +++ b/tests/test_hilbert_transform.py @@ -180,15 +180,17 @@ def test_value(self, arguments, image, expected_data, atol): @SkipIfNoModule("torch.fft") class TestHilbertTransformGPU(unittest.TestCase): @parameterized.expand( - [] - if not torch.cuda.is_available() - else [ - TEST_CASE_1D_SINE_GPU, - TEST_CASE_2D_SINE_GPU, - TEST_CASE_3D_SINE_GPU, - TEST_CASE_1D_2CH_SINE_GPU, - TEST_CASE_2D_2CH_SINE_GPU, - ], + ( + [] + if not torch.cuda.is_available() + else [ + TEST_CASE_1D_SINE_GPU, + TEST_CASE_2D_SINE_GPU, + TEST_CASE_3D_SINE_GPU, + TEST_CASE_1D_2CH_SINE_GPU, + TEST_CASE_2D_2CH_SINE_GPU, + ] + ), skip_on_empty=True, ) def test_value(self, arguments, image, expected_data, atol): diff --git a/tests/test_spacing.py b/tests/test_spacing.py index 1ff1518297..8b664641d7 100644 --- a/tests/test_spacing.py +++ b/tests/test_spacing.py @@ -74,9 +74,11 @@ torch.ones((1, 2, 1, 2)), # data torch.tensor([[2, 1, 0, 4], [-1, -3, 0, 5], [0, 0, 2.0, 5], [0, 0, 0, 1]]), {}, - torch.tensor([[[[0.75, 0.75]], [[0.75, 0.75]], [[0.75, 0.75]]]]) - if USE_COMPILED - else torch.tensor([[[[0.95527864, 0.95527864]], [[1.0, 1.0]], [[1.0, 1.0]]]]), + ( + torch.tensor([[[[0.75, 0.75]], [[0.75, 0.75]], [[0.75, 0.75]]]]) + if USE_COMPILED + else torch.tensor([[[[0.95527864, 0.95527864]], [[1.0, 1.0]], [[1.0, 1.0]]]]) + ), *device, ] ) diff --git a/tests/test_sure_loss.py b/tests/test_sure_loss.py index e4dc972fe4..388acea35c 100644 --- a/tests/test_sure_loss.py +++ b/tests/test_sure_loss.py @@ -66,7 +66,7 @@ def operator(x): loss_real = sure_loss_real(operator, x_real, y_pseudo_gt_real, complex_input=False) loss_complex = sure_loss_complex(operator, x_complex, y_pseudo_gt_complex, complex_input=True) - self.assertAlmostEqual(loss_real.item(), loss_complex.abs().item()) + self.assertAlmostEqual(loss_real.item(), loss_complex.abs().item(), places=6) print("complex general input test passed") From 2b5b895b19f029fa9dbdcd77060701da3d621be2 Mon Sep 17 00:00:00 2001 From: chaoliu Date: Sun, 28 Jan 2024 19:12:05 -0800 Subject: [PATCH 07/32] fix CI error Signed-off-by: chaoliu --- monai/losses/sure_loss.py | 40 +++++++++++++++------ monai/networks/layers/conjugate_gradient.py | 6 +++- tests/test_conjugate_gradient.py | 31 ++++++++-------- 3 files changed, 50 insertions(+), 27 deletions(-) diff --git a/monai/losses/sure_loss.py b/monai/losses/sure_loss.py index 866af1968b..6938ac9cf2 100644 --- a/monai/losses/sure_loss.py +++ b/monai/losses/sure_loss.py @@ -11,7 +11,7 @@ from __future__ import annotations -from typing import Callable, Optional +from typing import Callable, Optional, Union import torch import torch.nn as nn @@ -47,7 +47,7 @@ def sure_loss_function( x: torch.Tensor, y_pseudo_gt: torch.Tensor, y_ref: Optional[torch.Tensor] = None, - eps: Optional[float] = None, + eps: Optional[float] = -1.0, perturb_noise: Optional[torch.Tensor] = None, complex_input: Optional[bool] = False, ) -> torch.Tensor: @@ -74,7 +74,8 @@ def sure_loss_function( complex input, the shape is (B, 2, H, W) aka C=2 real. For real input, the shape is (B, 1, H, W) real. - eps (float, optional): The perturbation scalar. Defaults to None. + eps (float, optional): The perturbation scalar. Set to -1 to set it + automatically estimated based on y_pseudo_gtk perturb_noise (torch.Tensor, optional): The noise vector of shape (B, C, H, W). Defaults to None. For complex input, the shape is (B, 2, H, @@ -89,8 +90,8 @@ def sure_loss_function( # perturb input if perturb_noise is None: perturb_noise = torch.randn_like(x) - if eps is None: - eps = torch.abs(y_pseudo_gt.max()) / 1000 + if eps == -1.0: + eps = float(torch.abs(y_pseudo_gt.max())) / 1000 # get y_ref if not provided if y_ref is None: y_ref = operator(x) @@ -99,7 +100,7 @@ def sure_loss_function( x_perturbed = x + eps * perturb_noise y_perturbed = operator(x_perturbed) # divergence - divergence = torch.sum(1 / eps * torch.matmul(perturb_noise.permute(0, 1, 3, 2), y_perturbed - y_ref)) + divergence = torch.sum(1.0 / eps * torch.matmul(perturb_noise.permute(0, 1, 3, 2), y_perturbed - y_ref)) # l2 loss between y_ref, y_pseudo_gt if complex_input: l2_loss = complex_diff_abs_loss(y_ref, y_pseudo_gt) @@ -138,7 +139,11 @@ class SURELoss(_Loss): def __init__(self, perturb_noise: Optional[torch.Tensor] = None, eps: Optional[float] = None) -> None: """ Args: - perturb_noise (torch.Tensor, optional): The noise vector of shape (B, C, H, W). Defaults to None. For complex input, the shape is (B, 2, H, W) aka C=2 real. For real input, the shape is (B, 1, H, W) real. + perturb_noise (torch.Tensor, optional): The noise vector of shape + (B, C, H, W). Defaults to None. For complex input, the shape is (B, + 2, H, W) aka C=2 real. For real input, the shape is (B, 1, H, W) + real. + eps (float, optional): The perturbation scalar. Defaults to None. """ @@ -156,10 +161,23 @@ def forward( ) -> torch.Tensor: """ Args: - operator (function): The operator function that takes in an input tensor x and returns an output tensor y. We will use this to compute the divergence. More specifically, we will perturb the input x by a small amount and compute the divergence between the perturbed output and the reference output - x (torch.Tensor): The input tensor of shape (B, C, H, W) to the operator. C=1 or 2: For complex input, the shape is (B, 2, H, W) aka C=2 real. For real input, the shape is (B, 1, H, W) real. - y_pseudo_gt (torch.Tensor): The pseudo ground truth tensor of shape (B, C, H, W) used to compute the L2 loss. C=1 or 2: For complex input, the shape is (B, 2, H, W) aka C=2 real. For real input, the shape is (B, 1, H, W) real. - y_ref (torch.Tensor, optional): The reference output tensor of the same shape as y_pseudo_gt + operator (function): The operator function that takes in an input + tensor x and returns an output tensor y. We will use this to compute + the divergence. More specifically, we will perturb the input x by a + small amount and compute the divergence between the perturbed output + and the reference output + + x (torch.Tensor): The input tensor of shape (B, C, H, W) to the + operator. C=1 or 2: For complex input, the shape is (B, 2, H, W) aka + C=2 real. For real input, the shape is (B, 1, H, W) real. + + y_pseudo_gt (torch.Tensor): The pseudo ground truth tensor of shape + (B, C, H, W) used to compute the L2 loss. C=1 or 2: For complex + input, the shape is (B, 2, H, W) aka C=2 real. For real input, the + shape is (B, 1, H, W) real. + + y_ref (torch.Tensor, optional): The reference output tensor of the + same shape as y_pseudo_gt Returns: sure_loss (torch.Tensor): The SURE loss scalar. diff --git a/monai/networks/layers/conjugate_gradient.py b/monai/networks/layers/conjugate_gradient.py index de578e95e5..2dd2b01cbb 100644 --- a/monai/networks/layers/conjugate_gradient.py +++ b/monai/networks/layers/conjugate_gradient.py @@ -97,7 +97,11 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: run conjugate gradient for num_iter iterations to solve Ax = y Args: - x: tensor (real or complex); Initial guess for linear system Ax = y. The size of x should be applicable to the linear operator. For example, if the linear operator is FFT, then x is HCHW; if the linear operator is a matrix multiplication, then x is a vector + x: tensor (real or complex); Initial guess for linear system Ax = y. + The size of x should be applicable to the linear operator. For + example, if the linear operator is FFT, then x is HCHW; if the + linear operator is a matrix multiplication, then x is a vector + y: tensor (real or complex); Measurement. Same size as x Returns: diff --git a/tests/test_conjugate_gradient.py b/tests/test_conjugate_gradient.py index f068dd6a4c..c0b67ea213 100644 --- a/tests/test_conjugate_gradient.py +++ b/tests/test_conjugate_gradient.py @@ -20,34 +20,35 @@ class TestConjugateGradient(unittest.TestCase): def test_real_valued_inverse(self): - """Test ConjugateGradient with real-valued input: when the input is real value, the output should be the inverse of the matrix.""" - A_dim = 3 - A_mat = torch.tensor([[1, 2, 3], [2, 1, 2], [3, 2, 1]], dtype=torch.float) + """Test ConjugateGradient with real-valued input: when the input is real + value, the output should be the inverse of the matrix.""" + a_dim = 3 + a_mat = torch.tensor([[1, 2, 3], [2, 1, 2], [3, 2, 1]], dtype=torch.float) - def A_op(x): - return A_mat @ x + def a_op(x): + return a_mat @ x - cg_solver = ConjugateGradient(A_op, num_iter=100, dbprint=False) + cg_solver = ConjugateGradient(a_op, num_iter=100, dbprint=False) # define the measurement y = torch.tensor([1, 2, 3], dtype=torch.float) # solve for x - x = cg_solver(torch.zeros(A_dim), y) - x_ref = torch.linalg.solve(A_mat, y) + x = cg_solver(torch.zeros(a_dim), y) + x_ref = torch.linalg.solve(a_mat, y) # assert torch.allclose(x, x_ref, atol=1e-6), 'CG solver failed to converge to reference solution' self.assertTrue(torch.allclose(x, x_ref, atol=1e-6)) print("real value test passed") def test_complex_valued_inverse(self): - A_dim = 3 - A_mat = torch.tensor([[1, 2, 3], [2, 1, 2], [3, 2, 1]], dtype=torch.complex64) + a_dim = 3 + a_mat = torch.tensor([[1, 2, 3], [2, 1, 2], [3, 2, 1]], dtype=torch.complex64) - def A_op(x): - return A_mat @ x + def a_op(x): + return a_mat @ x - cg_solver = ConjugateGradient(A_op, num_iter=100, dbprint=False) + cg_solver = ConjugateGradient(a_op, num_iter=100, dbprint=False) y = torch.tensor([1, 2, 3], dtype=torch.complex64) - x = cg_solver(torch.zeros(A_dim, dtype=torch.complex64), y) - x_ref = torch.linalg.solve(A_mat, y) + x = cg_solver(torch.zeros(a_dim, dtype=torch.complex64), y) + x_ref = torch.linalg.solve(a_mat, y) self.assertTrue(torch.allclose(x, x_ref, atol=1e-6)) From 29dbaa0b15ae60f06156b6c2014565f3c08b024b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 29 Jan 2024 03:13:24 +0000 Subject: [PATCH 08/32] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/losses/sure_loss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/losses/sure_loss.py b/monai/losses/sure_loss.py index 6938ac9cf2..244792a21c 100644 --- a/monai/losses/sure_loss.py +++ b/monai/losses/sure_loss.py @@ -11,7 +11,7 @@ from __future__ import annotations -from typing import Callable, Optional, Union +from typing import Callable, Optional import torch import torch.nn as nn From 17f5ea0be35e5630de8bf00a484057ee730b7210 Mon Sep 17 00:00:00 2001 From: chaoliu Date: Sun, 28 Jan 2024 19:36:57 -0800 Subject: [PATCH 09/32] fix CI error Signed-off-by: chaoliu --- monai/data/decathlon_datalist.py | 6 ++++-- monai/losses/sure_loss.py | 17 +++++++---------- monai/utils/dist.py | 9 ++++++--- monai/utils/misc.py | 6 ++++-- 4 files changed, 21 insertions(+), 17 deletions(-) diff --git a/monai/data/decathlon_datalist.py b/monai/data/decathlon_datalist.py index 14765dcfaa..6f163f972e 100644 --- a/monai/data/decathlon_datalist.py +++ b/monai/data/decathlon_datalist.py @@ -24,11 +24,13 @@ @overload -def _compute_path(base_dir: PathLike, element: PathLike, check_path: bool = False) -> str: ... +def _compute_path(base_dir: PathLike, element: PathLike, check_path: bool = False) -> str: + ... @overload -def _compute_path(base_dir: PathLike, element: list[PathLike], check_path: bool = False) -> list[str]: ... +def _compute_path(base_dir: PathLike, element: list[PathLike], check_path: bool = False) -> list[str]: + ... def _compute_path(base_dir, element, check_path=False): diff --git a/monai/losses/sure_loss.py b/monai/losses/sure_loss.py index 6938ac9cf2..fbb23d8c6d 100644 --- a/monai/losses/sure_loss.py +++ b/monai/losses/sure_loss.py @@ -64,8 +64,8 @@ def sure_loss_function( operator. For complex input, the shape is (B, 2, H, W) aka C=2 real. For real input, the shape is (B, 1, H, W) real. - y_pseudo_gt (torch.Tensor): The pseudo ground truth tensor of shape (B, - C, H, W) used to compute the L2 loss. For complex input, the shape is + y_pseudo_gt (torch.Tensor): The pseudo ground truth tensor of shape + (B, C, H, W) used to compute the L2 loss. For complex input, the shape is (B, 2, H, W) aka C=2 real. For real input, the shape is (B, 1, H, W) real. @@ -77,9 +77,9 @@ def sure_loss_function( eps (float, optional): The perturbation scalar. Set to -1 to set it automatically estimated based on y_pseudo_gtk - perturb_noise (torch.Tensor, optional): The noise vector of shape (B, - C, H, W). Defaults to None. For complex input, the shape is (B, 2, H, - W) aka C=2 real. For real input, the shape is (B, 1, H, W) real. + perturb_noise (torch.Tensor, optional): The noise vector of shape (B, C, H, W). + Defaults to None. For complex input, the shape is (B, 2, H, W) aka C=2 real. + For real input, the shape is (B, 1, H, W) real. complex_input(bool, optional): Whether the input is complex or not. Defaults to False. @@ -140,13 +140,11 @@ def __init__(self, perturb_noise: Optional[torch.Tensor] = None, eps: Optional[f """ Args: perturb_noise (torch.Tensor, optional): The noise vector of shape - (B, C, H, W). Defaults to None. For complex input, the shape is (B, - 2, H, W) aka C=2 real. For real input, the shape is (B, 1, H, W) - real. + (B, C, H, W). Defaults to None. For complex input, the shape is (B, 2, H, W) aka C=2 real. + For real input, the shape is (B, 1, H, W) real. eps (float, optional): The perturbation scalar. Defaults to None. """ - super().__init__() self.perturb_noise = perturb_noise self.eps = eps @@ -182,7 +180,6 @@ def forward( Returns: sure_loss (torch.Tensor): The SURE loss scalar. """ - # TODO: support for C>2 real valued input # check inputs # dim check: diff --git a/monai/utils/dist.py b/monai/utils/dist.py index 2418b43591..20f09628ac 100644 --- a/monai/utils/dist.py +++ b/monai/utils/dist.py @@ -50,15 +50,18 @@ def get_dist_device(): @overload -def evenly_divisible_all_gather(data: torch.Tensor, concat: Literal[True]) -> torch.Tensor: ... +def evenly_divisible_all_gather(data: torch.Tensor, concat: Literal[True]) -> torch.Tensor: + ... @overload -def evenly_divisible_all_gather(data: torch.Tensor, concat: Literal[False]) -> list[torch.Tensor]: ... +def evenly_divisible_all_gather(data: torch.Tensor, concat: Literal[False]) -> list[torch.Tensor]: + ... @overload -def evenly_divisible_all_gather(data: torch.Tensor, concat: bool) -> torch.Tensor | list[torch.Tensor]: ... +def evenly_divisible_all_gather(data: torch.Tensor, concat: bool) -> torch.Tensor | list[torch.Tensor]: + ... def evenly_divisible_all_gather(data: torch.Tensor, concat: bool = True) -> torch.Tensor | list[torch.Tensor]: diff --git a/monai/utils/misc.py b/monai/utils/misc.py index 2a5c5da136..f4afa09178 100644 --- a/monai/utils/misc.py +++ b/monai/utils/misc.py @@ -103,11 +103,13 @@ def star_zip_with(op, *vals): @overload -def first(iterable: Iterable[T], default: T) -> T: ... +def first(iterable: Iterable[T], default: T) -> T: + ... @overload -def first(iterable: Iterable[T]) -> T | None: ... +def first(iterable: Iterable[T]) -> T | None: + ... def first(iterable: Iterable[T], default: T | None = None) -> T | None: From 1f75eb1e849d88c6d360c646125dc7761ef20878 Mon Sep 17 00:00:00 2001 From: chaoliu Date: Sun, 28 Jan 2024 19:39:28 -0800 Subject: [PATCH 10/32] fix CI error Signed-off-by: chaoliu --- monai/losses/sure_loss.py | 1 - 1 file changed, 1 deletion(-) diff --git a/monai/losses/sure_loss.py b/monai/losses/sure_loss.py index 97de7ee58c..39d4b06e78 100644 --- a/monai/losses/sure_loss.py +++ b/monai/losses/sure_loss.py @@ -52,7 +52,6 @@ def sure_loss_function( complex_input: Optional[bool] = False, ) -> torch.Tensor: """ - Args: operator (function): The operator function that takes in an input tensor x and returns an output tensor y. We will use this to compute From e8b34e77df61596f2979d1b70dee77bc3b424c62 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 29 Jan 2024 03:40:49 +0000 Subject: [PATCH 11/32] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/utils/misc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/utils/misc.py b/monai/utils/misc.py index f4afa09178..d6ff370f69 100644 --- a/monai/utils/misc.py +++ b/monai/utils/misc.py @@ -108,7 +108,7 @@ def first(iterable: Iterable[T], default: T) -> T: @overload -def first(iterable: Iterable[T]) -> T | None: +def first(iterable: Iterable[T]) -> T | None: ... From 7b0e98284bee292232a9374b4ddf0f76b2864c81 Mon Sep 17 00:00:00 2001 From: chaoliu Date: Sun, 28 Jan 2024 20:01:56 -0800 Subject: [PATCH 12/32] fix CI error Signed-off-by: chaoliu --- monai/data/decathlon_datalist.py | 6 ++---- monai/utils/dist.py | 9 +++------ monai/utils/misc.py | 6 ++---- 3 files changed, 7 insertions(+), 14 deletions(-) diff --git a/monai/data/decathlon_datalist.py b/monai/data/decathlon_datalist.py index 6f163f972e..14765dcfaa 100644 --- a/monai/data/decathlon_datalist.py +++ b/monai/data/decathlon_datalist.py @@ -24,13 +24,11 @@ @overload -def _compute_path(base_dir: PathLike, element: PathLike, check_path: bool = False) -> str: - ... +def _compute_path(base_dir: PathLike, element: PathLike, check_path: bool = False) -> str: ... @overload -def _compute_path(base_dir: PathLike, element: list[PathLike], check_path: bool = False) -> list[str]: - ... +def _compute_path(base_dir: PathLike, element: list[PathLike], check_path: bool = False) -> list[str]: ... def _compute_path(base_dir, element, check_path=False): diff --git a/monai/utils/dist.py b/monai/utils/dist.py index 20f09628ac..2418b43591 100644 --- a/monai/utils/dist.py +++ b/monai/utils/dist.py @@ -50,18 +50,15 @@ def get_dist_device(): @overload -def evenly_divisible_all_gather(data: torch.Tensor, concat: Literal[True]) -> torch.Tensor: - ... +def evenly_divisible_all_gather(data: torch.Tensor, concat: Literal[True]) -> torch.Tensor: ... @overload -def evenly_divisible_all_gather(data: torch.Tensor, concat: Literal[False]) -> list[torch.Tensor]: - ... +def evenly_divisible_all_gather(data: torch.Tensor, concat: Literal[False]) -> list[torch.Tensor]: ... @overload -def evenly_divisible_all_gather(data: torch.Tensor, concat: bool) -> torch.Tensor | list[torch.Tensor]: - ... +def evenly_divisible_all_gather(data: torch.Tensor, concat: bool) -> torch.Tensor | list[torch.Tensor]: ... def evenly_divisible_all_gather(data: torch.Tensor, concat: bool = True) -> torch.Tensor | list[torch.Tensor]: diff --git a/monai/utils/misc.py b/monai/utils/misc.py index f4afa09178..2a5c5da136 100644 --- a/monai/utils/misc.py +++ b/monai/utils/misc.py @@ -103,13 +103,11 @@ def star_zip_with(op, *vals): @overload -def first(iterable: Iterable[T], default: T) -> T: - ... +def first(iterable: Iterable[T], default: T) -> T: ... @overload -def first(iterable: Iterable[T]) -> T | None: - ... +def first(iterable: Iterable[T]) -> T | None: ... def first(iterable: Iterable[T], default: T | None = None) -> T | None: From ea8147495e7473c117c01421131ceafbc31f023b Mon Sep 17 00:00:00 2001 From: chaoliu Date: Sun, 28 Jan 2024 20:56:45 -0800 Subject: [PATCH 13/32] fix CI: after running ./runtest --autofix Signed-off-by: chaoliu --- monai/data/decathlon_datalist.py | 6 ++---- monai/utils/dist.py | 9 +++------ monai/utils/misc.py | 6 ++---- 3 files changed, 7 insertions(+), 14 deletions(-) diff --git a/monai/data/decathlon_datalist.py b/monai/data/decathlon_datalist.py index 6f163f972e..14765dcfaa 100644 --- a/monai/data/decathlon_datalist.py +++ b/monai/data/decathlon_datalist.py @@ -24,13 +24,11 @@ @overload -def _compute_path(base_dir: PathLike, element: PathLike, check_path: bool = False) -> str: - ... +def _compute_path(base_dir: PathLike, element: PathLike, check_path: bool = False) -> str: ... @overload -def _compute_path(base_dir: PathLike, element: list[PathLike], check_path: bool = False) -> list[str]: - ... +def _compute_path(base_dir: PathLike, element: list[PathLike], check_path: bool = False) -> list[str]: ... def _compute_path(base_dir, element, check_path=False): diff --git a/monai/utils/dist.py b/monai/utils/dist.py index 20f09628ac..2418b43591 100644 --- a/monai/utils/dist.py +++ b/monai/utils/dist.py @@ -50,18 +50,15 @@ def get_dist_device(): @overload -def evenly_divisible_all_gather(data: torch.Tensor, concat: Literal[True]) -> torch.Tensor: - ... +def evenly_divisible_all_gather(data: torch.Tensor, concat: Literal[True]) -> torch.Tensor: ... @overload -def evenly_divisible_all_gather(data: torch.Tensor, concat: Literal[False]) -> list[torch.Tensor]: - ... +def evenly_divisible_all_gather(data: torch.Tensor, concat: Literal[False]) -> list[torch.Tensor]: ... @overload -def evenly_divisible_all_gather(data: torch.Tensor, concat: bool) -> torch.Tensor | list[torch.Tensor]: - ... +def evenly_divisible_all_gather(data: torch.Tensor, concat: bool) -> torch.Tensor | list[torch.Tensor]: ... def evenly_divisible_all_gather(data: torch.Tensor, concat: bool = True) -> torch.Tensor | list[torch.Tensor]: diff --git a/monai/utils/misc.py b/monai/utils/misc.py index d6ff370f69..2a5c5da136 100644 --- a/monai/utils/misc.py +++ b/monai/utils/misc.py @@ -103,13 +103,11 @@ def star_zip_with(op, *vals): @overload -def first(iterable: Iterable[T], default: T) -> T: - ... +def first(iterable: Iterable[T], default: T) -> T: ... @overload -def first(iterable: Iterable[T]) -> T | None: - ... +def first(iterable: Iterable[T]) -> T | None: ... def first(iterable: Iterable[T], default: T | None = None) -> T | None: From 58ee71297524a2ca5dc5ea2579fac079da593119 Mon Sep 17 00:00:00 2001 From: chaoliu Date: Tue, 30 Jan 2024 10:54:04 -0800 Subject: [PATCH 14/32] fix CI: after running ./runtest --autofix Signed-off-by: chaoliu --- monai/losses/sure_loss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/losses/sure_loss.py b/monai/losses/sure_loss.py index 39d4b06e78..6da7c743c2 100644 --- a/monai/losses/sure_loss.py +++ b/monai/losses/sure_loss.py @@ -99,7 +99,7 @@ def sure_loss_function( x_perturbed = x + eps * perturb_noise y_perturbed = operator(x_perturbed) # divergence - divergence = torch.sum(1.0 / eps * torch.matmul(perturb_noise.permute(0, 1, 3, 2), y_perturbed - y_ref)) + divergence = torch.sum(1.0 / eps * torch.matmul(perturb_noise.permute(0, 1, 3, 2), y_perturbed - y_ref)) # type: ignore # l2 loss between y_ref, y_pseudo_gt if complex_input: l2_loss = complex_diff_abs_loss(y_ref, y_pseudo_gt) From 9755aa580ab152f3fa1d6c59c6123c061dfe7670 Mon Sep 17 00:00:00 2001 From: chaoliu Date: Tue, 30 Jan 2024 12:58:23 -0800 Subject: [PATCH 15/32] fix CI: after running ./runtest --autofix Signed-off-by: chaoliu --- monai/losses/sure_loss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/losses/sure_loss.py b/monai/losses/sure_loss.py index 6da7c743c2..3b813850a0 100644 --- a/monai/losses/sure_loss.py +++ b/monai/losses/sure_loss.py @@ -99,7 +99,7 @@ def sure_loss_function( x_perturbed = x + eps * perturb_noise y_perturbed = operator(x_perturbed) # divergence - divergence = torch.sum(1.0 / eps * torch.matmul(perturb_noise.permute(0, 1, 3, 2), y_perturbed - y_ref)) # type: ignore + divergence = torch.sum(1.0 / eps * torch.matmul(perturb_noise.permute(0, 1, 3, 2), y_perturbed - y_ref)) # type: ignore # l2 loss between y_ref, y_pseudo_gt if complex_input: l2_loss = complex_diff_abs_loss(y_ref, y_pseudo_gt) From e2aad7403a4f98fd03812134114f90afe15071e6 Mon Sep 17 00:00:00 2001 From: cxlcl Date: Sat, 24 Feb 2024 09:23:58 -0800 Subject: [PATCH 16/32] Update monai/losses/sure_loss.py Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Signed-off-by: cxlcl --- monai/losses/sure_loss.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/monai/losses/sure_loss.py b/monai/losses/sure_loss.py index 3b813850a0..8372002457 100644 --- a/monai/losses/sure_loss.py +++ b/monai/losses/sure_loss.py @@ -30,16 +30,12 @@ def complex_diff_abs_loss(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: l2_loss - scalar """ if not x.is_complex(): - x_ = torch.view_as_complex(x.permute(0, 2, 3, 1).contiguous()) - else: - x_ = x + x = torch.view_as_complex(x.permute(0, 2, 3, 1).contiguous()) if not y.is_complex(): - y_ = torch.view_as_complex(y.permute(0, 2, 3, 1).contiguous()) - else: - y_ = y - diff = x_ - y_ - l2_loss = nn.functional.mse_loss(torch.abs(diff), torch.zeros_like(torch.abs(diff)), reduction="mean") - return l2_loss + y = torch.view_as_complex(y.permute(0, 2, 3, 1).contiguous()) + + diff = torch.abs(x - y) + return nn.functional.mse_loss(diff, torch.zeros_like(diff), reduction="mean") def sure_loss_function( From f9f8e6ff28f12aa6b922f10fabfe45c47185ad59 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 24 Feb 2024 17:24:21 +0000 Subject: [PATCH 17/32] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/losses/sure_loss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/losses/sure_loss.py b/monai/losses/sure_loss.py index 8372002457..56a4a3b140 100644 --- a/monai/losses/sure_loss.py +++ b/monai/losses/sure_loss.py @@ -33,7 +33,7 @@ def complex_diff_abs_loss(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: x = torch.view_as_complex(x.permute(0, 2, 3, 1).contiguous()) if not y.is_complex(): y = torch.view_as_complex(y.permute(0, 2, 3, 1).contiguous()) - + diff = torch.abs(x - y) return nn.functional.mse_loss(diff, torch.zeros_like(diff), reduction="mean") From b995cb42b7f482e8a092bdf160bfc0ef2a4325ee Mon Sep 17 00:00:00 2001 From: chaoliu Date: Sat, 24 Feb 2024 10:04:54 -0800 Subject: [PATCH 18/32] modifications based on revision Signed-off-by: chaoliu --- monai/losses/sure_loss.py | 38 +++------ monai/networks/layers/conjugate_gradient.py | 88 +++++++++------------ tests/test_conjugate_gradient.py | 5 +- 3 files changed, 53 insertions(+), 78 deletions(-) diff --git a/monai/losses/sure_loss.py b/monai/losses/sure_loss.py index 56a4a3b140..fc828b77c8 100644 --- a/monai/losses/sure_loss.py +++ b/monai/losses/sure_loss.py @@ -176,33 +176,17 @@ def forward( sure_loss (torch.Tensor): The SURE loss scalar. """ - # check inputs - # dim check: - assert x.dim() == 4, "Input tensor x should be 4D." - assert y_pseudo_gt.dim() == 4, "Input tensor y_pseudo_gt should be 4D." - if y_ref is not None: - assert y_ref.dim() == 4, "Input tensor y_ref should be 4D." - - # complex/real check: - if complex_input: - assert ( - x.shape[1] == 2 and y_pseudo_gt.shape[1] == 2 and not x.is_complex() and not y_pseudo_gt.is_complex() - ), "For complex input, the shape is (B, 2, H, W) aka C=2 real or (B, 1, H, W) aka C=1 complex" - if y_ref is not None: - assert ( - y_ref.shape[1] == 2 and not y_ref.is_complex() - ), "For complex input, the shape is (B, 2, H, W) aka C=2 real or (B, 1, H, W) aka C=1 complex" - else: # real input - assert ( - x.shape[1] == 1 and y_pseudo_gt.shape[1] == 1 and not x.is_complex() and not y_pseudo_gt.is_complex() - ), "For real input, the shape is (B, 1, H, W) real." - if y_ref is not None: - assert y_ref.shape[1] == 1 and not y_ref.is_complex(), "For real input, the shape is (B, 1, H, W) real." - - # shape check - assert x.shape == y_pseudo_gt.shape, "Input tensor x and y_pseudo_gt should have the same shape." - if y_ref is not None: - assert y_pseudo_gt.shape == y_ref.shape, "Input tensor y_pseudo_gt and y_ref should have the same shape." + # check inputs shapes + if x.dim() != 4: + raise ValueError("Input tensor x should be 4D.") + if y_pseudo_gt.dim() != 4: + raise ValueError("Input tensor y_pseudo_gt should be 4D.") + if y_ref is not None and y_ref.dim() != 4: + raise ValueError("Input tensor y_ref should be 4D.") + if x.shape != y_pseudo_gt.shape: + raise ValueError("Input tensor x and y_pseudo_gt should have the same shape.") + if y_ref is not None and y_pseudo_gt.shape != y_ref.shape: + raise ValueError("Input tensor y_pseudo_gt and y_ref should have the same shape.") # compute loss loss = sure_loss_function(operator, x, y_pseudo_gt, y_ref, self.eps, self.perturb_noise, complex_input) diff --git a/monai/networks/layers/conjugate_gradient.py b/monai/networks/layers/conjugate_gradient.py index 2dd2b01cbb..a47d48b504 100644 --- a/monai/networks/layers/conjugate_gradient.py +++ b/monai/networks/layers/conjugate_gradient.py @@ -16,12 +16,32 @@ import torch from torch import nn +def _zdot(x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: + """ + Complex dot product between tensors x1 and x2: sum(x1.*x2) + """ + if torch.is_complex(x1): + assert torch.is_complex(x2), "x1 and x2 must both be complex" + return torch.sum(x1.conj() * x2) + else: + return torch.sum(x1 * x2) + +def _zdot_single(x: torch.Tensor) -> torch.Tensor: + """ + Complex dot product between tensor x and itself + """ + res = _zdot(x, x) + if torch.is_complex(res): + return res.real + else: + return res + class ConjugateGradient(nn.Module): """ Congugate Gradient (CG) solver for linear systems Ax = y. - For A (linear_op) that is positive definite and self-adjoint, CG is + For linear_op that is positive definite and self-adjoint, CG is guaranteed to converge CG is often used to solve linear systems of the form Ax = y, where A is too large to store explicitly, but can be computed via a linear operator. @@ -30,40 +50,20 @@ class ConjugateGradient(nn.Module): linear operator. For example, A could be a FFT/IFFT operation """ - def __init__(self, linear_op: Callable, num_iter: int, dbprint: Optional[bool] = False): + def __init__(self, linear_op: Callable, num_iter: int,): """ Args: linear_op: Linear operator num_iter: Number of iterations to run CG - dbprint [False]: Print residual at each iteration """ super().__init__() - self.A = linear_op + self.linear_op = linear_op self.num_iter = num_iter - self.dbprint = dbprint - - def _zdot(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: - """ - Complex dot product between tensors x1 and x2: sum(x1.*x2) - """ - if torch.is_complex(x1): - assert torch.is_complex(x2), "x1 and x2 must both be complex" - return torch.sum(x1.conj() * x2) - else: - return torch.sum(x1 * x2) - def _zdot_single(self, x: torch.Tensor) -> torch.Tensor: - """ - Complex dot product between tensor x and itself - """ - res = self._zdot(x, x) - if torch.is_complex(res): - return res.real - else: - return res - - def _update(self, iter: int) -> Callable: + def update( + self, x: torch.Tensor, p: torch.Tensor, r: torch.Tensor, rsold: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ perform one iteration of the CG method. It takes the current solution x, the current search direction p, the current residual r, and the old @@ -71,26 +71,16 @@ def _update(self, iter: int) -> Callable: direction, residual, and residual norm, and returns them. """ - def update_fn( - x: torch.Tensor, p: torch.Tensor, r: torch.Tensor, rsold: torch.Tensor - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - dy = self.A(p) - p_dot_dy = self._zdot(p, dy) - alpha = rsold / p_dot_dy - x = x + alpha * p - r = r - alpha * dy - rsnew = self._zdot_single(r) - beta = rsnew / rsold - rsold = rsnew - p = beta * p + r - - # print residual - if self.dbprint: - print(f"CG Iteration {iter}: {rsnew}") - - return x, p, r, rsold - - return update_fn + dy = self.linear_op(p) + p_dot_dy = _zdot(p, dy) + alpha = rsold / p_dot_dy + x = x + alpha * p + r = r - alpha * dy + rsnew = _zdot_single(r) + beta = rsnew / rsold + rsold = rsnew + p = beta * p + r + return x, p, r, rsold def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: """ @@ -108,13 +98,13 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: x: Solution to Ax = y """ # Compute residual - r = y - self.A(x) - rsold = self._zdot_single(r) + r = y - self.linear_op(x) + rsold = _zdot_single(r) p = r # Update for i in range(self.num_iter): - x, p, r, rsold = self._update(i)(x, p, r, rsold) + x, p, r, rsold = self.update(x, p, r, rsold) if rsold < 1e-10: break return x diff --git a/tests/test_conjugate_gradient.py b/tests/test_conjugate_gradient.py index c0b67ea213..0b428999b2 100644 --- a/tests/test_conjugate_gradient.py +++ b/tests/test_conjugate_gradient.py @@ -28,7 +28,7 @@ def test_real_valued_inverse(self): def a_op(x): return a_mat @ x - cg_solver = ConjugateGradient(a_op, num_iter=100, dbprint=False) + cg_solver = ConjugateGradient(a_op, num_iter=100,) # define the measurement y = torch.tensor([1, 2, 3], dtype=torch.float) # solve for x @@ -45,11 +45,12 @@ def test_complex_valued_inverse(self): def a_op(x): return a_mat @ x - cg_solver = ConjugateGradient(a_op, num_iter=100, dbprint=False) + cg_solver = ConjugateGradient(a_op, num_iter=100,) y = torch.tensor([1, 2, 3], dtype=torch.complex64) x = cg_solver(torch.zeros(a_dim, dtype=torch.complex64), y) x_ref = torch.linalg.solve(a_mat, y) self.assertTrue(torch.allclose(x, x_ref, atol=1e-6)) + print("complex value test passed") if __name__ == "__main__": From 7a5c7124f3f9aa3a1e3958a512cb3e1c6865bc54 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 24 Feb 2024 18:06:09 +0000 Subject: [PATCH 19/32] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/losses/sure_loss.py | 2 +- monai/networks/layers/conjugate_gradient.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/losses/sure_loss.py b/monai/losses/sure_loss.py index fc828b77c8..ff3a255c99 100644 --- a/monai/losses/sure_loss.py +++ b/monai/losses/sure_loss.py @@ -177,7 +177,7 @@ def forward( """ # check inputs shapes - if x.dim() != 4: + if x.dim() != 4: raise ValueError("Input tensor x should be 4D.") if y_pseudo_gt.dim() != 4: raise ValueError("Input tensor y_pseudo_gt should be 4D.") diff --git a/monai/networks/layers/conjugate_gradient.py b/monai/networks/layers/conjugate_gradient.py index a47d48b504..c46caa73d5 100644 --- a/monai/networks/layers/conjugate_gradient.py +++ b/monai/networks/layers/conjugate_gradient.py @@ -11,7 +11,7 @@ from __future__ import annotations -from typing import Callable, Optional +from typing import Callable import torch from torch import nn From 007e779d950ef85e1d41ffc5b5b4dcf934e2319e Mon Sep 17 00:00:00 2001 From: chaoliu Date: Sat, 24 Feb 2024 10:06:13 -0800 Subject: [PATCH 20/32] Modifications based on revision Signed-off-by: chaoliu --- monai/losses/sure_loss.py | 2 +- monai/networks/layers/conjugate_gradient.py | 6 ++++-- tests/test_conjugate_gradient.py | 4 ++-- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/monai/losses/sure_loss.py b/monai/losses/sure_loss.py index fc828b77c8..ff3a255c99 100644 --- a/monai/losses/sure_loss.py +++ b/monai/losses/sure_loss.py @@ -177,7 +177,7 @@ def forward( """ # check inputs shapes - if x.dim() != 4: + if x.dim() != 4: raise ValueError("Input tensor x should be 4D.") if y_pseudo_gt.dim() != 4: raise ValueError("Input tensor y_pseudo_gt should be 4D.") diff --git a/monai/networks/layers/conjugate_gradient.py b/monai/networks/layers/conjugate_gradient.py index a47d48b504..b78b666019 100644 --- a/monai/networks/layers/conjugate_gradient.py +++ b/monai/networks/layers/conjugate_gradient.py @@ -16,6 +16,7 @@ import torch from torch import nn + def _zdot(x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: """ Complex dot product between tensors x1 and x2: sum(x1.*x2) @@ -26,6 +27,7 @@ def _zdot(x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: else: return torch.sum(x1 * x2) + def _zdot_single(x: torch.Tensor) -> torch.Tensor: """ Complex dot product between tensor x and itself @@ -50,7 +52,7 @@ class ConjugateGradient(nn.Module): linear operator. For example, A could be a FFT/IFFT operation """ - def __init__(self, linear_op: Callable, num_iter: int,): + def __init__(self, linear_op: Callable, num_iter: int): """ Args: linear_op: Linear operator @@ -63,7 +65,7 @@ def __init__(self, linear_op: Callable, num_iter: int,): def update( self, x: torch.Tensor, p: torch.Tensor, r: torch.Tensor, rsold: torch.Tensor - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ perform one iteration of the CG method. It takes the current solution x, the current search direction p, the current residual r, and the old diff --git a/tests/test_conjugate_gradient.py b/tests/test_conjugate_gradient.py index 0b428999b2..1391985de3 100644 --- a/tests/test_conjugate_gradient.py +++ b/tests/test_conjugate_gradient.py @@ -28,7 +28,7 @@ def test_real_valued_inverse(self): def a_op(x): return a_mat @ x - cg_solver = ConjugateGradient(a_op, num_iter=100,) + cg_solver = ConjugateGradient(a_op, num_iter=100) # define the measurement y = torch.tensor([1, 2, 3], dtype=torch.float) # solve for x @@ -45,7 +45,7 @@ def test_complex_valued_inverse(self): def a_op(x): return a_mat @ x - cg_solver = ConjugateGradient(a_op, num_iter=100,) + cg_solver = ConjugateGradient(a_op, num_iter=100) y = torch.tensor([1, 2, 3], dtype=torch.complex64) x = cg_solver(torch.zeros(a_dim, dtype=torch.complex64), y) x_ref = torch.linalg.solve(a_mat, y) From b7198fb24280b7c4595da4cff2133caa664f8d87 Mon Sep 17 00:00:00 2001 From: cxlcl Date: Thu, 7 Mar 2024 18:58:30 -0800 Subject: [PATCH 21/32] Update docs/source/losses.rst Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Signed-off-by: cxlcl --- docs/source/losses.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/losses.rst b/docs/source/losses.rst index 91fa495100..9ecd1eea3a 100644 --- a/docs/source/losses.rst +++ b/docs/source/losses.rst @@ -135,7 +135,7 @@ Reconstruction Losses :members: `SURELoss` -~~~~~~~~~~~~~~ +~~~~~~~~~~ .. autoclass:: SURELoss :members: From 5590abc3bd4d94fcb0cb28b5974425f844a10cd8 Mon Sep 17 00:00:00 2001 From: cxlcl Date: Thu, 7 Mar 2024 18:58:47 -0800 Subject: [PATCH 22/32] Update monai/losses/sure_loss.py Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Signed-off-by: cxlcl --- monai/losses/sure_loss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/losses/sure_loss.py b/monai/losses/sure_loss.py index ff3a255c99..c270e76eb4 100644 --- a/monai/losses/sure_loss.py +++ b/monai/losses/sure_loss.py @@ -112,7 +112,7 @@ class SURELoss(_Loss): """ Calculate the Stein's Unbiased Risk Estimator (SURE) loss for a given operator. - This is a differentiable loss function that can be used to train/giude an + This is a differentiable loss function that can be used to train/guide an operator (e.g. neural network), where the pseudo ground truth is available but the reference ground truth is not. For example, in the MRI reconstruction, the pseudo ground truth is the zero-filled reconstruction From 281b78083978ec66aee3c05ac5bf5807b6e2a01a Mon Sep 17 00:00:00 2001 From: cxlcl Date: Thu, 7 Mar 2024 18:58:55 -0800 Subject: [PATCH 23/32] Update monai/losses/sure_loss.py Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Signed-off-by: cxlcl --- monai/losses/sure_loss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/losses/sure_loss.py b/monai/losses/sure_loss.py index c270e76eb4..f8fc74b8d0 100644 --- a/monai/losses/sure_loss.py +++ b/monai/losses/sure_loss.py @@ -178,7 +178,7 @@ def forward( # check inputs shapes if x.dim() != 4: - raise ValueError("Input tensor x should be 4D.") + raise ValueError(f"Input tensor x should be 4D, got {x.dim()}.") if y_pseudo_gt.dim() != 4: raise ValueError("Input tensor y_pseudo_gt should be 4D.") if y_ref is not None and y_ref.dim() != 4: From 36ca0d408e41a70a0e19432a2adf5f5bfd4c6a3d Mon Sep 17 00:00:00 2001 From: cxlcl Date: Thu, 7 Mar 2024 18:59:02 -0800 Subject: [PATCH 24/32] Update monai/losses/sure_loss.py Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Signed-off-by: cxlcl --- monai/losses/sure_loss.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/monai/losses/sure_loss.py b/monai/losses/sure_loss.py index f8fc74b8d0..01e0554eeb 100644 --- a/monai/losses/sure_loss.py +++ b/monai/losses/sure_loss.py @@ -180,13 +180,13 @@ def forward( if x.dim() != 4: raise ValueError(f"Input tensor x should be 4D, got {x.dim()}.") if y_pseudo_gt.dim() != 4: - raise ValueError("Input tensor y_pseudo_gt should be 4D.") + raise ValueError(f"Input tensor y_pseudo_gt should be 4D, but got {y_pseudo_gt.dim()}.") if y_ref is not None and y_ref.dim() != 4: - raise ValueError("Input tensor y_ref should be 4D.") + raise ValueError(f"Input tensor y_ref should be 4D, but got {y_ref.dim()}.") if x.shape != y_pseudo_gt.shape: - raise ValueError("Input tensor x and y_pseudo_gt should have the same shape.") + raise ValueError(f"Input tensor x and y_pseudo_gt should have the same shape, but got x shape {x.shape}, y_pseudo_gt shape {y_pseudo_gt.shape}.") if y_ref is not None and y_pseudo_gt.shape != y_ref.shape: - raise ValueError("Input tensor y_pseudo_gt and y_ref should have the same shape.") + raise ValueError(f"Input tensor y_pseudo_gt and y_ref should have the same shape, but got y_pseudo_gt shape {y_pseudo_gt.shape}, y_ref shape {y_ref.shape}.") # compute loss loss = sure_loss_function(operator, x, y_pseudo_gt, y_ref, self.eps, self.perturb_noise, complex_input) From 7b8c39df888e89d0fab2752977ece32021bcf5d5 Mon Sep 17 00:00:00 2001 From: cxlcl Date: Thu, 7 Mar 2024 18:59:11 -0800 Subject: [PATCH 25/32] Update tests/test_conjugate_gradient.py Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Signed-off-by: cxlcl --- tests/test_conjugate_gradient.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_conjugate_gradient.py b/tests/test_conjugate_gradient.py index 1391985de3..a09416897e 100644 --- a/tests/test_conjugate_gradient.py +++ b/tests/test_conjugate_gradient.py @@ -36,7 +36,6 @@ def a_op(x): x_ref = torch.linalg.solve(a_mat, y) # assert torch.allclose(x, x_ref, atol=1e-6), 'CG solver failed to converge to reference solution' self.assertTrue(torch.allclose(x, x_ref, atol=1e-6)) - print("real value test passed") def test_complex_valued_inverse(self): a_dim = 3 From 20e81a751b46018f02601d64b7f54f701f3e5406 Mon Sep 17 00:00:00 2001 From: cxlcl Date: Thu, 7 Mar 2024 18:59:18 -0800 Subject: [PATCH 26/32] Update tests/test_sure_loss.py Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Signed-off-by: cxlcl --- tests/test_sure_loss.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_sure_loss.py b/tests/test_sure_loss.py index 388acea35c..28ccd25516 100644 --- a/tests/test_sure_loss.py +++ b/tests/test_sure_loss.py @@ -43,7 +43,6 @@ def operator(x): x = torch.randn(2, 2, 128, 128) loss = sure_loss_complex(operator, x, y_pseudo_gt, complex_input=True) self.assertAlmostEqual(loss.item(), 0.0) - print("complex value test passed") def test_complex_general_input(self): """Test SURELoss with complex-valued input: when the input is general complex value, the loss should be 0.0.""" From 008920b9b7ceec0bd16870db26c484ba1836b0b4 Mon Sep 17 00:00:00 2001 From: cxlcl Date: Thu, 7 Mar 2024 18:59:30 -0800 Subject: [PATCH 27/32] Update monai/networks/layers/conjugate_gradient.py Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Signed-off-by: cxlcl --- monai/networks/layers/conjugate_gradient.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/networks/layers/conjugate_gradient.py b/monai/networks/layers/conjugate_gradient.py index 0e6cf8531b..93a45930d7 100644 --- a/monai/networks/layers/conjugate_gradient.py +++ b/monai/networks/layers/conjugate_gradient.py @@ -105,7 +105,7 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: p = r # Update - for i in range(self.num_iter): + for _i in range(self.num_iter): x, p, r, rsold = self.update(x, p, r, rsold) if rsold < 1e-10: break From e5b9d13fab98150abea7e248aaa7dedb4b1f7448 Mon Sep 17 00:00:00 2001 From: cxlcl Date: Thu, 7 Mar 2024 18:59:43 -0800 Subject: [PATCH 28/32] Update tests/test_sure_loss.py Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Signed-off-by: cxlcl --- tests/test_sure_loss.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_sure_loss.py b/tests/test_sure_loss.py index 28ccd25516..91815c4f2e 100644 --- a/tests/test_sure_loss.py +++ b/tests/test_sure_loss.py @@ -66,7 +66,6 @@ def operator(x): loss_real = sure_loss_real(operator, x_real, y_pseudo_gt_real, complex_input=False) loss_complex = sure_loss_complex(operator, x_complex, y_pseudo_gt_complex, complex_input=True) self.assertAlmostEqual(loss_real.item(), loss_complex.abs().item(), places=6) - print("complex general input test passed") if __name__ == "__main__": From 31fb45666294976f27bdbf1f16e80707392cc02c Mon Sep 17 00:00:00 2001 From: cxlcl Date: Thu, 7 Mar 2024 18:59:53 -0800 Subject: [PATCH 29/32] Update tests/test_sure_loss.py Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Signed-off-by: cxlcl --- tests/test_sure_loss.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_sure_loss.py b/tests/test_sure_loss.py index 91815c4f2e..945da657bf 100644 --- a/tests/test_sure_loss.py +++ b/tests/test_sure_loss.py @@ -30,7 +30,6 @@ def operator(x): x = torch.randn(2, 1, 128, 128) loss = sure_loss_real(operator, x, y_pseudo_gt, complex_input=False) self.assertAlmostEqual(loss.item(), 0.0) - print("real value test passed") def test_complex_value(self): """Test SURELoss with complex-valued input: when the input is complex value, the loss should be 0.0.""" From 3971581407fe415675c7571ae9f696d58c0fcb20 Mon Sep 17 00:00:00 2001 From: cxlcl Date: Thu, 7 Mar 2024 19:00:04 -0800 Subject: [PATCH 30/32] Update tests/test_conjugate_gradient.py Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Signed-off-by: cxlcl --- tests/test_conjugate_gradient.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_conjugate_gradient.py b/tests/test_conjugate_gradient.py index a09416897e..239dbe3ecd 100644 --- a/tests/test_conjugate_gradient.py +++ b/tests/test_conjugate_gradient.py @@ -49,7 +49,6 @@ def a_op(x): x = cg_solver(torch.zeros(a_dim, dtype=torch.complex64), y) x_ref = torch.linalg.solve(a_mat, y) self.assertTrue(torch.allclose(x, x_ref, atol=1e-6)) - print("complex value test passed") if __name__ == "__main__": From 81b9434bb24d4f895fb7cf826a04165c40fa0e5e Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Fri, 22 Mar 2024 21:17:10 +0800 Subject: [PATCH 31/32] fix flake8 Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/losses/sure_loss.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/monai/losses/sure_loss.py b/monai/losses/sure_loss.py index 01e0554eeb..a0a6c2c9f4 100644 --- a/monai/losses/sure_loss.py +++ b/monai/losses/sure_loss.py @@ -184,9 +184,13 @@ def forward( if y_ref is not None and y_ref.dim() != 4: raise ValueError(f"Input tensor y_ref should be 4D, but got {y_ref.dim()}.") if x.shape != y_pseudo_gt.shape: - raise ValueError(f"Input tensor x and y_pseudo_gt should have the same shape, but got x shape {x.shape}, y_pseudo_gt shape {y_pseudo_gt.shape}.") + raise ValueError( + f"Input tensor x and y_pseudo_gt should have the same shape, but got x shape {x.shape}, y_pseudo_gt shape {y_pseudo_gt.shape}." + ) if y_ref is not None and y_pseudo_gt.shape != y_ref.shape: - raise ValueError(f"Input tensor y_pseudo_gt and y_ref should have the same shape, but got y_pseudo_gt shape {y_pseudo_gt.shape}, y_ref shape {y_ref.shape}.") + raise ValueError( + f"Input tensor y_pseudo_gt and y_ref should have the same shape, but got y_pseudo_gt shape {y_pseudo_gt.shape}, y_ref shape {y_ref.shape}." + ) # compute loss loss = sure_loss_function(operator, x, y_pseudo_gt, y_ref, self.eps, self.perturb_noise, complex_input) From 2c78f4ccb00d4ff73b66fcf592e1ebd5a845c50c Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Fri, 22 Mar 2024 21:20:06 +0800 Subject: [PATCH 32/32] fix ci Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/losses/sure_loss.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/monai/losses/sure_loss.py b/monai/losses/sure_loss.py index a0a6c2c9f4..ebf25613a6 100644 --- a/monai/losses/sure_loss.py +++ b/monai/losses/sure_loss.py @@ -185,11 +185,13 @@ def forward( raise ValueError(f"Input tensor y_ref should be 4D, but got {y_ref.dim()}.") if x.shape != y_pseudo_gt.shape: raise ValueError( - f"Input tensor x and y_pseudo_gt should have the same shape, but got x shape {x.shape}, y_pseudo_gt shape {y_pseudo_gt.shape}." + f"Input tensor x and y_pseudo_gt should have the same shape, but got x shape {x.shape}, " + f"y_pseudo_gt shape {y_pseudo_gt.shape}." ) if y_ref is not None and y_pseudo_gt.shape != y_ref.shape: raise ValueError( - f"Input tensor y_pseudo_gt and y_ref should have the same shape, but got y_pseudo_gt shape {y_pseudo_gt.shape}, y_ref shape {y_ref.shape}." + f"Input tensor y_pseudo_gt and y_ref should have the same shape, but got y_pseudo_gt shape {y_pseudo_gt.shape}, " + f"y_ref shape {y_ref.shape}." ) # compute loss