-
Notifications
You must be signed in to change notification settings - Fork 1.3k
modify unified_focal_loss.py #8626
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Changes from all commits
0883c9c
0a29b5e
83d2318
1a28917
836ce42
81af139
08bfb0e
3c1ec33
c3c2570
912235e
230394b
b52c570
c0e9d78
b7a5013
630ad7a
b46f089
e8eff3d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -22,9 +22,9 @@ | |
|
|
||
| class AsymmetricFocalTverskyLoss(_Loss): | ||
| """ | ||
| AsymmetricFocalTverskyLoss is a variant of FocalTverskyLoss, which attentions to the foreground class. | ||
| AsymmetricFocalTverskyLoss is a variant of FocalTverskyLoss that focuses on foreground classes. | ||
|
|
||
| Actually, it's only supported for binary image segmentation now. | ||
| Supports multi-class segmentation with optional background inclusion. | ||
|
|
||
| Reimplementation of the Asymmetric Focal Tversky Loss described in: | ||
|
|
||
|
|
@@ -39,26 +39,29 @@ def __init__( | |
| gamma: float = 0.75, | ||
| epsilon: float = 1e-7, | ||
| reduction: LossReduction | str = LossReduction.MEAN, | ||
| include_background: bool = True, | ||
| ) -> None: | ||
| """ | ||
| Args: | ||
| to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False. | ||
| delta : weight of the background. Defaults to 0.7. | ||
| gamma : value of the exponent gamma in the definition of the Focal loss . Defaults to 0.75. | ||
| epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7. | ||
| include_background: whether to include background class in loss calculation. Defaults to True. | ||
| """ | ||
| super().__init__(reduction=LossReduction(reduction).value) | ||
| self.to_onehot_y = to_onehot_y | ||
| self.delta = delta | ||
| self.gamma = gamma | ||
| self.epsilon = epsilon | ||
| self.include_background: bool = include_background | ||
|
|
||
| def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: | ||
| n_pred_ch = y_pred.shape[1] | ||
|
|
||
| if self.to_onehot_y: | ||
| if n_pred_ch == 1: | ||
| warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") | ||
| warnings.warn("single channel prediction, `to_onehot_y=True` ignored.", stacklevel=2) | ||
| else: | ||
| y_true = one_hot(y_true, num_classes=n_pred_ch) | ||
|
|
||
|
|
@@ -74,21 +77,27 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: | |
| fn = torch.sum(y_true * (1 - y_pred), dim=axis) | ||
| fp = torch.sum((1 - y_true) * y_pred, dim=axis) | ||
| dice_class = (tp + self.epsilon) / (tp + self.delta * fn + (1 - self.delta) * fp + self.epsilon) | ||
| dice_class = torch.clamp(dice_class, self.epsilon, 1.0 - self.epsilon) | ||
|
|
||
| # Calculate losses separately for each class, enhancing both classes | ||
| back_dice = 1 - dice_class[:, 0] | ||
| fore_dice = (1 - dice_class[:, 1]) * torch.pow(1 - dice_class[:, 1], -self.gamma) | ||
| back_dice = 1 - dice_class[:, 0:1] | ||
| fore_dice = (1 - dice_class[:, 1:]) * torch.pow(1 - dice_class[:, 1:], -self.gamma) | ||
|
|
||
| if not self.include_background: | ||
| back_dice = back_dice * 0.0 | ||
coderabbitai[bot] marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| all_dice = torch.cat([back_dice, fore_dice], dim=1) | ||
|
|
||
| # Average class scores | ||
| loss = torch.mean(torch.stack([back_dice, fore_dice], dim=-1)) | ||
| loss = torch.mean(all_dice) | ||
| return loss | ||
|
|
||
|
|
||
| class AsymmetricFocalLoss(_Loss): | ||
| """ | ||
| AsymmetricFocalLoss is a variant of FocalTverskyLoss, which attentions to the foreground class. | ||
| AsymmetricFocalLoss is a variant of Focal Loss that focuses on foreground classes. | ||
|
|
||
| Actually, it's only supported for binary image segmentation now. | ||
| Supports multi-class segmentation with optional background inclusion. | ||
|
|
||
| Reimplementation of the Asymmetric Focal Loss described in: | ||
|
|
||
|
|
@@ -103,26 +112,29 @@ def __init__( | |
| gamma: float = 2, | ||
| epsilon: float = 1e-7, | ||
| reduction: LossReduction | str = LossReduction.MEAN, | ||
| include_background: bool = True, | ||
| ): | ||
| """ | ||
| Args: | ||
| to_onehot_y : whether to convert `y` into the one-hot format. Defaults to False. | ||
| delta : weight of the background. Defaults to 0.7. | ||
| gamma : value of the exponent gamma in the definition of the Focal loss . Defaults to 0.75. | ||
| gamma : value of the exponent gamma in the definition of the Focal loss . Defaults to 2. | ||
| epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7. | ||
| include_background: whether to include background class in loss calculation. Defaults to True. | ||
| """ | ||
| super().__init__(reduction=LossReduction(reduction).value) | ||
| self.to_onehot_y = to_onehot_y | ||
| self.delta = delta | ||
| self.gamma = gamma | ||
| self.epsilon = epsilon | ||
| self.include_background: bool = include_background | ||
|
|
||
| def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: | ||
| n_pred_ch = y_pred.shape[1] | ||
|
|
||
| if self.to_onehot_y: | ||
| if n_pred_ch == 1: | ||
| warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") | ||
| warnings.warn("single channel prediction, `to_onehot_y=True` ignored.", stacklevel=2) | ||
| else: | ||
| y_true = one_hot(y_true, num_classes=n_pred_ch) | ||
|
|
||
|
|
@@ -132,21 +144,26 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: | |
| y_pred = torch.clamp(y_pred, self.epsilon, 1.0 - self.epsilon) | ||
| cross_entropy = -y_true * torch.log(y_pred) | ||
|
|
||
| back_ce = torch.pow(1 - y_pred[:, 0], self.gamma) * cross_entropy[:, 0] | ||
| back_ce = torch.pow(1 - y_pred[:, 0:1], self.gamma) * cross_entropy[:, 0:1] | ||
| back_ce = (1 - self.delta) * back_ce | ||
|
|
||
| fore_ce = cross_entropy[:, 1] | ||
| fore_ce = cross_entropy[:, 1:] | ||
| fore_ce = self.delta * fore_ce | ||
|
|
||
| loss = torch.mean(torch.sum(torch.stack([back_ce, fore_ce], dim=1), dim=1)) | ||
| if not self.include_background: | ||
| back_ce = back_ce * 0.0 | ||
coderabbitai[bot] marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| all_ce = torch.cat([back_ce, fore_ce], dim=1) | ||
|
|
||
| loss = torch.mean(all_ce) | ||
| return loss | ||
|
|
||
|
|
||
| class AsymmetricUnifiedFocalLoss(_Loss): | ||
| """ | ||
| AsymmetricUnifiedFocalLoss is a variant of Focal Loss. | ||
| AsymmetricUnifiedFocalLoss combines Asymmetric Focal Loss and Asymmetric Focal Tversky Loss. | ||
|
|
||
| Actually, it's only supported for binary image segmentation now | ||
| Supports multi-class segmentation with configurable activation (sigmoid/softmax) and optional background inclusion. | ||
|
|
||
| Reimplementation of the Asymmetric Unified Focal Tversky Loss described in: | ||
|
|
||
|
|
@@ -162,15 +179,20 @@ def __init__( | |
| gamma: float = 0.5, | ||
| delta: float = 0.7, | ||
| reduction: LossReduction | str = LossReduction.MEAN, | ||
| include_background: bool = True, | ||
| use_softmax: bool = False, | ||
| ): | ||
| """ | ||
| Args: | ||
| to_onehot_y : whether to convert `y` into the one-hot format. Defaults to False. | ||
| num_classes : number of classes, it only supports 2 now. Defaults to 2. | ||
| num_classes : number of classes. Defaults to 2. | ||
| weight : weight for combining focal loss and focal tversky loss. Defaults to 0.5. | ||
| gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.5. | ||
| delta : weight of the background. Defaults to 0.7. | ||
| gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.75. | ||
| epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7. | ||
| weight : weight for each loss function, if it's none it's 0.5. Defaults to None. | ||
| reduction : reduction mode for the loss. Defaults to LossReduction.MEAN. | ||
| include_background : whether to include the background class in loss calculation. Defaults to True. | ||
| use_softmax: whether to use softmax to transform the original logits into probabilities. | ||
| If True, softmax is used. If False, sigmoid is used. Defaults to False. | ||
|
|
||
| Example: | ||
| >>> import torch | ||
|
|
@@ -179,26 +201,42 @@ def __init__( | |
| >>> grnd = torch.ones((1,1,32,32), dtype=torch.int64) | ||
| >>> fl = AsymmetricUnifiedFocalLoss(to_onehot_y=True) | ||
| >>> fl(pred, grnd) | ||
| >>> # Multiclass example with 3 classes | ||
| >>> pred_mc = torch.randn((1,3,32,32), dtype=torch.float32) | ||
| >>> grnd_mc = torch.randint(0, 3, (1,1,32,32), dtype=torch.int64) | ||
| >>> fl_mc = AsymmetricUnifiedFocalLoss(to_onehot_y=True, num_classes=3, use_softmax=True) | ||
| >>> fl_mc(pred_mc, grnd_mc) | ||
| """ | ||
| super().__init__(reduction=LossReduction(reduction).value) | ||
| self.to_onehot_y = to_onehot_y | ||
| self.num_classes = num_classes | ||
| self.gamma = gamma | ||
| self.delta = delta | ||
| self.weight: float = weight | ||
| self.asy_focal_loss = AsymmetricFocalLoss(gamma=self.gamma, delta=self.delta) | ||
| self.asy_focal_tversky_loss = AsymmetricFocalTverskyLoss(gamma=self.gamma, delta=self.delta) | ||
| self.include_background: bool = include_background | ||
| self.use_softmax = use_softmax | ||
| self.asy_focal_loss = AsymmetricFocalLoss( | ||
| to_onehot_y=self.to_onehot_y, | ||
| gamma=self.gamma, | ||
| delta=self.delta, | ||
| include_background=self.include_background, | ||
| reduction=LossReduction.NONE, | ||
| ) | ||
| self.asy_focal_tversky_loss = AsymmetricFocalTverskyLoss( | ||
| to_onehot_y=self.to_onehot_y, | ||
| gamma=self.gamma, | ||
| delta=self.delta, | ||
| include_background=self.include_background, | ||
| reduction=LossReduction.NONE, | ||
| ) | ||
|
|
||
| # TODO: Implement this function to support multiple classes segmentation | ||
| def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: | ||
| """ | ||
| Args: | ||
| y_pred : the shape should be BNH[WD], where N is the number of classes. | ||
| It only supports binary segmentation. | ||
| The input should be the original logits since it will be transformed by | ||
| a sigmoid in the forward function. | ||
| y_true : the shape should be BNH[WD], where N is the number of classes. | ||
| It only supports binary segmentation. | ||
| a sigmoid or softmax in the forward function. | ||
| y_true : the shape should be BNH[WD] or B1H[WD], where N is the number of classes. | ||
|
|
||
| Raises: | ||
| ValueError: When input and target are different shape | ||
|
|
@@ -212,20 +250,39 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: | |
| if len(y_pred.shape) != 4 and len(y_pred.shape) != 5: | ||
| raise ValueError(f"input shape must be 4 or 5, but got {y_pred.shape}") | ||
|
|
||
| if y_pred.shape[1] == 1: | ||
| y_pred = one_hot(y_pred, num_classes=self.num_classes) | ||
| y_true = one_hot(y_true, num_classes=self.num_classes) | ||
|
|
||
| if torch.max(y_true) != self.num_classes - 1: | ||
| raise ValueError(f"Please make sure the number of classes is {self.num_classes-1}") | ||
| if y_true.shape[1] == self.num_classes: | ||
| if not torch.all((y_true == 0) | (y_true == 1)): | ||
| raise ValueError("y_true appears to be one-hot but contains values other than 0 and 1") | ||
| elif y_true.shape[1] == 1: | ||
| if torch.max(y_true) >= self.num_classes: | ||
| raise ValueError( | ||
| f"y_true labels must be in [0, {self.num_classes - 1}], but got max {torch.max(y_true)}" | ||
| ) | ||
| else: | ||
| raise ValueError( | ||
| f"y_true must have {self.num_classes} channels (one-hot) or 1 channel (labels), got {y_true.shape[1]}" | ||
| ) | ||
|
|
||
| n_pred_ch = y_pred.shape[1] | ||
| if self.to_onehot_y: | ||
| if n_pred_ch == 1: | ||
| warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") | ||
| warnings.warn("single channel prediction, `to_onehot_y=True` ignored.", stacklevel=2) | ||
| else: | ||
| y_true = one_hot(y_true, num_classes=n_pred_ch) | ||
|
|
||
| if y_pred.shape[1] == 1: | ||
| warnings.warn("single channel prediction, augmenting with background channel.", stacklevel=2) | ||
| y_pred_sigmoid = torch.sigmoid(y_pred.float()) | ||
| y_pred = torch.cat([1 - y_pred_sigmoid, y_pred_sigmoid], dim=1) | ||
|
|
||
| if y_true.shape[1] == 1: | ||
| y_true = one_hot(y_true, num_classes=self.num_classes) | ||
| else: | ||
| if self.use_softmax: | ||
| y_pred = torch.softmax(y_pred.float(), dim=1) | ||
| else: | ||
| y_pred = y_pred.float() | ||
|
Comment on lines
+280
to
+284
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Missing sigmoid activation. Docstring (line 195) states "If False, sigmoid is used," but line 284 only casts to float without applying sigmoid. For multi-channel binary segmentation with Apply this diff: else:
if self.use_softmax:
y_pred = torch.softmax(y_pred.float(), dim=1)
else:
- y_pred = y_pred.float()
+ y_pred = torch.sigmoid(y_pred.float())🤖 Prompt for AI Agents |
||
|
|
||
| asy_focal_loss = self.asy_focal_loss(y_pred, y_true) | ||
| asy_focal_tversky_loss = self.asy_focal_tversky_loss(y_pred, y_true) | ||
|
|
||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.