|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import numpy as np |
|
from einops import rearrange |
|
from torchvision import models |
|
import pdb |
|
import math |
|
|
|
|
|
|
|
|
|
class MultiModalResNet(nn.Module): |
|
def __init__(self, pretrained): |
|
super(MultiModalResNet, self).__init__() |
|
resnet = models.resnet101() |
|
if pretrained: |
|
resnet.load_state_dict(torch.load(pretrained)) |
|
|
|
self.layer1 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1) |
|
self.layer2 = resnet.layer2 |
|
self.layer3 = resnet.layer3 |
|
self.layer4 = resnet.layer4 |
|
|
|
self.all_fusion1 = All_Fusion_Block(dim=256, num_heads=1, dropout=0.0) |
|
self.all_fusion2 = All_Fusion_Block(dim=512, num_heads=1, dropout=0.0) |
|
self.all_fusion3 = All_Fusion_Block(dim=1024, num_heads=1, dropout=0.0) |
|
self.all_fusion4 = All_Fusion_Block(dim=2048, num_heads=1, dropout=0.0) |
|
|
|
|
|
|
|
def forward(self, x, l, l_mask): |
|
|
|
outs = [] |
|
|
|
x = self.layer1(x) |
|
x, l = self.all_fusion1(x, l ,l_mask) |
|
outs.append(x) |
|
|
|
x = self.layer2(x) |
|
x, l = self.all_fusion2(x, l ,l_mask) |
|
outs.append(x) |
|
|
|
|
|
x = self.layer3(x) |
|
x, l = self.all_fusion3(x, l ,l_mask) |
|
outs.append(x) |
|
|
|
|
|
x = self.layer4(x) |
|
x, l = self.all_fusion4(x, l ,l_mask) |
|
outs.append(x) |
|
|
|
|
|
return l, tuple(outs) |
|
|
|
|
|
class All_Fusion_Block(nn.Module): |
|
def __init__(self, dim, num_heads=1, dropout=0.0): |
|
super(All_Fusion_Block, self).__init__() |
|
|
|
self.fusion = PWAM(dim, |
|
dim, |
|
768, |
|
dim, |
|
dim, |
|
num_heads=num_heads, |
|
dropout=dropout) |
|
self.res_gate = nn.Sequential( |
|
nn.Linear(dim, dim, bias=False), |
|
nn.ReLU(), |
|
nn.Linear(dim, dim, bias=False), |
|
nn.Tanh() |
|
) |
|
self.W_l = nn.Sequential( |
|
nn.Conv1d(768, 768, 1, 1), |
|
nn.GELU() |
|
) |
|
|
|
def forward(self, x, l, l_mask): |
|
|
|
H, W = x.shape[2], x.shape[3] |
|
x = x.view(x.shape[0], x.shape[1], H*W) |
|
x = x.permute(0, 2, 1).contiguous() |
|
x_residual, l_residual = self.fusion(x, l, l_mask) |
|
|
|
x = x + (self.res_gate(x_residual) * x_residual) |
|
l = l + self.W_l(l_residual) |
|
x = x.permute(0, 2, 1).contiguous() |
|
x = x.view(x.shape[0], x.shape[1], H, W) |
|
|
|
return x, l |
|
|
|
|
|
class PWAM(nn.Module): |
|
def __init__(self, dim, v_in_channels, l_in_channels, key_channels, value_channels, num_heads=0, dropout=0.0): |
|
super(PWAM, self).__init__() |
|
|
|
self.vis_project = nn.Sequential(nn.Conv1d(dim, dim, 1, 1), |
|
nn.GELU(), |
|
nn.Dropout(dropout) |
|
) |
|
|
|
self.image_lang_att = SpatialImageInteraction(v_in_channels, |
|
l_in_channels, |
|
key_channels, |
|
value_channels, |
|
out_channels=value_channels, |
|
num_heads=num_heads) |
|
|
|
self.project_mm = nn.Sequential(nn.Conv1d(value_channels, value_channels, 1, 1), |
|
nn.GELU(), |
|
nn.Dropout(dropout) |
|
) |
|
|
|
def forward(self, x, l, l_mask): |
|
|
|
|
|
vis = self.vis_project(x.permute(0, 2, 1)) |
|
|
|
lang, lang1 = self.image_lang_att(x, l, l_mask) |
|
|
|
lang = lang.permute(0, 2, 1) |
|
|
|
mm = torch.mul(vis, lang) |
|
mm = self.project_mm(mm) |
|
|
|
mm = mm.permute(0, 2, 1) |
|
|
|
return mm, lang1 |
|
|
|
|
|
class SpatialImageInteraction(nn.Module): |
|
def __init__(self, v_in_channels, l_in_channels, key_channels, value_channels, out_channels=None, num_heads=1): |
|
super(SpatialImageInteraction, self).__init__() |
|
|
|
|
|
|
|
self.v_in_channels = v_in_channels |
|
self.l_in_channels = l_in_channels |
|
self.out_channels = out_channels |
|
self.key_channels = key_channels |
|
self.value_channels = value_channels |
|
self.value_channels_l = l_in_channels |
|
if out_channels is None: |
|
self.out_channels = self.value_channels |
|
|
|
self.f_value = nn.Sequential( |
|
nn.Conv1d(self.l_in_channels, self.value_channels, kernel_size=1, stride=1), |
|
) |
|
|
|
self.f_value_v = nn.Sequential( |
|
nn.Conv1d(self.v_in_channels, self.value_channels_l, kernel_size=1, stride=1), |
|
nn.InstanceNorm1d(self.value_channels_l), |
|
) |
|
|
|
|
|
self.W = nn.Sequential( |
|
nn.Conv1d(self.value_channels, self.out_channels, kernel_size=1, stride=1), |
|
nn.InstanceNorm1d(self.out_channels), |
|
) |
|
self.W2 = nn.Sequential( |
|
nn.Conv1d(self.l_in_channels, self.l_in_channels, kernel_size=1, stride=1), |
|
) |
|
self.num_heads = num_heads |
|
self.refineimg11 = RefineVisualSim(self.v_in_channels, self.l_in_channels, self.key_channels, kernel=1, num_heads=1) |
|
self.refineimg33 = RefineVisualSim(self.v_in_channels, self.l_in_channels, self.key_channels, kernel=3, num_heads=1) |
|
self.refineimg55 = RefineVisualSim(self.v_in_channels, self.l_in_channels, self.key_channels, kernel=5, num_heads=1) |
|
self.refinelan11 = RefineLanSim(self.v_in_channels, self.l_in_channels, self.key_channels, kernel=(1,1), num_heads=1) |
|
self.refinelan21 = RefineLanSim(self.v_in_channels, self.l_in_channels, self.key_channels, kernel=(2,1), num_heads=1) |
|
self.refinelan31 = RefineLanSim(self.v_in_channels, self.l_in_channels, self.key_channels, kernel=(3,1), num_heads=1) |
|
|
|
self.vis_weight = nn.Parameter(torch.ones(3)) |
|
self.lan_weight = nn.Parameter(torch.ones(3)) |
|
|
|
def forward(self, x, l, l_mask): |
|
|
|
|
|
|
|
B, HW = x.size(0), x.size(1) |
|
n_l = l.size(2) |
|
l_mask1 = l_mask.permute(0, 2, 1) |
|
|
|
value = self.f_value(l) |
|
value = value * l_mask1 |
|
n_l = value.size(-1) |
|
value = value.reshape(B, self.num_heads, self.value_channels // self.num_heads, n_l) |
|
|
|
|
|
sim_mapv11 = self.refineimg11(x, l, l_mask) |
|
sim_mapv33 = self.refineimg33(x, l, l_mask) |
|
sim_mapv55 = self.refineimg55(x, l, l_mask) |
|
|
|
vis_weight1 = F.softmax(self.vis_weight, dim=0) |
|
sim_mapv = vis_weight1[0] * sim_mapv11 + vis_weight1[1] * sim_mapv33 + vis_weight1[2] * sim_mapv55 |
|
|
|
out_v = torch.matmul(sim_mapv, value.permute(0, 1, 3, 2)) |
|
out_v = out_v.permute(0, 2, 1, 3).contiguous().reshape(B, HW, self.value_channels) |
|
out_v = out_v.permute(0, 2, 1) |
|
out_v = self.W(out_v) |
|
out_v = out_v.permute(0, 2, 1) |
|
|
|
x_v = x.permute(0, 2, 1) |
|
x_v = self.f_value_v(x_v) |
|
value_v = x_v.reshape(B, self.num_heads, self.l_in_channels // self.num_heads, HW) |
|
|
|
|
|
sim_mapl11 = self.refinelan11(x, l, l_mask) |
|
sim_mapl21 = self.refinelan21(x, l, l_mask) |
|
sim_mapl31 = self.refinelan31(x, l, l_mask) |
|
|
|
lan_weight1 = F.softmax(self.lan_weight, dim=0) |
|
sim_mapl = lan_weight1[0] * sim_mapl11 + lan_weight1[1] * sim_mapl21 + lan_weight1[2] * sim_mapl31 |
|
|
|
out_l = torch.matmul(sim_mapl, value_v.permute(0, 1, 3, 2)) |
|
out_l = out_l.permute(0, 2, 1, 3).contiguous().reshape(B, n_l, self.l_in_channels) |
|
out_l = out_l.permute(0, 2, 1) |
|
out_l = self.W2(out_l) |
|
|
|
return out_v, out_l |
|
|
|
class RefineVisualSim(nn.Module): |
|
def __init__(self, v_in_channels, l_in_channels, key_channels, kernel, num_heads=1): |
|
super(RefineVisualSim, self).__init__() |
|
|
|
|
|
|
|
self.v_in_channels = v_in_channels |
|
self.l_in_channels = l_in_channels |
|
if kernel == 1: |
|
self.int_channels = key_channels |
|
elif kernel == 3: |
|
self.int_channels = key_channels // 2 |
|
elif kernel == 5: |
|
self.int_channels = key_channels // 4 |
|
self.key_channels = key_channels |
|
self.num_heads = num_heads |
|
self.kernel = kernel |
|
|
|
|
|
|
|
|
|
self.f_key = nn.Sequential( |
|
nn.Conv1d(self.l_in_channels, self.key_channels, kernel_size=1, stride=1), |
|
|
|
) |
|
|
|
|
|
self.f_query = nn.Sequential( |
|
nn.Conv1d(self.v_in_channels, self.int_channels, kernel_size=1, stride=1), |
|
nn.InstanceNorm1d(self.int_channels), |
|
) |
|
self.f_query2 = nn.Sequential( |
|
nn.Conv1d(self.int_channels * (self.kernel ** 2), self.key_channels, kernel_size=1, stride=1), |
|
nn.InstanceNorm1d(self.key_channels), |
|
) |
|
|
|
|
|
def forward(self, x, l, l_mask): |
|
|
|
|
|
|
|
B, HW = x.size(0), x.size(1) |
|
n_l = l.size(2) |
|
l_mask = l_mask.permute(0, 2, 1) |
|
|
|
x = x.permute(0, 2, 1) |
|
x1 = self.f_query(x) |
|
x1 = rearrange(x1, 'b c (h w) -> b c h w', h=int(math.sqrt(x.shape[2]))) |
|
x2 = F.unfold(x1, kernel_size=self.kernel, stride=1, padding=self.kernel//2) |
|
|
|
query = self.f_query2(x2) |
|
query = query.permute(0, 2, 1) |
|
key = self.f_key(l) |
|
key = key * l_mask |
|
query = query.reshape(B, HW, self.num_heads, self.key_channels // self.num_heads).permute(0, 2, 1, 3) |
|
|
|
key = key.reshape(B, self.num_heads, self.key_channels // self.num_heads, n_l) |
|
|
|
l_mask = l_mask.unsqueeze(1) |
|
|
|
sim_map = torch.matmul(query, key) |
|
sim_map = (self.key_channels ** -.5) * sim_map |
|
|
|
sim_map = sim_map + (1e4 * l_mask - 1e4) |
|
sim_map = F.softmax(sim_map, dim=-1) |
|
|
|
return sim_map |
|
|
|
class RefineLanSim(nn.Module): |
|
def __init__(self, v_in_channels, l_in_channels, key_channels, kernel, num_heads=1): |
|
super(RefineLanSim, self).__init__() |
|
|
|
|
|
|
|
self.v_in_channels = v_in_channels |
|
self.l_in_channels = l_in_channels |
|
self.key_channels = key_channels |
|
self.num_heads = num_heads |
|
self.kernel = kernel |
|
if self.kernel[0] == 1: |
|
self.int_channels = key_channels |
|
elif self.kernel[0] == 2: |
|
self.int_channels = key_channels // 2 |
|
elif self.kernel[0] == 3: |
|
self.int_channels = key_channels // 3 |
|
|
|
|
|
|
|
|
|
self.f_key = nn.Sequential( |
|
nn.Conv1d(self.v_in_channels, self.l_in_channels, kernel_size=1, stride=1), |
|
nn.InstanceNorm1d(self.l_in_channels), |
|
) |
|
|
|
|
|
|
|
self.f_query = nn.Sequential( |
|
nn.Conv1d(self.l_in_channels, self.int_channels, kernel_size=1, stride=1), |
|
|
|
) |
|
self.f_query2 = nn.Sequential( |
|
nn.Conv1d(self.int_channels * self.kernel[0], self.l_in_channels, kernel_size=1, stride=1), |
|
|
|
) |
|
|
|
|
|
def forward(self, x, l, l_mask): |
|
|
|
|
|
|
|
|
|
|
|
B, HW = x.size(0), x.size(1) |
|
n_l = l.size(2) |
|
|
|
l1 = self.f_query(l) |
|
l1 = l1.unsqueeze(3) |
|
l1 = F.pad(l1, (0, 0, self.kernel[0]//2, (self.kernel[0]-1)//2), mode='replicate') |
|
l2 = F.unfold(l1, kernel_size=(self.kernel[0], 1), stride=1) |
|
query = self.f_query2(l2) |
|
query = query.permute(0, 2, 1) |
|
|
|
|
|
x = x.permute(0, 2, 1) |
|
key = self.f_key(x) |
|
query = query * l_mask |
|
query = query.reshape(B, n_l, self.num_heads, self.l_in_channels // self.num_heads).permute(0, 2, 1, 3) |
|
|
|
key = key.reshape(B, self.num_heads, self.l_in_channels // self.num_heads, HW) |
|
|
|
l_mask = l_mask.unsqueeze(1) |
|
|
|
sim_map = torch.matmul(query, key) |
|
sim_map = (self.key_channels ** -.5) * sim_map |
|
|
|
sim_map = sim_map + (1e4 * l_mask - 1e4) |
|
sim_map = F.softmax(sim_map, dim=-1) |
|
|
|
return sim_map |