Spaces:
Runtime error
Runtime error
Ahsen Khaliq
commited on
Commit
•
49ce528
1
Parent(s):
eb212a1
add caitlyn model
Browse files
app.py
CHANGED
@@ -4,12 +4,6 @@ import torch
|
|
4 |
import gradio as gr
|
5 |
os.system("pip install gradio==2.5.3")
|
6 |
|
7 |
-
#os.system("pip install facexlib")
|
8 |
-
|
9 |
-
#from facexlib.utils.face_restoration_helper import FaceRestoreHelper
|
10 |
-
#os.system("pip install autocrop")
|
11 |
-
#os.system("pip install dlib")
|
12 |
-
#from autocrop import Cropper
|
13 |
import torch
|
14 |
torch.backends.cudnn.benchmark = True
|
15 |
from torchvision import transforms, utils
|
@@ -38,15 +32,6 @@ os.system("wget http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2"
|
|
38 |
os.system("bzip2 -dk shape_predictor_68_face_landmarks.dat.bz2")
|
39 |
os.system("mv shape_predictor_68_face_landmarks.dat models/dlibshape_predictor_68_face_landmarks.dat")
|
40 |
|
41 |
-
#cropper = Cropper(face_percent=80)
|
42 |
-
|
43 |
-
#face_helper = FaceRestoreHelper(
|
44 |
-
#upscale_factor=0,
|
45 |
-
#face_size=512,
|
46 |
-
#crop_ratio=(1, 1),
|
47 |
-
#det_model='retinaface_resnet50',
|
48 |
-
#save_ext='png',
|
49 |
-
#device='cpu')
|
50 |
|
51 |
device = 'cpu'
|
52 |
|
@@ -54,19 +39,19 @@ os.system("gdown https://drive.google.com/uc?id=1_cTsjqzD_X9DK3t3IZE53huKgnzj_bt
|
|
54 |
|
55 |
latent_dim = 512
|
56 |
|
57 |
-
# Load original generator
|
58 |
original_generator = Generator(1024, latent_dim, 8, 2).to(device)
|
59 |
ckpt = torch.load('stylegan2-ffhq-config-f.pt', map_location=lambda storage, loc: storage)
|
60 |
original_generator.load_state_dict(ckpt["g_ema"], strict=False)
|
61 |
mean_latent = original_generator.mean_latent(10000)
|
62 |
|
63 |
-
# to be finetuned generator
|
64 |
generatorjojo = deepcopy(original_generator)
|
65 |
|
66 |
generatordisney = deepcopy(original_generator)
|
67 |
|
68 |
generatorjinx = deepcopy(original_generator)
|
69 |
|
|
|
|
|
70 |
|
71 |
|
72 |
transform = transforms.Compose(
|
@@ -95,22 +80,15 @@ os.system("gdown https://drive.google.com/uc?id=1jElwHxaYPod5Itdy18izJk49K1nl4ne
|
|
95 |
ckptjinx = torch.load('arcane_jinx_preserve_color.pt', map_location=lambda storage, loc: storage)
|
96 |
generatorjinx.load_state_dict(ckptjinx["g"], strict=False)
|
97 |
|
|
|
|
|
|
|
|
|
|
|
98 |
|
99 |
def inference(img, model):
|
100 |
-
#face_helper.clean_all()
|
101 |
aligned_face = align_face(img)
|
102 |
-
|
103 |
-
|
104 |
-
#if cropped_array.any():
|
105 |
-
#aligned_face = Image.fromarray(cropped_array)
|
106 |
-
#else:
|
107 |
-
#aligned_face = Image.fromarray(img[:,:,::-1])
|
108 |
-
|
109 |
-
#face_helper.read_image(img)
|
110 |
-
#face_helper.get_face_landmarks_5(only_center_face=False, eye_dist_threshold=10)
|
111 |
-
#face_helper.align_warp_face(save_cropped_path="/home/user/app/")
|
112 |
-
#pilimg = Image.open("/home/user/app/_02.png")
|
113 |
-
|
114 |
my_w = e4e_projection(aligned_face, "test.pt", device).unsqueeze(0)
|
115 |
if model == 'JoJo':
|
116 |
with torch.no_grad():
|
@@ -118,9 +96,12 @@ def inference(img, model):
|
|
118 |
elif model == 'Disney':
|
119 |
with torch.no_grad():
|
120 |
my_sample = generatordisney(my_w, input_is_latent=True)
|
121 |
-
|
122 |
with torch.no_grad():
|
123 |
my_sample = generatorjinx(my_w, input_is_latent=True)
|
|
|
|
|
|
|
124 |
|
125 |
|
126 |
npimage = my_sample[0].permute(1, 2, 0).detach().numpy()
|
@@ -133,4 +114,4 @@ description = "Gradio Demo for JoJoGAN: One Shot Face Stylization. To use it, si
|
|
133 |
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2112.11641' target='_blank'>JoJoGAN: One Shot Face Stylization</a>| <a href='https://github.com/mchong6/JoJoGAN' target='_blank'>Github Repo Pytorch</a></p> <center><img src='https://visitor-badge.glitch.me/badge?page_id=akhaliq_jojogan' alt='visitor badge'></center> <p style='text-align: center'>samples from repo: <img src='https://raw.githubusercontent.com/mchong6/JoJoGAN/main/teaser.jpg' alt='animation'/></p>"
|
134 |
|
135 |
examples=[['iu.jpeg','Jinx']]
|
136 |
-
gr.Interface(inference, [gr.inputs.Image(type="filepath"),gr.inputs.Dropdown(choices=['JoJo', 'Disney','Jinx'], type="value", default='JoJo', label="Model")], gr.outputs.Image(type="file"),title=title,description=description,article=article,enable_queue=True,allow_flagging=False,examples=examples).launch()
|
4 |
import gradio as gr
|
5 |
os.system("pip install gradio==2.5.3")
|
6 |
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
import torch
|
8 |
torch.backends.cudnn.benchmark = True
|
9 |
from torchvision import transforms, utils
|
32 |
os.system("bzip2 -dk shape_predictor_68_face_landmarks.dat.bz2")
|
33 |
os.system("mv shape_predictor_68_face_landmarks.dat models/dlibshape_predictor_68_face_landmarks.dat")
|
34 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
|
36 |
device = 'cpu'
|
37 |
|
39 |
|
40 |
latent_dim = 512
|
41 |
|
|
|
42 |
original_generator = Generator(1024, latent_dim, 8, 2).to(device)
|
43 |
ckpt = torch.load('stylegan2-ffhq-config-f.pt', map_location=lambda storage, loc: storage)
|
44 |
original_generator.load_state_dict(ckpt["g_ema"], strict=False)
|
45 |
mean_latent = original_generator.mean_latent(10000)
|
46 |
|
|
|
47 |
generatorjojo = deepcopy(original_generator)
|
48 |
|
49 |
generatordisney = deepcopy(original_generator)
|
50 |
|
51 |
generatorjinx = deepcopy(original_generator)
|
52 |
|
53 |
+
generatorcaitlyn = deepcopy(original_generator)
|
54 |
+
|
55 |
|
56 |
|
57 |
transform = transforms.Compose(
|
80 |
ckptjinx = torch.load('arcane_jinx_preserve_color.pt', map_location=lambda storage, loc: storage)
|
81 |
generatorjinx.load_state_dict(ckptjinx["g"], strict=False)
|
82 |
|
83 |
+
os.system("gdown https://drive.google.com/uc?id=1cUTyjU-q98P75a8THCaO545RTwpVV-aH")
|
84 |
+
|
85 |
+
ckptcaitlyn = torch.load('arcane_caitlyn_preserve_color.pt', map_location=lambda storage, loc: storage)
|
86 |
+
generatorcaitlyn.load_state_dict(ckptcaitlyn["g"], strict=False)
|
87 |
+
|
88 |
|
89 |
def inference(img, model):
|
|
|
90 |
aligned_face = align_face(img)
|
91 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
my_w = e4e_projection(aligned_face, "test.pt", device).unsqueeze(0)
|
93 |
if model == 'JoJo':
|
94 |
with torch.no_grad():
|
96 |
elif model == 'Disney':
|
97 |
with torch.no_grad():
|
98 |
my_sample = generatordisney(my_w, input_is_latent=True)
|
99 |
+
elif model == 'Jinx':
|
100 |
with torch.no_grad():
|
101 |
my_sample = generatorjinx(my_w, input_is_latent=True)
|
102 |
+
else:
|
103 |
+
with torch.no_grad():
|
104 |
+
my_sample = generatorcaitlyn(my_w, input_is_latent=True)
|
105 |
|
106 |
|
107 |
npimage = my_sample[0].permute(1, 2, 0).detach().numpy()
|
114 |
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2112.11641' target='_blank'>JoJoGAN: One Shot Face Stylization</a>| <a href='https://github.com/mchong6/JoJoGAN' target='_blank'>Github Repo Pytorch</a></p> <center><img src='https://visitor-badge.glitch.me/badge?page_id=akhaliq_jojogan' alt='visitor badge'></center> <p style='text-align: center'>samples from repo: <img src='https://raw.githubusercontent.com/mchong6/JoJoGAN/main/teaser.jpg' alt='animation'/></p>"
|
115 |
|
116 |
examples=[['iu.jpeg','Jinx']]
|
117 |
+
gr.Interface(inference, [gr.inputs.Image(type="filepath"),gr.inputs.Dropdown(choices=['JoJo', 'Disney','Jinx','Caitlyn'], type="value", default='JoJo', label="Model")], gr.outputs.Image(type="file"),title=title,description=description,article=article,enable_queue=True,allow_flagging=False,examples=examples).launch()
|