DeepBeepMeep commited on
Commit
25e8685
·
1 Parent(s): 84e409b

Optimized Vace RAM usage

Browse files
Files changed (3) hide show
  1. wan/modules/model.py +98 -56
  2. wan/text2video.py +13 -0
  3. wgp.py +1 -24
wan/modules/model.py CHANGED
@@ -447,6 +447,21 @@ class WanAttentionBlock(nn.Module):
447
  grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
448
  freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
449
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
450
  e = (self.modulation + e).chunk(6, dim=1)
451
 
452
  # self-attention
@@ -485,13 +500,16 @@ class WanAttentionBlock(nn.Module):
485
 
486
  x.addcmul_(y, e[5])
487
 
488
-
489
- if self.block_id is not None and hints != None:
 
490
  if context_scale == 1:
491
- x.add_(hints[self.block_id])
492
  else:
493
- x.add_(hints[self.block_id], alpha =context_scale)
494
- return x
 
 
495
 
496
  class VaceWanAttentionBlock(WanAttentionBlock):
497
  def __init__(
@@ -516,18 +534,29 @@ class VaceWanAttentionBlock(WanAttentionBlock):
516
  nn.init.zeros_(self.after_proj.weight)
517
  nn.init.zeros_(self.after_proj.bias)
518
 
519
- def forward(self, c, x, **kwargs):
520
  # behold dbm magic !
 
 
521
  if self.block_id == 0:
522
  c = self.before_proj(c) + x
523
- all_c = []
524
- else:
525
- all_c = c
526
- c = all_c.pop(-1)
527
  c = super().forward(c, **kwargs)
528
  c_skip = self.after_proj(c)
529
- all_c += [c_skip, c]
530
- return all_c
 
 
 
 
 
 
 
 
 
 
 
 
 
531
 
532
  class Head(nn.Module):
533
 
@@ -764,35 +793,37 @@ class WanModel(ModelMixin, ConfigMixin):
764
  print(f"Tea Cache, best threshold found:{best_threshold:0.2f} with gain x{len(timesteps)/(target_nb_steps - best_signed_diff):0.2f} for a target of x{speed_factor}")
765
  return best_threshold
766
 
767
- def forward_vace(
768
- self,
769
- x,
770
- vace_context,
771
- seq_len,
772
- context,
773
- e,
774
- kwargs
775
- ):
776
- # embeddings
777
- c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context]
778
- c = [u.flatten(2).transpose(1, 2) for u in c]
779
- if (len(c) == 1 and seq_len == c[0].size(1)):
780
- c = c[0]
781
- else:
782
- c = torch.cat([
783
- torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
784
- dim=1) for u in c
785
- ])
786
-
787
- # arguments
788
- new_kwargs = dict(x=x)
789
- new_kwargs.update(kwargs)
790
 
791
- for block in self.vace_blocks:
792
- c = block(c, context= context, e= e, **new_kwargs)
793
- hints = c[:-1]
794
 
795
- return hints
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
796
 
