File size: 629 Bytes
82b70d0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 |
import torch
import torch.nn as nn
import torch.nn.functional as F
class SoftLoULoss(nn.Module):
def __init__(self):
super(SoftLoULoss, self).__init__()
def forward(self, pred, target):
pred = F.sigmoid(pred)
smooth = 1
intersection = pred * target
intersection_sum = torch.sum(intersection, dim=(1,2,3))
pred_sum = torch.sum(pred, dim=(1,2,3))
target_sum = torch.sum(target, dim=(1,2,3))
loss = (intersection_sum + smooth) / \
(pred_sum + target_sum - intersection_sum + smooth)
loss = 1 - torch.mean(loss)
return loss |