File size: 343 Bytes
482ab8a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import torch.nn as nn


def get_volume_label_loss(opt):
    return VolumeLabelLoss()


class VolumeLabelLoss(nn.Module):
    def __init__(self):
        super().__init__()

        self.BCE_loss = nn.BCELoss(reduction="mean")

    def forward(self, pred, volume, label):
        loss = self.BCE_loss(pred, label)
        return {"loss": loss}