Ahsen Khaliq commited on
Commit
ec8f0b0
1 Parent(s): 147ab10

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -6
app.py CHANGED
@@ -50,6 +50,9 @@ generatorjojo = deepcopy(original_generator)
50
 
51
  generatordisney = deepcopy(original_generator)
52
 
 
 
 
53
 
54
  transform = transforms.Compose(
55
  [
@@ -72,19 +75,25 @@ os.system("gdown https://drive.google.com/uc?id=1Bnh02DjfvN_Wm8c4JdOiNV4q9J7Z_ts
72
  ckptdisney = torch.load('disney_preserve_color.pt', map_location=lambda storage, loc: storage)
73
  generatordisney.load_state_dict(ckptdisney["g"], strict=False)
74
 
 
 
 
 
 
75
 
76
  def inference(img, model):
77
  aligned_face = align_face(img)
78
-
79
  my_w = e4e_projection(aligned_face, "test.pt", device).unsqueeze(0)
80
  if model == 'JoJo':
81
  with torch.no_grad():
82
-
83
  my_sample = generatorjojo(my_w, input_is_latent=True)
84
- else:
85
  with torch.no_grad():
86
-
87
  my_sample = generatordisney(my_w, input_is_latent=True)
 
 
 
 
88
 
89
  npimage = my_sample[0].permute(1, 2, 0).detach().numpy()
90
  imageio.imwrite('filename.jpeg', npimage)
@@ -95,5 +104,5 @@ description = "Gradio Demo for JoJoGAN: One Shot Face Stylization. To use it, si
95
 
96
  article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2112.11641' target='_blank'>JoJoGAN: One Shot Face Stylization</a>| <a href='https://github.com/mchong6/JoJoGAN' target='_blank'>Github Repo Pytorch</a></p> <center><img src='https://visitor-badge.glitch.me/badge?page_id=akhaliq_jojogan' alt='visitor badge'></center> <p style='text-align: center'>samples from repo: <img src='https://raw.githubusercontent.com/mchong6/JoJoGAN/main/teaser.jpg' alt='animation'/></p>"
97
 
98
- examples=[['iu.jpeg']]
99
- gr.Interface(inference, [gr.inputs.Image(type="filepath",shape=(256,256)),gr.inputs.Dropdown(choices=['JoJo', 'Disney'], type="value", default='JoJo', label="Model")], gr.outputs.Image(type="file"),title=title,description=description,article=article,enable_queue=True,allow_flagging=False,examples=examples).launch()
50
 
51
  generatordisney = deepcopy(original_generator)
52
 
53
+ generatorjinx = deepcopy(original_generator)
54
+
55
+
56
 
57
  transform = transforms.Compose(
58
  [
75
  ckptdisney = torch.load('disney_preserve_color.pt', map_location=lambda storage, loc: storage)
76
  generatordisney.load_state_dict(ckptdisney["g"], strict=False)
77
 
78
+ os.system("gdown https://drive.google.com/uc?id=1jElwHxaYPod5Itdy18izJk49K1nl4ney")
79
+
80
+ ckptjinx = torch.load('arcane_jinx_preserve_color.pt', map_location=lambda storage, loc: storage)
81
+ generatorjinx.load_state_dict(ckptjinx["g"], strict=False)
82
+
83
 
84
  def inference(img, model):
85
  aligned_face = align_face(img)
 
86
  my_w = e4e_projection(aligned_face, "test.pt", device).unsqueeze(0)
87
  if model == 'JoJo':
88
  with torch.no_grad():
 
89
  my_sample = generatorjojo(my_w, input_is_latent=True)
90
+ elif model == 'Disney':
91
  with torch.no_grad():
 
92
  my_sample = generatordisney(my_w, input_is_latent=True)
93
+ else:
94
+ with torch.no_grad():
95
+ my_sample = generatorjinx(my_w, input_is_latent=True)
96
+
97
 
98
  npimage = my_sample[0].permute(1, 2, 0).detach().numpy()
99
  imageio.imwrite('filename.jpeg', npimage)
104
 
105
  article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2112.11641' target='_blank'>JoJoGAN: One Shot Face Stylization</a>| <a href='https://github.com/mchong6/JoJoGAN' target='_blank'>Github Repo Pytorch</a></p> <center><img src='https://visitor-badge.glitch.me/badge?page_id=akhaliq_jojogan' alt='visitor badge'></center> <p style='text-align: center'>samples from repo: <img src='https://raw.githubusercontent.com/mchong6/JoJoGAN/main/teaser.jpg' alt='animation'/></p>"
106
 
107
+ examples=[['iu.jpeg','Jinx']]
108
+ gr.Interface(inference, [gr.inputs.Image(type="filepath",shape=(256,256)),gr.inputs.Dropdown(choices=['JoJo', 'Disney','Jinx'], type="value", default='JoJo', label="Model")], gr.outputs.Image(type="file"),title=title,description=description,article=article,enable_queue=True,allow_flagging=False,examples=examples).launch()