File size: 741 Bytes
a0bcaae |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 |
# Copyright (c) SenseTime Research. All rights reserved.
import pickle
import functools
import torch
from pti.pti_configs import paths_config, global_config
def toogle_grad(model, flag=True):
for p in model.parameters():
p.requires_grad = flag
def load_tuned_G(run_id, type):
new_G_path = f'{paths_config.checkpoints_dir}/model_{run_id}_{type}.pt'
with open(new_G_path, 'rb') as f:
new_G = torch.load(f).to(global_config.device).eval()
new_G = new_G.float()
toogle_grad(new_G, False)
return new_G
def load_old_G():
with open(paths_config.stylegan2_ada_shhq, 'rb') as f:
old_G = pickle.load(f)['G_ema'].to(global_config.device).eval()
old_G = old_G.float()
return old_G
|