Spaces:
Runtime error
Runtime error
File size: 1,271 Bytes
92ec8d3 34e88c0 92ec8d3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 |
import torch
from argparse import Namespace
import sys
sys.path.extend(['.', '..'])
from models.stylegan2.model import Generator
from models.hyperstyle.hyperstyle import HyperStyle
from models.hyperstyle.encoders.e4e import e4e
def load_model(checkpoint_path, device='cuda', update_opts=None, is_restyle_encoder=False):
ckpt = torch.load(checkpoint_path, map_location='cpu')
opts = ckpt['opts']
opts['checkpoint_path'] = checkpoint_path
opts['load_w_encoder'] = True
if update_opts is not None:
if type(update_opts) == dict:
opts.update(update_opts)
else:
opts.update(vars(update_opts))
opts['checkpoint_path'] = checkpoint_path
opts['load_w_encoder'] = True
opts = Namespace(**opts)
if is_restyle_encoder:
net = e4e(opts)
else:
net = HyperStyle(opts)
net.eval()
net.to(opts.device)
return net, opts
def load_generator(checkpoint_path, device='cuda'):
print(f"Loading generator from checkpoint: {checkpoint_path}")
generator = Generator(1024, 512, 8, channel_multiplier=2)
ckpt = torch.load(checkpoint_path, map_location='cpu')
generator.load_state_dict(ckpt['g_ema'])
generator.eval()
generator.to(device)
return generator
|