ViTeX-Edit-14B / diffsynth /models /wan_video_vace.py
ViTeX-Bench's picture
Bundle diffsynth library (no external repo dependency)
bc8c4af verified
import torch
import torch.nn as nn
from .wan_video_dit import DiTBlock, CrossAttention, flash_attention
from ..core.gradient import gradient_checkpoint_forward
class _OffloadToCPU(torch.autograd.Function):
"""Move tensor to CPU in forward, move gradient to GPU in backward."""
@staticmethod
def forward(ctx, tensor):
ctx.gpu_device = tensor.device
return tensor.detach().cpu().requires_grad_(True)
@staticmethod
def backward(ctx, grad_output):
return grad_output.to(ctx.gpu_device)
class _RestoreToGPU(torch.autograd.Function):
"""Move CPU tensor back to GPU in forward, move gradient to CPU in backward."""
@staticmethod
def forward(ctx, tensor, device):
ctx.save_for_backward(torch.empty(0))
ctx._device_str = str(device)
return tensor.to(device)
@staticmethod
def backward(ctx, grad_output):
return grad_output.cpu(), None
def tokenize_target_text(text, max_len=64, vocab_size=8192):
"""Convert target text string to character-level token IDs.
Uses modular hashing of Unicode code points to map any character
(ASCII, CJK, Cyrillic, math symbols, etc.) into a fixed vocabulary.
Token 0 is reserved for padding.
"""
ids = []
for ch in text[:max_len]:
ids.append(ord(ch) % (vocab_size - 1) + 1)
ids += [0] * (max_len - len(ids))
return ids
class ConditionCrossAttention(nn.Module):
"""Cross-attention for injecting condition features into VACE blocks.
Used for both glyph visual features (v2) and character-level text
tokens (v3). Each spatial position in the VACE hidden states queries
the condition tokens to determine what should be rendered at that location.
Zero-initialized output projection ensures the model starts from
pretrained VACE behavior and gradually learns to use the condition.
"""
def __init__(self, dim, num_heads, eps=1e-6):
super().__init__()
self.num_heads = num_heads
self.norm = nn.LayerNorm(dim, eps=eps)
self.q = nn.Linear(dim, dim)
self.k = nn.Linear(dim, dim)
self.v = nn.Linear(dim, dim)
self.o = nn.Linear(dim, dim)
# Zero-init output projection so condition attention starts as no-op
nn.init.zeros_(self.o.weight)
nn.init.zeros_(self.o.bias)
def forward(self, x, condition_tokens):
"""
Args:
x: VACE hidden states (B, seq_len, dim)
condition_tokens: condition features (B, num_tokens, dim)
Returns:
x + condition_attn_output (B, seq_len, dim)
"""
residual = x
x = self.norm(x)
q = self.q(x)
k = self.k(condition_tokens)
v = self.v(condition_tokens)
out = flash_attention(q, k, v, self.num_heads)
return residual + self.o(out)
# Backward compatibility alias
GlyphCrossAttention = ConditionCrossAttention
class VaceWanAttentionBlock(DiTBlock):
def __init__(self, has_image_input, dim, num_heads, ffn_dim, eps=1e-6, block_id=0):
super().__init__(has_image_input, dim, num_heads, ffn_dim, eps=eps)
self.block_id = block_id
if block_id == 0:
self.before_proj = torch.nn.Linear(self.dim, self.dim)
self.after_proj = torch.nn.Linear(self.dim, self.dim)
def forward(self, c, x, context, t_mod, freqs, condition_tokens=None):
if self.block_id == 0:
c = self.before_proj(c) + x
c = super().forward(c, context, t_mod, freqs)
# Condition cross-attention (injected after text cross-attn + FFN)
# Works for both glyph tokens (v2) and character tokens (v3)
if condition_tokens is not None and hasattr(self, 'condition_cross_attn'):
c = self.condition_cross_attn(c, condition_tokens)
c_skip = self.after_proj(c)
return c_skip, c
class GlyphEncoder(nn.Module):
"""Lightweight encoder that compresses glyph latents into tokens
for cross-attention in VACE blocks.
glyph_latent (16ch, from VAE) → Conv3D patch embed → spatial pooling
→ compact token sequence that each VACE block attends to.
"""
def __init__(self, in_channels=16, dim=1536, num_tokens=64, patch_size=(1, 2, 2)):
super().__init__()
self.num_tokens = num_tokens
# Patch embedding (same architecture as vace_patch_embedding)
self.patch_embed = nn.Conv3d(in_channels, dim, kernel_size=patch_size, stride=patch_size)
# Compress spatial sequence to fixed number of tokens via cross-attention pooling
self.query_tokens = nn.Parameter(torch.randn(1, num_tokens, dim) * 0.02)
self.pool_attn = nn.MultiheadAttention(dim, num_heads=8, batch_first=True)
self.pool_norm = nn.LayerNorm(dim)
# Output projection (zero-init for safe initialization)
self.out_proj = nn.Linear(dim, dim)
nn.init.zeros_(self.out_proj.weight)
nn.init.zeros_(self.out_proj.bias)
def forward(self, glyph_latent):
"""
Args:
glyph_latent: (B, 16, T, H, W) from VAE-encoded glyph video
Returns:
glyph_tokens: (B, num_tokens, dim) compressed glyph features
"""
# Patch embed: (B, dim, T, H/2, W/2)
x = self.patch_embed(glyph_latent)
B = x.shape[0]
# Flatten spatial: (B, dim, N) → (B, N, dim)
x = x.flatten(2).transpose(1, 2)
# Cross-attention pooling: compress N spatial tokens → num_tokens
queries = self.query_tokens.expand(B, -1, -1).to(dtype=x.dtype, device=x.device)
x_normed = self.pool_norm(x)
pooled, _ = self.pool_attn(queries, x_normed, x_normed)
return self.out_proj(pooled)
class TargetTextEncoder(nn.Module):
"""Character-level encoder for target text strings.
Encodes the target text (e.g., "NAD", "CARLIFE", "善") at the character
level using learned embeddings and a small Transformer. This provides
character-precise identity information that T5's subword tokenization
cannot guarantee.
The output tokens are used as K/V in ConditionCrossAttention within
each VACE block, allowing each spatial position to query which character
it should render.
"""
def __init__(self, vocab_size=8192, max_len=64, dim=1536,
num_layers=2, num_heads=8):
super().__init__()
self.vocab_size = vocab_size
self.max_len = max_len
self.char_embed = nn.Embedding(vocab_size, dim, padding_idx=0)
self.pos_embed = nn.Embedding(max_len, dim)
encoder_layer = nn.TransformerEncoderLayer(
d_model=dim, nhead=num_heads, dim_feedforward=dim * 4,
batch_first=True, norm_first=True, activation='gelu',
)
self.transformer = nn.TransformerEncoder(
encoder_layer, num_layers=num_layers,
)
# Zero-init output projection for safe initialization
self.out_proj = nn.Linear(dim, dim)
nn.init.zeros_(self.out_proj.weight)
nn.init.zeros_(self.out_proj.bias)
def forward(self, token_ids):
"""
Args:
token_ids: (B, max_len) long tensor of character IDs (0 = PAD)
Returns:
text_tokens: (B, max_len, dim)
"""
B, L = token_ids.shape
positions = torch.arange(L, device=token_ids.device).unsqueeze(0).expand(B, -1)
x = self.char_embed(token_ids) + self.pos_embed(positions)
# Padding mask: True means ignore this position
pad_mask = (token_ids == 0)
x = self.transformer(x, src_key_padding_mask=pad_mask)
return self.out_proj(x)
class VaceWanModel(torch.nn.Module):
def __init__(
self,
vace_layers=(0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28),
vace_in_dim=96,
glyph_channels=0,
glyph_num_tokens=64,
use_target_text_encoder=False,
text_vocab_size=8192,
text_max_len=64,
text_num_layers=2,
patch_size=(1, 2, 2),
has_image_input=False,
dim=1536,
num_heads=12,
ffn_dim=8960,
eps=1e-6,
):
super().__init__()
self.vace_layers = vace_layers
self.vace_in_dim = vace_in_dim
self.glyph_channels = glyph_channels
self.use_target_text_encoder = use_target_text_encoder
self.vace_layers_mapping = {i: n for n, i in enumerate(self.vace_layers)}
# VACE blocks
self.vace_blocks = torch.nn.ModuleList([
VaceWanAttentionBlock(has_image_input, dim, num_heads, ffn_dim, eps, block_id=i)
for i in self.vace_layers
])
# VACE patch embedding (original 96 channels, no glyph concat)
self.vace_patch_embedding = torch.nn.Conv3d(
vace_in_dim, dim, kernel_size=patch_size, stride=patch_size
)
# Glyph pathway (v2: glyph video → GlyphEncoder → cross-attention)
if glyph_channels > 0:
self.glyph_encoder = GlyphEncoder(
in_channels=glyph_channels,
dim=dim,
num_tokens=glyph_num_tokens,
patch_size=patch_size,
)
for block in self.vace_blocks:
block.condition_cross_attn = ConditionCrossAttention(dim, num_heads, eps)
# Target text pathway (v3: text string → TargetTextEncoder → cross-attention)
if use_target_text_encoder:
self.target_text_encoder = TargetTextEncoder(
vocab_size=text_vocab_size,
max_len=text_max_len,
dim=dim,
num_layers=text_num_layers,
num_heads=num_heads,
)
for block in self.vace_blocks:
block.condition_cross_attn = ConditionCrossAttention(dim, num_heads, eps)
def _has_new_modules(self):
"""Check if this model has modules not present in pretrained checkpoints."""
return self.glyph_channels > 0 or self.use_target_text_encoder
def load_state_dict(self, state_dict, strict=True, assign=False):
if self._has_new_modules():
# New modules won't be in pretrained checkpoints.
# First, materialize any meta-tensor parameters so they can be assigned.
for name, param in self.named_parameters():
if param.is_meta:
materialized = torch.zeros(param.shape, dtype=param.dtype, device='cpu')
parts = name.split('.')
module = self
for p in parts[:-1]:
module = getattr(module, p)
setattr(module, parts[-1], torch.nn.Parameter(materialized))
result = super().load_state_dict(state_dict, strict=False, assign=assign)
# Re-initialize glyph modules (v2 mode)
if hasattr(self, 'glyph_encoder'):
for name, module in self.glyph_encoder.named_modules():
if isinstance(module, (torch.nn.Linear, torch.nn.Conv3d)):
if 'out_proj' not in name:
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
if hasattr(self.glyph_encoder, 'query_tokens'):
torch.nn.init.normal_(self.glyph_encoder.query_tokens, std=0.02)
# Re-initialize target text encoder modules (v3 mode)
if hasattr(self, 'target_text_encoder'):
for name, module in self.target_text_encoder.named_modules():
if isinstance(module, (torch.nn.Linear,)):
if 'out_proj' not in name:
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, torch.nn.Embedding):
torch.nn.init.normal_(module.weight, std=0.02)
if module.padding_idx is not None:
torch.nn.init.zeros_(module.weight[module.padding_idx])
# Re-initialize condition cross-attention modules
for block in self.vace_blocks:
if hasattr(block, 'condition_cross_attn'):
ca = block.condition_cross_attn
for name, module in ca.named_modules():
if isinstance(module, torch.nn.Linear):
if name == 'o':
torch.nn.init.zeros_(module.weight)
torch.nn.init.zeros_(module.bias)
else:
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
# Cast re-initialized modules to model dtype (xavier_uniform_ etc. produce float32)
target_dtype = next((p.dtype for p in self.parameters() if p.dtype != torch.float32), torch.bfloat16)
if hasattr(self, 'glyph_encoder'):
self.glyph_encoder.to(dtype=target_dtype)
if hasattr(self, 'target_text_encoder'):
self.target_text_encoder.to(dtype=target_dtype)
for block in self.vace_blocks:
if hasattr(block, 'condition_cross_attn'):
block.condition_cross_attn.to(dtype=target_dtype)
return result
return super().load_state_dict(state_dict, strict=strict, assign=assign)
def forward(
self, x, vace_context, context, t_mod, freqs,
glyph_latent=None,
target_text_ids=None,
use_gradient_checkpointing: bool = False,
use_gradient_checkpointing_offload: bool = False,
):
c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context]
c = [u.flatten(2).transpose(1, 2) for u in c]
c = torch.cat([
torch.cat([u, u.new_zeros(1, x.shape[1] - u.size(1), u.size(2))],
dim=1) for u in c
])
# Encode condition tokens for cross-attention
condition_tokens = None
if glyph_latent is not None and hasattr(self, 'glyph_encoder'):
condition_tokens = self.glyph_encoder(glyph_latent)
elif target_text_ids is not None and hasattr(self, 'target_text_encoder'):
condition_tokens = self.target_text_encoder(target_text_ids)
hints = []
for block in self.vace_blocks:
result = gradient_checkpoint_forward(
block,
use_gradient_checkpointing,
use_gradient_checkpointing_offload,
c, x, context, t_mod, freqs, condition_tokens
)
c_skip, c = result
hints.append(c_skip)
return hints