File size: 8,695 Bytes
71b93be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import tensorflow as tf
import torch
import numpy as np


def main(model_name: str = "efficientnetv2-s",
         tf_weights_path: str = "./efficientnetv2-s/model",
         stage0_num: int = 2,
         fused_conv_num: int = 10):

    except_var = ["global_step"]

    new_weights = {}
    var_list = [i for i in tf.train.list_variables(tf_weights_path) if "Exponential" not in i[0]]
    reader = tf.train.load_checkpoint(tf_weights_path)
    for v in var_list:
        if v[0] in except_var:
            continue
        new_name = v[0].replace(model_name + "/", "").replace("/", ".")

        if "stem" in v[0]:
            new_name = new_name.replace("conv2d.kernel",
                                        "conv.weight")

            new_name = new_name.replace("tpu_batch_normalization.beta",
                                        "bn.bias")
            new_name = new_name.replace("tpu_batch_normalization.gamma",
                                        "bn.weight")
            new_name = new_name.replace("tpu_batch_normalization.moving_mean",
                                        "bn.running_mean")
            new_name = new_name.replace("tpu_batch_normalization.moving_variance",
                                        "bn.running_var")
        elif "head" in v[0]:
            new_name = new_name.replace("conv2d.kernel",
                                        "project_conv.conv.weight")
            new_name = new_name.replace("dense.kernel",
                                        "classifier.weight")
            new_name = new_name.replace("dense.bias",
                                        "classifier.bias")

            new_name = new_name.replace("tpu_batch_normalization.beta",
                                        "project_conv.bn.bias")
            new_name = new_name.replace("tpu_batch_normalization.gamma",
                                        "project_conv.bn.weight")
            new_name = new_name.replace("tpu_batch_normalization.moving_mean",
                                        "project_conv.bn.running_mean")
            new_name = new_name.replace("tpu_batch_normalization.moving_variance",
                                        "project_conv.bn.running_var")
        elif "blocks" in v[0]:
            # e.g. blocks_0.conv2d.kernel -> 0
            blocks_id = new_name.split(".", maxsplit=1)[0].replace("blocks_", "")
            new_name = new_name.replace("blocks_{}".format(blocks_id),
                                        "blocks.{}".format(blocks_id))

            if int(blocks_id) <= stage0_num - 1:  # expansion=1 fused_mbconv
                new_name = new_name.replace("conv2d.kernel",
                                            "project_conv.conv.weight")
                new_name = new_name.replace("tpu_batch_normalization.beta",
                                            "project_conv.bn.bias")
                new_name = new_name.replace("tpu_batch_normalization.gamma",
                                            "project_conv.bn.weight")
                new_name = new_name.replace("tpu_batch_normalization.moving_mean",
                                            "project_conv.bn.running_mean")
                new_name = new_name.replace("tpu_batch_normalization.moving_variance",
                                            "project_conv.bn.running_var")
            else:
                new_name = new_name.replace("blocks.{}.conv2d.kernel".format(blocks_id),
                                            "blocks.{}.expand_conv.conv.weight".format(blocks_id))
                new_name = new_name.replace("tpu_batch_normalization.beta",
                                            "expand_conv.bn.bias")
                new_name = new_name.replace("tpu_batch_normalization.gamma",
                                            "expand_conv.bn.weight")
                new_name = new_name.replace("tpu_batch_normalization.moving_mean",
                                            "expand_conv.bn.running_mean")
                new_name = new_name.replace("tpu_batch_normalization.moving_variance",
                                            "expand_conv.bn.running_var")

                if int(blocks_id) <= fused_conv_num - 1:  # fused_mbconv
                    new_name = new_name.replace("blocks.{}.conv2d_1.kernel".format(blocks_id),
                                                "blocks.{}.project_conv.conv.weight".format(blocks_id))
                    new_name = new_name.replace("tpu_batch_normalization_1.beta",
                                                "project_conv.bn.bias")
                    new_name = new_name.replace("tpu_batch_normalization_1.gamma",
                                                "project_conv.bn.weight")
                    new_name = new_name.replace("tpu_batch_normalization_1.moving_mean",
                                                "project_conv.bn.running_mean")
                    new_name = new_name.replace("tpu_batch_normalization_1.moving_variance",
                                                "project_conv.bn.running_var")
                else:  # mbconv
                    new_name = new_name.replace("blocks.{}.conv2d_1.kernel".format(blocks_id),
                                                "blocks.{}.project_conv.conv.weight".format(blocks_id))

                    new_name = new_name.replace("depthwise_conv2d.depthwise_kernel",
                                                "dwconv.conv.weight")

                    new_name = new_name.replace("tpu_batch_normalization_1.beta",
                                                "dwconv.bn.bias")
                    new_name = new_name.replace("tpu_batch_normalization_1.gamma",
                                                "dwconv.bn.weight")
                    new_name = new_name.replace("tpu_batch_normalization_1.moving_mean",
                                                "dwconv.bn.running_mean")
                    new_name = new_name.replace("tpu_batch_normalization_1.moving_variance",
                                                "dwconv.bn.running_var")

                    new_name = new_name.replace("tpu_batch_normalization_2.beta",
                                                "project_conv.bn.bias")
                    new_name = new_name.replace("tpu_batch_normalization_2.gamma",
                                                "project_conv.bn.weight")
                    new_name = new_name.replace("tpu_batch_normalization_2.moving_mean",
                                                "project_conv.bn.running_mean")
                    new_name = new_name.replace("tpu_batch_normalization_2.moving_variance",
                                                "project_conv.bn.running_var")

                    new_name = new_name.replace("se.conv2d.bias",
                                                "se.conv_reduce.bias")
                    new_name = new_name.replace("se.conv2d.kernel",
                                                "se.conv_reduce.weight")
                    new_name = new_name.replace("se.conv2d_1.bias",
                                                "se.conv_expand.bias")
                    new_name = new_name.replace("se.conv2d_1.kernel",
                                                "se.conv_expand.weight")
        else:
            print("not recognized name: " + v[0])

        var = reader.get_tensor(v[0])
        new_var = var
        if "conv" in new_name and "weight" in new_name and "bn" not in new_name and "dw" not in new_name:
            assert len(var.shape) == 4
            # conv kernel [h, w, c, n] -> [n, c, h, w]
            new_var = np.transpose(var, (3, 2, 0, 1))
        elif "bn" in new_name:
            pass
        elif "dwconv" in new_name and "weight" in new_name:
            # dw_kernel [h, w, n, c] -> [n, c, h, w]
            assert len(var.shape) == 4
            new_var = np.transpose(var, (2, 3, 0, 1))
        elif "classifier" in new_name and "weight" in new_name:
            assert len(var.shape) == 2
            new_var = np.transpose(var, (1, 0))

        new_weights[new_name] = torch.as_tensor(new_var)

    torch.save(new_weights, "pre_" + model_name + ".pth")


if __name__ == '__main__':
    main(model_name="efficientnetv2-s",
         tf_weights_path="./efficientnetv2-s/model",
         stage0_num=2,
         fused_conv_num=10)

    # main(model_name="efficientnetv2-m",
    #      tf_weights_path="./efficientnetv2-m/model",
    #      stage0_num=3,
    #      fused_conv_num=13)

    # main(model_name="efficientnetv2-l",
    #      tf_weights_path="./efficientnetv2-l/model",
    #      stage0_num=4,
    #      fused_conv_num=18)