hysts HF staff commited on
Commit
3f9cb7e
1 Parent(s): 89d0aa8
Files changed (4) hide show
  1. .pre-commit-config.yaml +35 -0
  2. .style.yapf +5 -0
  3. README.md +1 -1
  4. app.py +56 -71
.pre-commit-config.yaml ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ repos:
2
+ - repo: https://github.com/pre-commit/pre-commit-hooks
3
+ rev: v4.2.0
4
+ hooks:
5
+ - id: check-executables-have-shebangs
6
+ - id: check-json
7
+ - id: check-merge-conflict
8
+ - id: check-shebang-scripts-are-executable
9
+ - id: check-toml
10
+ - id: check-yaml
11
+ - id: double-quote-string-fixer
12
+ - id: end-of-file-fixer
13
+ - id: mixed-line-ending
14
+ args: ['--fix=lf']
15
+ - id: requirements-txt-fixer
16
+ - id: trailing-whitespace
17
+ - repo: https://github.com/myint/docformatter
18
+ rev: v1.4
19
+ hooks:
20
+ - id: docformatter
21
+ args: ['--in-place']
22
+ - repo: https://github.com/pycqa/isort
23
+ rev: 5.12.0
24
+ hooks:
25
+ - id: isort
26
+ - repo: https://github.com/pre-commit/mirrors-mypy
27
+ rev: v0.991
28
+ hooks:
29
+ - id: mypy
30
+ args: ['--ignore-missing-imports']
31
+ - repo: https://github.com/google/yapf
32
+ rev: v0.32.0
33
+ hooks:
34
+ - id: yapf
35
+ args: ['--parallel', '--in-place']
.style.yapf ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ [style]
2
+ based_on_style = pep8
3
+ blank_line_before_nested_class_or_def = false
4
+ spaces_before_comment = 2
5
+ split_before_logical_operator = true
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: ⚡
4
  colorFrom: red
5
  colorTo: yellow
6
  sdk: gradio
7
- sdk_version: 3.0.17
8
  app_file: app.py
9
  pinned: false
10
  ---
 
4
  colorFrom: red
5
  colorTo: yellow
6
  sdk: gradio
7
+ sdk_version: 3.19.1
8
  app_file: app.py
9
  pinned: false
10
  ---
app.py CHANGED
@@ -2,7 +2,6 @@
2
 
3
  from __future__ import annotations
4
 
5
- import argparse
6
  import functools
7
  import os
8
  import pathlib
@@ -24,28 +23,12 @@ sys.path.insert(0, 'gan-control/src')
24
 
25
  from gan_control.inference.controller import Controller
26
 
27
- TITLE = 'amazon-research/gan-control'
28
- DESCRIPTION = '''This is an unofficial demo for https://github.com/amazon-research/gan-control.
29
 
30
- Expected execution time on Hugging Face Spaces: 7s (for one image)
31
  '''
32
- ARTICLE = '<center><img src="https://visitor-badge.glitch.me/badge?page_id=hysts.gan-control" alt="visitor badge"/></center>'
33
 
34
- TOKEN = os.environ['TOKEN']
35
-
36
-
37
- def parse_args() -> argparse.Namespace:
38
- parser = argparse.ArgumentParser()
39
- parser.add_argument('--device', type=str, default='cpu')
40
- parser.add_argument('--theme', type=str)
41
- parser.add_argument('--live', action='store_true')
42
- parser.add_argument('--share', action='store_true')
43
- parser.add_argument('--port', type=int)
44
- parser.add_argument('--disable-queue',
45
- dest='enable_queue',
46
- action='store_false')
47
- parser.add_argument('--allow-flagging', type=str, default='never')
48
- return parser.parse_args()
49
 
50
 
51
  def download_models() -> None:
