File size: 6,961 Bytes
7439e48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
import argparse
import math
import torch

from gfpgan.archs.gfpganv1_clean_arch import GFPGANv1Clean


def modify_checkpoint(checkpoint_bilinear, checkpoint_clean):
    for ori_k, ori_v in checkpoint_bilinear.items():
        if 'stylegan_decoder' in ori_k:
            if 'style_mlp' in ori_k:  # style_mlp_layers
                lr_mul = 0.01
                prefix, name, idx, var = ori_k.split('.')
                idx = (int(idx) * 2) - 1
                crt_k = f'{prefix}.{name}.{idx}.{var}'
                if var == 'weight':
                    _, c_in = ori_v.size()
                    scale = (1 / math.sqrt(c_in)) * lr_mul
                    crt_v = ori_v * scale * 2**0.5
                else:
                    crt_v = ori_v * lr_mul * 2**0.5
                checkpoint_clean[crt_k] = crt_v
            elif 'modulation' in ori_k:  # modulation in StyleConv
                lr_mul = 1
                crt_k = ori_k
                var = ori_k.split('.')[-1]
                if var == 'weight':
                    _, c_in = ori_v.size()
                    scale = (1 / math.sqrt(c_in)) * lr_mul
                    crt_v = ori_v * scale
                else:
                    crt_v = ori_v * lr_mul
                checkpoint_clean[crt_k] = crt_v
            elif 'style_conv' in ori_k:
                # StyleConv in style_conv1 and style_convs
                if 'activate' in ori_k:  # FusedLeakyReLU
                    # eg. style_conv1.activate.bias
                    # eg. style_convs.13.activate.bias
                    split_rlt = ori_k.split('.')
                    if len(split_rlt) == 4:
                        prefix, name, _, var = split_rlt
                        crt_k = f'{prefix}.{name}.{var}'
                    elif len(split_rlt) == 5:
                        prefix, name, idx, _, var = split_rlt
                        crt_k = f'{prefix}.{name}.{idx}.{var}'
                    crt_v = ori_v * 2**0.5  # 2**0.5 used in FusedLeakyReLU
                    c = crt_v.size(0)
                    checkpoint_clean[crt_k] = crt_v.view(1, c, 1, 1)
                elif 'modulated_conv' in ori_k:
                    # eg. style_conv1.modulated_conv.weight
                    # eg. style_convs.13.modulated_conv.weight
                    _, c_out, c_in, k1, k2 = ori_v.size()
                    scale = 1 / math.sqrt(c_in * k1 * k2)
                    crt_k = ori_k
                    checkpoint_clean[crt_k] = ori_v * scale
                elif 'weight' in ori_k:
                    crt_k = ori_k
                    checkpoint_clean[crt_k] = ori_v * 2**0.5
            elif 'to_rgb' in ori_k:  # StyleConv in to_rgb1 and to_rgbs
                if 'modulated_conv' in ori_k:
                    # eg. to_rgb1.modulated_conv.weight
                    # eg. to_rgbs.5.modulated_conv.weight
                    _, c_out, c_in, k1, k2 = ori_v.size()
                    scale = 1 / math.sqrt(c_in * k1 * k2)
                    crt_k = ori_k
                    checkpoint_clean[crt_k] = ori_v * scale
                else:
                    crt_k = ori_k
                    checkpoint_clean[crt_k] = ori_v
            else:
                crt_k = ori_k
                checkpoint_clean[crt_k] = ori_v
            # end of 'stylegan_decoder'
        elif 'conv_body_first' in ori_k or 'final_conv' in ori_k:
            # key name
            name, _, var = ori_k.split('.')
            crt_k = f'{name}.{var}'
            # weight and bias
            if var == 'weight':
                c_out, c_in, k1, k2 = ori_v.size()
                scale = 1 / math.sqrt(c_in * k1 * k2)
                checkpoint_clean[crt_k] = ori_v * scale * 2**0.5
            else:
                checkpoint_clean[crt_k] = ori_v * 2**0.5
        elif 'conv_body' in ori_k:
            if 'conv_body_up' in ori_k:
                ori_k = ori_k.replace('conv2.weight', 'conv2.1.weight')
                ori_k = ori_k.replace('skip.weight', 'skip.1.weight')
            name1, idx1, name2, _, var = ori_k.split('.')
            crt_k = f'{name1}.{idx1}.{name2}.{var}'
            if name2 == 'skip':
                c_out, c_in, k1, k2 = ori_v.size()
                scale = 1 / math.sqrt(c_in * k1 * k2)
                checkpoint_clean[crt_k] = ori_v * scale / 2**0.5
            else:
                if var == 'weight':
                    c_out, c_in, k1, k2 = ori_v.size()
                    scale = 1 / math.sqrt(c_in * k1 * k2)
                    checkpoint_clean[crt_k] = ori_v * scale
                else:
                    checkpoint_clean[crt_k] = ori_v
                if 'conv1' in ori_k:
                    checkpoint_clean[crt_k] *= 2**0.5
        elif 'toRGB' in ori_k:
            crt_k = ori_k
            if 'weight' in ori_k:
                c_out, c_in, k1, k2 = ori_v.size()
                scale = 1 / math.sqrt(c_in * k1 * k2)
                checkpoint_clean[crt_k] = ori_v * scale
            else:
                checkpoint_clean[crt_k] = ori_v
        elif 'final_linear' in ori_k:
            crt_k = ori_k
            if 'weight' in ori_k:
                _, c_in = ori_v.size()
                scale = 1 / math.sqrt(c_in)
                checkpoint_clean[crt_k] = ori_v * scale
            else:
                checkpoint_clean[crt_k] = ori_v
        elif 'condition' in ori_k:
            crt_k = ori_k
            if '0.weight' in ori_k:
                c_out, c_in, k1, k2 = ori_v.size()
                scale = 1 / math.sqrt(c_in * k1 * k2)
                checkpoint_clean[crt_k] = ori_v * scale * 2**0.5
            elif '0.bias' in ori_k:
                checkpoint_clean[crt_k] = ori_v * 2**0.5
            elif '2.weight' in ori_k:
                c_out, c_in, k1, k2 = ori_v.size()
                scale = 1 / math.sqrt(c_in * k1 * k2)
                checkpoint_clean[crt_k] = ori_v * scale
            elif '2.bias' in ori_k:
                checkpoint_clean[crt_k] = ori_v

    return checkpoint_clean


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--ori_path', type=str, help='Path to the original model')
    parser.add_argument('--narrow', type=float, default=1)
    parser.add_argument('--channel_multiplier', type=float, default=2)
    parser.add_argument('--save_path', type=str)
    args = parser.parse_args()

    ori_ckpt = torch.load(args.ori_path)['params_ema']

    net = GFPGANv1Clean(
        512,
        num_style_feat=512,
        channel_multiplier=args.channel_multiplier,
        decoder_load_path=None,
        fix_decoder=False,
        # for stylegan decoder
        num_mlp=8,
        input_is_latent=True,
        different_w=True,
        narrow=args.narrow,
        sft_half=True)
    crt_ckpt = net.state_dict()

    crt_ckpt = modify_checkpoint(ori_ckpt, crt_ckpt)
    print(f'Save to {args.save_path}.')
    torch.save(dict(params_ema=crt_ckpt), args.save_path, _use_new_zipfile_serialization=False)