Spaces:
Runtime error
Runtime error
Ahsen Khaliq
commited on
Commit
•
ddaf006
1
Parent(s):
79681a1
Update app.py
Browse files
app.py
CHANGED
@@ -2,7 +2,7 @@ import os
|
|
2 |
from PIL import Image
|
3 |
import torch
|
4 |
import gradio as gr
|
5 |
-
os.system("pip install dlib")
|
6 |
os.system("git clone https://github.com/mchong6/JoJoGAN.git")
|
7 |
os.chdir("JoJoGAN")
|
8 |
|
@@ -29,9 +29,9 @@ os.makedirs('style_images', exist_ok=True)
|
|
29 |
os.makedirs('style_images_aligned', exist_ok=True)
|
30 |
os.makedirs('models', exist_ok=True)
|
31 |
|
32 |
-
os.system("wget http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2")
|
33 |
-
os.system("bzip2 -dk shape_predictor_68_face_landmarks.dat.bz2")
|
34 |
-
os.system("mv shape_predictor_68_face_landmarks.dat models/dlibshape_predictor_68_face_landmarks.dat")
|
35 |
|
36 |
|
37 |
device = 'cpu'
|
@@ -44,7 +44,7 @@ latent_dim = 512
|
|
44 |
|
45 |
# Load original generator
|
46 |
original_generator = Generator(1024, latent_dim, 8, 2).to(device)
|
47 |
-
ckpt = torch.load(
|
48 |
original_generator.load_state_dict(ckpt["g_ema"], strict=False)
|
49 |
mean_latent = original_generator.mean_latent(10000)
|
50 |
|
@@ -60,21 +60,12 @@ transform = transforms.Compose(
|
|
60 |
)
|
61 |
|
62 |
plt.rcParams['figure.dpi'] = 150
|
63 |
-
os.system("gdown https://drive.google.com/uc?id=1-7UlCppmiG4DKbhYDNbIZTc6mHy9JMWJ")
|
64 |
-
os.system("mv e4e_ffhq_encode.pt models/e4e_ffhq_encode.pt")
|
65 |
|
66 |
|
67 |
os.system("gdown https://drive.google.com/uc?id=1-8E0PFT37v5fZs-61oIrFbNpE28Unp2y")
|
68 |
-
os.system("mv e4e_ffhq_encode.pt models/jojo.pt")
|
69 |
|
70 |
-
def inference(img):
|
71 |
-
|
72 |
-
|
73 |
-
name = strip_path_extension(filepath)+'.pt'
|
74 |
-
|
75 |
-
aligned_face = align_face(filepath)
|
76 |
-
|
77 |
-
my_w = e4e_projection(aligned_face, name, device).unsqueeze(0)
|
78 |
|
79 |
|
80 |
plt.rcParams['figure.dpi'] = 150
|
@@ -83,7 +74,7 @@ def inference(img):
|
|
83 |
preserve_color = False
|
84 |
|
85 |
|
86 |
-
ckpt = torch.load(
|
87 |
generator.load_state_dict(ckpt["g"], strict=False)
|
88 |
|
89 |
with torch.no_grad():
|
@@ -98,4 +89,4 @@ title = "AnimeGANv2"
|
|
98 |
description = "Gradio Demo for AnimeGanv2 Face Portrait. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below. Please use a cropped portrait picture for best results similar to the examples below."
|
99 |
article = "<p style='text-align: center'><a href='https://github.com/bryandlee/animegan2-pytorch' target='_blank'>Github Repo Pytorch</a></p> <center><img src='https://visitor-badge.glitch.me/badge?page_id=akhaliq_animegan' alt='visitor badge'></center> <p style='text-align: center'>samples from repo: <img src='https://user-images.githubusercontent.com/26464535/129888683-98bb6283-7bb8-4d1a-a04a-e795f5858dcf.gif' alt='animation'/> <img src='https://user-images.githubusercontent.com/26464535/137619176-59620b59-4e20-4d98-9559-a424f86b7f24.jpg' alt='animation'/><img src='https://user-images.githubusercontent.com/26464535/127134790-93595da2-4f8b-4aca-a9d7-98699c5e6914.jpg' alt='animation'/></p>"
|
100 |
|
101 |
-
gr.Interface(inference, [gr.inputs.Image(type="
|
|
|
2 |
from PIL import Image
|
3 |
import torch
|
4 |
import gradio as gr
|
5 |
+
#os.system("pip install dlib")
|
6 |
os.system("git clone https://github.com/mchong6/JoJoGAN.git")
|
7 |
os.chdir("JoJoGAN")
|
8 |
|
|
|
29 |
os.makedirs('style_images_aligned', exist_ok=True)
|
30 |
os.makedirs('models', exist_ok=True)
|
31 |
|
32 |
+
#os.system("wget http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2")
|
33 |
+
#os.system("bzip2 -dk shape_predictor_68_face_landmarks.dat.bz2")
|
34 |
+
#os.system("mv shape_predictor_68_face_landmarks.dat models/dlibshape_predictor_68_face_landmarks.dat")
|
35 |
|
36 |
|
37 |
device = 'cpu'
|
|
|
44 |
|
45 |
# Load original generator
|
46 |
original_generator = Generator(1024, latent_dim, 8, 2).to(device)
|
47 |
+
ckpt = torch.load('stylegan2-ffhq-config-f.pt', map_location=lambda storage, loc: storage)
|
48 |
original_generator.load_state_dict(ckpt["g_ema"], strict=False)
|
49 |
mean_latent = original_generator.mean_latent(10000)
|
50 |
|
|
|
60 |
)
|
61 |
|
62 |
plt.rcParams['figure.dpi'] = 150
|
|
|
|
|
63 |
|
64 |
|
65 |
os.system("gdown https://drive.google.com/uc?id=1-8E0PFT37v5fZs-61oIrFbNpE28Unp2y")
|
|
|
66 |
|
67 |
+
def inference(img):
|
68 |
+
my_w = e4e_projection(img, "test.pt", device).unsqueeze(0)
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
|
70 |
|
71 |
plt.rcParams['figure.dpi'] = 150
|
|
|
74 |
preserve_color = False
|
75 |
|
76 |
|
77 |
+
ckpt = torch.load('jojo.pt', map_location=lambda storage, loc: storage)
|
78 |
generator.load_state_dict(ckpt["g"], strict=False)
|
79 |
|
80 |
with torch.no_grad():
|
|
|
89 |
description = "Gradio Demo for AnimeGanv2 Face Portrait. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below. Please use a cropped portrait picture for best results similar to the examples below."
|
90 |
article = "<p style='text-align: center'><a href='https://github.com/bryandlee/animegan2-pytorch' target='_blank'>Github Repo Pytorch</a></p> <center><img src='https://visitor-badge.glitch.me/badge?page_id=akhaliq_animegan' alt='visitor badge'></center> <p style='text-align: center'>samples from repo: <img src='https://user-images.githubusercontent.com/26464535/129888683-98bb6283-7bb8-4d1a-a04a-e795f5858dcf.gif' alt='animation'/> <img src='https://user-images.githubusercontent.com/26464535/137619176-59620b59-4e20-4d98-9559-a424f86b7f24.jpg' alt='animation'/><img src='https://user-images.githubusercontent.com/26464535/127134790-93595da2-4f8b-4aca-a9d7-98699c5e6914.jpg' alt='animation'/></p>"
|
91 |
|
92 |
+
gr.Interface(inference, [gr.inputs.Image(type="pil")], gr.outputs.Image(type="numpy"),title=title,description=description,article=article,enable_queue=True,allow_flagging=False).launch()
|