Spaces:
Sleeping
Sleeping
File size: 6,183 Bytes
2fd6166 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers import ModelMixin
from timm.models.vision_transformer import VisionTransformer, resize_pos_embed
from torch import Tensor
from torchvision.transforms import functional as TVF
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
MODEL_URLS = {
'vit_base_patch16_224_mae': 'https://dl.fbaipublicfiles.com/mae/pretrain/mae_pretrain_vit_base.pth',
'vit_small_patch16_224_msn': 'https://dl.fbaipublicfiles.com/msn/vits16_800ep.pth.tar',
'vit_large_patch7_224_msn': 'https://dl.fbaipublicfiles.com/msn/vitl7_200ep.pth.tar',
}
NORMALIZATION = {
'vit_base_patch16_224_mae': (IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
'vit_small_patch16_224_msn': (IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
'vit_large_patch7_224_msn': (IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
}
MODEL_KWARGS = {
'vit_base_patch16_224_mae': dict(
patch_size=16, embed_dim=768, depth=12, num_heads=12,
),
'vit_small_patch16_224_msn': dict(
patch_size=16, embed_dim=384, depth=12, num_heads=6,
),
'vit_large_patch7_224_msn': dict(
patch_size=7, embed_dim=1024, depth=24, num_heads=16,
)
}
class FeatureModel(ModelMixin, ConfigMixin):
@register_to_config
def __init__(
self,
image_size: int = 224,
model_name: str = 'vit_small_patch16_224_mae',
global_pool: str = '', # '' or 'token'
) -> None:
super().__init__()
self.model_name = model_name
# Identity
if self.model_name == 'identity':
return
# Create model
self.model = VisionTransformer(
img_size=image_size, num_classes=0, global_pool=global_pool,
**MODEL_KWARGS[model_name])
# Model properties
self.feature_dim = self.model.embed_dim
self.mean, self.std = NORMALIZATION[model_name]
# # Modify MSN model with output head from training
# if model_name.endswith('msn'):
# use_bn = True
# emb_dim = (192 if 'tiny' in model_name else 384 if 'small' in model_name else
# 768 if 'base' in model_name else 1024 if 'large' in model_name else 1280)
# hidden_dim = 2048
# output_dim = 256
# self.model.fc = None
# fc = OrderedDict([])
# fc['fc1'] = torch.nn.Linear(emb_dim, hidden_dim)
# if use_bn:
# fc['bn1'] = torch.nn.BatchNorm1d(hidden_dim)
# fc['gelu1'] = torch.nn.GELU()
# fc['fc2'] = torch.nn.Linear(hidden_dim, hidden_dim)
# if use_bn:
# fc['bn2'] = torch.nn.BatchNorm1d(hidden_dim)
# fc['gelu2'] = torch.nn.GELU()
# fc['fc3'] = torch.nn.Linear(hidden_dim, output_dim)
# self.model.fc = torch.nn.Sequential(fc)
# Load pretrained checkpoint
checkpoint = torch.hub.load_state_dict_from_url(MODEL_URLS[model_name])
if 'model' in checkpoint:
state_dict = checkpoint['model']
elif 'target_encoder' in checkpoint:
state_dict = checkpoint['target_encoder']
state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
# NOTE: Comment the line below if using the projection head, uncomment if not using it
# See https://github.com/facebookresearch/msn/blob/81cb855006f41cd993fbaad4b6a6efbb486488e6/src/msn_train.py#L490-L502
# for more info about the projection head
state_dict = {k: v for k, v in state_dict.items() if not k.startswith('fc.')}
else:
raise NotImplementedError()
state_dict['pos_embed'] = resize_pos_embed(state_dict['pos_embed'], self.model.pos_embed)
self.model.load_state_dict(state_dict)
self.model.eval()
# # Modify MSN model with output head from training
# if model_name.endswith('msn'):
# self.fc = self.model.fc
# del self.model.fc
# else:
# self.fc = nn.Identity()
# NOTE: I've disabled the whole projection head stuff for simplicity for now
self.fc = nn.Identity()
def denormalize(self, img: Tensor):
img = TVF.normalize(img, mean=[-m/s for m, s in zip(self.mean, self.std)], std=[1/s for s in self.std])
return torch.clip(img, 0, 1)
def normalize(self, img: Tensor):
return TVF.normalize(img, mean=self.mean, std=self.std)
def forward(
self,
x: Tensor,
return_type: str = 'features',
return_upscaled_features: bool = True,
return_projection_head_output: bool = False,
):
"""Normalizes the input `x` and runs it through `model` to obtain features"""
assert return_type in {'cls_token', 'features', 'all'}
# Identity
if self.model_name == 'identity':
return x
# Normalize and forward
B, C, H, W = x.shape
x = self.normalize(x)
feats = self.model(x)
# Reshape to image-like size
if return_type in {'features', 'all'}:
B, T, D = feats.shape
assert math.sqrt(T - 1).is_integer()
HW_down = int(math.sqrt(T - 1)) # subtract one for CLS token
output_feats: Tensor = feats[:, 1:, :].reshape(B, HW_down, HW_down, D).permute(0, 3, 1, 2) # (B, D, H_down, W_down)
if return_upscaled_features:
output_feats = F.interpolate(output_feats, size=(H, W), mode='bilinear',
align_corners=False) # (B, D, H_orig, W_orig)
# Head for MSN
output_cls = feats[:, 0]
if return_projection_head_output and return_type in {'cls_token', 'all'}:
output_cls = self.fc(output_cls)
# Return
if return_type == 'cls_token':
return output_cls
elif return_type == 'features':
return output_feats
else:
return output_cls, output_feats
|