rinong commited on
Commit
f9cb70e
1 Parent(s): 9a29e97

StyleSpace list now uses correct order.

Browse files
Files changed (2) hide show
  1. model/sg2_model.py +28 -27
  2. styleclip/styleclip_global.py +1 -1
model/sg2_model.py CHANGED
@@ -527,38 +527,39 @@ class Generator(nn.Module):
527
  styles = [self.style(s) for s in styles]
528
 
529
  s_codes = [{# const block
530
- self.modulation_layers[0]: self.modulation_layers[0](style[:, 0]),
531
- self.modulation_layers[1]: self.modulation_layers[1](style[:, 1]),
532
  # conv layers
533
- self.modulation_layers[2]: self.modulation_layers[2](style[:, 2]),
534
- self.modulation_layers[3]: self.modulation_layers[3](style[:, 3]),
535
- self.modulation_layers[5]: self.modulation_layers[5](style[:, 4]),
536
- self.modulation_layers[6]: self.modulation_layers[6](style[:, 5]),
537
- self.modulation_layers[8]: self.modulation_layers[8](style[:, 6]),
538
- self.modulation_layers[9]: self.modulation_layers[9](style[:, 7]),
539
- self.modulation_layers[11]: self.modulation_layers[11](style[:, 8]),
540
- self.modulation_layers[12]: self.modulation_layers[12](style[:, 9]),
541
- self.modulation_layers[14]: self.modulation_layers[14](style[:, 10]),
542
- self.modulation_layers[15]: self.modulation_layers[15](style[:, 11]),
543
- self.modulation_layers[17]: self.modulation_layers[17](style[:, 12]),
544
- self.modulation_layers[18]: self.modulation_layers[18](style[:, 13]),
545
- self.modulation_layers[20]: self.modulation_layers[20](style[:, 14]),
546
- self.modulation_layers[21]: self.modulation_layers[21](style[:, 15]),
547
- self.modulation_layers[23]: self.modulation_layers[23](style[:, 16]),
548
- self.modulation_layers[24]: self.modulation_layers[24](style[:, 17]),
549
  # toRGB layers
550
- self.modulation_layers[4]: self.modulation_layers[4](style[:, 3]),
551
- self.modulation_layers[7]: self.modulation_layers[7](style[:, 5]),
552
- self.modulation_layers[10]: self.modulation_layers[10](style[:, 7]),
553
- self.modulation_layers[13]: self.modulation_layers[13](style[:, 9]),
554
- self.modulation_layers[16]: self.modulation_layers[16](style[:, 11]),
555
- self.modulation_layers[19]: self.modulation_layers[19](style[:, 13]),
556
- self.modulation_layers[22]: self.modulation_layers[22](style[:, 15]),
557
- self.modulation_layers[25]: self.modulation_layers[25](style[:, 17]),
558
  } for style in styles]
559
-
560
  return s_codes
561
 
 
562
  def forward(
563
  self,
564
  styles,
 
527
  styles = [self.style(s) for s in styles]
528
 
529
  s_codes = [{# const block
530
+ self.modulation_layers[0]: self.modulation_layers[0](style[:, 0]), #s0
531
+ self.modulation_layers[1]: self.modulation_layers[1](style[:, 1]), #s1
532
  # conv layers
533
+ self.modulation_layers[2]: self.modulation_layers[2](style[:, 1]), #s2
534
+ self.modulation_layers[3]: self.modulation_layers[3](style[:, 2]), #s3
535
+ self.modulation_layers[4]: self.modulation_layers[4](style[:, 3]), #s5
536
+ self.modulation_layers[5]: self.modulation_layers[5](style[:, 4]), #s6
537
+ self.modulation_layers[6]: self.modulation_layers[6](style[:, 5]), #s8
538
+ self.modulation_layers[7]: self.modulation_layers[7](style[:, 6]), #s9
539
+ self.modulation_layers[8]: self.modulation_layers[8](style[:, 7]), #s11
540
+ self.modulation_layers[9]: self.modulation_layers[9](style[:, 8]), #s12
541
+ self.modulation_layers[10]: self.modulation_layers[10](style[:, 9]), #s14
542
+ self.modulation_layers[11]: self.modulation_layers[11](style[:, 10]), #s15
543
+ self.modulation_layers[12]: self.modulation_layers[12](style[:, 11]), #s17
544
+ self.modulation_layers[13]: self.modulation_layers[13](style[:, 12]), #s18
545
+ self.modulation_layers[14]: self.modulation_layers[14](style[:, 13]), #s20
546
+ self.modulation_layers[15]: self.modulation_layers[15](style[:, 14]), #s21
547
+ self.modulation_layers[16]: self.modulation_layers[16](style[:, 15]), #s23
548
+ self.modulation_layers[17]: self.modulation_layers[17](style[:, 16]), #s24
549
  # toRGB layers
550
+ self.modulation_layers[18]: self.modulation_layers[18](style[:, 3]), #s4
551
+ self.modulation_layers[19]: self.modulation_layers[19](style[:, 5]), #s7
552
+ self.modulation_layers[20]: self.modulation_layers[20](style[:, 7]), #s10
553
+ self.modulation_layers[21]: self.modulation_layers[21](style[:, 9]), #s13
554
+ self.modulation_layers[22]: self.modulation_layers[22](style[:, 11]), #s16
555
+ self.modulation_layers[23]: self.modulation_layers[23](style[:, 13]), #s19
556
+ self.modulation_layers[24]: self.modulation_layers[24](style[:, 15]), #s22
557
+ self.modulation_layers[25]: self.modulation_layers[25](style[:, 17]), #s25
558
  } for style in styles]
559
+
560
  return s_codes
561
 
562
+
563
  def forward(
564
  self,
565
  styles,
styleclip/styleclip_global.py CHANGED
@@ -114,7 +114,7 @@ def expand_to_full_dim(partial_tensor):
114
  start_idx = 0
115
  for conv_start, conv_end in CONV_CODE_INDICES:
116
  length = conv_end - conv_start
117
- full_dim_tensor[:, conv_start:conv_end] = partial_tensor[:, start_idx:start_idx + length]
118
  start_idx += length
119
 
120
  return full_dim_tensor
 
114
  start_idx = 0
115
  for conv_start, conv_end in CONV_CODE_INDICES:
116
  length = conv_end - conv_start
117
+ full_dim_tensor[:, conv_start:conv_end] = partial_tensor[start_idx:start_idx + length]
118
  start_idx += length
119
 
120
  return full_dim_tensor