hysts HF staff commited on
Commit
7db4b63
1 Parent(s): 8a48850
Files changed (4) hide show
  1. .gitmodules +3 -0
  2. StyleGAN-Human +1 -0
  3. app.py +175 -0
  4. requirements.txt +5 -0
.gitmodules ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [submodule "StyleGAN-Human"]
2
+ path = StyleGAN-Human
3
+ url = https://github.com/stylegan-human/StyleGAN-Human
StyleGAN-Human ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit ccab82ad02088debe106872d8d73f9fc4b250ad1
app.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from __future__ import annotations
4
+
5
+ import argparse
6
+ import os
7
+ import pickle
8
+ import sys
9
+
10
+ import gradio as gr
11
+ import numpy as np
12
+ import torch
13
+ import torch.nn as nn
14
+ from huggingface_hub import hf_hub_download
15
+
16
+ sys.path.insert(0, 'StyleGAN-Human')
17
+
18
+ TOKEN = os.environ['TOKEN']
19
+
20
+
21
+ def parse_args() -> argparse.Namespace:
22
+ parser = argparse.ArgumentParser()
23
+ parser.add_argument('--device', type=str, default='cpu')
24
+ parser.add_argument('--theme', type=str)
25
+ parser.add_argument('--share', action='store_true')
26
+ parser.add_argument('--port', type=int)
27
+ parser.add_argument('--disable-queue',
28
+ dest='enable_queue',
29
+ action='store_false')
30
+ return parser.parse_args()
31
+
32
+
33
+ class App:
34
+
35
+ def __init__(self, device: torch.device):
36
+ self.device = device
37
+ self.model = self.load_model('stylegan_human_v2_1024.pkl')
38
+
39
+ def load_model(self, file_name: str) -> nn.Module:
40
+ path = hf_hub_download('hysts/StyleGAN-Human',
41
+ f'models/{file_name}',
42
+ use_auth_token=TOKEN)
43
+ with open(path, 'rb') as f:
44
+ model = pickle.load(f)['G_ema']
45
+ model.eval()
46
+ model.to(self.device)
47
+ with torch.inference_mode():
48
+ z = torch.zeros((1, model.z_dim)).to(self.device)
49
+ label = torch.zeros([1, model.c_dim], device=self.device)
50
+ model(z, label, force_fp32=True)
51
+ return model
52
+
53
+ def generate_z(self, z_dim: int, seed: int) -> torch.Tensor:
54
+ return torch.from_numpy(np.random.RandomState(seed).randn(
55
+ 1, z_dim)).to(self.device).float()
56
+
57
+ @torch.inference_mode()
58
+ def generate_single_image(self, seed: int,
59
+ truncation_psi: float) -> np.ndarray:
60
+ seed = int(np.clip(seed, 0, np.iinfo(np.uint32).max))
61
+
62
+ z = self.generate_z(self.model.z_dim, seed)
63
+ label = torch.zeros([1, self.model.c_dim], device=self.device)
64
+
65
+ out = self.model(z,
66
+ label,
67
+ truncation_psi=truncation_psi,
68
+ force_fp32=True)
69
+ out = (out.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(
70
+ torch.uint8)
71
+ return out[0].cpu().numpy()
72
+
73
+ @torch.inference_mode()
74
+ def generate_interpolated_images(
75
+ self, seed0: int, psi0: float, seed1: int, psi1: float,
76
+ num_intermediate: int) -> tuple[list[np.ndarray], np.ndarray]:
77
+ seed0 = int(np.clip(seed0, 0, np.iinfo(np.uint32).max))
78
+ seed1 = int(np.clip(seed1, 0, np.iinfo(np.uint32).max))
79
+
80
+ z0 = self.generate_z(self.model.z_dim, seed0)
81
+ z1 = self.generate_z(self.model.z_dim, seed1)
82
+ vec = z1 - z0
83
+ dvec = vec / (num_intermediate + 1)
84
+ zs = [z0 + dvec * i for i in range(num_intermediate + 2)]
85
+ dpsi = (psi1 - psi0) / (num_intermediate + 1)
86
+ psis = [psi0 + dpsi * i for i in range(num_intermediate + 2)]
87
+
88
+ label = torch.zeros([1, self.model.c_dim], device=self.device)
89
+
90
+ res = []
91
+ for z, psi in zip(zs, psis):
92
+ out = self.model(z, label, truncation_psi=psi, force_fp32=True)
93
+ out = (out.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(
94
+ torch.uint8)
95
+ out = out[0].cpu().numpy()
96
+ res.append(out)
97
+ return res
98
+
99
+
100
+ def main():
101
+ args = parse_args()
102
+ app = App(device=torch.device(args.device))
103
+
104
+ with gr.Blocks(theme=args.theme) as demo:
105
+ gr.Markdown('''<center><h1>StyleGAN-Human</h1></center>
106
+
107
+ This is a Blocks version of [this app](https://huggingface.co/spaces/hysts/StyleGAN-Human) and [this app](https://huggingface.co/spaces/hysts/StyleGAN-Human-Interpolation).
108
+ ''')
109
+
110
+ with gr.Row():
111
+ with gr.Column():
112
+ with gr.Row():
113
+ seed1 = gr.Number(value=6876, label='Seed 1')
114
+ psi1 = gr.Slider(0,
115
+ 2,
116
+ value=0.7,
117
+ step=0.05,
118
+ label='Truncation psi 1')
119
+ with gr.Row():
120
+ generate_button1 = gr.Button('Generate')
121
+ with gr.Row():
122
+ generated_image1 = gr.Image(type='numpy',
123
+ label='Generated Image 1')
124
+
125
+ with gr.Column():
126
+ with gr.Row():
127
+ seed2 = gr.Number(value=6886, label='Seed 2')
128
+ psi2 = gr.Slider(0,
129
+ 2,
130
+ value=0.7,
131
+ step=0.05,
132
+ label='Truncation psi 2')
133
+ with gr.Row():
134
+ generate_button2 = gr.Button('Generate')
135
+ with gr.Row():
136
+ generated_image2 = gr.Image(type='numpy',
137
+ label='Generated Image 2')
138
+
139
+ with gr.Row():
140
+ with gr.Column():
141
+ with gr.Row():
142
+ num_frames = gr.Slider(
143
+ 0,
144
+ 41,
145
+ value=7,
146
+ step=1,
147
+ label='Number of Intermediate Frames')
148
+ with gr.Row():
149
+ interpolate_button = gr.Button('Interpolate')
150
+ with gr.Row():
151
+ interpolated_images = gr.Gallery(label='Output Images')
152
+
153
+ gr.Markdown(
154
+ '<center><img src="https://visitor-badge.glitch.me/badge?page_id=gradio-blocks.stylegan-human" alt="visitor badge"/></center>'
155
+ )
156
+
157
+ generate_button1.click(app.generate_single_image,
158
+ inputs=[seed1, psi1],
159
+ outputs=generated_image1)
160
+ generate_button2.click(app.generate_single_image,
161
+ inputs=[seed2, psi2],
162
+ outputs=generated_image2)
163
+ interpolate_button.click(app.generate_interpolated_images,
164
+ inputs=[seed1, psi1, seed2, psi2, num_frames],
165
+ outputs=interpolated_images)
166
+
167
+ demo.launch(
168
+ enable_queue=args.enable_queue,
169
+ server_port=args.port,
170
+ share=args.share,
171
+ )
172
+
173
+
174
+ if __name__ == '__main__':
175
+ main()
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ numpy==1.22.3
2
+ Pillow==9.1.0
3
+ scipy==1.8.0
4
+ torch==1.11.0
5
+ torchvision==0.12.0