Liangrj5
init
ebf5d87
import torch
import torch.nn as nn
import torch.nn.functional as F
from utils.model_utils import RNNEncoder
from easydict import EasyDict as edict
cal_base_cfg = edict(
visual_input_size=2048, # changes based on visual input type
textual_input_size=768,
query_feat_size=768,
visual_hidden_size=500, #
output_size=100,
embedding_size=768,
lstm_hidden_size=1000,
margin=0.1, # margin for ranking loss
loss_type="hinge", # loss type, 'hinge' or 'lse'
inter_loss_weight=0.4, # weight for inter negatives
ctx_mode="video"
)
class CAL(nn.Module):
def __init__(self, config):
super(CAL, self).__init__()
self.config = config
self.moment_mlp = nn.Sequential(
nn.Linear(config.visual_input_size, config.visual_hidden_size),
nn.ReLU(True),
nn.Linear(config.visual_hidden_size, config.output_size),
)
self.query_lstm = RNNEncoder(word_embedding_size=config.embedding_size,
hidden_size=config.lstm_hidden_size,
bidirectional=False,
rnn_type="lstm",
dropout_p=0,
n_layers=1,
return_outputs=False)
self.query_linear = nn.Linear(config.lstm_hidden_size, config.output_size)
def moment_encoder(self, moment_feat):
"""moment_feat: (N, L_clip, D_v)"""
return F.normalize(self.moment_mlp(moment_feat), p=2, dim=-1) # (N, L_clip, D_o)
def query_encoder(self, query_feat, query_mask):
"""
Args:
query_feat: (N, L_q, D_q), torch.float32
query_mask: (N, L_q), torch.float32, with 1 indicates valid query, 0 indicates mask
"""
_, hidden = self.query_lstm(query_feat, torch.sum(query_mask, dim=1).long())
return F.normalize(self.query_linear(hidden), p=2, dim=-1) # (N, D_o)
def compute_pdist(self, query_embedding, moment_feat, moment_mask):
""" pairwise L2 distance
Args:
query_embedding: (N, D_o)
moment_feat: (N, L_clip, D_v)
moment_mask: (N, L_clip), torch.float32, where 1 indicates valid, 0 indicates padding
"""
moment_embedding = self.moment_encoder(moment_feat) # (N, L_clip, D_o)
moment_clip_dist = torch.sum((moment_embedding - query_embedding.unsqueeze(1)) ** 2, dim=2) # (N, L_clip)
moment_dist = torch.sum(moment_clip_dist * moment_mask, dim=1) / moment_mask.sum(1) # (N, )
return moment_dist # (N, )
@classmethod
def compute_cdist_inference(cls, query_embeddings, moment_embeddings, moment_mask):
""" Compute L2 distance for every possible pair of queries and proposals. This is different from
compute_pdist as the latter computes only pairs at each row.
Args:
query_embeddings: (N_q, D_o)
moment_embeddings: (N_prop, N_clips, D_o)
moment_mask: (N_prop, N_clips)
return:
query_moment_scores: (N_q, N_prop)
"""
# sync device
query_device = query_embeddings.device # convert to cuda if we want to use GPU
if moment_embeddings.device != query_device:
moment_embeddings = moment_embeddings.to(query_device)
moment_mask = moment_mask.to(query_device)
# compute
n_query = query_embeddings.shape[0]
n_prop, n_clips, d = moment_embeddings.shape
query_clip_dist = torch.cdist(
query_embeddings, moment_embeddings.reshape(-1, d), p=2) ** 2 # (N_q, N_prop * N_clips)
query_clip_dist = query_clip_dist.reshape(n_query, n_prop, n_clips)
query_moment_dist = torch.sum(
query_clip_dist * moment_mask.unsqueeze(0), dim=2) / moment_mask.sum(1).unsqueeze(0)
return query_moment_dist # (N_q, N_prop)
def forward(self, query_feat, query_mask, pos_moment_feat, pos_moment_mask,
intra_neg_moment_feat, intra_neg_moment_mask,
inter_neg_moment_feat, inter_neg_moment_mask):
"""
Args:
query_feat: (N, L, D_q)
query_mask: (N, L)
pos_moment_feat: (N, L_clip_1, D_v)
pos_moment_mask: (N, L_clip_1)
intra_neg_moment_feat: (N, L_clip_2, D_v)
intra_neg_moment_mask: (N, L_clip_2)
inter_neg_moment_feat: (N, L_clip_3, D_v)
inter_neg_moment_mask: (N, L_clip_2)
"""
query_embed = self.query_encoder(query_feat, query_mask) # (N, D_o)
pos_dist = self.compute_pdist(query_embed, pos_moment_feat, pos_moment_mask) # (N, )
intra_neg_dist = self.compute_pdist(query_embed, intra_neg_moment_feat, intra_neg_moment_mask) # (N, )
if self.config.inter_loss_weight == 0: # should be zero for tef_only method.
loss_inter = 0.
else:
inter_neg_dist = self.compute_pdist(query_embed, inter_neg_moment_feat, inter_neg_moment_mask) # (N, )
loss_inter = self.calc_loss(pos_dist, inter_neg_dist)
loss = self.calc_loss(pos_dist, intra_neg_dist) + self.config.inter_loss_weight * loss_inter
return loss
def calc_loss(self, pos_dist, neg_dist):
""" Note here we encourage positive distance to be smaller than negative distance.
Args:
pos_dist: (N, ), torch.float32
neg_dist: (N, ), torch.float32
"""
if self.config.loss_type == "hinge": # max(0, m + S_pos - S_neg)
return torch.clamp(self.config.margin + pos_dist - neg_dist, min=0).sum() / len(pos_dist)
elif self.config.loss_type == "lse": # log[1 + exp(S_pos - S_neg)]
return torch.log1p(torch.exp(pos_dist - neg_dist)).sum() / len(pos_dist)
else:
raise NotImplementedError("Only support 'hinge' and 'lse'")
class CALWithSub(nn.Module):
def __init__(self, config):
super(CALWithSub, self).__init__()
self.config = config
self.use_video = "video" in config.ctx_mode
self.use_sub = "sub" in config.ctx_mode
self.use_tef = "tef" in config.ctx_mode
self.tef_only = self.use_tef and not self.use_video and not self.use_sub
if self.use_video or self.tef_only:
self.video_moment_mlp = nn.Sequential(
nn.Linear(config.visual_input_size, config.visual_hidden_size),
nn.ReLU(True),
nn.Linear(config.visual_hidden_size, config.output_size),
)
if self.use_sub:
self.sub_moment_mlp = nn.Sequential(
nn.Linear(config.textual_input_size, config.visual_hidden_size),
nn.ReLU(True),
nn.Linear(config.visual_hidden_size, config.output_size),
)
self.query_lstm = RNNEncoder(word_embedding_size=config.query_feat_size,
hidden_size=config.lstm_hidden_size,
bidirectional=False,
rnn_type="lstm",
dropout_p=0,
n_layers=1,
return_outputs=False)
self.query_linear = nn.Linear(config.lstm_hidden_size, config.output_size)
def moment_encoder(self, moment_feat, module_name="video"):
"""moment_feat: (N, L_clip, D_v)"""
if moment_feat is not None:
encoder = getattr(self, module_name + "_moment_mlp")
return F.normalize(encoder(moment_feat), p=2, dim=-1) # (N, L_clip, D_o)
else:
return None
def query_encoder(self, query_feat, query_mask):
"""
Args:
query_feat: (N, L_q, D_q), torch.float32
query_mask: (N, L_q), torch.float32, with 1 indicates valid query, 0 indicates mask
"""
_, hidden = self.query_lstm(query_feat, torch.sum(query_mask, dim=1).long())
return F.normalize(self.query_linear(hidden), p=2, dim=-1) # (N, D_o)
def _compute_pdist(self, query_embedding, moment_feat, moment_mask, module_name="video"):
""" pairwise L2 distance
Args:
query_embedding: (N, D_o)
moment_feat: (N, L_clip, D_v)
moment_mask: (N, L_clip), torch.float32, where 1 indicates valid, 0 indicates padding
"""
moment_embedding = self.moment_encoder(moment_feat, module_name=module_name) # (N, L_clip, D_o)
moment_clip_dist = torch.sum((moment_embedding - query_embedding.unsqueeze(1)) ** 2, dim=2) # (N, L_clip)
moment_dist = torch.sum(moment_clip_dist * moment_mask, dim=1) / moment_mask.sum(1) # (N, )
return moment_dist # (N, )
def compute_pdist(self, query_embedding, moment_video_feat, moment_sub_feat, moment_mask):
""" pairwise L2 distance
Args:
query_embedding: (N, D_o)
moment_video_feat: (N, L_clip, D_v)
moment_sub_feat: (N, L_clip, D_t)
moment_mask: (N, L_clip), torch.float32, where 1 indicates valid, 0 indicates padding
"""
divisor = (self.use_video or self.tef_only) + self.use_sub
video_moment_dist = self._compute_pdist(query_embedding, moment_video_feat, moment_mask, module_name="video") \
if self.use_video or self.tef_only else 0
sub_moment_dist = self._compute_pdist(query_embedding, moment_sub_feat, moment_mask, module_name="sub") \
if self.use_sub else 0
return (video_moment_dist + sub_moment_dist) / divisor # (N, )
def _compute_cdist_inference(self, query_embeddings, moment_embeddings, moment_mask):
""" Compute L2 distance for every possible pair of queries and proposals. This is different from
compute_pdist as the latter computes only pairs at each row.
Args:
query_embeddings: (N_q, D_o)
moment_embeddings: (N_prop, N_clips, D_o)
moment_mask: (N_prop, N_clips)
return:
query_moment_scores: (N_q, N_prop)
"""
# sync device
query_device = query_embeddings.device # convert to cuda if we want to use GPU
if moment_embeddings.device != query_device:
moment_embeddings = moment_embeddings.to(query_device)
moment_mask = moment_mask.to(query_device)
# compute
n_query = query_embeddings.shape[0]
n_prop, n_clips, d = moment_embeddings.shape
query_clip_dist = torch.cdist(
query_embeddings, moment_embeddings.reshape(-1, d), p=2) ** 2 # (N_q, N_prop * N_clips)
query_clip_dist = query_clip_dist.reshape(n_query, n_prop, n_clips)
query_moment_dist = torch.sum(
query_clip_dist * moment_mask.unsqueeze(0), dim=2) / moment_mask.sum(1).unsqueeze(0)
return query_moment_dist # (N_q, N_prop)
def compute_cdist_inference(self, query_embeddings, video_moment_embeddings, sub_moment_embeddings, moment_mask):
divisor = (self.use_video or self.tef_only) + self.use_sub
video_moment_dist = self._compute_cdist_inference(query_embeddings, video_moment_embeddings, moment_mask) \
if self.use_video or self.tef_only else 0
sub_moment_dist = self._compute_cdist_inference(query_embeddings, sub_moment_embeddings, moment_mask) \
if self.use_sub else 0
return (video_moment_dist + sub_moment_dist) / divisor # (N_q, N_prop)
def forward(self, query_feat, query_mask, pos_moment_video_feat, pos_moment_video_mask,
intra_neg_moment_video_feat, intra_neg_moment_video_mask,
inter_neg_moment_video_feat, inter_neg_moment_video_mask,
pos_moment_sub_feat, pos_moment_sub_mask,
intra_neg_moment_sub_feat, intra_neg_moment_sub_mask,
inter_neg_moment_sub_feat, inter_neg_moment_sub_mask):
"""
Args:
query_feat: (N, L, D_q)
query_mask: (N, L)
pos_moment_video_feat: (N, L_clip_1, D_v)
pos_moment_video_mask: (N, L_clip_1)
intra_neg_moment_video_feat: (N, L_clip_2, D_v)
intra_neg_moment_video_mask: (N, L_clip_2)
inter_neg_moment_video_feat: (N, L_clip_3, D_v)
inter_neg_moment_video_mask: (N, L_clip_2)
pos_moment_sub_feat:
pos_moment_sub_mask:
intra_neg_moment_sub_feat:
intra_neg_moment_sub_mask:
inter_neg_moment_sub_feat:
inter_neg_moment_sub_mask:
"""
query_embed = self.query_encoder(query_feat, query_mask) # (N, D_o)
pos_dist = self.compute_pdist(
query_embed, pos_moment_video_feat, pos_moment_sub_feat,
moment_mask=pos_moment_sub_mask if self.use_sub else pos_moment_video_mask) # (N, )
intra_neg_dist = self.compute_pdist(
query_embed, intra_neg_moment_video_feat, intra_neg_moment_sub_feat,
moment_mask=intra_neg_moment_sub_mask if self.use_sub else intra_neg_moment_video_mask) # (N, )
if self.config.inter_loss_weight == 0: # should be zero for tef_only method.
loss_inter = 0.
else:
inter_neg_dist = self.compute_pdist(
query_embed, inter_neg_moment_video_feat, inter_neg_moment_sub_feat,
moment_mask=inter_neg_moment_sub_mask if self.use_sub else inter_neg_moment_video_mask) # (N, )
loss_inter = self.calc_loss(pos_dist, inter_neg_dist)
loss = self.calc_loss(pos_dist, intra_neg_dist) + self.config.inter_loss_weight * loss_inter
return loss
def calc_loss(self, pos_dist, neg_dist):
""" Note here we encourage positive distance to be smaller than negative distance.
Args:
pos_dist: (N, ), torch.float32
neg_dist: (N, ), torch.float32
"""
if self.config.loss_type == "hinge": # max(0, m + S_pos - S_neg)
return torch.clamp(self.config.margin + pos_dist - neg_dist, min=0).sum() / len(pos_dist)
elif self.config.loss_type == "lse": # log[1 + exp(S_pos - S_neg)]
return torch.log1p(torch.exp(pos_dist - neg_dist)).sum() / len(pos_dist)
else:
raise NotImplementedError("Only support 'hinge' and 'lse'")