Мясников Филипп Сергеевич commited on
Commit
52e0c74
1 Parent(s): 8c43b94
Files changed (1) hide show
  1. app.py +8 -131
app.py CHANGED
@@ -34,7 +34,7 @@ from util import *
34
  from huggingface_hub import hf_hub_download
35
 
36
  device= 'cpu'
37
- model_path_e = hf_hub_download(repo_id="akhaliq/JoJoGAN_e4e_ffhq_encode", filename="e4e_ffhq_encode.pt")
38
  ckpt = torch.load(model_path_e, map_location='cpu')
39
  opts = ckpt['opts']
40
  opts['checkpoint_path'] = model_path_e
@@ -61,144 +61,21 @@ def projection(img, name, device='cuda'):
61
  return w_plus[0]
62
 
63
 
64
-
65
-
66
- device = 'cpu'
67
-
68
-
69
- latent_dim = 512
70
-
71
- model_path_s = hf_hub_download(repo_id="akhaliq/jojogan-stylegan2-ffhq-config-f", filename="stylegan2-ffhq-config-f.pt")
72
- original_generator = Generator(1024, latent_dim, 8, 2).to(device)
73
- ckpt = torch.load(model_path_s, map_location=lambda storage, loc: storage)
74
- original_generator.load_state_dict(ckpt["g_ema"], strict=False)
75
- mean_latent = original_generator.mean_latent(10000)
76
-
77
- generatorjojo = deepcopy(original_generator)
78
-
79
- generatordisney = deepcopy(original_generator)
80
-
81
- generatorjinx = deepcopy(original_generator)
82
-
83
- generatorcaitlyn = deepcopy(original_generator)
84
-
85
- generatoryasuho = deepcopy(original_generator)
86
-
87
- generatorarcanemulti = deepcopy(original_generator)
88
-
89
- generatorart = deepcopy(original_generator)
90
-
91
- generatorspider = deepcopy(original_generator)
92
-
93
- generatorsketch = deepcopy(original_generator)
94
-
95
-
96
- transform = transforms.Compose(
97
- [
98
- transforms.Resize((1024, 1024)),
99
- transforms.ToTensor(),
100
- transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
101
- ]
102
- )
103
-
104
-
105
-
106
-
107
- modeljojo = hf_hub_download(repo_id="akhaliq/JoJoGAN-jojo", filename="jojo_preserve_color.pt")
108
-
109
-
110
- ckptjojo = torch.load(modeljojo, map_location=lambda storage, loc: storage)
111
- generatorjojo.load_state_dict(ckptjojo["g"], strict=False)
112
-
113
-
114
- modeldisney = hf_hub_download(repo_id="akhaliq/jojogan-disney", filename="disney_preserve_color.pt")
115
-
116
- ckptdisney = torch.load(modeldisney, map_location=lambda storage, loc: storage)
117
- generatordisney.load_state_dict(ckptdisney["g"], strict=False)
118
-
119
-
120
- modeljinx = hf_hub_download(repo_id="akhaliq/jojo-gan-jinx", filename="arcane_jinx_preserve_color.pt")
121
-
122
- ckptjinx = torch.load(modeljinx, map_location=lambda storage, loc: storage)
123
- generatorjinx.load_state_dict(ckptjinx["g"], strict=False)
124
-
125
-
126
- modelcaitlyn = hf_hub_download(repo_id="akhaliq/jojogan-arcane", filename="arcane_caitlyn_preserve_color.pt")
127
-
128
- ckptcaitlyn = torch.load(modelcaitlyn, map_location=lambda storage, loc: storage)
129
- generatorcaitlyn.load_state_dict(ckptcaitlyn["g"], strict=False)
130
-
131
-
132
- modelyasuho = hf_hub_download(repo_id="akhaliq/JoJoGAN-jojo", filename="jojo_yasuho_preserve_color.pt")
133
-
134
- ckptyasuho = torch.load(modelyasuho, map_location=lambda storage, loc: storage)
135
- generatoryasuho.load_state_dict(ckptyasuho["g"], strict=False)
136
-
137
-
138
- model_arcane_multi = hf_hub_download(repo_id="akhaliq/jojogan-arcane", filename="arcane_multi_preserve_color.pt")
139
-
140
- ckptarcanemulti = torch.load(model_arcane_multi, map_location=lambda storage, loc: storage)
141
- generatorarcanemulti.load_state_dict(ckptarcanemulti["g"], strict=False)
142
-
143
-
144
- modelart = hf_hub_download(repo_id="akhaliq/jojo-gan-art", filename="art.pt")
145
-
146
- ckptart = torch.load(modelart, map_location=lambda storage, loc: storage)
147
- generatorart.load_state_dict(ckptart["g"], strict=False)
148
-
149
-
150
- modelSpiderverse = hf_hub_download(repo_id="akhaliq/jojo-gan-spiderverse", filename="Spiderverse-face-500iters-8face.pt")
151
-
152
- ckptspider = torch.load(modelSpiderverse, map_location=lambda storage, loc: storage)
153
- generatorspider.load_state_dict(ckptspider["g"], strict=False)
154
-
155
- modelSketch = hf_hub_download(repo_id="akhaliq/jojogan-sketch", filename="sketch_multi.pt")
156
-
157
- ckptsketch = torch.load(modelSketch, map_location=lambda storage, loc: storage)
158
- generatorsketch.load_state_dict(ckptsketch["g"], strict=False)
159
-
160
  def inference(img, model):
