hysts HF staff commited on
Commit
fa25938
1 Parent(s): b5e3fd5
.gitattributes CHANGED
@@ -1,3 +1,4 @@
 
1
  *.7z filter=lfs diff=lfs merge=lfs -text
2
  *.arrow filter=lfs diff=lfs merge=lfs -text
3
  *.bin filter=lfs diff=lfs merge=lfs -text
1
+ *.pkl filter=lfs diff=lfs merge=lfs -text
2
  *.7z filter=lfs diff=lfs merge=lfs -text
3
  *.arrow filter=lfs diff=lfs merge=lfs -text
4
  *.bin filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # TADNE models
2
+
3
+ - aydao-anime-danbooru2019s-512-5268480.pkl
4
+ - https://drive.google.com/file/d/1A-E_E32WAtTHRlOzjhhYhyyBDXLJN9_H
5
+
6
+ ## Model Conversion
7
+
8
+ The model in the `models` directory is converted with the following repo:
9
+ https://github.com/rosinality/stylegan2-pytorch
10
+
11
+ ### Apply patches
12
+ ```diff
13
+ --- a/model.py
14
+ +++ b/model.py
15
+ @@ -395,6 +395,7 @@ class Generator(nn.Module):
16
+ style_dim,
17
+ n_mlp,
18
+ channel_multiplier=2,
19
+ + additional_multiplier=2,
20
+ blur_kernel=[1, 3, 3, 1],
21
+ lr_mlp=0.01,
22
+ ):
23
+ @@ -426,6 +427,9 @@ class Generator(nn.Module):
24
+ 512: 32 * channel_multiplier,
25
+ 1024: 16 * channel_multiplier,
26
+ }
27
+ + if additional_multiplier > 1:
28
+ + for k in list(self.channels.keys()):
29
+ + self.channels[k] *= additional_multiplier
30
+
31
+ self.input = ConstantInput(self.channels[4])
32
+ self.conv1 = StyledConv(
33
+ @@ -518,7 +522,7 @@ class Generator(nn.Module):
34
+ getattr(self.noises, f"noise_{i}") for i in range(self.num_layers)
35
+ ]
36
+
37
+ - if truncation < 1:
38
+ + if truncation_latent is not None:
39
+ style_t = []
40
+
41
+ for style in styles:
42
+ ```
43
+
44
+ ```diff
45
+ --- a/convert_weight.py
46
+ +++ b/convert_weight.py
47
+ @@ -221,6 +221,7 @@ if __name__ == "__main__":
48
+ default=2,
49
+ help="channel multiplier factor. config-f = 2, else = 1",
50
+ )
51
+ + parser.add_argument("--additional_multiplier", type=int, default=2)
52
+ parser.add_argument("path", metavar="PATH", help="path to the tensorflow weights")
53
+
54
+ args = parser.parse_args()
55
+ @@ -243,7 +244,8 @@ if __name__ == "__main__":
56
+ if layer[0].startswith('Dense'):
57
+ n_mlp += 1
58
+
59
+ - g = Generator(size, 512, n_mlp, channel_multiplier=args.channel_multiplier)
60
+ + style_dim = 512 * args.additional_multiplier
61
+ + g = Generator(size, style_dim, n_mlp, channel_multiplier=args.channel_multiplier, additional_multiplier=args.additional_multiplier)
62
+ state_dict = g.state_dict()
63
+ state_dict = fill_statedict(state_dict, g_ema.vars, size, n_mlp)
64
+
65
+ @@ -254,7 +256,7 @@ if __name__ == "__main__":
66
+ ckpt = {"g_ema": state_dict, "latent_avg": latent_avg}
67
+
68
+ if args.gen:
69
+ - g_train = Generator(size, 512, n_mlp, channel_multiplier=args.channel_multiplier)
70
+ + g_train = Generator(size, style_dim, n_mlp, channel_multiplier=args.channel_multiplier, additional_multiplier=args.additional_multiplier)
71
+ g_train_state = g_train.state_dict()
72
+ g_train_state = fill_statedict(g_train_state, generator.vars, size, n_mlp)
73
+ ckpt["g"] = g_train_state
74
+ @@ -271,9 +273,12 @@ if __name__ == "__main__":
75
+ batch_size = {256: 16, 512: 9, 1024: 4}
76
+ n_sample = batch_size.get(size, 25)
77
+
78
+ + if args.additional_multiplier > 1:
79
+ + n_sample = 2
80
+ +
81
+ g = g.to(device)
82
+
83
+ - z = np.random.RandomState(0).randn(n_sample, 512).astype("float32")
84
+ + z = np.random.RandomState(0).randn(n_sample, style_dim).astype("float32")
85
+
86
+ with torch.no_grad():
87
+ img_pt, _ = g(
88
+ ```
89
+
90
+ ### Build Docker image
91
+
92
+ ```dockerfile
93
+ FROM nvidia/cuda:10.0-cudnn7-devel-ubuntu18.04
94
+
95
+ ENV DEBIAN_FRONTEND=noninteractive
96
+ RUN apt-get update -y && \
97
+ apt-get install -y --no-install-recommends \
98
+ git \
99
+ ninja-build \
100
+ # pyenv dependencies \
101
+ make \
102
+ build-essential \
103
+ libssl-dev \
104
+ zlib1g-dev \
105
+ libbz2-dev \
106
+ libreadline-dev \
107
+ libsqlite3-dev \
108
+ wget \
109
+ curl \
110
+ llvm \
111
+ libncursesw5-dev \
112
+ xz-utils \
113
+ tk-dev \
114
+ libxml2-dev \
115
+ libxmlsec1-dev \
116
+ libffi-dev \
117
+ liblzma-dev && \
118
+ apt-get clean && \
119
+ rm -rf /var/lib/apt/lists/*
120
+
121
+ ARG PYTHON_VERSION=3.7.12
122
+ ENV PYENV_ROOT /opt/pyenv
123
+ ENV PATH ${PYENV_ROOT}/shims:${PYENV_ROOT}/bin:${PATH}
124
+ RUN curl https://pyenv.run | bash
125
+ RUN pyenv install ${PYTHON_VERSION} && \
126
+ pyenv global ${PYTHON_VERSION}
127
+ RUN pip install --no-cache-dir -U requests tqdm opencv-python-headless
128
+ RUN pip install --no-cache-dir -U tensorflow-gpu==1.15.4
129
+ RUN pip install --no-cache-dir -U torch==1.10.2+cu102 torchvision==0.11.3+cu102 -f https://download.pytorch.org/whl/torch/ -f https://download.pytorch.org/whl/torchvision/
130
+ RUN rm -rf ${HOME}/.cache/pip
131
+
132
+ WORKDIR /work
133
+ ENV PYTHONPATH /work/:${PYTHONPATH}
134
+ ```
135
+
136
+ ```bash
137
+ docker build . -t stylegan2_pytorch
138
+ ```
139
+
140
+ ### Convert
141
+ ```bash
142
+ git clone https://github.com/NVLabs/stylegan2
143
+ docker run --rm -it -u $(id -u):$(id -g) -e XDG_CACHE_HOME=/work --ipc host --gpus all -w /work -v `pwd`:/work stylegan2_pytorch python convert_weight.py --repo stylegan2 aydao-anime-danbooru2019s-512-5268480.pkl
144
+ ```
145
+
146
+ ## Usage
147
+ ### Apply patch
148
+ ```diff
149
+ --- a/generate.py
150
+ +++ b/generate.py
151
+ @@ -6,21 +6,25 @@ from model import Generator
152
+ from tqdm import tqdm
153
+
154
+
155
+ -def generate(args, g_ema, device, mean_latent):
156
+ +def generate(args, g_ema, device, mean_latent, randomize_noise):
157
+
158
+ with torch.no_grad():
159
+ g_ema.eval()
160
+ for i in tqdm(range(args.pics)):
161
+ - sample_z = torch.randn(args.sample, args.latent, device=device)
162
+ + samples = []
163
+ + for _ in range(args.split):
164
+ + sample_z = torch.randn(args.sample // args.split, args.latent, device=device)
165
+
166
+ - sample, _ = g_ema(
167
+ - [sample_z], truncation=args.truncation, truncation_latent=mean_latent
168
+ - )
169
+ + sample, _ = g_ema(
170
+ + [sample_z], truncation=args.truncation, truncation_latent=mean_latent,
171
+ + randomize_noise=randomize_noise
172
+ + )
173
+ + samples.extend(sample)
174
+
175
+ utils.save_image(
176
+ - sample,
177
+ - f"sample/{str(i).zfill(6)}.png",
178
+ - nrow=1,
179
+ + samples,
180
+ + f"{args.output_dir}/{str(i).zfill(6)}.{args.ext}",
181
+ + nrow=args.ncol,
182
+ normalize=True,
183
+ range=(-1, 1),
184
+ )
185
+ @@ -30,6 +34,8 @@ if __name__ == "__main__":
186
+ device = "cuda"
187
+
188
+ parser = argparse.ArgumentParser(description="Generate samples from the generator")
189
+ + parser.add_argument("--seed", type=int, default=0)
190
+ + parser.add_argument("--output-dir", '-o', type=str, required=True)
191
+
192
+ parser.add_argument(
193
+ "--size", type=int, default=1024, help="output image size of the generator"
194
+ @@ -37,11 +43,14 @@ if __name__ == "__main__":
195
+ parser.add_argument(
196
+ "--sample",
197
+ type=int,
198
+ - default=1,
199
+ + default=100,
200
+ help="number of samples to be generated for each image",
201
+ )
202
+ + parser.add_argument("--ncol", type=int, default=10)
203
+ + parser.add_argument("--split", type=int, default=4)
204
+ + parser.add_argument("--ext", type=str, default='png')
205
+ parser.add_argument(
206
+ - "--pics", type=int, default=20, help="number of images to be generated"
207
+ + "--pics", type=int, default=1, help="number of images to be generated"
208
+ )
209
+ parser.add_argument("--truncation", type=float, default=1, help="truncation ratio")
210
+ parser.add_argument(
211
+ @@ -62,23 +71,31 @@ if __name__ == "__main__":
212
+ default=2,
213
+ help="channel multiplier of the generator. config-f = 2, else = 1",
214
+ )
215
+ + parser.add_argument("--additional_multiplier", type=int, default=1)
216
+ + parser.add_argument("--load_latent_vec", action='store_true')
217
+ + parser.add_argument("--no-randomize-noise", dest='randomize_noise', action='store_false')
218
+ + parser.add_argument("--n_mlp", type=int, default=8)
219
+
220
+ args = parser.parse_args()
221
+
222
+ - args.latent = 512
223
+ - args.n_mlp = 8
224
+ + seed = args.seed
225
+ + torch.manual_seed(seed)
226
+ + torch.cuda.manual_seed_all(seed)
227
+ +
228
+ + args.latent = 512 * args.additional_multiplier
229
+
230
+ g_ema = Generator(
231
+ - args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier
232
+ + args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier,
233
+ + additional_multiplier=args.additional_multiplier
234
+ ).to(device)
235
+ checkpoint = torch.load(args.ckpt)
236
+
237
+ - g_ema.load_state_dict(checkpoint["g_ema"])
238
+ + g_ema.load_state_dict(checkpoint["g_ema"], strict=True)
239
+
240
+ - if args.truncation < 1:
241
+ + if not args.load_latent_vec:
242
+ with torch.no_grad():
243
+ mean_latent = g_ema.mean_latent(args.truncation_mean)
244
+ else:
245
+ - mean_latent = None
246
+ + mean_latent = checkpoint['latent_avg'].to(device)
247
+
248
+ - generate(args, g_ema, device, mean_latent)
249
+ + generate(args, g_ema, device, mean_latent, randomize_noise=args.randomize_noise)
250
+ ```
251
+
252
+ ### Run
253
+ ```bash
254
+ python generate.py --ckpt aydao-anime-danbooru2019s-512-5268480.pt --size 512 --n_mlp 4 --additional_multiplier 2 --load_latent_vec --no-randomize-noise -o out_images --truncation 0.6 --seed 333 --pics 1 --sample 48 --ncol 8 --ext jpg
255
+ ```
256
+
models/aydao-anime-danbooru2019s-512-5268480.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ecfdff89ad1ff94165b982e1906b2b8cf5fbedab47ac7ba43c91d3513b6b50d5
3
+ size 470194205
orig/aydao-anime-danbooru2019s-512-5268480.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3e15a64f88f93c057da91311d6cce74db540f651f5f69e9bb66ed865321f354c
3
+ size 1056544230