Spaces:
Sleeping
Sleeping
File size: 6,399 Bytes
749745d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 |
import torch
import torch.nn as nn
import torch.nn.functional as F
class FeatureResizer(nn.Module):
"""
This class takes as input a set of embeddings of dimension C1 and outputs a set of
embedding of dimension C2, after a linear transformation, dropout and normalization (LN).
"""
def __init__(self, input_feat_size, output_feat_size, dropout, do_ln=True):
super().__init__()
self.do_ln = do_ln
# Object feature encoding
self.fc = nn.Linear(input_feat_size, output_feat_size, bias=True)
self.layer_norm = nn.LayerNorm(output_feat_size, eps=1e-12)
self.dropout = nn.Dropout(dropout)
def forward(self, encoder_features):
x = self.fc(encoder_features)
if self.do_ln:
x = self.layer_norm(x)
output = self.dropout(x)
return output
def _make_conv(input_dim, output_dim, k, stride=1):
pad = (k - 1) // 2
return nn.Sequential(
nn.Conv2d(input_dim, output_dim, (k, k), padding=(pad, pad), stride=(stride, stride)),
nn.BatchNorm2d(output_dim),
nn.ReLU(inplace=True),
)
def _make_mlp(input_dim, output_dim, drop):
return nn.Sequential(
nn.Linear(input_dim, output_dim),
nn.BatchNorm1d(output_dim),
nn.ReLU(inplace=True),
nn.Dropout(drop),
nn.Linear(output_dim, output_dim),
nn.BatchNorm1d(output_dim),
nn.ReLU(inplace=True),
)
def _make_coord(batch, height, width):
# relative position encoding
xv, yv = torch.meshgrid([torch.arange(0, height), torch.arange(0, width)])
xv_min = (xv.float() * 2 - width) / width
yv_min = (yv.float() * 2 - height) / height
xv_max = ((xv + 1).float() * 2 - width) / width
yv_max = ((yv + 1).float() * 2 - height) / height
xv_ctr = (xv_min + xv_max) / 2
yv_ctr = (yv_min + yv_max) / 2
hmap = torch.ones(height, width) * (1.0 / height)
wmap = torch.ones(height, width) * (1.0 / width)
coord = torch.autograd.Variable(
torch.cat(
[
xv_min.unsqueeze(0),
yv_min.unsqueeze(0),
xv_max.unsqueeze(0),
yv_max.unsqueeze(0),
xv_ctr.unsqueeze(0),
yv_ctr.unsqueeze(0),
hmap.unsqueeze(0),
wmap.unsqueeze(0),
],
dim=0,
)
)
coord = coord.unsqueeze(0).repeat(batch, 1, 1, 1)
return coord
def l1norm(X, dim, eps=1e-8):
"""L1-normalize columns of X"""
norm = torch.abs(X).sum(dim=dim, keepdim=True) + eps
X = torch.div(X, norm)
return X
def l2norm(X, dim, eps=1e-8):
"""L2-normalize columns of X"""
norm = torch.pow(X, 2).sum(dim=dim, keepdim=True).sqrt() + eps
X = torch.div(X, norm)
return X
def func_attention(query, context, smooth=1, raw_feature_norm="softmax", eps=1e-8):
"""
query: (n_context, queryL, d)
context: (n_context, sourceL, d)
"""
batch_size_q, queryL = query.size(0), query.size(1)
batch_size, sourceL = context.size(0), context.size(1)
# Get attention
# --> (batch, d, queryL)
queryT = torch.transpose(query, 1, 2)
# (batch, sourceL, d)(batch, d, queryL)
# --> (batch, sourceL, queryL)
attn = torch.bmm(context, queryT)
if raw_feature_norm == "softmax":
# --> (batch*sourceL, queryL)
attn = attn.view(batch_size * sourceL, queryL)
attn = nn.Softmax()(attn)
# --> (batch, sourceL, queryL)
attn = attn.view(batch_size, sourceL, queryL)
elif raw_feature_norm == "l2norm":
attn = l2norm(attn, 2)
elif raw_feature_norm == "clipped_l2norm":
attn = nn.LeakyReLU(0.1)(attn)
attn = l2norm(attn, 2)
else:
raise ValueError("unknown first norm type:", raw_feature_norm)
# --> (batch, queryL, sourceL)
attn = torch.transpose(attn, 1, 2).contiguous()
# --> (batch*queryL, sourceL)
attn = attn.view(batch_size * queryL, sourceL)
attn = nn.Softmax()(attn * smooth)
# --> (batch, queryL, sourceL)
attn = attn.view(batch_size, queryL, sourceL)
# --> (batch, sourceL, queryL)
attnT = torch.transpose(attn, 1, 2).contiguous()
# --> (batch, d, sourceL)
contextT = torch.transpose(context, 1, 2)
# (batch x d x sourceL)(batch x sourceL x queryL)
# --> (batch, d, queryL)
weightedContext = torch.bmm(contextT, attnT)
# --> (batch, queryL, d)
weightedContext = torch.transpose(weightedContext, 1, 2)
return weightedContext, attnT
class MultiHeadAttention(nn.Module):
"""Multi-head attention module for both image and text"""
def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
super(MultiHeadAttention, self).__init__()
self.n_head = n_head
self.d_k = d_k
self.d_v = d_v
self.w_qs = nn.Linear(d_model, n_head * d_k, bias=False)
self.w_ks = nn.Linear(d_model, n_head * d_k, bias=False)
self.w_vs = nn.Linear(d_model, n_head * d_v, bias=False)
self.fc = nn.Linear(n_head * d_v, d_model, bias=False)
self.dropout = nn.Dropout(dropout)
self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
def forward(self, q, k, v):
d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
sz_b, len_q, len_k, len_v = q.size(0), q.size(1), k.size(1), v.size(1)
# Pass through the pre-attention projection: b x lq x (n*dv)
# Separate different heads: b x lq x n x dv
q = self.w_qs(q).view(sz_b, len_q, n_head, d_k)
k = self.w_ks(k).view(sz_b, len_k, n_head, d_k)
v = self.w_vs(v).view(sz_b, len_v, n_head, d_v)
# Transpose for attention dot product: b x n x lq x dv
q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
attn = torch.matmul(q, k.transpose(2, 3))
attn = self.dropout(F.softmax(attn, dim=-1))
q = torch.matmul(attn, v)
# Transpose to move the head dimension back: b x lq x n x dv
# Combine the last two dimensions to concatenate all the heads together: b x lq x (n*dv)
q = q.transpose(1, 2).contiguous().view(sz_b, len_q, -1)
q = self.dropout(self.fc(q))
return q, attn
|