|
import os |
|
from PIL import Image |
|
import torch |
|
import gradio as gr |
|
|
|
|
|
import torch |
|
torch.backends.cudnn.benchmark = True |
|
from torchvision import transforms, utils |
|
from util import * |
|
from PIL import Image |
|
import math |
|
import random |
|
import numpy as np |
|
from torch import nn, autograd, optim |
|
from torch.nn import functional as F |
|
from tqdm import tqdm |
|
import lpips |
|
from model import * |
|
from e4e_projection import projection as e4e_projection |
|
|
|
from copy import deepcopy |
|
import imageio |
|
|
|
os.makedirs('inversion_codes', exist_ok=True) |
|
os.makedirs('style_images', exist_ok=True) |
|
os.makedirs('style_images_aligned', exist_ok=True) |
|
os.makedirs('models', exist_ok=True) |
|
|
|
os.system("wget http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2") |
|
os.system("bzip2 -dk shape_predictor_68_face_landmarks.dat.bz2") |
|
os.system("mv shape_predictor_68_face_landmarks.dat models/dlibshape_predictor_68_face_landmarks.dat") |
|
|
|
|
|
device = 'cpu' |
|
|
|
os.system("gdown https://drive.google.com/uc?id=1_cTsjqzD_X9DK3t3IZE53huKgnzj_btZ") |
|
|
|
latent_dim = 512 |
|
|
|
original_generator = Generator(1024, latent_dim, 8, 2).to(device) |
|
ckpt = torch.load('stylegan2-ffhq-config-f.pt', map_location=lambda storage, loc: storage) |
|
original_generator.load_state_dict(ckpt["g_ema"], strict=False) |
|
mean_latent = original_generator.mean_latent(10000) |
|
|
|
generatorjojo = deepcopy(original_generator) |
|
|
|
generatordisney = deepcopy(original_generator) |
|
|
|
generatorjinx = deepcopy(original_generator) |
|
|
|
generatorcaitlyn = deepcopy(original_generator) |
|
|
|
generatoryasuho = deepcopy(original_generator) |
|
|
|
generatorarcanemulti = deepcopy(original_generator) |
|
|
|
generatorart = deepcopy(original_generator) |
|
|
|
generatorspider = deepcopy(original_generator) |
|
|
|
|
|
transform = transforms.Compose( |
|
[ |
|
transforms.Resize((1024, 1024)), |
|
transforms.ToTensor(), |
|
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), |
|
] |
|
) |
|
|
|
os.system("gdown https://drive.google.com/uc?id=1jtCg8HQ6RlTmLdnbT2PfW1FJ2AYkWqsK") |
|
os.system("cp e4e_ffhq_encode.pt models/e4e_ffhq_encode.pt") |
|
|
|
os.system("wget https://huggingface.co/akhaliq/JoJoGAN-jojo/resolve/main/jojo_preserve_color.pt") |
|
|
|
ckptjojo = torch.load('jojo_preserve_color.pt', map_location=lambda storage, loc: storage) |
|
generatorjojo.load_state_dict(ckptjojo["g"], strict=False) |
|
|
|
os.system("wget https://huggingface.co/akhaliq/jojogan-disney/resolve/main/disney_preserve_color.pt") |
|
|
|
ckptdisney = torch.load('disney_preserve_color.pt', map_location=lambda storage, loc: storage) |
|
generatordisney.load_state_dict(ckptdisney["g"], strict=False) |
|
|
|
os.system("gdown https://drive.google.com/uc?id=1jElwHxaYPod5Itdy18izJk49K1nl4ney") |
|
|
|
ckptjinx = torch.load('arcane_jinx_preserve_color.pt', map_location=lambda storage, loc: storage) |
|
generatorjinx.load_state_dict(ckptjinx["g"], strict=False) |
|
|
|
os.system("gdown https://drive.google.com/uc?id=1cUTyjU-q98P75a8THCaO545RTwpVV-aH") |
|
|
|
ckptcaitlyn = torch.load('arcane_caitlyn_preserve_color.pt', map_location=lambda storage, loc: storage) |
|
generatorcaitlyn.load_state_dict(ckptcaitlyn["g"], strict=False) |
|
|
|
os.system("gdown https://drive.google.com/uc?id=1SKBu1h0iRNyeKBnya_3BBmLr4pkPeg_L") |
|
|
|
ckptyasuho = torch.load('jojo_yasuho_preserve_color.pt', map_location=lambda storage, loc: storage) |
|
generatoryasuho.load_state_dict(ckptyasuho["g"], strict=False) |
|
|
|
os.system("wget https://huggingface.co/akhaliq/jojogan-arcane/resolve/main/arcane_multi_preserve_color.pt") |
|
|
|
ckptarcanemulti = torch.load('arcane_multi_preserve_color.pt', map_location=lambda storage, loc: storage) |
|
generatorarcanemulti.load_state_dict(ckptarcanemulti["g"], strict=False) |
|
|
|
os.system("gdown https://drive.google.com/uc?id=1a0QDEHwXQ6hE_FcYEyNMuv5r5UnRQLKT") |
|
|
|
ckptart = torch.load('art.pt', map_location=lambda storage, loc: storage) |
|
generatorart.load_state_dict(ckptart["g"], strict=False) |
|
|
|
os.system("wget https://huggingface.co/akhaliq/jojo-gan-spiderverse/resolve/main/spiderverse-checkpoint-3-face-500iters.pt") |
|
|
|
ckptspider = torch.load('spiderverse-checkpoint-3-face-500iters.pt', map_location=lambda storage, loc: storage) |
|
generatorspider.load_state_dict(ckptspider["g"], strict=False) |
|
|
|
|
|
def inference(img, model): |
|
aligned_face = align_face(img) |
|
|
|
my_w = e4e_projection(aligned_face, "test.pt", device).unsqueeze(0) |
|
if model == 'JoJo': |
|
with torch.no_grad(): |
|
my_sample = generatorjojo(my_w, input_is_latent=True) |
|
elif model == 'Disney': |
|
with torch.no_grad(): |
|
my_sample = generatordisney(my_w, input_is_latent=True) |
|
elif model == 'Jinx': |
|
with torch.no_grad(): |
|
my_sample = generatorjinx(my_w, input_is_latent=True) |
|
elif model == 'Caitlyn': |
|
with torch.no_grad(): |
|
my_sample = generatorcaitlyn(my_w, input_is_latent=True) |
|
elif model == 'Yasuho': |
|
with torch.no_grad(): |
|
my_sample = generatoryasuho(my_w, input_is_latent=True) |
|
elif model == 'Arcane Multi': |
|
with torch.no_grad(): |
|
my_sample = generatorarcanemulti(my_w, input_is_latent=True) |
|
elif model == 'Art': |
|
with torch.no_grad(): |
|
my_sample = generatorart(my_w, input_is_latent=True) |
|
else: |
|
with torch.no_grad(): |
|
my_sample = generatorspider(my_w, input_is_latent=True) |
|
|
|
|
|
npimage = my_sample[0].permute(1, 2, 0).detach().numpy() |
|
imageio.imwrite('filename.jpeg', npimage) |
|
return 'filename.jpeg' |
|
|
|
title = "JoJoGAN" |
|
description = "Gradio Demo for JoJoGAN: One Shot Face Stylization. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below." |
|
|
|
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2112.11641' target='_blank'>JoJoGAN: One Shot Face Stylization</a>| <a href='https://github.com/mchong6/JoJoGAN' target='_blank'>Github Repo Pytorch</a></p> <center><img src='https://visitor-badge.glitch.me/badge?page_id=akhaliq_jojogan' alt='visitor badge'></center> <p style='text-align: center'>samples from repo: <img src='https://raw.githubusercontent.com/mchong6/JoJoGAN/main/teaser.jpg' alt='animation'/></p>" |
|
|
|
examples=[['iu.jpeg','Jinx']] |
|
gr.Interface(inference, [gr.inputs.Image(type="filepath"),gr.inputs.Dropdown(choices=['JoJo', 'Disney','Jinx','Caitlyn','Yasuho','Arcane Multi','Art','Spider-Verse'], type="value", default='JoJo', label="Model")], gr.outputs.Image(type="file"),title=title,description=description,article=article,enable_queue=True,allow_flagging=False,examples=examples).launch() |
|
|