JoJoGAN / app.py
Ahsen Khaliq
Update app.py
8fd4221
raw history blame
No virus
4.09 kB
import os
from PIL import Image
import torch
import gradio as gr
os.system("pip install dlib")
import torch
torch.backends.cudnn.benchmark = True
from torchvision import transforms, utils
from util import *
from PIL import Image
import math
import random
import numpy as np
from torch import nn, autograd, optim
from torch.nn import functional as F
from tqdm import tqdm
import lpips
from model import *
from e4e_projection import projection as e4e_projection
from copy import deepcopy
import imageio
os.makedirs('inversion_codes', exist_ok=True)
os.makedirs('style_images', exist_ok=True)
os.makedirs('style_images_aligned', exist_ok=True)
os.makedirs('models', exist_ok=True)
os.system("wget http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2")
os.system("bzip2 -dk shape_predictor_68_face_landmarks.dat.bz2")
os.system("mv shape_predictor_68_face_landmarks.dat models/dlibshape_predictor_68_face_landmarks.dat")
device = 'cpu'
os.system("gdown https://drive.google.com/uc?id=1-AG7JPTWc9REBrkll3OyEpZwSOWhlX0j")
latent_dim = 512
# Load original generator
original_generator = Generator(1024, latent_dim, 8, 2).to(device)
ckpt = torch.load('stylegan2-ffhq-config-f.pt', map_location=lambda storage, loc: storage)
original_generator.load_state_dict(ckpt["g_ema"], strict=False)
mean_latent = original_generator.mean_latent(10000)
# to be finetuned generator
generatorjojo = deepcopy(original_generator)
generatordisney = deepcopy(original_generator)
transform = transforms.Compose(
[
transforms.Resize((1024, 1024)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
)
os.system("gdown https://drive.google.com/uc?id=1-7UlCppmiG4DKbhYDNbIZTc6mHy9JMWJ")
os.system("cp e4e_ffhq_encode.pt models/e4e_ffhq_encode.pt")
plt.rcParams['figure.dpi'] = 150
os.system("gdown https://drive.google.com/uc?id=1-8E0PFT37v5fZs-61oIrFbNpE28Unp2y")
ckptjojo = torch.load('jojo.pt', map_location=lambda storage, loc: storage)
generatorjojo.load_state_dict(ckptjojo["g"], strict=False)
os.system("gdown https://drive.google.com/uc?id=1Bnh02DjfvN_Wm8c4JdOiNV4q9J7Z_tsi")
ckptdisney = torch.load('disney_preserve_color.pt', map_location=lambda storage, loc: storage)
generatordisney.load_state_dict(ckptdisney["g"], strict=False)
def inference(img, model):
aligned_face = align_face(img)
my_w = e4e_projection(aligned_face, "test.pt", device).unsqueeze(0)
if model == 'JoJo':
with torch.no_grad():
generator.eval()
#original_my_sample = original_generator(my_w, input_is_latent=True)
my_sample = generatorjojo(my_w, input_is_latent=True)
else:
with torch.no_grad():
generator.eval()
#original_my_sample = original_generator(my_w, input_is_latent=True)
my_sample = generatordisney(my_w, input_is_latent=True)
npimage = my_sample[0].permute(1, 2, 0).detach().numpy()
imageio.imwrite('filename.jpeg', npimage)
return 'filename.jpeg'
title = "JoJoGAN"
description = "Gradio Demo for JoJoGAN: One Shot Face Stylization. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below."
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2112.11641' target='_blank'>JoJoGAN: One Shot Face Stylization</a>| <a href='https://github.com/mchong6/JoJoGAN' target='_blank'>Github Repo Pytorch</a></p> <center><img src='https://visitor-badge.glitch.me/badge?page_id=akhaliq_jojogan' alt='visitor badge'></center> <p style='text-align: center'>samples from repo: <img src='https://raw.githubusercontent.com/mchong6/JoJoGAN/main/teaser.jpg' alt='animation'/></p>"
examples=[['iu.jpeg']]
gr.Interface(inference, [gr.inputs.Image(type="filepath",shape=(256,256)),gradio.inputs.Dropdown(choices=['JoJo', 'Disney'], type="value", default='JoJo', label="Model")], gr.outputs.Image(type="file"),title=title,description=description,article=article,enable_queue=True,allow_flagging=False,examples=examples).launch()