import torch from torch import nn import torch.nn.functional as F class PixWiseBCELoss(nn.Module): def __init__(self, beta=0.5): super().__init__() self.criterion = nn.BCELoss() self.beta = beta def forward(self, net_mask, net_label, target_mask, target_label): pixel_loss = self.criterion(net_mask, target_mask) binary_loss = self.criterion(net_label, target_label) loss = pixel_loss * self.beta + binary_loss * (1 - self.beta) return loss