|
import argparse |
|
import math |
|
import torch |
|
|
|
from gfpgan.archs.gfpganv1_clean_arch import GFPGANv1Clean |
|
|
|
|
|
def modify_checkpoint(checkpoint_bilinear, checkpoint_clean): |
|
for ori_k, ori_v in checkpoint_bilinear.items(): |
|
if 'stylegan_decoder' in ori_k: |
|
if 'style_mlp' in ori_k: |
|
lr_mul = 0.01 |
|
prefix, name, idx, var = ori_k.split('.') |
|
idx = (int(idx) * 2) - 1 |
|
crt_k = f'{prefix}.{name}.{idx}.{var}' |
|
if var == 'weight': |
|
_, c_in = ori_v.size() |
|
scale = (1 / math.sqrt(c_in)) * lr_mul |
|
crt_v = ori_v * scale * 2**0.5 |
|
else: |
|
crt_v = ori_v * lr_mul * 2**0.5 |
|
checkpoint_clean[crt_k] = crt_v |
|
elif 'modulation' in ori_k: |
|
lr_mul = 1 |
|
crt_k = ori_k |
|
var = ori_k.split('.')[-1] |
|
if var == 'weight': |
|
_, c_in = ori_v.size() |
|
scale = (1 / math.sqrt(c_in)) * lr_mul |
|
crt_v = ori_v * scale |
|
else: |
|
crt_v = ori_v * lr_mul |
|
checkpoint_clean[crt_k] = crt_v |
|
elif 'style_conv' in ori_k: |
|
|
|
if 'activate' in ori_k: |
|
|
|
|
|
split_rlt = ori_k.split('.') |
|
if len(split_rlt) == 4: |
|
prefix, name, _, var = split_rlt |
|
crt_k = f'{prefix}.{name}.{var}' |
|
elif len(split_rlt) == 5: |
|
prefix, name, idx, _, var = split_rlt |
|
crt_k = f'{prefix}.{name}.{idx}.{var}' |
|
crt_v = ori_v * 2**0.5 |
|
c = crt_v.size(0) |
|
checkpoint_clean[crt_k] = crt_v.view(1, c, 1, 1) |
|
elif 'modulated_conv' in ori_k: |
|
|
|
|
|
_, c_out, c_in, k1, k2 = ori_v.size() |
|
scale = 1 / math.sqrt(c_in * k1 * k2) |
|
crt_k = ori_k |
|
checkpoint_clean[crt_k] = ori_v * scale |
|
elif 'weight' in ori_k: |
|
crt_k = ori_k |
|
checkpoint_clean[crt_k] = ori_v * 2**0.5 |
|
elif 'to_rgb' in ori_k: |
|
if 'modulated_conv' in ori_k: |
|
|
|
|
|
_, c_out, c_in, k1, k2 = ori_v.size() |
|
scale = 1 / math.sqrt(c_in * k1 * k2) |
|
crt_k = ori_k |
|
checkpoint_clean[crt_k] = ori_v * scale |
|
else: |
|
crt_k = ori_k |
|
checkpoint_clean[crt_k] = ori_v |
|
else: |
|
crt_k = ori_k |
|
checkpoint_clean[crt_k] = ori_v |
|
|
|
elif 'conv_body_first' in ori_k or 'final_conv' in ori_k: |
|
|
|
name, _, var = ori_k.split('.') |
|
crt_k = f'{name}.{var}' |
|
|
|
if var == 'weight': |
|
c_out, c_in, k1, k2 = ori_v.size() |
|
scale = 1 / math.sqrt(c_in * k1 * k2) |
|
checkpoint_clean[crt_k] = ori_v * scale * 2**0.5 |
|
else: |
|
checkpoint_clean[crt_k] = ori_v * 2**0.5 |
|
elif 'conv_body' in ori_k: |
|
if 'conv_body_up' in ori_k: |
|
ori_k = ori_k.replace('conv2.weight', 'conv2.1.weight') |
|
ori_k = ori_k.replace('skip.weight', 'skip.1.weight') |
|
name1, idx1, name2, _, var = ori_k.split('.') |
|
crt_k = f'{name1}.{idx1}.{name2}.{var}' |
|
if name2 == 'skip': |
|
c_out, c_in, k1, k2 = ori_v.size() |
|
scale = 1 / math.sqrt(c_in * k1 * k2) |
|
checkpoint_clean[crt_k] = ori_v * scale / 2**0.5 |
|
else: |
|
if var == 'weight': |
|
c_out, c_in, k1, k2 = ori_v.size() |
|
scale = 1 / math.sqrt(c_in * k1 * k2) |
|
checkpoint_clean[crt_k] = ori_v * scale |
|
else: |
|
checkpoint_clean[crt_k] = ori_v |
|
if 'conv1' in ori_k: |
|
checkpoint_clean[crt_k] *= 2**0.5 |
|
elif 'toRGB' in ori_k: |
|
crt_k = ori_k |
|
if 'weight' in ori_k: |
|
c_out, c_in, k1, k2 = ori_v.size() |
|
scale = 1 / math.sqrt(c_in * k1 * k2) |
|
checkpoint_clean[crt_k] = ori_v * scale |
|
else: |
|
checkpoint_clean[crt_k] = ori_v |
|
elif 'final_linear' in ori_k: |
|
crt_k = ori_k |
|
if 'weight' in ori_k: |
|
_, c_in = ori_v.size() |
|
scale = 1 / math.sqrt(c_in) |
|
checkpoint_clean[crt_k] = ori_v * scale |
|
else: |
|
checkpoint_clean[crt_k] = ori_v |
|
elif 'condition' in ori_k: |
|
crt_k = ori_k |
|
if '0.weight' in ori_k: |
|
c_out, c_in, k1, k2 = ori_v.size() |
|
scale = 1 / math.sqrt(c_in * k1 * k2) |
|
checkpoint_clean[crt_k] = ori_v * scale * 2**0.5 |
|
elif '0.bias' in ori_k: |
|
checkpoint_clean[crt_k] = ori_v * 2**0.5 |
|
elif '2.weight' in ori_k: |
|
c_out, c_in, k1, k2 = ori_v.size() |
|
scale = 1 / math.sqrt(c_in * k1 * k2) |
|
checkpoint_clean[crt_k] = ori_v * scale |
|
elif '2.bias' in ori_k: |
|
checkpoint_clean[crt_k] = ori_v |
|
|
|
return checkpoint_clean |
|
|
|
|
|
if __name__ == '__main__': |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--ori_path', type=str, help='Path to the original model') |
|
parser.add_argument('--narrow', type=float, default=1) |
|
parser.add_argument('--channel_multiplier', type=float, default=2) |
|
parser.add_argument('--save_path', type=str) |
|
args = parser.parse_args() |
|
|
|
ori_ckpt = torch.load(args.ori_path)['params_ema'] |
|
|
|
net = GFPGANv1Clean( |
|
512, |
|
num_style_feat=512, |
|
channel_multiplier=args.channel_multiplier, |
|
decoder_load_path=None, |
|
fix_decoder=False, |
|
|
|
num_mlp=8, |
|
input_is_latent=True, |
|
different_w=True, |
|
narrow=args.narrow, |
|
sft_half=True) |
|
crt_ckpt = net.state_dict() |
|
|
|
crt_ckpt = modify_checkpoint(ori_ckpt, crt_ckpt) |
|
print(f'Save to {args.save_path}.') |
|
torch.save(dict(params_ema=crt_ckpt), args.save_path, _use_new_zipfile_serialization=False) |
|
|