Spaces:
Runtime error
Runtime error
Ahsen Khaliq
commited on
Commit
·
ec8f0b0
1
Parent(s):
147ab10
Update app.py
Browse files
app.py
CHANGED
@@ -50,6 +50,9 @@ generatorjojo = deepcopy(original_generator)
|
|
50 |
|
51 |
generatordisney = deepcopy(original_generator)
|
52 |
|
|
|
|
|
|
|
53 |
|
54 |
transform = transforms.Compose(
|
55 |
[
|
@@ -72,19 +75,25 @@ os.system("gdown https://drive.google.com/uc?id=1Bnh02DjfvN_Wm8c4JdOiNV4q9J7Z_ts
|
|
72 |
ckptdisney = torch.load('disney_preserve_color.pt', map_location=lambda storage, loc: storage)
|
73 |
generatordisney.load_state_dict(ckptdisney["g"], strict=False)
|
74 |
|
|
|
|
|
|
|
|
|
|
|
75 |
|
76 |
def inference(img, model):
|
77 |
aligned_face = align_face(img)
|
78 |
-
|
79 |
my_w = e4e_projection(aligned_face, "test.pt", device).unsqueeze(0)
|
80 |
if model == 'JoJo':
|
81 |
with torch.no_grad():
|
82 |
-
|
83 |
my_sample = generatorjojo(my_w, input_is_latent=True)
|
84 |
-
|
85 |
with torch.no_grad():
|
86 |
-
|
87 |
my_sample = generatordisney(my_w, input_is_latent=True)
|
|
|
|
|
|
|
|
|
88 |
|
89 |
npimage = my_sample[0].permute(1, 2, 0).detach().numpy()
|
90 |
imageio.imwrite('filename.jpeg', npimage)
|
@@ -95,5 +104,5 @@ description = "Gradio Demo for JoJoGAN: One Shot Face Stylization. To use it, si
|
|
95 |
|
96 |
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2112.11641' target='_blank'>JoJoGAN: One Shot Face Stylization</a>| <a href='https://github.com/mchong6/JoJoGAN' target='_blank'>Github Repo Pytorch</a></p> <center><img src='https://visitor-badge.glitch.me/badge?page_id=akhaliq_jojogan' alt='visitor badge'></center> <p style='text-align: center'>samples from repo: <img src='https://raw.githubusercontent.com/mchong6/JoJoGAN/main/teaser.jpg' alt='animation'/></p>"
|
97 |
|
98 |
-
examples=[['iu.jpeg']]
|
99 |
-
gr.Interface(inference, [gr.inputs.Image(type="filepath",shape=(256,256)),gr.inputs.Dropdown(choices=['JoJo', 'Disney'], type="value", default='JoJo', label="Model")], gr.outputs.Image(type="file"),title=title,description=description,article=article,enable_queue=True,allow_flagging=False,examples=examples).launch()
|
|
|
50 |
|
51 |
generatordisney = deepcopy(original_generator)
|
52 |
|
53 |
+
generatorjinx = deepcopy(original_generator)
|
54 |
+
|
55 |
+
|
56 |
|
57 |
transform = transforms.Compose(
|
58 |
[
|
|
|
75 |
ckptdisney = torch.load('disney_preserve_color.pt', map_location=lambda storage, loc: storage)
|
76 |
generatordisney.load_state_dict(ckptdisney["g"], strict=False)
|
77 |
|
78 |
+
os.system("gdown https://drive.google.com/uc?id=1jElwHxaYPod5Itdy18izJk49K1nl4ney")
|
79 |
+
|
80 |
+
ckptjinx = torch.load('arcane_jinx_preserve_color.pt', map_location=lambda storage, loc: storage)
|
81 |
+
generatorjinx.load_state_dict(ckptjinx["g"], strict=False)
|
82 |
+
|
83 |
|
84 |
def inference(img, model):
|
85 |
aligned_face = align_face(img)
|
|
|
86 |
my_w = e4e_projection(aligned_face, "test.pt", device).unsqueeze(0)
|
87 |
if model == 'JoJo':
|
88 |
with torch.no_grad():
|
|
|
89 |
my_sample = generatorjojo(my_w, input_is_latent=True)
|
90 |
+
elif model == 'Disney':
|
91 |
with torch.no_grad():
|
|
|
92 |
my_sample = generatordisney(my_w, input_is_latent=True)
|
93 |
+
else:
|
94 |
+
with torch.no_grad():
|
95 |
+
my_sample = generatorjinx(my_w, input_is_latent=True)
|
96 |
+
|
97 |
|
98 |
npimage = my_sample[0].permute(1, 2, 0).detach().numpy()
|
99 |
imageio.imwrite('filename.jpeg', npimage)
|
|
|
104 |
|
105 |
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2112.11641' target='_blank'>JoJoGAN: One Shot Face Stylization</a>| <a href='https://github.com/mchong6/JoJoGAN' target='_blank'>Github Repo Pytorch</a></p> <center><img src='https://visitor-badge.glitch.me/badge?page_id=akhaliq_jojogan' alt='visitor badge'></center> <p style='text-align: center'>samples from repo: <img src='https://raw.githubusercontent.com/mchong6/JoJoGAN/main/teaser.jpg' alt='animation'/></p>"
|
106 |
|
107 |
+
examples=[['iu.jpeg','Jinx']]
|
108 |
+
gr.Interface(inference, [gr.inputs.Image(type="filepath",shape=(256,256)),gr.inputs.Dropdown(choices=['JoJo', 'Disney','Jinx'], type="value", default='JoJo', label="Model")], gr.outputs.Image(type="file"),title=title,description=description,article=article,enable_queue=True,allow_flagging=False,examples=examples).launch()
|