hysts commited on
Commit
161d647
β€’
1 Parent(s): f5fa003
Files changed (7) hide show
  1. .gitattributes +1 -0
  2. .gitmodules +3 -0
  3. app.py +124 -0
  4. patch +85 -0
  5. requirements.txt +4 -0
  6. samples/sample.jpg +3 -0
  7. stylegan2-pytorch +1 -0
.gitattributes CHANGED
@@ -1,3 +1,4 @@
 
1
  *.7z filter=lfs diff=lfs merge=lfs -text
2
  *.arrow filter=lfs diff=lfs merge=lfs -text
3
  *.bin filter=lfs diff=lfs merge=lfs -text
1
+ *.jpg filter=lfs diff=lfs merge=lfs -text
2
  *.7z filter=lfs diff=lfs merge=lfs -text
3
  *.arrow filter=lfs diff=lfs merge=lfs -text
4
  *.bin filter=lfs diff=lfs merge=lfs -text
.gitmodules ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ [submodule "stylegan2-pytorch"]
2
+ path = stylegan2-pytorch
3
+ url = https://github.com/rosinality/stylegan2-pytorch
app.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from __future__ import annotations
4
+
5
+ import argparse
6
+ import functools
7
+ import os
8
+ import subprocess
9
+ import sys
10
+
11
+ import gradio as gr
12
+ import numpy as np
13
+ import torch
14
+ import torch.nn as nn
15
+ from huggingface_hub import hf_hub_download
16
+
17
+ if os.environ.get('SYSTEM') == 'spaces':
18
+ subprocess.call('git apply ../patch'.split(), cwd='stylegan2-pytorch')
19
+
20
+ sys.path.insert(0, 'stylegan2-pytorch')
21
+
22
+ from model import Generator
23
+
24
+ TITLE = 'TADNE (This Anime Does Not Exist)'
25
+ DESCRIPTION = '''The original TADNE site is https://thisanimedoesnotexist.ai/.
26
+ The model used here is the one converted from the model provided in [this site](https://www.gwern.net/Faces) using [this repo](https://github.com/rosinality/stylegan2-pytorch).
27
+ '''
28
+ SAMPLE_IMAGE_DIR = 'https://huggingface.co/spaces/hysts/TADNE/resolve/main/samples'
29
+ ARTICLE = f'''## Generated images
30
+ - size: 512x512
31
+ - truncation: 0.7
32
+ - seed: 0-99
33
+ ![samples]({SAMPLE_IMAGE_DIR}/sample.jpg)
34
+ '''
35
+
36
+ TOKEN = os.environ['TOKEN']
37
+
38
+
39
+ def parse_args() -> argparse.Namespace:
40
+ parser = argparse.ArgumentParser()
41
+ parser.add_argument('--device', type=str, default='cpu')
42
+ parser.add_argument('--theme', type=str)
43
+ parser.add_argument('--live', action='store_true')
44
+ parser.add_argument('--share', action='store_true')
45
+ parser.add_argument('--port', type=int)
46
+ parser.add_argument('--disable-queue',
47
+ dest='enable_queue',
48
+ action='store_false')
49
+ parser.add_argument('--allow-flagging', type=str, default='never')
50
+ parser.add_argument('--allow-screenshot', action='store_true')
51
+ return parser.parse_args()
52
+
53
+
54
+ def load_model(device: torch.device) -> nn.Module:
55
+ model = Generator(512, 1024, 4, channel_multiplier=2)
56
+ path = hf_hub_download('hysts/TADNE',
57
+ 'models/aydao-anime-danbooru2019s-512-5268480.pt',
58
+ use_auth_token=TOKEN)
59
+ checkpoint = torch.load(path)
60
+ model.load_state_dict(checkpoint['g_ema'])
61
+ model.eval()
62
+ model.to(device)
63
+ model.latent_avg = checkpoint['latent_avg'].to(device)
64
+ with torch.inference_mode():
65
+ z = torch.zeros((1, model.style_dim)).to(device)
66
+ model([z], truncation=0.7, truncation_latent=model.latent_avg)
67
+ return model
68
+
69
+
70
+ def generate_z(z_dim: int, seed: int, device: torch.device) -> torch.Tensor:
71
+ return torch.from_numpy(np.random.RandomState(seed).randn(
72
+ 1, z_dim)).to(device).float()
73
+
74
+
75
+ @torch.inference_mode()
76
+ def generate_image(seed: int, truncation_psi: float, randomize_noise: bool,
77
+ model: nn.Module, device: torch.device) -> np.ndarray:
78
+ seed = int(np.clip(seed, 0, np.iinfo(np.uint32).max))
79
+
80
+ z = generate_z(model.style_dim, seed, device)
81
+ out, _ = model([z],
82
+ truncation=truncation_psi,
83
+ truncation_latent=model.latent_avg,
84
+ randomize_noise=randomize_noise)
85
+ out = (out.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
86
+ return out[0].cpu().numpy()
87
+
88
+
89
+ def main():
90
+ gr.close_all()
91
+
92
+ args = parse_args()
93
+ device = torch.device(args.device)
94
+
95
+ model = load_model(device)
96
+
97
+ func = functools.partial(generate_image, model=model, device=device)
98
+ func = functools.update_wrapper(func, generate_image)
99
+
100
+ gr.Interface(
101
+ func,
102
+ [
103
+ gr.inputs.Number(default=0, label='Seed'),
104
+ gr.inputs.Slider(
105
+ 0, 2, step=0.05, default=0.7, label='Truncation psi'),
106
+ gr.inputs.Checkbox(default=False, label='Randomize Noise'),
107
+ ],
108
+ gr.outputs.Image(type='numpy', label='Output'),
109
+ title=TITLE,
110
+ description=DESCRIPTION,
111
+ article=ARTICLE,
112
+ theme=args.theme,
113
+ allow_screenshot=args.allow_screenshot,
114
+ allow_flagging=args.allow_flagging,
115
+ live=args.live,
116
+ ).launch(
117
+ enable_queue=args.enable_queue,
118
+ server_port=args.port,
119
+ share=args.share,
120
+ )
121
+
122
+
123
+ if __name__ == '__main__':
124
+ main()
patch ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ diff --git a/model.py b/model.py
2
+ index 0134c39..3a7826c 100755
3
+ --- a/model.py
4
+ +++ b/model.py
5
+ @@ -395,6 +395,7 @@ class Generator(nn.Module):
6
+ style_dim,
7
+ n_mlp,
8
+ channel_multiplier=2,
9
+ + additional_multiplier=2,
10
+ blur_kernel=[1, 3, 3, 1],
11
+ lr_mlp=0.01,
12
+ ):
13
+ @@ -426,6 +427,9 @@ class Generator(nn.Module):
14
+ 512: 32 * channel_multiplier,
15
+ 1024: 16 * channel_multiplier,
16
+ }
17
+ + if additional_multiplier > 1:
18
+ + for k in list(self.channels.keys()):
19
+ + self.channels[k] *= additional_multiplier
20
+
21
+ self.input = ConstantInput(self.channels[4])
22
+ self.conv1 = StyledConv(
23
+ @@ -518,7 +522,7 @@ class Generator(nn.Module):
24
+ getattr(self.noises, f"noise_{i}") for i in range(self.num_layers)
25
+ ]
26
+
27
+ - if truncation < 1:
28
+ + if truncation_latent is not None:
29
+ style_t = []
30
+
31
+ for style in styles:
32
+ diff --git a/op/fused_act.py b/op/fused_act.py
33
+ index 5d46e10..bc522ed 100755
34
+ --- a/op/fused_act.py
35
+ +++ b/op/fused_act.py
36
+ @@ -1,5 +1,3 @@
37
+ -import os
38
+ -
39
+ import torch
40
+ from torch import nn
41
+ from torch.nn import functional as F
42
+ @@ -7,16 +5,6 @@ from torch.autograd import Function
43
+ from torch.utils.cpp_extension import load
44
+
45
+
46
+ -module_path = os.path.dirname(__file__)
47
+ -fused = load(
48
+ - "fused",
49
+ - sources=[
50
+ - os.path.join(module_path, "fused_bias_act.cpp"),
51
+ - os.path.join(module_path, "fused_bias_act_kernel.cu"),
52
+ - ],
53
+ -)
54
+ -
55
+ -
56
+ class FusedLeakyReLUFunctionBackward(Function):
57
+ @staticmethod
58
+ def forward(ctx, grad_output, out, bias, negative_slope, scale):
59
+ diff --git a/op/upfirdn2d.py b/op/upfirdn2d.py
60
+ index 67e0375..6c5840e 100755
61
+ --- a/op/upfirdn2d.py
62
+ +++ b/op/upfirdn2d.py
63
+ @@ -1,5 +1,4 @@
64
+ from collections import abc
65
+ -import os
66
+
67
+ import torch
68
+ from torch.nn import functional as F
69
+ @@ -7,16 +6,6 @@ from torch.autograd import Function
70
+ from torch.utils.cpp_extension import load
71
+
72
+
73
+ -module_path = os.path.dirname(__file__)
74
+ -upfirdn2d_op = load(
75
+ - "upfirdn2d",
76
+ - sources=[
77
+ - os.path.join(module_path, "upfirdn2d.cpp"),
78
+ - os.path.join(module_path, "upfirdn2d_kernel.cu"),
79
+ - ],
80
+ -)
81
+ -
82
+ -
83
+ class UpFirDn2dBackward(Function):
84
+ @staticmethod
85
+ def forward(
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
1
+ numpy==1.22.3
2
+ Pillow==9.0.1
3
+ torch==1.11.0
4
+ torchvision==0.12.0
samples/sample.jpg ADDED

Git LFS Details

  • SHA256: 2973e030ad45dc61759f6dcb1e2ffdff7b2aefb16704ba0551e7b80f6d289a09
  • Pointer size: 132 Bytes
  • Size of remote file: 8.23 MB
stylegan2-pytorch ADDED
@@ -0,0 +1 @@
 
1
+ Subproject commit bef283a1c24087da704d16c30abc8e36e63efa0e