File size: 5,567 Bytes
186701e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import argparse
from collections import OrderedDict

import torch


def convert(src, dst):
    import sys
    sys.path.append('yolov6')
    try:
        ckpt = torch.load(src, map_location=torch.device('cpu'))
    except ModuleNotFoundError:
        raise RuntimeError(
            'This script must be placed under the meituan/YOLOv6 repo,'
            ' because loading the official pretrained model need'
            ' some python files to build model.')
    # The saved model is the model before reparameterization
    model = ckpt['ema' if ckpt.get('ema') else 'model'].float()
    new_state_dict = OrderedDict()
    is_ns = False
    for k, v in model.state_dict().items():
        name = k
        if 'detect' in k:
            if 'proj' in k:
                continue
            if 'reg_preds_lrtb' in k:
                is_ns = True
            name = k.replace('detect', 'bbox_head.head_module')
        if k.find('anchors') >= 0 or k.find('anchor_grid') >= 0:
            continue

        if 'ERBlock_2' in k:
            name = k.replace('ERBlock_2', 'stage1.0')
            if '.cv' in k:
                name = name.replace('.cv', '.conv')
            if '.m.' in k:
                name = name.replace('.m.', '.block.')
        elif 'ERBlock_3' in k:
            name = k.replace('ERBlock_3', 'stage2.0')
            if '.cv' in k:
                name = name.replace('.cv', '.conv')
            if '.m.' in k:
                name = name.replace('.m.', '.block.')
        elif 'ERBlock_4' in k:
            name = k.replace('ERBlock_4', 'stage3.0')
            if '.cv' in k:
                name = name.replace('.cv', '.conv')
            if '.m.' in k:
                name = name.replace('.m.', '.block.')
        elif 'ERBlock_5' in k:
            name = k.replace('ERBlock_5', 'stage4.0')
            if '.cv' in k:
                name = name.replace('.cv', '.conv')
            if '.m.' in k:
                name = name.replace('.m.', '.block.')
            if 'stage4.0.2' in name:
                name = name.replace('stage4.0.2', 'stage4.1')
                name = name.replace('cv', 'conv')
        elif 'reduce_layer0' in k:
            name = k.replace('reduce_layer0', 'reduce_layers.2')
        elif 'Rep_p4' in k:
            name = k.replace('Rep_p4', 'top_down_layers.0.0')
            if '.cv' in k:
                name = name.replace('.cv', '.conv')
            if '.m.' in k:
                name = name.replace('.m.', '.block.')
        elif 'reduce_layer1' in k:
            name = k.replace('reduce_layer1', 'top_down_layers.0.1')
            if '.cv' in k:
                name = name.replace('.cv', '.conv')
            if '.m.' in k:
                name = name.replace('.m.', '.block.')
        elif 'Rep_p3' in k:
            name = k.replace('Rep_p3', 'top_down_layers.1')
            if '.cv' in k:
                name = name.replace('.cv', '.conv')
            if '.m.' in k:
                name = name.replace('.m.', '.block.')
        elif 'Bifusion0' in k:
            name = k.replace('Bifusion0', 'upsample_layers.0')
            if '.cv' in k:
                name = name.replace('.cv', '.conv')
            if '.m.' in k:
                name = name.replace('.m.', '.block.')
            if '.upsample_transpose.' in k:
                name = name.replace('.upsample_transpose.', '.')
        elif 'Bifusion1' in k:
            name = k.replace('Bifusion1', 'upsample_layers.1')
            if '.cv' in k:
                name = name.replace('.cv', '.conv')
            if '.m.' in k:
                name = name.replace('.m.', '.block.')
            if '.upsample_transpose.' in k:
                name = name.replace('.upsample_transpose.', '.')
        elif 'Rep_n3' in k:
            name = k.replace('Rep_n3', 'bottom_up_layers.0')
            if '.cv' in k:
                name = name.replace('.cv', '.conv')
            if '.m.' in k:
                name = name.replace('.m.', '.block.')
        elif 'Rep_n4' in k:
            name = k.replace('Rep_n4', 'bottom_up_layers.1')
            if '.cv' in k:
                name = name.replace('.cv', '.conv')
            if '.m.' in k:
                name = name.replace('.m.', '.block.')
        elif 'downsample2' in k:
            name = k.replace('downsample2', 'downsample_layers.0')
        elif 'downsample1' in k:
            name = k.replace('downsample1', 'downsample_layers.1')

        new_state_dict[name] = v

    # The yolov6_v3_n/s has two regression heads.
    # One called 'reg_preds_lrtb' is a regular anchor-free head,
    # which is used for inference.
    # One called 'reg_preds' is a DFL style head, which
    # is only used in training.
    if is_ns:
        tmp_state_dict = OrderedDict()
        for k, v in new_state_dict.items():
            name = k
            if 'reg_preds_lrtb' in k:
                name = k.replace('reg_preds_lrtb', 'reg_preds')
            elif 'reg_preds' in k:
                name = k.replace('reg_preds', 'distill_ns_head')
            tmp_state_dict[name] = v
        new_state_dict = tmp_state_dict

    data = {'state_dict': new_state_dict}
    torch.save(data, dst)


# Note: This script must be placed under the yolov6 repo to run.
def main():
    parser = argparse.ArgumentParser(description='Convert model keys')
    parser.add_argument(
        '--src', default='yolov6s.pt', help='src yolov6 model path')
    parser.add_argument('--dst', default='mmyolov6.pt', help='save path')
    args = parser.parse_args()
    convert(args.src, args.dst)


if __name__ == '__main__':
    main()