797
  def forward(
798
  self,
@@ -904,6 +935,34 @@ class WanModel(ModelMixin, ConfigMixin):
904
  x_list = [x]
905
  context_list = [context]
906
  del x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
907
  should_calc = True
908
  if self.enable_teacache:
909
  if is_uncond:
@@ -935,23 +994,6 @@ class WanModel(ModelMixin, ConfigMixin):
935
  if joint_pass or not is_uncond:
936
  self.previous_residual_cond = None
937
  ori_hidden_states = x_list[0].clone()
938
- # arguments
939
-
940
- kwargs = dict(
941
- seq_lens=seq_lens,
942
- grid_sizes=grid_sizes,
943
- freqs=freqs,
944
- context_lens=context_lens)
945
-
946
- if vace_context == None:
947
- hints_list = [None ] *len(x_list)
948
- else:
949
- hints_list = []
950
- for x, context in zip(x_list, context_list) :
951
- hints_list.append( self.forward_vace(x, vace_context, seq_len, context= context, e= e0, kwargs= kwargs))
952
- del x, context
953
- kwargs['context_scale'] = vace_context_scale
954
-
955
 
956
  for block_idx, block in enumerate(self.blocks):
957
  offload.shared_state["layer"] = block_idx
 
447
  grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
448
  freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
449
  """
450
+ hint = None
451
+ if self.block_id is not None and hints is not None:
452
+ kwargs = {
453
+ "seq_lens" : seq_lens,
454
+ "grid_sizes" : grid_sizes,
455
+ "freqs" :freqs,
456
+ "context" : context,
457
+ "context_lens" : context_lens,
458
+ "e" : e,
459
+ }
460
+ if self.block_id == 0:
461
+ hint = self.vace(hints, x, **kwargs)
462
+ else:
463
+ hint = self.vace(hints, None, **kwargs)
464
+
465
  e = (self.modulation + e).chunk(6, dim=1)
466
 
467
  # self-attention
 
500
 
501
  x.addcmul_(y, e[5])
502
 
503
+
504
+
505
+ if hint is not None:
506
  if context_scale == 1:
507
+ x.add_(hint)
508
  else:
509
+ x.add_(hint, alpha= context_scale)
510
+ return x
511
+
512
+
513
 
514
  class VaceWanAttentionBlock(WanAttentionBlock):
515
  def __init__(
 
534
  nn.init.zeros_(self.after_proj.weight)
535
  nn.init.zeros_(self.after_proj.bias)
536
 
537
+ def forward(self, hints, x, **kwargs):
538
  # behold dbm magic !
539
+ c = hints[0]
540
+ hints[0] = None
541
  if self.block_id == 0:
542
  c = self.before_proj(c) + x
 
 
 
 
543
  c = super().forward(c, **kwargs)
544
  c_skip = self.after_proj(c)
545
+ hints[0] = c
546
+ return c_skip
547
+
548
+ # def forward(self, c, x, **kwargs):
549
+ # # behold dbm magic !
550
+ # if self.block_id == 0:
551
+ # c = self.before_proj(c) + x
552
+ # all_c = []
553
+ # else:
554
+ # all_c = c
555
+ # c = all_c.pop(-1)
556
+ # c = super().forward(c, **kwargs)
557
+ # c_skip = self.after_proj(c)
558
+ # all_c += [c_skip, c]
559
+ # return all_c
560
 
561
  class Head(nn.Module):
562
 
 
793
  print(f"Tea Cache, best threshold found:{best_threshold:0.2f} with gain x{len(timesteps)/(target_nb_steps - best_signed_diff):0.2f} for a target of x{speed_factor}")
794
  return best_threshold
795
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
796
 
 
 
 
797
 
798
+ # def forward_vace(
799
+ # self,
800
+ # x,
801
+ # vace_context,
802
+ # seq_len,
803
+ # context,
804
+ # e,
805
+ # kwargs
806
+ # ):
807
+ # # embeddings
808
+ # c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context]
809
+ # c = [u.flatten(2).transpose(1, 2) for u in c]
810
+ # if (len(c) == 1 and seq_len == c[0].size(1)):
811
+ # c = c[0]
812
+ # else:
813
+ # c = torch.cat([
814
+ # torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
815
+ # dim=1) for u in c
816
+ # ])
817
+
818
+ # # arguments
819
+ # new_kwargs = dict(x=x)
820
+ # new_kwargs.update(kwargs)
821
+
822
+ # for block in self.vace_blocks:
823
+ # c = block(c, context= context, e= e, **new_kwargs)
824
+ # hints = c[:-1]
825
+
826
+ # return hints
827
 
828
  def forward(
829
  self,
 
935
  x_list = [x]
936
  context_list = [context]
937
  del x
938
+
939
+ # arguments
940
+
941
+ kwargs = dict(
942
+ seq_lens=seq_lens,
943
+ grid_sizes=grid_sizes,
944
+ freqs=freqs,
945
+ context_lens=context_lens,
946
+ )
947
+
948
+ if vace_context == None:
949
+ hints_list = [None ] *len(x_list)
950
+ else:
951
+ # embeddings
952
+ c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context]
953
+ c = [u.flatten(2).transpose(1, 2) for u in c]
954
+ if (len(c) == 1 and seq_len == c[0].size(1)):
955
+ c = c[0]
956
+ else:
957
+ c = torch.cat([
958
+ torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
959
+ dim=1) for u in c
960
+ ])
961
+
962
+ kwargs['context_scale'] = vace_context_scale
963
+ hints_list = [ [c] if i==0 else [c.clone()] for i in range(len(x_list)) ]
964
+ del c
965
+
966
  should_calc = True
967
  if self.enable_teacache:
968
  if is_uncond:
 
994
  if joint_pass or not is_uncond:
995
  self.previous_residual_cond = None
996
  ori_hidden_states = x_list[0].clone()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
997
 
998
  for block_idx, block in enumerate(self.blocks):
999
  offload.shared_state["layer"] = block_idx
wan/text2video.py CHANGED
@@ -143,6 +143,8 @@ class WanT2V:
143
  seq_len=32760,
144
  keep_last=True)
145
 
 
 
146
  def vace_encode_frames(self, frames, ref_images, masks=None, tile_size = 0):
147
  if ref_images is None:
148
  ref_images = [None] * len(frames)
@@ -505,3 +507,14 @@ class WanT2V:
505
  dist.barrier()
506
 
507
  return videos[0] if self.rank == 0 else None
 
 
 
 
 
 
 
 
 
 
 
 
143
  seq_len=32760,
144
  keep_last=True)
145
 
146
+ self.adapt_vace_model()
147
+
148
  def vace_encode_frames(self, frames, ref_images, masks=None, tile_size = 0):
149
  if ref_images is None:
150
  ref_images = [None] * len(frames)
 
507
  dist.barrier()
508
 
509
  return videos[0] if self.rank == 0 else None
510
+
511
+ def adapt_vace_model(self):
512
+ model = self.model
513
+ modules_dict= { k: m for k, m in model.named_modules()}
514
+ for num in range(15):
515
+ module = modules_dict[f"vace_blocks.{num}"]
516
+ target = modules_dict[f"blocks.{2*num}"]
517
+ setattr(target, "vace", module )
518
+ delattr(model, "vace_blocks")
519
+
520
+
wgp.py CHANGED
@@ -910,14 +910,6 @@ def get_queue_table(queue):
910
  if len(queue) == 1:
911
  return data
912
 
913
- # def td(l, content, width =None):
914
- # if width !=None:
915
- # l.append("<TD WIDTH="+ str(width) + "px>" + content + "</TD>")
916
- # else:
917
- # l.append("<TD>" + content + "</TD>")
918
-
919
- # data.append("<STYLE> .TB, .TB th, .TB td {border: 1px solid #CCCCCC};></STYLE><TABLE CLASS=TB><TR BGCOLOR=#F2F2F2><TD Style='Bold'>Qty</TD><TD>Prompt</TD><TD>Steps</TD><TD></TD><TD><TD></TD><TD></TD><TD></TD></TR>")
920
-
921
  for i, item in enumerate(queue):
922
  if i==0:
923
  continue
@@ -937,22 +929,7 @@ def get_queue_table(queue):
937
  start_img_md = f'<img src="{start_img_uri}" alt="Start" style="max-width:{thumbnail_size}; max-height:{thumbnail_size}; display: block; margin: auto; object-fit: contain;" />'
938
  if end_img_uri:
939
  end_img_md = f'<img src="{end_img_uri}" alt="End" style="max-width:{thumbnail_size}; max-height:{thumbnail_size}; display: block; margin: auto; object-fit: contain;" />'
940
- # if i % 2 == 1:
941
- # data.append("<TR>")
942
- # else:
943
- # data.append("<TR BGCOLOR=#F2F2F2>")
944
-
945
- # td(data,str(item.get('repeats', "1")) )
946
- # td(data, prompt_cell, "100%")
947
- # td(data, num_steps, "100%")
948
- # td(data, start_img_md)
949
- # td(data, end_img_md)
950
- # td(data, "↑")
951
- # td(data, "↓")
952
- # td(data, "✖")
953
- # data.append("</TR>")
954
- # data.append("</TABLE>")
955
- # return ''.join(data)
956
 
957
  data.append([item.get('repeats', "1"),
958
  prompt_cell,
 
910
  if len(queue) == 1:
911
  return data
912
 
 
 
 
 
 
 
 
 
913
  for i, item in enumerate(queue):
914
  if i==0:
915
  continue
 
929
  start_img_md = f'<img src="{start_img_uri}" alt="Start" style="max-width:{thumbnail_size}; max-height:{thumbnail_size}; display: block; margin: auto; object-fit: contain;" />'
930
  if end_img_uri:
931
  end_img_md = f'<img src="{end_img_uri}" alt="End" style="max-width:{thumbnail_size}; max-height:{thumbnail_size}; display: block; margin: auto; object-fit: contain;" />'
932
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
933
 
934
  data.append([item.get('repeats', "1"),
935
  prompt_cell,