File size: 1,271 Bytes
92ec8d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34e88c0
92ec8d3
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import torch
from argparse import Namespace

import sys
sys.path.extend(['.', '..'])

from models.stylegan2.model import Generator
from models.hyperstyle.hyperstyle import HyperStyle
from models.hyperstyle.encoders.e4e import e4e


def load_model(checkpoint_path, device='cuda', update_opts=None, is_restyle_encoder=False):
    ckpt = torch.load(checkpoint_path, map_location='cpu')
    opts = ckpt['opts']

    opts['checkpoint_path'] = checkpoint_path
    opts['load_w_encoder'] = True

    if update_opts is not None:
        if type(update_opts) == dict:
            opts.update(update_opts)
        else:
            opts.update(vars(update_opts))
    opts['checkpoint_path'] = checkpoint_path
    opts['load_w_encoder'] = True

    opts = Namespace(**opts)

    if is_restyle_encoder:
        net = e4e(opts)
    else:
        net = HyperStyle(opts)

    net.eval()
    net.to(opts.device)
    return net, opts


def load_generator(checkpoint_path, device='cuda'):
    print(f"Loading generator from checkpoint: {checkpoint_path}")
    generator = Generator(1024, 512, 8, channel_multiplier=2)
    ckpt = torch.load(checkpoint_path, map_location='cpu')
    generator.load_state_dict(ckpt['g_ema'])
    generator.eval()
    generator.to(device)
    return generator