File size: 1,622 Bytes
749745d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import numpy as np
import torch
import torch.nn as nn

from collections import OrderedDict


def _remove_bn_statics(state_dict):
    layer_keys = sorted(state_dict.keys())
    remove_list = []
    for key in layer_keys:
        if "running_mean" in key or "running_var" in key or "num_batches_tracked" in key:
            remove_list.append(key)
    for key in remove_list:
        del state_dict[key]
    return state_dict


def _rename_conv_weights_for_deformable_conv_layers(state_dict, cfg):
    import re

    layer_keys = sorted(state_dict.keys())
    for ix, stage_with_dcn in enumerate(cfg.MODEL.RESNETS.STAGE_WITH_DCN, 1):
        if not stage_with_dcn:
            continue
        for old_key in layer_keys:
            pattern = ".*layer{}.*conv2.*".format(ix)
            r = re.match(pattern, old_key)
            if r is None:
                continue
            for param in ["weight", "bias"]:
                if old_key.find(param) is -1:
                    continue
                if "unit01" in old_key:
                    continue
                new_key = old_key.replace("conv2.{}".format(param), "conv2.conv.{}".format(param))
                print("pattern: {}, old_key: {}, new_key: {}".format(pattern, old_key, new_key))
                state_dict[new_key] = state_dict[old_key]
                del state_dict[old_key]
    return state_dict


def load_pretrain_format(cfg, f):
    model = torch.load(f)
    model = _remove_bn_statics(model)
    model = _rename_conv_weights_for_deformable_conv_layers(model, cfg)

    return dict(model=model)