RPCANet / utils /loss.py
fengyiwu's picture
Upload 93 files
82b70d0 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
class SoftLoULoss(nn.Module):
def __init__(self):
super(SoftLoULoss, self).__init__()
def forward(self, pred, target):
pred = F.sigmoid(pred)
smooth = 1
intersection = pred * target
intersection_sum = torch.sum(intersection, dim=(1,2,3))
pred_sum = torch.sum(pred, dim=(1,2,3))
target_sum = torch.sum(target, dim=(1,2,3))
loss = (intersection_sum + smooth) / \
(pred_sum + target_sum - intersection_sum + smooth)
loss = 1 - torch.mean(loss)
return loss