hysts HF staff commited on
Commit
afe6a1b
1 Parent(s): 3078ca0
Files changed (5) hide show
  1. .pre-commit-config.yaml +46 -0
  2. .style.yapf +5 -0
  3. app.py +67 -110
  4. model.py +86 -0
  5. style.css +11 -0
.pre-commit-config.yaml ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ exclude: ^projected_gan
2
+ repos:
3
+ - repo: https://github.com/pre-commit/pre-commit-hooks
4
+ rev: v4.2.0
5
+ hooks:
6
+ - id: check-executables-have-shebangs
7
+ - id: check-json
8
+ - id: check-merge-conflict
9
+ - id: check-shebang-scripts-are-executable
10
+ - id: check-toml
11
+ - id: check-yaml
12
+ - id: double-quote-string-fixer
13
+ - id: end-of-file-fixer
14
+ - id: mixed-line-ending
15
+ args: ['--fix=lf']
16
+ - id: requirements-txt-fixer
17
+ - id: trailing-whitespace
18
+ - repo: https://github.com/myint/docformatter
19
+ rev: v1.4
20
+ hooks:
21
+ - id: docformatter
22
+ args: ['--in-place']
23
+ - repo: https://github.com/pycqa/isort
24
+ rev: 5.10.1
25
+ hooks:
26
+ - id: isort
27
+ - repo: https://github.com/pre-commit/mirrors-mypy
28
+ rev: v0.812
29
+ hooks:
30
+ - id: mypy
31
+ args: ['--ignore-missing-imports']
32
+ - repo: https://github.com/google/yapf
33
+ rev: v0.32.0
34
+ hooks:
35
+ - id: yapf
36
+ args: ['--parallel', '--in-place']
37
+ - repo: https://github.com/kynan/nbstripout
38
+ rev: 0.5.0
39
+ hooks:
40
+ - id: nbstripout
41
+ args: ['--extra-keys', 'metadata.interpreter metadata.kernelspec cell.metadata.pycharm']
42
+ - repo: https://github.com/nbQA-dev/nbQA
43
+ rev: 1.3.1
44
+ hooks:
45
+ - id: nbqa-isort
46
+ - id: nbqa-yapf
.style.yapf ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ [style]
2
+ based_on_style = pep8
3
+ blank_line_before_nested_class_or_def = false
4
+ spaces_before_comment = 2
5
+ split_before_logical_operator = true
app.py CHANGED
@@ -3,144 +3,101 @@
3
  from __future__ import annotations
4
 
5
  import argparse
6
- import functools
7
- import os
8
- import pickle
9
- import sys
10
 
11
  import gradio as gr
12
  import numpy as np
13
- import torch
14
- import torch.nn as nn
15
- from huggingface_hub import hf_hub_download
16
 
17
- sys.path.insert(0, 'projected_gan')
18
 
19
- TITLE = 'autonomousvision/projected_gan'
20
- DESCRIPTION = '''This is an unofficial demo for https://github.com/autonomousvision/projected_gan.
21
 
22
  Expected execution time on Hugging Face Spaces: 1s
23
  '''
