import torch | |
from torch import nn | |
import torch.nn.functional as F | |
class PixWiseBCELoss(nn.Module): | |
def __init__(self, beta=0.5): | |
super().__init__() | |
self.criterion = nn.BCELoss() | |
self.beta = beta | |
def forward(self, net_mask, net_label, target_mask, target_label): | |
pixel_loss = self.criterion(net_mask, target_mask) | |
binary_loss = self.criterion(net_label, target_label) | |
loss = pixel_loss * self.beta + binary_loss * (1 - self.beta) | |
return loss | |