Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
class FeatureResizer(nn.Module): | |
""" | |
This class takes as input a set of embeddings of dimension C1 and outputs a set of | |
embedding of dimension C2, after a linear transformation, dropout and normalization (LN). | |
""" | |
def __init__(self, input_feat_size, output_feat_size, dropout, do_ln=True): | |
super().__init__() | |
self.do_ln = do_ln | |
# Object feature encoding | |
self.fc = nn.Linear(input_feat_size, output_feat_size, bias=True) | |
self.layer_norm = nn.LayerNorm(output_feat_size, eps=1e-12) | |
self.dropout = nn.Dropout(dropout) | |
def forward(self, encoder_features): | |
x = self.fc(encoder_features) | |
if self.do_ln: | |
x = self.layer_norm(x) | |
output = self.dropout(x) | |
return output | |
def _make_conv(input_dim, output_dim, k, stride=1): | |
pad = (k - 1) // 2 | |
return nn.Sequential( | |
nn.Conv2d(input_dim, output_dim, (k, k), padding=(pad, pad), stride=(stride, stride)), | |
nn.BatchNorm2d(output_dim), | |
nn.ReLU(inplace=True), | |
) | |
def _make_mlp(input_dim, output_dim, drop): | |
return nn.Sequential( | |
nn.Linear(input_dim, output_dim), | |
nn.BatchNorm1d(output_dim), | |
nn.ReLU(inplace=True), | |
nn.Dropout(drop), | |
nn.Linear(output_dim, output_dim), | |
nn.BatchNorm1d(output_dim), | |
nn.ReLU(inplace=True), | |
) | |
def _make_coord(batch, height, width): | |
# relative position encoding | |
xv, yv = torch.meshgrid([torch.arange(0, height), torch.arange(0, width)]) | |
xv_min = (xv.float() * 2 - width) / width | |
yv_min = (yv.float() * 2 - height) / height | |
xv_max = ((xv + 1).float() * 2 - width) / width | |
yv_max = ((yv + 1).float() * 2 - height) / height | |
xv_ctr = (xv_min + xv_max) / 2 | |
yv_ctr = (yv_min + yv_max) / 2 | |
hmap = torch.ones(height, width) * (1.0 / height) | |
wmap = torch.ones(height, width) * (1.0 / width) | |
coord = torch.autograd.Variable( | |
torch.cat( | |
[ | |
xv_min.unsqueeze(0), | |
yv_min.unsqueeze(0), | |
xv_max.unsqueeze(0), | |
yv_max.unsqueeze(0), | |
xv_ctr.unsqueeze(0), | |
yv_ctr.unsqueeze(0), | |
hmap.unsqueeze(0), | |
wmap.unsqueeze(0), | |
], | |
dim=0, | |
) | |
) | |
coord = coord.unsqueeze(0).repeat(batch, 1, 1, 1) | |
return coord | |
def l1norm(X, dim, eps=1e-8): | |
"""L1-normalize columns of X""" | |
norm = torch.abs(X).sum(dim=dim, keepdim=True) + eps | |
X = torch.div(X, norm) | |
return X | |
def l2norm(X, dim, eps=1e-8): | |
"""L2-normalize columns of X""" | |
norm = torch.pow(X, 2).sum(dim=dim, keepdim=True).sqrt() + eps | |
X = torch.div(X, norm) | |
return X | |
def func_attention(query, context, smooth=1, raw_feature_norm="softmax", eps=1e-8): | |
""" | |
query: (n_context, queryL, d) | |
context: (n_context, sourceL, d) | |
""" | |
batch_size_q, queryL = query.size(0), query.size(1) | |
batch_size, sourceL = context.size(0), context.size(1) | |
# Get attention | |
# --> (batch, d, queryL) | |
queryT = torch.transpose(query, 1, 2) | |
# (batch, sourceL, d)(batch, d, queryL) | |
# --> (batch, sourceL, queryL) | |
attn = torch.bmm(context, queryT) | |
if raw_feature_norm == "softmax": | |
# --> (batch*sourceL, queryL) | |
attn = attn.view(batch_size * sourceL, queryL) | |
attn = nn.Softmax()(attn) | |
# --> (batch, sourceL, queryL) | |
attn = attn.view(batch_size, sourceL, queryL) | |
elif raw_feature_norm == "l2norm": | |
attn = l2norm(attn, 2) | |
elif raw_feature_norm == "clipped_l2norm": | |
attn = nn.LeakyReLU(0.1)(attn) | |
attn = l2norm(attn, 2) | |
else: | |
raise ValueError("unknown first norm type:", raw_feature_norm) | |
# --> (batch, queryL, sourceL) | |
attn = torch.transpose(attn, 1, 2).contiguous() | |
# --> (batch*queryL, sourceL) | |
attn = attn.view(batch_size * queryL, sourceL) | |
attn = nn.Softmax()(attn * smooth) | |
# --> (batch, queryL, sourceL) | |
attn = attn.view(batch_size, queryL, sourceL) | |
# --> (batch, sourceL, queryL) | |
attnT = torch.transpose(attn, 1, 2).contiguous() | |
# --> (batch, d, sourceL) | |
contextT = torch.transpose(context, 1, 2) | |
# (batch x d x sourceL)(batch x sourceL x queryL) | |
# --> (batch, d, queryL) | |
weightedContext = torch.bmm(contextT, attnT) | |
# --> (batch, queryL, d) | |
weightedContext = torch.transpose(weightedContext, 1, 2) | |
return weightedContext, attnT | |
class MultiHeadAttention(nn.Module): | |
"""Multi-head attention module for both image and text""" | |
def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1): | |
super(MultiHeadAttention, self).__init__() | |
self.n_head = n_head | |
self.d_k = d_k | |
self.d_v = d_v | |
self.w_qs = nn.Linear(d_model, n_head * d_k, bias=False) | |
self.w_ks = nn.Linear(d_model, n_head * d_k, bias=False) | |
self.w_vs = nn.Linear(d_model, n_head * d_v, bias=False) | |
self.fc = nn.Linear(n_head * d_v, d_model, bias=False) | |
self.dropout = nn.Dropout(dropout) | |
self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) | |
def forward(self, q, k, v): | |
d_k, d_v, n_head = self.d_k, self.d_v, self.n_head | |
sz_b, len_q, len_k, len_v = q.size(0), q.size(1), k.size(1), v.size(1) | |
# Pass through the pre-attention projection: b x lq x (n*dv) | |
# Separate different heads: b x lq x n x dv | |
q = self.w_qs(q).view(sz_b, len_q, n_head, d_k) | |
k = self.w_ks(k).view(sz_b, len_k, n_head, d_k) | |
v = self.w_vs(v).view(sz_b, len_v, n_head, d_v) | |
# Transpose for attention dot product: b x n x lq x dv | |
q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) | |
attn = torch.matmul(q, k.transpose(2, 3)) | |
attn = self.dropout(F.softmax(attn, dim=-1)) | |
q = torch.matmul(attn, v) | |
# Transpose to move the head dimension back: b x lq x n x dv | |
# Combine the last two dimensions to concatenate all the heads together: b x lq x (n*dv) | |
q = q.transpose(1, 2).contiguous().view(sz_b, len_q, -1) | |
q = self.dropout(self.fc(q)) | |
return q, attn | |