Skip to content

Commit ec63e06

Browse files
authored
Update to use log_sigmoid in FocalLoss (#7534)
Fixes #7533 ### Description A few sentences describing the changes proposed in this pull request. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: YunLiu <[email protected]>
1 parent 6b7568d commit ec63e06

File tree

2 files changed

+3
-4
lines changed

2 files changed

+3
-4
lines changed

monai/losses/focal_loss.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -234,9 +234,8 @@ def sigmoid_focal_loss(
234234
"""
235235
# computing binary cross entropy with logits
236236
# equivalent to F.binary_cross_entropy_with_logits(input, target, reduction='none')
237-
# see also https://github.com/pytorch/pytorch/blob/v1.9.0/aten/src/ATen/native/Loss.cpp#L231
238-
max_val = (-input).clamp(min=0)
239-
loss: torch.Tensor = input - input * target + max_val + ((-max_val).exp() + (-input - max_val).exp()).log()
237+
# see also https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/Loss.cpp#L363
238+
loss: torch.Tensor = input - input * target - F.logsigmoid(input)
240239

241240
# sigmoid(-i) if t==1; sigmoid(i) if t==0 <=>
242241
# 1-sigmoid(i) if t==1; sigmoid(i) if t==0 <=>

tests/test_focal_loss.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def test_consistency_with_cross_entropy_2d_no_reduction(self):
132132
error = np.abs(a - b)
133133
max_error = np.maximum(error, max_error)
134134

135-
assert np.allclose(max_error, 0)
135+
assert np.allclose(max_error, 0, atol=1e-6)
136136

137137
def test_consistency_with_cross_entropy_2d_onehot_label(self):
138138
"""For gamma=0 the focal loss reduces to the cross entropy loss"""

0 commit comments

Comments
 (0)