File size: 2,577 Bytes
f39e999
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import os
import glob

import numpy as np
from numpy import linalg
import PIL.Image as Image
import torch
from torchvision import transforms
from tqdm import tqdm
from argparse import Namespace
import easydict

import legacy
import dnnlib

from opensimplex import OpenSimplex

from configs import data_configs
from models.psp import pSp


def build_stylegan2(
    increment = 0.01,
    network_pkl = 'pretrained/furry.pkl',
    process = 'image',                 #['image', 'interpolation','truncation','interpolation-truncation']
    random_seed = 0,
    diameter = 100.0,
    scale_type = 'pad',               #['pad', 'padside', 'symm','symmside']
    size = [512, 512],
    seeds =  [0],
    space = 'z',                    #['z', 'w']
    fps = 24,
    frames = 240,
    noise_mode = 'none',     #['const', 'random', 'none']
    outdir = 'path',
    projected_w = 'path',
    easing = 'linear',
    device = 'cpu'

    ):

    G_kwargs = dnnlib.EasyDict()
    G_kwargs.size = size 
    G_kwargs.scale_type = scale_type

    device = torch.device(device)
    with dnnlib.util.open_url(network_pkl) as f:
        # G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore
        G = legacy.load_network_pkl(f, custom=True, **G_kwargs)['G_ema'].to(device) # type: ignore
   
    return G.synthesis


def build_psp():
    test_opts = easydict.EasyDict({
        # arguments for inference script
        'checkpoint_path' : 'pretrained/psp.pt',
        'couple_outputs' : False,
        'resize_outputs' : False,
    
        'test_batch_size' : 1,
        'test_workers' : 1,
    
        # arguments for style-mixing script
        'n_images' : None,
        'n_outputs_to_generate' : 5,
        'mix_alpha' : None,
        'latent_mask' : None,
    
        # arguments for super-resolution
        'resize_factors' : None,
    })

    # update test options with options used during training
    ckpt = torch.load(test_opts.checkpoint_path, map_location='cpu')
    opts = ckpt['opts']
    opts.update(vars(test_opts))
    if 'learn_in_w' not in opts:
        opts['learn_in_w'] = False
    opts = Namespace(**opts)
    opts.device = 'cpu'
    net = pSp(opts)
    net.eval()
    return net
    
def img_preprocess(img, transform):
    if (img.mode == 'RGBA') or (img.mode == 'P'):
        img.load()
        background = Image.new("RGB", img.size, (255, 255, 255))
        background.paste(img, mask=img.split()[3]) # 3 is the alpha channel
        img = background
    assert img.mode == 'RGB'
    img = transform(img)
    img = img.unsqueeze(dim=0)
    return img