endo-yuki-t commited on
Commit
12979fc
1 Parent(s): 715c849
Files changed (2) hide show
  1. README.md +1 -0
  2. models/StyleGANControler.py +4 -1
README.md CHANGED
@@ -31,6 +31,7 @@ The latent transformer can be trained with
31
  ```
32
  python scripts/train.py --exp_dir=results --stylegan_weights=pretrained_models/stylegan2-cat-config-f.pt
33
  ```
 
34
 
35
  ## Citation
36
  Please cite our paper if you find the code useful:
 
31
  ```
32
  python scripts/train.py --exp_dir=results --stylegan_weights=pretrained_models/stylegan2-cat-config-f.pt
33
  ```
34
+ To perform training with your dataset, you need first to train StyleGAN2 on your dataset using [rosinality's code](https://github.com/rosinality/stylegan2-pytorch) and then run the above script with specifying the trained weights.
35
 
36
  ## Citation
37
  Please cite our paper if you find the code useful:
models/StyleGANControler.py CHANGED
@@ -28,14 +28,17 @@ class StyleGANControler(nn.Module):
28
  self.style_num = 14
29
  elif 'anime' in self.opts.stylegan_weights:
30
  self.style_num = 16
 
 
31
 
32
  self.encoder = self.set_encoder()
33
  if self.style_num==18:
34
- self.decoder = Generator(1024, 512, 8, channel_multiplier=2)
35
  elif self.style_num==16:
36
  self.decoder = Generator(512, 512, 8, channel_multiplier=2)
37
  elif self.style_num==14:
38
  self.decoder = Generator(256, 512, 8, channel_multiplier=2)
 
39
  self.face_pool = torch.nn.AdaptiveAvgPool2d((256, 256))
40
 
41
  # Load weights if needed
 
28
  self.style_num = 14
29
  elif 'anime' in self.opts.stylegan_weights:
30
  self.style_num = 16
31
+ else:
32
+ self.style_num = 18 #Please modify to adjust network architecture to your pre-trained StyleGAN2
33
 
34
  self.encoder = self.set_encoder()
35
  if self.style_num==18:
36
+ self.decoder = Generator(1024, 512, 8, channel_multiplier=2)
37
  elif self.style_num==16:
38
  self.decoder = Generator(512, 512, 8, channel_multiplier=2)
39
  elif self.style_num==14:
40
  self.decoder = Generator(256, 512, 8, channel_multiplier=2)
41
+
42
  self.face_pool = torch.nn.AdaptiveAvgPool2d((256, 256))
43
 
44
  # Load weights if needed