Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import math | |
| from typing import Tuple, Optional | |
| from einops import rearrange | |
| from .utils import hash_state_dict_keys | |
| from .wan_video_camera_controller import SimpleAdapter | |
| try: | |
| import flash_attn_interface | |
| FLASH_ATTN_3_AVAILABLE = True | |
| except ModuleNotFoundError: | |
| FLASH_ATTN_3_AVAILABLE = False | |
| try: | |
| import flash_attn | |
| FLASH_ATTN_2_AVAILABLE = True | |
| except ModuleNotFoundError: | |
| FLASH_ATTN_2_AVAILABLE = False | |
| try: | |
| from sageattention import sageattn | |
| SAGE_ATTN_AVAILABLE = True | |
| except ModuleNotFoundError: | |
| SAGE_ATTN_AVAILABLE = False | |
| def flash_attention( | |
| q: torch.Tensor, | |
| k: torch.Tensor, | |
| v: torch.Tensor, | |
| num_heads: int, | |
| compatibility_mode=False, | |
| ): | |
| if compatibility_mode: | |
| q = rearrange(q, "b s (n d) -> b n s d", n=num_heads) | |
| k = rearrange(k, "b s (n d) -> b n s d", n=num_heads) | |
| v = rearrange(v, "b s (n d) -> b n s d", n=num_heads) | |
| x = F.scaled_dot_product_attention(q, k, v) | |
| x = rearrange(x, "b n s d -> b s (n d)", n=num_heads) | |
| elif FLASH_ATTN_3_AVAILABLE: | |
| q = rearrange(q, "b s (n d) -> b s n d", n=num_heads) | |
| k = rearrange(k, "b s (n d) -> b s n d", n=num_heads) | |
| v = rearrange(v, "b s (n d) -> b s n d", n=num_heads) | |
| x = flash_attn_interface.flash_attn_func(q, k, v) | |
| if isinstance(x, tuple): | |
| x = x[0] | |
| x = rearrange(x, "b s n d -> b s (n d)", n=num_heads) | |
| elif FLASH_ATTN_2_AVAILABLE: | |
| q = rearrange(q, "b s (n d) -> b s n d", n=num_heads) | |
| k = rearrange(k, "b s (n d) -> b s n d", n=num_heads) | |
| v = rearrange(v, "b s (n d) -> b s n d", n=num_heads) | |
| x = flash_attn.flash_attn_func(q, k, v) | |
| x = rearrange(x, "b s n d -> b s (n d)", n=num_heads) | |
| elif SAGE_ATTN_AVAILABLE: | |
| q = rearrange(q, "b s (n d) -> b n s d", n=num_heads) | |
| k = rearrange(k, "b s (n d) -> b n s d", n=num_heads) | |
| v = rearrange(v, "b s (n d) -> b n s d", n=num_heads) | |
| x = sageattn(q, k, v, tensor_layout="HND", is_causal=False) | |
| x = rearrange(x, "b n s d -> b s (n d)", n=num_heads) | |
| else: | |
| q = rearrange(q, "b s (n d) -> b n s d", n=num_heads) | |
| k = rearrange(k, "b s (n d) -> b n s d", n=num_heads) | |
| v = rearrange(v, "b s (n d) -> b n s d", n=num_heads) | |
| x = F.scaled_dot_product_attention(q, k, v) | |
| x = rearrange(x, "b n s d -> b s (n d)", n=num_heads) | |
| return x | |
| def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor): | |
| return x * (1 + scale) + shift | |
| def sinusoidal_embedding_1d(dim, position): | |
| sinusoid = torch.outer( | |
| position.type(torch.float64), | |
| torch.pow( | |
| 10000, | |
| -torch.arange(dim // 2, dtype=torch.float64, device=position.device).div( | |
| dim // 2 | |
| ), | |
| ), | |
| ) | |
| x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1) | |
| return x.to(position.dtype) | |
| def precompute_freqs_cis_3d(dim: int, end: int = 1024, theta: float = 10000.0): | |
| # 3d rope precompute | |
| f_freqs_cis = precompute_freqs_cis(dim - 2 * (dim // 3), end + 1, theta) | |
| h_freqs_cis = precompute_freqs_cis(dim // 3, end, theta) | |
| w_freqs_cis = precompute_freqs_cis(dim // 3, end, theta) | |
| return f_freqs_cis, h_freqs_cis, w_freqs_cis | |
| def precompute_freqs_cis(dim: int, end: int = 1024, theta: float = 10000.0): | |
| # 1d rope precompute | |
| freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].double() / dim)) | |
| ###################################################### add f = -1 | |
| positions = torch.arange(-1, end, device=freqs.device) | |
| freqs = torch.outer(positions, freqs) | |
| freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 | |
| ###################################################### | |
| return freqs_cis | |
| def rope_apply(x, freqs, num_heads): | |
| x = rearrange(x, "b s (n d) -> b s n d", n=num_heads) | |
| x_out = torch.view_as_complex( | |
| x.to(torch.float64).reshape(x.shape[0], x.shape[1], x.shape[2], -1, 2) | |
| ) | |
| x_out = torch.view_as_real(x_out * freqs).flatten(2) | |
| return x_out.to(x.dtype) | |
| class RMSNorm(nn.Module): | |
| def __init__(self, dim, eps=1e-5): | |
| super().__init__() | |
| self.eps = eps | |
| self.weight = nn.Parameter(torch.ones(dim)) | |
| def norm(self, x): | |
| return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps) | |
| def forward(self, x): | |
| dtype = x.dtype | |
| return self.norm(x.float()).to(dtype) * self.weight | |
| class AttentionModule(nn.Module): | |
| def __init__(self, num_heads): | |
| super().__init__() | |
| self.num_heads = num_heads | |
| def forward(self, q, k, v): | |
| x = flash_attention(q=q, k=k, v=v, num_heads=self.num_heads) | |
| return x | |
| class LoRALinearLayer(nn.Module): | |
| def __init__( | |
| self, | |
| in_features: int, | |
| out_features: int, | |
| rank: int = 128, | |
| device="cuda", | |
| dtype: Optional[torch.dtype] = torch.float32, | |
| ): | |
| super().__init__() | |
| self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype) | |
| self.up = nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype) | |
| self.rank = rank | |
| self.out_features = out_features | |
| self.in_features = in_features | |
| nn.init.normal_(self.down.weight, std=1 / rank) | |
| nn.init.zeros_(self.up.weight) | |
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: | |
| orig_dtype = hidden_states.dtype | |
| dtype = self.down.weight.dtype | |
| down_hidden_states = self.down(hidden_states.to(dtype)) | |
| up_hidden_states = self.up(down_hidden_states) | |
| return up_hidden_states.to(orig_dtype) | |
| class SelfAttention(nn.Module): | |
| def __init__(self, dim: int, num_heads: int, eps: float = 1e-6): | |
| super().__init__() | |
| self.dim = dim | |
| self.num_heads = num_heads | |
| self.head_dim = dim // num_heads | |
| self.q = nn.Linear(dim, dim) | |
| self.k = nn.Linear(dim, dim) | |
| self.v = nn.Linear(dim, dim) | |
| self.o = nn.Linear(dim, dim) | |
| self.norm_q = RMSNorm(dim, eps=eps) | |
| self.norm_k = RMSNorm(dim, eps=eps) | |
| self.attn = AttentionModule(self.num_heads) | |
| self.kv_cache = None | |
| self.cond_size = None | |
| def init_lora(self, train=False): | |
| dim = self.dim | |
| self.q_loras = LoRALinearLayer(dim, dim, rank=128) | |
| self.k_loras = LoRALinearLayer(dim, dim, rank=128) | |
| self.v_loras = LoRALinearLayer(dim, dim, rank=128) | |
| requires_grad = train | |
| for lora in [self.q_loras, self.k_loras, self.v_loras]: | |
| for param in lora.parameters(): | |
| param.requires_grad = requires_grad | |
| def forward(self, x, freqs): | |
| if self.cond_size is not None: | |
| if self.kv_cache is None: | |
| x_main, x_ip = x[:, : -self.cond_size], x[:, -self.cond_size :] | |
| split_point = freqs.shape[0] - self.cond_size | |
| freqs_main = freqs[:split_point] | |
| freqs_ip = freqs[split_point:] | |
| q_main = self.norm_q(self.q(x_main)) | |
| k_main = self.norm_k(self.k(x_main)) | |
| v_main = self.v(x_main) | |
| q_main = rope_apply(q_main, freqs_main, self.num_heads) | |
| k_main = rope_apply(k_main, freqs_main, self.num_heads) | |
| q_ip = self.norm_q(self.q(x_ip) + self.q_loras(x_ip)) | |
| k_ip = self.norm_k(self.k(x_ip) + self.k_loras(x_ip)) | |
| v_ip = self.v(x_ip) + self.v_loras(x_ip) | |
| q_ip = rope_apply(q_ip, freqs_ip, self.num_heads) | |
| k_ip = rope_apply(k_ip, freqs_ip, self.num_heads) | |
| self.kv_cache = {"k_ip": k_ip.detach(), "v_ip": v_ip.detach()} | |
| full_k = torch.concat([k_main, k_ip], dim=1) | |
| full_v = torch.concat([v_main, v_ip], dim=1) | |
| cond_out = self.attn(q_ip, k_ip, v_ip) | |
| main_out = self.attn(q_main, full_k, full_v) | |
| out = torch.concat([main_out, cond_out], dim=1) | |
| return self.o(out) | |
| else: | |
| k_ip = self.kv_cache["k_ip"] | |
| v_ip = self.kv_cache["v_ip"] | |
| q_main = self.norm_q(self.q(x)) | |
| k_main = self.norm_k(self.k(x)) | |
| v_main = self.v(x) | |
| q_main = rope_apply(q_main, freqs, self.num_heads) | |
| k_main = rope_apply(k_main, freqs, self.num_heads) | |
| full_k = torch.concat([k_main, k_ip], dim=1) | |
| full_v = torch.concat([v_main, v_ip], dim=1) | |
| x = self.attn(q_main, full_k, full_v) | |
| return self.o(x) | |
| else: | |
| q = self.norm_q(self.q(x)) | |
| k = self.norm_k(self.k(x)) | |
| v = self.v(x) | |
| q = rope_apply(q, freqs, self.num_heads) | |
| k = rope_apply(k, freqs, self.num_heads) | |
| x = self.attn(q, k, v) | |
| return self.o(x) | |
| class CrossAttention(nn.Module): | |
| def __init__( | |
| self, dim: int, num_heads: int, eps: float = 1e-6, has_image_input: bool = False | |
| ): | |
| super().__init__() | |
| self.dim = dim | |
| self.num_heads = num_heads | |
| self.head_dim = dim // num_heads | |
| self.q = nn.Linear(dim, dim) | |
| self.k = nn.Linear(dim, dim) | |
| self.v = nn.Linear(dim, dim) | |
| self.o = nn.Linear(dim, dim) | |
| self.norm_q = RMSNorm(dim, eps=eps) | |
| self.norm_k = RMSNorm(dim, eps=eps) | |
| self.has_image_input = has_image_input | |
| if has_image_input: | |
| self.k_img = nn.Linear(dim, dim) | |
| self.v_img = nn.Linear(dim, dim) | |
| self.norm_k_img = RMSNorm(dim, eps=eps) | |
| self.attn = AttentionModule(self.num_heads) | |
| def forward(self, x: torch.Tensor, y: torch.Tensor): | |
| if self.has_image_input: | |
| img = y[:, :257] | |
| ctx = y[:, 257:] | |
| else: | |
| ctx = y | |
| q = self.norm_q(self.q(x)) | |
| k = self.norm_k(self.k(ctx)) | |
| v = self.v(ctx) | |
| x = self.attn(q, k, v) | |
| if self.has_image_input: | |
| k_img = self.norm_k_img(self.k_img(img)) | |
| v_img = self.v_img(img) | |
| y = flash_attention(q, k_img, v_img, num_heads=self.num_heads) | |
| x = x + y | |
| return self.o(x) | |
| class GateModule(nn.Module): | |
| def __init__( | |
| self, | |
| ): | |
| super().__init__() | |
| def forward(self, x, gate, residual): | |
| return x + gate * residual | |
| class DiTBlock(nn.Module): | |
| def __init__( | |
| self, | |
| has_image_input: bool, | |
| dim: int, | |
| num_heads: int, | |
| ffn_dim: int, | |
| eps: float = 1e-6, | |
| ): | |
| super().__init__() | |
| self.dim = dim | |
| self.num_heads = num_heads | |
| self.ffn_dim = ffn_dim | |
| self.self_attn = SelfAttention(dim, num_heads, eps) | |
| self.cross_attn = CrossAttention( | |
| dim, num_heads, eps, has_image_input=has_image_input | |
| ) | |
| self.norm1 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False) | |
| self.norm2 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False) | |
| self.norm3 = nn.LayerNorm(dim, eps=eps) | |
| self.ffn = nn.Sequential( | |
| nn.Linear(dim, ffn_dim), | |
| nn.GELU(approximate="tanh"), | |
| nn.Linear(ffn_dim, dim), | |
| ) | |
| self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5) | |
| self.gate = GateModule() | |
| def forward(self, x, context, t_mod, freqs, x_ip=None, t_mod_ip=None): | |
| # msa: multi-head self-attention mlp: multi-layer perceptron | |
| shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( | |
| self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod | |
| ).chunk(6, dim=1) | |
| input_x = modulate(self.norm1(x), shift_msa, scale_msa) | |
| if x_ip is not None: | |
| ( | |
| shift_msa_ip, | |
| scale_msa_ip, | |
| gate_msa_ip, | |
| shift_mlp_ip, | |
| scale_mlp_ip, | |
| gate_mlp_ip, | |
| ) = ( | |
| self.modulation.to(dtype=t_mod_ip.dtype, device=t_mod_ip.device) | |
| + t_mod_ip | |
| ).chunk(6, dim=1) | |
| input_x_ip = modulate( | |
| self.norm1(x_ip), shift_msa_ip, scale_msa_ip | |
| ) # [1, 1024, 5120] | |
| self.self_attn.cond_size = input_x_ip.shape[1] | |
| input_x = torch.concat([input_x, input_x_ip], dim=1) | |
| self.self_attn.kv_cache = None | |
| attn_out = self.self_attn(input_x, freqs) | |
| if x_ip is not None: | |
| attn_out, attn_out_ip = ( | |
| attn_out[:, : -self.self_attn.cond_size], | |
| attn_out[:, -self.self_attn.cond_size :], | |
| ) | |
| x = self.gate(x, gate_msa, attn_out) | |
| x = x + self.cross_attn(self.norm3(x), context) | |
| input_x = modulate(self.norm2(x), shift_mlp, scale_mlp) | |
| x = self.gate(x, gate_mlp, self.ffn(input_x)) | |
| if x_ip is not None: | |
| x_ip = self.gate(x_ip, gate_msa_ip, attn_out_ip) | |
| input_x_ip = modulate(self.norm2(x_ip), shift_mlp_ip, scale_mlp_ip) | |
| x_ip = self.gate(x_ip, gate_mlp_ip, self.ffn(input_x_ip)) | |
| return x, x_ip | |
| class MLP(torch.nn.Module): | |
| def __init__(self, in_dim, out_dim, has_pos_emb=False): | |
| super().__init__() | |
| self.proj = torch.nn.Sequential( | |
| nn.LayerNorm(in_dim), | |
| nn.Linear(in_dim, in_dim), | |
| nn.GELU(), | |
| nn.Linear(in_dim, out_dim), | |
| nn.LayerNorm(out_dim), | |
| ) | |
| self.has_pos_emb = has_pos_emb | |
| if has_pos_emb: | |
| self.emb_pos = torch.nn.Parameter(torch.zeros((1, 514, 1280))) | |
| def forward(self, x): | |
| if self.has_pos_emb: | |
| x = x + self.emb_pos.to(dtype=x.dtype, device=x.device) | |
| return self.proj(x) | |
| class Head(nn.Module): | |
| def __init__( | |
| self, dim: int, out_dim: int, patch_size: Tuple[int, int, int], eps: float | |
| ): | |
| super().__init__() | |
| self.dim = dim | |
| self.patch_size = patch_size | |
| self.norm = nn.LayerNorm(dim, eps=eps, elementwise_affine=False) | |
| self.head = nn.Linear(dim, out_dim * math.prod(patch_size)) | |
| self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5) | |
| def forward(self, x, t_mod): | |
| if len(t_mod.shape) == 3: | |
| shift, scale = ( | |
| self.modulation.unsqueeze(0).to(dtype=t_mod.dtype, device=t_mod.device) | |
| + t_mod.unsqueeze(2) | |
| ).chunk(2, dim=2) | |
| x = self.head(self.norm(x) * (1 + scale.squeeze(2)) + shift.squeeze(2)) | |
| else: | |
| shift, scale = ( | |
| self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod | |
| ).chunk(2, dim=1) | |
| x = self.head(self.norm(x) * (1 + scale) + shift) | |
| return x | |
| class WanModel(torch.nn.Module): | |
| def __init__( | |
| self, | |
| dim: int, | |
| in_dim: int, | |
| ffn_dim: int, | |
| out_dim: int, | |
| text_dim: int, | |
| freq_dim: int, | |
| eps: float, | |
| patch_size: Tuple[int, int, int], | |
| num_heads: int, | |
| num_layers: int, | |
| has_image_input: bool, | |
| has_image_pos_emb: bool = False, | |
| has_ref_conv: bool = False, | |
| add_control_adapter: bool = False, | |
| in_dim_control_adapter: int = 24, | |
| seperated_timestep: bool = False, | |
| require_vae_embedding: bool = True, | |
| require_clip_embedding: bool = True, | |
| fuse_vae_embedding_in_latents: bool = False, | |
| ): | |
| super().__init__() | |
| self.dim = dim | |
| self.freq_dim = freq_dim | |
| self.has_image_input = has_image_input | |
| self.patch_size = patch_size | |
| self.seperated_timestep = seperated_timestep | |
| self.require_vae_embedding = require_vae_embedding | |
| self.require_clip_embedding = require_clip_embedding | |
| self.fuse_vae_embedding_in_latents = fuse_vae_embedding_in_latents | |
| self.patch_embedding = nn.Conv3d( | |
| in_dim, dim, kernel_size=patch_size, stride=patch_size | |
| ) | |
| self.text_embedding = nn.Sequential( | |
| nn.Linear(text_dim, dim), nn.GELU(approximate="tanh"), nn.Linear(dim, dim) | |
| ) | |
| self.time_embedding = nn.Sequential( | |
| nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim) | |
| ) | |
| self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6)) | |
| self.blocks = nn.ModuleList( | |
| [ | |
| DiTBlock(has_image_input, dim, num_heads, ffn_dim, eps) | |
| for _ in range(num_layers) | |
| ] | |
| ) | |
| self.head = Head(dim, out_dim, patch_size, eps) | |
| head_dim = dim // num_heads | |
| self.freqs = precompute_freqs_cis_3d(head_dim) | |
| if has_image_input: | |
| self.img_emb = MLP( | |
| 1280, dim, has_pos_emb=has_image_pos_emb | |
| ) # clip_feature_dim = 1280 | |
| if has_ref_conv: | |
| self.ref_conv = nn.Conv2d(16, dim, kernel_size=(2, 2), stride=(2, 2)) | |
| self.has_image_pos_emb = has_image_pos_emb | |
| self.has_ref_conv = has_ref_conv | |
| if add_control_adapter: | |
| self.control_adapter = SimpleAdapter( | |
| in_dim_control_adapter, | |
| dim, | |
| kernel_size=patch_size[1:], | |
| stride=patch_size[1:], | |
| ) | |
| else: | |
| self.control_adapter = None | |
| def patchify( | |
| self, x: torch.Tensor, control_camera_latents_input: torch.Tensor = None | |
| ): | |
| x = self.patch_embedding(x) | |
| if ( | |
| self.control_adapter is not None | |
| and control_camera_latents_input is not None | |
| ): | |
| y_camera = self.control_adapter(control_camera_latents_input) | |
| x = [u + v for u, v in zip(x, y_camera)] | |
| x = x[0].unsqueeze(0) | |
| grid_size = x.shape[2:] | |
| x = rearrange(x, "b c f h w -> b (f h w) c").contiguous() | |
| return x, grid_size # x, grid_size: (f, h, w) | |
| def unpatchify(self, x: torch.Tensor, grid_size: torch.Tensor): | |
| return rearrange( | |
| x, | |
| "b (f h w) (x y z c) -> b c (f x) (h y) (w z)", | |
| f=grid_size[0], | |
| h=grid_size[1], | |
| w=grid_size[2], | |
| x=self.patch_size[0], | |
| y=self.patch_size[1], | |
| z=self.patch_size[2], | |
| ) | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| timestep: torch.Tensor, | |
| context: torch.Tensor, | |
| clip_feature: Optional[torch.Tensor] = None, | |
| y: Optional[torch.Tensor] = None, | |
| use_gradient_checkpointing: bool = False, | |
| use_gradient_checkpointing_offload: bool = False, | |
| ip_image=None, | |
| **kwargs, | |
| ): | |
| x_ip = None | |
| t_mod_ip = None | |
| t = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, timestep)) | |
| t_mod = self.time_projection(t).unflatten(1, (6, self.dim)) | |
| context = self.text_embedding(context) | |
| if ip_image is not None: | |
| timestep_ip = torch.zeros_like(timestep) # [B] with 0s | |
| t_ip = self.time_embedding( | |
| sinusoidal_embedding_1d(self.freq_dim, timestep_ip) | |
| ) | |
| t_mod_ip = self.time_projection(t_ip).unflatten(1, (6, self.dim)) | |
| x, (f, h, w) = self.patchify(x) | |
| offset = 1 | |
| freqs = ( | |
| torch.cat( | |
| [ | |
| self.freqs[0][offset : f + offset] | |
| .view(f, 1, 1, -1) | |
| .expand(f, h, w, -1), | |
| self.freqs[1][offset : h + offset] | |
| .view(1, h, 1, -1) | |
| .expand(f, h, w, -1), | |
| self.freqs[2][offset : w + offset] | |
| .view(1, 1, w, -1) | |
| .expand(f, h, w, -1), | |
| ], | |
| dim=-1, | |
| ) | |
| .reshape(f * h * w, 1, -1) | |
| .to(x.device) | |
| ) | |
| ############################################################################################ | |
| if ip_image is not None: | |
| if ip_image.dim() == 6 and ip_image.shape[3] == 1: | |
| ip_image = ip_image.squeeze(1) | |
| x_ip, (f_ip, h_ip, w_ip) = self.patchify( | |
| ip_image | |
| ) # x_ip [1, 1024, 5120] [B, N, D] f_ip = 1 h_ip = 32 w_ip = 32 | |
| freqs_ip = ( | |
| torch.cat( | |
| [ | |
| self.freqs[0][0] | |
| .view(f_ip, 1, 1, -1) | |
| .expand(f_ip, h_ip, w_ip, -1), | |
| self.freqs[1][h + offset : h + offset + h_ip] | |
| .view(1, h_ip, 1, -1) | |
| .expand(f_ip, h_ip, w_ip, -1), | |
| self.freqs[2][w + offset : w + offset + w_ip] | |
| .view(1, 1, w_ip, -1) | |
| .expand(f_ip, h_ip, w_ip, -1), | |
| ], | |
| dim=-1, | |
| ) | |
| .reshape(f_ip * h_ip * w_ip, 1, -1) | |
| .to(x_ip.device) | |
| ) | |
| freqs = torch.cat([freqs, freqs_ip], dim=0) | |
| ############################################################################################ | |
| def create_custom_forward(module): | |
| def custom_forward(*inputs): | |
| return module(*inputs) | |
| return custom_forward | |
| for block in self.blocks: | |
| if self.training and use_gradient_checkpointing: | |
| if use_gradient_checkpointing_offload: | |
| with torch.autograd.graph.save_on_cpu(): | |
| x, x_ip = torch.utils.checkpoint.checkpoint( | |
| create_custom_forward(block), | |
| x, | |
| context, | |
| t_mod, | |
| freqs, | |
| x_ip, | |
| t_mod_ip, | |
| use_reentrant=False, | |
| ) | |
| else: | |
| x, x_ip = torch.utils.checkpoint.checkpoint( | |
| create_custom_forward(block), | |
| x, | |
| context, | |
| t_mod, | |
| freqs, | |
| x_ip, | |
| t_mod_ip, | |
| use_reentrant=False, | |
| ) | |
| else: | |
| x, x_ip = block(x, context, t_mod, freqs, x_ip, t_mod_ip) | |
| x = self.head(x, t) | |
| x = self.unpatchify(x, (f, h, w)) | |
| return x | |
| def state_dict_converter(): | |
| return WanModelStateDictConverter() | |
| class WanModelStateDictConverter: | |
| def __init__(self): | |
| pass | |
| def from_diffusers(self, state_dict): | |
| rename_dict = { | |
| "blocks.0.attn1.norm_k.weight": "blocks.0.self_attn.norm_k.weight", | |
| "blocks.0.attn1.norm_q.weight": "blocks.0.self_attn.norm_q.weight", | |
| "blocks.0.attn1.to_k.bias": "blocks.0.self_attn.k.bias", | |
| "blocks.0.attn1.to_k.weight": "blocks.0.self_attn.k.weight", | |
| "blocks.0.attn1.to_out.0.bias": "blocks.0.self_attn.o.bias", | |
| "blocks.0.attn1.to_out.0.weight": "blocks.0.self_attn.o.weight", | |
| "blocks.0.attn1.to_q.bias": "blocks.0.self_attn.q.bias", | |
| "blocks.0.attn1.to_q.weight": "blocks.0.self_attn.q.weight", | |
| "blocks.0.attn1.to_v.bias": "blocks.0.self_attn.v.bias", | |
| "blocks.0.attn1.to_v.weight": "blocks.0.self_attn.v.weight", | |
| "blocks.0.attn2.norm_k.weight": "blocks.0.cross_attn.norm_k.weight", | |
| "blocks.0.attn2.norm_q.weight": "blocks.0.cross_attn.norm_q.weight", | |
| "blocks.0.attn2.to_k.bias": "blocks.0.cross_attn.k.bias", | |
| "blocks.0.attn2.to_k.weight": "blocks.0.cross_attn.k.weight", | |
| "blocks.0.attn2.to_out.0.bias": "blocks.0.cross_attn.o.bias", | |
| "blocks.0.attn2.to_out.0.weight": "blocks.0.cross_attn.o.weight", | |
| "blocks.0.attn2.to_q.bias": "blocks.0.cross_attn.q.bias", | |
| "blocks.0.attn2.to_q.weight": "blocks.0.cross_attn.q.weight", | |
| "blocks.0.attn2.to_v.bias": "blocks.0.cross_attn.v.bias", | |
| "blocks.0.attn2.to_v.weight": "blocks.0.cross_attn.v.weight", | |
| "blocks.0.ffn.net.0.proj.bias": "blocks.0.ffn.0.bias", | |
| "blocks.0.ffn.net.0.proj.weight": "blocks.0.ffn.0.weight", | |
| "blocks.0.ffn.net.2.bias": "blocks.0.ffn.2.bias", | |
| "blocks.0.ffn.net.2.weight": "blocks.0.ffn.2.weight", | |
| "blocks.0.norm2.bias": "blocks.0.norm3.bias", | |
| "blocks.0.norm2.weight": "blocks.0.norm3.weight", | |
| "blocks.0.scale_shift_table": "blocks.0.modulation", | |
| "condition_embedder.text_embedder.linear_1.bias": "text_embedding.0.bias", | |
| "condition_embedder.text_embedder.linear_1.weight": "text_embedding.0.weight", | |
| "condition_embedder.text_embedder.linear_2.bias": "text_embedding.2.bias", | |
| "condition_embedder.text_embedder.linear_2.weight": "text_embedding.2.weight", | |
| "condition_embedder.time_embedder.linear_1.bias": "time_embedding.0.bias", | |
| "condition_embedder.time_embedder.linear_1.weight": "time_embedding.0.weight", | |
| "condition_embedder.time_embedder.linear_2.bias": "time_embedding.2.bias", | |
| "condition_embedder.time_embedder.linear_2.weight": "time_embedding.2.weight", | |
| "condition_embedder.time_proj.bias": "time_projection.1.bias", | |
| "condition_embedder.time_proj.weight": "time_projection.1.weight", | |
| "patch_embedding.bias": "patch_embedding.bias", | |
| "patch_embedding.weight": "patch_embedding.weight", | |
| "scale_shift_table": "head.modulation", | |
| "proj_out.bias": "head.head.bias", | |
| "proj_out.weight": "head.head.weight", | |
| } | |
| state_dict_ = {} | |
| for name, param in state_dict.items(): | |
| if name in rename_dict: | |
| state_dict_[rename_dict[name]] = param | |
| else: | |
| name_ = ".".join(name.split(".")[:1] + ["0"] + name.split(".")[2:]) | |
| if name_ in rename_dict: | |
| name_ = rename_dict[name_] | |
| name_ = ".".join( | |
| name_.split(".")[:1] | |
| + [name.split(".")[1]] | |
| + name_.split(".")[2:] | |
| ) | |
| state_dict_[name_] = param | |
| if hash_state_dict_keys(state_dict) == "cb104773c6c2cb6df4f9529ad5c60d0b": | |
| config = { | |
| "model_type": "t2v", | |
| "patch_size": (1, 2, 2), | |
| "text_len": 512, | |
| "in_dim": 16, | |
| "dim": 5120, | |
| "ffn_dim": 13824, | |
| "freq_dim": 256, | |
| "text_dim": 4096, | |
| "out_dim": 16, | |
| "num_heads": 40, | |
| "num_layers": 40, | |
| "window_size": (-1, -1), | |
| "qk_norm": True, | |
| "cross_attn_norm": True, | |
| "eps": 1e-6, | |
| } | |
| else: | |
| config = {} | |
| return state_dict_, config | |
| def from_civitai(self, state_dict): | |
| state_dict = { | |
| name: param | |
| for name, param in state_dict.items() | |
| if not name.startswith("vace") | |
| } | |
| if hash_state_dict_keys(state_dict) == "9269f8db9040a9d860eaca435be61814": | |
| config = { | |
| "has_image_input": False, | |
| "patch_size": [1, 2, 2], | |
| "in_dim": 16, | |
| "dim": 1536, | |
| "ffn_dim": 8960, | |
| "freq_dim": 256, | |
| "text_dim": 4096, | |
| "out_dim": 16, | |
| "num_heads": 12, | |
| "num_layers": 30, | |
| "eps": 1e-6, | |
| } | |
| elif hash_state_dict_keys(state_dict) == "aafcfd9672c3a2456dc46e1cb6e52c70": | |
| config = { | |
| "has_image_input": False, | |
| "patch_size": [1, 2, 2], | |
| "in_dim": 16, | |
| "dim": 5120, | |
| "ffn_dim": 13824, | |
| "freq_dim": 256, | |
| "text_dim": 4096, | |
| "out_dim": 16, | |
| "num_heads": 40, | |
| "num_layers": 40, | |
| "eps": 1e-6, | |
| } | |
| elif hash_state_dict_keys(state_dict) == "6bfcfb3b342cb286ce886889d519a77e": | |
| config = { | |
| "has_image_input": True, | |
| "patch_size": [1, 2, 2], | |
| "in_dim": 36, | |
| "dim": 5120, | |
| "ffn_dim": 13824, | |
| "freq_dim": 256, | |
| "text_dim": 4096, | |
| "out_dim": 16, | |
| "num_heads": 40, | |
| "num_layers": 40, | |
| "eps": 1e-6, | |
| } | |
| elif hash_state_dict_keys(state_dict) == "6d6ccde6845b95ad9114ab993d917893": | |
| config = { | |
| "has_image_input": True, | |
| "patch_size": [1, 2, 2], | |
| "in_dim": 36, | |
| "dim": 1536, | |
| "ffn_dim": 8960, | |
| "freq_dim": 256, | |
| "text_dim": 4096, | |
| "out_dim": 16, | |
| "num_heads": 12, | |
| "num_layers": 30, | |
| "eps": 1e-6, | |
| } | |
| elif hash_state_dict_keys(state_dict) == "6bfcfb3b342cb286ce886889d519a77e": | |
| config = { | |
| "has_image_input": True, | |
| "patch_size": [1, 2, 2], | |
| "in_dim": 36, | |
| "dim": 5120, | |
| "ffn_dim": 13824, | |
| "freq_dim": 256, | |
| "text_dim": 4096, | |
| "out_dim": 16, | |
| "num_heads": 40, | |
| "num_layers": 40, | |
| "eps": 1e-6, | |
| } | |
| elif hash_state_dict_keys(state_dict) == "349723183fc063b2bfc10bb2835cf677": | |
| # 1.3B PAI control | |
| config = { | |
| "has_image_input": True, | |
| "patch_size": [1, 2, 2], | |
| "in_dim": 48, | |
| "dim": 1536, | |
| "ffn_dim": 8960, | |
| "freq_dim": 256, | |
| "text_dim": 4096, | |
| "out_dim": 16, | |
| "num_heads": 12, | |
| "num_layers": 30, | |
| "eps": 1e-6, | |
| } | |
| elif hash_state_dict_keys(state_dict) == "efa44cddf936c70abd0ea28b6cbe946c": | |
| # 14B PAI control | |
| config = { | |
| "has_image_input": True, | |
| "patch_size": [1, 2, 2], | |
| "in_dim": 48, | |
| "dim": 5120, | |
| "ffn_dim": 13824, | |
| "freq_dim": 256, | |
| "text_dim": 4096, | |
| "out_dim": 16, | |
| "num_heads": 40, | |
| "num_layers": 40, | |
| "eps": 1e-6, | |
| } | |
| elif hash_state_dict_keys(state_dict) == "3ef3b1f8e1dab83d5b71fd7b617f859f": | |
| config = { | |
| "has_image_input": True, | |
| "patch_size": [1, 2, 2], | |
| "in_dim": 36, | |
| "dim": 5120, | |
| "ffn_dim": 13824, | |
| "freq_dim": 256, | |
| "text_dim": 4096, | |
| "out_dim": 16, | |
| "num_heads": 40, | |
| "num_layers": 40, | |
| "eps": 1e-6, | |
| "has_image_pos_emb": True, | |
| } | |
| elif hash_state_dict_keys(state_dict) == "70ddad9d3a133785da5ea371aae09504": | |
| # 1.3B PAI control v1.1 | |
| config = { | |
| "has_image_input": True, | |
| "patch_size": [1, 2, 2], | |
| "in_dim": 48, | |
| "dim": 1536, | |
| "ffn_dim": 8960, | |
| "freq_dim": 256, | |
| "text_dim": 4096, | |
| "out_dim": 16, | |
| "num_heads": 12, | |
| "num_layers": 30, | |
| "eps": 1e-6, | |
| "has_ref_conv": True, | |
| } | |
| elif hash_state_dict_keys(state_dict) == "26bde73488a92e64cc20b0a7485b9e5b": | |
| # 14B PAI control v1.1 | |
| config = { | |
| "has_image_input": True, | |
| "patch_size": [1, 2, 2], | |
| "in_dim": 48, | |
| "dim": 5120, | |
| "ffn_dim": 13824, | |
| "freq_dim": 256, | |
| "text_dim": 4096, | |
| "out_dim": 16, | |
| "num_heads": 40, | |
| "num_layers": 40, | |
| "eps": 1e-6, | |
| "has_ref_conv": True, | |
| } | |
| elif hash_state_dict_keys(state_dict) == "ac6a5aa74f4a0aab6f64eb9a72f19901": | |
| # 1.3B PAI control-camera v1.1 | |
| config = { | |
| "has_image_input": True, | |
| "patch_size": [1, 2, 2], | |
| "in_dim": 32, | |
| "dim": 1536, | |
| "ffn_dim": 8960, | |
| "freq_dim": 256, | |
| "text_dim": 4096, | |
| "out_dim": 16, | |
| "num_heads": 12, | |
| "num_layers": 30, | |
| "eps": 1e-6, | |
| "has_ref_conv": False, | |
| "add_control_adapter": True, | |
| "in_dim_control_adapter": 24, | |
| } | |
| elif hash_state_dict_keys(state_dict) == "b61c605c2adbd23124d152ed28e049ae": | |
| # 14B PAI control-camera v1.1 | |
| config = { | |
| "has_image_input": True, | |
| "patch_size": [1, 2, 2], | |
| "in_dim": 32, | |
| "dim": 5120, | |
| "ffn_dim": 13824, | |
| "freq_dim": 256, | |
| "text_dim": 4096, | |
| "out_dim": 16, | |
| "num_heads": 40, | |
| "num_layers": 40, | |
| "eps": 1e-6, | |
| "has_ref_conv": False, | |
| "add_control_adapter": True, | |
| "in_dim_control_adapter": 24, | |
| } | |
| elif hash_state_dict_keys(state_dict) == "1f5ab7703c6fc803fdded85ff040c316": | |
| # Wan-AI/Wan2.2-TI2V-5B | |
| config = { | |
| "has_image_input": False, | |
| "patch_size": [1, 2, 2], | |
| "in_dim": 48, | |
| "dim": 3072, | |
| "ffn_dim": 14336, | |
| "freq_dim": 256, | |
| "text_dim": 4096, | |
| "out_dim": 48, | |
| "num_heads": 24, | |
| "num_layers": 30, | |
| "eps": 1e-6, | |
| "seperated_timestep": True, | |
| "require_clip_embedding": False, | |
| "require_vae_embedding": False, | |
| "fuse_vae_embedding_in_latents": True, | |
| } | |
| elif hash_state_dict_keys(state_dict) == "5b013604280dd715f8457c6ed6d6a626": | |
| # Wan-AI/Wan2.2-I2V-A14B | |
| config = { | |
| "has_image_input": False, | |
| "patch_size": [1, 2, 2], | |
| "in_dim": 36, | |
| "dim": 5120, | |
| "ffn_dim": 13824, | |
| "freq_dim": 256, | |
| "text_dim": 4096, | |
| "out_dim": 16, | |
| "num_heads": 40, | |
| "num_layers": 40, | |
| "eps": 1e-6, | |
| "require_clip_embedding": False, | |
| } | |
| else: | |
| config = {} | |
| return state_dict, config | |