24
- SAMPLE_IMAGE_DIR = 'https://huggingface.co/spaces/hysts/projected_gan/resolve/main/samples'
25
- ARTICLE = f'''## Generated images
26
- - truncation: 0.7
27
- - size: 256x256
28
- - seed: 0-99
29
- ### Art painting
30
- ![Art painting samples]({SAMPLE_IMAGE_DIR}/art_painting.jpg)
31
- ### Bedroom
32
- ![Bedroom samples]({SAMPLE_IMAGE_DIR}/bedroom.jpg)
33
- ### Church
34
- ![Church samples]({SAMPLE_IMAGE_DIR}/church.jpg)
35
- ### Cityscapes
36
- ![Cityscapes samples]({SAMPLE_IMAGE_DIR}/cityscapes.jpg)
37
- ### CLEVR
38
- ![CLEVR samples]({SAMPLE_IMAGE_DIR}/clevr.jpg)
39
- ### FFHQ
40
- ![FFHQ samples]({SAMPLE_IMAGE_DIR}/ffhq.jpg)
41
- ### Flowers
42
- ![Flowers samples]({SAMPLE_IMAGE_DIR}/flowers.jpg)
43
- ### Landscape
44
- ![Landscape samples]({SAMPLE_IMAGE_DIR}/landscape.jpg)
45
- ### Pokemon
46
- ![Pokemon samples]({SAMPLE_IMAGE_DIR}/pokemon.jpg)
47
-
48
- <center><img src="https://visitor-badge.glitch.me/badge?page_id=hysts.projected_gan" alt="visitor badge"/></center>
49
- '''
50
-
51
- TOKEN = os.environ['TOKEN']
52
 
53
 
54
  def parse_args() -> argparse.Namespace:
55
  parser = argparse.ArgumentParser()
56
  parser.add_argument('--device', type=str, default='cpu')
57
  parser.add_argument('--theme', type=str)
58
- parser.add_argument('--live', action='store_true')
59
  parser.add_argument('--share', action='store_true')
60
  parser.add_argument('--port', type=int)
61
  parser.add_argument('--disable-queue',
62
  dest='enable_queue',
63
  action='store_false')
64
- parser.add_argument('--allow-flagging', type=str, default='never')
65
  return parser.parse_args()
66
 
67
 
