Spaces:
Running
on
Zero
Running
on
Zero
| #!/usr/bin/env python3 | |
| # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. | |
| # Copyright 2020 Ross Wightman | |
| # Modified Model definition | |
| from collections import OrderedDict | |
| from functools import partial | |
| import torch | |
| import torch.nn as nn | |
| from timm.layers import trunc_normal_ | |
| from mmaudio.ext.synchformer import vit_helper | |
| class VisionTransformer(nn.Module): | |
| """ Vision Transformer with support for patch or hybrid CNN input stage """ | |
| def __init__(self, cfg): | |
| super().__init__() | |
| self.img_size = cfg.DATA.TRAIN_CROP_SIZE | |
| self.patch_size = cfg.VIT.PATCH_SIZE | |
| self.in_chans = cfg.VIT.CHANNELS | |
| if cfg.TRAIN.DATASET == "Epickitchens": | |
| self.num_classes = [97, 300] | |
| else: | |
| self.num_classes = cfg.MODEL.NUM_CLASSES | |
| self.embed_dim = cfg.VIT.EMBED_DIM | |
| self.depth = cfg.VIT.DEPTH | |
| self.num_heads = cfg.VIT.NUM_HEADS | |
| self.mlp_ratio = cfg.VIT.MLP_RATIO | |
| self.qkv_bias = cfg.VIT.QKV_BIAS | |
| self.drop_rate = cfg.VIT.DROP | |
| self.drop_path_rate = cfg.VIT.DROP_PATH | |
| self.head_dropout = cfg.VIT.HEAD_DROPOUT | |
| self.video_input = cfg.VIT.VIDEO_INPUT | |
| self.temporal_resolution = cfg.VIT.TEMPORAL_RESOLUTION | |
| self.use_mlp = cfg.VIT.USE_MLP | |
| self.num_features = self.embed_dim | |
| norm_layer = partial(nn.LayerNorm, eps=1e-6) | |
| self.attn_drop_rate = cfg.VIT.ATTN_DROPOUT | |
| self.head_act = cfg.VIT.HEAD_ACT | |
| self.cfg = cfg | |
| # Patch Embedding | |
| self.patch_embed = vit_helper.PatchEmbed(img_size=224, | |
| patch_size=self.patch_size, | |
| in_chans=self.in_chans, | |
| embed_dim=self.embed_dim) | |
| # 3D Patch Embedding | |
| self.patch_embed_3d = vit_helper.PatchEmbed3D(img_size=self.img_size, | |
| temporal_resolution=self.temporal_resolution, | |
| patch_size=self.patch_size, | |
| in_chans=self.in_chans, | |
| embed_dim=self.embed_dim, | |
| z_block_size=self.cfg.VIT.PATCH_SIZE_TEMP) | |
| self.patch_embed_3d.proj.weight.data = torch.zeros_like( | |
| self.patch_embed_3d.proj.weight.data) | |
| # Number of patches | |
| if self.video_input: | |
| num_patches = self.patch_embed.num_patches * self.temporal_resolution | |
| else: | |
| num_patches = self.patch_embed.num_patches | |
| self.num_patches = num_patches | |
| # CLS token | |
| self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) | |
| trunc_normal_(self.cls_token, std=.02) | |
| # Positional embedding | |
| self.pos_embed = nn.Parameter( | |
| torch.zeros(1, self.patch_embed.num_patches + 1, self.embed_dim)) | |
| self.pos_drop = nn.Dropout(p=cfg.VIT.POS_DROPOUT) | |
| trunc_normal_(self.pos_embed, std=.02) | |
| if self.cfg.VIT.POS_EMBED == "joint": | |
| self.st_embed = nn.Parameter(torch.zeros(1, num_patches + 1, self.embed_dim)) | |
| trunc_normal_(self.st_embed, std=.02) | |
| elif self.cfg.VIT.POS_EMBED == "separate": | |
| self.temp_embed = nn.Parameter(torch.zeros(1, self.temporal_resolution, self.embed_dim)) | |
| # Layer Blocks | |
| dpr = [x.item() for x in torch.linspace(0, self.drop_path_rate, self.depth)] | |
| if self.cfg.VIT.ATTN_LAYER == "divided": | |
| self.blocks = nn.ModuleList([ | |
| vit_helper.DividedSpaceTimeBlock( | |
| attn_type=cfg.VIT.ATTN_LAYER, | |
| dim=self.embed_dim, | |
| num_heads=self.num_heads, | |
| mlp_ratio=self.mlp_ratio, | |
| qkv_bias=self.qkv_bias, | |
| drop=self.drop_rate, | |
| attn_drop=self.attn_drop_rate, | |
| drop_path=dpr[i], | |
| norm_layer=norm_layer, | |
| ) for i in range(self.depth) | |
| ]) | |
| else: | |
| self.blocks = nn.ModuleList([ | |
| vit_helper.Block(attn_type=cfg.VIT.ATTN_LAYER, | |
| dim=self.embed_dim, | |
| num_heads=self.num_heads, | |
| mlp_ratio=self.mlp_ratio, | |
| qkv_bias=self.qkv_bias, | |
| drop=self.drop_rate, | |
| attn_drop=self.attn_drop_rate, | |
| drop_path=dpr[i], | |
| norm_layer=norm_layer, | |
| use_original_code=self.cfg.VIT.USE_ORIGINAL_TRAJ_ATTN_CODE) | |
| for i in range(self.depth) | |
| ]) | |
| self.norm = norm_layer(self.embed_dim) | |
| # MLP head | |
| if self.use_mlp: | |
| hidden_dim = self.embed_dim | |
| if self.head_act == 'tanh': | |
| # logging.info("Using TanH activation in MLP") | |
| act = nn.Tanh() | |
| elif self.head_act == 'gelu': | |
| # logging.info("Using GELU activation in MLP") | |
| act = nn.GELU() | |
| else: | |
| # logging.info("Using ReLU activation in MLP") | |
| act = nn.ReLU() | |
| self.pre_logits = nn.Sequential( | |
| OrderedDict([ | |
| ('fc', nn.Linear(self.embed_dim, hidden_dim)), | |
| ('act', act), | |
| ])) | |
| else: | |
| self.pre_logits = nn.Identity() | |
| # Classifier Head | |
| self.head_drop = nn.Dropout(p=self.head_dropout) | |
| if isinstance(self.num_classes, (list, )) and len(self.num_classes) > 1: | |
| for a, i in enumerate(range(len(self.num_classes))): | |
| setattr(self, "head%d" % a, nn.Linear(self.embed_dim, self.num_classes[i])) | |
| else: | |
| self.head = nn.Linear(self.embed_dim, | |
| self.num_classes) if self.num_classes > 0 else nn.Identity() | |
| # Initialize weights | |
| self.apply(self._init_weights) | |
| def _init_weights(self, m): | |
| if isinstance(m, nn.Linear): | |
| trunc_normal_(m.weight, std=.02) | |
| if isinstance(m, nn.Linear) and m.bias is not None: | |
| nn.init.constant_(m.bias, 0) | |
| elif isinstance(m, nn.LayerNorm): | |
| nn.init.constant_(m.bias, 0) | |
| nn.init.constant_(m.weight, 1.0) | |
| def no_weight_decay(self): | |
| if self.cfg.VIT.POS_EMBED == "joint": | |
| return {'pos_embed', 'cls_token', 'st_embed'} | |
| else: | |
| return {'pos_embed', 'cls_token', 'temp_embed'} | |
| def get_classifier(self): | |
| return self.head | |
| def reset_classifier(self, num_classes, global_pool=''): | |
| self.num_classes = num_classes | |
| self.head = (nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()) | |
| def forward_features(self, x): | |
| # if self.video_input: | |
| # x = x[0] | |
| B = x.shape[0] | |
| # Tokenize input | |
| # if self.cfg.VIT.PATCH_SIZE_TEMP > 1: | |
| # for simplicity of mapping between content dimensions (input x) and token dims (after patching) | |
| # we use the same trick as for AST (see modeling_ast.ASTModel.forward for the details): | |
| # apply patching on input | |
| x = self.patch_embed_3d(x) | |
| tok_mask = None | |
| # else: | |
| # tok_mask = None | |
| # # 2D tokenization | |
| # if self.video_input: | |
| # x = x.permute(0, 2, 1, 3, 4) | |
| # (B, T, C, H, W) = x.shape | |
| # x = x.reshape(B * T, C, H, W) | |
| # x = self.patch_embed(x) | |
| # if self.video_input: | |
| # (B2, T2, D2) = x.shape | |
| # x = x.reshape(B, T * T2, D2) | |
| # Append CLS token | |
| cls_tokens = self.cls_token.expand(B, -1, -1) | |
| x = torch.cat((cls_tokens, x), dim=1) | |
| # if tok_mask is not None: | |
| # # prepend 1(=keep) to the mask to account for the CLS token as well | |
| # tok_mask = torch.cat((torch.ones_like(tok_mask[:, [0]]), tok_mask), dim=1) | |
| # Interpolate positinoal embeddings | |
| # if self.cfg.DATA.TRAIN_CROP_SIZE != 224: | |
| # pos_embed = self.pos_embed | |
| # N = pos_embed.shape[1] - 1 | |
| # npatch = int((x.size(1) - 1) / self.temporal_resolution) | |
| # class_emb = pos_embed[:, 0] | |
| # pos_embed = pos_embed[:, 1:] | |
| # dim = x.shape[-1] | |
| # pos_embed = torch.nn.functional.interpolate( | |
| # pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), | |
| # scale_factor=math.sqrt(npatch / N), | |
| # mode='bicubic', | |
| # ) | |
| # pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) | |
| # new_pos_embed = torch.cat((class_emb.unsqueeze(0), pos_embed), dim=1) | |
| # else: | |
| new_pos_embed = self.pos_embed | |
| npatch = self.patch_embed.num_patches | |
| # Add positional embeddings to input | |
| if self.video_input: | |
| if self.cfg.VIT.POS_EMBED == "separate": | |
| cls_embed = self.pos_embed[:, 0, :].unsqueeze(1) | |
| tile_pos_embed = new_pos_embed[:, 1:, :].repeat(1, self.temporal_resolution, 1) | |
| tile_temporal_embed = self.temp_embed.repeat_interleave(npatch, 1) | |
| total_pos_embed = tile_pos_embed + tile_temporal_embed | |
| total_pos_embed = torch.cat([cls_embed, total_pos_embed], dim=1) | |
| x = x + total_pos_embed | |
| elif self.cfg.VIT.POS_EMBED == "joint": | |
| x = x + self.st_embed | |
| else: | |
| # image input | |
| x = x + new_pos_embed | |
| # Apply positional dropout | |
| x = self.pos_drop(x) | |
| # Encoding using transformer layers | |
| for i, blk in enumerate(self.blocks): | |
| x = blk(x, | |
| seq_len=npatch, | |
| num_frames=self.temporal_resolution, | |
| approx=self.cfg.VIT.APPROX_ATTN_TYPE, | |
| num_landmarks=self.cfg.VIT.APPROX_ATTN_DIM, | |
| tok_mask=tok_mask) | |
| ### v-iashin: I moved it to the forward pass | |
| # x = self.norm(x)[:, 0] | |
| # x = self.pre_logits(x) | |
| ### | |
| return x, tok_mask | |
| # def forward(self, x): | |
| # x = self.forward_features(x) | |
| # ### v-iashin: here. This should leave the same forward output as before | |
| # x = self.norm(x)[:, 0] | |
| # x = self.pre_logits(x) | |
| # ### | |
| # x = self.head_drop(x) | |
| # if isinstance(self.num_classes, (list, )) and len(self.num_classes) > 1: | |
| # output = [] | |
| # for head in range(len(self.num_classes)): | |
| # x_out = getattr(self, "head%d" % head)(x) | |
| # if not self.training: | |
| # x_out = torch.nn.functional.softmax(x_out, dim=-1) | |
| # output.append(x_out) | |
| # return output | |
| # else: | |
| # x = self.head(x) | |
| # if not self.training: | |
| # x = torch.nn.functional.softmax(x, dim=-1) | |
| # return x | |