File size: 5,960 Bytes
dd78229
179cb5d
dd78229
179cb5d
dd78229
 
 
179cb5d
 
 
 
dd78229
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179cb5d
 
b426e64
 
 
 
dd78229
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import os
from pathlib import Path

import yaml
from timm.models.helpers import load_pretrained, load_custom_pretrained
from timm.models.registry import register_model
from timm.models.vision_transformer import _create_vision_transformer
from timm.models.vision_transformer import default_cfgs, checkpoint_filter_fn

import segmenter_model.torch as ptu
import torch
from segmenter_model.decoder import MaskTransformer
from segmenter_model.segmenter import Segmenter
from segmenter_model.vit_dino import vit_small, VisionTransformer


@register_model
def vit_base_patch8_384(pretrained=False, **kwargs):
    """ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
    ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
    """
    model_kwargs = dict(patch_size=8, embed_dim=768, depth=12, num_heads=12, **kwargs)
    model = _create_vision_transformer(
        "vit_base_patch8_384",
        pretrained=pretrained,
        default_cfg=dict(
            url="",
            input_size=(3, 384, 384),
            mean=(0.5, 0.5, 0.5),
            std=(0.5, 0.5, 0.5),
            num_classes=1000,
        ),
        **model_kwargs,
    )
    return model


def create_vit(model_cfg):
    model_cfg = model_cfg.copy()
    backbone = model_cfg.pop("backbone")
    if 'pretrained_weights' in model_cfg:
        pretrained_weights = model_cfg.pop('pretrained_weights')

    if 'dino' in backbone:
        if backbone.lower() == 'dino_vits16':
            model_cfg['drop_rate'] = model_cfg['dropout']
            model = vit_small(**model_cfg)
            # hard-coded for now, too lazy
            pretrained_weights = 'dino_deitsmall16_pretrain.pth'
            if not os.path.exists(pretrained_weights):
                import urllib.request
                urllib.request.urlretrieve(
                    "https://dl.fbaipublicfiles.com/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth",
                    pretrained_weights)
            model.load_state_dict(torch.load(pretrained_weights), strict=True)
        else:
            model = torch.hub.load('facebookresearch/dino:main', backbone)
        setattr(model, 'd_model', model.num_features)
        setattr(model, 'patch_size', model.patch_embed.patch_size)
        setattr(model, 'distilled', False)
        model.forward = lambda x, return_features: model.get_intermediate_layers(x, n=1)[0]
    else:
        normalization = model_cfg.pop("normalization")
        model_cfg["n_cls"] = 1000
        mlp_expansion_ratio = 4
        model_cfg["d_ff"] = mlp_expansion_ratio * model_cfg["d_model"]

        if backbone in default_cfgs:
            default_cfg = default_cfgs[backbone]
        else:
            default_cfg = dict(
                pretrained=False,
                num_classes=1000,
                drop_rate=0.0,
                drop_path_rate=0.0,
                drop_block_rate=None,
            )

        default_cfg["input_size"] = (
            3,
            model_cfg["image_size"][0],
            model_cfg["image_size"][1],
        )
        model = VisionTransformer(**model_cfg)
        if backbone == "vit_base_patch8_384":
            path = os.path.expandvars("/home/vobecant/PhD/weights/vit_base_patch8_384.pth")
            state_dict = torch.load(path, map_location="cpu")
            filtered_dict = checkpoint_filter_fn(state_dict, model)
            model.load_state_dict(filtered_dict, strict=True)
        elif "deit" in backbone:
            load_pretrained(model, default_cfg, filter_fn=checkpoint_filter_fn)
        else:
            load_custom_pretrained(model, default_cfg)

    return model


def create_decoder(encoder, decoder_cfg):
    decoder_cfg = decoder_cfg.copy()
    name = decoder_cfg.pop("name")
    decoder_cfg["d_encoder"] = encoder.d_model
    decoder_cfg["patch_size"] = encoder.patch_size

    if "linear" in name:
        decoder = DecoderLinear(**decoder_cfg)
    elif name == "mask_transformer":
        dim = encoder.d_model
        n_heads = dim // 64
        decoder_cfg["n_heads"] = n_heads
        decoder_cfg["d_model"] = dim
        decoder_cfg["d_ff"] = 4 * dim
        decoder = MaskTransformer(**decoder_cfg)
    elif 'deeplab' in name:
        decoder = DeepLabHead(in_channels=encoder.d_model, num_classes=decoder_cfg["n_cls"],
                              patch_size=decoder_cfg["patch_size"])
    else:
        raise ValueError(f"Unknown decoder: {name}")
    return decoder


def create_segmenter(model_cfg):
    model_cfg = model_cfg.copy()
    decoder_cfg = model_cfg.pop("decoder")
    decoder_cfg["n_cls"] = model_cfg["n_cls"]

    if 'weights_path' in model_cfg.keys():
        weights_path = model_cfg.pop('weights_path')
    else:
        weights_path = None

    encoder = create_vit(model_cfg)
    decoder = create_decoder(encoder, decoder_cfg)
    model = Segmenter(encoder, decoder, n_cls=model_cfg["n_cls"])

    if weights_path is not None:
        raise Exception('Wants to load weights to the complete segmenter insice create_segmenter method!')
        state_dict = torch.load(weights_path, map_location="cpu")
        if 'model' in state_dict:
            state_dict = state_dict['model']
        msg = model.load_state_dict(state_dict, strict=False)
        print(msg)

    return model


def load_model(model_path, decoder_only=False, variant_path=None):
    variant_path = Path(model_path).parent / "variant.yml" if variant_path is None else variant_path
    with open(variant_path, "r") as f:
        variant = yaml.load(f, Loader=yaml.FullLoader)
    net_kwargs = variant["net_kwargs"]

    model = create_segmenter(net_kwargs)
    data = torch.load(model_path, map_location=ptu.device)
    checkpoint = data["model"]

    if decoder_only:
        model.decoder.load_state_dict(checkpoint, strict=True)
    else:
        model.load_state_dict(checkpoint, strict=True)

    return model, variant