Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 90 additions & 33 deletions monai/losses/unified_focal_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@

class AsymmetricFocalTverskyLoss(_Loss):
"""
AsymmetricFocalTverskyLoss is a variant of FocalTverskyLoss, which attentions to the foreground class.
AsymmetricFocalTverskyLoss is a variant of FocalTverskyLoss that focuses on foreground classes.

Actually, it's only supported for binary image segmentation now.
Supports multi-class segmentation with optional background inclusion.

Reimplementation of the Asymmetric Focal Tversky Loss described in:

Expand All @@ -39,26 +39,29 @@ def __init__(
gamma: float = 0.75,
epsilon: float = 1e-7,
reduction: LossReduction | str = LossReduction.MEAN,
include_background: bool = True,
) -> None:
"""
Args:
to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False.
delta : weight of the background. Defaults to 0.7.
gamma : value of the exponent gamma in the definition of the Focal loss . Defaults to 0.75.
epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7.
include_background: whether to include background class in loss calculation. Defaults to True.
"""
super().__init__(reduction=LossReduction(reduction).value)
self.to_onehot_y = to_onehot_y
self.delta = delta
self.gamma = gamma
self.epsilon = epsilon
self.include_background: bool = include_background

def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
n_pred_ch = y_pred.shape[1]

if self.to_onehot_y:
if n_pred_ch == 1:
warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
warnings.warn("single channel prediction, `to_onehot_y=True` ignored.", stacklevel=2)
else:
y_true = one_hot(y_true, num_classes=n_pred_ch)

Expand All @@ -74,21 +77,27 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
fn = torch.sum(y_true * (1 - y_pred), dim=axis)
fp = torch.sum((1 - y_true) * y_pred, dim=axis)
dice_class = (tp + self.epsilon) / (tp + self.delta * fn + (1 - self.delta) * fp + self.epsilon)
dice_class = torch.clamp(dice_class, self.epsilon, 1.0 - self.epsilon)

# Calculate losses separately for each class, enhancing both classes
back_dice = 1 - dice_class[:, 0]
fore_dice = (1 - dice_class[:, 1]) * torch.pow(1 - dice_class[:, 1], -self.gamma)
back_dice = 1 - dice_class[:, 0:1]
fore_dice = (1 - dice_class[:, 1:]) * torch.pow(1 - dice_class[:, 1:], -self.gamma)

if not self.include_background:
back_dice = back_dice * 0.0

all_dice = torch.cat([back_dice, fore_dice], dim=1)

# Average class scores
loss = torch.mean(torch.stack([back_dice, fore_dice], dim=-1))
loss = torch.mean(all_dice)
return loss


class AsymmetricFocalLoss(_Loss):
"""
AsymmetricFocalLoss is a variant of FocalTverskyLoss, which attentions to the foreground class.
AsymmetricFocalLoss is a variant of Focal Loss that focuses on foreground classes.

Actually, it's only supported for binary image segmentation now.
Supports multi-class segmentation with optional background inclusion.

Reimplementation of the Asymmetric Focal Loss described in:

Expand All @@ -103,26 +112,29 @@ def __init__(
gamma: float = 2,
epsilon: float = 1e-7,
reduction: LossReduction | str = LossReduction.MEAN,
include_background: bool = True,
):
"""
Args:
to_onehot_y : whether to convert `y` into the one-hot format. Defaults to False.
delta : weight of the background. Defaults to 0.7.
gamma : value of the exponent gamma in the definition of the Focal loss . Defaults to 0.75.
gamma : value of the exponent gamma in the definition of the Focal loss . Defaults to 2.
epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7.
include_background: whether to include background class in loss calculation. Defaults to True.
"""
super().__init__(reduction=LossReduction(reduction).value)
self.to_onehot_y = to_onehot_y
self.delta = delta
self.gamma = gamma
self.epsilon = epsilon
self.include_background: bool = include_background

def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
n_pred_ch = y_pred.shape[1]