@@ -108,54 +91,56 @@ def run(
108
  return res0, res1, res2, res3
109
 
110
 
111
- def main():
112
- args = parse_args()
113
- device = torch.device(args.device)
114
-
115
- download_models()
116
-
117
- path = 'controller_age015id025exp02hai04ori02gam15/'
118
- controller = Controller(path, device)
119
-
120
- func = functools.partial(run, controller=controller, device=device)
121
- func = functools.update_wrapper(func, run)
122
-
123
- gr.Interface(
124
- func,
125
- [
126
- gr.inputs.Number(default=0, label='Seed'),
127
- gr.inputs.Slider(0, 1, step=0.1, default=0.7, label='Truncation'),
128
- gr.inputs.Slider(-90, 90, step=1, default=30, label='Yaw'),
129
- gr.inputs.Slider(-90, 90, step=1, default=0, label='Pitch'),
130
- gr.inputs.Slider(15, 75, step=1, default=75, label='Age'),
131
- gr.inputs.Slider(
132
- 0, 255, step=1, default=186, label='Hair Color (R)'),
133
- gr.inputs.Slider(
134
- 0, 255, step=1, default=158, label='Hair Color (G)'),
135
- gr.inputs.Slider(
136
- 0, 255, step=1, default=92, label='Hair Color (B)'),
137
- gr.inputs.Slider(1, 3, step=1, default=1, label='Number of Rows'),
138
- gr.inputs.Slider(
139
- 1, 5, step=1, default=5, label='Number of Columns'),
140
- ],
141
- [
142
- gr.outputs.Image(type='pil', label='Generated Image'),
143
- gr.outputs.Image(type='pil', label='Head Pose Controlled'),
144
- gr.outputs.Image(type='pil', label='Age Controlled'),
145
- gr.outputs.Image(type='pil', label='Hair Color Controlled'),
146
- ],
147
- title=TITLE,
148
- description=DESCRIPTION,
149
- article=ARTICLE,
150
- theme=args.theme,
151
- allow_flagging=args.allow_flagging,
152
- live=args.live,
153
- ).launch(
154
- enable_queue=args.enable_queue,
155
- server_port=args.port,
156
- share=args.share,
157
- )
158
-
159
-
160
- if __name__ == '__main__':
161
- main()
 
 
 
2
 
3
  from __future__ import annotations
4
 
 
5
  import functools
6
  import os
7
  import pathlib
 
23
 
24
  from gan_control.inference.controller import Controller
25
 
26
+ DESCRIPTION = '''GAN-Control
 
27
 
28
+ This is an unofficial demo for https://github.com/amazon-research/gan-control.
29
  '''
 
30
 
31
+ TOKEN = os.getenv('HF_TOKEN')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
 
34
  def download_models() -> None:
 
91
  return res0, res1, res2, res3
92
 
93
 
94
+ download_models()
95
+
96
+ device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
97
+ path = 'controller_age015id025exp02hai04ori02gam15/'
98
+ controller = Controller(path, device)
99
+ func = functools.partial(run, controller=controller, device=device)
100
+
101
+ gr.Interface(
102
+ fn=func,
103
+ inputs=[
104
+ gr.Slider(label='Seed', minimum=0, maximum=1000000, step=1, value=0),
105
+ gr.Slider(label='Truncation',
106
+ minimum=0,
107
+ maximum=1,
108
+ step=0.1,
109
+ value=0.7),
110
+ gr.Slider(label='Yaw', minimum=-90, maximum=90, step=1, value=30),
111
+ gr.Slider(label='Pitch', minimum=-90, maximum=90, step=1, value=0),
112
+ gr.Slider(label='Age', minimum=15, maximum=75, step=1, value=75),
113
+ gr.Slider(label='Hair Color (R)',
114
+ minimum=0,
115
+ maximum=255,
116
+ step=1,
117
+ value=186),
118
+ gr.Slider(label='Hair Color (G)',
119
+ minimum=0,
120
+ maximum=255,
121
+ step=1,
122
+ value=158),
123
+ gr.Slider(label='Hair Color (B)',
124
+ minimum=0,
125
+ maximum=255,
126
+ step=1,
127
+ value=92),
128
+ gr.Slider(label='Number of Rows',
129
+ minimum=1,
130
+ maximum=3,
131
+ step=1,
132
+ value=1),
133
+ gr.Slider(label='Number of Columns',
134
+ minimum=1,
135
+ maximum=5,
136
+ step=1,
137
+ value=5),
138
+ ],
139
+ outputs=[
140
+ gr.Image(label='Generated Image', type='pil'),
141
+ gr.Image(label='Head Pose Controlled', type='pil'),
142
+ gr.Image(label='Age Controlled', type='pil'),
143
+ gr.Image(label='Hair Color Controlled', type='pil'),
144
+ ],
145
+ description=DESCRIPTION,
146
+ ).queue().launch(show_api=False)