import torch import torch.nn as nn import torch.nn.functional as F from .utils import split_feature, merge_splits def single_head_full_attention(q, k, v): # q, k, v: [B, L, C] assert q.dim() == k.dim() == v.dim() == 3 scores = torch.matmul(q, k.permute(0, 2, 1)) / (q.size(2) ** .5) # [B, L, L] attn = torch.softmax(scores, dim=2) # [B, L, L] out = torch.matmul(attn, v) # [B, L, C] return out def generate_shift_window_attn_mask(input_resolution, window_size_h, window_size_w, shift_size_h, shift_size_w, device=torch.device('cuda')): # Ref: https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py # calculate attention mask for SW-MSA h, w = input_resolution img_mask = torch.zeros((1, h, w, 1)).to(device) # 1 H W 1 h_slices = (slice(0, -window_size_h), slice(-window_size_h, -shift_size_h), slice(-shift_size_h, None)) w_slices = (slice(0, -window_size_w), slice(-window_size_w, -shift_size_w), slice(-shift_size_w, None)) cnt = 0 for h in h_slices: for w in w_slices: img_mask[:, h, w, :] = cnt cnt += 1 mask_windows = split_feature(img_mask, num_splits=input_resolution[-1] // window_size_w, channel_last=True) mask_windows = mask_windows.view(-1, window_size_h * window_size_w) attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) return attn_mask def single_head_split_window_attention(q, k, v, num_splits=1, with_shift=False, h=None, w=None, attn_mask=None, ): # Ref: https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py # q, k, v: [B, L, C] assert q.dim() == k.dim() == v.dim() == 3 assert h is not None and w is not None assert q.size(1) == h * w b, _, c = q.size() b_new = b * num_splits * num_splits window_size_h = h // num_splits window_size_w = w // num_splits q = q.view(b, h, w, c) # [B, H, W, C] k = k.view(b, h, w, c) v = v.view(b, h, w, c) scale_factor = c ** 0.5 if with_shift: assert attn_mask is not None # compute once shift_size_h = window_size_h // 2 shift_size_w = window_size_w // 2 q = torch.roll(q, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2)) k = torch.roll(k, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2)) v = torch.roll(v, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2)) q = split_feature(q, num_splits=num_splits, channel_last=True) # [B*K*K, H/K, W/K, C] k = split_feature(k, num_splits=num_splits, channel_last=True) v = split_feature(v, num_splits=num_splits, channel_last=True) scores = torch.matmul(q.view(b_new, -1, c), k.view(b_new, -1, c).permute(0, 2, 1) ) / scale_factor # [B*K*K, H/K*W/K, H/K*W/K] if with_shift: scores += attn_mask.repeat(b, 1, 1) attn = torch.softmax(scores, dim=-1) out = torch.matmul(attn, v.view(b_new, -1, c)) # [B*K*K, H/K*W/K, C] out = merge_splits(out.view(b_new, h // num_splits, w // num_splits, c), num_splits=num_splits, channel_last=True) # [B, H, W, C] # shift back if with_shift: out = torch.roll(out, shifts=(shift_size_h, shift_size_w), dims=(1, 2)) out = out.view(b, -1, c) return out class TransformerLayer(nn.Module): def __init__(self, d_model=256, nhead=1, attention_type='swin', no_ffn=False, ffn_dim_expansion=4, with_shift=False, **kwargs, ): super(TransformerLayer, self).__init__() self.dim = d_model self.nhead = nhead self.attention_type = attention_type self.no_ffn = no_ffn self.with_shift = with_shift # multi-head attention self.q_proj = nn.Linear(d_model, d_model, bias=False) self.k_proj = nn.Linear(d_model, d_model, bias=False) self.v_proj = nn.Linear(d_model, d_model, bias=False) self.merge = nn.Linear(d_model, d_model, bias=False) self.norm1 = nn.LayerNorm(d_model) # no ffn after self-attn, with ffn after cross-attn if not self.no_ffn: in_channels = d_model * 2 self.mlp = nn.Sequential( nn.Linear(in_channels, in_channels * ffn_dim_expansion, bias=False), nn.GELU(), nn.Linear(in_channels * ffn_dim_expansion, d_model, bias=False), ) self.norm2 = nn.LayerNorm(d_model) def forward(self, source, target, height=None, width=None, shifted_window_attn_mask=None, attn_num_splits=None, **kwargs, ): # source, target: [B, L, C] query, key, value = source, target, target # single-head attention query = self.q_proj(query) # [B, L, C] key = self.k_proj(key) # [B, L, C] value = self.v_proj(value) # [B, L, C] if self.attention_type == 'swin' and attn_num_splits > 1: if self.nhead > 1: # we observe that multihead attention slows down the speed and increases the memory consumption # without bringing obvious performance gains and thus the implementation is removed raise NotImplementedError else: message = single_head_split_window_attention(query, key, value, num_splits=attn_num_splits, with_shift=self.with_shift, h=height, w=width, attn_mask=shifted_window_attn_mask, ) else: message = single_head_full_attention(query, key, value) # [B, L, C] message = self.merge(message) # [B, L, C] message = self.norm1(message) if not self.no_ffn: message = self.mlp(torch.cat([source, message], dim=-1)) message = self.norm2(message) return source + message class TransformerBlock(nn.Module): """self attention + cross attention + FFN""" def __init__(self, d_model=256, nhead=1, attention_type='swin', ffn_dim_expansion=4, with_shift=False, **kwargs, ): super(TransformerBlock, self).__init__() self.self_attn = TransformerLayer(d_model=d_model, nhead=nhead, attention_type=attention_type, no_ffn=True, ffn_dim_expansion=ffn_dim_expansion, with_shift=with_shift, ) self.cross_attn_ffn = TransformerLayer(d_model=d_model, nhead=nhead, attention_type=attention_type, ffn_dim_expansion=ffn_dim_expansion, with_shift=with_shift, ) def forward(self, source, target, height=None, width=None, shifted_window_attn_mask=None, attn_num_splits=None, **kwargs, ): # source, target: [B, L, C] # self attention source = self.self_attn(source, source, height=height, width=width, shifted_window_attn_mask=shifted_window_attn_mask, attn_num_splits=attn_num_splits, ) # cross attention and ffn source = self.cross_attn_ffn(source, target, height=height, width=width, shifted_window_attn_mask=shifted_window_attn_mask, attn_num_splits=attn_num_splits, ) return source class FeatureTransformer(nn.Module): def __init__(self, num_layers=6, d_model=128, nhead=1, attention_type='swin', ffn_dim_expansion=4, **kwargs, ): super(FeatureTransformer, self).__init__() self.attention_type = attention_type self.d_model = d_model self.nhead = nhead self.layers = nn.ModuleList([ TransformerBlock(d_model=d_model, nhead=nhead, attention_type=attention_type, ffn_dim_expansion=ffn_dim_expansion, with_shift=True if attention_type == 'swin' and i % 2 == 1 else False, ) for i in range(num_layers)]) for p in self.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) def forward(self, feature0, feature1, attn_num_splits=None, **kwargs, ): b, c, h, w = feature0.shape assert self.d_model == c feature0 = feature0.flatten(-2).permute(0, 2, 1) # [B, H*W, C] feature1 = feature1.flatten(-2).permute(0, 2, 1) # [B, H*W, C] if self.attention_type == 'swin' and attn_num_splits > 1: # global and refine use different number of splits window_size_h = h // attn_num_splits window_size_w = w // attn_num_splits # compute attn mask once shifted_window_attn_mask = generate_shift_window_attn_mask( input_resolution=(h, w), window_size_h=window_size_h, window_size_w=window_size_w, shift_size_h=window_size_h // 2, shift_size_w=window_size_w // 2, device=feature0.device, ) # [K*K, H/K*W/K, H/K*W/K] else: shifted_window_attn_mask = None # concat feature0 and feature1 in batch dimension to compute in parallel concat0 = torch.cat((feature0, feature1), dim=0) # [2B, H*W, C] concat1 = torch.cat((feature1, feature0), dim=0) # [2B, H*W, C] for layer in self.layers: concat0 = layer(concat0, concat1, height=h, width=w, shifted_window_attn_mask=shifted_window_attn_mask, attn_num_splits=attn_num_splits, ) # update feature1 concat1 = torch.cat(concat0.chunk(chunks=2, dim=0)[::-1], dim=0) feature0, feature1 = concat0.chunk(chunks=2, dim=0) # [B, H*W, C] # reshape back feature0 = feature0.view(b, h, w, c).permute(0, 3, 1, 2).contiguous() # [B, C, H, W] feature1 = feature1.view(b, h, w, c).permute(0, 3, 1, 2).contiguous() # [B, C, H, W] return feature0, feature1 class FeatureFlowAttention(nn.Module): """ flow propagation with self-attention on feature query: feature0, key: feature0, value: flow """ def __init__(self, in_channels, **kwargs, ): super(FeatureFlowAttention, self).__init__() self.q_proj = nn.Linear(in_channels, in_channels) self.k_proj = nn.Linear(in_channels, in_channels) for p in self.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) def forward(self, feature0, flow, local_window_attn=False, local_window_radius=1, **kwargs, ): # q, k: feature [B, C, H, W], v: flow [B, 2, H, W] if local_window_attn: return self.forward_local_window_attn(feature0, flow, local_window_radius=local_window_radius) b, c, h, w = feature0.size() query = feature0.view(b, c, h * w).permute(0, 2, 1) # [B, H*W, C] # a note: the ``correct'' implementation should be: # ``query = self.q_proj(query), key = self.k_proj(query)'' # this problem is observed while cleaning up the code # however, this doesn't affect the performance since the projection is a linear operation, # thus the two projection matrices for key can be merged # so I just leave it as is in order to not re-train all models :) query = self.q_proj(query) # [B, H*W, C] key = self.k_proj(query) # [B, H*W, C] value = flow.view(b, flow.size(1), h * w).permute(0, 2, 1) # [B, H*W, 2] scores = torch.matmul(query, key.permute(0, 2, 1)) / (c ** 0.5) # [B, H*W, H*W] prob = torch.softmax(scores, dim=-1) out = torch.matmul(prob, value) # [B, H*W, 2] out = out.view(b, h, w, value.size(-1)).permute(0, 3, 1, 2) # [B, 2, H, W] return out def forward_local_window_attn(self, feature0, flow, local_window_radius=1, ): assert flow.size(1) == 2 assert local_window_radius > 0 b, c, h, w = feature0.size() feature0_reshape = self.q_proj(feature0.view(b, c, -1).permute(0, 2, 1) ).reshape(b * h * w, 1, c) # [B*H*W, 1, C] kernel_size = 2 * local_window_radius + 1 feature0_proj = self.k_proj(feature0.view(b, c, -1).permute(0, 2, 1)).permute(0, 2, 1).reshape(b, c, h, w) feature0_window = F.unfold(feature0_proj, kernel_size=kernel_size, padding=local_window_radius) # [B, C*(2R+1)^2), H*W] feature0_window = feature0_window.view(b, c, kernel_size ** 2, h, w).permute( 0, 3, 4, 1, 2).reshape(b * h * w, c, kernel_size ** 2) # [B*H*W, C, (2R+1)^2] flow_window = F.unfold(flow, kernel_size=kernel_size, padding=local_window_radius) # [B, 2*(2R+1)^2), H*W] flow_window = flow_window.view(b, 2, kernel_size ** 2, h, w).permute( 0, 3, 4, 2, 1).reshape(b * h * w, kernel_size ** 2, 2) # [B*H*W, (2R+1)^2, 2] scores = torch.matmul(feature0_reshape, feature0_window) / (c ** 0.5) # [B*H*W, 1, (2R+1)^2] prob = torch.softmax(scores, dim=-1) out = torch.matmul(prob, flow_window).view(b, h, w, 2).permute(0, 3, 1, 2).contiguous() # [B, 2, H, W] return out