Spaces:
Runtime error
Runtime error
endo-yuki-t
commited on
Commit
•
12979fc
1
Parent(s):
715c849
minor fix
Browse files- README.md +1 -0
- models/StyleGANControler.py +4 -1
README.md
CHANGED
@@ -31,6 +31,7 @@ The latent transformer can be trained with
|
|
31 |
```
|
32 |
python scripts/train.py --exp_dir=results --stylegan_weights=pretrained_models/stylegan2-cat-config-f.pt
|
33 |
```
|
|
|
34 |
|
35 |
## Citation
|
36 |
Please cite our paper if you find the code useful:
|
|
|
31 |
```
|
32 |
python scripts/train.py --exp_dir=results --stylegan_weights=pretrained_models/stylegan2-cat-config-f.pt
|
33 |
```
|
34 |
+
To perform training with your dataset, you need first to train StyleGAN2 on your dataset using [rosinality's code](https://github.com/rosinality/stylegan2-pytorch) and then run the above script with specifying the trained weights.
|
35 |
|
36 |
## Citation
|
37 |
Please cite our paper if you find the code useful:
|
models/StyleGANControler.py
CHANGED
@@ -28,14 +28,17 @@ class StyleGANControler(nn.Module):
|
|
28 |
self.style_num = 14
|
29 |
elif 'anime' in self.opts.stylegan_weights:
|
30 |
self.style_num = 16
|
|
|
|
|
31 |
|
32 |
self.encoder = self.set_encoder()
|
33 |
if self.style_num==18:
|
34 |
-
self.decoder = Generator(1024, 512, 8, channel_multiplier=2)
|
35 |
elif self.style_num==16:
|
36 |
self.decoder = Generator(512, 512, 8, channel_multiplier=2)
|
37 |
elif self.style_num==14:
|
38 |
self.decoder = Generator(256, 512, 8, channel_multiplier=2)
|
|
|
39 |
self.face_pool = torch.nn.AdaptiveAvgPool2d((256, 256))
|
40 |
|
41 |
# Load weights if needed
|
|
|
28 |
self.style_num = 14
|
29 |
elif 'anime' in self.opts.stylegan_weights:
|
30 |
self.style_num = 16
|
31 |
+
else:
|
32 |
+
self.style_num = 18 #Please modify to adjust network architecture to your pre-trained StyleGAN2
|
33 |
|
34 |
self.encoder = self.set_encoder()
|
35 |
if self.style_num==18:
|
36 |
+
self.decoder = Generator(1024, 512, 8, channel_multiplier=2)
|
37 |
elif self.style_num==16:
|
38 |
self.decoder = Generator(512, 512, 8, channel_multiplier=2)
|
39 |
elif self.style_num==14:
|
40 |
self.decoder = Generator(256, 512, 8, channel_multiplier=2)
|
41 |
+
|
42 |
self.face_pool = torch.nn.AdaptiveAvgPool2d((256, 256))
|
43 |
|
44 |
# Load weights if needed
|