hysts HF staff commited on
Commit
9124d49
1 Parent(s): f7662a1

Add other style types

Browse files
Files changed (1) hide show
  1. app.py +47 -19
app.py CHANGED
@@ -129,15 +129,22 @@ def postprocess(tensor: torch.Tensor) -> PIL.Image.Image:
129
  @torch.inference_mode()
130
  def run(
131
  image,
132
- style_id: int,
 
133
  dlib_landmark_model,
134
  encoder: nn.Module,
135
- generator: nn.Module,
136
- exstyles: dict[str, np.ndarray],
137
  transform: Callable,
138
  device: torch.device,
139
- style_image_dir: pathlib.Path,
140
- ) -> tuple[PIL.Image.Image, PIL.Image.Image, PIL.Image.Image, PIL.Image.Image]:
 
 
 
 
 
 
141
  stylename = list(exstyles.keys())[style_id]
142
 
143
  image = align_face(filepath=image.name, predictor=dlib_landmark_model)
@@ -181,7 +188,11 @@ def run(
181
  img_gen1 = postprocess(img_gen[1])
182
  img_gen2 = postprocess(img_gen2[0])
183
 
184
- style_image = PIL.Image.open(style_image_dir / stylename)
 
 
 
 
185
 
186
  return image, style_image, img_rec, img_gen0, img_gen1, img_gen2
187
 
@@ -192,43 +203,60 @@ def main():
192
  args = parse_args()
193
  device = torch.device(args.device)
194
 
195
- style_type = 'cartoon'
196
- style_image_dir = pathlib.Path(style_type)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
 
198
  download_cartoon_images()
199
  dlib_landmark_model = create_dlib_landmark_model()
200
  encoder = load_encoder(device)
201
- generator = load_generator(style_type, device)
202
- exstyles = load_exstylecode(style_type)
203
  transform = create_transform()
204
 
205
  func = functools.partial(run,
206
  dlib_landmark_model=dlib_landmark_model,
207
  encoder=encoder,
208
- generator=generator,
209
- exstyles=exstyles,
210
  transform=transform,
211
- device=device,
212
- style_image_dir=style_image_dir)
213
  func = functools.update_wrapper(func, run)
214
 
215
  repo_url = 'https://github.com/williamyang1991/DualStyleGAN'
216
  title = 'williamyang1991/DualStyleGAN'
217
  description = f"""A demo for {repo_url}
218
 
219
- You can select style images from the table below.
220
  """
221
  article = '![cartoon style images](https://user-images.githubusercontent.com/18130694/159848447-96fa5194-32ec-42f0-945a-3b1958bf6e5e.jpg)'
222
 
223
  image_paths = sorted(pathlib.Path('images').glob('*'))
224
- examples = [[path.as_posix(), 26] for path in image_paths]
225
 
226
  gr.Interface(
227
  func,
228
  [
229
- gr.inputs.Image(type='file', label='Image'),
230
- gr.inputs.Slider(
231
- 0, 316, step=1, default=26, label='Style Image Index'),
 
 
 
 
 
232
  ],
233
  [
234
  gr.outputs.Image(type='pil', label='Aligned Face'),
 
129
  @torch.inference_mode()
130
  def run(
131
  image,
132
+ style_type: str,
133
+ style_id: float,
134
  dlib_landmark_model,
135
  encoder: nn.Module,
136
+ generator_dict: dict[str, nn.Module],
137
+ exstyle_dict: dict[str, dict[str, np.ndarray]],
138
  transform: Callable,
139
  device: torch.device,
140
+ ) -> tuple[PIL.Image.Image, PIL.Image.Image, PIL.Image.Image, PIL.Image.Image,
141
+ PIL.Image, PIL.Image]:
142
+ generator = generator_dict[style_type]
143
+ exstyles = exstyle_dict[style_type]
144
+
145
+ style_id = int(style_id)
146
+ style_id = min(max(0, style_id), len(exstyles) - 1)
147
+
148
  stylename = list(exstyles.keys())[style_id]
149
 
150
  image = align_face(filepath=image.name, predictor=dlib_landmark_model)
 
188
  img_gen1 = postprocess(img_gen[1])
189
  img_gen2 = postprocess(img_gen2[0])
190
 
191
+ try:
192
+ style_image_dir = pathlib.Path(style_type)
193
+ style_image = PIL.Image.open(style_image_dir / stylename)
194
+ except Exception:
195
+ style_image = None
196
 
197
  return image, style_image, img_rec, img_gen0, img_gen1, img_gen2
198
 
 
203
  args = parse_args()
204
  device = torch.device(args.device)
205
 
206
+ style_types = [
207
+ 'cartoon',
208
+ 'caricature',
209
+ 'anime',
210
+ 'arcane',
211
+ 'comic',
212
+ 'pixar',
213
+ 'slamdunk',
214
+ ]
215
+ generator_dict = {
216
+ style_type: load_generator(style_type, device)
217
+ for style_type in style_types
218
+ }
219
+ exstyle_dict = {
220
+ style_type: load_exstylecode(style_type)
221
+ for style_type in style_types
222
+ }
223
 
224
  download_cartoon_images()
225
  dlib_landmark_model = create_dlib_landmark_model()
226
  encoder = load_encoder(device)
 
 
227
  transform = create_transform()
228
 
229
  func = functools.partial(run,
230
  dlib_landmark_model=dlib_landmark_model,
231
  encoder=encoder,
232
+ generator_dict=generator_dict,
233
+ exstyle_dict=exstyle_dict,
234
  transform=transform,
235
+ device=device)
 
236
  func = functools.update_wrapper(func, run)
237
 
238
  repo_url = 'https://github.com/williamyang1991/DualStyleGAN'
239
  title = 'williamyang1991/DualStyleGAN'
240
  description = f"""A demo for {repo_url}
241
 
242
+ You can select style images for cartoon from the table below.
243
  """
244
  article = '![cartoon style images](https://user-images.githubusercontent.com/18130694/159848447-96fa5194-32ec-42f0-945a-3b1958bf6e5e.jpg)'
245
 
246
  image_paths = sorted(pathlib.Path('images').glob('*'))
247
+ examples = [[path.as_posix(), 'cartoon', 26] for path in image_paths]
248
 
249
  gr.Interface(
250
  func,
251
  [
252
+ gr.inputs.Image(type='file', label='Input Image'),
253
+ gr.inputs.Radio(
254
+ style_types,
255
+ type='value',
256
+ default='cartoon',
257
+ label='Style Type',
258
+ ),
259
+ gr.inputs.Number(default=26, label='Style Image Index'),
260
  ],
261
  [
262
  gr.outputs.Image(type='pil', label='Aligned Face'),