import torch import torch.nn as nn import torch.nn.functional as F from torch.autograd import Variable from face_detect.utils.box_utils import match, log_sum_exp from face_detect.data import cfg_mnet GPU = cfg_mnet['gpu_train'] class MultiBoxLoss(nn.Module): """SSD Weighted Loss Function Compute Targets: 1) Produce Confidence Target Indices by matching ground truth boxes with (default) 'priorboxes' that have jaccard index > threshold parameter (default threshold: 0.5). 2) Produce localization target by 'encoding' variance into offsets of ground truth boxes and their matched 'priorboxes'. 3) Hard negative mining to filter the excessive number of negative examples that comes with using a large number of default bounding boxes. (default negative:positive ratio 3:1) Objective Loss: L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / N Where, Lconf is the CrossEntropy Loss and Lloc is the SmoothL1 Loss weighted by α which is set to 1 by cross val. Args: c: class confidences, l: predicted boxes, g: ground truth boxes N: number of matched default boxes See: https://arxiv.org/pdf/1512.02325.pdf for more details. """ def __init__(self, num_classes, overlap_thresh, prior_for_matching, bkg_label, neg_mining, neg_pos, neg_overlap, encode_target): super(MultiBoxLoss, self).__init__() self.num_classes = num_classes self.threshold = overlap_thresh self.background_label = bkg_label self.encode_target = encode_target self.use_prior_for_matching = prior_for_matching self.do_neg_mining = neg_mining self.negpos_ratio = neg_pos self.neg_overlap = neg_overlap self.variance = [0.1, 0.2] def forward(self, predictions, priors, targets): """Multibox Loss Args: predictions (tuple): A tuple containing loc preds, conf preds, and prior boxes from SSD net. conf shape: torch.size(batch_size,num_priors,num_classes) loc shape: torch.size(batch_size,num_priors,4) priors shape: torch.size(num_priors,4) ground_truth (tensor): Ground truth boxes and labels for a batch, shape: [batch_size,num_objs,5] (last idx is the label). """ loc_data, conf_data, landm_data = predictions priors = priors num = loc_data.size(0) num_priors = (priors.size(0)) # match priors (default boxes) and ground truth boxes loc_t = torch.Tensor(num, num_priors, 4) landm_t = torch.Tensor(num, num_priors, 10) conf_t = torch.LongTensor(num, num_priors) for idx in range(num): truths = targets[idx][:, :4].data labels = targets[idx][:, -1].data landms = targets[idx][:, 4:14].data defaults = priors.data match(self.threshold, truths, defaults, self.variance, labels, landms, loc_t, conf_t, landm_t, idx) if GPU: loc_t = loc_t.cuda() conf_t = conf_t.cuda() landm_t = landm_t.cuda() zeros = torch.tensor(0).cuda() # landm Loss (Smooth L1) # Shape: [batch,num_priors,10] pos1 = conf_t > zeros num_pos_landm = pos1.long().sum(1, keepdim=True) N1 = max(num_pos_landm.data.sum().float(), 1) pos_idx1 = pos1.unsqueeze(pos1.dim()).expand_as(landm_data) landm_p = landm_data[pos_idx1].view(-1, 10) landm_t = landm_t[pos_idx1].view(-1, 10) loss_landm = F.smooth_l1_loss(landm_p, landm_t, reduction='sum') pos = conf_t != zeros conf_t[pos] = 1 # Localization Loss (Smooth L1) # Shape: [batch,num_priors,4] pos_idx = pos.unsqueeze(pos.dim()).expand_as(loc_data) loc_p = loc_data[pos_idx].view(-1, 4) loc_t = loc_t[pos_idx].view(-1, 4) loss_l = F.smooth_l1_loss(loc_p, loc_t, reduction='sum') # Compute max conf across batch for hard negative mining batch_conf = conf_data.view(-1, self.num_classes) loss_c = log_sum_exp(batch_conf) - batch_conf.gather(1, conf_t.view(-1, 1)) # Hard Negative Mining loss_c[pos.view(-1, 1)] = 0 # filter out pos boxes for now loss_c = loss_c.view(num, -1) _, loss_idx = loss_c.sort(1, descending=True) _, idx_rank = loss_idx.sort(1) num_pos = pos.long().sum(1, keepdim=True) num_neg = torch.clamp(self.negpos_ratio*num_pos, max=pos.size(1)-1) neg = idx_rank < num_neg.expand_as(idx_rank) # Confidence Loss Including Positive and Negative Examples pos_idx = pos.unsqueeze(2).expand_as(conf_data) neg_idx = neg.unsqueeze(2).expand_as(conf_data) conf_p = conf_data[(pos_idx+neg_idx).gt(0)].view(-1,self.num_classes) targets_weighted = conf_t[(pos+neg).gt(0)] loss_c = F.cross_entropy(conf_p, targets_weighted, reduction='sum') # Sum of losses: L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / N N = max(num_pos.data.sum().float(), 1) loss_l /= N loss_c /= N loss_landm /= N1 return loss_l, loss_c, loss_landm