68
- def generate_z(z_dim: int, seed: int, device: torch.device) -> torch.Tensor:
69
- return torch.from_numpy(
70
- np.random.RandomState(seed).randn(1,
71
- z_dim).astype(np.float32)).to(device)
72
-
73
-
74
- @torch.inference_mode()
75
- def generate_image(model_name: str, seed: int, truncation_psi: float,
76
- model_dict: dict[str, nn.Module],
77
- device: torch.device) -> np.ndarray:
78
- model = model_dict[model_name]
79
- seed = int(np.clip(seed, 0, np.iinfo(np.uint32).max))
80
-
81
- z = generate_z(model.z_dim, seed, device)
82
- label = torch.zeros([1, model.c_dim], device=device)
83
-
84
- out = model(z, label, truncation_psi=truncation_psi)
85
- out = (out.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
86
- return out[0].cpu().numpy()
87
 
88
 
89
- def load_model(model_name: str, device: torch.device) -> nn.Module:
90
- path = hf_hub_download('hysts/projected_gan',
91
- f'models/{model_name}.pkl',
92
- use_auth_token=TOKEN)
93
- with open(path, 'rb') as f:
94
- model = pickle.load(f)['G_ema']
95
- model.eval()
96
- model.to(device)
97
- with torch.inference_mode():
98
- z = torch.zeros((1, model.z_dim)).to(device)
99
- label = torch.zeros([1, model.c_dim], device=device)
100
- model(z, label)
101
- return model
102
 
103
 
104
  def main():
105
  args = parse_args()
106
- device = torch.device(args.device)
107
-
108
- model_names = [
109
- 'art_painting',
110
- 'church',
111
- 'bedroom',
112
- 'cityscapes',
113
- 'clevr',
114
- 'ffhq',
115
- 'flowers',
116
- 'landscape',
117
- 'pokemon',
118
- ]
119
-
120
- model_dict = {name: load_model(name, device) for name in model_names}
121
-
122
- func = functools.partial(generate_image,
123
- model_dict=model_dict,
124
- device=device)
125
- func = functools.update_wrapper(func, generate_image)
126
-
127
- gr.Interface(
128
- func,
129
- [
130
- gr.inputs.Radio(
131
- model_names, type='value', default='pokemon', label='Model'),
132
- gr.inputs.Number(default=0, label='Seed'),
133
- gr.inputs.Slider(
134
- 0, 2, step=0.05, default=0.7, label='Truncation psi'),
135
- ],
136
- gr.outputs.Image(type='numpy', label='Output'),
137
- title=TITLE,
138
- description=DESCRIPTION,
139
- article=ARTICLE,
140
- theme=args.theme,
141
- allow_flagging=args.allow_flagging,
142
- live=args.live,
143
- ).launch(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
  enable_queue=args.enable_queue,
145
  server_port=args.port,
146
  share=args.share,
 
3
  from __future__ import annotations
4
 
5
  import argparse
 
 
 
 
6
 
7
  import gradio as gr
8
  import numpy as np
 
 
 
9
 
10
+ from model import Model
11
 
12
+ TITLE = '# autonomousvision/projected_gan'
13
+ DESCRIPTION = '''This is an unofficial demo for [https://github.com/autonomousvision/projected_gan](https://github.com/autonomousvision/projected_gan).
14
 
15
  Expected execution time on Hugging Face Spaces: 1s
16
  '''
17
+ FOOTER = '<img id="visitor-badge" alt="visitor badge" src="https://visitor-badge.glitch.me/badge?page_id=hysts.projected_gan" />'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
 
20
  def parse_args() -> argparse.Namespace:
21
  parser = argparse.ArgumentParser()
22
  parser.add_argument('--device', type=str, default='cpu')
23
  parser.add_argument('--theme', type=str)
 
24
  parser.add_argument('--share', action='store_true')
25
  parser.add_argument('--port', type=int)
26
  parser.add_argument('--disable-queue',
27
  dest='enable_queue',
28
  action='store_false')
 
29
  return parser.parse_args()
30
 
31
 
32
+ def get_sample_image_url(name: str) -> str:
33
+ sample_image_dir = 'https://huggingface.co/spaces/hysts/projected_gan/resolve/main/samples'
34
+ return f'{sample_image_dir}/{name}.jpg'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
 
37
+ def get_sample_image_markdown(name: str) -> str:
38
+ url = get_sample_image_url(name)
39
+ return f'''
40
+ - size: 256x256
41
+ - seed: 0-99
42
+ - truncation: 0.7
43
+ ![sample images]({url})'''
 
 
 
 
 
 
44
 
45
 
46
  def main():
47
  args = parse_args()
48
+ model = Model(args.device)
49
+
50
+ with gr.Blocks(theme=args.theme, css='style.css') as demo:
51
+ gr.Markdown(TITLE)
52
+ gr.Markdown(DESCRIPTION)
53
+
54
+ with gr.Tabs():
55
+ with gr.TabItem('App'):
56
+ with gr.Row():
57
+ with gr.Column():
58
+ with gr.Group():
59
+ model_name = gr.Dropdown(
60
+ model.MODEL_NAMES,
61
+ value=model.MODEL_NAMES[8],
62
+ label='Model')
63
+ seed = gr.Slider(0,
64
+ np.iinfo(np.uint32).max,
65
+ step=1,
66
+ value=0,
67
+ label='Seed')
68
+ psi = gr.Slider(0,
69
+ 2,
70
+ step=0.05,
71
+ value=0.7,
72
+ label='Truncation psi')
73
+ run_button = gr.Button('Run')
74
+ with gr.Column():
75
+ result = gr.Image(label='Result', elem_id='result')
76
+
77
+ with gr.TabItem('Sample Images'):
78
+ with gr.Row():
79
+ model_name2 = gr.Dropdown(model.MODEL_NAMES,
80
+ value=model.MODEL_NAMES[0],
81
+ label='Model')
82
+ with gr.Row():
83
+ text = get_sample_image_markdown(model_name2.value)
84
+ sample_images = gr.Markdown(text)
85
+
86
+ gr.Markdown(FOOTER)
87
+
88
+ model_name.change(fn=model.set_model, inputs=model_name, outputs=None)
89
+ run_button.click(fn=model.set_model_and_generate_image,
90
+ inputs=[
91
+ model_name,
92
+ seed,
93
+ psi,
94
+ ],
95
+ outputs=result)
96
+ model_name2.change(fn=get_sample_image_markdown,
97
+ inputs=model_name2,
98
+ outputs=sample_images)
99
+
100
+ demo.launch(
101
  enable_queue=args.enable_queue,
102
  server_port=args.port,
103
  share=args.share,
model.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import pathlib
5
+ import pickle
6
+ import sys
7
+
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn as nn
11
+ from huggingface_hub import hf_hub_download
12
+
13
+ current_dir = pathlib.Path(__file__).parent
14
+ submodule_dir = current_dir / 'projected_gan'
15
+ sys.path.insert(0, submodule_dir.as_posix())
16
+
17
+ HF_TOKEN = os.environ['HF_TOKEN']
18
+
19
+
20
+ class Model:
21
+
22
+ MODEL_NAMES = [
23
+ 'art_painting',
24
+ 'church',
25
+ 'bedroom',
26
+ 'cityscapes',
27
+ 'clevr',
28
+ 'ffhq',
29
+ 'flowers',
30
+ 'landscape',
31
+ 'pokemon',
32
+ ]
33
+
34
+ def __init__(self, device: str | torch.device):
35
+ self.device = torch.device(device)
36
+ self._download_all_models()
37
+ self.model_name = self.MODEL_NAMES[3]
38
+ self.model = self._load_model(self.model_name)
39
+
40
+ def _load_model(self, model_name: str) -> nn.Module:
41
+ path = hf_hub_download('hysts/projected_gan',
42
+ f'models/{model_name}.pkl',
43
+ use_auth_token=HF_TOKEN)
44
+ with open(path, 'rb') as f:
45
+ model = pickle.load(f)['G_ema']
46
+ model.eval()
47
+ model.to(self.device)
48
+ return model
49
+
50
+ def set_model(self, model_name: str) -> None:
51
+ if model_name == self.model_name:
52
+ return
53
+ self.model_name = model_name
54
+ self.model = self._load_model(model_name)
55
+
56
+ def _download_all_models(self):
57
+ for name in self.MODEL_NAMES:
58
+ self._load_model(name)
59
+
60
+ def generate_z(self, seed: int) -> torch.Tensor:
61
+ seed = int(np.clip(seed, 0, np.iinfo(np.uint32).max))
62
+ z = np.random.RandomState(seed).randn(1, self.model.z_dim)
63
+ return torch.from_numpy(z).float().to(self.device)
64
+
65
+ def postprocess(self, tensor: torch.Tensor) -> np.ndarray:
66
+ tensor = (tensor.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(
67
+ torch.uint8)
68
+ return tensor.cpu().numpy()
69
+
70
+ @torch.inference_mode()
71
+ def generate(self, z: torch.Tensor, label: torch.Tensor,
72
+ truncation_psi: float) -> torch.Tensor:
73
+ return self.model(z, label, truncation_psi=truncation_psi)
74
+
75
+ def generate_image(self, seed: int, truncation_psi: float) -> np.ndarray:
76
+ z = self.generate_z(seed)
77
+ label = torch.zeros([1, self.model.c_dim], device=self.device)
78
+
79
+ out = self.generate(z, label, truncation_psi)
80
+ out = self.postprocess(out)
81
+ return out[0]
82
+
83
+ def set_model_and_generate_image(self, model_name: str, seed: int,
84
+ truncation_psi: float) -> np.ndarray:
85
+ self.set_model(model_name)
86
+ return self.generate_image(seed, truncation_psi)
style.css ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ h1 {
2
+ text-align: center;
3
+ }
4
+ div#result {
5
+ max-width: 600px;
6
+ max-height: 600px;
7
+ }
8
+ img#visitor-badge {
9
+ display: block;
10
+ margin: auto;
11
+ }