File size: 6,183 Bytes
2fd6166
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import math

import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers import ModelMixin
from timm.models.vision_transformer import VisionTransformer, resize_pos_embed
from torch import Tensor
from torchvision.transforms import functional as TVF


IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)

MODEL_URLS = {
    'vit_base_patch16_224_mae': 'https://dl.fbaipublicfiles.com/mae/pretrain/mae_pretrain_vit_base.pth',
    'vit_small_patch16_224_msn': 'https://dl.fbaipublicfiles.com/msn/vits16_800ep.pth.tar',
    'vit_large_patch7_224_msn': 'https://dl.fbaipublicfiles.com/msn/vitl7_200ep.pth.tar',
}

NORMALIZATION = {
    'vit_base_patch16_224_mae': (IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
    'vit_small_patch16_224_msn': (IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
    'vit_large_patch7_224_msn': (IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
}

MODEL_KWARGS = {
    'vit_base_patch16_224_mae': dict(
        patch_size=16, embed_dim=768, depth=12, num_heads=12,
    ), 
    'vit_small_patch16_224_msn': dict(
        patch_size=16, embed_dim=384, depth=12, num_heads=6,
    ),
    'vit_large_patch7_224_msn': dict(
        patch_size=7, embed_dim=1024, depth=24, num_heads=16,
    )
}


class FeatureModel(ModelMixin, ConfigMixin):

    @register_to_config
    def __init__(
        self, 
        image_size: int = 224,
        model_name: str = 'vit_small_patch16_224_mae',
        global_pool: str = '',  # '' or 'token'
    ) -> None:
        super().__init__()
        self.model_name = model_name

        # Identity
        if self.model_name == 'identity':
            return

        # Create model
        self.model = VisionTransformer(
            img_size=image_size, num_classes=0, global_pool=global_pool,
            **MODEL_KWARGS[model_name])

        # Model properties
        self.feature_dim = self.model.embed_dim
        self.mean, self.std = NORMALIZATION[model_name]

        # # Modify MSN model with output head from training
        # if model_name.endswith('msn'):
        #     use_bn = True
        #     emb_dim = (192 if 'tiny' in model_name else 384 if 'small' in model_name else 
        #         768 if 'base' in model_name else 1024 if 'large' in model_name else 1280)
        #     hidden_dim = 2048
        #     output_dim = 256
        #     self.model.fc = None
        #     fc = OrderedDict([])
        #     fc['fc1'] = torch.nn.Linear(emb_dim, hidden_dim)
        #     if use_bn:
        #         fc['bn1'] = torch.nn.BatchNorm1d(hidden_dim)
        #     fc['gelu1'] = torch.nn.GELU()
        #     fc['fc2'] = torch.nn.Linear(hidden_dim, hidden_dim)
        #     if use_bn:
        #         fc['bn2'] = torch.nn.BatchNorm1d(hidden_dim)
        #     fc['gelu2'] = torch.nn.GELU()
        #     fc['fc3'] = torch.nn.Linear(hidden_dim, output_dim)
        #     self.model.fc = torch.nn.Sequential(fc)
        
        # Load pretrained checkpoint
        checkpoint = torch.hub.load_state_dict_from_url(MODEL_URLS[model_name])
        if 'model' in checkpoint:
            state_dict = checkpoint['model']
        elif 'target_encoder' in checkpoint:
            state_dict = checkpoint['target_encoder']
            state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
            # NOTE: Comment the line below if using the projection head, uncomment if not using it
            # See https://github.com/facebookresearch/msn/blob/81cb855006f41cd993fbaad4b6a6efbb486488e6/src/msn_train.py#L490-L502
            # for more info about the projection head
            state_dict = {k: v for k, v in state_dict.items() if not k.startswith('fc.')}
        else:
            raise NotImplementedError()
        state_dict['pos_embed'] = resize_pos_embed(state_dict['pos_embed'], self.model.pos_embed)
        self.model.load_state_dict(state_dict)
        self.model.eval()

        # # Modify MSN model with output head from training
        # if model_name.endswith('msn'):
        #     self.fc = self.model.fc
        #     del self.model.fc
        # else:
        #     self.fc = nn.Identity()
        
        # NOTE: I've disabled the whole projection head stuff for simplicity for now
        self.fc = nn.Identity()

    def denormalize(self, img: Tensor):
        img = TVF.normalize(img, mean=[-m/s for m, s in zip(self.mean, self.std)], std=[1/s for s in self.std])
        return torch.clip(img, 0, 1)

    def normalize(self, img: Tensor):
        return TVF.normalize(img, mean=self.mean, std=self.std)

    def forward(
        self, 
        x: Tensor, 
        return_type: str = 'features',
        return_upscaled_features: bool = True,
        return_projection_head_output: bool = False,
    ):
        """Normalizes the input `x` and runs it through `model` to obtain features"""
        assert return_type in {'cls_token', 'features', 'all'}

        # Identity
        if self.model_name == 'identity':
            return x
        
        # Normalize and forward
        B, C, H, W = x.shape
        x = self.normalize(x)
        feats = self.model(x)

        # Reshape to image-like size
        if return_type in {'features', 'all'}:
            B, T, D = feats.shape
            assert math.sqrt(T - 1).is_integer()
            HW_down = int(math.sqrt(T - 1))  # subtract one for CLS token
            output_feats: Tensor = feats[:, 1:, :].reshape(B, HW_down, HW_down, D).permute(0, 3, 1, 2)  # (B, D, H_down, W_down)
            if return_upscaled_features:
                output_feats = F.interpolate(output_feats, size=(H, W), mode='bilinear',
                    align_corners=False)  # (B, D, H_orig, W_orig)

        # Head for MSN
        output_cls = feats[:, 0]
        if return_projection_head_output and return_type in {'cls_token', 'all'}:
            output_cls = self.fc(output_cls)
        
        # Return
        if return_type == 'cls_token':
            return output_cls
        elif return_type == 'features':
            return output_feats
        else:
            return output_cls, output_feats