Spaces:
Sleeping
Sleeping
import math | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from diffusers.configuration_utils import ConfigMixin, register_to_config | |
from diffusers import ModelMixin | |
from timm.models.vision_transformer import VisionTransformer, resize_pos_embed | |
from torch import Tensor | |
from torchvision.transforms import functional as TVF | |
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) | |
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) | |
MODEL_URLS = { | |
'vit_base_patch16_224_mae': 'https://dl.fbaipublicfiles.com/mae/pretrain/mae_pretrain_vit_base.pth', | |
'vit_small_patch16_224_msn': 'https://dl.fbaipublicfiles.com/msn/vits16_800ep.pth.tar', | |
'vit_large_patch7_224_msn': 'https://dl.fbaipublicfiles.com/msn/vitl7_200ep.pth.tar', | |
} | |
NORMALIZATION = { | |
'vit_base_patch16_224_mae': (IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD), | |
'vit_small_patch16_224_msn': (IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD), | |
'vit_large_patch7_224_msn': (IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD), | |
} | |
MODEL_KWARGS = { | |
'vit_base_patch16_224_mae': dict( | |
patch_size=16, embed_dim=768, depth=12, num_heads=12, | |
), | |
'vit_small_patch16_224_msn': dict( | |
patch_size=16, embed_dim=384, depth=12, num_heads=6, | |
), | |
'vit_large_patch7_224_msn': dict( | |
patch_size=7, embed_dim=1024, depth=24, num_heads=16, | |
) | |
} | |
class FeatureModel(ModelMixin, ConfigMixin): | |
def __init__( | |
self, | |
image_size: int = 224, | |
model_name: str = 'vit_small_patch16_224_mae', | |
global_pool: str = '', # '' or 'token' | |
) -> None: | |
super().__init__() | |
self.model_name = model_name | |
# Identity | |
if self.model_name == 'identity': | |
return | |
# Create model | |
self.model = VisionTransformer( | |
img_size=image_size, num_classes=0, global_pool=global_pool, | |
**MODEL_KWARGS[model_name]) | |
# Model properties | |
self.feature_dim = self.model.embed_dim | |
self.mean, self.std = NORMALIZATION[model_name] | |
# # Modify MSN model with output head from training | |
# if model_name.endswith('msn'): | |
# use_bn = True | |
# emb_dim = (192 if 'tiny' in model_name else 384 if 'small' in model_name else | |
# 768 if 'base' in model_name else 1024 if 'large' in model_name else 1280) | |
# hidden_dim = 2048 | |
# output_dim = 256 | |
# self.model.fc = None | |
# fc = OrderedDict([]) | |
# fc['fc1'] = torch.nn.Linear(emb_dim, hidden_dim) | |
# if use_bn: | |
# fc['bn1'] = torch.nn.BatchNorm1d(hidden_dim) | |
# fc['gelu1'] = torch.nn.GELU() | |
# fc['fc2'] = torch.nn.Linear(hidden_dim, hidden_dim) | |
# if use_bn: | |
# fc['bn2'] = torch.nn.BatchNorm1d(hidden_dim) | |
# fc['gelu2'] = torch.nn.GELU() | |
# fc['fc3'] = torch.nn.Linear(hidden_dim, output_dim) | |
# self.model.fc = torch.nn.Sequential(fc) | |
# Load pretrained checkpoint | |
checkpoint = torch.hub.load_state_dict_from_url(MODEL_URLS[model_name]) | |
if 'model' in checkpoint: | |
state_dict = checkpoint['model'] | |
elif 'target_encoder' in checkpoint: | |
state_dict = checkpoint['target_encoder'] | |
state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()} | |
# NOTE: Comment the line below if using the projection head, uncomment if not using it | |
# See https://github.com/facebookresearch/msn/blob/81cb855006f41cd993fbaad4b6a6efbb486488e6/src/msn_train.py#L490-L502 | |
# for more info about the projection head | |
state_dict = {k: v for k, v in state_dict.items() if not k.startswith('fc.')} | |
else: | |
raise NotImplementedError() | |
state_dict['pos_embed'] = resize_pos_embed(state_dict['pos_embed'], self.model.pos_embed) | |
self.model.load_state_dict(state_dict) | |
self.model.eval() | |
# # Modify MSN model with output head from training | |
# if model_name.endswith('msn'): | |
# self.fc = self.model.fc | |
# del self.model.fc | |
# else: | |
# self.fc = nn.Identity() | |
# NOTE: I've disabled the whole projection head stuff for simplicity for now | |
self.fc = nn.Identity() | |
def denormalize(self, img: Tensor): | |
img = TVF.normalize(img, mean=[-m/s for m, s in zip(self.mean, self.std)], std=[1/s for s in self.std]) | |
return torch.clip(img, 0, 1) | |
def normalize(self, img: Tensor): | |
return TVF.normalize(img, mean=self.mean, std=self.std) | |
def forward( | |
self, | |
x: Tensor, | |
return_type: str = 'features', | |
return_upscaled_features: bool = True, | |
return_projection_head_output: bool = False, | |
): | |
"""Normalizes the input `x` and runs it through `model` to obtain features""" | |
assert return_type in {'cls_token', 'features', 'all'} | |
# Identity | |
if self.model_name == 'identity': | |
return x | |
# Normalize and forward | |
B, C, H, W = x.shape | |
x = self.normalize(x) | |
feats = self.model(x) | |
# Reshape to image-like size | |
if return_type in {'features', 'all'}: | |
B, T, D = feats.shape | |
assert math.sqrt(T - 1).is_integer() | |
HW_down = int(math.sqrt(T - 1)) # subtract one for CLS token | |
output_feats: Tensor = feats[:, 1:, :].reshape(B, HW_down, HW_down, D).permute(0, 3, 1, 2) # (B, D, H_down, W_down) | |
if return_upscaled_features: | |
output_feats = F.interpolate(output_feats, size=(H, W), mode='bilinear', | |
align_corners=False) # (B, D, H_orig, W_orig) | |
# Head for MSN | |
output_cls = feats[:, 0] | |
if return_projection_head_output and return_type in {'cls_token', 'all'}: | |
output_cls = self.fc(output_cls) | |
# Return | |
if return_type == 'cls_token': | |
return output_cls | |
elif return_type == 'features': | |
return output_feats | |
else: | |
return output_cls, output_feats | |