rinong commited on
Commit
4663a72
1 Parent(s): f2ea589

Modified s_dict generation

Browse files
Files changed (3) hide show
  1. app.py +3 -2
  2. model/sg2_model.py +30 -1
  3. styleclip/styleclip_global.py +5 -2
app.py CHANGED
@@ -368,8 +368,9 @@ with blocks:
368
  vid_button = gr.Button("Generate Video")
369
  loop_styles = gr.inputs.Checkbox(default=True, label="Loop video back to the initial style?")
370
  with gr.Row():
371
- gr.Markdown("Warning: Videos generation requires the synthesis of hundreds of frames and is expected to take several minutes.")
372
- gr.Markdown("To reduce queue times, we significantly reduced the number of video frames. Using more than 3 styles will further reduce the frames per style, leading to quicker transitions. For better control, we reccomend cloning the gradio app, adjusting `num_alphas` in `generate_videos`, and running the code locally.")
 
373
  with gr.Column():
374
  vid_output = gr.outputs.Video(label="Output Video")
375
 
 
368
  vid_button = gr.Button("Generate Video")
369
  loop_styles = gr.inputs.Checkbox(default=True, label="Loop video back to the initial style?")
370
  with gr.Row():
371
+ with gr.Column():
372
+ gr.Markdown("Warning: Videos generation requires the synthesis of hundreds of frames and is expected to take several minutes.")
373
+ gr.Markdown("To reduce queue times, we significantly reduced the number of video frames. Using more than 3 styles will further reduce the frames per style, leading to quicker transitions. For better control, we reccomend cloning the gradio app, adjusting `num_alphas` in `generate_videos`, and running the code locally.")
374
  with gr.Column():
375
  vid_output = gr.outputs.Video(label="Output Video")
376
 
model/sg2_model.py CHANGED
@@ -526,7 +526,36 @@ class Generator(nn.Module):
526
  if not input_is_latent:
527
  styles = [self.style(s) for s in styles]
528
 
529
- s_codes = [{layer: layer(s) for layer in self.modulation_layers} for s in styles] * len(styles)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
530
 
531
  return s_codes
532
 
 
526
  if not input_is_latent:
527
  styles = [self.style(s) for s in styles]
528
 
529
+ s_codes = {# const block
530
+ self.modulation_layers[0]: self.modulation_layers[0](styles[0]),
531
+ self.modulation_layers[1]: self.modulation_layers[1](styles[1]),
532
+ # conv layers
533
+ self.modulation_layers[2]: self.modulation_layers[2](styles[2]),
534
+ self.modulation_layers[3]: self.modulation_layers[3](styles[3]),
535
+ self.modulation_layers[5]: self.modulation_layers[5](styles[4]),
536
+ self.modulation_layers[6]: self.modulation_layers[6](styles[5]),
537
+ self.modulation_layers[8]: self.modulation_layers[8](styles[6]),
538
+ self.modulation_layers[9]: self.modulation_layers[9](styles[7]),
539
+ self.modulation_layers[11]: self.modulation_layers[11](styles[8]),
540
+ self.modulation_layers[12]: self.modulation_layers[12](styles[9]),
541
+ self.modulation_layers[14]: self.modulation_layers[14](styles[10]),
542
+ self.modulation_layers[15]: self.modulation_layers[15](styles[11]),
543
+ self.modulation_layers[17]: self.modulation_layers[17](styles[12]),
544
+ self.modulation_layers[18]: self.modulation_layers[18](styles[13]),
545
+ self.modulation_layers[20]: self.modulation_layers[20](styles[14]),
546
+ self.modulation_layers[21]: self.modulation_layers[21](styles[15]),
547
+ self.modulation_layers[23]: self.modulation_layers[23](styles[16]),
548
+ self.modulation_layers[24]: self.modulation_layers[24](styles[17]),
549
+ # toRGB layers
550
+ self.modulation_layers[4]: self.modulation_layers[4](styles[3]),
551
+ self.modulation_layers[7]: self.modulation_layers[7](styles[5]),
552
+ self.modulation_layers[10]: self.modulation_layers[10](styles[7]),
553
+ self.modulation_layers[13]: self.modulation_layers[13](styles[9]),
554
+ self.modulation_layers[16]: self.modulation_layers[16](styles[11]),
555
+ self.modulation_layers[19]: self.modulation_layers[19](styles[13]),
556
+ self.modulation_layers[22]: self.modulation_layers[22](styles[15]),
557
+ self.modulation_layers[25]: self.modulation_layers[25](styles[17]),
558
+ }
559
 
560
  return s_codes
561
 
styleclip/styleclip_global.py CHANGED
@@ -120,7 +120,10 @@ def get_direction(neutral_class, target_class, beta, di, clip_model=None):
120
 
121
  dt = class_weights[:, 1] - class_weights[:, 0]
122
  dt = dt / dt.norm()
123
- dt = dt.type(type(di))
 
 
 
124
  relevance = di @ dt
125
  mask = relevance.abs() > beta
126
  direction = relevance * mask
@@ -144,7 +147,7 @@ def style_tensor_to_style_dict(style_tensor, refernce_generator):
144
  def style_dict_to_style_tensor(style_dict, reference_generator):
145
  style_layers = reference_generator.modulation_layers
146
 
147
- style_tensor = torch.zeros(shape=(1, 9088))
148
  for layer in style_dict:
149
  layer_idx = style_layers.index(layer)
150
  style_tensor[:, FFHQ_CODE_INDICES[layer_idx][0]:FFHQ_CODE_INDICES[layer_idx][1]] = style_dict[layer]
 
120
 
121
  dt = class_weights[:, 1] - class_weights[:, 0]
122
  dt = dt / dt.norm()
123
+
124
+ dt = dt.float()
125
+ di = di.float()
126
+
127
  relevance = di @ dt
128
  mask = relevance.abs() > beta
129
  direction = relevance * mask
 
147
  def style_dict_to_style_tensor(style_dict, reference_generator):
148
  style_layers = reference_generator.modulation_layers
149
 
150
+ style_tensor = torch.zeros(size=(1, 9088))
151
  for layer in style_dict:
152
  layer_idx = style_layers.index(layer)
153
  style_tensor[:, FFHQ_CODE_INDICES[layer_idx][0]:FFHQ_CODE_INDICES[layer_idx][1]] = style_dict[layer]