import os import glob import numpy as np from numpy import linalg import PIL.Image as Image import torch from torchvision import transforms from tqdm import tqdm from argparse import Namespace import easydict import legacy import dnnlib from opensimplex import OpenSimplex from configs import data_configs from models.psp import pSp def build_stylegan2( increment = 0.01, network_pkl = 'pretrained/ohayou_face2.pkl', process = 'image', #['image', 'interpolation','truncation','interpolation-truncation'] random_seed = 0, diameter = 100.0, scale_type = 'pad', #['pad', 'padside', 'symm','symmside'] size = [512, 512], seeds = [0], space = 'z', #['z', 'w'] fps = 24, frames = 240, noise_mode = 'none', #['const', 'random', 'none'] outdir = 'path', projected_w = 'path', easing = 'linear', device = 'cpu' ): G_kwargs = dnnlib.EasyDict() G_kwargs.size = size G_kwargs.scale_type = scale_type device = torch.device(device) with dnnlib.util.open_url(network_pkl) as f: # G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore G = legacy.load_network_pkl(f, custom=True, **G_kwargs)['G_ema'].to(device) # type: ignore return G.synthesis def build_psp(): test_opts = easydict.EasyDict({ # arguments for inference script 'checkpoint_path' : 'pretrained/ohayou_face.pt', 'couple_outputs' : False, 'resize_outputs' : False, 'test_batch_size' : 1, 'test_workers' : 1, # arguments for style-mixing script 'n_images' : None, 'n_outputs_to_generate' : 5, 'mix_alpha' : None, 'latent_mask' : None, # arguments for super-resolution 'resize_factors' : None, }) # update test options with options used during training ckpt = torch.load(test_opts.checkpoint_path, map_location='cpu') opts = ckpt['opts'] opts.update(vars(test_opts)) if 'learn_in_w' not in opts: opts['learn_in_w'] = False opts = Namespace(**opts) opts.device = 'cpu' net = pSp(opts) net.eval() return net def img_preprocess(img, transform): if (img.mode == 'RGBA') or (img.mode == 'P'): img.load() background = Image.new("RGB", img.size, (255, 255, 255)) background.paste(img, mask=img.split()[3]) # 3 is the alpha channel img = background assert img.mode == 'RGB' img = transform(img) img = img.unsqueeze(dim=0) return img