Spaces:
Runtime error
Runtime error
import os | |
import glob | |
import numpy as np | |
from numpy import linalg | |
import PIL.Image as Image | |
import torch | |
from torchvision import transforms | |
from tqdm import tqdm | |
from argparse import Namespace | |
import easydict | |
import legacy | |
import dnnlib | |
from opensimplex import OpenSimplex | |
from configs import data_configs | |
from models.psp import pSp | |
def build_stylegan2( | |
increment = 0.01, | |
network_pkl = 'pretrained/ohayou_face2.pkl', | |
process = 'image', #['image', 'interpolation','truncation','interpolation-truncation'] | |
random_seed = 0, | |
diameter = 100.0, | |
scale_type = 'pad', #['pad', 'padside', 'symm','symmside'] | |
size = [512, 512], | |
seeds = [0], | |
space = 'z', #['z', 'w'] | |
fps = 24, | |
frames = 240, | |
noise_mode = 'none', #['const', 'random', 'none'] | |
outdir = 'path', | |
projected_w = 'path', | |
easing = 'linear', | |
device = 'cpu' | |
): | |
G_kwargs = dnnlib.EasyDict() | |
G_kwargs.size = size | |
G_kwargs.scale_type = scale_type | |
device = torch.device(device) | |
with dnnlib.util.open_url(network_pkl) as f: | |
# G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore | |
G = legacy.load_network_pkl(f, custom=True, **G_kwargs)['G_ema'].to(device) # type: ignore | |
return G.synthesis | |
def build_psp(): | |
test_opts = easydict.EasyDict({ | |
# arguments for inference script | |
'checkpoint_path' : 'pretrained/ohayou_face.pt', | |
'couple_outputs' : False, | |
'resize_outputs' : False, | |
'test_batch_size' : 1, | |
'test_workers' : 1, | |
# arguments for style-mixing script | |
'n_images' : None, | |
'n_outputs_to_generate' : 5, | |
'mix_alpha' : None, | |
'latent_mask' : None, | |
# arguments for super-resolution | |
'resize_factors' : None, | |
}) | |
# update test options with options used during training | |
ckpt = torch.load(test_opts.checkpoint_path, map_location='cpu') | |
opts = ckpt['opts'] | |
opts.update(vars(test_opts)) | |
if 'learn_in_w' not in opts: | |
opts['learn_in_w'] = False | |
opts = Namespace(**opts) | |
opts.device = 'cpu' | |
net = pSp(opts) | |
net.eval() | |
return net | |
def img_preprocess(img, transform): | |
if (img.mode == 'RGBA') or (img.mode == 'P'): | |
img.load() | |
background = Image.new("RGB", img.size, (255, 255, 255)) | |
background.paste(img, mask=img.split()[3]) # 3 is the alpha channel | |
img = background | |
assert img.mode == 'RGB' | |
img = transform(img) | |
img = img.unsqueeze(dim=0) | |
return img |