import torch import torch.nn as nn import torch.nn.functional as F from collections import defaultdict def dice_loss(input_mask, cls_gt): num_objects = input_mask.shape[1] losses = [] for i in range(num_objects): mask = input_mask[:,i].flatten(start_dim=1) # background not in mask, so we add one to cls_gt gt = (cls_gt==(i+1)).float().flatten(start_dim=1) numerator = 2 * (mask * gt).sum(-1) denominator = mask.sum(-1) + gt.sum(-1) loss = 1 - (numerator + 1) / (denominator + 1) losses.append(loss) return torch.cat(losses).mean() # https://stackoverflow.com/questions/63735255/how-do-i-compute-bootstrapped-cross-entropy-loss-in-pytorch class BootstrappedCE(nn.Module): def __init__(self, start_warm, end_warm, top_p=0.15): super().__init__() self.start_warm = start_warm self.end_warm = end_warm self.top_p = top_p def forward(self, input, target, it): if it < self.start_warm: return F.cross_entropy(input, target), 1.0 raw_loss = F.cross_entropy(input, target, reduction='none').view(-1) num_pixels = raw_loss.numel() if it > self.end_warm: this_p = self.top_p else: this_p = self.top_p + (1-self.top_p)*((self.end_warm-it)/(self.end_warm-self.start_warm)) loss, _ = torch.topk(raw_loss, int(num_pixels * this_p), sorted=False) return loss.mean(), this_p class LossComputer: def __init__(self, config): super().__init__() self.config = config self.bce = BootstrappedCE(config['start_warm'], config['end_warm']) def compute(self, data, num_objects, it): losses = defaultdict(int) b, t = data['rgb'].shape[:2] losses['total_loss'] = 0 for ti in range(1, t): for bi in range(b): loss, p = self.bce(data[f'logits_{ti}'][bi:bi+1, :num_objects[bi]+1], data['cls_gt'][bi:bi+1,ti,0], it) losses['p'] += p / b / (t-1) losses[f'ce_loss_{ti}'] += loss / b losses['total_loss'] += losses['ce_loss_%d'%ti] losses[f'dice_loss_{ti}'] = dice_loss(data[f'masks_{ti}'], data['cls_gt'][:,ti,0]) losses['total_loss'] += losses[f'dice_loss_{ti}'] return losses