File size: 3,201 Bytes
749745d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import numpy as np
import torch
import torch.nn as nn

from collections import OrderedDict


def tf2th(conv_weights):
    """Possibly convert HWIO to OIHW."""
    if conv_weights.ndim == 4:
        conv_weights = conv_weights.transpose([3, 2, 0, 1])
    return torch.from_numpy(conv_weights)


def _rename_conv_weights_for_deformable_conv_layers(state_dict, cfg):
    import re

    layer_keys = sorted(state_dict.keys())
    for ix, stage_with_dcn in enumerate(cfg.MODEL.RESNETS.STAGE_WITH_DCN, 1):
        if not stage_with_dcn:
            continue
        for old_key in layer_keys:
            pattern = ".*block{}.*conv2.*".format(ix)
            r = re.match(pattern, old_key)
            if r is None:
                continue
            for param in ["weight", "bias"]:
                if old_key.find(param) is -1:
                    continue
                if "unit01" in old_key:
                    continue
                new_key = old_key.replace("conv2.{}".format(param), "conv2.conv.{}".format(param))
                print("pattern: {}, old_key: {}, new_key: {}".format(pattern, old_key, new_key))
                # Calculate SD conv weight
                w = state_dict[old_key]
                v, m = torch.var_mean(w, dim=[1, 2, 3], keepdim=True, unbiased=False)
                w = (w - m) / torch.sqrt(v + 1e-10)

                state_dict[new_key] = w
                del state_dict[old_key]
    return state_dict


def load_big_format(cfg, f):
    model = OrderedDict()
    weights = np.load(f)

    cmap = {"a": 1, "b": 2, "c": 3}
    for key, val in weights.items():
        old_key = key.replace("resnet/", "")
        if "root_block" in old_key:
            new_key = "root.conv.weight"
        elif "/proj/standardized_conv2d/kernel" in old_key:
            key_pattern = old_key.replace("/proj/standardized_conv2d/kernel", "").replace("resnet/", "")
            bname, uname, cidx = key_pattern.split("/")
            new_key = "{}.downsample.{}.conv{}.weight".format(bname, uname, cmap[cidx])
        elif "/standardized_conv2d/kernel" in old_key:
            key_pattern = old_key.replace("/standardized_conv2d/kernel", "").replace("resnet/", "")
            bname, uname, cidx = key_pattern.split("/")
            new_key = "{}.{}.conv{}.weight".format(bname, uname, cmap[cidx])
        elif "/group_norm/gamma" in old_key:
            key_pattern = old_key.replace("/group_norm/gamma", "").replace("resnet/", "")
            bname, uname, cidx = key_pattern.split("/")
            new_key = "{}.{}.gn{}.weight".format(bname, uname, cmap[cidx])
        elif "/group_norm/beta" in old_key:
            key_pattern = old_key.replace("/group_norm/beta", "").replace("resnet/", "")
            bname, uname, cidx = key_pattern.split("/")
            new_key = "{}.{}.gn{}.bias".format(bname, uname, cmap[cidx])
        else:
            print("Unknown key {}".format(old_key))
            continue
        print("Map {} -> {}".format(key, new_key))
        model[new_key] = tf2th(val)

    model = _rename_conv_weights_for_deformable_conv_layers(model, cfg)

    return dict(model=model)