# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved """ DETR model and criterion classes. """ import torch import torch.nn.functional as F from torch import nn from utils.span_utils import generalized_temporal_iou, span_cxw_to_xx from model.transformer import build_transformer from model.matcher import build_matcher from model.position_encoding import build_position_encoding @torch.no_grad() def accuracy(output, target, topk=(1,)): """Computes the precision@k for the specified values of k output: (#items, #classes) target: int, """ maxk = max(topk) num_items = output.size(0) _, pred = output.topk(maxk, 1, True, True) pred = pred.t() correct = pred.eq(target) res = [] for k in topk: correct_k = correct[:k].view(-1).float().sum(0) res.append(correct_k.mul_(100.0 / num_items)) return res class Model(nn.Module): """ This is the Moment-DETR module that performs moment localization. """ def __init__(self, transformer, position_embed, txt_position_embed, txt_dim, vid_dim, num_queries, input_dropout, aux_loss=False, contrastive_align_loss=False, contrastive_hdim=64, max_v_l=75, span_loss_type="l1", use_txt_pos=False, n_input_proj=2): """ Initializes the model. Parameters: transformer: torch module of the transformer architecture. See transformer.py position_embed: torch module of the position_embedding, See position_encoding.py txt_position_embed: position_embedding for text txt_dim: int, text query input dimension vid_dim: int, video feature input dimension num_queries: number of object queries, ie detection slot. This is the maximal number of objects Moment-DETR can detect in a single video. aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used. contrastive_align_loss: If true, perform span - tokens contrastive learning contrastive_hdim: dimension used for projecting the embeddings before computing contrastive loss max_v_l: int, maximum #clips in videos span_loss_type: str, one of [l1, ce] l1: (center-x, width) regression. ce: (st_idx, ed_idx) classification. # foreground_thd: float, intersection over prediction >= foreground_thd: labeled as foreground # background_thd: float, intersection over prediction <= background_thd: labeled background """ super().__init__() self.num_queries = num_queries self.transformer = transformer self.position_embed = position_embed self.txt_position_embed = txt_position_embed hidden_dim = transformer.d_model self.span_loss_type = span_loss_type self.max_v_l = max_v_l span_pred_dim = 2 if span_loss_type == "l1" else max_v_l * 2 self.span_embed = MLP(hidden_dim, hidden_dim, span_pred_dim, 3) self.class_embed = nn.Linear(hidden_dim, 2) # 0: background, 1: foreground self.use_txt_pos = use_txt_pos self.n_input_proj = n_input_proj # self.foreground_thd = foreground_thd # self.background_thd = background_thd self.query_embed = nn.Embedding(num_queries, hidden_dim) relu_args = [True] * 3 relu_args[n_input_proj-1] = False self.input_txt_proj = nn.Sequential(*[ LinearLayer(txt_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[0]), LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[1]), LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[2]) ][:n_input_proj]) self.input_vid_proj = nn.Sequential(*[ LinearLayer(vid_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[0]), LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[1]), LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[2]) ][:n_input_proj]) self.contrastive_align_loss = contrastive_align_loss if contrastive_align_loss: self.contrastive_align_projection_query = nn.Linear(hidden_dim, contrastive_hdim) self.contrastive_align_projection_txt = nn.Linear(hidden_dim, contrastive_hdim) self.contrastive_align_projection_vid = nn.Linear(hidden_dim, contrastive_hdim) self.saliency_proj = nn.Linear(hidden_dim, 1) self.aux_loss = aux_loss def forward(self, src_txt, src_txt_mask, src_vid, src_vid_mask): """The forward expects two tensors: - src_txt: [batch_size, L_txt, D_txt] - src_txt_mask: [batch_size, L_txt], containing 0 on padded pixels, will convert to 1 as padding later for transformer - src_vid: [batch_size, L_vid, D_vid] - src_vid_mask: [batch_size, L_vid], containing 0 on padded pixels, will convert to 1 as padding later for transformer It returns a dict with the following elements: - "pred_spans": The normalized boxes coordinates for all queries, represented as (center_x, width). These values are normalized in [0, 1], relative to the size of each individual image (disregarding possible padding). See PostProcess for information on how to retrieve the unnormalized bounding box. - "aux_outputs": Optional, only returned when auxilary losses are activated. It is a list of dictionnaries containing the two above keys for each decoder layer. """ src_vid = self.input_vid_proj(src_vid) src_txt = self.input_txt_proj(src_txt) src = torch.cat([src_vid, src_txt], dim=1) # (bsz, L_vid+L_txt, d) mask = torch.cat([src_vid_mask, src_txt_mask], dim=1).bool() # (bsz, L_vid+L_txt) # TODO should we remove or use different positional embeddings to the src_txt? pos_vid = self.position_embed(src_vid, src_vid_mask) # (bsz, L_vid, d) pos_txt = self.txt_position_embed(src_txt) if self.use_txt_pos else torch.zeros_like(src_txt) # (bsz, L_txt, d) # pos_txt = torch.zeros_like(src_txt) # pad zeros for txt positions pos = torch.cat([pos_vid, pos_txt], dim=1) # (#layers, bsz, #queries, d), (bsz, L_vid+L_txt, d) hs, memory = self.transformer(src, ~mask, self.query_embed.weight, pos) outputs_class = self.class_embed(hs) # (#layers, batch_size, #queries, #classes) outputs_coord = self.span_embed(hs) # (#layers, bsz, #queries, 2 or max_v_l * 2) if self.span_loss_type == "l1": outputs_coord = outputs_coord.sigmoid() out = {'pred_logits': outputs_class[-1], 'pred_spans': outputs_coord[-1]} txt_mem = memory[:, src_vid.shape[1]:] # (bsz, L_txt, d) vid_mem = memory[:, :src_vid.shape[1]] # (bsz, L_vid, d) if self.contrastive_align_loss: proj_queries = F.normalize(self.contrastive_align_projection_query(hs), p=2, dim=-1) proj_txt_mem = F.normalize(self.contrastive_align_projection_txt(txt_mem), p=2, dim=-1) proj_vid_mem = F.normalize(self.contrastive_align_projection_vid(vid_mem), p=2, dim=-1) out.update(dict( proj_queries=proj_queries[-1], proj_txt_mem=proj_txt_mem, proj_vid_mem=proj_vid_mem )) out["saliency_scores"] = self.saliency_proj(vid_mem).squeeze(-1) # (bsz, L_vid) if self.aux_loss: # assert proj_queries and proj_txt_mem out['aux_outputs'] = [ {'pred_logits': a, 'pred_spans': b} for a, b in zip(outputs_class[:-1], outputs_coord[:-1])] if self.contrastive_align_loss: assert proj_queries is not None for idx, d in enumerate(proj_queries[:-1]): out['aux_outputs'][idx].update(dict(proj_queries=d, proj_txt_mem=proj_txt_mem)) return out # @torch.jit.unused # def _set_aux_loss(self, outputs_class, outputs_coord): # # this is a workaround to make torchscript happy, as torchscript # # doesn't support dictionary with non-homogeneous values, such # # as a dict having both a Tensor and a list. # return [{'pred_logits': a, 'pred_spans': b} # for a, b in zip(outputs_class[:-1], outputs_coord[:-1])] class SetCriterion(nn.Module): """ This class computes the loss for DETR. The process happens in two steps: 1) we compute hungarian assignment between ground truth boxes and the outputs of the model 2) we supervise each pair of matched ground-truth / prediction (supervise class and box) """ def __init__(self, matcher, weight_dict, eos_coef, losses, temperature, span_loss_type, max_v_l, saliency_margin=1): """ Create the criterion. Parameters: matcher: module able to compute a matching between targets and proposals weight_dict: dict containing as key the names of the losses and as values their relative weight. eos_coef: relative classification weight applied to the no-object category losses: list of all the losses to be applied. See get_loss for list of available losses. temperature: float, temperature for NCE loss span_loss_type: str, [l1, ce] max_v_l: int, saliency_margin: float """ super().__init__() self.matcher = matcher self.weight_dict = weight_dict self.losses = losses self.temperature = temperature self.span_loss_type = span_loss_type self.max_v_l = max_v_l self.saliency_margin = saliency_margin # foreground and background classification self.foreground_label = 0 self.background_label = 1 self.eos_coef = eos_coef empty_weight = torch.ones(2) empty_weight[-1] = self.eos_coef # lower weight for background (index 1, foreground index 0) self.register_buffer('empty_weight', empty_weight) def loss_spans(self, outputs, targets, indices): """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss targets dicts must contain the key "spans" containing a tensor of dim [nb_tgt_spans, 2] The target spans are expected in format (center_x, w), normalized by the image size. """ assert 'pred_spans' in outputs targets = targets["span_labels"] idx = self._get_src_permutation_idx(indices) src_spans = outputs['pred_spans'][idx] # (#spans, max_v_l * 2) tgt_spans = torch.cat([t['spans'][i] for t, (_, i) in zip(targets, indices)], dim=0) # (#spans, 2) if self.span_loss_type == "l1": loss_span = F.l1_loss(src_spans, tgt_spans, reduction='none') loss_giou = 1 - torch.diag(generalized_temporal_iou(span_cxw_to_xx(src_spans), span_cxw_to_xx(tgt_spans))) else: # ce n_spans = src_spans.shape[0] src_spans = src_spans.view(n_spans, 2, self.max_v_l).transpose(1, 2) loss_span = F.cross_entropy(src_spans, tgt_spans, reduction='none') # giou # src_span_indices = src_spans.max(1)[1] # (#spans, 2) # src_span_indices[:, 1] += 1 # ed non-inclusive [st, ed) # # tgt_span_indices = tgt_spans # tgt_span_indices[:, 1] += 1 # loss_giou = 1 - torch.diag(generalized_temporal_iou(src_span_indices, tgt_span_indices)) loss_giou = loss_span.new_zeros([1]) losses = {} losses['loss_b'] = loss_span.mean() losses['loss_g'] = loss_giou.mean() return losses def loss_labels(self, outputs, targets, indices, log=True): """Classification loss (NLL) targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes] """ # TODO add foreground and background classifier. use all non-matched as background. assert 'pred_logits' in outputs src_logits = outputs['pred_logits'] # (batch_size, #queries, #classes=2) # idx is a tuple of two 1D tensors (batch_idx, src_idx), of the same length == #objects in batch idx = self._get_src_permutation_idx(indices) target_classes = torch.full(src_logits.shape[:2], self.background_label, dtype=torch.int64, device=src_logits.device) # (batch_size, #queries) target_classes[idx] = self.foreground_label loss_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes, self.empty_weight, reduction="none") losses = {'loss_f': loss_ce.mean()} if log: # TODO this should probably be a separate loss, not hacked in this one here losses['class_error'] = 100 - accuracy(src_logits[idx], self.foreground_label)[0] return losses def loss_saliency(self, outputs, targets, indices, log=True): """higher scores for positive clips""" if "saliency_pos_labels" not in targets: return {"loss_s_intra": 0} saliency_scores = outputs["saliency_scores"] # (N, L) pos_indices = targets["saliency_pos_labels"] # (N, #pairs) neg_indices = targets["saliency_neg_labels"] # (N, #pairs) num_pairs = pos_indices.shape[1] # typically 2 or 4 batch_indices = torch.arange(len(saliency_scores)).to(saliency_scores.device) pos_scores = torch.stack( [saliency_scores[batch_indices, pos_indices[:, col_idx]] for col_idx in range(num_pairs)], dim=1) neg_scores = torch.stack( [saliency_scores[batch_indices, neg_indices[:, col_idx]] for col_idx in range(num_pairs)], dim=1) loss_saliency = torch.clamp(self.saliency_margin + neg_scores - pos_scores, min=0).sum() \ / (len(pos_scores) * num_pairs) * 2 # * 2 to keep the loss the same scale return {"loss_s_intra": loss_saliency} def loss_contrastive_align(self, outputs, targets, indices, log=True): """encourage higher scores between matched query span and input text""" normalized_text_embed = outputs["proj_txt_mem"] # (bsz, #tokens, d) text tokens normalized_img_embed = outputs["proj_queries"] # (bsz, #queries, d) logits = torch.einsum( "bmd,bnd->bmn", normalized_img_embed, normalized_text_embed) # (bsz, #queries, #tokens) logits = logits.sum(2) / self.temperature # (bsz, #queries) idx = self._get_src_permutation_idx(indices) positive_map = torch.zeros_like(logits, dtype=torch.bool) positive_map[idx] = True positive_logits = logits.masked_fill(~positive_map, 0) pos_term = positive_logits.sum(1) # (bsz, ) num_pos = positive_map.sum(1) # (bsz, ) neg_term = logits.logsumexp(1) # (bsz, ) loss_nce = - pos_term / num_pos + neg_term # (bsz, ) losses = {"loss_contrastive_align": loss_nce.mean()} return losses def loss_contrastive_align_vid_txt(self, outputs, targets, indices, log=True): """encourage higher scores between matched query span and input text""" # TODO (1) align vid_mem and txt_mem; # TODO (2) change L1 loss as CE loss on 75 labels, similar to soft token prediction in MDETR normalized_text_embed = outputs["proj_txt_mem"] # (bsz, #tokens, d) text tokens normalized_img_embed = outputs["proj_queries"] # (bsz, #queries, d) logits = torch.einsum( "bmd,bnd->bmn", normalized_img_embed, normalized_text_embed) # (bsz, #queries, #tokens) logits = logits.sum(2) / self.temperature # (bsz, #queries) idx = self._get_src_permutation_idx(indices) positive_map = torch.zeros_like(logits, dtype=torch.bool) positive_map[idx] = True positive_logits = logits.masked_fill(~positive_map, 0) pos_term = positive_logits.sum(1) # (bsz, ) num_pos = positive_map.sum(1) # (bsz, ) neg_term = logits.logsumexp(1) # (bsz, ) loss_nce = - pos_term / num_pos + neg_term # (bsz, ) losses = {"loss_contrastive_align": loss_nce.mean()} return losses def _get_src_permutation_idx(self, indices): # permute predictions following indices batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)]) src_idx = torch.cat([src for (src, _) in indices]) return batch_idx, src_idx # two 1D tensors of the same length def _get_tgt_permutation_idx(self, indices): # permute targets following indices batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)]) tgt_idx = torch.cat([tgt for (_, tgt) in indices]) return batch_idx, tgt_idx def get_loss(self, loss, outputs, targets, indices, **kwargs): loss_map = { "spans": self.loss_spans, "labels": self.loss_labels, "contrastive_align": self.loss_contrastive_align, "saliency": self.loss_saliency, } assert loss in loss_map, f'do you really want to compute {loss} loss?' return loss_map[loss](outputs, targets, indices, **kwargs) def forward(self, outputs, targets): """ This performs the loss computation. Parameters: outputs: dict of tensors, see the output specification of the model for the format targets: list of dicts, such that len(targets) == batch_size. The expected keys in each dict depends on the losses applied, see each loss' doc """ outputs_without_aux = {k: v for k, v in outputs.items() if k != 'aux_outputs'} # Retrieve the matching between the outputs of the last layer and the targets # list(tuples), each tuple is (pred_span_indices, tgt_span_indices) indices = self.matcher(outputs_without_aux, targets) # Compute all the requested losses losses = {} for loss in self.losses: losses.update(self.get_loss(loss, outputs, targets, indices)) # In case of auxiliary losses, we repeat this process with the output of each intermediate layer. if 'aux_outputs' in outputs: for i, aux_outputs in enumerate(outputs['aux_outputs']): indices = self.matcher(aux_outputs, targets) for loss in self.losses: if "saliency" == loss: # skip as it is only in the top layer continue kwargs = {} l_dict = self.get_loss(loss, aux_outputs, targets, indices, **kwargs) l_dict = {k + f'_{i}': v for k, v in l_dict.items()} losses.update(l_dict) return losses class MLP(nn.Module): """ Very simple multi-layer perceptron (also called FFN)""" def __init__(self, input_dim, hidden_dim, output_dim, num_layers): super().__init__() self.num_layers = num_layers h = [hidden_dim] * (num_layers - 1) self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) def forward(self, x): for i, layer in enumerate(self.layers): x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) return x class LinearLayer(nn.Module): """linear layer configurable with layer normalization, dropout, ReLU.""" def __init__(self, in_hsz, out_hsz, layer_norm=True, dropout=0.1, relu=True): super(LinearLayer, self).__init__() self.relu = relu self.layer_norm = layer_norm if layer_norm: self.LayerNorm = nn.LayerNorm(in_hsz) layers = [ nn.Dropout(dropout), nn.Linear(in_hsz, out_hsz) ] self.net = nn.Sequential(*layers) def forward(self, x): """(N, L, D)""" if self.layer_norm: x = self.LayerNorm(x) x = self.net(x) if self.relu: x = F.relu(x, inplace=True) return x # (N, L, D) def build_model(args): # the `num_classes` naming here is somewhat misleading. # it indeed corresponds to `max_obj_id + 1`, where max_obj_id # is the maximum id for a class in your dataset. For example, # COCO has a max_obj_id of 90, so we pass `num_classes` to be 91. # As another example, for a dataset that has a single class with id 1, # you should pass `num_classes` to be 2 (max_obj_id + 1). # For more details on this, check the following discussion # https://github.com/facebookresearch/moment_bert/issues/108#issuecomment-650269223 device = torch.device(args.device) transformer = build_transformer(args) position_embedding, txt_position_embedding = build_position_encoding(args) model = Model( transformer, position_embedding, txt_position_embedding, txt_dim=args.t_feat_dim, vid_dim=args.v_feat_dim, num_queries=args.num_queries, input_dropout=args.input_dropout, aux_loss=args.aux_loss, # contrastive_align_loss=args.contrastive_align_loss, # contrastive_hdim=args.contrastive_hdim, span_loss_type=args.span_loss_type, use_txt_pos=args.use_txt_pos, n_input_proj=args.n_input_proj, ) matcher = build_matcher(args) weight_dict = {"loss_b": args.b_loss_coef, "loss_g": args.g_loss_coef, "loss_f": args.f_loss_coef, "loss_s_intra": args.s_loss_intra_coef, "loss_s_inter": args.s_loss_inter_coef} # if args.contrastive_align_loss: # weight_dict["loss_contrastive_align"] = args.contrastive_align_loss_coef # TODO this is a hack if args.aux_loss: aux_weight_dict = {} for i in range(args.dec_layers - 1): aux_weight_dict.update({k + f'_{i}': v for k, v in weight_dict.items() if k != "loss_saliency"}) weight_dict.update(aux_weight_dict) losses = ['spans', 'labels', 'saliency'] # if args.contrastive_align_loss: # losses += ["contrastive_align"] criterion = SetCriterion( matcher=matcher, weight_dict=weight_dict, losses=losses, eos_coef=args.eos_coef, temperature=args.temperature, span_loss_type=args.span_loss_type, max_v_l=args.max_v_l, saliency_margin=args.saliency_margin ) criterion.to(device) return model, criterion