| """Loss builder. Everything is treated as MULTICLASS (binary == 2 classes), |
| which sidesteps the binary/multiclass mode pitfall and unifies all datasets. |
| |
| ce_dice : CrossEntropy + multiclass Dice (default, robust for medical seg) |
| ce : CrossEntropy only |
| dice : multiclass Dice only |
| |
| Inputs: logits [B,C,H,W], target [B,H,W] (long, ids 0..C-1). |
| """ |
| from __future__ import annotations |
|
|
| import torch |
| import torch.nn as nn |
| import segmentation_models_pytorch as smp |
|
|
|
|
| class CEDiceLoss(nn.Module): |
| def __init__(self, mode: str = "ce_dice"): |
| super().__init__() |
| self.mode = mode |
| self.ce = nn.CrossEntropyLoss() |
| self.dice = smp.losses.DiceLoss(mode=smp.losses.MULTICLASS_MODE, from_logits=True) |
|
|
| def forward(self, logits: torch.Tensor, target: torch.Tensor) -> torch.Tensor: |
| if self.mode == "ce": |
| return self.ce(logits, target) |
| if self.mode == "dice": |
| return self.dice(logits, target) |
| return self.ce(logits, target) + self.dice(logits, target) |
|
|
|
|
| def build_loss(name: str = "ce_dice") -> nn.Module: |
| return CEDiceLoss(name) |
|
|