hysts HF staff commited on
Commit
910c8b5
β€’
1 Parent(s): 99916f8
Files changed (5) hide show
  1. .pre-commit-config.yaml +3 -12
  2. README.md +4 -1
  3. app.py +116 -149
  4. model.py +5 -8
  5. requirements.txt +5 -5
.pre-commit-config.yaml CHANGED
@@ -21,26 +21,17 @@ repos:
21
  - id: docformatter
22
  args: ['--in-place']
23
  - repo: https://github.com/pycqa/isort
24
- rev: 5.10.1
25
  hooks:
26
  - id: isort
27
  - repo: https://github.com/pre-commit/mirrors-mypy
28
- rev: v0.812
29
  hooks:
30
  - id: mypy
31
  args: ['--ignore-missing-imports']
 
32
  - repo: https://github.com/google/yapf
33
  rev: v0.32.0
34
  hooks:
35
  - id: yapf
36
  args: ['--parallel', '--in-place']
37
- - repo: https://github.com/kynan/nbstripout
38
- rev: 0.5.0
39
- hooks:
40
- - id: nbstripout
41
- args: ['--extra-keys', 'metadata.interpreter metadata.kernelspec cell.metadata.pycharm']
42
- - repo: https://github.com/nbQA-dev/nbQA
43
- rev: 1.3.1
44
- hooks:
45
- - id: nbqa-isort
46
- - id: nbqa-yapf
 
21
  - id: docformatter
22
  args: ['--in-place']
23
  - repo: https://github.com/pycqa/isort
24
+ rev: 5.12.0
25
  hooks:
26
  - id: isort
27
  - repo: https://github.com/pre-commit/mirrors-mypy
28
+ rev: v0.991
29
  hooks:
30
  - id: mypy
31
  args: ['--ignore-missing-imports']
32
+ additional_dependencies: ['types-python-slugify']
33
  - repo: https://github.com/google/yapf
34
  rev: v0.32.0
35
  hooks:
36
  - id: yapf
37
  args: ['--parallel', '--in-place']
 
 
 
 
 
 
 
 
 
 
README.md CHANGED
@@ -4,9 +4,12 @@ emoji: πŸƒ
4
  colorFrom: red
5
  colorTo: purple
6
  sdk: gradio
7
- sdk_version: 3.1.3
8
  app_file: app.py
9
  pinned: false
 
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces#reference
 
 
 
4
  colorFrom: red
5
  colorTo: purple
6
  sdk: gradio
7
+ sdk_version: 3.35.2
8
  app_file: app.py
9
  pinned: false
10
+ suggested_hardware: t4-small
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces#reference
14
+
15
+ https://arxiv.org/abs/2202.00273
app.py CHANGED
@@ -2,7 +2,6 @@
2
 
3
  from __future__ import annotations
4
 
5
- import argparse
6
  import json
7
 
8
  import gradio as gr
@@ -10,24 +9,7 @@ import numpy as np
10
 
11
  from model import Model
12
 
13
- TITLE = '# StyleGAN-XL'
14
- DESCRIPTION = '''This is an unofficial demo for [https://github.com/autonomousvision/stylegan_xl](https://github.com/autonomousvision/stylegan_xl).
15
-
16
- Expected execution time on Hugging Face Spaces: 16s
17
- '''
18
- FOOTER = '<img id="visitor-badge" src="https://visitor-badge.glitch.me/badge?page_id=hysts.stylegan-xl" alt="visitor badge" />'
19
-
20
-
21
- def parse_args() -> argparse.Namespace:
22
- parser = argparse.ArgumentParser()
23
- parser.add_argument('--device', type=str, default='cpu')
24
- parser.add_argument('--theme', type=str)
25
- parser.add_argument('--share', action='store_true')
26
- parser.add_argument('--port', type=int)
27
- parser.add_argument('--disable-queue',
28
- dest='enable_queue',
29
- action='store_false')
30
- return parser.parse_args()
31
 
32
 
33
  def update_class_index(name: str) -> dict:
@@ -105,133 +87,118 @@ def update_class_name(model_name: str, index: int) -> dict:
105
  return gr.Textbox.update(visible=False)
106
 
107
 
