hysts HF staff commited on
Commit
1dd3178
1 Parent(s): 57ff62a
Files changed (2) hide show
  1. app.py +99 -0
  2. requirements.txt +4 -0
app.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from __future__ import annotations
4
+
5
+ import functools
6
+ import os
7
+ import pickle
8
+ import sys
9
+
10
+ sys.path.insert(0, 'stylegan3')
11
+
12
+ import gradio as gr
13
+ import numpy as np
14
+ import PIL.Image
15
+ import torch
16
+ from huggingface_hub import hf_hub_download
17
+
18
+ MODEL_REPO = 'hysts/stylegan3-anime-face-exp001-model'
19
+ MODEL_FILE_NAME = '006600.pkl'
20
+ TOKEN = os.environ['TOKEN']
21
+
22
+ DEFAULT_SEED = 3407851645
23
+
24
+ TITLE = 'StyleGAN3 Anime Face Generation'
25
+
26
+
27
+ def make_transform(translate: tuple[float, float], angle: float) -> np.ndarray:
28
+ mat = np.eye(3)
29
+ sin = np.sin(angle / 360 * np.pi * 2)
30
+ cos = np.cos(angle / 360 * np.pi * 2)
31
+ mat[0][0] = cos
32
+ mat[0][1] = sin
33
+ mat[0][2] = translate[0]
34
+ mat[1][0] = -sin
35
+ mat[1][1] = cos
36
+ mat[1][2] = translate[1]
37
+ return mat
38
+
39
+
40
+ def generate_z(seed, device):
41
+ return torch.from_numpy(np.random.RandomState(seed).randn(1,
42
+ 512)).to(device)
43
+
44
+
45
+ @torch.inference_mode()
46
+ def generate_image(seed, truncation_psi, tx, ty, angle, model, device):
47
+ seed = int(np.clip(seed, 0, np.iinfo(np.uint32).max))
48
+ z = generate_z(seed, device)
49
+ c = torch.zeros(0).to(device)
50
+
51
+ mat = make_transform((tx, ty), angle)
52
+ mat = np.linalg.inv(mat)
53
+ model.synthesis.input.transform.copy_(torch.from_numpy(mat))
54
+
55
+ out = model(z, c, truncation_psi=truncation_psi)
56
+ out = (out.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
57
+ return PIL.Image.fromarray(out[0].cpu().numpy(), 'RGB')
58
+
59
+
60
+ def load_model(device):
61
+ path = hf_hub_download(MODEL_REPO, MODEL_FILE_NAME, use_auth_token=TOKEN)
62
+ with open(path, 'rb') as f:
63
+ model = pickle.load(f)
64
+ model.eval()
65
+ model.to(device)
66
+ with torch.inference_mode():
67
+ z = torch.zeros((1, 512)).to(device)
68
+ c = torch.zeros(0).to(device)
69
+ model(z, c)
70
+ return model
71
+
72
+
73
+ def main():
74
+ device = torch.device('cpu')
75
+
76
+ model = load_model(device)
77
+ func = functools.partial(generate_image, model=model, device=device)
78
+ func = functools.update_wrapper(func, generate_image)
79
+
80
+ gr.Interface(
81
+ func,
82
+ [
83
+ gr.inputs.Number(default=DEFAULT_SEED, label='Seed'),
84
+ gr.inputs.Slider(
85
+ 0, 2, step=0.05, default=0.7, label='Truncation psi'),
86
+ gr.inputs.Slider(-1, 1, step=0.05, default=0, label='Translate X'),
87
+ gr.inputs.Slider(-1, 1, step=0.05, default=0, label='Translate Y'),
88
+ gr.inputs.Slider(-180, 180, step=5, default=0, label='Angle'),
89
+ ],
90
+ gr.outputs.Image(type='pil', label='Output'),
91
+ title=TITLE,
92
+ enable_queue=True,
93
+ allow_screenshot=False,
94
+ allow_flagging=False,
95
+ ).launch()
96
+
97
+
98
+ if __name__ == '__main__':
99
+ main()
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ numpy>=1.21.4
2
+ Pillow>=8.3.1
3
+ scipy>=1.7.2
4
+ torch>=1.10.0