hysts HF staff commited on
Commit
6815db7
1 Parent(s): 6b930f7
Files changed (4) hide show
  1. .gitmodules +3 -0
  2. app.py +161 -0
  3. requirements.txt +6 -0
  4. stylegan_xl +1 -0
.gitmodules ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ [submodule "stylegan_xl"]
2
+ path = stylegan_xl
3
+ url = https://github.com/autonomousvision/stylegan_xl
app.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from __future__ import annotations
4
+
5
+ import argparse
6
+ import functools
7
+ import os
8
+ import pickle
9
+ import sys
10
+
11
+ sys.path.insert(0, 'stylegan_xl')
12
+
13
+ import gradio as gr
14
+ import numpy as np
15
+ import torch
16
+ import torch.nn as nn
17
+ from huggingface_hub import hf_hub_download
18
+
19
+ ORIGINAL_REPO_URL = 'https://github.com/autonomousvision/stylegan_xl'
20
+ TITLE = 'autonomousvision/stylegan_xl'
21
+ DESCRIPTION = f'''This is a demo for {ORIGINAL_REPO_URL}.
22
+
23
+ For class-conditional models, you can specify the class index.
24
+ Index-to-label dictionaries for ImageNet and CIFAR-10 can be found [here](https://raw.githubusercontent.com/autonomousvision/stylegan_xl/main/misc/imagenet_idx2labels.txt) and [here](https://www.cs.toronto.edu/~kriz/cifar.html), respectively.
25
+ '''
26
+ ARTICLE = None
27
+
28
+ TOKEN = os.environ['TOKEN']
29
+
30
+
31
+ def parse_args() -> argparse.Namespace:
32
+ parser = argparse.ArgumentParser()
33
+ parser.add_argument('--device', type=str, default='cpu')
34
+ parser.add_argument('--theme', type=str)
35
+ parser.add_argument('--live', action='store_true')
36
+ parser.add_argument('--share', action='store_true')
37
+ parser.add_argument('--port', type=int)
38
+ parser.add_argument('--disable-queue',
39
+ dest='enable_queue',
40
+ action='store_false')
41
+ parser.add_argument('--allow-flagging', type=str, default='never')
42
+ parser.add_argument('--allow-screenshot', action='store_true')
43
+ return parser.parse_args()
44
+
45
+
46
+ def make_transform(translate: tuple[float, float], angle: float) -> np.ndarray:
47
+ mat = np.eye(3)
48
+ sin = np.sin(angle / 360 * np.pi * 2)
49
+ cos = np.cos(angle / 360 * np.pi * 2)
50
+ mat[0][0] = cos
51
+ mat[0][1] = sin
52
+ mat[0][2] = translate[0]
53
+ mat[1][0] = -sin
54
+ mat[1][1] = cos
55
+ mat[1][2] = translate[1]
56
+ return mat
57
+
58
+
59
+ def generate_z(seed: int, device: torch.device) -> torch.Tensor:
60
+ return torch.from_numpy(np.random.RandomState(seed).randn(1,
61
+ 64)).to(device)
62
+
63
+
64
+ @torch.inference_mode()
65
+ def generate_image(model_name: str, class_index: int, seed: int,
66
+ truncation_psi: float, tx: float, ty: float, angle: float,
67
+ model_dict: dict[str, nn.Module],
68
+ device: torch.device) -> np.ndarray:
69
+ model = model_dict[model_name]
70
+ seed = int(np.clip(seed, 0, np.iinfo(np.uint32).max))
71
+
72
+ z = generate_z(seed, device)
73
+
74
+ label = torch.zeros([1, model.c_dim], device=device)
75
+ class_index = round(class_index)
76
+ class_index = min(max(0, class_index), model.c_dim - 1)
77
+ class_index = torch.tensor(class_index, dtype=torch.long)
78
+ if class_index >= 0:
79
+ label[:, class_index] = 1
80
+
81
+ mat = make_transform((tx, ty), angle)
82
+ mat = np.linalg.inv(mat)
83
+ model.synthesis.input.transform.copy_(torch.from_numpy(mat))
84
+
85
+ out = model(z, label, truncation_psi=truncation_psi)
86
+ out = (out.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
87
+ return out[0].cpu().numpy()
88
+
89
+
90
+ def load_model(model_name: str, device: torch.device) -> nn.Module:
91
+ path = hf_hub_download('hysts/StyleGAN-XL',
92
+ f'models/{model_name}.pkl',
93
+ use_auth_token=TOKEN)
94
+ with open(path, 'rb') as f:
95
+ model = pickle.load(f)['G_ema']
96
+ model.eval()
97
+ model.to(device)
98
+ with torch.inference_mode():
99
+ z = torch.zeros((1, 64)).to(device)
100
+ label = torch.zeros([1, model.c_dim], device=device)
101
+ model(z, label)
102
+ return model
103
+
104
+
105
+ def main():
106
+ gr.close_all()
107
+
108
+ args = parse_args()
109
+ device = torch.device(args.device)
110
+
111
+ model_names = [
112
+ 'imagenet16',
113
+ 'imagenet32',
114
+ 'imagenet64',
115
+ 'imagenet128',
116
+ 'cifar10',
117
+ 'ffhq256',
118
+ 'pokemon256',
119
+ ]
120
+
121
+ model_dict = {name: load_model(name, device) for name in model_names}
122
+
123
+ func = functools.partial(generate_image,
124
+ model_dict=model_dict,
125
+ device=device)
126
+ func = functools.update_wrapper(func, generate_image)
127
+
128
+ gr.Interface(
129
+ func,
130
+ [
131
+ gr.inputs.Radio(
132
+ model_names,
133
+ type='value',
134
+ default='imagenet128',
135
+ label='Model',
136
+ ),
137
+ gr.inputs.Number(default=284, label='Class index'),
138
+ gr.inputs.Number(default=0, label='Seed'),
139
+ gr.inputs.Slider(
140
+ 0, 2, step=0.05, default=0.7, label='Truncation psi'),
141
+ gr.inputs.Slider(-1, 1, step=0.05, default=0, label='Translate X'),
142
+ gr.inputs.Slider(-1, 1, step=0.05, default=0, label='Translate Y'),
143
+ gr.inputs.Slider(-180, 180, step=5, default=0, label='Angle'),
144
+ ],
145
+ gr.outputs.Image(type='numpy', label='Output'),
146
+ theme=args.theme,
147
+ title=TITLE,
148
+ description=DESCRIPTION,
149
+ article=ARTICLE,
150
+ allow_screenshot=args.allow_screenshot,
151
+ allow_flagging=args.allow_flagging,
152
+ live=args.live,
153
+ ).launch(
154
+ enable_queue=args.enable_queue,
155
+ server_port=args.port,
156
+ share=args.share,
157
+ )
158
+
159
+
160
+ if __name__ == '__main__':
161
+ main()
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
1
+ ftfy==6.1.1
2
+ numpy==1.22.3
3
+ Pillow==9.0.1
4
+ scipy==1.8.0
5
+ torch==1.11.0
6
+ torchvision==0.12.0
stylegan_xl ADDED
@@ -0,0 +1 @@
 
1
+ Subproject commit 754f491583c96ff1c8ee9c05762aef1835b6d0b9