|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from utils.model_utils import RNNEncoder |
|
from easydict import EasyDict as edict |
|
|
|
|
|
cal_base_cfg = edict( |
|
visual_input_size=2048, |
|
textual_input_size=768, |
|
query_feat_size=768, |
|
visual_hidden_size=500, |
|
output_size=100, |
|
embedding_size=768, |
|
lstm_hidden_size=1000, |
|
margin=0.1, |
|
loss_type="hinge", |
|
inter_loss_weight=0.4, |
|
ctx_mode="video" |
|
) |
|
|
|
|
|
class CAL(nn.Module): |
|
def __init__(self, config): |
|
super(CAL, self).__init__() |
|
self.config = config |
|
|
|
self.moment_mlp = nn.Sequential( |
|
nn.Linear(config.visual_input_size, config.visual_hidden_size), |
|
nn.ReLU(True), |
|
nn.Linear(config.visual_hidden_size, config.output_size), |
|
) |
|
|
|
self.query_lstm = RNNEncoder(word_embedding_size=config.embedding_size, |
|
hidden_size=config.lstm_hidden_size, |
|
bidirectional=False, |
|
rnn_type="lstm", |
|
dropout_p=0, |
|
n_layers=1, |
|
return_outputs=False) |
|
|
|
self.query_linear = nn.Linear(config.lstm_hidden_size, config.output_size) |
|
|
|
def moment_encoder(self, moment_feat): |
|
"""moment_feat: (N, L_clip, D_v)""" |
|
return F.normalize(self.moment_mlp(moment_feat), p=2, dim=-1) |
|
|
|
def query_encoder(self, query_feat, query_mask): |
|
""" |
|
Args: |
|
query_feat: (N, L_q, D_q), torch.float32 |
|
query_mask: (N, L_q), torch.float32, with 1 indicates valid query, 0 indicates mask |
|
""" |
|
_, hidden = self.query_lstm(query_feat, torch.sum(query_mask, dim=1).long()) |
|
return F.normalize(self.query_linear(hidden), p=2, dim=-1) |
|
|
|
def compute_pdist(self, query_embedding, moment_feat, moment_mask): |
|
""" pairwise L2 distance |
|
Args: |
|
query_embedding: (N, D_o) |
|
moment_feat: (N, L_clip, D_v) |
|
moment_mask: (N, L_clip), torch.float32, where 1 indicates valid, 0 indicates padding |
|
""" |
|
moment_embedding = self.moment_encoder(moment_feat) |
|
moment_clip_dist = torch.sum((moment_embedding - query_embedding.unsqueeze(1)) ** 2, dim=2) |
|
moment_dist = torch.sum(moment_clip_dist * moment_mask, dim=1) / moment_mask.sum(1) |
|
return moment_dist |
|
|
|
@classmethod |
|
def compute_cdist_inference(cls, query_embeddings, moment_embeddings, moment_mask): |
|
""" Compute L2 distance for every possible pair of queries and proposals. This is different from |
|
compute_pdist as the latter computes only pairs at each row. |
|
Args: |
|
query_embeddings: (N_q, D_o) |
|
moment_embeddings: (N_prop, N_clips, D_o) |
|
moment_mask: (N_prop, N_clips) |
|
return: |
|
query_moment_scores: (N_q, N_prop) |
|
""" |
|
|
|
query_device = query_embeddings.device |
|
if moment_embeddings.device != query_device: |
|
moment_embeddings = moment_embeddings.to(query_device) |
|
moment_mask = moment_mask.to(query_device) |
|
|
|
|
|
n_query = query_embeddings.shape[0] |
|
n_prop, n_clips, d = moment_embeddings.shape |
|
query_clip_dist = torch.cdist( |
|
query_embeddings, moment_embeddings.reshape(-1, d), p=2) ** 2 |
|
query_clip_dist = query_clip_dist.reshape(n_query, n_prop, n_clips) |
|
query_moment_dist = torch.sum( |
|
query_clip_dist * moment_mask.unsqueeze(0), dim=2) / moment_mask.sum(1).unsqueeze(0) |
|
return query_moment_dist |
|
|
|
def forward(self, query_feat, query_mask, pos_moment_feat, pos_moment_mask, |
|
intra_neg_moment_feat, intra_neg_moment_mask, |
|
inter_neg_moment_feat, inter_neg_moment_mask): |
|
""" |
|
Args: |
|
query_feat: (N, L, D_q) |
|
query_mask: (N, L) |
|
pos_moment_feat: (N, L_clip_1, D_v) |
|
pos_moment_mask: (N, L_clip_1) |
|
intra_neg_moment_feat: (N, L_clip_2, D_v) |
|
intra_neg_moment_mask: (N, L_clip_2) |
|
inter_neg_moment_feat: (N, L_clip_3, D_v) |
|
inter_neg_moment_mask: (N, L_clip_2) |
|
""" |
|
query_embed = self.query_encoder(query_feat, query_mask) |
|
pos_dist = self.compute_pdist(query_embed, pos_moment_feat, pos_moment_mask) |
|
intra_neg_dist = self.compute_pdist(query_embed, intra_neg_moment_feat, intra_neg_moment_mask) |
|
if self.config.inter_loss_weight == 0: |
|
loss_inter = 0. |
|
else: |
|
inter_neg_dist = self.compute_pdist(query_embed, inter_neg_moment_feat, inter_neg_moment_mask) |
|
loss_inter = self.calc_loss(pos_dist, inter_neg_dist) |
|
|
|
loss = self.calc_loss(pos_dist, intra_neg_dist) + self.config.inter_loss_weight * loss_inter |
|
return loss |
|
|
|
def calc_loss(self, pos_dist, neg_dist): |
|
""" Note here we encourage positive distance to be smaller than negative distance. |
|
Args: |
|
pos_dist: (N, ), torch.float32 |
|
neg_dist: (N, ), torch.float32 |
|
""" |
|
if self.config.loss_type == "hinge": |
|
return torch.clamp(self.config.margin + pos_dist - neg_dist, min=0).sum() / len(pos_dist) |
|
elif self.config.loss_type == "lse": |
|
return torch.log1p(torch.exp(pos_dist - neg_dist)).sum() / len(pos_dist) |
|
else: |
|
raise NotImplementedError("Only support 'hinge' and 'lse'") |
|
|
|
|
|
class CALWithSub(nn.Module): |
|
def __init__(self, config): |
|
super(CALWithSub, self).__init__() |
|
self.config = config |
|
self.use_video = "video" in config.ctx_mode |
|
self.use_sub = "sub" in config.ctx_mode |
|
self.use_tef = "tef" in config.ctx_mode |
|
self.tef_only = self.use_tef and not self.use_video and not self.use_sub |
|
|
|
if self.use_video or self.tef_only: |
|
self.video_moment_mlp = nn.Sequential( |
|
nn.Linear(config.visual_input_size, config.visual_hidden_size), |
|
nn.ReLU(True), |
|
nn.Linear(config.visual_hidden_size, config.output_size), |
|
) |
|
|
|
if self.use_sub: |
|
self.sub_moment_mlp = nn.Sequential( |
|
nn.Linear(config.textual_input_size, config.visual_hidden_size), |
|
nn.ReLU(True), |
|
nn.Linear(config.visual_hidden_size, config.output_size), |
|
) |
|
|
|
self.query_lstm = RNNEncoder(word_embedding_size=config.query_feat_size, |
|
hidden_size=config.lstm_hidden_size, |
|
bidirectional=False, |
|
rnn_type="lstm", |
|
dropout_p=0, |
|
n_layers=1, |
|
return_outputs=False) |
|
|
|
self.query_linear = nn.Linear(config.lstm_hidden_size, config.output_size) |
|
|
|
def moment_encoder(self, moment_feat, module_name="video"): |
|
"""moment_feat: (N, L_clip, D_v)""" |
|
if moment_feat is not None: |
|
encoder = getattr(self, module_name + "_moment_mlp") |
|
return F.normalize(encoder(moment_feat), p=2, dim=-1) |
|
else: |
|
return None |
|
|
|
def query_encoder(self, query_feat, query_mask): |
|
""" |
|
Args: |
|
query_feat: (N, L_q, D_q), torch.float32 |
|
query_mask: (N, L_q), torch.float32, with 1 indicates valid query, 0 indicates mask |
|
""" |
|
_, hidden = self.query_lstm(query_feat, torch.sum(query_mask, dim=1).long()) |
|
return F.normalize(self.query_linear(hidden), p=2, dim=-1) |
|
|
|
def _compute_pdist(self, query_embedding, moment_feat, moment_mask, module_name="video"): |
|
""" pairwise L2 distance |
|
Args: |
|
query_embedding: (N, D_o) |
|
moment_feat: (N, L_clip, D_v) |
|
moment_mask: (N, L_clip), torch.float32, where 1 indicates valid, 0 indicates padding |
|
""" |
|
moment_embedding = self.moment_encoder(moment_feat, module_name=module_name) |
|
moment_clip_dist = torch.sum((moment_embedding - query_embedding.unsqueeze(1)) ** 2, dim=2) |
|
moment_dist = torch.sum(moment_clip_dist * moment_mask, dim=1) / moment_mask.sum(1) |
|
return moment_dist |
|
|
|
def compute_pdist(self, query_embedding, moment_video_feat, moment_sub_feat, moment_mask): |
|
""" pairwise L2 distance |
|
Args: |
|
query_embedding: (N, D_o) |
|
moment_video_feat: (N, L_clip, D_v) |
|
moment_sub_feat: (N, L_clip, D_t) |
|
moment_mask: (N, L_clip), torch.float32, where 1 indicates valid, 0 indicates padding |
|
""" |
|
divisor = (self.use_video or self.tef_only) + self.use_sub |
|
video_moment_dist = self._compute_pdist(query_embedding, moment_video_feat, moment_mask, module_name="video") \ |
|
if self.use_video or self.tef_only else 0 |
|
sub_moment_dist = self._compute_pdist(query_embedding, moment_sub_feat, moment_mask, module_name="sub") \ |
|
if self.use_sub else 0 |
|
return (video_moment_dist + sub_moment_dist) / divisor |
|
|
|
def _compute_cdist_inference(self, query_embeddings, moment_embeddings, moment_mask): |
|
""" Compute L2 distance for every possible pair of queries and proposals. This is different from |
|
compute_pdist as the latter computes only pairs at each row. |
|
Args: |
|
query_embeddings: (N_q, D_o) |
|
moment_embeddings: (N_prop, N_clips, D_o) |
|
moment_mask: (N_prop, N_clips) |
|
return: |
|
query_moment_scores: (N_q, N_prop) |
|
""" |
|
|
|
query_device = query_embeddings.device |
|
if moment_embeddings.device != query_device: |
|
moment_embeddings = moment_embeddings.to(query_device) |
|
moment_mask = moment_mask.to(query_device) |
|
|
|
|
|
n_query = query_embeddings.shape[0] |
|
n_prop, n_clips, d = moment_embeddings.shape |
|
query_clip_dist = torch.cdist( |
|
query_embeddings, moment_embeddings.reshape(-1, d), p=2) ** 2 |
|
query_clip_dist = query_clip_dist.reshape(n_query, n_prop, n_clips) |
|
query_moment_dist = torch.sum( |
|
query_clip_dist * moment_mask.unsqueeze(0), dim=2) / moment_mask.sum(1).unsqueeze(0) |
|
return query_moment_dist |
|
|
|
def compute_cdist_inference(self, query_embeddings, video_moment_embeddings, sub_moment_embeddings, moment_mask): |
|
divisor = (self.use_video or self.tef_only) + self.use_sub |
|
video_moment_dist = self._compute_cdist_inference(query_embeddings, video_moment_embeddings, moment_mask) \ |
|
if self.use_video or self.tef_only else 0 |
|
sub_moment_dist = self._compute_cdist_inference(query_embeddings, sub_moment_embeddings, moment_mask) \ |
|
if self.use_sub else 0 |
|
return (video_moment_dist + sub_moment_dist) / divisor |
|
|
|
def forward(self, query_feat, query_mask, pos_moment_video_feat, pos_moment_video_mask, |
|
intra_neg_moment_video_feat, intra_neg_moment_video_mask, |
|
inter_neg_moment_video_feat, inter_neg_moment_video_mask, |
|
pos_moment_sub_feat, pos_moment_sub_mask, |
|
intra_neg_moment_sub_feat, intra_neg_moment_sub_mask, |
|
inter_neg_moment_sub_feat, inter_neg_moment_sub_mask): |
|
""" |
|
Args: |
|
query_feat: (N, L, D_q) |
|
query_mask: (N, L) |
|
pos_moment_video_feat: (N, L_clip_1, D_v) |
|
pos_moment_video_mask: (N, L_clip_1) |
|
intra_neg_moment_video_feat: (N, L_clip_2, D_v) |
|
intra_neg_moment_video_mask: (N, L_clip_2) |
|
inter_neg_moment_video_feat: (N, L_clip_3, D_v) |
|
inter_neg_moment_video_mask: (N, L_clip_2) |
|
pos_moment_sub_feat: |
|
pos_moment_sub_mask: |
|
intra_neg_moment_sub_feat: |
|
intra_neg_moment_sub_mask: |
|
inter_neg_moment_sub_feat: |
|
inter_neg_moment_sub_mask: |
|
""" |
|
query_embed = self.query_encoder(query_feat, query_mask) |
|
pos_dist = self.compute_pdist( |
|
query_embed, pos_moment_video_feat, pos_moment_sub_feat, |
|
moment_mask=pos_moment_sub_mask if self.use_sub else pos_moment_video_mask) |
|
intra_neg_dist = self.compute_pdist( |
|
query_embed, intra_neg_moment_video_feat, intra_neg_moment_sub_feat, |
|
moment_mask=intra_neg_moment_sub_mask if self.use_sub else intra_neg_moment_video_mask) |
|
if self.config.inter_loss_weight == 0: |
|
loss_inter = 0. |
|
else: |
|
inter_neg_dist = self.compute_pdist( |
|
query_embed, inter_neg_moment_video_feat, inter_neg_moment_sub_feat, |
|
moment_mask=inter_neg_moment_sub_mask if self.use_sub else inter_neg_moment_video_mask) |
|
loss_inter = self.calc_loss(pos_dist, inter_neg_dist) |
|
|
|
loss = self.calc_loss(pos_dist, intra_neg_dist) + self.config.inter_loss_weight * loss_inter |
|
return loss |
|
|
|
def calc_loss(self, pos_dist, neg_dist): |
|
""" Note here we encourage positive distance to be smaller than negative distance. |
|
Args: |
|
pos_dist: (N, ), torch.float32 |
|
neg_dist: (N, ), torch.float32 |
|
""" |
|
if self.config.loss_type == "hinge": |
|
return torch.clamp(self.config.margin + pos_dist - neg_dist, min=0).sum() / len(pos_dist) |
|
elif self.config.loss_type == "lse": |
|
return torch.log1p(torch.exp(pos_dist - neg_dist)).sum() / len(pos_dist) |
|
else: |
|
raise NotImplementedError("Only support 'hinge' and 'lse'") |
|
|