from functools import partial from itertools import repeat #from torch._six import container_abcs import collections.abc as container_abcs import logging import os from collections import OrderedDict import numpy as np import scipy import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange from einops.layers.torch import Rearrange from timm.models.layers import DropPath, trunc_normal_ #from .registry import register_model #from config import config from torchinfo import summary import json _model_entrypoints = {} def register_model(fn): module_name_split = fn.__module__.split('.') model_name = module_name_split[-1] _model_entrypoints[model_name] = fn return fn def model_entrypoints(model_name): return _model_entrypoints[model_name] def is_model(model_name): return model_name in _model_entrypoints # From PyTorch internals def _ntuple(n): def parse(x): if isinstance(x, container_abcs.Iterable): return x return tuple(repeat(x, n)) return parse to_1tuple = _ntuple(1) to_2tuple = _ntuple(2) to_3tuple = _ntuple(3) to_4tuple = _ntuple(4) to_ntuple = _ntuple class LayerNorm(nn.LayerNorm): """Subclass torch's LayerNorm to handle fp16.""" def forward(self, x: torch.Tensor): orig_type = x.dtype ret = super().forward(x.type(torch.float32)) return ret.type(orig_type) class QuickGELU(nn.Module): def forward(self, x: torch.Tensor): return x * torch.sigmoid(1.702 * x) class Mlp(nn.Module): def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features) self.act = act_layer() self.fc2 = nn.Linear(hidden_features, out_features) self.drop = nn.Dropout(drop) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x class Attention(nn.Module): def __init__(self, dim_in, dim_out, num_heads, qkv_bias=False, attn_drop=0., proj_drop=0., method='dw_bn', kernel_size=3, stride_kv=1, stride_q=1, padding_kv=1, padding_q=1, with_cls_token=True, **kwargs ): super().__init__() self.stride_kv = stride_kv self.stride_q = stride_q self.dim = dim_out self.num_heads = num_heads # head_dim = self.qkv_dim // num_heads self.scale = dim_out ** -0.5 self.with_cls_token = with_cls_token self.conv_proj_q = self._build_projection( dim_in, dim_out, kernel_size, padding_q, stride_q, 'linear' if method == 'avg' else method ) self.conv_proj_k = self._build_projection( dim_in, dim_out, kernel_size, padding_kv, stride_kv, method ) self.conv_proj_v = self._build_projection( dim_in, dim_out, kernel_size, padding_kv, stride_kv, method ) self.proj_q = nn.Linear(dim_in, dim_out, bias=qkv_bias) self.proj_k = nn.Linear(dim_in, dim_out, bias=qkv_bias) self.proj_v = nn.Linear(dim_in, dim_out, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim_out, dim_out) self.proj_drop = nn.Dropout(proj_drop) def _build_projection(self, dim_in, dim_out, kernel_size, padding, stride, method): if method == 'dw_bn': proj = nn.Sequential(OrderedDict([ ('conv', nn.Conv2d( dim_in, dim_in, kernel_size=kernel_size, padding=padding, stride=stride, bias=False, groups=dim_in )), ('bn', nn.BatchNorm2d(dim_in)), ('rearrage', Rearrange('b c h w -> b (h w) c')), ])) elif method == 'avg': proj = nn.Sequential(OrderedDict([ ('avg', nn.AvgPool2d( kernel_size=kernel_size, padding=padding, stride=stride, ceil_mode=True )), ('rearrage', Rearrange('b c h w -> b (h w) c')), ])) elif method == 'linear': proj = None else: raise ValueError('Unknown method ({})'.format(method)) return proj def forward_conv(self, x, h, w): if self.with_cls_token: cls_token, x = torch.split(x, [1, h*w], 1) x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w) if self.conv_proj_q is not None: q = self.conv_proj_q(x) else: q = rearrange(x, 'b c h w -> b (h w) c') if self.conv_proj_k is not None: k = self.conv_proj_k(x) else: k = rearrange(x, 'b c h w -> b (h w) c') if self.conv_proj_v is not None: v = self.conv_proj_v(x) else: v = rearrange(x, 'b c h w -> b (h w) c') if self.with_cls_token: q = torch.cat((cls_token, q), dim=1) k = torch.cat((cls_token, k), dim=1) v = torch.cat((cls_token, v), dim=1) return q, k, v def forward(self, x, h, w): if ( self.conv_proj_q is not None or self.conv_proj_k is not None or self.conv_proj_v is not None ): q, k, v = self.forward_conv(x, h, w) q = rearrange(self.proj_q(q), 'b t (h d) -> b h t d', h=self.num_heads) k = rearrange(self.proj_k(k), 'b t (h d) -> b h t d', h=self.num_heads) v = rearrange(self.proj_v(v), 'b t (h d) -> b h t d', h=self.num_heads) attn_score = torch.einsum('bhlk,bhtk->bhlt', [q, k]) * self.scale attn = F.softmax(attn_score, dim=-1) attn = self.attn_drop(attn) x = torch.einsum('bhlt,bhtv->bhlv', [attn, v]) x = rearrange(x, 'b h t d -> b t (h d)') x = self.proj(x) x = self.proj_drop(x) return x @staticmethod def compute_macs(module, input, output): # T: num_token # S: num_token input = input[0] flops = 0 _, T, C = input.shape H = W = int(np.sqrt(T-1)) if module.with_cls_token else int(np.sqrt(T)) H_Q = H / module.stride_q W_Q = H / module.stride_q T_Q = H_Q * W_Q + 1 if module.with_cls_token else H_Q * W_Q H_KV = H / module.stride_kv W_KV = W / module.stride_kv T_KV = H_KV * W_KV + 1 if module.with_cls_token else H_KV * W_KV # C = module.dim # S = T # Scaled-dot-product macs # [B x T x C] x [B x C x T] --> [B x T x S] # multiplication-addition is counted as 1 because operations can be fused flops += T_Q * T_KV * module.dim # [B x T x S] x [B x S x C] --> [B x T x C] flops += T_Q * module.dim * T_KV if ( hasattr(module, 'conv_proj_q') and hasattr(module.conv_proj_q, 'conv') ): params = sum( [ p.numel() for p in module.conv_proj_q.conv.parameters() ] ) flops += params * H_Q * W_Q if ( hasattr(module, 'conv_proj_k') and hasattr(module.conv_proj_k, 'conv') ): params = sum( [ p.numel() for p in module.conv_proj_k.conv.parameters() ] ) flops += params * H_KV * W_KV if ( hasattr(module, 'conv_proj_v') and hasattr(module.conv_proj_v, 'conv') ): params = sum( [ p.numel() for p in module.conv_proj_v.conv.parameters() ] ) flops += params * H_KV * W_KV params = sum([p.numel() for p in module.proj_q.parameters()]) flops += params * T_Q params = sum([p.numel() for p in module.proj_k.parameters()]) flops += params * T_KV params = sum([p.numel() for p in module.proj_v.parameters()]) flops += params * T_KV params = sum([p.numel() for p in module.proj.parameters()]) flops += params * T module.__flops__ += flops class Block(nn.Module): def __init__(self, dim_in, dim_out, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, **kwargs): super().__init__() self.with_cls_token = kwargs['with_cls_token'] self.norm1 = norm_layer(dim_in) self.attn = Attention( dim_in, dim_out, num_heads, qkv_bias, attn_drop, drop, **kwargs ) self.drop_path = DropPath(drop_path) \ if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim_out) dim_mlp_hidden = int(dim_out * mlp_ratio) self.mlp = Mlp( in_features=dim_out, hidden_features=dim_mlp_hidden, act_layer=act_layer, drop=drop ) def forward(self, x, h, w): res = x x = self.norm1(x) attn = self.attn(x, h, w) x = res + self.drop_path(attn) x = x + self.drop_path(self.mlp(self.norm2(x))) return x class ConvEmbed(nn.Module): """ Image to Conv Embedding """ def __init__(self, patch_size=7, in_chans=1, #1 for spectrogram, 3 for rgb image embed_dim=64, stride=4, padding=2, norm_layer=None): super().__init__() patch_size = to_2tuple(patch_size) self.patch_size = patch_size self.proj = nn.Conv2d( in_chans, embed_dim, kernel_size=patch_size, stride=stride, padding=padding ) self.norm = norm_layer(embed_dim) if norm_layer else None def forward(self, x): x = self.proj(x) B, C, H, W = x.shape x = rearrange(x, 'b c h w -> b (h w) c') if self.norm: x = self.norm(x) x = rearrange(x, 'b (h w) c -> b c h w', h=H, w=W) return x class VisionTransformer(nn.Module): """ Vision Transformer with support for patch or hybrid CNN input stage """ def __init__(self, patch_size=16, patch_stride=16, patch_padding=0, in_chans=1, #1for spectrogram, 3 for RGB embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=False, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, init='trunc_norm', **kwargs): super().__init__() self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models self.rearrage = None self.patch_embed = ConvEmbed( # img_size=img_size, patch_size=patch_size, in_chans=in_chans, stride=patch_stride, padding=patch_padding, embed_dim=embed_dim, norm_layer=norm_layer ) with_cls_token = kwargs['with_cls_token'] if with_cls_token: self.cls_token = nn.Parameter( torch.zeros(1, 1, embed_dim) ) else: self.cls_token = None self.pos_drop = nn.Dropout(p=drop_rate) dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule blocks = [] for j in range(depth): blocks.append( Block( dim_in=embed_dim, dim_out=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[j], act_layer=act_layer, norm_layer=norm_layer, **kwargs ) ) self.blocks = nn.ModuleList(blocks) if self.cls_token is not None: trunc_normal_(self.cls_token, std=.02) if init == 'xavier': self.apply(self._init_weights_xavier) else: self.apply(self._init_weights_trunc_normal) def _init_weights_trunc_normal(self, m): if isinstance(m, nn.Linear): logging.info('=> init weight of Linear from trunc norm') trunc_normal_(m.weight, std=0.02) if m.bias is not None: logging.info('=> init bias of Linear to zeros') nn.init.constant_(m.bias, 0) elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) def _init_weights_xavier(self, m): if isinstance(m, nn.Linear): logging.info('=> init weight of Linear from xavier uniform') nn.init.xavier_uniform_(m.weight) if m.bias is not None: logging.info('=> init bias of Linear to zeros') nn.init.constant_(m.bias, 0) elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) def forward(self, x): x = self.patch_embed(x) B, C, H, W = x.size() x = rearrange(x, 'b c h w -> b (h w) c') cls_tokens = None if self.cls_token is not None: # stole cls_tokens impl from Phil Wang, thanks cls_tokens = self.cls_token.expand(B, -1, -1) x = torch.cat((cls_tokens, x), dim=1) x = self.pos_drop(x) for i, blk in enumerate(self.blocks): x = blk(x, H, W) if self.cls_token is not None: cls_tokens, x = torch.split(x, [1, H*W], 1) x = rearrange(x, 'b (h w) c -> b c h w', h=H, w=W) return x, cls_tokens class ConvolutionalVisionTransformer(nn.Module): def __init__(self, in_chans=1, #3 for RGB, 1 for Spectrogram num_classes=1000, act_layer=nn.GELU, norm_layer=nn.LayerNorm, init='trunc_norm', spec=None): super().__init__() self.num_classes = num_classes self.num_stages = spec['NUM_STAGES'] for i in range(self.num_stages): kwargs = { 'patch_size': spec['PATCH_SIZE'][i], 'patch_stride': spec['PATCH_STRIDE'][i], 'patch_padding': spec['PATCH_PADDING'][i], 'embed_dim': spec['DIM_EMBED'][i], 'depth': spec['DEPTH'][i], 'num_heads': spec['NUM_HEADS'][i], 'mlp_ratio': spec['MLP_RATIO'][i], 'qkv_bias': spec['QKV_BIAS'][i], 'drop_rate': spec['DROP_RATE'][i], 'attn_drop_rate': spec['ATTN_DROP_RATE'][i], 'drop_path_rate': spec['DROP_PATH_RATE'][i], 'with_cls_token': spec['CLS_TOKEN'][i], 'method': spec['QKV_PROJ_METHOD'][i], 'kernel_size': spec['KERNEL_QKV'][i], 'padding_q': spec['PADDING_Q'][i], 'padding_kv': spec['PADDING_KV'][i], 'stride_kv': spec['STRIDE_KV'][i], 'stride_q': spec['STRIDE_Q'][i], } stage = VisionTransformer( in_chans=in_chans, init=init, act_layer=act_layer, norm_layer=norm_layer, **kwargs ) setattr(self, f'stage{i}', stage) in_chans = spec['DIM_EMBED'][i] dim_embed = spec['DIM_EMBED'][-1] self.norm = norm_layer(dim_embed) self.cls_token = spec['CLS_TOKEN'][-1] # Classifier head #self.head = nn.Linear(dim_embed, num_classes) if num_classes > 0 else nn.Identity() #trunc_normal_(self.head.weight, std=0.02) self.head = nn.Identity() def init_weights(self, pretrained='', pretrained_layers=[], verbose=True): if os.path.isfile(pretrained): pretrained_dict = torch.load(pretrained, map_location='cpu') logging.info(f'=> loading pretrained model {pretrained}') model_dict = self.state_dict() pretrained_dict = { k: v for k, v in pretrained_dict.items() if k in model_dict.keys() } need_init_state_dict = {} for k, v in pretrained_dict.items(): need_init = ( k.split('.')[0] in pretrained_layers #or pretrained_layers[0] is '*' or pretrained_layers[0] == '*' ) if need_init: if verbose: logging.info(f'=> init {k} from {pretrained}') if 'pos_embed' in k and v.size() != model_dict[k].size(): size_pretrained = v.size() size_new = model_dict[k].size() logging.info( '=> load_pretrained: resized variant: {} to {}' .format(size_pretrained, size_new) ) ntok_new = size_new[1] ntok_new -= 1 posemb_tok, posemb_grid = v[:, :1], v[0, 1:] gs_old = int(np.sqrt(len(posemb_grid))) gs_new = int(np.sqrt(ntok_new)) logging.info( '=> load_pretrained: grid-size from {} to {}' .format(gs_old, gs_new) ) posemb_grid = posemb_grid.reshape(gs_old, gs_old, -1) zoom = (gs_new / gs_old, gs_new / gs_old, 1) posemb_grid = scipy.ndimage.zoom( posemb_grid, zoom, order=1 ) posemb_grid = posemb_grid.reshape(1, gs_new ** 2, -1) v = torch.tensor( np.concatenate([posemb_tok, posemb_grid], axis=1) ) need_init_state_dict[k] = v self.load_state_dict(need_init_state_dict, strict=False) @torch.jit.ignore def no_weight_decay(self): layers = set() for i in range(self.num_stages): layers.add(f'stage{i}.pos_embed') layers.add(f'stage{i}.cls_token') return layers def forward_features(self, x): for i in range(self.num_stages): x, cls_tokens = getattr(self, f'stage{i}')(x) if self.cls_token: x = self.norm(cls_tokens) #x = cls_tokens x = torch.squeeze(x) else: x = rearrange(x, 'b c h w -> b (h w) c') x = self.norm(x) x = torch.mean(x, dim=1) return x def forward(self, x): x = self.forward_features(x) x = self.head(x) return x @register_model def get_cls_model(**kwargs): msvit_spec = config.MODEL.SPEC msvit = ConvolutionalVisionTransformer( in_chans=1, #1 for spectrogram 3 for RGB num_classes=config.MODEL.NUM_CLASSES, act_layer=QuickGELU, norm_layer=partial(LayerNorm, eps=1e-5), init=getattr(msvit_spec, 'INIT', 'trunc_norm'), spec=msvit_spec ) # if config.MODEL.INIT_WEIGHTS: # msvit.init_weights( # config.MODEL.PRETRAINED, # config.MODEL.PRETRAINED_LAYERS, # config.VERBOSE # ) return msvit def build_model(config, **kwargs): model_name = config.MODEL.NAME if not is_model(model_name): raise ValueError(f'Unkown model: {model_name}') return model_entrypoints(model_name)(config, **kwargs) def cvt13(**kwargs): f = open('config.json', 'r') config = json.load(f) return ConvolutionalVisionTransformer(spec=config['MODEL']['SPEC']) # only loades the config, no pretraining if __name__ == '__main__': f = open('config.json', 'r') config = json.load(f) model = ConvolutionalVisionTransformer(spec=config['MODEL']['SPEC']) print(summary(model)) quit() print(summary(model, input_size=(4, 1, 128, 301)))