Spaces:
Runtime error
Runtime error
import os | |
from PIL import Image | |
import torch | |
import gradio as gr | |
import torch | |
torch.backends.cudnn.benchmark = True | |
from torchvision import transforms, utils | |
from util import * | |
from PIL import Image | |
import math | |
import random | |
import numpy as np | |
from torch import nn, autograd, optim | |
from torch.nn import functional as F | |
from tqdm import tqdm | |
import lpips | |
from model import * | |
#from e4e_projection import projection as e4e_projection | |
from copy import deepcopy | |
import imageio | |
import os | |
import sys | |
import numpy as np | |
from PIL import Image | |
import torch | |
import torchvision.transforms as transforms | |
from argparse import Namespace | |
from e4e.models.psp import pSp | |
from util import * | |
os.makedirs('models', exist_ok=True) | |
os.system("gdown https://drive.google.com/uc?id=1jtCg8HQ6RlTmLdnbT2PfW1FJ2AYkWqsK") | |
os.system("cp e4e_ffhq_encode.pt models/e4e_ffhq_encode.pt") | |
device= 'cpu' | |
model_path = 'models/e4e_ffhq_encode.pt' | |
ckpt = torch.load(model_path, map_location='cpu') | |
opts = ckpt['opts'] | |
opts['checkpoint_path'] = model_path | |
opts= Namespace(**opts) | |
net = pSp(opts, device).eval().to(device) | |
def projection(img, name, device='cuda'): | |
transform = transforms.Compose( | |
[ | |
transforms.Resize(256), | |
transforms.CenterCrop(256), | |
transforms.ToTensor(), | |
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), | |
] | |
) | |
img = transform(img).unsqueeze(0).to(device) | |
images, w_plus = net(img, randomize_noise=False, return_latents=True) | |
result_file = {} | |
result_file['latent'] = w_plus[0] | |
torch.save(result_file, name) | |
return w_plus[0] | |
os.system("wget http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2") | |
os.system("bzip2 -dk shape_predictor_68_face_landmarks.dat.bz2") | |
os.system("mv shape_predictor_68_face_landmarks.dat models/dlibshape_predictor_68_face_landmarks.dat") | |
device = 'cpu' | |
os.system("gdown https://drive.google.com/uc?id=1_cTsjqzD_X9DK3t3IZE53huKgnzj_btZ") | |
latent_dim = 512 | |
original_generator = Generator(1024, latent_dim, 8, 2).to(device) | |
ckpt = torch.load('stylegan2-ffhq-config-f.pt', map_location=lambda storage, loc: storage) | |
original_generator.load_state_dict(ckpt["g_ema"], strict=False) | |
mean_latent = original_generator.mean_latent(10000) | |
generatorjojo = deepcopy(original_generator) | |
generatordisney = deepcopy(original_generator) | |
generatorjinx = deepcopy(original_generator) | |
generatorcaitlyn = deepcopy(original_generator) | |
generatoryasuho = deepcopy(original_generator) | |
generatorarcanemulti = deepcopy(original_generator) | |
generatorart = deepcopy(original_generator) | |
generatorspider = deepcopy(original_generator) | |
transform = transforms.Compose( | |
[ | |
transforms.Resize((1024, 1024)), | |
transforms.ToTensor(), | |
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), | |
] | |
) | |
os.system("wget https://huggingface.co/akhaliq/JoJoGAN-jojo/resolve/main/jojo_preserve_color.pt") | |
ckptjojo = torch.load('jojo_preserve_color.pt', map_location=lambda storage, loc: storage) | |
generatorjojo.load_state_dict(ckptjojo["g"], strict=False) | |
os.system("wget https://huggingface.co/akhaliq/jojogan-disney/resolve/main/disney_preserve_color.pt") | |
ckptdisney = torch.load('disney_preserve_color.pt', map_location=lambda storage, loc: storage) | |
generatordisney.load_state_dict(ckptdisney["g"], strict=False) | |
os.system("wget https://huggingface.co/akhaliq/jojo-gan-jinx/resolve/main/arcane_jinx_preserve_color.pt") | |
ckptjinx = torch.load('arcane_jinx_preserve_color.pt', map_location=lambda storage, loc: storage) | |
generatorjinx.load_state_dict(ckptjinx["g"], strict=False) | |
os.system("wget https://huggingface.co/akhaliq/jojogan-arcane/resolve/main/arcane_caitlyn_preserve_color.pt") | |
ckptcaitlyn = torch.load('arcane_caitlyn_preserve_color.pt', map_location=lambda storage, loc: storage) | |
generatorcaitlyn.load_state_dict(ckptcaitlyn["g"], strict=False) | |
os.system("wget https://huggingface.co/akhaliq/JoJoGAN-jojo/resolve/main/jojo_yasuho_preserve_color.pt") | |
ckptyasuho = torch.load('jojo_yasuho_preserve_color.pt', map_location=lambda storage, loc: storage) | |
generatoryasuho.load_state_dict(ckptyasuho["g"], strict=False) | |
os.system("wget https://huggingface.co/akhaliq/jojogan-arcane/resolve/main/arcane_multi_preserve_color.pt") | |
ckptarcanemulti = torch.load('arcane_multi_preserve_color.pt', map_location=lambda storage, loc: storage) | |
generatorarcanemulti.load_state_dict(ckptarcanemulti["g"], strict=False) | |
os.system("wget https://huggingface.co/akhaliq/jojo-gan-art/resolve/main/art.pt") | |
ckptart = torch.load('art.pt', map_location=lambda storage, loc: storage) | |
generatorart.load_state_dict(ckptart["g"], strict=False) | |
os.system("wget https://huggingface.co/akhaliq/jojo-gan-spiderverse/resolve/main/Spiderverse-face-500iters-8face.pt") | |
ckptspider = torch.load('Spiderverse-face-500iters-8face.pt', map_location=lambda storage, loc: storage) | |
generatorspider.load_state_dict(ckptspider["g"], strict=False) | |
def inference(img, model): | |
aligned_face = align_face(img) | |
my_w = projection(aligned_face, "test.pt", device).unsqueeze(0) | |
if model == 'JoJo': | |
with torch.no_grad(): | |
my_sample = generatorjojo(my_w, input_is_latent=True) | |
elif model == 'Disney': | |
with torch.no_grad(): | |
my_sample = generatordisney(my_w, input_is_latent=True) | |
elif model == 'Jinx': | |
with torch.no_grad(): | |
my_sample = generatorjinx(my_w, input_is_latent=True) | |
elif model == 'Caitlyn': | |
with torch.no_grad(): | |
my_sample = generatorcaitlyn(my_w, input_is_latent=True) | |
elif model == 'Yasuho': | |
with torch.no_grad(): | |
my_sample = generatoryasuho(my_w, input_is_latent=True) | |
elif model == 'Arcane Multi': | |
with torch.no_grad(): | |
my_sample = generatorarcanemulti(my_w, input_is_latent=True) | |
elif model == 'Art': | |
with torch.no_grad(): | |
my_sample = generatorart(my_w, input_is_latent=True) | |
else: | |
with torch.no_grad(): | |
my_sample = generatorspider(my_w, input_is_latent=True) | |
npimage = my_sample[0].permute(1, 2, 0).detach().numpy() | |
imageio.imwrite('filename.jpeg', npimage) | |
return 'filename.jpeg' | |
title = "JoJoGAN" | |
description = "Gradio Demo for JoJoGAN: One Shot Face Stylization. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below." | |
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2112.11641' target='_blank'>JoJoGAN: One Shot Face Stylization</a>| <a href='https://github.com/mchong6/JoJoGAN' target='_blank'>Github Repo Pytorch</a></p> <center><img src='https://visitor-badge.glitch.me/badge?page_id=akhaliq_jojogan' alt='visitor badge'></center>" | |
examples=[['mona.png','Jinx']] | |
gr.Interface(inference, [gr.inputs.Image(type="filepath"),gr.inputs.Dropdown(choices=['JoJo', 'Disney','Jinx','Caitlyn','Yasuho','Arcane Multi','Art','Spider-Verse'], type="value", default='JoJo', label="Model")], gr.outputs.Image(type="file"),title=title,description=description,article=article,allow_flagging=False,examples=examples,allow_screenshot=False,enable_queue=True).launch() | |