import torch.nn as nn def get_volume_label_loss(opt): return VolumeLabelLoss() class VolumeLabelLoss(nn.Module): def __init__(self): super().__init__() self.BCE_loss = nn.BCELoss(reduction="mean") def forward(self, pred, volume, label): loss = self.BCE_loss(pred, label) return {"loss": loss}