Ahsen Khaliq commited on
Commit
91959e5
1 Parent(s): 46afa05

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -9
app.py CHANGED
@@ -66,19 +66,32 @@ plt.rcParams['figure.dpi'] = 150
66
 
67
  os.system("gdown https://drive.google.com/uc?id=1-8E0PFT37v5fZs-61oIrFbNpE28Unp2y")
68
 
69
- ckpt = torch.load('jojo.pt', map_location=lambda storage, loc: storage)
70
- generator.load_state_dict(ckpt["g"], strict=False)
71
 
72
- def inference(img):
 
 
 
 
 
 
73
  aligned_face = align_face(img)
74
 
75
  my_w = e4e_projection(aligned_face, "test.pt", device).unsqueeze(0)
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
- with torch.no_grad():
78
- generator.eval()
79
-
80
- original_my_sample = original_generator(my_w, input_is_latent=True)
81
- my_sample = generator(my_w, input_is_latent=True)
82
  npimage = my_sample[0].permute(1, 2, 0).detach().numpy()
83
  imageio.imwrite('filename.jpeg', npimage)
84
  return 'filename.jpeg'
@@ -89,4 +102,4 @@ description = "Gradio Demo for JoJoGAN: One Shot Face Stylization. To use it, si
89
  article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2112.11641' target='_blank'>JoJoGAN: One Shot Face Stylization</a>| <a href='https://github.com/mchong6/JoJoGAN' target='_blank'>Github Repo Pytorch</a></p> <center><img src='https://visitor-badge.glitch.me/badge?page_id=akhaliq_jojogan' alt='visitor badge'></center> <p style='text-align: center'>samples from repo: <img src='https://raw.githubusercontent.com/mchong6/JoJoGAN/main/teaser.jpg' alt='animation'/></p>"
90
 
91
  examples=[['iu.jpeg']]
92
- gr.Interface(inference, [gr.inputs.Image(type="filepath",shape=(256,256))], gr.outputs.Image(type="file"),title=title,description=description,article=article,enable_queue=True,allow_flagging=False,examples=examples).launch()
66
 
67
  os.system("gdown https://drive.google.com/uc?id=1-8E0PFT37v5fZs-61oIrFbNpE28Unp2y")
68
 
69
+ ckptjojo = torch.load('jojo.pt', map_location=lambda storage, loc: storage)
70
+ generatorjojo.load_state_dict(ckptjojo["g"], strict=False)
71
 
72
+ os.system("gdown https://drive.google.com/uc?id=1Bnh02DjfvN_Wm8c4JdOiNV4q9J7Z_tsi")
73
+
74
+ ckptdisney = torch.load('disney_preserve_color.pt', map_location=lambda storage, loc: storage)
75
+ generatordisney.load_state_dict(ckptdisney["g"], strict=False)
76
+
77
+
78
+ def inference(img, model):
79
  aligned_face = align_face(img)
80
 
81
  my_w = e4e_projection(aligned_face, "test.pt", device).unsqueeze(0)
82
+ if model == 'JoJo':
83
+ with torch.no_grad():
84
+ generator.eval()
85
+
86
+ #original_my_sample = original_generator(my_w, input_is_latent=True)
87
+ my_sample = generatorjojo(my_w, input_is_latent=True)
88
+ else:
89
+ with torch.no_grad():
90
+ generator.eval()
91
+
92
+ #original_my_sample = original_generator(my_w, input_is_latent=True)
93
+ my_sample = generatordisney(my_w, input_is_latent=True)
94
 
 
 
 
 
 
95
  npimage = my_sample[0].permute(1, 2, 0).detach().numpy()
96
  imageio.imwrite('filename.jpeg', npimage)
97
  return 'filename.jpeg'
102
  article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2112.11641' target='_blank'>JoJoGAN: One Shot Face Stylization</a>| <a href='https://github.com/mchong6/JoJoGAN' target='_blank'>Github Repo Pytorch</a></p> <center><img src='https://visitor-badge.glitch.me/badge?page_id=akhaliq_jojogan' alt='visitor badge'></center> <p style='text-align: center'>samples from repo: <img src='https://raw.githubusercontent.com/mchong6/JoJoGAN/main/teaser.jpg' alt='animation'/></p>"
103
 
104
  examples=[['iu.jpeg']]
105
+ gr.Interface(inference, [gr.inputs.Image(type="filepath",shape=(256,256)),gradio.inputs.Dropdown(choices=['JoJo', 'Disney'], type="value", default='JoJo', label="Model")], gr.outputs.Image(type="file"),title=title,description=description,article=article,enable_queue=True,allow_flagging=False,examples=examples).launch()