Skip to content

Commit c649934

Browse files
Lucas-rbntericspodKumoLiupre-commit-ci[bot]
authored
Add Barlow Twins loss for representation learning (#7530)
### Description Addition of the BarlowTwinsLoss class. This cost function is introduced in the http://proceedings.mlr.press/v139/zbontar21a/zbontar21a.pdf paper with the aim of disentangling the representations learned on two views of the same sample, making it a powerful tool for multimodal and unsupervised learning. This cost function is similar to the InfoNCE Loss function already implemented in MONAI (https://docs.monai.io/en/latest/_modules/monai/losses/contrastive.html#ContrastiveLoss). However, it differs in several respects: there is no l2-normalisation, but rather a z-normalisation. In addition, rather than working between pairs of embeddings, Barlow Twins seeks to decorrelate the components of the representations. ```math \mathcal{L}_{BT} := \sum_i (1 - \mathcal{C}_{ii})^2 + \lambda \sum_i \sum_{i\neq j} \mathcal{C}_{ij}^2 ``` with $\lambda$ a positive hyperparameters and $\mathcal{C}$ the cross-correlation matrix ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [x] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [x] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Lucas Robinet <[email protected]> Signed-off-by: Lucas Robinet <[email protected]> Co-authored-by: Lucas Robinet <[email protected]> Co-authored-by: Eric Kerfoot <[email protected]> Co-authored-by: YunLiu <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 35c93fd commit c649934

File tree

4 files changed

+199
-0
lines changed

4 files changed

+199
-0
lines changed

docs/source/losses.rst

+5
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,11 @@ Segmentation Losses
7373
.. autoclass:: ContrastiveLoss
7474
:members:
7575

76+
`BarlowTwinsLoss`
77+
~~~~~~~~~~~~~~~~~
78+
.. autoclass:: BarlowTwinsLoss
79+
:members:
80+
7681
`HausdorffDTLoss`
7782
~~~~~~~~~~~~~~~~~
7883
.. autoclass:: HausdorffDTLoss

monai/losses/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from __future__ import annotations
1313

1414
from .adversarial_loss import PatchAdversarialLoss
15+
from .barlow_twins import BarlowTwinsLoss
1516
from .cldice import SoftclDiceLoss, SoftDiceclDiceLoss
1617
from .contrastive import ContrastiveLoss
1718
from .deform import BendingEnergyLoss, DiffusionLoss

monai/losses/barlow_twins.py

+84
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from __future__ import annotations
13+
14+
import torch
15+
from torch.nn.modules.loss import _Loss
16+
17+
18+
class BarlowTwinsLoss(_Loss):
19+
"""
20+
The Barlow Twins cost function takes the representations extracted by a neural network from two
21+
distorted views and seeks to make the cross-correlation matrix of the two representations tend
22+
towards identity. This encourages the neural network to learn similar representations with the least
23+
amount of redundancy. This cost function can be used in particular in multimodal learning to work on
24+
representations from two modalities. The most common use case is for unsupervised learning, where data
25+
augmentations are used to generate 2 distorted views of the same sample to force the encoder to
26+
extract useful features for downstream tasks.
27+
28+
Zbontar, Jure, et al. "Barlow Twins: Self-Supervised Learning via Redundancy Reduction" International
29+
conference on machine learning. PMLR, 2020. (http://proceedings.mlr.press/v139/zbontar21a/zbontar21a.pdf)
30+
31+
Adapted from:
32+
https://github.com/facebookresearch/barlowtwins
33+
34+
"""
35+
36+
def __init__(self, lambd: float = 5e-3) -> None:
37+
"""
38+
Args:
39+
lamb: Can be any float to handle the informativeness and invariance trade-off. Ideally set to 5e-3.
40+
41+
Raises:
42+
ValueError: When an input of dimension length > 2 is passed
43+
ValueError: When input and target are of different shapes
44+
ValueError: When batch size is less than or equal to 1
45+
46+
"""
47+
super().__init__()
48+
self.lambd = lambd
49+
50+
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
51+
"""
52+
Args:
53+
input: the shape should be B[F].
54+
target: the shape should be B[F].
55+
"""
56+
if len(target.shape) > 2 or len(input.shape) > 2:
57+
raise ValueError(
58+
f"Either target or input has dimensions greater than 2 where target "
59+
f"shape is ({target.shape}) and input shape is ({input.shape})"
60+
)
61+
62+
if target.shape != input.shape:
63+
raise ValueError(f"ground truth has differing shape ({target.shape}) from input ({input.shape})")
64+
65+
if target.size(0) <= 1:
66+
raise ValueError(
67+
f"Batch size must be greater than 1 to compute Barlow Twins Loss, but got {target.size(0)}"
68+
)
69+
70+
lambd_tensor = torch.as_tensor(self.lambd).to(input.device)
71+
batch_size = input.shape[0]
72+
73+
# normalize input and target
74+
input_norm = (input - input.mean(0)) / input.std(0).add(1e-6)
75+
target_norm = (target - target.mean(0)) / target.std(0).add(1e-6)
76+
77+
# cross-correlation matrix
78+
c = torch.mm(input_norm.t(), target_norm) / batch_size # input_norm.t() is FxB, target_norm is BxF so c is FxF
79+
80+
# loss
81+
c_diff = (c - torch.eye(c.size(0), device=c.device)).pow_(2) # FxF
82+
c_diff[~torch.eye(c.size(0), device=c.device).bool()] *= lambd_tensor
83+
84+
return c_diff.sum()

tests/test_barlow_twins_loss.py

+109
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from __future__ import annotations
13+
14+
import unittest
15+
16+
import numpy as np
17+
import torch
18+
from parameterized import parameterized
19+
20+
from monai.losses import BarlowTwinsLoss
21+
22+
TEST_CASES = [
23+
[ # shape: (2, 4), (2, 4)
24+
{"lambd": 5e-3},
25+
{
26+
"input": torch.tensor([[1.0, 1.0, 0.0, 0.0], [1.0, 1.0, 0.0, 0.0]]),
27+
"target": torch.tensor([[1.0, 1.0, 0.0, 0.0], [1.0, 1.0, 0.0, 0.0]]),
28+
},
29+
4.0,
30+
],
31+
[ # shape: (2, 4), (2, 4)
32+
{"lambd": 5e-3},
33+
{
34+
"input": torch.tensor([[0.0, 0.0, 1.0, 1.0], [0.0, 0.0, 1.0, 1.0]]),
35+
"target": torch.tensor([[1.0, 1.0, 0.0, 0.0], [1.0, 1.0, 0.0, 0.0]]),
36+
},
37+
4.0,
38+
],
39+
[ # shape: (2, 4), (2, 4)
40+
{"lambd": 5e-3},
41+
{
42+
"input": torch.tensor([[1.0, 0.0, 1.0, 1.0], [0.0, 1.0, 1.0, 0.0]]),
43+
"target": torch.tensor([[1.0, 1.0, 1.0, 0.0], [1.0, 1.0, 0.0, 1.0]]),
44+
},
45+
5.2562,
46+
],
47+
[ # shape: (2, 4), (2, 4)
48+
{"lambd": 5e-4},
49+
{
50+
"input": torch.tensor([[2.0, 3.0, 1.0, 2.0], [0.0, 1.0, 2.0, 5.0]]),
51+
"target": torch.tensor([[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]]),
52+
},
53+
5.0015,
54+
],
55+
[ # shape: (4, 4), (4, 4)
56+
{"lambd": 5e-3},
57+
{
58+
"input": torch.tensor(
59+
[[1.0, 2.0, 1.0, 1.0], [3.0, 1.0, 1.0, 2.0], [1.0, 1.0, 1.0, 1.0], [2.0, 1.0, 1.0, 0.0]]
60+
),
61+
"target": torch.tensor(
62+
[
63+
[0.0, 1.0, -1.0, 0.0],
64+
[1 / 3, 0.0, -2 / 3, 1 / 3],
65+
[-2 / 3, -1.0, 7 / 3, 1 / 3],
66+
[1 / 3, 0.0, 1 / 3, -2 / 3],
67+
]
68+
),
69+
},
70+
1.4736,
71+
],
72+
]
73+
74+
75+
class TestBarlowTwinsLoss(unittest.TestCase):
76+
77+
@parameterized.expand(TEST_CASES)
78+
def test_result(self, input_param, input_data, expected_val):
79+
barlowtwinsloss = BarlowTwinsLoss(**input_param)
80+
result = barlowtwinsloss(**input_data)
81+
np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, atol=1e-4, rtol=1e-4)
82+
83+
def test_ill_shape(self):
84+
loss = BarlowTwinsLoss(lambd=5e-3)
85+
with self.assertRaises(ValueError):
86+
loss(torch.ones((1, 2, 3)), torch.ones((1, 1, 2, 3)))
87+
88+
def test_ill_batch_size(self):
89+
loss = BarlowTwinsLoss(lambd=5e-3)
90+
with self.assertRaises(ValueError):
91+
loss(torch.ones((1, 2)), torch.ones((1, 2)))
92+
93+
def test_with_cuda(self):
94+
loss = BarlowTwinsLoss(lambd=5e-3)
95+
i = torch.ones((2, 10))
96+
j = torch.ones((2, 10))
97+
if torch.cuda.is_available():
98+
i = i.cuda()
99+
j = j.cuda()
100+
output = loss(i, j)
101+
np.testing.assert_allclose(output.detach().cpu().numpy(), 10.0, atol=1e-4, rtol=1e-4)
102+
103+
def check_warning_raised(self):
104+
with self.assertWarns(Warning):
105+
BarlowTwinsLoss(lambd=5e-3, batch_size=1)
106+
107+
108+
if __name__ == "__main__":
109+
unittest.main()

0 commit comments

Comments
 (0)