hysts HF staff commited on
Commit
8e6dd19
1 Parent(s): 9124d49
Files changed (1) hide show
  1. app.py +22 -32
app.py CHANGED
@@ -7,7 +7,6 @@ import functools
7
  import os
8
  import pathlib
9
  import sys
10
- import tarfile
11
  from typing import Callable
12
 
13
  if os.environ.get('SYSTEM') == 'spaces':
@@ -29,6 +28,24 @@ from model.encoder.align_all_parallel import align_face
29
  from model.encoder.psp import pSp
30
  from util import load_image, visualize
31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  TOKEN = os.environ['TOKEN']
33
 
34
  MODEL_REPO = 'hysts/DualStyleGAN'
@@ -49,17 +66,6 @@ def parse_args() -> argparse.Namespace:
49
  return parser.parse_args()
50
 
51
 
52
- def download_cartoon_images() -> None:
53
- image_dir = pathlib.Path('cartoon')
54
- if not image_dir.exists():
55
- path = huggingface_hub.hf_hub_download('hysts/DualStyleGAN-Cartoon',
56
- 'cartoon.tar.gz',
57
- repo_type='dataset',
58
- use_auth_token=TOKEN)
59
- with tarfile.open(path) as f:
60
- f.extractall()
61
-
62
-
63
  def load_encoder(device: torch.device) -> nn.Module:
64
  ckpt_path = huggingface_hub.hf_hub_download(MODEL_REPO,
65
  'models/encoder.pt',
@@ -188,13 +194,7 @@ def run(
188
  img_gen1 = postprocess(img_gen[1])
189
  img_gen2 = postprocess(img_gen2[0])
190
 
191
- try:
192
- style_image_dir = pathlib.Path(style_type)
193
- style_image = PIL.Image.open(style_image_dir / stylename)
194
- except Exception:
195
- style_image = None
196
-
197
- return image, style_image, img_rec, img_gen0, img_gen1, img_gen2
198
 
199
 
200
  def main():
@@ -221,7 +221,6 @@ def main():
221
  for style_type in style_types
222
  }
223
 
224
- download_cartoon_images()
225
  dlib_landmark_model = create_dlib_landmark_model()
226
  encoder = load_encoder(device)
227
  transform = create_transform()
@@ -235,14 +234,6 @@ def main():
235
  device=device)
236
  func = functools.update_wrapper(func, run)
237
 
238
- repo_url = 'https://github.com/williamyang1991/DualStyleGAN'
239
- title = 'williamyang1991/DualStyleGAN'
240
- description = f"""A demo for {repo_url}
241
-
242
- You can select style images for cartoon from the table below.
243
- """
244
- article = '![cartoon style images](https://user-images.githubusercontent.com/18130694/159848447-96fa5194-32ec-42f0-945a-3b1958bf6e5e.jpg)'
245
-
246
  image_paths = sorted(pathlib.Path('images').glob('*'))
247
  examples = [[path.as_posix(), 'cartoon', 26] for path in image_paths]
248
 
@@ -260,7 +251,6 @@ def main():
260
  ],
261
  [
262
  gr.outputs.Image(type='pil', label='Aligned Face'),
263
- gr.outputs.Image(type='pil', label='Selected Style Image'),
264
  gr.outputs.Image(type='pil', label='Reconstructed'),
265
  gr.outputs.Image(type='pil', label='Result 1'),
266
  gr.outputs.Image(type='pil', label='Result 2'),
@@ -268,9 +258,9 @@ def main():
268
  ],
269
  examples=examples,
270
  theme=args.theme,
271
- title=title,
272
- description=description,
273
- article=article,
274
  allow_screenshot=args.allow_screenshot,
275
  allow_flagging=args.allow_flagging,
276
  live=args.live,
 
7
  import os
8
  import pathlib
9
  import sys
 
10
  from typing import Callable
11
 
12
  if os.environ.get('SYSTEM') == 'spaces':
 
28
  from model.encoder.psp import pSp
29
  from util import load_image, visualize
30
 
31
+ ORIGINAL_REPO_URL = 'https://github.com/williamyang1991/DualStyleGAN'
32
+ TITLE = 'williamyang1991/DualStyleGAN'
33
+ DESCRIPTION = f"""A demo for {ORIGINAL_REPO_URL}
34
+
35
+ You can select style images for cartoon from the table below.
36
+
37
+ The style image index should be in the following range:
38
+
39
+ - cartoon: 0-316
40
+ - caricature: 0-198
41
+ - anime: 0-173
42
+ - arcane: 0-99
43
+ - comic: 0-100
44
+ - pixar: 0-121
45
+ - slamdunk: 0-119
46
+ """
47
+ ARTICLE = '![cartoon style images](https://user-images.githubusercontent.com/18130694/159848447-96fa5194-32ec-42f0-945a-3b1958bf6e5e.jpg)'
48
+
49
  TOKEN = os.environ['TOKEN']
50
 
51
  MODEL_REPO = 'hysts/DualStyleGAN'
 
66
  return parser.parse_args()
67
 
68
 
 
 
 
 
 
 
 
 
 
 
 
69
  def load_encoder(device: torch.device) -> nn.Module:
70
  ckpt_path = huggingface_hub.hf_hub_download(MODEL_REPO,
71
  'models/encoder.pt',
 
194
  img_gen1 = postprocess(img_gen[1])
195
  img_gen2 = postprocess(img_gen2[0])
196
 
197
+ return image, img_rec, img_gen0, img_gen1, img_gen2
 
 
 
 
 
 
198
 
199
 
200
  def main():
 
221
  for style_type in style_types
222
  }
223
 
 
224
  dlib_landmark_model = create_dlib_landmark_model()
225
  encoder = load_encoder(device)
226
  transform = create_transform()
 
234
  device=device)
235
  func = functools.update_wrapper(func, run)
236
 
 
 
 
 
 
 
 
 
237
  image_paths = sorted(pathlib.Path('images').glob('*'))
238
  examples = [[path.as_posix(), 'cartoon', 26] for path in image_paths]
239
 
 
251
  ],
252
  [
253
  gr.outputs.Image(type='pil', label='Aligned Face'),
 
254
  gr.outputs.Image(type='pil', label='Reconstructed'),
255
  gr.outputs.Image(type='pil', label='Result 1'),
256
  gr.outputs.Image(type='pil', label='Result 2'),
 
258
  ],
259
  examples=examples,
260
  theme=args.theme,
261
+ title=TITLE,
262
+ description=DESCRIPTION,
263
+ article=ARTICLE,
264
  allow_screenshot=args.allow_screenshot,
265
  allow_flagging=args.allow_flagging,
266
  live=args.live,