if self.to_onehot_y:
if n_pred_ch == 1:
warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
warnings.warn("single channel prediction, `to_onehot_y=True` ignored.", stacklevel=2)
else:
y_true = one_hot(y_true, num_classes=n_pred_ch)

Expand All @@ -132,21 +144,26 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
y_pred = torch.clamp(y_pred, self.epsilon, 1.0 - self.epsilon)
cross_entropy = -y_true * torch.log(y_pred)

back_ce = torch.pow(1 - y_pred[:, 0], self.gamma) * cross_entropy[:, 0]
back_ce = torch.pow(1 - y_pred[:, 0:1], self.gamma) * cross_entropy[:, 0:1]
back_ce = (1 - self.delta) * back_ce

fore_ce = cross_entropy[:, 1]
fore_ce = cross_entropy[:, 1:]
fore_ce = self.delta * fore_ce

loss = torch.mean(torch.sum(torch.stack([back_ce, fore_ce], dim=1), dim=1))
if not self.include_background:
back_ce = back_ce * 0.0

all_ce = torch.cat([back_ce, fore_ce], dim=1)

loss = torch.mean(all_ce)
return loss


class AsymmetricUnifiedFocalLoss(_Loss):
"""
AsymmetricUnifiedFocalLoss is a variant of Focal Loss.
AsymmetricUnifiedFocalLoss combines Asymmetric Focal Loss and Asymmetric Focal Tversky Loss.

Actually, it's only supported for binary image segmentation now
Supports multi-class segmentation with configurable activation (sigmoid/softmax) and optional background inclusion.

Reimplementation of the Asymmetric Unified Focal Tversky Loss described in:

Expand All @@ -162,15 +179,20 @@ def __init__(
gamma: float = 0.5,
delta: float = 0.7,
reduction: LossReduction | str = LossReduction.MEAN,
include_background: bool = True,
use_softmax: bool = False,
):
"""
Args:
to_onehot_y : whether to convert `y` into the one-hot format. Defaults to False.
num_classes : number of classes, it only supports 2 now. Defaults to 2.
num_classes : number of classes. Defaults to 2.
weight : weight for combining focal loss and focal tversky loss. Defaults to 0.5.
gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.5.
delta : weight of the background. Defaults to 0.7.
gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.75.
epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7.
weight : weight for each loss function, if it's none it's 0.5. Defaults to None.
reduction : reduction mode for the loss. Defaults to LossReduction.MEAN.
include_background : whether to include the background class in loss calculation. Defaults to True.
use_softmax: whether to use softmax to transform the original logits into probabilities.
If True, softmax is used. If False, sigmoid is used. Defaults to False.

