Ahsen Khaliq commited on
Commit
11137cc
1 Parent(s): 541a5f9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -2
app.py CHANGED
@@ -102,6 +102,8 @@ generatorart = deepcopy(original_generator)
102
 
103
  generatorspider = deepcopy(original_generator)
104
 
 
 
105
 
106
  transform = transforms.Compose(
107
  [
@@ -162,6 +164,10 @@ modelSpiderverse = hf_hub_download(repo_id="akhaliq/jojo-gan-spiderverse", filen
162
  ckptspider = torch.load(modelSpiderverse, map_location=lambda storage, loc: storage)
163
  generatorspider.load_state_dict(ckptspider["g"], strict=False)
164
 
 
 
 
 
165
 
166
  def inference(img, model):
167
  img.save('out.jpg')
@@ -189,9 +195,12 @@ def inference(img, model):
189
  elif model == 'Art':
190
  with torch.no_grad():
191
  my_sample = generatorart(my_w, input_is_latent=True)
192
- else:
193
  with torch.no_grad():
194
  my_sample = generatorspider(my_w, input_is_latent=True)
 
 
 
195
 
196
 
197
  npimage = my_sample[0].permute(1, 2, 0).detach().numpy()
@@ -204,4 +213,4 @@ description = "Gradio Demo for JoJoGAN: One Shot Face Stylization. To use it, si
204
  article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2112.11641' target='_blank'>JoJoGAN: One Shot Face Stylization</a>| <a href='https://github.com/mchong6/JoJoGAN' target='_blank'>Github Repo Pytorch</a></p> <center><img src='https://visitor-badge.glitch.me/badge?page_id=akhaliq_jojogan' alt='visitor badge'></center>"
205
 
206
  examples=[['mona.png','Jinx']]
207
- gr.Interface(inference, [gr.inputs.Image(type="pil"),gr.inputs.Dropdown(choices=['JoJo', 'Disney','Jinx','Caitlyn','Yasuho','Arcane Multi','Art','Spider-Verse'], type="value", default='JoJo', label="Model")], gr.outputs.Image(type="file"),title=title,description=description,article=article,allow_flagging=False,examples=examples,allow_screenshot=False).launch(enable_queue=True, cache_examples=True)
102
 
103
  generatorspider = deepcopy(original_generator)
104
 
105
+ generatorsketch = deepcopy(original_generator)
106
+
107
 
108
  transform = transforms.Compose(
109
  [
164
  ckptspider = torch.load(modelSpiderverse, map_location=lambda storage, loc: storage)
165
  generatorspider.load_state_dict(ckptspider["g"], strict=False)
166
 
167
+ modelSketch = hf_hub_download(repo_id="akhaliq/akhaliq/jojogan-sketch", filename="sketch_multi.pt")
168
+
169
+ ckptsketch = torch.load(modelSketch, map_location=lambda storage, loc: storage)
170
+ generatorsketch.load_state_dict(ckptsketch["g"], strict=False)
171
 
172
  def inference(img, model):
173
  img.save('out.jpg')
195
  elif model == 'Art':
196
  with torch.no_grad():
197
  my_sample = generatorart(my_w, input_is_latent=True)
198
+ elif model == 'Spider-Verse':
199
  with torch.no_grad():
200
  my_sample = generatorspider(my_w, input_is_latent=True)
201
+ else:
202
+ with torch.no_grad():
203
+ my_sample = generatorsketch(my_w, input_is_latent=True)
204
 
205
 
206
  npimage = my_sample[0].permute(1, 2, 0).detach().numpy()
213
  article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2112.11641' target='_blank'>JoJoGAN: One Shot Face Stylization</a>| <a href='https://github.com/mchong6/JoJoGAN' target='_blank'>Github Repo Pytorch</a></p> <center><img src='https://visitor-badge.glitch.me/badge?page_id=akhaliq_jojogan' alt='visitor badge'></center>"
214
 
215
  examples=[['mona.png','Jinx']]
216
+ gr.Interface(inference, [gr.inputs.Image(type="pil"),gr.inputs.Dropdown(choices=['JoJo', 'Disney','Jinx','Caitlyn','Yasuho','Arcane Multi','Art','Spider-Verse','Sketch'], type="value", default='JoJo', label="Model")], gr.outputs.Image(type="file"),title=title,description=description,article=article,allow_flagging=False,examples=examples,allow_screenshot=False).launch(enable_queue=True, cache_examples=True)