|
import math |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
|
|
def conv_layer(in_dim, out_dim, kernel_size=1, padding=0, stride=1): |
|
return nn.Sequential( |
|
nn.Conv2d(in_dim, out_dim, kernel_size, stride, padding, bias=False), |
|
nn.BatchNorm2d(out_dim), nn.ReLU(True)) |
|
|
|
|
|
def linear_layer(in_dim, out_dim, bias=False): |
|
return nn.Sequential(nn.Linear(in_dim, out_dim, bias), |
|
nn.BatchNorm1d(out_dim), nn.ReLU(True)) |
|
|
|
|
|
class CoordConv(nn.Module): |
|
def __init__(self, |
|
in_channels, |
|
out_channels, |
|
kernel_size=3, |
|
padding=1, |
|
stride=1): |
|
super().__init__() |
|
self.conv1 = conv_layer(in_channels + 2, out_channels, kernel_size, |
|
padding, stride) |
|
|
|
def add_coord(self, input): |
|
b, _, h, w = input.size() |
|
x_range = torch.linspace(-1, 1, w, device=input.device) |
|
y_range = torch.linspace(-1, 1, h, device=input.device) |
|
y, x = torch.meshgrid(y_range, x_range) |
|
y = y.expand([b, 1, -1, -1]) |
|
x = x.expand([b, 1, -1, -1]) |
|
coord_feat = torch.cat([x, y], 1) |
|
input = torch.cat([input, coord_feat], 1) |
|
return input |
|
|
|
def forward(self, x): |
|
x = self.add_coord(x) |
|
x = self.conv1(x) |
|
return x |
|
|
|
|
|
class Projector(nn.Module): |
|
def __init__(self, word_dim=1024, in_dim=256, kernel_size=3): |
|
super().__init__() |
|
self.in_dim = in_dim |
|
self.kernel_size = kernel_size |
|
|
|
self.vis = nn.Sequential( |
|
nn.Upsample(scale_factor=2, mode='bilinear'), |
|
conv_layer(in_dim * 2, in_dim * 2, 3, padding=1), |
|
nn.Upsample(scale_factor=2, mode='bilinear'), |
|
conv_layer(in_dim * 2, in_dim, 3, padding=1), |
|
nn.Conv2d(in_dim, in_dim, 1)) |
|
|
|
out_dim = 1 * in_dim * kernel_size * kernel_size + 1 |
|
self.txt = nn.Linear(word_dim, out_dim) |
|
|
|
def forward(self, x, word): |
|
''' |
|
x: b, 512, 26, 26 |
|
word: b, 512 |
|
''' |
|
x = self.vis(x) |
|
B, C, H, W = x.size() |
|
|
|
x = x.reshape(1, B * C, H, W) |
|
|
|
word = self.txt(word) |
|
weight, bias = word[:, :-1], word[:, -1] |
|
weight = weight.reshape(B, C, self.kernel_size, self.kernel_size) |
|
|
|
out = F.conv2d(x, |
|
weight, |
|
padding=self.kernel_size // 2, |
|
groups=weight.size(0), |
|
bias=bias) |
|
out = out.transpose(0, 1) |
|
|
|
return out |
|
|
|
|
|
class TransformerDecoder(nn.Module): |
|
def __init__(self, |
|
num_layers, |
|
d_model, |
|
nhead, |
|
dim_ffn, |
|
dropout, |
|
return_intermediate=False): |
|
super().__init__() |
|
self.layers = nn.ModuleList([ |
|
TransformerDecoderLayer(d_model=d_model, |
|
nhead=nhead, |
|
dim_feedforward=dim_ffn, |
|
dropout=dropout) for _ in range(num_layers) |
|
]) |
|
self.num_layers = num_layers |
|
self.norm = nn.LayerNorm(d_model) |
|
self.return_intermediate = return_intermediate |
|
|
|
@staticmethod |
|
def pos1d(d_model, length): |
|
""" |
|
:param d_model: dimension of the model |
|
:param length: length of positions |
|
:return: length*d_model position matrix |
|
""" |
|
if d_model % 2 != 0: |
|
raise ValueError("Cannot use sin/cos positional encoding with " |
|
"odd dim (got dim={:d})".format(d_model)) |
|
pe = torch.zeros(length, d_model) |
|
position = torch.arange(0, length).unsqueeze(1) |
|
div_term = torch.exp((torch.arange(0, d_model, 2, dtype=torch.float) * |
|
-(math.log(10000.0) / d_model))) |
|
pe[:, 0::2] = torch.sin(position.float() * div_term) |
|
pe[:, 1::2] = torch.cos(position.float() * div_term) |
|
|
|
return pe.unsqueeze(1) |
|
|
|
@staticmethod |
|
def pos2d(d_model, height, width): |
|
""" |
|
:param d_model: dimension of the model |
|
:param height: height of the positions |
|
:param width: width of the positions |
|
:return: d_model*height*width position matrix |
|
""" |
|
if d_model % 4 != 0: |
|
raise ValueError("Cannot use sin/cos positional encoding with " |
|
"odd dimension (got dim={:d})".format(d_model)) |
|
pe = torch.zeros(d_model, height, width) |
|
|
|
d_model = int(d_model / 2) |
|
div_term = torch.exp( |
|
torch.arange(0., d_model, 2) * -(math.log(10000.0) / d_model)) |
|
pos_w = torch.arange(0., width).unsqueeze(1) |
|
pos_h = torch.arange(0., height).unsqueeze(1) |
|
pe[0:d_model:2, :, :] = torch.sin(pos_w * div_term).transpose( |
|
0, 1).unsqueeze(1).repeat(1, height, 1) |
|
pe[1:d_model:2, :, :] = torch.cos(pos_w * div_term).transpose( |
|
0, 1).unsqueeze(1).repeat(1, height, 1) |
|
pe[d_model::2, :, :] = torch.sin(pos_h * div_term).transpose( |
|
0, 1).unsqueeze(2).repeat(1, 1, width) |
|
pe[d_model + 1::2, :, :] = torch.cos(pos_h * div_term).transpose( |
|
0, 1).unsqueeze(2).repeat(1, 1, width) |
|
|
|
return pe.reshape(-1, 1, height * width).permute(2, 1, 0) |
|
|
|
def forward(self, vis, txt, pad_mask): |
|
''' |
|
vis: b, 512, h, w |
|
txt: b, L, 512 |
|
pad_mask: b, L |
|
''' |
|
B, C, H, W = vis.size() |
|
_, L, D = txt.size() |
|
|
|
vis_pos = self.pos2d(C, H, W) |
|
txt_pos = self.pos1d(D, L) |
|
|
|
vis = vis.reshape(B, C, -1).permute(2, 0, 1) |
|
txt = txt.permute(1, 0, 2) |
|
|
|
output = vis |
|
intermediate = [] |
|
for layer in self.layers: |
|
output = layer(output, txt, vis_pos, txt_pos, pad_mask) |
|
if self.return_intermediate: |
|
|
|
intermediate.append(self.norm(output).permute(1, 2, 0)) |
|
|
|
if self.norm is not None: |
|
|
|
output = self.norm(output).permute(1, 2, 0) |
|
if self.return_intermediate: |
|
intermediate.pop() |
|
intermediate.append(output) |
|
|
|
return intermediate |
|
else: |
|
|
|
return output |
|
return output |
|
|
|
|
|
class TransformerDecoderLayer(nn.Module): |
|
def __init__(self, |
|
d_model=512, |
|
nhead=9, |
|
dim_feedforward=2048, |
|
dropout=0.1): |
|
super().__init__() |
|
|
|
self.self_attn_norm = nn.LayerNorm(d_model) |
|
self.cross_attn_norm = nn.LayerNorm(d_model) |
|
|
|
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) |
|
self.multihead_attn = nn.MultiheadAttention(d_model, |
|
nhead, |
|
dropout=dropout, |
|
kdim=d_model, |
|
vdim=d_model) |
|
|
|
self.ffn = nn.Sequential(nn.Linear(d_model, dim_feedforward), |
|
nn.ReLU(True), nn.Dropout(dropout), |
|
nn.LayerNorm(dim_feedforward), |
|
nn.Linear(dim_feedforward, d_model)) |
|
|
|
self.norm1 = nn.LayerNorm(d_model) |
|
self.norm2 = nn.LayerNorm(d_model) |
|
self.norm3 = nn.LayerNorm(d_model) |
|
self.dropout1 = nn.Dropout(dropout) |
|
self.dropout2 = nn.Dropout(dropout) |
|
self.dropout3 = nn.Dropout(dropout) |
|
|
|
def with_pos_embed(self, tensor, pos): |
|
return tensor if pos is None else tensor + pos.to(tensor.device) |
|
|
|
def forward(self, vis, txt, vis_pos, txt_pos, pad_mask): |
|
''' |
|
vis: 26*26, b, 512 |
|
txt: L, b, 512 |
|
vis_pos: 26*26, 1, 512 |
|
txt_pos: L, 1, 512 |
|
pad_mask: b, L |
|
''' |
|
|
|
vis2 = self.norm1(vis) |
|
q = k = self.with_pos_embed(vis2, vis_pos) |
|
vis2 = self.self_attn(q, k, value=vis2)[0] |
|
vis2 = self.self_attn_norm(vis2) |
|
vis = vis + self.dropout1(vis2) |
|
|
|
vis2 = self.norm2(vis) |
|
vis2 = self.multihead_attn(query=self.with_pos_embed(vis2, vis_pos), |
|
key=self.with_pos_embed(txt, txt_pos), |
|
value=txt, |
|
key_padding_mask=pad_mask)[0] |
|
vis2 = self.cross_attn_norm(vis2) |
|
vis = vis + self.dropout2(vis2) |
|
|
|
vis2 = self.norm3(vis) |
|
vis2 = self.ffn(vis2) |
|
vis = vis + self.dropout3(vis2) |
|
return vis |
|
|
|
|
|
class FPN(nn.Module): |
|
def __init__(self, |
|
in_channels=[512, 1024, 1024], |
|
out_channels=[256, 512, 1024]): |
|
super(FPN, self).__init__() |
|
|
|
self.txt_proj = linear_layer(in_channels[2], out_channels[2]) |
|
|
|
self.f1_v_proj = conv_layer(in_channels[2], out_channels[2], 1, 0) |
|
self.norm_layer = nn.Sequential(nn.BatchNorm2d(out_channels[2]), |
|
nn.ReLU(True)) |
|
|
|
self.f2_v_proj = conv_layer(in_channels[1], out_channels[1], 3, 1) |
|
self.f2_cat = conv_layer(out_channels[2] + out_channels[1], |
|
out_channels[1], 1, 0) |
|
|
|
self.f3_v_proj = conv_layer(in_channels[0], out_channels[0], 3, 1) |
|
self.f3_cat = conv_layer(out_channels[0] + out_channels[1], |
|
out_channels[1], 1, 0) |
|
|
|
self.f4_proj5 = conv_layer(out_channels[2], out_channels[1], 3, 1) |
|
self.f4_proj4 = conv_layer(out_channels[1], out_channels[1], 3, 1) |
|
self.f4_proj3 = conv_layer(out_channels[1], out_channels[1], 3, 1) |
|
|
|
self.aggr = conv_layer(3 * out_channels[1], out_channels[1], 1, 0) |
|
self.coordconv = nn.Sequential( |
|
CoordConv(out_channels[1], out_channels[1], 3, 1), |
|
conv_layer(out_channels[1], out_channels[1], 3, 1)) |
|
|
|
def forward(self, imgs, state): |
|
|
|
v3, v4, v5 = imgs |
|
|
|
|
|
state = self.txt_proj(state).unsqueeze(-1).unsqueeze( |
|
-1) |
|
f5 = self.f1_v_proj(v5) |
|
f5 = self.norm_layer(f5 * state) |
|
|
|
f4 = self.f2_v_proj(v4) |
|
f5_ = F.interpolate(f5, scale_factor=2, mode='bilinear') |
|
f4 = self.f2_cat(torch.cat([f4, f5_], dim=1)) |
|
|
|
f3 = self.f3_v_proj(v3) |
|
f3 = F.avg_pool2d(f3, 2, 2) |
|
f3 = self.f3_cat(torch.cat([f3, f4], dim=1)) |
|
|
|
fq5 = self.f4_proj5(f5) |
|
fq4 = self.f4_proj4(f4) |
|
fq3 = self.f4_proj3(f3) |
|
|
|
fq5 = F.interpolate(fq5, scale_factor=2, mode='bilinear') |
|
fq = torch.cat([fq3, fq4, fq5], dim=1) |
|
fq = self.aggr(fq) |
|
fq = self.coordconv(fq) |
|
|
|
return fq, f5 |
|
|