HDM-interaction-recon / model /feature_model.py
xiexh20's picture
add hdm demo v1
2fd6166
raw
history blame
6.18 kB
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers import ModelMixin
from timm.models.vision_transformer import VisionTransformer, resize_pos_embed
from torch import Tensor
from torchvision.transforms import functional as TVF
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
MODEL_URLS = {
'vit_base_patch16_224_mae': 'https://dl.fbaipublicfiles.com/mae/pretrain/mae_pretrain_vit_base.pth',
'vit_small_patch16_224_msn': 'https://dl.fbaipublicfiles.com/msn/vits16_800ep.pth.tar',
'vit_large_patch7_224_msn': 'https://dl.fbaipublicfiles.com/msn/vitl7_200ep.pth.tar',
}
NORMALIZATION = {
'vit_base_patch16_224_mae': (IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
'vit_small_patch16_224_msn': (IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
'vit_large_patch7_224_msn': (IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
}
MODEL_KWARGS = {
'vit_base_patch16_224_mae': dict(
patch_size=16, embed_dim=768, depth=12, num_heads=12,
),
'vit_small_patch16_224_msn': dict(
patch_size=16, embed_dim=384, depth=12, num_heads=6,
),
'vit_large_patch7_224_msn': dict(
patch_size=7, embed_dim=1024, depth=24, num_heads=16,
)
}
class FeatureModel(ModelMixin, ConfigMixin):
@register_to_config
def __init__(
self,
image_size: int = 224,
model_name: str = 'vit_small_patch16_224_mae',
global_pool: str = '', # '' or 'token'
) -> None:
super().__init__()
self.model_name = model_name
# Identity
if self.model_name == 'identity':
return
# Create model
self.model = VisionTransformer(
img_size=image_size, num_classes=0, global_pool=global_pool,
**MODEL_KWARGS[model_name])
# Model properties
self.feature_dim = self.model.embed_dim
self.mean, self.std = NORMALIZATION[model_name]
# # Modify MSN model with output head from training
# if model_name.endswith('msn'):
# use_bn = True
# emb_dim = (192 if 'tiny' in model_name else 384 if 'small' in model_name else
# 768 if 'base' in model_name else 1024 if 'large' in model_name else 1280)
# hidden_dim = 2048
# output_dim = 256
# self.model.fc = None
# fc = OrderedDict([])
# fc['fc1'] = torch.nn.Linear(emb_dim, hidden_dim)
# if use_bn:
# fc['bn1'] = torch.nn.BatchNorm1d(hidden_dim)
# fc['gelu1'] = torch.nn.GELU()
# fc['fc2'] = torch.nn.Linear(hidden_dim, hidden_dim)
# if use_bn:
# fc['bn2'] = torch.nn.BatchNorm1d(hidden_dim)
# fc['gelu2'] = torch.nn.GELU()
# fc['fc3'] = torch.nn.Linear(hidden_dim, output_dim)
# self.model.fc = torch.nn.Sequential(fc)
# Load pretrained checkpoint
checkpoint = torch.hub.load_state_dict_from_url(MODEL_URLS[model_name])
if 'model' in checkpoint:
state_dict = checkpoint['model']
elif 'target_encoder' in checkpoint:
state_dict = checkpoint['target_encoder']
state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
# NOTE: Comment the line below if using the projection head, uncomment if not using it
# See https://github.com/facebookresearch/msn/blob/81cb855006f41cd993fbaad4b6a6efbb486488e6/src/msn_train.py#L490-L502
# for more info about the projection head
state_dict = {k: v for k, v in state_dict.items() if not k.startswith('fc.')}
else:
raise NotImplementedError()
state_dict['pos_embed'] = resize_pos_embed(state_dict['pos_embed'], self.model.pos_embed)
self.model.load_state_dict(state_dict)
self.model.eval()
# # Modify MSN model with output head from training
# if model_name.endswith('msn'):
# self.fc = self.model.fc
# del self.model.fc
# else:
# self.fc = nn.Identity()
# NOTE: I've disabled the whole projection head stuff for simplicity for now
self.fc = nn.Identity()
def denormalize(self, img: Tensor):
img = TVF.normalize(img, mean=[-m/s for m, s in zip(self.mean, self.std)], std=[1/s for s in self.std])
return torch.clip(img, 0, 1)
def normalize(self, img: Tensor):
return TVF.normalize(img, mean=self.mean, std=self.std)
def forward(
self,
x: Tensor,
return_type: str = 'features',
return_upscaled_features: bool = True,
return_projection_head_output: bool = False,
):
"""Normalizes the input `x` and runs it through `model` to obtain features"""
assert return_type in {'cls_token', 'features', 'all'}
# Identity
if self.model_name == 'identity':
return x
# Normalize and forward
B, C, H, W = x.shape
x = self.normalize(x)
feats = self.model(x)
# Reshape to image-like size
if return_type in {'features', 'all'}:
B, T, D = feats.shape
assert math.sqrt(T - 1).is_integer()
HW_down = int(math.sqrt(T - 1)) # subtract one for CLS token
output_feats: Tensor = feats[:, 1:, :].reshape(B, HW_down, HW_down, D).permute(0, 3, 1, 2) # (B, D, H_down, W_down)
if return_upscaled_features:
output_feats = F.interpolate(output_feats, size=(H, W), mode='bilinear',
align_corners=False) # (B, D, H_orig, W_orig)
# Head for MSN
output_cls = feats[:, 0]
if return_projection_head_output and return_type in {'cls_token', 'all'}:
output_cls = self.fc(output_cls)
# Return
if return_type == 'cls_token':
return output_cls
elif return_type == 'features':
return output_feats
else:
return output_cls, output_feats