import torch import torch.nn as nn import torch.nn.functional as F class FeatureResizer(nn.Module): """ This class takes as input a set of embeddings of dimension C1 and outputs a set of embedding of dimension C2, after a linear transformation, dropout and normalization (LN). """ def __init__(self, input_feat_size, output_feat_size, dropout, do_ln=True): super().__init__() self.do_ln = do_ln # Object feature encoding self.fc = nn.Linear(input_feat_size, output_feat_size, bias=True) self.layer_norm = nn.LayerNorm(output_feat_size, eps=1e-12) self.dropout = nn.Dropout(dropout) def forward(self, encoder_features): x = self.fc(encoder_features) if self.do_ln: x = self.layer_norm(x) output = self.dropout(x) return output def _make_conv(input_dim, output_dim, k, stride=1): pad = (k - 1) // 2 return nn.Sequential( nn.Conv2d(input_dim, output_dim, (k, k), padding=(pad, pad), stride=(stride, stride)), nn.BatchNorm2d(output_dim), nn.ReLU(inplace=True), ) def _make_mlp(input_dim, output_dim, drop): return nn.Sequential( nn.Linear(input_dim, output_dim), nn.BatchNorm1d(output_dim), nn.ReLU(inplace=True), nn.Dropout(drop), nn.Linear(output_dim, output_dim), nn.BatchNorm1d(output_dim), nn.ReLU(inplace=True), ) def _make_coord(batch, height, width): # relative position encoding xv, yv = torch.meshgrid([torch.arange(0, height), torch.arange(0, width)]) xv_min = (xv.float() * 2 - width) / width yv_min = (yv.float() * 2 - height) / height xv_max = ((xv + 1).float() * 2 - width) / width yv_max = ((yv + 1).float() * 2 - height) / height xv_ctr = (xv_min + xv_max) / 2 yv_ctr = (yv_min + yv_max) / 2 hmap = torch.ones(height, width) * (1.0 / height) wmap = torch.ones(height, width) * (1.0 / width) coord = torch.autograd.Variable( torch.cat( [ xv_min.unsqueeze(0), yv_min.unsqueeze(0), xv_max.unsqueeze(0), yv_max.unsqueeze(0), xv_ctr.unsqueeze(0), yv_ctr.unsqueeze(0), hmap.unsqueeze(0), wmap.unsqueeze(0), ], dim=0, ) ) coord = coord.unsqueeze(0).repeat(batch, 1, 1, 1) return coord def l1norm(X, dim, eps=1e-8): """L1-normalize columns of X""" norm = torch.abs(X).sum(dim=dim, keepdim=True) + eps X = torch.div(X, norm) return X def l2norm(X, dim, eps=1e-8): """L2-normalize columns of X""" norm = torch.pow(X, 2).sum(dim=dim, keepdim=True).sqrt() + eps X = torch.div(X, norm) return X def func_attention(query, context, smooth=1, raw_feature_norm="softmax", eps=1e-8): """ query: (n_context, queryL, d) context: (n_context, sourceL, d) """ batch_size_q, queryL = query.size(0), query.size(1) batch_size, sourceL = context.size(0), context.size(1) # Get attention # --> (batch, d, queryL) queryT = torch.transpose(query, 1, 2) # (batch, sourceL, d)(batch, d, queryL) # --> (batch, sourceL, queryL) attn = torch.bmm(context, queryT) if raw_feature_norm == "softmax": # --> (batch*sourceL, queryL) attn = attn.view(batch_size * sourceL, queryL) attn = nn.Softmax()(attn) # --> (batch, sourceL, queryL) attn = attn.view(batch_size, sourceL, queryL) elif raw_feature_norm == "l2norm": attn = l2norm(attn, 2) elif raw_feature_norm == "clipped_l2norm": attn = nn.LeakyReLU(0.1)(attn) attn = l2norm(attn, 2) else: raise ValueError("unknown first norm type:", raw_feature_norm) # --> (batch, queryL, sourceL) attn = torch.transpose(attn, 1, 2).contiguous() # --> (batch*queryL, sourceL) attn = attn.view(batch_size * queryL, sourceL) attn = nn.Softmax()(attn * smooth) # --> (batch, queryL, sourceL) attn = attn.view(batch_size, queryL, sourceL) # --> (batch, sourceL, queryL) attnT = torch.transpose(attn, 1, 2).contiguous() # --> (batch, d, sourceL) contextT = torch.transpose(context, 1, 2) # (batch x d x sourceL)(batch x sourceL x queryL) # --> (batch, d, queryL) weightedContext = torch.bmm(contextT, attnT) # --> (batch, queryL, d) weightedContext = torch.transpose(weightedContext, 1, 2) return weightedContext, attnT class MultiHeadAttention(nn.Module): """Multi-head attention module for both image and text""" def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1): super(MultiHeadAttention, self).__init__() self.n_head = n_head self.d_k = d_k self.d_v = d_v self.w_qs = nn.Linear(d_model, n_head * d_k, bias=False) self.w_ks = nn.Linear(d_model, n_head * d_k, bias=False) self.w_vs = nn.Linear(d_model, n_head * d_v, bias=False) self.fc = nn.Linear(n_head * d_v, d_model, bias=False) self.dropout = nn.Dropout(dropout) self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) def forward(self, q, k, v): d_k, d_v, n_head = self.d_k, self.d_v, self.n_head sz_b, len_q, len_k, len_v = q.size(0), q.size(1), k.size(1), v.size(1) # Pass through the pre-attention projection: b x lq x (n*dv) # Separate different heads: b x lq x n x dv q = self.w_qs(q).view(sz_b, len_q, n_head, d_k) k = self.w_ks(k).view(sz_b, len_k, n_head, d_k) v = self.w_vs(v).view(sz_b, len_v, n_head, d_v) # Transpose for attention dot product: b x n x lq x dv q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) attn = torch.matmul(q, k.transpose(2, 3)) attn = self.dropout(F.softmax(attn, dim=-1)) q = torch.matmul(attn, v) # Transpose to move the head dimension back: b x lq x n x dv # Combine the last two dimensions to concatenate all the heads together: b x lq x (n*dv) q = q.transpose(1, 2).contiguous().view(sz_b, len_q, -1) q = self.dropout(self.fc(q)) return q, attn