108
- def main():
109
- args = parse_args()
110
- model = Model(args.device)
111
-
112
- with gr.Blocks(theme=args.theme, css='style.css') as demo:
113
- gr.Markdown(TITLE)
114
- gr.Markdown(DESCRIPTION)
115
-
116
- with gr.Tabs():
117
- with gr.TabItem('App'):
118
- with gr.Row():
119
- with gr.Column():
120
- with gr.Group():
121
- model_name = gr.Dropdown(
122
- model.MODEL_NAMES,
123
- value=model.MODEL_NAMES[3],
124
- label='Model')
125
- seed = gr.Slider(0,
126
- np.iinfo(np.uint32).max,
127
- step=1,
128
- value=0,
129
- label='Seed')
130
- psi = gr.Slider(0,
131
- 2,
132
- step=0.05,
133
- value=0.7,
134
- label='Truncation psi')
135
- class_index = gr.Slider(0,
136
- 999,
137
- step=1,
138
- value=83,
139
- label='Class Index')
140
- class_name = gr.Textbox(
141
- value=IMAGENET_NAMES[class_index.value],
142
- label='Class Label',
143
- interactive=False)
144
- tx = gr.Slider(-1,
145
- 1,
146
- step=0.05,
147
- value=0,
148
- label='Translate X')
149
- ty = gr.Slider(-1,
150
- 1,
151
- step=0.05,
152
- value=0,
153
- label='Translate Y')
154
- angle = gr.Slider(-180,
155
- 180,
156
- step=5,
157
- value=0,
158
- label='Angle')
159
- run_button = gr.Button('Run')
160
- with gr.Column():
161
- result = gr.Image(label='Result', elem_id='result')
162
-
163
- with gr.TabItem('Sample Images'):
164
- with gr.Row():
165
- model_name2 = gr.Dropdown([
166
- 'imagenet',
167
- 'cifar10',
168
- 'ffhq',
169
- 'pokemon',
170
- ],
171
- value='imagenet',
172
- label='Model')
173
- with gr.Row():
174
- text = get_sample_image_markdown(model_name2.value)
175
- sample_images = gr.Markdown(text)
176
-
177
- with gr.TabItem('Class Names'):
178
- with gr.Row():
179
- dataset_name = gr.Dropdown([
180
- 'imagenet',
181
- 'cifar10',
182
- ],
183
- value='imagenet',
184
- label='Dataset')
185
- with gr.Row():
186
- df = get_class_name_df('imagenet')
187
- class_names = gr.Dataframe(
188
- df,
189
- col_count=2,
190
- headers=['Class Index', 'Label'],
191
- interactive=False)
192
-
193
- gr.Markdown(FOOTER)
194
-
195
- model_name.change(fn=model.set_model, inputs=model_name, outputs=None)
196
- model_name.change(fn=update_class_index,
197
- inputs=model_name,
198
- outputs=class_index)
199
- model_name.change(fn=update_class_name,
200
- inputs=[
201
- model_name,
202
- class_index,
203
- ],
204
- outputs=class_name)
205
- class_index.change(fn=update_class_name,
206
- inputs=[
207
- model_name,
208
- class_index,
209
- ],
210
- outputs=class_name)
211
- run_button.click(fn=model.set_model_and_generate_image,
212
- inputs=[
213
- model_name,
214
- seed,
215
- psi,
216
- class_index,
217
- tx,
218
- ty,
219
- angle,
220
- ],
221
- outputs=result)
222
- model_name2.change(fn=get_sample_image_markdown,
223
- inputs=model_name2,
224
- outputs=sample_images)
225
- dataset_name.change(fn=get_class_name_df,
226
- inputs=dataset_name,
227
- outputs=class_names)
228
-
229
- demo.launch(
230
- enable_queue=args.enable_queue,
231
- server_port=args.port,
232
- share=args.share,
233
- )
234
-
235
-
236
- if __name__ == '__main__':
237
- main()
 
2
 
3
  from __future__ import annotations
4
 
 
5
  import json
6
 
7
  import gradio as gr
 
9
 
10
  from model import Model
11
 
12
+ DESCRIPTION = '# [StyleGAN-XL](https://github.com/autonomousvision/stylegan_xl)'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
 
15
  def update_class_index(name: str) -> dict:
 
87
  return gr.Textbox.update(visible=False)
88
 
89
 
