MaybeRichard's picture
Upload folder using huggingface_hub
b8fae22 verified
Raw
History Blame Contribute Delete
1.12 kB
"""Loss builder. Everything is treated as MULTICLASS (binary == 2 classes),
which sidesteps the binary/multiclass mode pitfall and unifies all datasets.
ce_dice : CrossEntropy + multiclass Dice (default, robust for medical seg)
ce : CrossEntropy only
dice : multiclass Dice only
Inputs: logits [B,C,H,W], target [B,H,W] (long, ids 0..C-1).
"""
from __future__ import annotations
import torch
import torch.nn as nn
import segmentation_models_pytorch as smp
class CEDiceLoss(nn.Module):
def __init__(self, mode: str = "ce_dice"):
super().__init__()
self.mode = mode
self.ce = nn.CrossEntropyLoss()
self.dice = smp.losses.DiceLoss(mode=smp.losses.MULTICLASS_MODE, from_logits=True)
def forward(self, logits: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
if self.mode == "ce":
return self.ce(logits, target)
if self.mode == "dice":
return self.dice(logits, target)
return self.ce(logits, target) + self.dice(logits, target)
def build_loss(name: str = "ce_dice") -> nn.Module:
return CEDiceLoss(name)