hysts HF staff commited on
Commit
83321d4
1 Parent(s): f92ba48
Files changed (6) hide show
  1. .pre-commit-config.yaml +35 -0
  2. .style.yapf +5 -0
  3. README.md +1 -1
  4. app.py +33 -60
  5. model.py +0 -1
  6. requirements.txt +2 -2
.pre-commit-config.yaml ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ repos:
2
+ - repo: https://github.com/pre-commit/pre-commit-hooks
3
+ rev: v4.2.0
4
+ hooks:
5
+ - id: check-executables-have-shebangs
6
+ - id: check-json
7
+ - id: check-merge-conflict
8
+ - id: check-shebang-scripts-are-executable
9
+ - id: check-toml
10
+ - id: check-yaml
11
+ - id: double-quote-string-fixer
12
+ - id: end-of-file-fixer
13
+ - id: mixed-line-ending
14
+ args: ['--fix=lf']
15
+ - id: requirements-txt-fixer
16
+ - id: trailing-whitespace
17
+ - repo: https://github.com/myint/docformatter
18
+ rev: v1.4
19
+ hooks:
20
+ - id: docformatter
21
+ args: ['--in-place']
22
+ - repo: https://github.com/pycqa/isort
23
+ rev: 5.12.0
24
+ hooks:
25
+ - id: isort
26
+ - repo: https://github.com/pre-commit/mirrors-mypy
27
+ rev: v0.991
28
+ hooks:
29
+ - id: mypy
30
+ args: ['--ignore-missing-imports']
31
+ - repo: https://github.com/google/yapf
32
+ rev: v0.32.0
33
+ hooks:
34
+ - id: yapf
35
+ args: ['--parallel', '--in-place']
.style.yapf ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ [style]
2
+ based_on_style = pep8
3
+ blank_line_before_nested_class_or_def = false
4
+ spaces_before_comment = 2
5
+ split_before_logical_operator = true
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: 📚
4
  colorFrom: indigo
5
  colorTo: pink
6
  sdk: gradio
7
- sdk_version: 3.0.5
8
  app_file: app.py
9
  pinned: false
10
  ---
 
4
  colorFrom: indigo
5
  colorTo: pink
6
  sdk: gradio
7
+ sdk_version: 3.19.1
8
  app_file: app.py
9
  pinned: false
10
  ---
app.py CHANGED
@@ -2,7 +2,6 @@
2
 
3
  from __future__ import annotations
4
 
5
- import argparse
6
  import functools
7
  import os
8
 
@@ -14,11 +13,8 @@ from huggingface_hub import hf_hub_download
14
 
15
  from model import Model
16
 
17
- TITLE = 'bes-dev/MobileStyleGAN.pytorch'
18
- DESCRIPTION = '''This is an unofficial demo for https://github.com/bes-dev/MobileStyleGAN.pytorch.
19
-
20
- Expected execution time on Hugging Face Spaces: 1s
21
- '''
22
  SAMPLE_IMAGE_DIR = 'https://huggingface.co/spaces/hysts/MobileStyleGAN/resolve/main/samples'
23
  ARTICLE = f'''## Generated images
24
  ### FFHQ
@@ -26,25 +22,9 @@ ARTICLE = f'''## Generated images
26
  - seed: 0-99
27
  - truncation: 1.0
28
  ![FFHQ]({SAMPLE_IMAGE_DIR}/ffhq.jpg)
