import cv2 import math import torch import numpy as np import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint as checkpoint from timm.models.layers import DropPath, to_2tuple, trunc_normal_ import os import gdown class Mlp(nn.Module): def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features) self.act = act_layer() self.fc2 = nn.Linear(hidden_features, out_features) self.drop = nn.Dropout(drop) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x def window_partition(x, window_size): """ Args: x: (B, H, W, C) window_size (int): window size Returns: windows: (num_windows*B, window_size, window_size, C) """ B, H, W, C = x.shape x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) return windows def window_reverse(windows, window_size, H, W): """ Args: windows: (num_windows*B, window_size, window_size, C) window_size (int): Window size H (int): Height of image W (int): Width of image Returns: x: (B, H, W, C) """ B = int(windows.shape[0] / (H * W / window_size / window_size)) x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) return x class WindowAttention(nn.Module): r""" Window based multi-head self attention (W-MSA) module with relative position bias. It supports both of shifted and non-shifted window. Args: dim (int): Number of input channels. window_size (tuple[int]): The height and width of the window. num_heads (int): Number of attention heads. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 proj_drop (float, optional): Dropout ratio of output. Default: 0.0 """ def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): super().__init__() self.dim = dim self.window_size = window_size # Wh, Ww self.num_heads = num_heads head_dim = dim // num_heads self.scale = qk_scale or head_dim ** -0.5 # define a parameter table of relative position bias self.relative_position_bias_table = nn.Parameter( torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH # get pair-wise relative position index for each token inside the window coords_h = torch.arange(self.window_size[0]) coords_w = torch.arange(self.window_size[1]) coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 relative_coords[:, :, 1] += self.window_size[1] - 1 relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww self.register_buffer("relative_position_index", relative_position_index) self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) trunc_normal_(self.relative_position_bias_table, std=.02) self.softmax = nn.Softmax(dim=-1) def forward(self, x, mask=None): """ Args: x: input features with shape of (num_windows*B, N, C) mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None """ with torch.cuda.amp.autocast(True): B_, N, C = x.shape qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) with torch.cuda.amp.autocast(False): q, k, v = qkv[0].float(), qkv[1].float(), qkv[2].float() # make torchscript happy (cannot use tensor as tuple) q = q * self.scale attn = (q @ k.transpose(-2, -1)) relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww attn = attn + relative_position_bias.unsqueeze(0) if mask is not None: nW = mask.shape[0] attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) attn = attn.view(-1, self.num_heads, N, N) attn = self.softmax(attn) else: attn = self.softmax(attn) attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(B_, N, C) with torch.cuda.amp.autocast(True): x = self.proj(x) x = self.proj_drop(x) return x def extra_repr(self) -> str: return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' def flops(self, N): # calculate flops for 1 window with token length of N flops = 0 # qkv = self.qkv(x) flops += N * self.dim * 3 * self.dim # attn = (q @ k.transpose(-2, -1)) flops += self.num_heads * N * (self.dim // self.num_heads) * N # x = (attn @ v) flops += self.num_heads * N * N * (self.dim // self.num_heads) # x = self.proj(x) flops += N * self.dim * self.dim return flops class SwinTransformerBlock(nn.Module): r""" Swin Transformer Block. Args: dim (int): Number of input channels. input_resolution (tuple[int]): Input resulotion. num_heads (int): Number of attention heads. window_size (int): Window size. shift_size (int): Shift size for SW-MSA. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. drop (float, optional): Dropout rate. Default: 0.0 attn_drop (float, optional): Attention dropout rate. Default: 0.0 drop_path (float, optional): Stochastic depth rate. Default: 0.0 act_layer (nn.Module, optional): Activation layer. Default: nn.GELU norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm """ def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): super().__init__() self.dim = dim self.input_resolution = input_resolution self.num_heads = num_heads self.window_size = window_size self.shift_size = shift_size self.mlp_ratio = mlp_ratio if min(self.input_resolution) <= self.window_size: # if window size is larger than input resolution, we don't partition windows self.shift_size = 0 self.window_size = min(self.input_resolution) assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" self.norm1 = norm_layer(dim) self.attn = WindowAttention( dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) if self.shift_size > 0: # calculate attention mask for SW-MSA H, W = self.input_resolution img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 h_slices = (slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), slice(-self.shift_size, None)) w_slices = (slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), slice(-self.shift_size, None)) cnt = 0 for h in h_slices: for w in w_slices: img_mask[:, h, w, :] = cnt cnt += 1 mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 mask_windows = mask_windows.view(-1, self.window_size * self.window_size) attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) else: attn_mask = None self.register_buffer("attn_mask", attn_mask) def forward(self, x): H, W = self.input_resolution B, L, C = x.shape assert L == H * W, "input feature has wrong size" shortcut = x x = self.norm1(x) x = x.view(B, H, W, C) # cyclic shift if self.shift_size > 0: shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) else: shifted_x = x # partition windows x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C # W-MSA/SW-MSA attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C # merge windows attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C # reverse cyclic shift if self.shift_size > 0: x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) else: x = shifted_x x = x.view(B, H * W, C) # FFN x = shortcut + self.drop_path(x) with torch.cuda.amp.autocast(True): x = x + self.drop_path(self.mlp(self.norm2(x))) return x def extra_repr(self) -> str: return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" def flops(self): flops = 0 H, W = self.input_resolution # norm1 flops += self.dim * H * W # W-MSA/SW-MSA nW = H * W / self.window_size / self.window_size flops += nW * self.attn.flops(self.window_size * self.window_size) # mlp flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio # norm2 flops += self.dim * H * W return flops class PatchMerging(nn.Module): r""" Patch Merging Layer. Args: input_resolution (tuple[int]): Resolution of input feature. dim (int): Number of input channels. norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm """ def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): super().__init__() self.input_resolution = input_resolution self.dim = dim self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) self.norm = norm_layer(4 * dim) def forward(self, x): """ x: B, H*W, C """ H, W = self.input_resolution B, L, C = x.shape assert L == H * W, "input feature has wrong size" assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." x = x.view(B, H, W, C) x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C x = self.norm(x) x = self.reduction(x) return x def extra_repr(self) -> str: return f"input_resolution={self.input_resolution}, dim={self.dim}" def flops(self): H, W = self.input_resolution flops = H * W * self.dim flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim return flops class BasicLayer(nn.Module): """ A basic Swin Transformer layer for one stage. Args: dim (int): Number of input channels. input_resolution (tuple[int]): Input resolution. depth (int): Number of blocks. num_heads (int): Number of attention heads. window_size (int): Local window size. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. drop (float, optional): Dropout rate. Default: 0.0 attn_drop (float, optional): Attention dropout rate. Default: 0.0 drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. """ def __init__(self, dim, input_resolution, depth, num_heads, window_size, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): super().__init__() self.dim = dim self.input_resolution = input_resolution self.depth = depth self.use_checkpoint = use_checkpoint # build blocks self.blocks = nn.ModuleList([ SwinTransformerBlock(dim=dim, input_resolution=input_resolution, num_heads=num_heads, window_size=window_size, shift_size=0 if (i % 2 == 0) else window_size // 2, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop, attn_drop=attn_drop, drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, norm_layer=norm_layer) for i in range(depth)]) # patch merging layer if downsample is not None: self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) else: self.downsample = None def forward(self, x): for blk in self.blocks: if self.use_checkpoint: x = checkpoint.checkpoint(blk, x) else: x = blk(x) if self.downsample is not None: x = self.downsample(x) return x def extra_repr(self) -> str: return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" def flops(self): flops = 0 for blk in self.blocks: flops += blk.flops() if self.downsample is not None: flops += self.downsample.flops() return flops class PatchEmbed(nn.Module): r""" Image to Patch Embedding Args: img_size (int): Image size. Default: 224. patch_size (int): Patch token size. Default: 4. in_chans (int): Number of input image channels. Default: 3. embed_dim (int): Number of linear projection output channels. Default: 96. norm_layer (nn.Module, optional): Normalization layer. Default: None """ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): super().__init__() img_size = to_2tuple(img_size) patch_size = to_2tuple(patch_size) patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] self.img_size = img_size self.patch_size = patch_size self.patches_resolution = patches_resolution self.num_patches = patches_resolution[0] * patches_resolution[1] self.in_chans = in_chans self.embed_dim = embed_dim self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) if norm_layer is not None: self.norm = norm_layer(embed_dim) else: self.norm = None def forward(self, x): B, C, H, W = x.shape # FIXME look at relaxing size constraints assert H == self.img_size[0] and W == self.img_size[1], \ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C if self.norm is not None: x = self.norm(x) return x def flops(self): Ho, Wo = self.patches_resolution flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) if self.norm is not None: flops += Ho * Wo * self.embed_dim return flops class SwinTransformer(nn.Module): r""" Swin Transformer A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - https://arxiv.org/pdf/2103.14030 Args: img_size (int | tuple(int)): Input image size. Default 224 patch_size (int | tuple(int)): Patch size. Default: 4 in_chans (int): Number of input image channels. Default: 3 num_classes (int): Number of classes for classification head. Default: 1000 embed_dim (int): Patch embedding dimension. Default: 96 depths (tuple(int)): Depth of each Swin Transformer layer. num_heads (tuple(int)): Number of attention heads in different layers. window_size (int): Window size. Default: 7 mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None drop_rate (float): Dropout rate. Default: 0 attn_drop_rate (float): Attention dropout rate. Default: 0 drop_path_rate (float): Stochastic depth rate. Default: 0.1 norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. ape (bool): If True, add absolute position embedding to the patch embedding. Default: False patch_norm (bool): If True, add normalization after patch embedding. Default: True use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False """ def __init__(self, img_size=112, patch_size=2, in_chans=3, num_classes=1000, embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, norm_layer=nn.LayerNorm, ape=False, patch_norm=True, use_checkpoint=False, **kwargs): super().__init__() self.num_classes = num_classes self.num_layers = len(depths) self.embed_dim = embed_dim self.ape = ape self.patch_norm = patch_norm self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) self.mlp_ratio = mlp_ratio # split image into non-overlapping patches self.patch_embed = PatchEmbed( img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, norm_layer=norm_layer if self.patch_norm else None) num_patches = self.patch_embed.num_patches patches_resolution = self.patch_embed.patches_resolution self.patches_resolution = patches_resolution # absolute position embedding if self.ape: self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) trunc_normal_(self.absolute_pos_embed, std=.02) self.pos_drop = nn.Dropout(p=drop_rate) # stochastic depth dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule # build layers self.layers = nn.ModuleList() for i_layer in range(self.num_layers): layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer), input_resolution=(patches_resolution[0] // (2 ** i_layer), patches_resolution[1] // (2 ** i_layer)), depth=depths[i_layer], num_heads=num_heads[i_layer], window_size=window_size, mlp_ratio=self.mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], norm_layer=norm_layer, downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, use_checkpoint=use_checkpoint) self.layers.append(layer) self.norm = norm_layer(self.num_features) self.avgpool = nn.AdaptiveAvgPool1d(1) self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() self.feature = nn.Sequential( nn.Linear(in_features=self.num_features, out_features=self.num_features, bias=False), nn.BatchNorm1d(num_features=self.num_features, eps=2e-5), nn.Linear(in_features=self.num_features, out_features=num_classes, bias=False), nn.BatchNorm1d(num_features=num_classes, eps=2e-5) ) self.feature_resolution = (patches_resolution[0] // (2 ** (self.num_layers-1)), patches_resolution[1] // (2 ** (self.num_layers-1))) self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) @torch.jit.ignore def no_weight_decay(self): return {'absolute_pos_embed'} @torch.jit.ignore def no_weight_decay_keywords(self): return {'relative_position_bias_table'} def forward_features(self, x): patches_resolution = self.patch_embed.patches_resolution x = self.patch_embed(x) if self.ape: x = x + self.absolute_pos_embed x = self.pos_drop(x) local_features = [] i = 0 for layer in self.layers: i += 1 x = layer(x) if not i == self.num_layers: H = patches_resolution[0] // (2 ** i) W = patches_resolution[1] // (2 ** i) B, L, C = x.shape temp = x.transpose(1, 2).reshape(B, C, H, W) win_h = H // self.feature_resolution[0] win_w = W // self.feature_resolution[1] if not (win_h == 1 and win_w == 1): temp = F.avg_pool2d(temp, kernel_size=(win_h, win_w)) local_features.append(temp) local_features = torch.cat(local_features, dim=1) # B, C, H, W global_features = x B, L, C = global_features.shape global_features = global_features.transpose(1, 2).reshape(B, C, self.feature_resolution[0], self.feature_resolution[1]) # B, C, H, W x = self.norm(x) # B L C x = self.avgpool(x.transpose(1, 2)) # B C 1 x = torch.flatten(x, 1) return local_features, global_features, x def forward(self, x): local_features, global_features, x = self.forward_features(x) x = self.feature(x) return local_features, global_features, x def flops(self): flops = 0 flops += self.patch_embed.flops() for i, layer in enumerate(self.layers): flops += layer.flops() flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers) flops += self.num_features * self.num_classes return flops class BasicConv(nn.Module): def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True, bn=True, bias=False): super(BasicConv, self).__init__() self.out_channels = out_planes self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias) self.bn = nn.BatchNorm2d(out_planes,eps=1e-5, momentum=0.01, affine=True) if bn else None self.relu = nn.ReLU() if relu else None def forward(self, x): x = self.conv(x) if self.bn is not None: x = self.bn(x) if self.relu is not None: x = self.relu(x) return x class Flatten(nn.Module): def forward(self, x): return x.view(x.size(0), -1) class ChannelGate(nn.Module): def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max']): super(ChannelGate, self).__init__() self.gate_channels = gate_channels self.mlp = nn.Sequential( Flatten(), nn.Linear(gate_channels, gate_channels // reduction_ratio), nn.ReLU(), nn.Linear(gate_channels // reduction_ratio, gate_channels) ) self.pool_types = pool_types def forward(self, x): channel_att_sum = None for pool_type in self.pool_types: if pool_type=='avg': avg_pool = F.avg_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) channel_att_raw = self.mlp( avg_pool ) elif pool_type=='max': max_pool = F.max_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) channel_att_raw = self.mlp( max_pool ) elif pool_type=='lp': lp_pool = F.lp_pool2d( x, 2, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) channel_att_raw = self.mlp( lp_pool ) elif pool_type=='lse': # LSE pool only lse_pool = logsumexp_2d(x) channel_att_raw = self.mlp( lse_pool ) if channel_att_sum is None: channel_att_sum = channel_att_raw else: channel_att_sum = channel_att_sum + channel_att_raw scale = F.sigmoid( channel_att_sum ).unsqueeze(2).unsqueeze(3).expand_as(x) return x * scale def logsumexp_2d(tensor): tensor_flatten = tensor.view(tensor.size(0), tensor.size(1), -1) s, _ = torch.max(tensor_flatten, dim=2, keepdim=True) outputs = s + (tensor_flatten - s).exp().sum(dim=2, keepdim=True).log() return outputs class ChannelPool(nn.Module): def forward(self, x): return torch.cat( (torch.max(x,1)[0].unsqueeze(1), torch.mean(x,1).unsqueeze(1)), dim=1 ) class SpatialGate(nn.Module): def __init__(self): super(SpatialGate, self).__init__() kernel_size = 7 self.compress = ChannelPool() self.spatial = BasicConv(2, 1, kernel_size, stride=1, padding=(kernel_size-1) // 2, relu=False) def forward(self, x): x_compress = self.compress(x) x_out = self.spatial(x_compress) scale = F.sigmoid(x_out) # broadcasting return x * scale class CBAM(nn.Module): def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max'], no_spatial=False): super(CBAM, self).__init__() self.ChannelGate = ChannelGate(gate_channels, reduction_ratio, pool_types) self.no_spatial=no_spatial if not no_spatial: self.SpatialGate = SpatialGate() def forward(self, x): x_out = self.ChannelGate(x) if not self.no_spatial: x_out = self.SpatialGate(x_out) return x_out class ConvLayer(torch.nn.Module): def __init__(self, in_chans=768, out_chans=512, conv_mode="normal", kernel_size=3): super().__init__() self.conv_mode = conv_mode if conv_mode == "normal": self.conv = nn.Conv2d(in_chans, out_chans, kernel_size, stride=1, padding=(kernel_size-1)//2, bias=False) elif conv_mode == "split": self.convs = nn.ModuleList() for j in range(len(in_chans)): conv = nn.Conv2d(in_chans[j], out_chans[j], kernel_size, stride=1, padding=(kernel_size-1)//2, bias=False) self.convs.append(conv) self.cut = [0 for i in range(len(in_chans)+1)] self.cut[0] = 0 for i in range(1, len(in_chans)+1): self.cut[i] = self.cut[i - 1] + in_chans[i-1] def forward(self, x): if self.conv_mode == "normal": x = self.conv(x) elif self.conv_mode == "split": outputs = [] for j in range(len(self.cut)-1): input_map = x[:, self.cut[j]:self.cut[j+1]] #print(input_map.shape) output_map = self.convs[j](input_map) outputs.append(output_map) #print(output_map.shape) x = torch.cat(outputs, dim=1) return x class LANet(torch.nn.Module): def __init__(self, in_chans=512, reduction_ratio=2.0): super().__init__() self.in_chans = in_chans self.mid_chans = int(self.in_chans/reduction_ratio) self.conv1 = nn.Conv2d(self.in_chans, self.mid_chans, kernel_size=(1, 1), stride=(1, 1), bias=False) self.conv2 = nn.Conv2d(self.mid_chans, 1, kernel_size=(1, 1), stride=(1, 1), bias=False) def forward(self, x): x = F.relu(self.conv1(x)) x = torch.sigmoid(self.conv2(x)) return x def MAD(x, p=0.6): B, C, W, H = x.shape mask1 = torch.cat([torch.randperm(C).unsqueeze(dim=0) for j in range(B)], dim=0).cuda() mask2 = torch.rand([B, C]).cuda() ones = torch.ones([B, C], dtype=torch.float).cuda() zeros = torch.zeros([B, C], dtype=torch.float).cuda() mask = torch.where(mask1 == 0, zeros, ones) mask = torch.where(mask2 < p, mask, ones) x = x.permute(2, 3, 0, 1) x = x.mul(mask) x = x.permute(2, 3, 0, 1) return x class LANets(torch.nn.Module): def __init__(self, branch_num=2, feature_dim=512, la_reduction_ratio=2.0, MAD=MAD): super().__init__() self.LANets = nn.ModuleList() for i in range(branch_num): self.LANets.append(LANet(in_chans=feature_dim, reduction_ratio=la_reduction_ratio)) self.MAD = MAD self.branch_num = branch_num def forward(self, x): B, C, W, H = x.shape outputs = [] for lanet in self.LANets: output = lanet(x) outputs.append(output) LANets_output = torch.cat(outputs, dim=1) if self.MAD and self.branch_num != 1: LANets_output = self.MAD(LANets_output) mask = torch.max(LANets_output, dim=1).values.reshape(B, 1, W, H) x = x.mul(mask) return x class FeatureAttentionNet(torch.nn.Module): def __init__(self, in_chans=768, feature_dim=512, kernel_size=3, conv_shared=False, conv_mode="normal", channel_attention=None, spatial_attention=None, pooling="max", la_branch_num=2): super().__init__() self.conv_shared = conv_shared self.channel_attention = channel_attention self.spatial_attention = spatial_attention if not self.conv_shared: if conv_mode == "normal": self.conv = ConvLayer(in_chans=in_chans, out_chans=feature_dim, conv_mode="normal", kernel_size=kernel_size) elif conv_mode == "split" and in_chans == 2112: self.conv = ConvLayer(in_chans=[192, 384, 768, 768], out_chans=[47, 93, 186, 186], conv_mode="split", kernel_size=kernel_size) if self.channel_attention == "CBAM": self.channel_attention = ChannelGate(gate_channels=feature_dim) if self.spatial_attention == "CBAM": self.spatial_attention = SpatialGate() elif self.spatial_attention == "LANet": self.spatial_attention = LANets(branch_num=la_branch_num, feature_dim=feature_dim) if pooling == "max": self.pool = nn.AdaptiveMaxPool2d((1, 1)) elif pooling == "avg": self.pool = nn.AdaptiveAvgPool2d((1, 1)) self.act = nn.ReLU(inplace=True) self.norm = nn.BatchNorm1d(num_features=feature_dim, eps=2e-5) def forward(self, x): if not self.conv_shared: x = self.conv(x) if self.channel_attention: x = self.channel_attention(x) if self.spatial_attention: x = self.spatial_attention(x) x = self.act(x) B, C, _, __ = x.shape x = self.pool(x).reshape(B, C) x = self.norm(x) return x class FeatureAttentionModule(torch.nn.Module): def __init__(self, branch_num=11, in_chans=2112, feature_dim=512, conv_shared=False, conv_mode="split", kernel_size=3, channel_attention="CBAM", spatial_attention=None, la_num_list=[2 for j in range(11)], pooling="max"): super().__init__() self.conv_shared = conv_shared if self.conv_shared: if conv_mode == "normal": self.conv = ConvLayer(in_chans=in_chans, out_chans=feature_dim, conv_mode="normal", kernel_size=kernel_size) elif conv_mode == "split" and in_chans == 2112: self.conv = ConvLayer(in_chans=[192, 384, 768, 768], out_chans=[47, 93, 186, 186], conv_mode="split", kernel_size=kernel_size) self.nets = nn.ModuleList() for i in range(branch_num): net = FeatureAttentionNet(in_chans=in_chans, feature_dim=feature_dim, conv_shared=conv_shared, conv_mode=conv_mode, kernel_size=kernel_size, channel_attention=channel_attention, spatial_attention=spatial_attention, la_branch_num=la_num_list[i], pooling=pooling) self.nets.append(net) self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) def forward(self, x): if self.conv_shared: x = self.conv(x) outputs = [] for net in self.nets: output = net(x).unsqueeze(dim=0) outputs.append(output) outputs = torch.cat(outputs, dim=0) return outputs class TaskSpecificSubnet(torch.nn.Module): def __init__(self, feature_dim=512, drop_rate=0.5): super().__init__() self.feature = nn.Sequential( nn.Linear(feature_dim, feature_dim), nn.ReLU(True), nn.Dropout(drop_rate), nn.Linear(feature_dim, feature_dim), nn.ReLU(True), nn.Dropout(drop_rate),) def forward(self, x): return self.feature(x) class TaskSpecificSubnets(torch.nn.Module): def __init__(self, branch_num=11): super().__init__() self.branch_num = branch_num self.nets = nn.ModuleList() for i in range(self.branch_num): net = TaskSpecificSubnet(drop_rate=0.5) self.nets.append(net) self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) def forward(self, x): outputs = [] for i in range(self.branch_num): net = self.nets[i] output = net(x[i]).unsqueeze(dim=0) outputs.append(output) outputs = torch.cat(outputs, dim=0) return outputs class OutputModule(torch.nn.Module): def __init__(self, feature_dim=512, output_type="Dict"): super().__init__() self.output_sizes = [[2], [1, 2], [7, 2], [2 for j in range(6)], [2 for j in range(10)], [2 for j in range(5)], [2, 2], [2 for j in range(4)], [2 for j in range(6)], [2, 2], [2, 2]] self.output_fcs = nn.ModuleList() for i in range(0, len(self.output_sizes)): for j in range(len(self.output_sizes[i])): output_fc = nn.Linear(feature_dim, self.output_sizes[i][j]) self.output_fcs.append(output_fc) self.task_names = [ 'Age', 'Attractive', 'Blurry', 'Chubby', 'Heavy Makeup', 'Gender', 'Oval Face', 'Pale Skin', 'Smiling', 'Young', 'Bald', 'Bangs', 'Black Hair', 'Blond Hair', 'Brown Hair', 'Gray Hair', 'Receding Hairline', 'Straight Hair', 'Wavy Hair', 'Wearing Hat', 'Arched Eyebrows', 'Bags Under Eyes', 'Bushy Eyebrows', 'Eyeglasses', 'Narrow Eyes', 'Big Nose', 'Pointy Nose', 'High Cheekbones', 'Rosy Cheeks', 'Wearing Earrings', 'Sideburns', r"Five O'Clock Shadow", 'Big Lips', 'Mouth Slightly Open', 'Mustache', 'Wearing Lipstick', 'No Beard', 'Double Chin', 'Goatee', 'Wearing Necklace', 'Wearing Necktie', 'Expression', 'Recognition'] # Total:43 self.output_type = output_type self.apply(self._init_weights) def set_output_type(self, output_type): self.output_type = output_type def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) def forward(self, x, embedding): outputs = [] k = 0 for i in range(0, len(self.output_sizes)): for j in range(len(self.output_sizes[i])): output_fc = self.output_fcs[k] output = output_fc(x[i]) outputs.append(output) k += 1 [gender, age, young, expression, smiling, attractive, blurry, chubby, heavy_makeup, oval_face, pale_skin, bald, bangs, black_hair, blond_hair, brown_hair, gray_hair, receding_hairline, straight_hair, wavy_hair, wearing_hat, arched_eyebrows, bags_under_eyes, bushy_eyebrows, eyeglasses, narrow_eyes, big_nose, pointy_nose, high_cheekbones, rosy_cheeks, wearing_earrings, sideburns, five_o_clock_shadow, big_lips, mouth_slightly_open, mustache, wearing_lipstick, no_beard, double_chin, goatee, wearing_necklace, wearing_necktie] = outputs outputs = [age, attractive, blurry, chubby, heavy_makeup, gender, oval_face, pale_skin, smiling, young, bald, bangs, black_hair, blond_hair, brown_hair, gray_hair, receding_hairline, straight_hair, wavy_hair, wearing_hat, arched_eyebrows, bags_under_eyes, bushy_eyebrows, eyeglasses, narrow_eyes, big_nose, pointy_nose, high_cheekbones, rosy_cheeks, wearing_earrings, sideburns, five_o_clock_shadow, big_lips, mouth_slightly_open, mustache, wearing_lipstick, no_beard, double_chin, goatee, wearing_necklace, wearing_necktie, expression] # Total:42 outputs.append(embedding) result = dict() for j in range(43): result[self.task_names[j]] = outputs[j] if self.output_type == "Dict": return result elif self.output_type == "List": return outputs elif self.output_type == "Attribute": return outputs[1: 41] else: return result[self.output_type] class ModelBox(torch.nn.Module): def __init__(self, backbone=None, fam=None, tss=None, om=None, feature="global", output_type="Dict"): super().__init__() self.backbone = backbone self.fam = fam self.tss = tss self.om = om self.output_type = output_type if self.om: self.om.set_output_type(self.output_type) self.feature = feature def set_output_type(self, output_type): self.output_type = output_type if self.om: self.om.set_output_type(self.output_type) def forward(self, x): local_features, global_features, embedding = self.backbone(x) if self.feature == "all": x = torch.cat([local_features, global_features], dim=1) elif self.feature == "global": x = global_features elif self.feature == "local": x = local_features x = self.fam(x) x = self.tss(x) x = self.om(x, embedding) return x def build_model(cfg): backbone = SwinTransformer(num_classes=cfg.embedding_size) fam = FeatureAttentionModule( in_chans=cfg.fam_in_chans, kernel_size=cfg.fam_kernel_size, conv_shared=cfg.fam_conv_shared, conv_mode=cfg.fam_conv_mode, channel_attention=cfg.fam_channel_attention, spatial_attention=cfg.fam_spatial_attention, pooling=cfg.fam_pooling, la_num_list=cfg.fam_la_num_list) tss = TaskSpecificSubnets() om = OutputModule() model = ModelBox(backbone=backbone, fam=fam, tss=tss, om=om, feature=cfg.fam_feature) return model class SwinFaceCfg: network = "swin_t" fam_kernel_size=3 fam_in_chans=2112 fam_conv_shared=False fam_conv_mode="split" fam_channel_attention="CBAM" fam_spatial_attention=None fam_pooling="max" fam_la_num_list=[2 for j in range(11)] fam_feature="all" fam = "3x3_2112_F_s_C_N_max" embedding_size = 512 @torch.no_grad() def load_model(): cfg = SwinFaceCfg() weight = os.getcwd() + "/weights.pt" if not os.path.isfile(weight): gdown.download("https://drive.google.com/uc?export=download&id=1fi4IuuFV8NjnWm-CufdrhMKrkjxhSmjx", weight) model = build_model(cfg) dict_checkpoint = torch.load(weight, map_location=torch.device('cpu')) model.backbone.load_state_dict(dict_checkpoint["state_dict_backbone"]) model.fam.load_state_dict(dict_checkpoint["state_dict_fam"]) model.tss.load_state_dict(dict_checkpoint["state_dict_tss"]) model.om.load_state_dict(dict_checkpoint["state_dict_om"]) model.eval() return model def get_embeddings(model, images): embeddings = [] for img in images: img = cv2.resize(np.array(img), (112, 112)) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) img = np.transpose(img, (2, 0, 1)) img = torch.from_numpy(img).unsqueeze(0).float() img.div_(255).sub_(0.5).div_(0.5) with torch.inference_mode(): output = model(img) embeddings.append(output["Recognition"][0].numpy()) return embeddings