|
|
|
|
|
import torch |
|
import torch.nn as nn |
|
from nncore.nn import MODELS, build_model |
|
|
|
|
|
@MODELS.register() |
|
class R2Block(nn.Module): |
|
|
|
def __init__(self, |
|
dims, |
|
in_dims, |
|
k=4, |
|
dropout=0.5, |
|
use_tef=True, |
|
pos_cfg=None, |
|
tem_cfg=None): |
|
super(R2Block, self).__init__() |
|
|
|
|
|
self.video_map = nn.Sequential( |
|
nn.LayerNorm((in_dims[0] + 2) if use_tef else in_dims[0]), |
|
nn.Dropout(dropout), |
|
nn.Linear((in_dims[0] + 2) if use_tef else in_dims[0], dims), |
|
nn.ReLU(inplace=True), |
|
nn.LayerNorm(dims), |
|
nn.Dropout(dropout), |
|
nn.Linear(dims, dims)) |
|
|
|
self.query_map = nn.Sequential( |
|
nn.LayerNorm(in_dims[1]), |
|
nn.Dropout(dropout), |
|
nn.Linear(in_dims[1], dims), |
|
nn.ReLU(inplace=True), |
|
nn.LayerNorm(dims), |
|
nn.Dropout(dropout), |
|
nn.Linear(dims, dims)) |
|
|
|
|
|
if k > 1: |
|
self.gate = nn.Parameter(torch.zeros([k - 1])) |
|
|
|
self.v_map = nn.Linear(dims, dims) |
|
self.q_map = nn.Linear(dims, dims) |
|
self.scale = nn.Parameter(torch.zeros([k])) |
|
|
|
self.pos = build_model(pos_cfg, dims=dims) |
|
self.tem = build_model(tem_cfg, dims=dims) |
|
|
|
self.dims = dims |
|
self.in_dims = in_dims |
|
self.k = k |
|
self.dropout = dropout |
|
self.use_tef = use_tef |
|
|
|
def forward(self, video_emb, query_emb, video_msk, query_msk): |
|
video_emb = video_emb[-self.k:] |
|
query_emb = query_emb[-self.k:] |
|
|
|
_, b, t, p, _ = video_emb.size() |
|
|
|
if self.use_tef: |
|
tef_s = torch.arange(0, 1, 1 / t, device=video_emb.device) |
|
tef_e = tef_s + 1.0 / t |
|
tef = torch.stack((tef_s, tef_e), dim=1) |
|
tef = tef.unsqueeze(1).unsqueeze(0).unsqueeze(0).repeat(self.k, b, 1, p, 1) |
|
video_emb = torch.cat((video_emb, tef[:, :, :video_emb.size(2)]), dim=-1) |
|
|
|
coll_v, coll_q, last = [], [], None |
|
for i in range(self.k - 1, -1, -1): |
|
v_emb = self.video_map(video_emb[i]) |
|
q_emb = self.query_map(query_emb[i]) |
|
|
|
coll_v.append(v_emb[:, :, 0]) |
|
coll_q.append(q_emb) |
|
|
|
v_pool = v_emb.view(b * t, -1, self.dims) |
|
q_pool = q_emb.repeat_interleave(t, dim=0) |
|
|
|
v_pool_map = self.v_map(v_pool) |
|
q_pool_map = self.q_map(q_pool) |
|
|
|
att = torch.bmm(q_pool_map, v_pool_map.transpose(1, 2)) / self.dims**0.5 |
|
att = att.softmax(-1) |
|
|
|
o_pool = torch.bmm(att, v_pool) + q_pool |
|
o_pool = o_pool.amax(dim=1, keepdim=True) |
|
v_emb = v_pool[:, 0, None] + o_pool * self.scale[i].tanh() |
|
v_emb = v_emb.view(b, t, self.dims) |
|
|
|
if i < self.k - 1: |
|
gate = self.gate[i].sigmoid() |
|
v_emb = gate * v_emb + (1 - gate) * last |
|
|
|
v_pe = self.pos(v_emb) |
|
last = self.tem(v_emb, q_emb, q_pe=v_pe, q_mask=video_msk, k_mask=query_msk) |
|
|
|
return last, q_emb, coll_v, coll_q |
|
|