29
-
30
- <center><img src="https://visitor-badge.glitch.me/badge?page_id=hysts.mobilestylegan" alt="visitor badge"/></center>
31
  '''
32
 
33
- TOKEN = os.environ['TOKEN']
34
-
35
-
36
- def parse_args() -> argparse.Namespace:
37
- parser = argparse.ArgumentParser()
38
- parser.add_argument('--device', type=str, default='cpu')
39
- parser.add_argument('--theme', type=str)
40
- parser.add_argument('--live', action='store_true')
41
- parser.add_argument('--share', action='store_true')
42
- parser.add_argument('--port', type=int)
43
- parser.add_argument('--disable-queue',
44
- dest='enable_queue',
45
- action='store_false')
46
- parser.add_argument('--allow-flagging', type=str, default='never')
47
- return parser.parse_args()
48
 
49
 
50
  def generate_z(z_dim: int, seed: int, device: torch.device) -> torch.Tensor:
@@ -67,7 +47,7 @@ def generate_image(seed: int, truncation_psi: float, generator: str,
67
  def load_model(device: torch.device) -> nn.Module:
68
  path = hf_hub_download('hysts/MobileStyleGAN',
69
  'models/mobilestylegan_ffhq_v2.pth',
70
- use_auth_token=TOKEN)
71
  ckpt = torch.load(path)
72
  model = Model()
73
  model.load_state_dict(ckpt['state_dict'], strict=False)
@@ -79,39 +59,32 @@ def load_model(device: torch.device) -> nn.Module:
79
  return model
80
 
81
 
82
- def main():
83
- args = parse_args()
84
- device = torch.device(args.device)
85
-
86
- model = load_model(device)
87
-
88
- func = functools.partial(generate_image, model=model, device=device)
89
- func = functools.update_wrapper(func, generate_image)
90
-
91
- gr.Interface(
92
- func,
93
- [
94
- gr.inputs.Number(default=0, label='Seed'),
95
- gr.inputs.Slider(
96
- 0, 2, step=0.05, default=1.0, label='Truncation psi'),
97
- gr.inputs.Radio(['student', 'teacher'],
98
- type='value',
99
- default='student',
100
- label='Generator'),
101
- ],
102
- gr.outputs.Image(type='numpy', label='Output'),
103
- title=TITLE,
104
- description=DESCRIPTION,
105
- article=ARTICLE,
106
- theme=args.theme,
107
- allow_flagging=args.allow_flagging,
108
- live=args.live,
109
- ).launch(
110
- enable_queue=args.enable_queue,
111
- server_port=args.port,
112
- share=args.share,
113
- )
114
-
115
-
116
- if __name__ == '__main__':
117
- main()
 
2
 
3
  from __future__ import annotations
4
 
 
5
  import functools
6
  import os
7
 
 
13
 
14
  from model import Model
15
 
16
+ TITLE = 'MobileStyleGAN'
17
+ DESCRIPTION = 'This is an unofficial demo for https://github.com/bes-dev/MobileStyleGAN.pytorch.'
 
 
 
18
  SAMPLE_IMAGE_DIR = 'https://huggingface.co/spaces/hysts/MobileStyleGAN/resolve/main/samples'
19
  ARTICLE = f'''## Generated images
20
  ### FFHQ
 
22
  - seed: 0-99
23
  - truncation: 1.0
24
  ![FFHQ]({SAMPLE_IMAGE_DIR}/ffhq.jpg)
 
 
25
  '''
26
 
27
+ HF_TOKEN = os.getenv('HF_TOKEN')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
 
30
  def generate_z(z_dim: int, seed: int, device: torch.device) -> torch.Tensor:
 
47
  def load_model(device: torch.device) -> nn.Module:
48
  path = hf_hub_download('hysts/MobileStyleGAN',
49
  'models/mobilestylegan_ffhq_v2.pth',
50
+ use_auth_token=HF_TOKEN)
51
  ckpt = torch.load(path)
52
  model = Model()
53
  model.load_state_dict(ckpt['state_dict'], strict=False)
 
59
  return model
60
 
61
 
62
+ device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
63
+ model = load_model(device)
64
+
65
+ func = functools.partial(generate_image, model=model, device=device)
66
+
67
+ gr.Interface(
68
+ fn=func,
69
+ inputs=[
70
+ gr.Slider(label='Seed',
71
+ minimum=0,
72
+ maximum=100000,
73
+ step=1,
74
+ value=0,
75
+ randomize=True),
76
+ gr.Slider(label='Truncation psi',
77
+ minimum=0,
78
+ maximum=2,
79
+ step=0.05,
80
+ value=1.0),
81
+ gr.Radio(label='Generator',
82
+ choices=['student', 'teacher'],
83
+ type='value',
84
+ value='student'),
85
+ ],
86
+ outputs=gr.Image(label='Output', type='numpy'),
87
+ title=TITLE,
88
+ description=DESCRIPTION,
89
+ article=ARTICLE,
90
+ ).queue().launch(show_api=False)
 
 
 
 
 
 
 
model.py CHANGED
@@ -11,7 +11,6 @@ from core.models.synthesis_network import SynthesisNetwork
11
 
12
 
13
  class Model(nn.Module):
14
-
15
  def __init__(self):
16
  super().__init__()
17
  # teacher model
 
11
 
12
 
13
  class Model(nn.Module):
 
14
  def __init__(self):
15
  super().__init__()
16
  # teacher model
requirements.txt CHANGED
@@ -1,8 +1,8 @@
 
1
  numpy==1.22.3
2
  Pillow==9.0.1
3
- PyWavelets==1.2.0
4
  piq==0.6.0
 
5
  scipy==1.8.0
6
  torch==1.11.0
7
  torchvision==0.12.0
8
- git+https://github.com/fbcotter/pytorch_wavelets.git
 
1
+ git+https://github.com/fbcotter/pytorch_wavelets.git
2
  numpy==1.22.3
3
  Pillow==9.0.1
 
4
  piq==0.6.0
5
+ PyWavelets==1.2.0
6
  scipy==1.8.0
7
  torch==1.11.0
8
  torchvision==0.12.0