Ahsen Khaliq commited on
Commit
e00d153
1 Parent(s): 3df6d65

add arcane multi model

Browse files
Files changed (1) hide show
  1. app.py +13 -3
app.py CHANGED
@@ -3,7 +3,6 @@ from PIL import Image
3
  import torch
4
  import gradio as gr
5
  os.system("pip install gradio==2.5.3")
6
-
7
  import torch
8
  torch.backends.cudnn.benchmark = True
9
  from torchvision import transforms, utils
@@ -54,6 +53,9 @@ generatorcaitlyn = deepcopy(original_generator)
54
 
55
  generatoryasuho = deepcopy(original_generator)
56
 
 
 
 
57
 
58
 
59
 
@@ -93,6 +95,11 @@ os.system("gdown https://drive.google.com/uc?id=1SKBu1h0iRNyeKBnya_3BBmLr4pkPeg_
93
  ckptyasuho = torch.load('jojo_yasuho_preserve_color.pt', map_location=lambda storage, loc: storage)
94
  generatoryasuho.load_state_dict(ckptyasuho["g"], strict=False)
95
 
 
 
 
 
 
96
 
97
  def inference(img, model):
98
  aligned_face = align_face(img)
@@ -110,9 +117,12 @@ def inference(img, model):
110
  elif model == 'Caitlyn':
111
  with torch.no_grad():
112
  my_sample = generatorcaitlyn(my_w, input_is_latent=True)
113
- else:
114
  with torch.no_grad():
115
  my_sample = generatoryasuho(my_w, input_is_latent=True)
 
 
 
116
 
117
 
118
  npimage = my_sample[0].permute(1, 2, 0).detach().numpy()
@@ -125,4 +135,4 @@ description = "Gradio Demo for JoJoGAN: One Shot Face Stylization. To use it, si
125
  article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2112.11641' target='_blank'>JoJoGAN: One Shot Face Stylization</a>| <a href='https://github.com/mchong6/JoJoGAN' target='_blank'>Github Repo Pytorch</a></p> <center><img src='https://visitor-badge.glitch.me/badge?page_id=akhaliq_jojogan' alt='visitor badge'></center> <p style='text-align: center'>samples from repo: <img src='https://raw.githubusercontent.com/mchong6/JoJoGAN/main/teaser.jpg' alt='animation'/></p>"
126
 
127
  examples=[['iu.jpeg','Jinx']]
128
- gr.Interface(inference, [gr.inputs.Image(type="filepath"),gr.inputs.Dropdown(choices=['JoJo', 'Disney','Jinx','Caitlyn','Yasuho'], type="value", default='JoJo', label="Model")], gr.outputs.Image(type="file"),title=title,description=description,article=article,enable_queue=True,allow_flagging=False,examples=examples).launch()
3
  import torch
4
  import gradio as gr
5
  os.system("pip install gradio==2.5.3")
 
6
  import torch
7
  torch.backends.cudnn.benchmark = True
8
  from torchvision import transforms, utils
53
 
54
  generatoryasuho = deepcopy(original_generator)
55
 
56
+ generatorarcanemulti = deepcopy(original_generator)
57
+
58
+
59
 
60
 
61
 
95
  ckptyasuho = torch.load('jojo_yasuho_preserve_color.pt', map_location=lambda storage, loc: storage)
96
  generatoryasuho.load_state_dict(ckptyasuho["g"], strict=False)
97
 
98
+ os.system("gdown https://drive.google.com/uc?id=1enJgrC08NpWpx2XGBmLt1laimjpGCyfl")
99
+
100
+ ckptarcanemulti = torch.load('arcane_multi_preserve_color.pt', map_location=lambda storage, loc: storage)
101
+ generatorarcanemulti.load_state_dict(ckptarcanemulti["g"], strict=False)
102
+
103
 
104
  def inference(img, model):
105
  aligned_face = align_face(img)
117
  elif model == 'Caitlyn':
118
  with torch.no_grad():
119
  my_sample = generatorcaitlyn(my_w, input_is_latent=True)
120
+ elif model == 'Yasuho':
121
  with torch.no_grad():
122
  my_sample = generatoryasuho(my_w, input_is_latent=True)
123
+ else:
124
+ with torch.no_grad():
125
+ my_sample = generatorarcanemulti(my_w, input_is_latent=True)
126
 
127
 
128
  npimage = my_sample[0].permute(1, 2, 0).detach().numpy()
135
  article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2112.11641' target='_blank'>JoJoGAN: One Shot Face Stylization</a>| <a href='https://github.com/mchong6/JoJoGAN' target='_blank'>Github Repo Pytorch</a></p> <center><img src='https://visitor-badge.glitch.me/badge?page_id=akhaliq_jojogan' alt='visitor badge'></center> <p style='text-align: center'>samples from repo: <img src='https://raw.githubusercontent.com/mchong6/JoJoGAN/main/teaser.jpg' alt='animation'/></p>"
136
 
137
  examples=[['iu.jpeg','Jinx']]
138
+ gr.Interface(inference, [gr.inputs.Image(type="filepath"),gr.inputs.Dropdown(choices=['JoJo', 'Disney','Jinx','Caitlyn','Yasuho','Arcane Multi'], type="value", default='JoJo', label="Model")], gr.outputs.Image(type="file"),title=title,description=description,article=article,enable_queue=True,allow_flagging=False,examples=examples).launch()