from .backbone import MultiModalSwinTransformer import torch.nn as nn import numpy as np import torch def window_partition(x, window_size): """ Args: x: (B, H, W, C) window_size (int): window size Returns: windows: (num_windows*B, window_size, window_size, C) """ B, H, W, C = x.shape x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) return windows class MultiModalSwin(MultiModalSwinTransformer): def __init__(self, pretrain_img_size=224, patch_size=4, in_chans=3, embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.2, norm_layer=nn.LayerNorm, ape=False, patch_norm=True, out_indices=(0, 1, 2, 3), frozen_stages=-1, use_checkpoint=False, num_heads_fusion=[1, 1, 1, 1], fusion_drop=0.0 ): super().__init__(pretrain_img_size=pretrain_img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, depths=depths, num_heads=num_heads, window_size=window_size, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop_rate=drop_rate, attn_drop_rate=attn_drop_rate, drop_path_rate=drop_path_rate, norm_layer=norm_layer, ape=ape, patch_norm=patch_norm, out_indices=out_indices, frozen_stages=frozen_stages, use_checkpoint=use_checkpoint, num_heads_fusion=num_heads_fusion, fusion_drop=fusion_drop ) self.window_size = window_size self.shift_size = window_size // 2 self.use_checkpoint = use_checkpoint def forward_stem(self, x): x = self.patch_embed(x) Wh, Ww = x.size(2), x.size(3) if self.ape: # interpolate the position embedding to the corresponding size absolute_pos_embed = F.interpolate(self.absolute_pos_embed, size=(Wh, Ww), mode='bicubic') x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C else: x = x.flatten(2).transpose(1, 2) x = self.pos_drop(x) return x, Wh, Ww def forward_stage1(self, x, H, W): #print("stage1", x.shape) # calculate attention mask for SW-MSA Hp = int(np.ceil(H / self.window_size)) * self.window_size Wp = int(np.ceil(W / self.window_size)) * self.window_size img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1 h_slices = (slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), slice(-self.shift_size, None)) w_slices = (slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), slice(-self.shift_size, None)) cnt = 0 for h in h_slices: for w in w_slices: img_mask[:, h, w, :] = cnt cnt += 1 mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 mask_windows = mask_windows.view(-1, self.window_size * self.window_size) attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) for blk in self.layers[0].blocks: blk.H, blk.W = H, W if self.use_checkpoint: x = checkpoint.checkpoint(blk, x, attn_mask) else: x = blk(x, attn_mask) # output of a Block has shape (B, H*W, dim) return x def forward_stage2(self, x, H, W): #print("stage2", x.shape) #H, W = x.size(2), x.size(3) # calculate attention mask for SW-MSA Hp = int(np.ceil(H / self.window_size)) * self.window_size Wp = int(np.ceil(W / self.window_size)) * self.window_size img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1 h_slices = (slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), slice(-self.shift_size, None)) w_slices = (slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), slice(-self.shift_size, None)) cnt = 0 for h in h_slices: for w in w_slices: img_mask[:, h, w, :] = cnt cnt += 1 mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 mask_windows = mask_windows.view(-1, self.window_size * self.window_size) attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) for blk in self.layers[1].blocks: blk.H, blk.W = H, W if self.use_checkpoint: x = checkpoint.checkpoint(blk, x, attn_mask) else: x = blk(x, attn_mask) # output of a Block has shape (B, H*W, dim) return x def forward_stage3(self, x, H, W): #H, W = x.size(2), x.size(3) # calculate attention mask for SW-MSA Hp = int(np.ceil(H / self.window_size)) * self.window_size Wp = int(np.ceil(W / self.window_size)) * self.window_size img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1 h_slices = (slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), slice(-self.shift_size, None)) w_slices = (slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), slice(-self.shift_size, None)) cnt = 0 for h in h_slices: for w in w_slices: img_mask[:, h, w, :] = cnt cnt += 1 mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 mask_windows = mask_windows.view(-1, self.window_size * self.window_size) attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) for blk in self.layers[2].blocks: blk.H, blk.W = H, W if self.use_checkpoint: x = checkpoint.checkpoint(blk, x, attn_mask) else: x = blk(x, attn_mask) # output of a Block has shape (B, H*W, dim) return x def forward_stage4(self, x, H, W): #H, W = x.size(2), x.size(3) # calculate attention mask for SW-MSA Hp = int(np.ceil(H / self.window_size)) * self.window_size Wp = int(np.ceil(W / self.window_size)) * self.window_size img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1 h_slices = (slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), slice(-self.shift_size, None)) w_slices = (slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), slice(-self.shift_size, None)) cnt = 0 for h in h_slices: for w in w_slices: img_mask[:, h, w, :] = cnt cnt += 1 mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 mask_windows = mask_windows.view(-1, self.window_size * self.window_size) attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) for blk in self.layers[3].blocks: blk.H, blk.W = H, W if self.use_checkpoint: x = checkpoint.checkpoint(blk, x, attn_mask) else: x = blk(x, attn_mask) # output of a Block has shape (B, H*W, dim) return x def forward_pwam1(self, x, H, W, l, l_mask): # PWAM fusion x_residual = self.layers[0].fusion(x, l, l_mask) # apply a gate on the residual x = x + (self.layers[0].res_gate(x_residual) * x_residual) x_residual = self.norm0(x_residual) x_residual = x_residual.view(-1, H, W, self.num_features[0]).permute(0, 3, 1, 2).contiguous() if self.layers[0].downsample is not None: x_down = self.layers[0].downsample(x, H, W) Wh, Ww = (H + 1) // 2, (W + 1) // 2 return x_residual, H, W, x_down, Wh, Ww else: return x_residual, H, W, x, H, W def forward_pwam2(self, x, H, W, l, l_mask): # PWAM fusion x_residual = self.layers[1].fusion(x, l, l_mask) # apply a gate on the residual x = x + (self.layers[1].res_gate(x_residual) * x_residual) x_residual = self.norm1(x_residual) x_residual = x_residual.view(-1, H, W, self.num_features[1]).permute(0, 3, 1, 2).contiguous() if self.layers[1].downsample is not None: x_down = self.layers[1].downsample(x, H, W) Wh, Ww = (H + 1) // 2, (W + 1) // 2 return x_residual, H, W, x_down, Wh, Ww else: return x_residual, H, W, x, H, W def forward_pwam3(self, x, H, W, l, l_mask): # PWAM fusion x_residual = self.layers[2].fusion(x, l, l_mask) # apply a gate on the residual x = x + (self.layers[2].res_gate(x_residual) * x_residual) x_residual = self.norm2(x_residual) x_residual = x_residual.view(-1, H, W, self.num_features[2]).permute(0, 3, 1, 2).contiguous() if self.layers[2].downsample is not None: x_down = self.layers[2].downsample(x, H, W) Wh, Ww = (H + 1) // 2, (W + 1) // 2 return x_residual, H, W, x_down, Wh, Ww else: return x_residual, H, W, x, H, W def forward_pwam4(self, x, H, W, l, l_mask): # PWAM fusion x_residual = self.layers[3].fusion(x, l, l_mask) # apply a gate on the residual x = x + (self.layers[3].res_gate(x_residual) * x_residual) x_residual = self.norm3(x_residual) x_residual = x_residual.view(-1, H, W, self.num_features[3]).permute(0, 3, 1, 2).contiguous() if self.layers[3].downsample is not None: x_down = self.layers[3].downsample(x, H, W) Wh, Ww = (H + 1) // 2, (W + 1) // 2 return x_residual, H, W, x_down, Wh, Ww else: return x_residual, H, W, x, H, W