|
|
|
import torch.nn.functional as F |
|
from .backbone_ppm import MultiModalSwinTransformer |
|
import torch.nn as nn |
|
import numpy as np |
|
import torch |
|
|
|
def window_partition(x, window_size): |
|
""" |
|
Args: |
|
x: (B, H, W, C) |
|
window_size (int): window size |
|
|
|
Returns: |
|
windows: (num_windows*B, window_size, window_size, C) |
|
""" |
|
B, H, W, C = x.shape |
|
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) |
|
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) |
|
return windows |
|
|
|
|
|
class MultiModalSwin(MultiModalSwinTransformer): |
|
def __init__(self, |
|
pretrain_img_size=224, |
|
patch_size=4, |
|
in_chans=3, |
|
embed_dim=96, |
|
depths=[2, 2, 6, 2], |
|
num_heads=[3, 6, 12, 24], |
|
window_size=7, |
|
mlp_ratio=4., |
|
qkv_bias=True, |
|
qk_scale=None, |
|
drop_rate=0., |
|
attn_drop_rate=0., |
|
drop_path_rate=0.2, |
|
norm_layer=nn.LayerNorm, |
|
ape=False, |
|
patch_norm=True, |
|
out_indices=(0, 1, 2, 3), |
|
frozen_stages=-1, |
|
use_checkpoint=False, |
|
num_heads_fusion=[1, 1, 1, 1], |
|
fusion_drop=0.0 |
|
): |
|
super().__init__(pretrain_img_size=pretrain_img_size, |
|
patch_size=patch_size, |
|
in_chans=in_chans, |
|
embed_dim=embed_dim, |
|
depths=depths, |
|
num_heads=num_heads, |
|
window_size=window_size, |
|
mlp_ratio=mlp_ratio, |
|
qkv_bias=qkv_bias, |
|
qk_scale=qk_scale, |
|
drop_rate=drop_rate, |
|
attn_drop_rate=attn_drop_rate, |
|
drop_path_rate=drop_path_rate, |
|
norm_layer=norm_layer, |
|
ape=ape, |
|
patch_norm=patch_norm, |
|
out_indices=out_indices, |
|
frozen_stages=frozen_stages, |
|
use_checkpoint=use_checkpoint, |
|
num_heads_fusion=num_heads_fusion, |
|
fusion_drop=fusion_drop |
|
) |
|
self.window_size = window_size |
|
self.shift_size = window_size // 2 |
|
self.use_checkpoint = use_checkpoint |
|
|
|
def forward_stem(self, x): |
|
x = self.patch_embed(x) |
|
|
|
Wh, Ww = x.size(2), x.size(3) |
|
if self.ape: |
|
|
|
absolute_pos_embed = F.interpolate(self.absolute_pos_embed, size=(Wh, Ww), mode='bicubic') |
|
x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) |
|
else: |
|
x = x.flatten(2).transpose(1, 2) |
|
x = self.pos_drop(x) |
|
return x, Wh, Ww |
|
|
|
def forward_stage1(self, x, H, W): |
|
|
|
|
|
|
|
Hp = int(np.ceil(H / self.window_size)) * self.window_size |
|
Wp = int(np.ceil(W / self.window_size)) * self.window_size |
|
img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) |
|
h_slices = (slice(0, -self.window_size), |
|
slice(-self.window_size, -self.shift_size), |
|
slice(-self.shift_size, None)) |
|
w_slices = (slice(0, -self.window_size), |
|
slice(-self.window_size, -self.shift_size), |
|
slice(-self.shift_size, None)) |
|
cnt = 0 |
|
for h in h_slices: |
|
for w in w_slices: |
|
img_mask[:, h, w, :] = cnt |
|
cnt += 1 |
|
|
|
mask_windows = window_partition(img_mask, self.window_size) |
|
mask_windows = mask_windows.view(-1, self.window_size * self.window_size) |
|
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) |
|
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) |
|
|
|
for blk in self.layers[0].blocks: |
|
blk.H, blk.W = H, W |
|
if self.use_checkpoint: |
|
x = checkpoint.checkpoint(blk, x, attn_mask) |
|
else: |
|
x = blk(x, attn_mask) |
|
|
|
return x |
|
|
|
|
|
def forward_stage2(self, x, H, W): |
|
|
|
|
|
|
|
|
|
Hp = int(np.ceil(H / self.window_size)) * self.window_size |
|
Wp = int(np.ceil(W / self.window_size)) * self.window_size |
|
img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) |
|
h_slices = (slice(0, -self.window_size), |
|
slice(-self.window_size, -self.shift_size), |
|
slice(-self.shift_size, None)) |
|
w_slices = (slice(0, -self.window_size), |
|
slice(-self.window_size, -self.shift_size), |
|
slice(-self.shift_size, None)) |
|
cnt = 0 |
|
for h in h_slices: |
|
for w in w_slices: |
|
img_mask[:, h, w, :] = cnt |
|
cnt += 1 |
|
|
|
mask_windows = window_partition(img_mask, self.window_size) |
|
mask_windows = mask_windows.view(-1, self.window_size * self.window_size) |
|
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) |
|
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) |
|
|
|
for blk in self.layers[1].blocks: |
|
blk.H, blk.W = H, W |
|
if self.use_checkpoint: |
|
x = checkpoint.checkpoint(blk, x, attn_mask) |
|
else: |
|
x = blk(x, attn_mask) |
|
return x |
|
|
|
def forward_stage3(self, x, H, W): |
|
|
|
|
|
|
|
Hp = int(np.ceil(H / self.window_size)) * self.window_size |
|
Wp = int(np.ceil(W / self.window_size)) * self.window_size |
|
img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) |
|
h_slices = (slice(0, -self.window_size), |
|
slice(-self.window_size, -self.shift_size), |
|
slice(-self.shift_size, None)) |
|
w_slices = (slice(0, -self.window_size), |
|
slice(-self.window_size, -self.shift_size), |
|
slice(-self.shift_size, None)) |
|
cnt = 0 |
|
for h in h_slices: |
|
for w in w_slices: |
|
img_mask[:, h, w, :] = cnt |
|
cnt += 1 |
|
|
|
mask_windows = window_partition(img_mask, self.window_size) |
|
mask_windows = mask_windows.view(-1, self.window_size * self.window_size) |
|
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) |
|
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) |
|
|
|
for blk in self.layers[2].blocks: |
|
blk.H, blk.W = H, W |
|
if self.use_checkpoint: |
|
x = checkpoint.checkpoint(blk, x, attn_mask) |
|
else: |
|
x = blk(x, attn_mask) |
|
return x |
|
|
|
def forward_stage4(self, x, H, W): |
|
|
|
|
|
|
|
Hp = int(np.ceil(H / self.window_size)) * self.window_size |
|
Wp = int(np.ceil(W / self.window_size)) * self.window_size |
|
img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) |
|
h_slices = (slice(0, -self.window_size), |
|
slice(-self.window_size, -self.shift_size), |
|
slice(-self.shift_size, None)) |
|
w_slices = (slice(0, -self.window_size), |
|
slice(-self.window_size, -self.shift_size), |
|
slice(-self.shift_size, None)) |
|
cnt = 0 |
|
for h in h_slices: |
|
for w in w_slices: |
|
img_mask[:, h, w, :] = cnt |
|
cnt += 1 |
|
|
|
mask_windows = window_partition(img_mask, self.window_size) |
|
mask_windows = mask_windows.view(-1, self.window_size * self.window_size) |
|
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) |
|
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) |
|
|
|
for blk in self.layers[3].blocks: |
|
blk.H, blk.W = H, W |
|
if self.use_checkpoint: |
|
x = checkpoint.checkpoint(blk, x, attn_mask) |
|
else: |
|
x = blk(x, attn_mask) |
|
return x |
|
|
|
def forward_pwam1(self, x, H, W, l, l_mask): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
out = [] |
|
|
|
|
|
|
|
|
|
x_reshape = x.permute(0,2,1).view(x.shape[0], x.shape[2], H, W) |
|
x_size = x_reshape.size() |
|
for i, p in enumerate(self.layers[0].psizes): |
|
px = self.layers[0].pyramids[i](x_reshape) |
|
px = px.flatten(2).permute(0,2,1) |
|
|
|
px_residual = self.layers[0].fusions[i](px, l, l_mask) |
|
px_residual = px_residual.permute(0,2,1).view(x.shape[0], self.layers[0].reduction_dim , p, p) |
|
|
|
|
|
out.append(F.interpolate(px_residual, x_size[2:], mode='bilinear', align_corners=True).flatten(2).permute(0,2,1)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
x_residual = self.layers[0].fusion(x, l, l_mask) |
|
out.append(x_residual) |
|
|
|
x = x + (self.layers[0].res_gate(x_residual) * x_residual) |
|
|
|
|
|
|
|
|
|
|
|
|
|
x_residual = self.layers[0].mixer(torch.cat(out, dim =2)) |
|
x_residual = x_residual.view(-1, H, W, self.num_features[0]).permute(0, 3, 1, 2).contiguous() |
|
|
|
if self.layers[0].downsample is not None: |
|
x_down = self.layers[0].downsample(x, H, W) |
|
Wh, Ww = (H + 1) // 2, (W + 1) // 2 |
|
return x_residual, H, W, x_down, Wh, Ww |
|
else: |
|
return x_residual, H, W, x, H, W |
|
|
|
def forward_pwam2(self, x, H, W, l, l_mask): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
out = [] |
|
|
|
|
|
|
|
|
|
x_reshape = x.permute(0,2,1).view(x.shape[0], x.shape[2], H, W) |
|
x_size = x_reshape.size() |
|
for i, p in enumerate(self.layers[1].psizes): |
|
px = self.layers[1].pyramids[i](x_reshape) |
|
px = px.flatten(2).permute(0,2,1) |
|
|
|
px_residual = self.layers[1].fusions[i](px, l, l_mask) |
|
px_residual = px_residual.permute(0,2,1).view(x.shape[0], self.layers[1].reduction_dim , p, p) |
|
|
|
out.append(F.interpolate(px_residual, x_size[2:], mode='bilinear', align_corners=True).flatten(2).permute(0,2,1)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
x_residual = self.layers[1].fusion(x, l, l_mask) |
|
out.append(x_residual) |
|
|
|
x = x + (self.layers[1].res_gate(x_residual) * x_residual) |
|
|
|
|
|
|
|
|
|
|
|
|
|
x_residual = self.layers[1].mixer(torch.cat(out, dim =2)) |
|
x_residual = x_residual.view(-1, H, W, self.num_features[1]).permute(0, 3, 1, 2).contiguous() |
|
if self.layers[1].downsample is not None: |
|
x_down = self.layers[1].downsample(x, H, W) |
|
Wh, Ww = (H + 1) // 2, (W + 1) // 2 |
|
return x_residual, H, W, x_down, Wh, Ww |
|
else: |
|
return x_residual, H, W, x, H, W |
|
|
|
def forward_pwam3(self, x, H, W, l, l_mask): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
out = [] |
|
|
|
|
|
|
|
|
|
x_reshape = x.permute(0,2,1).view(x.shape[0], x.shape[2], H, W) |
|
x_size = x_reshape.size() |
|
for i, p in enumerate(self.layers[2].psizes): |
|
px = self.layers[2].pyramids[i](x_reshape) |
|
px = px.flatten(2).permute(0,2,1) |
|
|
|
px_residual = self.layers[2].fusions[i](px, l, l_mask) |
|
px_residual = px_residual.permute(0,2,1).view(x.shape[0], self.layers[2].reduction_dim , p, p) |
|
|
|
out.append(F.interpolate(px_residual, x_size[2:], mode='bilinear', align_corners=True).flatten(2).permute(0,2,1)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
x_residual = self.layers[2].fusion(x, l, l_mask) |
|
out.append(x_residual) |
|
|
|
x = x + (self.layers[2].res_gate(x_residual) * x_residual) |
|
|
|
|
|
|
|
|
|
|
|
|
|
x_residual = self.layers[2].mixer(torch.cat(out, dim =2)) |
|
x_residual = x_residual.view(-1, H, W, self.num_features[2]).permute(0, 3, 1, 2).contiguous() |
|
if self.layers[2].downsample is not None: |
|
x_down = self.layers[2].downsample(x, H, W) |
|
Wh, Ww = (H + 1) // 2, (W + 1) // 2 |
|
return x_residual, H, W, x_down, Wh, Ww |
|
else: |
|
return x_residual, H, W, x, H, W |
|
|
|
def forward_pwam4(self, x, H, W, l, l_mask): |
|
|
|
|
|
|
|
|
|
|
|
|
|
out = [] |
|
|
|
|
|
|
|
|
|
x_reshape = x.permute(0,2,1).view(x.shape[0], x.shape[2], H, W) |
|
x_size = x_reshape.size() |
|
for i, p in enumerate(self.layers[3].psizes): |
|
px = self.layers[3].pyramids[i](x_reshape) |
|
px = px.flatten(2).permute(0,2,1) |
|
|
|
px_residual = self.layers[3].fusions[i](px, l, l_mask) |
|
px_residual = px_residual.permute(0,2,1).view(x.shape[0], self.layers[3].reduction_dim , p, p) |
|
|
|
out.append(F.interpolate(px_residual, x_size[2:], mode='bilinear', align_corners=True).flatten(2).permute(0,2,1)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
x_residual = self.layers[3].fusion(x, l, l_mask) |
|
out.append(x_residual) |
|
|
|
x = x + (self.layers[3].res_gate(x_residual) * x_residual) |
|
|
|
|
|
|
|
|
|
|
|
|
|
x_residual = self.layers[3].mixer(torch.cat(out, dim =2)) |
|
x_residual = x_residual.view(-1, H, W, self.num_features[3]).permute(0, 3, 1, 2).contiguous() |
|
if self.layers[3].downsample is not None: |
|
x_down = self.layers[3].downsample(x, H, W) |
|
Wh, Ww = (H + 1) // 2, (W + 1) // 2 |
|
return x_residual, H, W, x_down, Wh, Ww |
|
else: |
|
return x_residual, H, W, x, H, W |
|
|