# Copyright (c) Facebook, Inc. and its affiliates. import torch import torch.nn.functional as F from torch import nn import math from detectron2.modeling import META_ARCH_REGISTRY, build_backbone from detectron2.structures import Boxes, Instances from ..utils import load_class_freq, get_fed_loss_inds from models.backbone import Joiner from models.deformable_detr import DeformableDETR, SetCriterion, MLP from models.deformable_detr import _get_clones from models.matcher import HungarianMatcher from models.position_encoding import PositionEmbeddingSine from models.deformable_transformer import DeformableTransformer from models.segmentation import sigmoid_focal_loss from util.box_ops import box_cxcywh_to_xyxy, box_xyxy_to_cxcywh from util.misc import NestedTensor, accuracy __all__ = ["DeformableDetr"] class CustomSetCriterion(SetCriterion): def __init__(self, num_classes, matcher, weight_dict, losses, \ focal_alpha=0.25, use_fed_loss=False): super().__init__(num_classes, matcher, weight_dict, losses, focal_alpha) self.use_fed_loss = use_fed_loss if self.use_fed_loss: self.register_buffer( 'fed_loss_weight', load_class_freq(freq_weight=0.5)) def loss_labels(self, outputs, targets, indices, num_boxes, log=True): """Classification loss (NLL) targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes] """ assert 'pred_logits' in outputs src_logits = outputs['pred_logits'] idx = self._get_src_permutation_idx(indices) target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)]) target_classes = torch.full(src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device) target_classes[idx] = target_classes_o target_classes_onehot = torch.zeros( [src_logits.shape[0], src_logits.shape[1], src_logits.shape[2] + 1], dtype=src_logits.dtype, layout=src_logits.layout, device=src_logits.device) target_classes_onehot.scatter_(2, target_classes.unsqueeze(-1), 1) target_classes_onehot = target_classes_onehot[:,:,:-1] # B x N x C if self.use_fed_loss: inds = get_fed_loss_inds( gt_classes=target_classes_o, num_sample_cats=50, weight=self.fed_loss_weight, C=target_classes_onehot.shape[2]) loss_ce = sigmoid_focal_loss( src_logits[:, :, inds], target_classes_onehot[:, :, inds], num_boxes, alpha=self.focal_alpha, gamma=2) * src_logits.shape[1] else: loss_ce = sigmoid_focal_loss( src_logits, target_classes_onehot, num_boxes, alpha=self.focal_alpha, gamma=2) * src_logits.shape[1] losses = {'loss_ce': loss_ce} if log: # TODO this should probably be a separate loss, not hacked in this one here losses['class_error'] = 100 - accuracy(src_logits[idx], target_classes_o)[0] return losses class MaskedBackbone(nn.Module): """ This is a thin wrapper around D2's backbone to provide padding masking""" def __init__(self, cfg): super().__init__() self.backbone = build_backbone(cfg) backbone_shape = self.backbone.output_shape() self.feature_strides = [backbone_shape[f].stride for f in backbone_shape.keys()] self.strides = [backbone_shape[f].stride for f in backbone_shape.keys()] self.num_channels = [backbone_shape[x].channels for x in backbone_shape.keys()] def forward(self, tensor_list: NestedTensor): xs = self.backbone(tensor_list.tensors) out = {} for name, x in xs.items(): m = tensor_list.mask assert m is not None mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0] out[name] = NestedTensor(x, mask) return out @META_ARCH_REGISTRY.register() class DeformableDetr(nn.Module): """ Implement Deformable Detr """ def __init__(self, cfg): super().__init__() self.with_image_labels = cfg.WITH_IMAGE_LABELS self.weak_weight = cfg.MODEL.DETR.WEAK_WEIGHT self.device = torch.device(cfg.MODEL.DEVICE) self.test_topk = cfg.TEST.DETECTIONS_PER_IMAGE self.num_classes = cfg.MODEL.DETR.NUM_CLASSES self.mask_on = cfg.MODEL.MASK_ON hidden_dim = cfg.MODEL.DETR.HIDDEN_DIM num_queries = cfg.MODEL.DETR.NUM_OBJECT_QUERIES # Transformer parameters: nheads = cfg.MODEL.DETR.NHEADS dropout = cfg.MODEL.DETR.DROPOUT dim_feedforward = cfg.MODEL.DETR.DIM_FEEDFORWARD enc_layers = cfg.MODEL.DETR.ENC_LAYERS dec_layers = cfg.MODEL.DETR.DEC_LAYERS num_feature_levels = cfg.MODEL.DETR.NUM_FEATURE_LEVELS two_stage = cfg.MODEL.DETR.TWO_STAGE with_box_refine = cfg.MODEL.DETR.WITH_BOX_REFINE # Loss parameters: giou_weight = cfg.MODEL.DETR.GIOU_WEIGHT l1_weight = cfg.MODEL.DETR.L1_WEIGHT deep_supervision = cfg.MODEL.DETR.DEEP_SUPERVISION cls_weight = cfg.MODEL.DETR.CLS_WEIGHT focal_alpha = cfg.MODEL.DETR.FOCAL_ALPHA N_steps = hidden_dim // 2 d2_backbone = MaskedBackbone(cfg) backbone = Joiner(d2_backbone, PositionEmbeddingSine(N_steps, normalize=True)) transformer = DeformableTransformer( d_model=hidden_dim, nhead=nheads, num_encoder_layers=enc_layers, num_decoder_layers=dec_layers, dim_feedforward=dim_feedforward, dropout=dropout, activation="relu", return_intermediate_dec=True, num_feature_levels=num_feature_levels, dec_n_points=4, enc_n_points=4, two_stage=two_stage, two_stage_num_proposals=num_queries) self.detr = DeformableDETR( backbone, transformer, num_classes=self.num_classes, num_queries=num_queries, num_feature_levels=num_feature_levels, aux_loss=deep_supervision, with_box_refine=with_box_refine, two_stage=two_stage, ) if self.mask_on: assert 0, 'Mask is not supported yet :(' matcher = HungarianMatcher( cost_class=cls_weight, cost_bbox=l1_weight, cost_giou=giou_weight) weight_dict = {"loss_ce": cls_weight, "loss_bbox": l1_weight} weight_dict["loss_giou"] = giou_weight if deep_supervision: aux_weight_dict = {} for i in range(dec_layers - 1): aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()}) weight_dict.update(aux_weight_dict) print('weight_dict', weight_dict) losses = ["labels", "boxes", "cardinality"] if self.mask_on: losses += ["masks"] self.criterion = CustomSetCriterion( self.num_classes, matcher=matcher, weight_dict=weight_dict, focal_alpha=focal_alpha, losses=losses, use_fed_loss=cfg.MODEL.DETR.USE_FED_LOSS ) pixel_mean = torch.Tensor(cfg.MODEL.PIXEL_MEAN).to(self.device).view(3, 1, 1) pixel_std = torch.Tensor(cfg.MODEL.PIXEL_STD).to(self.device).view(3, 1, 1) self.normalizer = lambda x: (x - pixel_mean) / pixel_std def forward(self, batched_inputs): """ Args: Returns: dict[str: Tensor]: mapping from a named loss to a tensor storing the loss. Used during training only. """ images = self.preprocess_image(batched_inputs) output = self.detr(images) if self.training: gt_instances = [x["instances"].to(self.device) for x in batched_inputs] targets = self.prepare_targets(gt_instances) loss_dict = self.criterion(output, targets) weight_dict = self.criterion.weight_dict for k in loss_dict.keys(): if k in weight_dict: loss_dict[k] *= weight_dict[k] if self.with_image_labels: if batched_inputs[0]['ann_type'] in ['image', 'captiontag']: loss_dict['loss_image'] = self.weak_weight * self._weak_loss( output, batched_inputs) else: loss_dict['loss_image'] = images[0].new_zeros( [1], dtype=torch.float32)[0] # import pdb; pdb.set_trace() return loss_dict else: image_sizes = output["pred_boxes"].new_tensor( [(t["height"], t["width"]) for t in batched_inputs]) results = self.post_process(output, image_sizes) return results def prepare_targets(self, targets): new_targets = [] for targets_per_image in targets: h, w = targets_per_image.image_size image_size_xyxy = torch.as_tensor([w, h, w, h], dtype=torch.float, device=self.device) gt_classes = targets_per_image.gt_classes gt_boxes = targets_per_image.gt_boxes.tensor / image_size_xyxy gt_boxes = box_xyxy_to_cxcywh(gt_boxes) new_targets.append({"labels": gt_classes, "boxes": gt_boxes}) if self.mask_on and hasattr(targets_per_image, 'gt_masks'): assert 0, 'Mask is not supported yet :(' gt_masks = targets_per_image.gt_masks gt_masks = convert_coco_poly_to_mask(gt_masks.polygons, h, w) new_targets[-1].update({'masks': gt_masks}) return new_targets def post_process(self, outputs, target_sizes): """ """ out_logits, out_bbox = outputs['pred_logits'], outputs['pred_boxes'] assert len(out_logits) == len(target_sizes) assert target_sizes.shape[1] == 2 prob = out_logits.sigmoid() topk_values, topk_indexes = torch.topk( prob.view(out_logits.shape[0], -1), self.test_topk, dim=1) scores = topk_values topk_boxes = topk_indexes // out_logits.shape[2] labels = topk_indexes % out_logits.shape[2] boxes = box_cxcywh_to_xyxy(out_bbox) boxes = torch.gather(boxes, 1, topk_boxes.unsqueeze(-1).repeat(1,1,4)) # and from relative [0, 1] to absolute [0, height] coordinates img_h, img_w = target_sizes.unbind(1) scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1) boxes = boxes * scale_fct[:, None, :] results = [] for s, l, b, size in zip(scores, labels, boxes, target_sizes): r = Instances((size[0], size[1])) r.pred_boxes = Boxes(b) r.scores = s r.pred_classes = l results.append({'instances': r}) return results def preprocess_image(self, batched_inputs): """ Normalize, pad and batch the input images. """ images = [self.normalizer(x["image"].to(self.device)) for x in batched_inputs] return images def _weak_loss(self, outputs, batched_inputs): loss = 0 for b, x in enumerate(batched_inputs): labels = x['pos_category_ids'] pred_logits = [outputs['pred_logits'][b]] pred_boxes = [outputs['pred_boxes'][b]] for xx in outputs['aux_outputs']: pred_logits.append(xx['pred_logits'][b]) pred_boxes.append(xx['pred_boxes'][b]) pred_logits = torch.stack(pred_logits, dim=0) # L x N x C pred_boxes = torch.stack(pred_boxes, dim=0) # L x N x 4 for label in labels: loss += self._max_size_loss( pred_logits, pred_boxes, label) / len(labels) loss = loss / len(batched_inputs) return loss def _max_size_loss(self, logits, boxes, label): ''' Inputs: logits: L x N x C boxes: L x N x 4 ''' target = logits.new_zeros((logits.shape[0], logits.shape[2])) target[:, label] = 1. sizes = boxes[..., 2] * boxes[..., 3] # L x N ind = sizes.argmax(dim=1) # L loss = F.binary_cross_entropy_with_logits( logits[range(len(ind)), ind], target, reduction='sum') return loss