90
+ model = Model()
91
+
92
+ with gr.Blocks(css='style.css') as demo:
93
+ gr.Markdown(DESCRIPTION)
94
+
95
+ with gr.Tabs():
96
+ with gr.TabItem('App'):
97
+ with gr.Row():
98
+ with gr.Column():
99
+ with gr.Group():
100
+ model_name = gr.Dropdown(model.MODEL_NAMES,
101
+ value=model.MODEL_NAMES[3],
102
+ label='Model')
103
+ seed = gr.Slider(0,
104
+ np.iinfo(np.uint32).max,
105
+ step=1,
106
+ value=0,
107
+ label='Seed')
108
+ psi = gr.Slider(0,
109
+ 2,
110
+ step=0.05,
111
+ value=0.7,
112
+ label='Truncation psi')
113
+ class_index = gr.Slider(0,
114
+ 999,
115
+ step=1,
116
+ value=83,
117
+ label='Class Index')
118
+ class_name = gr.Textbox(
119
+ value=IMAGENET_NAMES[class_index.value],
120
+ label='Class Label',
121
+ interactive=False)
122
+ tx = gr.Slider(-1,
123
+ 1,
124
+ step=0.05,
125
+ value=0,
126
+ label='Translate X')
127
+ ty = gr.Slider(-1,
128
+ 1,
129
+ step=0.05,
130
+ value=0,
131
+ label='Translate Y')
132
+ angle = gr.Slider(-180,
133
+ 180,
134
+ step=5,
135
+ value=0,
136
+ label='Angle')
137
+ run_button = gr.Button('Run')
138
+ with gr.Column():
139
+ result = gr.Image(label='Result', elem_id='result')
140
+
141
+ with gr.TabItem('Sample Images'):
142
+ with gr.Row():
143
+ model_name2 = gr.Dropdown([
144
+ 'imagenet',
145
+ 'cifar10',
146
+ 'ffhq',
147
+ 'pokemon',
148
+ ],
149
+ value='imagenet',
150
+ label='Model')
151
+ with gr.Row():
152
+ text = get_sample_image_markdown(model_name2.value)
153
+ sample_images = gr.Markdown(text)
154
+
155
+ with gr.TabItem('Class Names'):
156
+ with gr.Row():
157
+ dataset_name = gr.Dropdown([
158
+ 'imagenet',
159
+ 'cifar10',
160
+ ],
161
+ value='imagenet',
162
+ label='Dataset')
163
+ with gr.Row():
164
+ df = get_class_name_df('imagenet')
165
+ class_names = gr.Dataframe(df,
166
+ col_count=2,
167
+ headers=['Class Index', 'Label'],
168
+ interactive=False)
169
+
170
+ model_name.change(fn=model.set_model, inputs=model_name, outputs=None)
171
+ model_name.change(fn=update_class_index,
172
+ inputs=model_name,
173
+ outputs=class_index)
174
+ model_name.change(fn=update_class_name,
175
+ inputs=[
176
+ model_name,
177
+ class_index,
178
+ ],
179
+ outputs=class_name)
180
+ class_index.change(fn=update_class_name,
181
+ inputs=[
182
+ model_name,
183
+ class_index,
184
+ ],
185
+ outputs=class_name)
186
+ run_button.click(fn=model.set_model_and_generate_image,
187
+ inputs=[
188
+ model_name,
189
+ seed,
190
+ psi,
191
+ class_index,
192
+ tx,
193
+ ty,
194
+ angle,
195
+ ],
196
+ outputs=result)
197
+ model_name2.change(fn=get_sample_image_markdown,
198
+ inputs=model_name2,
199
+ outputs=sample_images)
200
+ dataset_name.change(fn=get_class_name_df,
201
+ inputs=dataset_name,
202
+ outputs=class_names)
203
+
204
+ demo.queue(max_size=10).launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model.py CHANGED
@@ -1,6 +1,5 @@
1
  from __future__ import annotations
2
 
3
- import os
4
  import pathlib
5
  import pickle
6
  import sys
@@ -14,8 +13,6 @@ current_dir = pathlib.Path(__file__).parent
14
  submodule_dir = current_dir / 'stylegan_xl'
15
  sys.path.insert(0, submodule_dir.as_posix())
16
 
17
- HF_TOKEN = os.environ['HF_TOKEN']
18
-
19
 
20
  class Model:
21
 
@@ -29,16 +26,16 @@ class Model:
29
  'pokemon256',
30
  ]
31
 
32
- def __init__(self, device: str | torch.device):
33
- self.device = torch.device(device)
 
34
  self._download_all_models()
35
  self.model_name = self.MODEL_NAMES[3]
36
  self.model = self._load_model(self.model_name)
37
 
38
  def _load_model(self, model_name: str) -> nn.Module:
39
- path = hf_hub_download('hysts/StyleGAN-XL',
40
- f'models/{model_name}.pkl',
41
- use_auth_token=HF_TOKEN)
42
  with open(path, 'rb') as f:
43
  model = pickle.load(f)['G_ema']
44
  model.eval()
 
1
  from __future__ import annotations
2
 
 
3
  import pathlib
4
  import pickle
5
  import sys
 
13
  submodule_dir = current_dir / 'stylegan_xl'
14
  sys.path.insert(0, submodule_dir.as_posix())
15
 
 
 
16
 
17
  class Model:
18
 
 
26
  'pokemon256',
27
  ]
28
 
29
+ def __init__(self):
30
+ self.device = torch.device(
31
+ 'cuda:0' if torch.cuda.is_available() else 'cpu')
32
  self._download_all_models()
33
  self.model_name = self.MODEL_NAMES[3]
34
  self.model = self._load_model(self.model_name)
35
 
36
  def _load_model(self, model_name: str) -> nn.Module:
37
+ path = hf_hub_download('public-data/StyleGAN-XL',
38
+ f'models/{model_name}.pkl')
 
39
  with open(path, 'rb') as f:
40
  model = pickle.load(f)['G_ema']
41
  model.eval()
requirements.txt CHANGED
@@ -1,6 +1,6 @@
1
  ftfy==6.1.1
2
- numpy==1.22.3
3
- Pillow==9.0.1
4
- scipy==1.8.0
5
- torch==1.11.0
6
- torchvision==0.12.0
 
1
  ftfy==6.1.1
2
+ numpy==1.23.5
3
+ Pillow==10.0.0
4
+ scipy==1.10.1
5
+ torch==2.0.1
6
+ torchvision==0.15.2