hysts HF staff commited on
Commit
518857d
1 Parent(s): 3c59a24
Files changed (7) hide show
  1. .gitattributes +1 -0
  2. .gitmodules +3 -0
  3. MobileStyleGAN.pytorch +1 -0
  4. app.py +117 -0
  5. model.py +47 -0
  6. requirements.txt +8 -0
  7. samples/ffhq.jpg +3 -0
.gitattributes CHANGED
@@ -1,3 +1,4 @@
 
1
  *.7z filter=lfs diff=lfs merge=lfs -text
2
  *.arrow filter=lfs diff=lfs merge=lfs -text
3
  *.bin filter=lfs diff=lfs merge=lfs -text
 
1
+ *.jpg filter=lfs diff=lfs merge=lfs -text
2
  *.7z filter=lfs diff=lfs merge=lfs -text
3
  *.arrow filter=lfs diff=lfs merge=lfs -text
4
  *.bin filter=lfs diff=lfs merge=lfs -text
.gitmodules ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [submodule "MobileStyleGAN.pytorch"]
2
+ path = MobileStyleGAN.pytorch
3
+ url = https://github.com/bes-dev/MobileStyleGAN.pytorch
MobileStyleGAN.pytorch ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit a9776ff8f05a868b2d3b637bda14eca4c074d2a3
app.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from __future__ import annotations
4
+
5
+ import argparse
6
+ import functools
7
+ import os
8
+
9
+ import gradio as gr
10
+ import numpy as np
11
+ import torch
12
+ import torch.nn as nn
13
+ from huggingface_hub import hf_hub_download
14
+
15
+ from model import Model
16
+
17
+ ORIGINAL_REPO_URL = 'https://github.com/bes-dev/MobileStyleGAN.pytorch'
18
+ TITLE = 'bes-dev/MobileStyleGAN.pytorch'
19
+ DESCRIPTION = f'This is a demo for {ORIGINAL_REPO_URL}.'
20
+ SAMPLE_IMAGE_DIR = 'https://huggingface.co/spaces/hysts/MobileStyleGAN/resolve/main/samples'
21
+ ARTICLE = f'''## Generated images
22
+ ### FFHQ
23
+ - size: 1024x1024
24
+ - seed: 0-99
25
+ - truncation: 1.0
26
+ ![FFHQ]({SAMPLE_IMAGE_DIR}/ffhq.jpg)
27
+ '''
28
+
29
+ TOKEN = os.environ['TOKEN']
30
+
31
+
32
+ def parse_args() -> argparse.Namespace:
33
+ parser = argparse.ArgumentParser()
34
+ parser.add_argument('--device', type=str, default='cpu')
35
+ parser.add_argument('--theme', type=str)
36
+ parser.add_argument('--live', action='store_true')
37
+ parser.add_argument('--share', action='store_true')
38
+ parser.add_argument('--port', type=int)
39
+ parser.add_argument('--disable-queue',
40
+ dest='enable_queue',
41
+ action='store_false')
42
+ parser.add_argument('--allow-flagging', type=str, default='never')
43
+ parser.add_argument('--allow-screenshot', action='store_true')
44
+ return parser.parse_args()
45
+
46
+
47
+ def generate_z(z_dim: int, seed: int, device: torch.device) -> torch.Tensor:
48
+ return torch.from_numpy(np.random.RandomState(seed).randn(
49
+ 1, z_dim)).to(device).float()
50
+
51
+
52
+ @torch.inference_mode()
53
+ def generate_image(seed: int, truncation_psi: float, generator: str,
54
+ model: nn.Module, device: torch.device) -> np.ndarray:
55
+ seed = int(np.clip(seed, 0, np.iinfo(np.uint32).max))
56
+
57
+ z = generate_z(model.mapping_net.style_dim, seed, device)
58
+
59
+ out = model(z, truncation_psi=truncation_psi, generator=generator)
60
+ out = (out.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
61
+ return out[0].cpu().numpy()
62
+
63
+
64
+ def load_model(device: torch.device) -> nn.Module:
65
+ path = hf_hub_download('hysts/MobileStyleGAN',
66
+ 'models/mobilestylegan_ffhq_v2.pth',
67
+ use_auth_token=TOKEN)
68
+ ckpt = torch.load(path)
69
+ model = Model()
70
+ model.load_state_dict(ckpt['state_dict'], strict=False)
71
+ model.eval()
72
+ model.to(device)
73
+ with torch.inference_mode():
74
+ z = torch.zeros((1, model.mapping_net.style_dim)).to(device)
75
+ model(z)
76
+ return model
77
+
78
+
79
+ def main():
80
+ gr.close_all()
81
+
82
+ args = parse_args()
83
+ device = torch.device(args.device)
84
+
85
+ model = load_model(device)
86
+
87
+ func = functools.partial(generate_image, model=model, device=device)
88
+ func = functools.update_wrapper(func, generate_image)
89
+
90
+ gr.Interface(
91
+ func,
92
+ [
93
+ gr.inputs.Number(default=0, label='Seed'),
94
+ gr.inputs.Slider(
95
+ 0, 2, step=0.05, default=1.0, label='Truncation psi'),
96
+ gr.inputs.Radio(['student', 'teacher'],
97
+ type='value',
98
+ default='student',
99
+ label='Generator'),
100
+ ],
101
+ gr.outputs.Image(type='numpy', label='Output'),
102
+ title=TITLE,
103
+ description=DESCRIPTION,
104
+ article=ARTICLE,
105
+ theme=args.theme,
106
+ allow_screenshot=args.allow_screenshot,
107
+ allow_flagging=args.allow_flagging,
108
+ live=args.live,
109
+ ).launch(
110
+ enable_queue=args.enable_queue,
111
+ server_port=args.port,
112
+ share=args.share,
113
+ )
114
+
115
+
116
+ if __name__ == '__main__':
117
+ main()
model.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ sys.path.insert(0, 'MobileStyleGAN.pytorch')
7
+
8
+ from core.models.mapping_network import MappingNetwork
9
+ from core.models.mobile_synthesis_network import MobileSynthesisNetwork
10
+ from core.models.synthesis_network import SynthesisNetwork
11
+
12
+
13
+ class Model(nn.Module):
14
+
15
+ def __init__(self):
16
+ super().__init__()
17
+ # teacher model
18
+ mapping_net_params = {'style_dim': 512, 'n_layers': 8, 'lr_mlp': 0.01}
19
+ synthesis_net_params = {
20
+ 'size': 1024,
21
+ 'style_dim': 512,
22
+ 'blur_kernel': [1, 3, 3, 1],
23
+ 'channels': [512, 512, 512, 512, 512, 256, 128, 64, 32]
24
+ }
25
+ self.mapping_net = MappingNetwork(**mapping_net_params).eval()
26
+ self.synthesis_net = SynthesisNetwork(**synthesis_net_params).eval()
27
+ # student network
28
+ self.student = MobileSynthesisNetwork(
29
+ style_dim=self.mapping_net.style_dim,
30
+ channels=synthesis_net_params['channels'][:-1])
31
+
32
+ self.style_mean = nn.Parameter(torch.zeros((1, 512)),
33
+ requires_grad=False)
34
+
35
+ def forward(self,
36
+ var: torch.Tensor,
37
+ truncation_psi: float = 0.5,
38
+ generator: str = 'student') -> torch.Tensor:
39
+ style = self.mapping_net(var)
40
+ style = self.style_mean + truncation_psi * (style - self.style_mean)
41
+ if generator == 'student':
42
+ img = self.student(style)['img']
43
+ elif generator == 'teacher':
44
+ img = self.synthesis_net(style)['img']
45
+ else:
46
+ raise ValueError
47
+ return img
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ numpy==1.22.3
2
+ Pillow==9.0.1
3
+ PyWavelets==1.2.0
4
+ piq==0.6.0
5
+ scipy==1.8.0
6
+ torch==1.11.0
7
+ torchvision==0.12.0
8
+ git+https://github.com/fbcotter/pytorch_wavelets.git
samples/ffhq.jpg ADDED

Git LFS Details

  • SHA256: a2e91194dfbc9948235876ca480ab85431a159117c771a65715082608624faf8
  • Pointer size: 133 Bytes
  • Size of remote file: 28.3 MB