161
  img.save('out.jpg')
162
  aligned_face = align_face('out.jpg')
163
 
164
- my_w = projection(aligned_face, "test.pt", device).unsqueeze(0)
165
- if model == 'JoJo':
166
- with torch.no_grad():
167
- my_sample = generatorjojo(my_w, input_is_latent=True)
168
- elif model == 'Disney':
169
- with torch.no_grad():
170
- my_sample = generatordisney(my_w, input_is_latent=True)
171
- elif model == 'Jinx':
172
- with torch.no_grad():
173
- my_sample = generatorjinx(my_w, input_is_latent=True)
174
- elif model == 'Caitlyn':
175
- with torch.no_grad():
176
- my_sample = generatorcaitlyn(my_w, input_is_latent=True)
177
- elif model == 'Yasuho':
178
- with torch.no_grad():
179
- my_sample = generatoryasuho(my_w, input_is_latent=True)
180
- elif model == 'Arcane Multi':
181
- with torch.no_grad():
182
- my_sample = generatorarcanemulti(my_w, input_is_latent=True)
183
- elif model == 'Art':
184
- with torch.no_grad():
185
- my_sample = generatorart(my_w, input_is_latent=True)
186
- elif model == 'Spider-Verse':
187
- with torch.no_grad():
188
- my_sample = generatorspider(my_w, input_is_latent=True)
189
- else:
190
- with torch.no_grad():
191
- my_sample = generatorsketch(my_w, input_is_latent=True)
192
-
193
 
194
- npimage = my_sample[0].permute(1, 2, 0).detach().numpy()
195
  imageio.imwrite('filename.jpeg', npimage)
196
  return 'filename.jpeg'
197
 
198
  title = "JoJoGAN"
199
  description = "Gradio Demo for JoJoGAN: One Shot Face Stylization. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below."
200
 
201
- article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2112.11641' target='_blank'>JoJoGAN: One Shot Face Stylization</a>| <a href='https://github.com/mchong6/JoJoGAN' target='_blank'>Github Repo Pytorch</a></p> <center><img src='https://visitor-badge.glitch.me/badge?page_id=akhaliq_jojogan' alt='visitor badge'></center>"
202
-
203
- examples=[['mona.png','Jinx']]
204
- gr.Interface(inference, [gr.inputs.Image(type="pil"),gr.inputs.Dropdown(choices=['JoJo', 'Disney','Jinx','Caitlyn','Yasuho','Arcane Multi','Art','Spider-Verse','Sketch'], type="value", default='JoJo', label="Model")], gr.outputs.Image(type="file"),title=title,description=description,article=article,allow_flagging=False,examples=examples,allow_screenshot=False).launch()
 
 
34
  from huggingface_hub import hf_hub_download
35
 
36
  device= 'cpu'
37
+ model_path_e = hf_hub_download(repo_id="bankholdup/stylegan_petbreeder", filename="e4e_ffhq5_cat.pt")
38
  ckpt = torch.load(model_path_e, map_location='cpu')
39
  opts = ckpt['opts']
40
  opts['checkpoint_path'] = model_path_e
 
61
  return w_plus[0]
62
 
63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  def inference(img, model):
65
  img.save('out.jpg')
66
  aligned_face = align_face('out.jpg')
67
 
68
+ my_w = projection(aligned_face, "test.pt", device).unsqueeze(0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
+ npimage = my_w.permute(1, 2, 0).detach().numpy()
71
  imageio.imwrite('filename.jpeg', npimage)
72
  return 'filename.jpeg'
73
 
74
  title = "JoJoGAN"
75
  description = "Gradio Demo for JoJoGAN: One Shot Face Stylization. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below."
76
 
77
+ gr.Interface(inference,
78
+ [gr.inputs.Image(type="pil")],
79
+ gr.outputs.Image(type="file"),
80
+ title=title,
81
+ description=description).launch()