File size: 1,105 Bytes
f7f604d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import torch
import torch.nn.functional as F

def bce_loss(pred, mask, reduction='none'):
    bce = F.binary_cross_entropy(pred, mask, reduction=reduction)
    return bce

def weighted_bce_loss(pred, mask, reduction='none'):
    weight = 1 + 5 * torch.abs(F.avg_pool2d(mask, kernel_size=31, stride=1, padding=15) - mask)
    weight = weight.flatten()
    
    bce = weight * bce_loss(pred, mask, reduction='none').flatten()
    
    if reduction == 'mean':
        bce = bce.mean()
    
    return bce

def iou_loss(pred, mask, reduction='none'):
    inter = pred * mask
    union = pred + mask
    iou = 1 - (inter + 1) / (union - inter + 1)

    if reduction == 'mean':
        iou = iou.mean()

    return iou

def bce_loss_with_logits(pred, mask, reduction='none'):
    return bce_loss(torch.sigmoid(pred), mask, reduction=reduction)

def weighted_bce_loss_with_logits(pred, mask, reduction='none'):
    return weighted_bce_loss(torch.sigmoid(pred), mask, reduction=reduction)

def iou_loss_with_logits(pred, mask, reduction='none'):
    return iou_loss(torch.sigmoid(pred), mask, reduction=reduction)