Ohayou_Face / model_build.py
Reeve's picture
Update model_build.py
50bf3a8
raw
history blame
2.59 kB
import os
import glob
import numpy as np
from numpy import linalg
import PIL.Image as Image
import torch
from torchvision import transforms
from tqdm import tqdm
from argparse import Namespace
import easydict
import legacy
import dnnlib
from opensimplex import OpenSimplex
from configs import data_configs
from models.psp import pSp
def build_stylegan2(
increment = 0.01,
network_pkl = 'pretrained/ohayou_face2.pkl',
process = 'image', #['image', 'interpolation','truncation','interpolation-truncation']
random_seed = 0,
diameter = 100.0,
scale_type = 'pad', #['pad', 'padside', 'symm','symmside']
size = [512, 512],
seeds = [0],
space = 'z', #['z', 'w']
fps = 24,
frames = 240,
noise_mode = 'none', #['const', 'random', 'none']
outdir = 'path',
projected_w = 'path',
easing = 'linear',
device = 'cpu'
):
G_kwargs = dnnlib.EasyDict()
G_kwargs.size = size
G_kwargs.scale_type = scale_type
device = torch.device(device)
with dnnlib.util.open_url(network_pkl) as f:
# G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore
G = legacy.load_network_pkl(f, custom=True, **G_kwargs)['G_ema'].to(device) # type: ignore
return G.synthesis
def build_psp():
test_opts = easydict.EasyDict({
# arguments for inference script
'checkpoint_path' : 'pretrained/ohayou_face.pt',
'couple_outputs' : False,
'resize_outputs' : False,
'test_batch_size' : 1,
'test_workers' : 1,
# arguments for style-mixing script
'n_images' : None,
'n_outputs_to_generate' : 5,
'mix_alpha' : None,
'latent_mask' : None,
# arguments for super-resolution
'resize_factors' : None,
})
# update test options with options used during training
ckpt = torch.load(test_opts.checkpoint_path, map_location='cpu')
opts = ckpt['opts']
opts.update(vars(test_opts))
if 'learn_in_w' not in opts:
opts['learn_in_w'] = False
opts = Namespace(**opts)
opts.device = 'cpu'
net = pSp(opts)
net.eval()
return net
def img_preprocess(img, transform):
if (img.mode == 'RGBA') or (img.mode == 'P'):
img.load()
background = Image.new("RGB", img.size, (255, 255, 255))
background.paste(img, mask=img.split()[3]) # 3 is the alpha channel
img = background
assert img.mode == 'RGB'
img = transform(img)
img = img.unsqueeze(dim=0)
return img