DeepBeepMeep
commited on
Commit
·
25e8685
1
Parent(s):
84e409b
Optimized Vace RAM usage
Browse files- wan/modules/model.py +98 -56
- wan/text2video.py +13 -0
- wgp.py +1 -24
wan/modules/model.py
CHANGED
|
@@ -447,6 +447,21 @@ class WanAttentionBlock(nn.Module):
|
|
| 447 |
grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
|
| 448 |
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
|
| 449 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 450 |
e = (self.modulation + e).chunk(6, dim=1)
|
| 451 |
|
| 452 |
# self-attention
|
|
@@ -485,13 +500,16 @@ class WanAttentionBlock(nn.Module):
|
|
| 485 |
|
| 486 |
x.addcmul_(y, e[5])
|
| 487 |
|
| 488 |
-
|
| 489 |
-
|
|
|
|
| 490 |
if context_scale == 1:
|
| 491 |
-
x.add_(
|
| 492 |
else:
|
| 493 |
-
x.add_(
|
| 494 |
-
return x
|
|
|
|
|
|
|
| 495 |
|
| 496 |
class VaceWanAttentionBlock(WanAttentionBlock):
|
| 497 |
def __init__(
|
|
@@ -516,18 +534,29 @@ class VaceWanAttentionBlock(WanAttentionBlock):
|
|
| 516 |
nn.init.zeros_(self.after_proj.weight)
|
| 517 |
nn.init.zeros_(self.after_proj.bias)
|
| 518 |
|
| 519 |
-
def forward(self,
|
| 520 |
# behold dbm magic !
|
|
|
|
|
|
|
| 521 |
if self.block_id == 0:
|
| 522 |
c = self.before_proj(c) + x
|
| 523 |
-
all_c = []
|
| 524 |
-
else:
|
| 525 |
-
all_c = c
|
| 526 |
-
c = all_c.pop(-1)
|
| 527 |
c = super().forward(c, **kwargs)
|
| 528 |
c_skip = self.after_proj(c)
|
| 529 |
-
|
| 530 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 531 |
|
| 532 |
class Head(nn.Module):
|
| 533 |
|
|
@@ -764,35 +793,37 @@ class WanModel(ModelMixin, ConfigMixin):
|
|
| 764 |
print(f"Tea Cache, best threshold found:{best_threshold:0.2f} with gain x{len(timesteps)/(target_nb_steps - best_signed_diff):0.2f} for a target of x{speed_factor}")
|
| 765 |
return best_threshold
|
| 766 |
|
| 767 |
-
def forward_vace(
|
| 768 |
-
self,
|
| 769 |
-
x,
|
| 770 |
-
vace_context,
|
| 771 |
-
seq_len,
|
| 772 |
-
context,
|
| 773 |
-
e,
|
| 774 |
-
kwargs
|
| 775 |
-
):
|
| 776 |
-
# embeddings
|
| 777 |
-
c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context]
|
| 778 |
-
c = [u.flatten(2).transpose(1, 2) for u in c]
|
| 779 |
-
if (len(c) == 1 and seq_len == c[0].size(1)):
|
| 780 |
-
c = c[0]
|
| 781 |
-
else:
|
| 782 |
-
c = torch.cat([
|
| 783 |
-
torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
|
| 784 |
-
dim=1) for u in c
|
| 785 |
-
])
|
| 786 |
-
|
| 787 |
-
# arguments
|
| 788 |
-
new_kwargs = dict(x=x)
|
| 789 |
-
new_kwargs.update(kwargs)
|
| 790 |
|
| 791 |
-
for block in self.vace_blocks:
|
| 792 |
-
c = block(c, context= context, e= e, **new_kwargs)
|
| 793 |
-
hints = c[:-1]
|
| 794 |
|
| 795 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 796 |
|
| 797 |
def forward(
|
| 798 |
self,
|
|
@@ -904,6 +935,34 @@ class WanModel(ModelMixin, ConfigMixin):
|
|
| 904 |
x_list = [x]
|
| 905 |
context_list = [context]
|
| 906 |
del x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 907 |
should_calc = True
|
| 908 |
if self.enable_teacache:
|
| 909 |
if is_uncond:
|
|
@@ -935,23 +994,6 @@ class WanModel(ModelMixin, ConfigMixin):
|
|
| 935 |
if joint_pass or not is_uncond:
|
| 936 |
self.previous_residual_cond = None
|
| 937 |
ori_hidden_states = x_list[0].clone()
|
| 938 |
-
# arguments
|
| 939 |
-
|
| 940 |
-
kwargs = dict(
|
| 941 |
-
seq_lens=seq_lens,
|
| 942 |
-
grid_sizes=grid_sizes,
|
| 943 |
-
freqs=freqs,
|
| 944 |
-
context_lens=context_lens)
|
| 945 |
-
|
| 946 |
-
if vace_context == None:
|
| 947 |
-
hints_list = [None ] *len(x_list)
|
| 948 |
-
else:
|
| 949 |
-
hints_list = []
|
| 950 |
-
for x, context in zip(x_list, context_list) :
|
| 951 |
-
hints_list.append( self.forward_vace(x, vace_context, seq_len, context= context, e= e0, kwargs= kwargs))
|
| 952 |
-
del x, context
|
| 953 |
-
kwargs['context_scale'] = vace_context_scale
|
| 954 |
-
|
| 955 |
|
| 956 |
for block_idx, block in enumerate(self.blocks):
|
| 957 |
offload.shared_state["layer"] = block_idx
|
|
|
|
| 447 |
grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
|
| 448 |
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
|
| 449 |
"""
|
| 450 |
+
hint = None
|
| 451 |
+
if self.block_id is not None and hints is not None:
|
| 452 |
+
kwargs = {
|
| 453 |
+
"seq_lens" : seq_lens,
|
| 454 |
+
"grid_sizes" : grid_sizes,
|
| 455 |
+
"freqs" :freqs,
|
| 456 |
+
"context" : context,
|
| 457 |
+
"context_lens" : context_lens,
|
| 458 |
+
"e" : e,
|
| 459 |
+
}
|
| 460 |
+
if self.block_id == 0:
|
| 461 |
+
hint = self.vace(hints, x, **kwargs)
|
| 462 |
+
else:
|
| 463 |
+
hint = self.vace(hints, None, **kwargs)
|
| 464 |
+
|
| 465 |
e = (self.modulation + e).chunk(6, dim=1)
|
| 466 |
|
| 467 |
# self-attention
|
|
|
|
| 500 |
|
| 501 |
x.addcmul_(y, e[5])
|
| 502 |
|
| 503 |
+
|
| 504 |
+
|
| 505 |
+
if hint is not None:
|
| 506 |
if context_scale == 1:
|
| 507 |
+
x.add_(hint)
|
| 508 |
else:
|
| 509 |
+
x.add_(hint, alpha= context_scale)
|
| 510 |
+
return x
|
| 511 |
+
|
| 512 |
+
|
| 513 |
|
| 514 |
class VaceWanAttentionBlock(WanAttentionBlock):
|
| 515 |
def __init__(
|
|
|
|
| 534 |
nn.init.zeros_(self.after_proj.weight)
|
| 535 |
nn.init.zeros_(self.after_proj.bias)
|
| 536 |
|
| 537 |
+
def forward(self, hints, x, **kwargs):
|
| 538 |
# behold dbm magic !
|
| 539 |
+
c = hints[0]
|
| 540 |
+
hints[0] = None
|
| 541 |
if self.block_id == 0:
|
| 542 |
c = self.before_proj(c) + x
|
|
|
|
|
|
|
|
|
|
|
|
|
| 543 |
c = super().forward(c, **kwargs)
|
| 544 |
c_skip = self.after_proj(c)
|
| 545 |
+
hints[0] = c
|
| 546 |
+
return c_skip
|
| 547 |
+
|
| 548 |
+
# def forward(self, c, x, **kwargs):
|
| 549 |
+
# # behold dbm magic !
|
| 550 |
+
# if self.block_id == 0:
|
| 551 |
+
# c = self.before_proj(c) + x
|
| 552 |
+
# all_c = []
|
| 553 |
+
# else:
|
| 554 |
+
# all_c = c
|
| 555 |
+
# c = all_c.pop(-1)
|
| 556 |
+
# c = super().forward(c, **kwargs)
|
| 557 |
+
# c_skip = self.after_proj(c)
|
| 558 |
+
# all_c += [c_skip, c]
|
| 559 |
+
# return all_c
|
| 560 |
|
| 561 |
class Head(nn.Module):
|
| 562 |
|
|
|
|
| 793 |
print(f"Tea Cache, best threshold found:{best_threshold:0.2f} with gain x{len(timesteps)/(target_nb_steps - best_signed_diff):0.2f} for a target of x{speed_factor}")
|
| 794 |
return best_threshold
|
| 795 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 796 |
|
|
|
|
|
|
|
|
|
|
| 797 |
|
| 798 |
+
# def forward_vace(
|
| 799 |
+
# self,
|
| 800 |
+
# x,
|
| 801 |
+
# vace_context,
|
| 802 |
+
# seq_len,
|
| 803 |
+
# context,
|
| 804 |
+
# e,
|
| 805 |
+
# kwargs
|
| 806 |
+
# ):
|
| 807 |
+
# # embeddings
|
| 808 |
+
# c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context]
|
| 809 |
+
# c = [u.flatten(2).transpose(1, 2) for u in c]
|
| 810 |
+
# if (len(c) == 1 and seq_len == c[0].size(1)):
|
| 811 |
+
# c = c[0]
|
| 812 |
+
# else:
|
| 813 |
+
# c = torch.cat([
|
| 814 |
+
# torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
|
| 815 |
+
# dim=1) for u in c
|
| 816 |
+
# ])
|
| 817 |
+
|
| 818 |
+
# # arguments
|
| 819 |
+
# new_kwargs = dict(x=x)
|
| 820 |
+
# new_kwargs.update(kwargs)
|
| 821 |
+
|
| 822 |
+
# for block in self.vace_blocks:
|
| 823 |
+
# c = block(c, context= context, e= e, **new_kwargs)
|
| 824 |
+
# hints = c[:-1]
|
| 825 |
+
|
| 826 |
+
# return hints
|
| 827 |
|
| 828 |
def forward(
|
| 829 |
self,
|
|
|
|
| 935 |
x_list = [x]
|
| 936 |
context_list = [context]
|
| 937 |
del x
|
| 938 |
+
|
| 939 |
+
# arguments
|
| 940 |
+
|
| 941 |
+
kwargs = dict(
|
| 942 |
+
seq_lens=seq_lens,
|
| 943 |
+
grid_sizes=grid_sizes,
|
| 944 |
+
freqs=freqs,
|
| 945 |
+
context_lens=context_lens,
|
| 946 |
+
)
|
| 947 |
+
|
| 948 |
+
if vace_context == None:
|
| 949 |
+
hints_list = [None ] *len(x_list)
|
| 950 |
+
else:
|
| 951 |
+
# embeddings
|
| 952 |
+
c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context]
|
| 953 |
+
c = [u.flatten(2).transpose(1, 2) for u in c]
|
| 954 |
+
if (len(c) == 1 and seq_len == c[0].size(1)):
|
| 955 |
+
c = c[0]
|
| 956 |
+
else:
|
| 957 |
+
c = torch.cat([
|
| 958 |
+
torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
|
| 959 |
+
dim=1) for u in c
|
| 960 |
+
])
|
| 961 |
+
|
| 962 |
+
kwargs['context_scale'] = vace_context_scale
|
| 963 |
+
hints_list = [ [c] if i==0 else [c.clone()] for i in range(len(x_list)) ]
|
| 964 |
+
del c
|
| 965 |
+
|
| 966 |
should_calc = True
|
| 967 |
if self.enable_teacache:
|
| 968 |
if is_uncond:
|
|
|
|
| 994 |
if joint_pass or not is_uncond:
|
| 995 |
self.previous_residual_cond = None
|
| 996 |
ori_hidden_states = x_list[0].clone()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 997 |
|
| 998 |
for block_idx, block in enumerate(self.blocks):
|
| 999 |
offload.shared_state["layer"] = block_idx
|
wan/text2video.py
CHANGED
|
@@ -143,6 +143,8 @@ class WanT2V:
|
|
| 143 |
seq_len=32760,
|
| 144 |
keep_last=True)
|
| 145 |
|
|
|
|
|
|
|
| 146 |
def vace_encode_frames(self, frames, ref_images, masks=None, tile_size = 0):
|
| 147 |
if ref_images is None:
|
| 148 |
ref_images = [None] * len(frames)
|
|
@@ -505,3 +507,14 @@ class WanT2V:
|
|
| 505 |
dist.barrier()
|
| 506 |
|
| 507 |
return videos[0] if self.rank == 0 else None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 143 |
seq_len=32760,
|
| 144 |
keep_last=True)
|
| 145 |
|
| 146 |
+
self.adapt_vace_model()
|
| 147 |
+
|
| 148 |
def vace_encode_frames(self, frames, ref_images, masks=None, tile_size = 0):
|
| 149 |
if ref_images is None:
|
| 150 |
ref_images = [None] * len(frames)
|
|
|
|
| 507 |
dist.barrier()
|
| 508 |
|
| 509 |
return videos[0] if self.rank == 0 else None
|
| 510 |
+
|
| 511 |
+
def adapt_vace_model(self):
|
| 512 |
+
model = self.model
|
| 513 |
+
modules_dict= { k: m for k, m in model.named_modules()}
|
| 514 |
+
for num in range(15):
|
| 515 |
+
module = modules_dict[f"vace_blocks.{num}"]
|
| 516 |
+
target = modules_dict[f"blocks.{2*num}"]
|
| 517 |
+
setattr(target, "vace", module )
|
| 518 |
+
delattr(model, "vace_blocks")
|
| 519 |
+
|
| 520 |
+
|
wgp.py
CHANGED
|
@@ -910,14 +910,6 @@ def get_queue_table(queue):
|
|
| 910 |
if len(queue) == 1:
|
| 911 |
return data
|
| 912 |
|
| 913 |
-
# def td(l, content, width =None):
|
| 914 |
-
# if width !=None:
|
| 915 |
-
# l.append("<TD WIDTH="+ str(width) + "px>" + content + "</TD>")
|
| 916 |
-
# else:
|
| 917 |
-
# l.append("<TD>" + content + "</TD>")
|
| 918 |
-
|
| 919 |
-
# data.append("<STYLE> .TB, .TB th, .TB td {border: 1px solid #CCCCCC};></STYLE><TABLE CLASS=TB><TR BGCOLOR=#F2F2F2><TD Style='Bold'>Qty</TD><TD>Prompt</TD><TD>Steps</TD><TD></TD><TD><TD></TD><TD></TD><TD></TD></TR>")
|
| 920 |
-
|
| 921 |
for i, item in enumerate(queue):
|
| 922 |
if i==0:
|
| 923 |
continue
|
|
@@ -937,22 +929,7 @@ def get_queue_table(queue):
|
|
| 937 |
start_img_md = f'<img src="{start_img_uri}" alt="Start" style="max-width:{thumbnail_size}; max-height:{thumbnail_size}; display: block; margin: auto; object-fit: contain;" />'
|
| 938 |
if end_img_uri:
|
| 939 |
end_img_md = f'<img src="{end_img_uri}" alt="End" style="max-width:{thumbnail_size}; max-height:{thumbnail_size}; display: block; margin: auto; object-fit: contain;" />'
|
| 940 |
-
|
| 941 |
-
# data.append("<TR>")
|
| 942 |
-
# else:
|
| 943 |
-
# data.append("<TR BGCOLOR=#F2F2F2>")
|
| 944 |
-
|
| 945 |
-
# td(data,str(item.get('repeats', "1")) )
|
| 946 |
-
# td(data, prompt_cell, "100%")
|
| 947 |
-
# td(data, num_steps, "100%")
|
| 948 |
-
# td(data, start_img_md)
|
| 949 |
-
# td(data, end_img_md)
|
| 950 |
-
# td(data, "↑")
|
| 951 |
-
# td(data, "↓")
|
| 952 |
-
# td(data, "✖")
|
| 953 |
-
# data.append("</TR>")
|
| 954 |
-
# data.append("</TABLE>")
|
| 955 |
-
# return ''.join(data)
|
| 956 |
|
| 957 |
data.append([item.get('repeats', "1"),
|
| 958 |
prompt_cell,
|
|
|
|
| 910 |
if len(queue) == 1:
|
| 911 |
return data
|
| 912 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 913 |
for i, item in enumerate(queue):
|
| 914 |
if i==0:
|
| 915 |
continue
|
|
|
|
| 929 |
start_img_md = f'<img src="{start_img_uri}" alt="Start" style="max-width:{thumbnail_size}; max-height:{thumbnail_size}; display: block; margin: auto; object-fit: contain;" />'
|
| 930 |
if end_img_uri:
|
| 931 |
end_img_md = f'<img src="{end_img_uri}" alt="End" style="max-width:{thumbnail_size}; max-height:{thumbnail_size}; display: block; margin: auto; object-fit: contain;" />'
|
| 932 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 933 |
|
| 934 |
data.append([item.get('repeats', "1"),
|
| 935 |
prompt_cell,
|