import torch import torch.nn as nn import cv2 import numpy as np from torch.nn.modules.loss import _Loss import torch.nn.functional as F from utils.utils import postprocess, display, BBoxTransform, ClipBoxes from typing import Optional, List from functools import partial BINARY_MODE: str = "binary" MULTICLASS_MODE: str = "multiclass" MULTILABEL_MODE: str = "multilabel" def calc_iou(a, b): # a(anchor) [boxes, (y1, x1, y2, x2)] # b(gt, coco-style) [boxes, (x1, y1, x2, y2)] area = (b[:, 2] - b[:, 0]) * (b[:, 3] - b[:, 1]) iw = torch.min(torch.unsqueeze(a[:, 3], dim=1), b[:, 2]) - torch.max(torch.unsqueeze(a[:, 1], 1), b[:, 0]) ih = torch.min(torch.unsqueeze(a[:, 2], dim=1), b[:, 3]) - torch.max(torch.unsqueeze(a[:, 0], 1), b[:, 1]) iw = torch.clamp(iw, min=0) ih = torch.clamp(ih, min=0) ua = torch.unsqueeze((a[:, 2] - a[:, 0]) * (a[:, 3] - a[:, 1]), dim=1) + area - iw * ih ua = torch.clamp(ua, min=1e-8) intersection = iw * ih IoU = intersection / ua return IoU class FocalLoss(nn.Module): def __init__(self): super(FocalLoss, self).__init__() def forward(self, classifications, regressions, anchors, annotations, **kwargs): alpha = 0.25 gamma = 2.0 batch_size = classifications.shape[0] classification_losses = [] regression_losses = [] anchor = anchors[0, :, :] # assuming all image sizes are the same, which it is dtype = anchors.dtype anchor_widths = anchor[:, 3] - anchor[:, 1] anchor_heights = anchor[:, 2] - anchor[:, 0] anchor_ctr_x = anchor[:, 1] + 0.5 * anchor_widths anchor_ctr_y = anchor[:, 0] + 0.5 * anchor_heights for j in range(batch_size): classification = classifications[j, :, :] regression = regressions[j, :, :] bbox_annotation = annotations[j] bbox_annotation = bbox_annotation[bbox_annotation[:, 4] != -1] # print(bbox_annotation) classification = torch.clamp(classification, 1e-4, 1.0 - 1e-4) if bbox_annotation.shape[0] == 0: if torch.cuda.is_available(): alpha_factor = torch.ones_like(classification) * alpha alpha_factor = alpha_factor.cuda() alpha_factor = 1. - alpha_factor focal_weight = classification focal_weight = alpha_factor * torch.pow(focal_weight, gamma) bce = -(torch.log(1.0 - classification)) cls_loss = focal_weight * bce regression_losses.append(torch.tensor(0).to(dtype).cuda()) classification_losses.append(cls_loss.sum()) else: alpha_factor = torch.ones_like(classification) * alpha alpha_factor = 1. - alpha_factor focal_weight = classification focal_weight = alpha_factor * torch.pow(focal_weight, gamma) bce = -(torch.log(1.0 - classification)) cls_loss = focal_weight * bce regression_losses.append(torch.tensor(0).to(dtype)) classification_losses.append(cls_loss.sum()) continue IoU = calc_iou(anchor[:, :], bbox_annotation[:, :4]) IoU_max, IoU_argmax = torch.max(IoU, dim=1) # compute the loss for classification #targets = torch.ones_like(classification) * -1 targets = torch.zeros_like(classification) if torch.cuda.is_available(): targets = targets.cuda() assigned_annotations = bbox_annotation[IoU_argmax, :] positive_indices = torch.full_like(IoU_max,False,dtype=torch.bool) #torch.ge(IoU_max, 0.2) tensorA = (assigned_annotations[:, 2] - assigned_annotations[:, 0]) * (assigned_annotations[:, 3] - assigned_annotations[:, 1]) > 10 * 10 # for idx,iou in enumerate(IoU_max): # if tensorA[idx]: # Set iou threshold = 0.5 # if iou >= 0.5: # positive_indices[idx] = True # # targets[idx,:] = True # # else: # # positive_indices[idx] = False # else: # if iou >= 0.15: # positive_indices[idx] = True # # else: # # positive_indices[idx] = False # # targets[torch.lt(IoU_max, 0.4), :] = 0 positive_indices[torch.logical_or(torch.logical_and(tensorA,IoU_max >= 0.5),torch.logical_and(~tensorA,IoU_max >= 0.15))] = True num_positive_anchors = positive_indices.sum() # for box in assigned_annotations[positive_indices, :]: # xmin,ymin,xmax,ymax, cls = box # print("WIDTH HEIGHT:", (xmax-xmin),"\t", (ymax-ymin)) # for box in bbox_annotation: # xmin,ymin,xmax,ymax, cls = box # print("111 WIDTH HEIGHT:", (xmax-xmin),"\t", (ymax-ymin)) # targets[positive_indices, :] = 0 targets[positive_indices, assigned_annotations[positive_indices, 4].long()] = 1 alpha_factor = torch.ones_like(targets) * alpha if torch.cuda.is_available(): alpha_factor = alpha_factor.cuda() alpha_factor = torch.where(torch.eq(targets, 1.), alpha_factor, 1. - alpha_factor) focal_weight = torch.where(torch.eq(targets, 1.), 1. - classification, classification) focal_weight = alpha_factor * torch.pow(focal_weight, gamma) bce = -(targets * torch.log(classification) + (1.0 - targets) * torch.log(1.0 - classification)) cls_loss = focal_weight * bce zeros = torch.zeros_like(cls_loss) if torch.cuda.is_available(): zeros = zeros.cuda() cls_loss = torch.where(torch.ne(targets, -1.0), cls_loss, zeros) classification_losses.append(cls_loss.sum() / torch.clamp(num_positive_anchors.to(dtype), min=1.0)) if positive_indices.sum() > 0: assigned_annotations = assigned_annotations[positive_indices, :] anchor_widths_pi = anchor_widths[positive_indices] anchor_heights_pi = anchor_heights[positive_indices] anchor_ctr_x_pi = anchor_ctr_x[positive_indices] anchor_ctr_y_pi = anchor_ctr_y[positive_indices] gt_widths = assigned_annotations[:, 2] - assigned_annotations[:, 0] gt_heights = assigned_annotations[:, 3] - assigned_annotations[:, 1] gt_ctr_x = assigned_annotations[:, 0] + 0.5 * gt_widths gt_ctr_y = assigned_annotations[:, 1] + 0.5 * gt_heights gt_widths = torch.clamp(gt_widths, min=1) gt_heights = torch.clamp(gt_heights, min=1) targets_dx = (gt_ctr_x - anchor_ctr_x_pi) / anchor_widths_pi targets_dy = (gt_ctr_y - anchor_ctr_y_pi) / anchor_heights_pi targets_dw = torch.log(gt_widths / anchor_widths_pi) targets_dh = torch.log(gt_heights / anchor_heights_pi) targets = torch.stack((targets_dy, targets_dx, targets_dh, targets_dw)) targets = targets.t() regression_diff = torch.abs(targets - regression[positive_indices, :]) regression_loss = torch.where( torch.le(regression_diff, 1.0 / 9.0), 0.5 * 9.0 * torch.pow(regression_diff, 2), regression_diff - 0.5 / 9.0 ) regression_losses.append(regression_loss.mean()) else: if torch.cuda.is_available(): regression_losses.append(torch.tensor(0).to(dtype).cuda()) else: regression_losses.append(torch.tensor(0).to(dtype)) # debug imgs = kwargs.get('imgs', None) if imgs is not None: regressBoxes = BBoxTransform() clipBoxes = ClipBoxes() obj_list = kwargs.get('obj_list', None) out = postprocess(imgs.detach(), torch.stack([anchors[0]] * imgs.shape[0], 0).detach(), regressions.detach(), classifications.detach(), regressBoxes, clipBoxes, 0.25, 0.3) imgs = imgs.permute(0, 2, 3, 1).cpu().numpy() imgs = ((imgs * [0.229, 0.224, 0.225] + [0.485, 0.456, 0.406]) * 255).astype(np.uint8) imgs = [cv2.cvtColor(img, cv2.COLOR_RGB2BGR) for img in imgs] display(out, imgs, obj_list, imshow=False, imwrite=True) return torch.stack(classification_losses).mean(dim=0, keepdim=True), \ torch.stack(regression_losses).mean(dim=0, keepdim=True) * 50 # https://github.com/google/automl/blob/6fdd1de778408625c1faf368a327fe36ecd41bf7/efficientdet/hparams_config.py#L233 def focal_loss_with_logits( output: torch.Tensor, target: torch.Tensor, gamma: float = 2.0, alpha: Optional[float] = 0.25, reduction: str = "mean", normalized: bool = False, reduced_threshold: Optional[float] = None, eps: float = 1e-6, ) -> torch.Tensor: """Compute binary focal loss between target and output logits. See :class:`~pytorch_toolbelt.losses.FocalLoss` for details. Args: output: Tensor of arbitrary shape (predictions of the model) target: Tensor of the same shape as input gamma: Focal loss power factor alpha: Weight factor to balance positive and negative samples. Alpha must be in [0...1] range, high values will give more weight to positive class. reduction (string, optional): Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum' | 'batchwise_mean'. 'none': no reduction will be applied, 'mean': the sum of the output will be divided by the number of elements in the output, 'sum': the output will be summed. Note: :attr:`size_average` and :attr:`reduce` are in the process of being deprecated, and in the meantime, specifying either of those two args will override :attr:`reduction`. 'batchwise_mean' computes mean loss per sample in batch. Default: 'mean' normalized (bool): Compute normalized focal loss (https://arxiv.org/pdf/1909.07829.pdf). reduced_threshold (float, optional): Compute reduced focal loss (https://arxiv.org/abs/1903.01347). References: https://github.com/open-mmlab/mmdetection/blob/master/mmdet/core/loss/losses.py """ target = target.type(output.type()) # print(output.size(), target.size()) logpt = F.binary_cross_entropy_with_logits(output, target, reduction="none") pt = torch.exp(-logpt) # compute the loss if reduced_threshold is None: focal_term = (1.0 - pt).pow(gamma) else: focal_term = ((1.0 - pt) / reduced_threshold).pow(gamma) focal_term[pt < reduced_threshold] = 1 loss = focal_term * logpt if alpha is not None: loss *= alpha * target + (1 - alpha) * (1 - target) if normalized: norm_factor = focal_term.sum().clamp_min(eps) loss /= norm_factor if reduction == "mean": loss = loss.mean() if reduction == "sum": loss = loss.sum() if reduction == "batchwise_mean": loss = loss.sum(0) return loss class FocalLossSeg(_Loss): def __init__( self, mode: str, alpha: Optional[float] = None, gamma: Optional[float] = 2.0, ignore_index: Optional[int] = None, reduction: Optional[str] = "mean", normalized: bool = False, reduced_threshold: Optional[float] = None, ): """Compute Focal loss Args: mode: Loss mode 'binary', 'multiclass' or 'multilabel' alpha: Prior probability of having positive value in target. gamma: Power factor for dampening weight (focal strength). ignore_index: If not None, targets may contain values to be ignored. Target values equal to ignore_index will be ignored from loss computation. normalized: Compute normalized focal loss (https://arxiv.org/pdf/1909.07829.pdf). reduced_threshold: Switch to reduced focal loss. Note, when using this mode you should use `reduction="sum"`. Shape - **y_pred** - torch.Tensor of shape (N, C, H, W) - **y_true** - torch.Tensor of shape (N, H, W) or (N, C, H, W) Reference https://github.com/BloodAxe/pytorch-toolbelt """ assert mode in {BINARY_MODE, MULTILABEL_MODE, MULTICLASS_MODE} super().__init__() self.mode = mode self.ignore_index = ignore_index self.focal_loss_fn = partial( focal_loss_with_logits, alpha=alpha, gamma=gamma, reduced_threshold=reduced_threshold, reduction=reduction, normalized=normalized, ) def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: if self.mode in {BINARY_MODE, MULTILABEL_MODE}: y_true = y_true.view(-1) y_pred = y_pred.view(-1) if self.ignore_index is not None: # Filter predictions with ignore label from loss computation not_ignored = y_true != self.ignore_index y_pred = y_pred[not_ignored] y_true = y_true[not_ignored] loss = self.focal_loss_fn(y_pred, y_true) elif self.mode == MULTICLASS_MODE: num_classes = y_pred.size(1) loss = 0 # Filter anchors with -1 label from loss computation if self.ignore_index is not None: not_ignored = y_true != self.ignore_index for cls in range(num_classes): # cls_y_true = (y_true == cls).long() cls_y_true = y_true[:, cls, ...] cls_y_pred = y_pred[:, cls, ...] if self.ignore_index is not None: cls_y_true = cls_y_true[not_ignored] cls_y_pred = cls_y_pred[not_ignored] loss += self.focal_loss_fn(cls_y_pred, cls_y_true) return loss def to_tensor(x, dtype=None) -> torch.Tensor: if isinstance(x, torch.Tensor): if dtype is not None: x = x.type(dtype) return x if isinstance(x, np.ndarray): x = torch.from_numpy(x) if dtype is not None: x = x.type(dtype) return x if isinstance(x, (list, tuple)): x = np.array(x) x = torch.from_numpy(x) if dtype is not None: x = x.type(dtype) return x def soft_dice_score( output: torch.Tensor, target: torch.Tensor, smooth: float = 0.0, eps: float = 1e-7, dims=None, ) -> torch.Tensor: assert output.size() == target.size() if dims is not None: intersection = torch.sum(output * target, dim=dims) cardinality = torch.sum(output + target, dim=dims) else: intersection = torch.sum(output * target) cardinality = torch.sum(output + target) dice_score = (2.0 * intersection + smooth) / (cardinality + smooth).clamp_min(eps) return dice_score class DiceLoss(_Loss): def __init__( self, mode: str, classes: Optional[List[int]] = None, log_loss: bool = False, from_logits: bool = True, smooth: float = 0.0, ignore_index: Optional[int] = None, eps: float = 1e-7, ): """Dice loss for image segmentation task. It supports binary, multiclass and multilabel cases Args: mode: Loss mode 'binary', 'multiclass' or 'multilabel' classes: List of classes that contribute in loss computation. By default, all channels are included. log_loss: If True, loss computed as `- log(dice_coeff)`, otherwise `1 - dice_coeff` from_logits: If True, assumes input is raw logits smooth: Smoothness constant for dice coefficient (a) ignore_index: Label that indicates ignored pixels (does not contribute to loss) eps: A small epsilon for numerical stability to avoid zero division error (denominator will be always greater or equal to eps) Shape - **y_pred** - torch.Tensor of shape (N, C, H, W) - **y_true** - torch.Tensor of shape (N, H, W) or (N, C, H, W) Reference https://github.com/BloodAxe/pytorch-toolbelt """ assert mode in {BINARY_MODE, MULTILABEL_MODE, MULTICLASS_MODE} super(DiceLoss, self).__init__() self.mode = mode if classes is not None: assert mode != BINARY_MODE, "Masking classes is not supported with mode=binary" classes = to_tensor(classes, dtype=torch.long) self.classes = classes self.from_logits = from_logits self.smooth = smooth self.eps = eps self.log_loss = log_loss self.ignore_index = ignore_index def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: assert y_true.size(0) == y_pred.size(0) if self.from_logits: # Apply activations to get [0..1] class probabilities # Using Log-Exp as this gives more numerically stable result and does not cause vanishing gradient on # extreme values 0 and 1 # print(y_pred) if self.mode == MULTICLASS_MODE: y_pred = y_pred.log_softmax(dim=1).exp() else: y_pred = F.logsigmoid(y_pred).exp() # print("AFTER: ", y_pred) bs = y_true.size(0) num_classes = y_pred.size(1) dims = (0, 2) if self.mode == BINARY_MODE: y_true = y_true.view(bs, 1, -1) y_pred = y_pred.view(bs, 1, -1) if self.ignore_index is not None: mask = y_true != self.ignore_index y_pred = y_pred * mask y_true = y_true * mask if self.mode == MULTICLASS_MODE: y_true = y_true.view(bs, num_classes, -1) y_pred = y_pred.view(bs, num_classes, -1) # print("NUM CLASSES:", num_classes, y_true.size()) # if self.ignore_index is not None: # mask = y_true != self.ignore_index # y_pred = y_pred * mask.unsqueeze(1) # # y_true = F.one_hot((y_true * mask).to(torch.long), num_classes) # N,H*W -> N,H*W, C # y_true = y_true.permute(0, 2, 1) * mask.unsqueeze(1) # H, C, H*W # else: # y_true = F.one_hot(y_true, num_classes) # N,H*W -> N,H*W, C # y_true = y_true.permute(0, 2, 1) # N, C, H*W # # print("HERE", y_true.size()) # print(y_pred.size()) if self.mode == MULTILABEL_MODE: y_true = y_true.view(bs, num_classes, -1) y_pred = y_pred.view(bs, num_classes, -1) if self.ignore_index is not None: mask = y_true != self.ignore_index y_pred = y_pred * mask y_true = y_true * mask scores = self.compute_score(y_pred, y_true.type_as(y_pred), smooth=self.smooth, eps=self.eps, dims=dims) if self.log_loss: loss = -torch.log(scores.clamp_min(self.eps)) else: loss = 1.0 - scores # Dice loss is undefined for non-empty classes # So we zero contribution of channel that does not have true pixels # NOTE: A better workaround would be to use loss term `mean(y_pred)` # for this case, however it will be a modified jaccard loss mask = y_true.sum(dims) > 0 loss *= mask.to(loss.dtype) if self.classes is not None: loss = loss[self.classes] return self.aggregate_loss(loss) def aggregate_loss(self, loss): return loss.mean() def compute_score(self, output, target, smooth=0.0, eps=1e-7, dims=None) -> torch.Tensor: return soft_dice_score(output, target, smooth, eps, dims) def soft_tversky_score( output: torch.Tensor, target: torch.Tensor, alpha: float, beta: float, smooth: float = 0.0, eps: float = 1e-7, dims=None, ) -> torch.Tensor: assert output.size() == target.size() if dims is not None: intersection = torch.sum(output * target, dim=dims) # TP fp = torch.sum(output * (1.0 - target), dim=dims) fn = torch.sum((1 - output) * target, dim=dims) else: intersection = torch.sum(output * target) # TP fp = torch.sum(output * (1.0 - target)) fn = torch.sum((1 - output) * target) tversky_score = (intersection + smooth) / (intersection + alpha * fp + beta * fn + smooth).clamp_min(eps) return tversky_score class TverskyLoss(DiceLoss): """Tversky loss for image segmentation task. Where TP and FP is weighted by alpha and beta params. With alpha == beta == 0.5, this loss becomes equal DiceLoss. It supports binary, multiclass and multilabel cases Args: mode: Metric mode {'binary', 'multiclass', 'multilabel'} classes: Optional list of classes that contribute in loss computation; By default, all channels are included. log_loss: If True, loss computed as ``-log(tversky)`` otherwise ``1 - tversky`` from_logits: If True assumes input is raw logits smooth: ignore_index: Label that indicates ignored pixels (does not contribute to loss) eps: Small epsilon for numerical stability alpha: Weight constant that penalize model for FPs (False Positives) beta: Weight constant that penalize model for FNs (False Positives) gamma: Constant that squares the error function. Defaults to ``1.0`` Return: loss: torch.Tensor """ def __init__( self, mode: str, classes: List[int] = None, log_loss: bool = False, from_logits: bool = True, smooth: float = 0.0, ignore_index: Optional[int] = None, eps: float = 1e-7, alpha: float = 0.5, beta: float = 0.5, gamma: float = 1.0 ): assert mode in {BINARY_MODE, MULTILABEL_MODE, MULTICLASS_MODE} super().__init__(mode, classes, log_loss, from_logits, smooth, ignore_index, eps) self.alpha = alpha self.beta = beta self.gamma = gamma def aggregate_loss(self, loss): return loss.mean() ** self.gamma def compute_score(self, output, target, smooth=0.0, eps=1e-7, dims=None) -> torch.Tensor: return soft_tversky_score(output, target, self.alpha, self.beta, smooth, eps, dims)