Oisin Mac Aodha
added bat code
9ace58a
raw
history blame
1.51 kB
import torch
import torch.nn.functional as F
def bbox_size_loss(pred_size, gt_size):
"""
Bounding box size loss. Only compute loss where there is a bounding box.
"""
gt_size_mask = (gt_size > 0).float()
return (F.l1_loss(pred_size*gt_size_mask, gt_size, reduction='sum') / (gt_size_mask.sum() + 1e-5))
def focal_loss(pred, gt, weights=None, valid_mask=None):
"""
Focal loss adapted from CornerNet: Detecting Objects as Paired Keypoints
pred (batch x c x h x w)
gt (batch x c x h x w)
"""
eps = 1e-5
beta = 4
alpha = 2
pos_inds = gt.eq(1).float()
neg_inds = gt.lt(1).float()
pos_loss = torch.log(pred + eps) * torch.pow(1 - pred, alpha) * pos_inds
neg_loss = torch.log(1 - pred + eps) * torch.pow(pred, alpha) * torch.pow(1 - gt, beta) * neg_inds
if weights is not None:
pos_loss = pos_loss*weights
#neg_loss = neg_loss*weights
if valid_mask is not None:
pos_loss = pos_loss*valid_mask
neg_loss = neg_loss*valid_mask
pos_loss = pos_loss.sum()
neg_loss = neg_loss.sum()
num_pos = pos_inds.float().sum()
if num_pos == 0:
loss = -neg_loss
else:
loss = -(pos_loss + neg_loss) / num_pos
return loss
def mse_loss(pred, gt, weights=None, valid_mask=None):
"""
Mean squared error loss.
"""
if valid_mask is None:
op = ((gt-pred)**2).mean()
else:
op = (valid_mask*((gt-pred)**2)).sum() / valid_mask.sum()
return op