jbilcke-hf's picture
jbilcke-hf HF staff
Upload 6 files
a6d4ba6 verified
raw
history blame
6.88 kB
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers.models.attention import Attention
from einops import rearrange
from torch import nn
from enhance_a_video.enhance import enhance_score
from enhance_a_video.globals import get_num_frames, is_enhance_enabled, set_num_frames
def inject_enhance_for_hunyuanvideo(model: nn.Module) -> None:
"""
Inject enhance score for HunyuanVideo model.
1. register hook to update num frames
2. replace attention processor with enhance processor to weight the attention scores
"""
# register hook to update num frames
model.register_forward_pre_hook(num_frames_hook, with_kwargs=True)
# replace attention with enhanceAvideo
for name, module in model.named_modules():
if "attn" in name and isinstance(module, Attention) and "transformer_blocks" in name:
module.set_processor(EnhanceHunyuanVideoAttnProcessor2_0())
def num_frames_hook(module, args, kwargs):
"""
Hook to update the number of frames automatically.
"""
if "hidden_states" in kwargs:
hidden_states = kwargs["hidden_states"]
else:
hidden_states = args[0]
num_frames = hidden_states.shape[2]
p_t = module.config.patch_size_t
post_patch_num_frames = num_frames // p_t
set_num_frames(post_patch_num_frames)
return args, kwargs
class EnhanceHunyuanVideoAttnProcessor2_0:
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"HunyuanVideoAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0."
)
def _get_enhance_scores(self, attn, query, key, encoder_hidden_states):
if attn.add_q_proj is None and encoder_hidden_states is not None:
img_q, img_k = query[:, :, : -encoder_hidden_states.shape[1]], key[:, :, : -encoder_hidden_states.shape[1]]
else:
img_q, img_k = query, key
num_frames = get_num_frames()
_, num_heads, ST, head_dim = img_q.shape
spatial_dim = ST / num_frames
spatial_dim = int(spatial_dim)
query_image = rearrange(
img_q, "B N (T S) C -> (B S) N T C", T=num_frames, S=spatial_dim, N=num_heads, C=head_dim
)
key_image = rearrange(img_k, "B N (T S) C -> (B S) N T C", T=num_frames, S=spatial_dim, N=num_heads, C=head_dim)
return enhance_score(query_image, key_image, head_dim, num_frames)
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if attn.add_q_proj is None and encoder_hidden_states is not None:
hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
# 1. QKV projections
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
# 2. QK normalization
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# 3. Rotational positional embeddings applied to latent stream
if image_rotary_emb is not None:
from diffusers.models.embeddings import apply_rotary_emb
if attn.add_q_proj is None and encoder_hidden_states is not None:
query = torch.cat(
[
apply_rotary_emb(query[:, :, : -encoder_hidden_states.shape[1]], image_rotary_emb),
query[:, :, -encoder_hidden_states.shape[1] :],
],
dim=2,
)
key = torch.cat(
[
apply_rotary_emb(key[:, :, : -encoder_hidden_states.shape[1]], image_rotary_emb),
key[:, :, -encoder_hidden_states.shape[1] :],
],
dim=2,
)
else:
query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb)
# ========== Enhance-A-Video ==========
if is_enhance_enabled():
enhance_scores = self._get_enhance_scores(attn, query, key, encoder_hidden_states)
# ========== Enhance-A-Video ==========
# 4. Encoder condition QKV projection and normalization
if attn.add_q_proj is not None and encoder_hidden_states is not None:
encoder_query = attn.add_q_proj(encoder_hidden_states)
encoder_key = attn.add_k_proj(encoder_hidden_states)
encoder_value = attn.add_v_proj(encoder_hidden_states)
encoder_query = encoder_query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
encoder_key = encoder_key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
encoder_value = encoder_value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
if attn.norm_added_q is not None:
encoder_query = attn.norm_added_q(encoder_query)
if attn.norm_added_k is not None:
encoder_key = attn.norm_added_k(encoder_key)
query = torch.cat([query, encoder_query], dim=2)
key = torch.cat([key, encoder_key], dim=2)
value = torch.cat([value, encoder_value], dim=2)
# 5. Attention
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
hidden_states = hidden_states.to(query.dtype)
# 6. Output projection
if encoder_hidden_states is not None:
hidden_states, encoder_hidden_states = (
hidden_states[:, : -encoder_hidden_states.shape[1]],
hidden_states[:, -encoder_hidden_states.shape[1] :],
)
if getattr(attn, "to_out", None) is not None:
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
if getattr(attn, "to_add_out", None) is not None:
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
# ========== Enhance-A-Video ==========
if is_enhance_enabled():
hidden_states = hidden_states * enhance_scores
# ========== Enhance-A-Video ==========
return hidden_states, encoder_hidden_states