| | import torch
|
| | import torch.nn as nn
|
| | import torch.nn.functional as F
|
| | import torch.distributed as dist
|
| |
|
| | from torch import Tensor
|
| | import torchvision
|
| | import torch.distributed as dist
|
| | from typing import List, Optional
|
| |
|
| |
|
| | def _max_by_axis(the_list):
|
| |
|
| | maxes = the_list[0]
|
| | for sublist in the_list[1:]:
|
| | for index, item in enumerate(sublist):
|
| | maxes[index] = max(maxes[index], item)
|
| | return maxes
|
| |
|
| |
|
| | class NestedTensor(object):
|
| | def __init__(self, tensors, mask: Optional[Tensor]):
|
| | self.tensors = tensors
|
| | self.mask = mask
|
| |
|
| | def to(self, device):
|
| |
|
| | cast_tensor = self.tensors.to(device)
|
| | mask = self.mask
|
| | if mask is not None:
|
| | assert mask is not None
|
| | cast_mask = mask.to(device)
|
| | else:
|
| | cast_mask = None
|
| | return NestedTensor(cast_tensor, cast_mask)
|
| |
|
| | def decompose(self):
|
| | return self.tensors, self.mask
|
| |
|
| | def __repr__(self):
|
| | return str(self.tensors)
|
| |
|
| |
|
| | def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
|
| |
|
| | if tensor_list[0].ndim == 3:
|
| | if torchvision._is_tracing():
|
| |
|
| |
|
| | return _onnx_nested_tensor_from_tensor_list(tensor_list)
|
| |
|
| |
|
| | max_size = _max_by_axis([list(img.shape) for img in tensor_list])
|
| |
|
| | batch_shape = [len(tensor_list)] + max_size
|
| | b, c, h, w = batch_shape
|
| | dtype = tensor_list[0].dtype
|
| | device = tensor_list[0].device
|
| | tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
|
| | mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
|
| | for img, pad_img, m in zip(tensor_list, tensor, mask):
|
| | pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
|
| | m[: img.shape[1], : img.shape[2]] = False
|
| | else:
|
| | raise ValueError("not supported")
|
| | return NestedTensor(tensor, mask)
|
| |
|
| |
|
| |
|
| |
|
| | @torch.jit.unused
|
| | def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor:
|
| | max_size = []
|
| | for i in range(tensor_list[0].dim()):
|
| | max_size_i = torch.max(
|
| | torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)
|
| | ).to(torch.int64)
|
| | max_size.append(max_size_i)
|
| | max_size = tuple(max_size)
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | padded_imgs = []
|
| | padded_masks = []
|
| | for img in tensor_list:
|
| | padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))]
|
| | padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0]))
|
| | padded_imgs.append(padded_img)
|
| |
|
| | m = torch.zeros_like(img[0], dtype=torch.int, device=img.device)
|
| | padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1)
|
| | padded_masks.append(padded_mask.to(torch.bool))
|
| |
|
| | tensor = torch.stack(padded_imgs)
|
| | mask = torch.stack(padded_masks)
|
| |
|
| | return NestedTensor(tensor, mask=mask)
|
| |
|
| |
|
| | def is_dist_avail_and_initialized():
|
| | if not dist.is_available():
|
| | return False
|
| | if not dist.is_initialized():
|
| | return False
|
| | return True
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | def get_world_size() -> int:
|
| | if not dist.is_available():
|
| | return 1
|
| | if not dist.is_initialized():
|
| | return 1
|
| | return dist.get_world_size()
|
| |
|
| |
|
| | def dice_loss(inputs, targets, num_masks):
|
| | """
|
| | Compute the DICE loss, similar to generalized IOU for masks
|
| | Args:
|
| | inputs: A float tensor of arbitrary shape.
|
| | The predictions for each example.
|
| | targets: A float tensor with the same shape as inputs. Stores the binary
|
| | classification label for each element in inputs
|
| | (0 for the negative class and 1 for the positive class).
|
| | """
|
| | inputs = inputs.sigmoid()
|
| | inputs = inputs.flatten(1)
|
| | numerator = 2 * (inputs * targets).sum(-1)
|
| | denominator = inputs.sum(-1) + targets.sum(-1)
|
| | loss = 1 - (numerator + 1) / (denominator + 1)
|
| | return loss.sum() / num_masks
|
| |
|
| |
|
| | def sigmoid_focal_loss(inputs, targets, num_masks, alpha: float = 0.25, gamma: float = 2):
|
| | """
|
| | Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
|
| | Args:
|
| | inputs: A float tensor of arbitrary shape.
|
| | The predictions for each example.
|
| | targets: A float tensor with the same shape as inputs. Stores the binary
|
| | classification label for each element in inputs
|
| | (0 for the negative class and 1 for the positive class).
|
| | alpha: (optional) Weighting factor in range (0,1) to balance
|
| | positive vs negative examples. Default = -1 (no weighting).
|
| | gamma: Exponent of the modulating factor (1 - p_t) to
|
| | balance easy vs hard examples.
|
| | Returns:
|
| | Loss tensor
|
| | """
|
| | prob = inputs.sigmoid()
|
| | ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
|
| | p_t = prob * targets + (1 - prob) * (1 - targets)
|
| | loss = ce_loss * ((1 - p_t) ** gamma)
|
| |
|
| | if alpha >= 0:
|
| | alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
|
| | loss = alpha_t * loss
|
| |
|
| | return loss.mean(1).sum() / num_masks
|
| |
|
| |
|
| | class SetCriterion(nn.Module):
|
| | """This class computes the loss for DETR.
|
| | The process happens in two steps:
|
| | 1) we compute hungarian assignment between ground truth boxes and the outputs of the model
|
| | 2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
|
| | """
|
| |
|
| | def __init__(self, num_classes, weight_dict, losses, eos_coef=0.1):
|
| | """Create the criterion.
|
| | Parameters:
|
| | num_classes: number of object categories, omitting the special no-object category
|
| | matcher: module able to compute a matching between targets and proposals
|
| | weight_dict: dict containing as key the names of the losses and as values their relative weight.
|
| | eos_coef: relative classification weight applied to the no-object category
|
| | losses: list of all the losses to be applied. See get_loss for list of available losses.
|
| | """
|
| | super().__init__()
|
| | self.num_classes = num_classes
|
| | self.weight_dict = weight_dict
|
| | self.eos_coef = eos_coef
|
| | self.losses = losses
|
| | empty_weight = torch.ones(self.num_classes + 1)
|
| | empty_weight[-1] = self.eos_coef
|
| | self.register_buffer("empty_weight", empty_weight)
|
| | self.empty_weight = self.empty_weight.to("cuda")
|
| |
|
| | def loss_labels(self, outputs, targets, indices, num_masks):
|
| | """Classification loss (NLL)
|
| | targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
|
| | """
|
| | assert "pred_logits" in outputs
|
| | src_logits = outputs["pred_logits"]
|
| |
|
| | idx = self._get_src_permutation_idx(indices)
|
| | target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)])
|
| | target_classes = torch.full(
|
| | src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device
|
| | )
|
| | target_classes[idx] = target_classes_o
|
| |
|
| | loss_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes, self.empty_weight)
|
| | losses = {"loss_ce": loss_ce}
|
| | return losses
|
| |
|
| | def loss_masks(self, outputs, targets, indices, num_masks):
|
| | """Compute the losses related to the masks: the focal loss and the dice loss.
|
| | targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w]
|
| | """
|
| | assert "pred_masks" in outputs
|
| |
|
| | src_idx = self._get_src_permutation_idx(indices)
|
| | tgt_idx = self._get_tgt_permutation_idx(indices)
|
| | src_masks = outputs["pred_masks"]
|
| | if src_masks.dim() != 4:
|
| | return {"no_loss": 0}
|
| | src_masks = src_masks[src_idx]
|
| | masks = [t["masks"] for t in targets]
|
| |
|
| | target_masks, valid = nested_tensor_from_tensor_list(masks).decompose()
|
| | target_masks = target_masks.to(src_masks)
|
| | target_masks = target_masks[tgt_idx]
|
| |
|
| |
|
| | src_masks = F.interpolate(
|
| | src_masks[:, None], size=target_masks.shape[-2:], mode="bilinear", align_corners=False
|
| | )
|
| | src_masks = src_masks[:, 0].flatten(1)
|
| |
|
| | target_masks = target_masks.flatten(1)
|
| | target_masks = target_masks.view(src_masks.shape)
|
| | losses = {
|
| | "loss_mask": sigmoid_focal_loss(src_masks, target_masks, num_masks),
|
| | "loss_dice": dice_loss(src_masks, target_masks, num_masks),
|
| | }
|
| | return losses
|
| |
|
| | def _get_src_permutation_idx(self, indices):
|
| |
|
| | batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
|
| | src_idx = torch.cat([src for (src, _) in indices])
|
| | return batch_idx, src_idx
|
| |
|
| | def _get_tgt_permutation_idx(self, indices):
|
| |
|
| | batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
|
| | tgt_idx = torch.cat([tgt for (_, tgt) in indices])
|
| | return batch_idx, tgt_idx
|
| |
|
| | def get_loss(self, loss, outputs, targets, indices, num_masks):
|
| | loss_map = {"labels": self.loss_labels, "masks": self.loss_masks}
|
| | assert loss in loss_map, f"do you really want to compute {loss} loss?"
|
| | return loss_map[loss](outputs, targets, indices, num_masks)
|
| |
|
| | def forward(self, outputs, targets):
|
| | """This performs the loss computation.
|
| | Parameters:
|
| | outputs: dict of tensors, see the output specification of the model for the format
|
| | targets: list of dicts, such that len(targets) == batch_size.
|
| | The expected keys in each dict depends on the losses applied, see each loss' doc
|
| | """
|
| | outputs_without_aux = {k: v for k, v in outputs.items() if k != "aux_outputs"}
|
| |
|
| |
|
| |
|
| | labels = [x['labels'] for x in targets]
|
| | indices_new = []
|
| | for label in labels:
|
| | t = torch.arange(len(label))
|
| | indices_new.append([label, t])
|
| | indices = indices_new
|
| |
|
| | num_masks = sum(len(t["labels"]) for t in targets)
|
| | num_masks = torch.as_tensor(
|
| | [num_masks], dtype=torch.float, device=next(iter(outputs.values())).device
|
| | )
|
| | if is_dist_avail_and_initialized():
|
| | torch.distributed.all_reduce(num_masks)
|
| | num_masks = torch.clamp(num_masks / get_world_size(), min=1).item()
|
| |
|
| |
|
| | losses = {}
|
| | for loss in self.losses:
|
| | losses.update(self.get_loss(loss, outputs, targets, indices, num_masks))
|
| |
|
| |
|
| | if "aux_outputs" in outputs:
|
| | for i, aux_outputs in enumerate(outputs["aux_outputs"]):
|
| |
|
| | for loss in self.losses:
|
| | l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_masks)
|
| | l_dict = {k + f"_{i}": v for k, v in l_dict.items()}
|
| | losses.update(l_dict)
|
| |
|
| | return losses
|
| |
|
| |
|
| |
|
| | class ATMLoss(nn.Module):
|
| | """ATMLoss.
|
| |
|
| | """
|
| |
|
| | def __init__(self,
|
| | ignore_index,
|
| | num_classes,
|
| | dec_layers,
|
| | mask_weight=20.0,
|
| | dice_weight=1.0,
|
| | cls_weight=1.0,
|
| | atm_loss_weight=1.0,
|
| | use_point=False):
|
| | super(ATMLoss, self).__init__()
|
| | self.ignore_index = ignore_index
|
| | weight_dict = {"loss_ce": cls_weight, "loss_mask": mask_weight, "loss_dice": dice_weight}
|
| | aux_weight_dict = {}
|
| | for i in range(dec_layers - 1):
|
| | aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()})
|
| | weight_dict.update(aux_weight_dict)
|
| | if use_point:
|
| | self.criterion = SetCriterion_point(
|
| | num_classes,
|
| | weight_dict=weight_dict,
|
| | losses=["labels", "masks"],
|
| | )
|
| | else:
|
| | self.criterion = SetCriterion(
|
| | num_classes,
|
| | weight_dict=weight_dict,
|
| | losses=["labels", "masks"],
|
| | )
|
| | self.loss_weight = atm_loss_weight
|
| |
|
| | def forward(self,
|
| | outputs,
|
| | label,
|
| | ):
|
| | """Forward function."""
|
| |
|
| | targets = self.prepare_targets(label)
|
| | losses = self.criterion(outputs, targets)
|
| |
|
| | totol_loss = torch.as_tensor(0, dtype=torch.float, device=label.device)
|
| | for k in list(losses.keys()):
|
| | if k in self.criterion.weight_dict:
|
| | losses[k] = losses[k] * self.criterion.weight_dict[k] * self.loss_weight
|
| | totol_loss += losses[k]
|
| | else:
|
| |
|
| | losses.pop(k)
|
| |
|
| | return totol_loss
|
| |
|
| | def prepare_targets(self, targets):
|
| | new_targets = []
|
| | for targets_per_image in targets:
|
| |
|
| | gt_cls = targets_per_image.unique()
|
| | gt_cls = gt_cls[gt_cls != self.ignore_index]
|
| | masks = []
|
| | for cls in gt_cls:
|
| | masks.append(targets_per_image == cls)
|
| | if len(gt_cls) == 0:
|
| | masks.append(targets_per_image == self.ignore_index)
|
| |
|
| | masks = torch.stack(masks, dim=0)
|
| | new_targets.append(
|
| | {
|
| | "labels": gt_cls,
|
| | "masks": masks,
|
| | }
|
| | )
|
| | return new_targets |