algomuffin commited on
Commit
6c1e309
1 Parent(s): 80774a6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -1
app.py CHANGED
@@ -96,6 +96,7 @@ generatorspider = deepcopy(original_generator)
96
 
97
  generatorsketch = deepcopy(original_generator)
98
 
 
99
 
100
  transform = transforms.Compose(
101
  [
@@ -161,6 +162,15 @@ modelSketch = hf_hub_download(repo_id="akhaliq/jojogan-sketch", filename="sketch
161
  ckptsketch = torch.load(modelSketch, map_location=lambda storage, loc: storage)
162
  generatorsketch.load_state_dict(ckptsketch["g"], strict=False)
163
 
 
 
 
 
 
 
 
 
 
164
  def inference(img, model):
165
  img.save('out.jpg')
166
  aligned_face = align_face('out.jpg')
@@ -190,6 +200,9 @@ def inference(img, model):
190
  elif model == 'Spider-Verse':
191
  with torch.no_grad():
192
  my_sample = generatorspider(my_w, input_is_latent=True)
 
 
 
193
  else:
194
  with torch.no_grad():
195
  my_sample = generatorsketch(my_w, input_is_latent=True)
@@ -205,4 +218,4 @@ description = "Gradio Demo for JoJoGAN: This is a fork made by algomuffin in ord
205
  article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2112.11641' target='_blank'>JoJoGAN: One Shot Face Stylization</a>| <a href='https://github.com/mchong6/JoJoGAN' target='_blank'>Github Repo Pytorch</a></p> <center><img src='https://visitor-badge.glitch.me/badge?page_id=akhaliq_jojogan' alt='visitor badge'></center>"
206
 
207
  examples=[['baby-face.jpg','Jinx']]
208
- gr.Interface(inference, [gr.inputs.Image(type="pil"),gr.inputs.Dropdown(choices=['JoJo', 'Disney','Jinx','Caitlyn','Yasuho','Arcane Multi','Art','Spider-Verse','Sketch'], type="value", default='JoJo', label="Model")], gr.outputs.Image(type="file"),title=title,description=description,article=article,allow_flagging=False,examples=examples,allow_screenshot=False).launch(enable_queue=True, cache_examples=True)
 
96
 
97
  generatorsketch = deepcopy(original_generator)
98
 
99
+ generatorMy = deepcopy(original_generator)
100
 
101
  transform = transforms.Compose(
102
  [
 
162
  ckptsketch = torch.load(modelSketch, map_location=lambda storage, loc: storage)
163
  generatorsketch.load_state_dict(ckptsketch["g"], strict=False)
164
 
165
+
166
+ modelMy = hf_hub_download(repo_id="algomuffin/my_model", filename="my.pt")
167
+
168
+ ckptMy = torch.load(modelMy, map_location=lambda storage, loc: storage)
169
+ generatorMy.load_state_dict(ckptMy["g"], strict=False)
170
+
171
+
172
+
173
+
174
  def inference(img, model):
175
  img.save('out.jpg')
176
  aligned_face = align_face('out.jpg')
 
200
  elif model == 'Spider-Verse':
201
  with torch.no_grad():
202
  my_sample = generatorspider(my_w, input_is_latent=True)
203
+ elif model == 'My-model':
204
+ with torch.no_grad():
205
+ my_sample = generatorMy(my_w, input_is_latent=True)
206
  else:
207
  with torch.no_grad():
208
  my_sample = generatorsketch(my_w, input_is_latent=True)
 
218
  article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2112.11641' target='_blank'>JoJoGAN: One Shot Face Stylization</a>| <a href='https://github.com/mchong6/JoJoGAN' target='_blank'>Github Repo Pytorch</a></p> <center><img src='https://visitor-badge.glitch.me/badge?page_id=akhaliq_jojogan' alt='visitor badge'></center>"
219
 
220
  examples=[['baby-face.jpg','Jinx']]
221
+ gr.Interface(inference, [gr.inputs.Image(type="pil"),gr.inputs.Dropdown(choices=['JoJo', 'Disney','Jinx','Caitlyn','Yasuho','Arcane Multi','Art','Spider-Verse','My-model','Sketch'], type="value", default='JoJo', label="Model")], gr.outputs.Image(type="file"),title=title,description=description,article=article,allow_flagging=False,examples=examples,allow_screenshot=False).launch(enable_queue=True, cache_examples=True)