print("Importing standard...") from abc import ABC, abstractmethod print("Importing external...") import torch from torch.nn.functional import binary_cross_entropy # from matplotlib import pyplot as plt print("Importing internal...") from utils import preprocess_masks_features, get_row_col, symlog, calculate_iou ######### BINARY LOSSES ############### def my_lovasz_hinge(logits, gt, downsample=False): if downsample: offset = int(torch.randint(downsample - 1, (1,))) logits, gt = logits[:, offset::downsample], gt[:, offset::downsample] # B, HW gt = 1.0 * gt # go float areas = gt.sum(dim=1, keepdims=True) # B, 1 # per_image = True, ignore = None signs = 2 * gt - 1 errors = 1 - logits * signs errors_sorted, perm = torch.sort(errors, dim=1, descending=True) gt_sorted = torch.gather(gt, 1, perm) # B, HW # lovasz grad intersection = areas - gt_sorted.cumsum(dim=1) # B, HW union = areas + (1 - gt_sorted).cumsum(dim=1) # B, HW jaccard = 1 - intersection / union # B, HW jaccard[:, 1:] = jaccard[:, 1:] - jaccard[:, :-1] loss = (torch.relu(errors_sorted) * jaccard).sum(dim=1) # B, return torch.nanmean(loss) def focal_loss(scores, targets, alpha=0.25, gamma=2): p = scores ce_loss = binary_cross_entropy(p, targets, reduction="none") p_t = p * targets + (1 - p) * (1 - targets) loss = ce_loss * ((1 - p_t) ** gamma) if alpha >= 0: alpha_t = alpha * targets + (1 - alpha) * (1 - targets) loss = alpha_t * loss return loss # also binary_cross_entropy and lovasz ########## SUBFUNCTIONS ######################3 def get_distances(features, refs, sigma, norm_p, square_distances, H, W): # features: B, 1, F, HW # refs: B, M, F, 1 # sigma: B, M, 1, 1 B, M = refs.shape[0], refs.shape[1] distances = torch.norm( features - refs, dim=2, p=norm_p, keepdim=True ) # B, M, 1, H*W distances = distances**2 if square_distances else distances distances = (distances / (2 * sigma**2)).reshape(B, M, H * W) return distances def activate(features, masks, activation, use_sigma, offset_pos, ret_prediction): # sigmoid is very similar to exp # prepare features assert activation in ["sigmoid", "symlog"] if masks is None: # when inferencing B, M = 1, 1 F, N = sorted(features.shape) H, W = [int(N ** (0.5))] * 2 features = features.reshape(1, 1, -1, H * W) else: masks, features, M, B, H, W, F = preprocess_masks_features(masks, features) # features: B, 1, F, H*W # masks: B, M, 1, H*W if use_sigma: sigma = torch.nn.functional.softplus(features)[:, :, -1:] # B, 1, 1, H*W features = features[:, :, :-1] F = features.shape[2] else: sigma = 1 features = symlog(features) if activation == "symlog" else torch.sigmoid(features) if offset_pos: assert F >= 2 row, col = get_row_col(H, W, features.device) row = row.reshape(1, 1, 1, H, 1).expand(B, 1, 1, H, W).reshape(B, 1, 1, H * W) col = col.reshape(1, 1, 1, 1, W).expand(B, 1, 1, H, W).reshape(B, 1, 1, H * W) positional_features = torch.cat([row, col], dim=2) # B, 1, 2, H*W features[:, :, :2] = features[:, :, :2] + positional_features prediction = features.reshape(B, 1, -1, H, W) if ret_prediction else None if masks is None: features = features.reshape(-1, H * W) sigma = sigma.reshape(-1, H * W) if use_sigma else 1 return features, sigma, H, W return features, masks, sigma, prediction, B, M, F, H, W class AbstractLoss(ABC): @staticmethod @abstractmethod def loss(features, masks, ret_prediction=False, **kwargs): pass @staticmethod @abstractmethod def get_mask_from_query(features, sindex, **kwargs): pass class IISLoss(AbstractLoss): @staticmethod def loss(features, masks, ret_prediction=False, K=3, logger=None): features, masks, sigma, prediction, B, M, F, H, W = activate( features, masks, "symlog", False, False, ret_prediction ) rindices = torch.randperm(H * W, device=masks.device) # the following should work if all masks have more than K pixels sindices = torch.stack( [ torch.stack([rindices[masks[b, m, 0, rindices]][:K] for m in range(M)]) for b in range(B) ] ) # B, M, K feats_at_sindices = torch.gather( features.permute(0, 3, 1, 2).expand(B, H * W, K, F), dim=1, index=sindices.reshape(B, M, K, 1).expand(B, M, K, F), ) # B, M, K, F feats_at_sindices = feats_at_sindices.reshape(B, M, K, F, 1) # B, M, K, F, 1 dists = get_distances( features, feats_at_sindices.reshape(B, M * K, F, 1), sigma, 2, True, H, W ) score = torch.exp(-dists) # B, M*K, H*W [0, 1] targets = ( masks.expand(B, M, K, H * W).reshape(B, M * K, H * W).float() ) # B, M, K, H*W floss = focal_loss(score, targets).mean() lloss = my_lovasz_hinge( score.view(B * M * K, H * W) * 2 - 1, targets.view(B * M * K, H * W), ) loss = floss + lloss return loss, prediction @staticmethod def get_mask_from_query(features, sindex): features, _, H, W = activate(features, None, "symlog", False, False, False) F = features.shape[0] query_feat = features[:, sindex] dists = get_distances( features.reshape(1, 1, F, H * W), query_feat.reshape(1, 1, F, 1), 1, 2, True, H, W, ) score = torch.exp(-dists) # 1, H*W pred = score > 0.5 return pred def iis_iou(features, masks, get_mask_from_query, K=20): masks, features, M, B, H, W, F = preprocess_masks_features(masks, features) # features: B, 1, F, H*W # masks: B, M, 1, H*W rindices = torch.randperm(H * W).to(masks.device) sindices = torch.stack( [ torch.stack([rindices[masks[b, m, 0, rindices]][:K] for m in range(M)]) for b in range(B) ] ) # B, M, K cum_iou, n_samples = 0, 0 for b in range(B): for m in range(M): for k in range(K): sindex = sindices[b, m, k] pred = get_mask_from_query(features[b, 0], sindex) iou = calculate_iou(pred, masks[b, m, 0, :]) cum_iou += iou n_samples += 1 return cum_iou / n_samples losses_names = [ "iis", ] # def get_loss_class(loss_name): if loss_name == "iis": return IISLoss else: raise NotImplementedError def get_get_mask_from_query(loss_name): loss_class = get_loss_class(loss_name) return loss_class.get_mask_from_query def get_loss(loss_name): loss_class = get_loss_class(loss_name) return loss_class.loss