WSCL / losses /volume_label_loss.py
yhzhai's picture
release code
482ab8a
import torch.nn as nn
def get_volume_label_loss(opt):
return VolumeLabelLoss()
class VolumeLabelLoss(nn.Module):
def __init__(self):
super().__init__()
self.BCE_loss = nn.BCELoss(reduction="mean")
def forward(self, pred, volume, label):
loss = self.BCE_loss(pred, label)
return {"loss": loss}