Example:
>>> import torch
Expand All @@ -179,26 +201,42 @@ def __init__(
>>> grnd = torch.ones((1,1,32,32), dtype=torch.int64)
>>> fl = AsymmetricUnifiedFocalLoss(to_onehot_y=True)
>>> fl(pred, grnd)
>>> # Multiclass example with 3 classes
>>> pred_mc = torch.randn((1,3,32,32), dtype=torch.float32)
>>> grnd_mc = torch.randint(0, 3, (1,1,32,32), dtype=torch.int64)
>>> fl_mc = AsymmetricUnifiedFocalLoss(to_onehot_y=True, num_classes=3, use_softmax=True)
>>> fl_mc(pred_mc, grnd_mc)
"""
super().__init__(reduction=LossReduction(reduction).value)
self.to_onehot_y = to_onehot_y
self.num_classes = num_classes
self.gamma = gamma
self.delta = delta
self.weight: float = weight
self.asy_focal_loss = AsymmetricFocalLoss(gamma=self.gamma, delta=self.delta)
self.asy_focal_tversky_loss = AsymmetricFocalTverskyLoss(gamma=self.gamma, delta=self.delta)
self.include_background: bool = include_background
self.use_softmax = use_softmax
self.asy_focal_loss = AsymmetricFocalLoss(
to_onehot_y=self.to_onehot_y,
gamma=self.gamma,
delta=self.delta,
include_background=self.include_background,
reduction=LossReduction.NONE,
)
self.asy_focal_tversky_loss = AsymmetricFocalTverskyLoss(
to_onehot_y=self.to_onehot_y,
gamma=self.gamma,
delta=self.delta,
include_background=self.include_background,
reduction=LossReduction.NONE,
)

# TODO: Implement this function to support multiple classes segmentation
def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
"""
Args:
y_pred : the shape should be BNH[WD], where N is the number of classes.
It only supports binary segmentation.
The input should be the original logits since it will be transformed by
a sigmoid in the forward function.
y_true : the shape should be BNH[WD], where N is the number of classes.
It only supports binary segmentation.
a sigmoid or softmax in the forward function.
y_true : the shape should be BNH[WD] or B1H[WD], where N is the number of classes.

Raises:
ValueError: When input and target are different shape
Expand All @@ -212,20 +250,39 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
if len(y_pred.shape) != 4 and len(y_pred.shape) != 5:
raise ValueError(f"input shape must be 4 or 5, but got {y_pred.shape}")

if y_pred.shape[1] == 1:
y_pred = one_hot(y_pred, num_classes=self.num_classes)
y_true = one_hot(y_true, num_classes=self.num_classes)

if torch.max(y_true) != self.num_classes - 1:
raise ValueError(f"Please make sure the number of classes is {self.num_classes-1}")
if y_true.shape[1] == self.num_classes:
if not torch.all((y_true == 0) | (y_true == 1)):
raise ValueError("y_true appears to be one-hot but contains values other than 0 and 1")
elif y_true.shape[1] == 1:
if torch.max(y_true) >= self.num_classes:
raise ValueError(
f"y_true labels must be in [0, {self.num_classes - 1}], but got max {torch.max(y_true)}"
)
else:
raise ValueError(
f"y_true must have {self.num_classes} channels (one-hot) or 1 channel (labels), got {y_true.shape[1]}"
)

n_pred_ch = y_pred.shape[1]
if self.to_onehot_y:
if n_pred_ch == 1:
warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
warnings.warn("single channel prediction, `to_onehot_y=True` ignored.", stacklevel=2)
else:
y_true = one_hot(y_true, num_classes=n_pred_ch)

if y_pred.shape[1] == 1:
warnings.warn("single channel prediction, augmenting with background channel.", stacklevel=2)
y_pred_sigmoid = torch.sigmoid(y_pred.float())
y_pred = torch.cat([1 - y_pred_sigmoid, y_pred_sigmoid], dim=1)

if y_true.shape[1] == 1:
y_true = one_hot(y_true, num_classes=self.num_classes)
else:
if self.use_softmax:
y_pred = torch.softmax(y_pred.float(), dim=1)
else:
y_pred = y_pred.float()
Comment on lines +280 to +284
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Missing sigmoid activation.

Docstring (line 195) states "If False, sigmoid is used," but line 284 only casts to float without applying sigmoid. For multi-channel binary segmentation with use_softmax=False, raw logits are passed to sub-losses, which expect probabilities (see clamping at lines 72, 144).

Apply this diff:

         else:
             if self.use_softmax:
                 y_pred = torch.softmax(y_pred.float(), dim=1)
             else:
-                y_pred = y_pred.float()
+                y_pred = torch.sigmoid(y_pred.float())
🤖 Prompt for AI Agents
In monai/losses/unified_focal_loss.py around lines 280 to 284, when
self.use_softmax is False the code only casts y_pred to float but does not apply
sigmoid as the docstring promises; update the branch to apply torch.sigmoid to
the float-cast logits (e.g., y_pred = torch.sigmoid(y_pred.float())) so
downstream sub-losses receive probabilities for multi-channel binary
segmentation, preserving device/dtype semantics as needed.


asy_focal_loss = self.asy_focal_loss(y_pred, y_true)
asy_focal_tversky_loss = self.asy_focal_tversky_loss(y_pred, y_true)

Expand Down
Loading