import pdb import torch import torch.nn.functional as F from torch import nn import numpy as np from model.transformer_encoder import build_transformer from model.matcher import build_matcher from model.position_encoding import build_position_encoding from utils.span_utils import generalized_temporal_iou, span_cxw_to_xx def init_weights(module): if isinstance(module, (nn.Linear, nn.Embedding)): module.weight.data.normal_(mean=0.0, std=0.02) elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() def mask_logits(inputs, mask, mask_value=-1e30): mask = mask.type(torch.float32) return inputs + (1.0 - mask) * mask_value def sim_matrix(a, b, eps=1e-8): """ added eps for numerical stability """ a_n, b_n = a.norm(dim=1)[:, None], b.norm(dim=1)[:, None] a_norm = a / torch.max(a_n, eps * torch.ones_like(a_n)) b_norm = b / torch.max(b_n, eps * torch.ones_like(b_n)) sim_mt = torch.mm(a_norm, b_norm.transpose(0, 1)) return sim_mt class WeightedPool(nn.Module): def __init__(self, dim): super(WeightedPool, self).__init__() weight = torch.empty(dim, 1) nn.init.xavier_uniform_(weight) self.weight = nn.Parameter(weight, requires_grad=True) def forward(self, x, mask): alpha = torch.tensordot(x, self.weight, dims=1) # shape = (batch_size, seq_length, 1) alpha = mask_logits(alpha, mask=mask.unsqueeze(2)) alphas = nn.Softmax(dim=1)(alpha) pooled_x = torch.matmul(x.transpose(1, 2), alphas) # (batch_size, dim, 1) pooled_x = pooled_x.squeeze(2) return pooled_x class Model(nn.Module): """ This is the UniVTG module that performs moment localization. """ def __init__(self, transformer, position_embed, txt_position_embed, txt_dim, vid_dim, input_dropout, aux_loss=False, max_v_l=75, span_loss_type="l1", use_txt_pos=False, n_input_proj=2): """ Initializes the model. Parameters: transformer: torch module of the transformer architecture. See transformer.py position_embed: torch module of the position_embedding, See position_encoding.py txt_position_embed: position_embedding for text txt_dim: int, text query input dimension vid_dim: int, video feature input dimension max_v_l: int, maximum #clips in videos span_loss_type: str, one of [l1, ce] l1: (center-x, width) regression. ce: (st_idx, ed_idx) classification. # foreground_thd: float, intersection over prediction >= foreground_thd: labeled as foreground # background_thd: float, intersection over prediction <= background_thd: labeled background """ super().__init__() self.transformer = transformer self.position_embed = position_embed self.txt_position_embed = txt_position_embed hidden_dim = transformer.d_model self.span_loss_type = span_loss_type self.max_v_l = max_v_l span_pred_dim = 2 if span_loss_type == "l1" else max_v_l * 2 self.prompt_learner = nn.Embedding(10, hidden_dim) self.token_type_embeddings = nn.Embedding(2, hidden_dim) self.token_type_embeddings.apply(init_weights) # Conv projector self.span_embed = Conv(hidden_dim, hidden_dim, span_pred_dim, 3, kernel_size=3) self.class_embed = Conv(hidden_dim, hidden_dim, 1, 3, kernel_size=3) # 0: background, 1: foreground self.use_txt_pos = use_txt_pos self.n_input_proj = n_input_proj relu_args = [True] * 3 relu_args[n_input_proj-1] = False self.input_txt_proj = nn.Sequential(*[ LinearLayer(txt_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[0]), LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[1]), LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[2]) ][:n_input_proj]) self.input_vid_proj = nn.Sequential(*[ LinearLayer(vid_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[0]), LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[1]), LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[2]) ][:n_input_proj]) # MLP Projector self.weightedpool = WeightedPool(hidden_dim) def forward(self, src_txt, src_txt_mask, src_vid, src_vid_mask, src_cls=None, src_cls_mask=None): bs = src_vid.shape[0] src_vid = self.input_vid_proj(src_vid) src_txt = self.input_txt_proj(src_txt) if src_cls is not None: src_cls = self.input_txt_proj(src_cls) src_prompt = self.prompt_learner.weight.unsqueeze(0).repeat(bs, 1, 1) src_prompt_mask = torch.ones((bs, src_prompt.shape[1])).cuda() if self.training: # src_txt = src_prompt # src_txt_mask = torch.ones_like(src_prompt).cuda() src_txt = torch.cat([src_prompt, src_txt], dim=1) src_txt_mask = torch.cat([src_prompt_mask, src_txt_mask], dim=1) else: src_txt = torch.cat([src_prompt, src_txt], dim=1) src_txt_mask = torch.cat([src_prompt_mask, src_txt_mask], dim=1) # type token. src_vid = src_vid + self.token_type_embeddings(torch.full_like(src_vid_mask.long(), 1)) src_txt = src_txt + self.token_type_embeddings(torch.zeros_like(src_txt_mask.long())) if src_cls is not None: src_cls = src_cls + self.token_type_embeddings(torch.zeros_like(src_cls_mask.long())) src = torch.cat([src_vid, src_txt], dim=1) # (bsz, L_vid+L_txt, d) mask = torch.cat([src_vid_mask, src_txt_mask], dim=1).bool() # (bsz, L_vid+L_txt) pos_vid = self.position_embed(src_vid, src_vid_mask) # (bsz, L_vid, d) pos_txt = self.txt_position_embed(src_txt) if self.use_txt_pos else torch.zeros_like(src_txt) # (bsz, L_txt, d) pos = torch.cat([pos_vid, pos_txt], dim=1) memory = self.transformer(src, ~mask, pos) vid_mem = memory[:, :src_vid.shape[1], :] # (bsz, L_vid, d) outputs_class = self.class_embed(vid_mem).sigmoid() # (#layers, batch_size, #queries, #classes) outputs_coord = self.span_embed(vid_mem) # (#layers, bsz, #queries, 2 or max_v_l * 2) if self.span_loss_type == "l1": outputs_coord = outputs_coord.sigmoid() idx_mask = torch.tensor((-1, 1)).unsqueeze(0).unsqueeze(0).cuda() idx_mask = idx_mask.repeat(outputs_coord.shape[0], outputs_coord.shape[1], 1) outputs_coord = outputs_coord * idx_mask else: raise NotImplementedError out = {'pred_logits': outputs_class, 'pred_spans': outputs_coord, 'src_vid_mask': src_vid_mask} vid_mem_proj = src_vid # word-level -> sentence-level txt_mem_proj = self.weightedpool(src_txt, src_txt_mask).unsqueeze(1) sim = F.cosine_similarity(vid_mem_proj, txt_mem_proj, dim=-1) + (src_vid_mask + 1e-45).log() out["vid_mem_proj"] = vid_mem_proj out["txt_mem_proj"] = txt_mem_proj if src_cls is not None: cls_mem_proj = self.weightedpool(src_cls, src_cls_mask) out["cls_mem_proj"] = cls_mem_proj out["saliency_scores"] = sim return out class SetCriterion(nn.Module): """ This class computes the loss for DETR. The process happens in two steps: 1) we compute hungarian assignment between ground truth boxes and the outputs of the model 2) we supervise each pair of matched ground-truth / prediction (supervise class and box) """ def __init__(self, matcher, weight_dict, eos_coef, losses, temperature, span_loss_type, max_v_l, saliency_margin=1): """ Create the criterion. Parameters: matcher: module able to compute a matching between targets and proposals weight_dict: dict containing as key the names of the losses and as values their relative weight. eos_coef: relative classification weight applied to the no-object category losses: list of all the losses to be applied. See get_loss for list of available losses. temperature: float, temperature for NCE loss span_loss_type: str, [l1, ce] max_v_l: int, saliency_margin: float """ super().__init__() self.matcher = matcher self.weight_dict = weight_dict self.losses = losses self.temperature = temperature self.span_loss_type = span_loss_type self.max_v_l = max_v_l self.saliency_margin = saliency_margin self.temperature = 0.07 # foreground and background classification self.foreground_label = 0 self.background_label = 1 self.eos_coef = eos_coef empty_weight = torch.ones(2) empty_weight[-1] = self.eos_coef # lower weight for background (index 1, foreground index 0) self.register_buffer('empty_weight', empty_weight) def loss_spans(self, outputs, targets, indices): assert 'pred_spans' in outputs start_spans = targets['timestamp'] pred_spans = outputs['pred_spans'] src_spans = start_spans + pred_spans gt_spans = targets['span_labels_nn'] mask = targets['timestamp_mask'].bool() mask_full = targets['timestamp_mask'].unsqueeze(2).repeat(1, 1, 2) mask_valid = targets['timestamp_window'].bool() mask_valid_full = targets['timestamp_window'].unsqueeze(2).repeat(1, 1, 2) loss_span = F.smooth_l1_loss(src_spans, gt_spans, reduction='none') * mask_valid_full loss_giou = 1 - torch.diag(generalized_temporal_iou(src_spans[mask_valid], gt_spans[mask_valid])) losses = {} losses['loss_b'] = loss_span.sum() / mask_valid.sum() losses['loss_g'] = loss_giou.mean() return losses def loss_labels(self, outputs, targets, indices, log=True): src_logits = outputs['pred_logits'].squeeze(-1) # (batch_size, #queries, #classes=2) mask = targets['timestamp_mask'].bool() mask_valid = targets['timestamp_window'].bool() target_classes = torch.full(src_logits.shape[:2], 0, dtype=torch.int64, device=src_logits.device) # (batch_size, #queries) target_classes[mask_valid] = 1 # target_classes = targets['timestamp_window'] # soft cls. target_classes.float() # pdb.set_trace() weights = torch.zeros_like(target_classes).float() weights[mask] = self.empty_weight[1] weights[mask_valid] = self.empty_weight[0] loss_ce = F.binary_cross_entropy(src_logits, target_classes.float(), weight=weights, reduction="none") * mask return {"loss_f": loss_ce.sum() / mask.sum()} def loss_saliency(self, outputs, targets, indices, log=True): """higher scores for positive clips""" if "saliency_pos_labels" not in targets: return {"loss_s_inter": 0., "loss_s_intra": 0.} saliency_scores = targets["saliency_scores"] if saliency_scores.sum() == 0: return {"loss_s_inter": 0., "loss_s_intra": 0.} # * inter-vid mode vid_mem_proj = outputs["vid_mem_proj"] pos_indices = targets["saliency_pos_labels"][:,0].long() # (N, #pairs) batch_indices = torch.arange(len(vid_mem_proj)).to(vid_mem_proj.device) vid_feats = vid_mem_proj[batch_indices, pos_indices] txt_feats = outputs["txt_mem_proj"].squeeze(1) sim = sim_matrix(vid_feats, txt_feats) i_logsm = F.log_softmax(sim / self.temperature, dim=1) j_logsm = F.log_softmax(sim.t() /self.temperature, dim=1) # sum over positives idiag = torch.diag(i_logsm) jdiag = torch.diag(j_logsm) loss_i = idiag.sum() / len(idiag) loss_j = jdiag.sum() / len(jdiag) loss_saliency_inter = - loss_i - loss_j # * intra-vid mode mask = targets['timestamp_mask'] selected_scores = saliency_scores[batch_indices, pos_indices].unsqueeze(-1) neg_indices_in = (saliency_scores < selected_scores) neg_indices_in[batch_indices, pos_indices] = True mask_invalid = neg_indices_in * mask.bool() sim_in = F.cosine_similarity(vid_mem_proj, txt_feats.unsqueeze(1), dim=-1) sim_in = sim_in + (mask_invalid + 1e-45).log() logsm_in_i = F.log_softmax(sim_in / self.temperature, dim=1) logsm_in_j = F.log_softmax(sim_in.t() / self.temperature, dim=1) pos_logsm_in_i = logsm_in_i[batch_indices, pos_indices] pos_logsm_in_j = logsm_in_j[pos_indices, batch_indices] loss_in_i = pos_logsm_in_i.sum() / len(pos_logsm_in_i) loss_in_j = pos_logsm_in_j.sum() / len(pos_logsm_in_j) loss_saliency_intra = - loss_in_i - loss_in_j return {"loss_s_inter": loss_saliency_inter, "loss_s_intra": loss_saliency_intra} def loss_saliency_cls(self, outputs, targets, indices, log=True): """higher scores for positive clips""" if "saliency_pos_labels" not in targets: return {"loss_s_inter": 0., "loss_s_intra": 0.} saliency_scores = targets["saliency_scores"] if saliency_scores.sum() == 0: return {"loss_s_inter": 0., "loss_s_intra": 0.} # * inter-vid mode vid_mem_proj = outputs["vid_mem_proj"] pos_indices = targets["saliency_pos_labels"][:,0].long() # (N, #pairs) batch_indices = torch.arange(len(vid_mem_proj)).to(vid_mem_proj.device) vid_feats = vid_mem_proj[batch_indices, pos_indices] txt_feats = outputs["txt_mem_proj"].squeeze(1) sim = sim_matrix(vid_feats, txt_feats) i_logsm = F.log_softmax(sim / self.temperature, dim=1) j_logsm = F.log_softmax(sim.t() /self.temperature, dim=1) # sum over positives idiag = torch.diag(i_logsm) jdiag = torch.diag(j_logsm) loss_i = idiag.sum() / len(idiag) loss_j = jdiag.sum() / len(jdiag) loss_saliency_inter = - loss_i - loss_j # * intra-vid mode if 'cls_idx' not in targets.keys(): # eval return {"loss_s_inter": loss_saliency_inter} cls_indices = targets['cls_idx'].bool() cls_feats = outputs["cls_mem_proj"].squeeze(1) sim_cls = sim_matrix(vid_feats, cls_feats) i_logsm_cls = F.log_softmax(sim_cls / self.temperature, dim=1) idiag_cls = i_logsm_cls[cls_indices] loss_cls_i = idiag_cls.sum() / len(idiag_cls) loss_saliency_intra = - loss_cls_i return {"loss_s_inter": loss_saliency_inter, "loss_s_intra": loss_saliency_intra} def get_loss(self, loss, outputs, targets, indices, **kwargs): loss_map = { "spans": self.loss_spans, "labels": self.loss_labels, "saliency": self.loss_saliency, "saliency_cls": self.loss_saliency_cls, } assert loss in loss_map, f'do you really want to compute {loss} loss?' return loss_map[loss](outputs, targets, indices, **kwargs) def forward(self, outputs, targets, hl_only=False): """ This performs the loss computation. Parameters: outputs: dict of tensors, see the output specification of the model for the format targets: list of dicts, such that len(targets) == batch_size. The expected keys in each dict depends on the losses applied, see each loss' doc """ indices = None # Compute all the requested losses losses = {} for loss in self.losses: losses.update(self.get_loss(loss, outputs, targets, indices)) return losses class MLP(nn.Module): """ Very simple multi-layer perceptron (also called FFN)""" def __init__(self, input_dim, hidden_dim, output_dim, num_layers): super().__init__() self.num_layers = num_layers h = [hidden_dim] * (num_layers - 1) self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) def forward(self, x): for i, layer in enumerate(self.layers): x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) return x class Conv(nn.Module): """ Very simple multi-layer perceptron (also called FFN)""" def __init__(self, input_dim, hidden_dim, output_dim, num_layers, kernel_size): super().__init__() self.num_layers = num_layers h = [hidden_dim] * (num_layers - 1) # self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) self.layers = nn.ModuleList( nn.Conv1d(n, k, kernel_size=kernel_size, stride=1, padding=kernel_size//2, dilation=1, groups=1, bias=True, padding_mode='zeros') for n, k in zip([input_dim] + h, h + [output_dim])) def forward(self, x): x = x.permute(0,2,1) for i, layer in enumerate(self.layers): x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) return x.permute(0, 2, 1) class LinearLayer(nn.Module): """linear layer configurable with layer normalization, dropout, ReLU.""" def __init__(self, in_hsz, out_hsz, layer_norm=True, dropout=0.1, relu=True): super(LinearLayer, self).__init__() self.relu = relu self.layer_norm = layer_norm if layer_norm: self.LayerNorm = nn.LayerNorm(in_hsz) layers = [ nn.Dropout(dropout), nn.Linear(in_hsz, out_hsz) ] self.net = nn.Sequential(*layers) def forward(self, x): """(N, L, D)""" if self.layer_norm: x = self.LayerNorm(x) x = self.net(x) if self.relu: x = F.relu(x, inplace=True) return x # (N, L, D) def build_model(args): device = torch.device(args.device) transformer = build_transformer(args) position_embedding, txt_position_embedding = build_position_encoding(args) model = Model( transformer, position_embedding, txt_position_embedding, txt_dim=args.t_feat_dim, vid_dim=args.v_feat_dim, input_dropout=args.input_dropout, span_loss_type=args.span_loss_type, use_txt_pos=args.use_txt_pos, n_input_proj=args.n_input_proj, ) matcher = build_matcher(args) weight_dict = {"loss_b": args.b_loss_coef, "loss_g": args.g_loss_coef, "loss_f": args.f_loss_coef, "loss_s_intra": args.s_loss_intra_coef, "loss_s_inter": args.s_loss_inter_coef} if args.dset_type in ['mr']: if 'tal' not in args.train_path: losses = ['spans', 'labels', 'saliency'] else: losses = ['spans', 'labels', 'saliency_cls'] elif args.dset_type in ['hl', 'vs']: losses = ['labels', 'saliency'] criterion = SetCriterion( matcher=matcher, weight_dict=weight_dict, losses=losses, eos_coef=args.eos_coef, temperature=args.temperature, span_loss_type=args.span_loss_type, max_v_l=args.max_v_l, saliency_margin=args.saliency_margin, ) criterion.to(device) return model, criterion