Spaces:
Runtime error
Runtime error
File size: 5,149 Bytes
8d4d98f f338a52 2a30e2f 58091b8 8b17f8c 8d4d98f 3f2d4dc 8d4d98f 58091b8 8d4d98f 2a30e2f 8d4d98f d8e7405 8d4d98f ddaf006 8d4d98f 8fd4221 ec8f0b0 8d4d98f 474b6cf 46a015d feb9da2 aa0db9c 91959e5 38887ff 91959e5 ec8f0b0 91959e5 2a30e2f 58091b8 2a30e2f 58091b8 2a30e2f 4c71d5b 76c051f 2a30e2f 91959e5 ec8f0b0 91959e5 ec8f0b0 aa0db9c a058c0e 3f2d4dc 8d4d98f a058c0e 5d457fc aa0db9c fec4733 5d457fc ec8f0b0 f338a52 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
import os
from PIL import Image
import torch
import gradio as gr
os.system("pip install gradio==2.5.3")
os.system("pip install facexlib")
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
#os.system("pip install autocrop")
#os.system("pip install dlib")
#from autocrop import Cropper
import torch
torch.backends.cudnn.benchmark = True
from torchvision import transforms, utils
from util import *
from PIL import Image
import math
import random
import numpy as np
from torch import nn, autograd, optim
from torch.nn import functional as F
from tqdm import tqdm
import lpips
from model import *
from e4e_projection import projection as e4e_projection
from copy import deepcopy
import imageio
os.makedirs('inversion_codes', exist_ok=True)
os.makedirs('style_images', exist_ok=True)
os.makedirs('style_images_aligned', exist_ok=True)
os.makedirs('models', exist_ok=True)
#os.system("wget http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2")
#os.system("bzip2 -dk shape_predictor_68_face_landmarks.dat.bz2")
#os.system("mv shape_predictor_68_face_landmarks.dat models/dlibshape_predictor_68_face_landmarks.dat")
#cropper = Cropper(face_percent=80)
face_helper = FaceRestoreHelper(
upscale_factor=0,
face_size=512,
crop_ratio=(1, 1),
det_model='retinaface_resnet50',
save_ext='png',
device='cpu')
device = 'cpu'
os.system("gdown https://drive.google.com/uc?id=1_cTsjqzD_X9DK3t3IZE53huKgnzj_btZ")
latent_dim = 512
# Load original generator
original_generator = Generator(1024, latent_dim, 8, 2).to(device)
ckpt = torch.load('stylegan2-ffhq-config-f.pt', map_location=lambda storage, loc: storage)
original_generator.load_state_dict(ckpt["g_ema"], strict=False)
mean_latent = original_generator.mean_latent(10000)
# to be finetuned generator
generatorjojo = deepcopy(original_generator)
generatordisney = deepcopy(original_generator)
generatorjinx = deepcopy(original_generator)
transform = transforms.Compose(
[
transforms.Resize((1024, 1024)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
)
os.system("gdown https://drive.google.com/uc?id=1jtCg8HQ6RlTmLdnbT2PfW1FJ2AYkWqsK")
os.system("cp e4e_ffhq_encode.pt models/e4e_ffhq_encode.pt")
os.system("gdown https://drive.google.com/uc?id=1-8E0PFT37v5fZs-61oIrFbNpE28Unp2y")
ckptjojo = torch.load('jojo.pt', map_location=lambda storage, loc: storage)
generatorjojo.load_state_dict(ckptjojo["g"], strict=False)
os.system("gdown https://drive.google.com/uc?id=1Bnh02DjfvN_Wm8c4JdOiNV4q9J7Z_tsi")
ckptdisney = torch.load('disney_preserve_color.pt', map_location=lambda storage, loc: storage)
generatordisney.load_state_dict(ckptdisney["g"], strict=False)
os.system("gdown https://drive.google.com/uc?id=1jElwHxaYPod5Itdy18izJk49K1nl4ney")
ckptjinx = torch.load('arcane_jinx_preserve_color.pt', map_location=lambda storage, loc: storage)
generatorjinx.load_state_dict(ckptjinx["g"], strict=False)
def inference(img, model):
face_helper.clean_all()
#aligned_face = align_face(img)
#cropped_array = cropper.crop(img[:,:,::-1])
#if cropped_array.any():
#aligned_face = Image.fromarray(cropped_array)
#else:
#aligned_face = Image.fromarray(img[:,:,::-1])
face_helper.read_image(img)
face_helper.get_face_landmarks_5(only_center_face=False, eye_dist_threshold=10)
face_helper.align_warp_face(save_cropped_path="/home/user/app/")
pilimg = Image.open("/home/user/app/_02.png")
my_w = e4e_projection(pilimg, "test.pt", device).unsqueeze(0)
if model == 'JoJo':
with torch.no_grad():
my_sample = generatorjojo(my_w, input_is_latent=True)
elif model == 'Disney':
with torch.no_grad():
my_sample = generatordisney(my_w, input_is_latent=True)
else:
with torch.no_grad():
my_sample = generatorjinx(my_w, input_is_latent=True)
npimage = my_sample[0].permute(1, 2, 0).detach().numpy()
imageio.imwrite('filename.jpeg', npimage)
return 'filename.jpeg'
title = "JoJoGAN"
description = "Gradio Demo for JoJoGAN: One Shot Face Stylization. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below."
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2112.11641' target='_blank'>JoJoGAN: One Shot Face Stylization</a>| <a href='https://github.com/mchong6/JoJoGAN' target='_blank'>Github Repo Pytorch</a></p> <center><img src='https://visitor-badge.glitch.me/badge?page_id=akhaliq_jojogan' alt='visitor badge'></center> <p style='text-align: center'>samples from repo: <img src='https://raw.githubusercontent.com/mchong6/JoJoGAN/main/teaser.jpg' alt='animation'/></p>"
examples=[['iu.jpeg','Jinx']]
gr.Interface(inference, [gr.inputs.Image(type="numpy"),gr.inputs.Dropdown(choices=['JoJo', 'Disney','Jinx'], type="value", default='JoJo', label="Model")], gr.outputs.Image(type="file"),title=title,description=description,article=article,enable_queue=True,allow_flagging=False,examples=examples).launch()
|