import torch import torch.nn as nn import torch.nn.functional as F from functools import partial from timm.models.layers import DropPath, to_2tuple, trunc_normal_ from timm.models.registry import register_model from timm.models.vision_transformer import _cfg from ..builder import BACKBONES from mmseg.utils import get_root_logger from mmcv.runner import load_checkpoint class Mlp(nn.Module): def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features) self.act = act_layer() self.fc2 = nn.Linear(hidden_features, out_features) self.drop = nn.Dropout(drop) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x class Attention(nn.Module): def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1): super().__init__() assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." self.dim = dim self.num_heads = num_heads head_dim = dim // num_heads self.scale = qk_scale or head_dim ** -0.5 self.q = nn.Linear(dim, dim, bias=qkv_bias) self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) self.sr_ratio = sr_ratio if sr_ratio > 1: self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) self.norm = nn.LayerNorm(dim) def forward(self, x, H, W): B, N, C = x.shape q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) if self.sr_ratio > 1: x_ = x.permute(0, 2, 1).reshape(B, C, H, W) x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1) x_ = self.norm(x_) kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) else: kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) k, v = kv[0], kv[1] attn = (q @ k.transpose(-2, -1)) * self.scale attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) return x class Block(nn.Module): def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1): super().__init__() self.norm1 = norm_layer(dim) self.attn = Attention( dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio) # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) def forward(self, x, H, W): x = x + self.drop_path(self.attn(self.norm1(x), H, W)) x = x + self.drop_path(self.mlp(self.norm2(x))) return x class PatchEmbed(nn.Module): """ Image to Patch Embedding """ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): super().__init__() img_size = to_2tuple(img_size) patch_size = to_2tuple(patch_size) self.img_size = img_size self.patch_size = patch_size assert img_size[0] % patch_size[0] == 0 and img_size[1] % patch_size[1] == 0, \ f"img_size {img_size} should be divided by patch_size {patch_size}." self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1] self.num_patches = self.H * self.W self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) self.norm = nn.LayerNorm(embed_dim) def forward(self, x): B, C, H, W = x.shape x = self.proj(x).flatten(2).transpose(1, 2) x = self.norm(x) H, W = H // self.patch_size[0], W // self.patch_size[1] return x, (H, W) class PyramidVisionTransformer(nn.Module): def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256, 512], num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm, depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], F4=False): super().__init__() self.num_classes = num_classes self.depths = depths self.F4 = F4 # patch_embed self.patch_embed1 = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dims[0]) self.patch_embed2 = PatchEmbed(img_size=img_size // 4, patch_size=2, in_chans=embed_dims[0], embed_dim=embed_dims[1]) self.patch_embed3 = PatchEmbed(img_size=img_size // 8, patch_size=2, in_chans=embed_dims[1], embed_dim=embed_dims[2]) self.patch_embed4 = PatchEmbed(img_size=img_size // 16, patch_size=2, in_chans=embed_dims[2], embed_dim=embed_dims[3]) # pos_embed self.pos_embed1 = nn.Parameter(torch.zeros(1, self.patch_embed1.num_patches, embed_dims[0])) self.pos_drop1 = nn.Dropout(p=drop_rate) self.pos_embed2 = nn.Parameter(torch.zeros(1, self.patch_embed2.num_patches, embed_dims[1])) self.pos_drop2 = nn.Dropout(p=drop_rate) self.pos_embed3 = nn.Parameter(torch.zeros(1, self.patch_embed3.num_patches, embed_dims[2])) self.pos_drop3 = nn.Dropout(p=drop_rate) self.pos_embed4 = nn.Parameter(torch.zeros(1, self.patch_embed4.num_patches + 1, embed_dims[3])) self.pos_drop4 = nn.Dropout(p=drop_rate) # transformer encoder dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule cur = 0 self.block1 = nn.ModuleList([Block( dim=embed_dims[0], num_heads=num_heads[0], mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, sr_ratio=sr_ratios[0]) for i in range(depths[0])]) cur += depths[0] self.block2 = nn.ModuleList([Block( dim=embed_dims[1], num_heads=num_heads[1], mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, sr_ratio=sr_ratios[1]) for i in range(depths[1])]) cur += depths[1] self.block3 = nn.ModuleList([Block( dim=embed_dims[2], num_heads=num_heads[2], mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, sr_ratio=sr_ratios[2]) for i in range(depths[2])]) cur += depths[2] self.block4 = nn.ModuleList([Block( dim=embed_dims[3], num_heads=num_heads[3], mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, sr_ratio=sr_ratios[3]) for i in range(depths[3])]) # init weights trunc_normal_(self.pos_embed1, std=.02) trunc_normal_(self.pos_embed2, std=.02) trunc_normal_(self.pos_embed3, std=.02) trunc_normal_(self.pos_embed4, std=.02) self.apply(self._init_weights) def init_weights(self, pretrained=None): if isinstance(pretrained, str): logger = get_root_logger() load_checkpoint(self, pretrained, map_location='cpu', strict=False, logger=logger) def reset_drop_path(self, drop_path_rate): dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))] cur = 0 for i in range(self.depths[0]): self.block1[i].drop_path.drop_prob = dpr[cur + i] cur += self.depths[0] for i in range(self.depths[1]): self.block2[i].drop_path.drop_prob = dpr[cur + i] cur += self.depths[1] for i in range(self.depths[2]): self.block3[i].drop_path.drop_prob = dpr[cur + i] cur += self.depths[2] for i in range(self.depths[3]): self.block4[i].drop_path.drop_prob = dpr[cur + i] def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) def _get_pos_embed(self, pos_embed, patch_embed, H, W): if H * W == self.patch_embed1.num_patches: return pos_embed else: return F.interpolate( pos_embed.reshape(1, patch_embed.H, patch_embed.W, -1).permute(0, 3, 1, 2), size=(H, W), mode="bilinear").reshape(1, -1, H * W).permute(0, 2, 1) def forward_features(self, x): outs = [] B = x.shape[0] # stage 1 x, (H, W) = self.patch_embed1(x) pos_embed1 = self._get_pos_embed(self.pos_embed1, self.patch_embed1, H, W) x = x + pos_embed1 x = self.pos_drop1(x) for blk in self.block1: x = blk(x, H, W) x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() outs.append(x) # stage 2 x, (H, W) = self.patch_embed2(x) pos_embed2 = self._get_pos_embed(self.pos_embed2, self.patch_embed2, H, W) x = x + pos_embed2 x = self.pos_drop2(x) for blk in self.block2: x = blk(x, H, W) x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() outs.append(x) # stage 3 x, (H, W) = self.patch_embed3(x) pos_embed3 = self._get_pos_embed(self.pos_embed3, self.patch_embed3, H, W) x = x + pos_embed3 x = self.pos_drop3(x) for blk in self.block3: x = blk(x, H, W) x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() outs.append(x) # stage 4 x, (H, W) = self.patch_embed4(x) pos_embed4 = self._get_pos_embed(self.pos_embed4[:, 1:], self.patch_embed4, H, W) x = x + pos_embed4 x = self.pos_drop4(x) for blk in self.block4: x = blk(x, H, W) x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() outs.append(x) return outs def forward(self, x): x = self.forward_features(x) if self.F4: x = x[3:4] return x def _conv_filter(state_dict, patch_size=16): """ convert patch embedding weight from manual patchify + linear proj to conv""" out_dict = {} for k, v in state_dict.items(): if 'patch_embed.proj.weight' in k: v = v.reshape((v.shape[0], 3, patch_size, patch_size)) out_dict[k] = v return out_dict @BACKBONES.register_module() class pvt_tiny(PyramidVisionTransformer): def __init__(self, **kwargs): super(pvt_tiny, self).__init__( patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1], drop_rate=0.0, drop_path_rate=0.2) @BACKBONES.register_module() class pvt_small(PyramidVisionTransformer): def __init__(self, **kwargs): super(pvt_small, self).__init__( patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], drop_rate=0.0, drop_path_rate=0.2) @BACKBONES.register_module() class pvt_small_f4(PyramidVisionTransformer): def __init__(self, **kwargs): super(pvt_small_f4, self).__init__( patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], drop_rate=0.0, drop_path_rate=0.2, F4=True) @BACKBONES.register_module() class pvt_large(PyramidVisionTransformer): def __init__(self, **kwargs): super(pvt_large, self).__init__( patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 8, 27, 3], sr_ratios=[8, 4, 2, 1], drop_rate=0.0, drop_path_rate=0.2)