|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import inspect |
|
|
from dataclasses import dataclass |
|
|
from typing import Any, Callable, Dict, Type |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class AttentionProcessorMetadata: |
|
|
skip_processor_output_fn: Callable[[Any], Any] |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class TransformerBlockMetadata: |
|
|
return_hidden_states_index: int = None |
|
|
return_encoder_hidden_states_index: int = None |
|
|
|
|
|
_cls: Type = None |
|
|
_cached_parameter_indices: Dict[str, int] = None |
|
|
|
|
|
def _get_parameter_from_args_kwargs(self, identifier: str, args=(), kwargs=None): |
|
|
kwargs = kwargs or {} |
|
|
if identifier in kwargs: |
|
|
return kwargs[identifier] |
|
|
if self._cached_parameter_indices is not None: |
|
|
return args[self._cached_parameter_indices[identifier]] |
|
|
if self._cls is None: |
|
|
raise ValueError("Model class is not set for metadata.") |
|
|
parameters = list(inspect.signature(self._cls.forward).parameters.keys()) |
|
|
parameters = parameters[1:] |
|
|
self._cached_parameter_indices = {param: i for i, param in enumerate(parameters)} |
|
|
if identifier not in self._cached_parameter_indices: |
|
|
raise ValueError(f"Parameter '{identifier}' not found in function signature but was requested.") |
|
|
index = self._cached_parameter_indices[identifier] |
|
|
if index >= len(args): |
|
|
raise ValueError(f"Expected {index} arguments but got {len(args)}.") |
|
|
return args[index] |
|
|
|
|
|
|
|
|
class AttentionProcessorRegistry: |
|
|
_registry = {} |
|
|
|
|
|
|
|
|
|
|
|
_is_registered = False |
|
|
|
|
|
@classmethod |
|
|
def register(cls, model_class: Type, metadata: AttentionProcessorMetadata): |
|
|
cls._register() |
|
|
cls._registry[model_class] = metadata |
|
|
|
|
|
@classmethod |
|
|
def get(cls, model_class: Type) -> AttentionProcessorMetadata: |
|
|
cls._register() |
|
|
if model_class not in cls._registry: |
|
|
raise ValueError(f"Model class {model_class} not registered.") |
|
|
return cls._registry[model_class] |
|
|
|
|
|
@classmethod |
|
|
def _register(cls): |
|
|
if cls._is_registered: |
|
|
return |
|
|
cls._is_registered = True |
|
|
_register_attention_processors_metadata() |
|
|
|
|
|
|
|
|
class TransformerBlockRegistry: |
|
|
_registry = {} |
|
|
|
|
|
|
|
|
|
|
|
_is_registered = False |
|
|
|
|
|
@classmethod |
|
|
def register(cls, model_class: Type, metadata: TransformerBlockMetadata): |
|
|
cls._register() |
|
|
metadata._cls = model_class |
|
|
cls._registry[model_class] = metadata |
|
|
|
|
|
@classmethod |
|
|
def get(cls, model_class: Type) -> TransformerBlockMetadata: |
|
|
cls._register() |
|
|
if model_class not in cls._registry: |
|
|
raise ValueError(f"Model class {model_class} not registered.") |
|
|
return cls._registry[model_class] |
|
|
|
|
|
@classmethod |
|
|
def _register(cls): |
|
|
if cls._is_registered: |
|
|
return |
|
|
cls._is_registered = True |
|
|
_register_transformer_blocks_metadata() |
|
|
|
|
|
|
|
|
def _register_attention_processors_metadata(): |
|
|
from ..models.attention_processor import AttnProcessor2_0 |
|
|
from ..models.transformers.transformer_cogview4 import CogView4AttnProcessor |
|
|
from ..models.transformers.transformer_flux import FluxAttnProcessor |
|
|
from ..models.transformers.transformer_qwenimage import QwenDoubleStreamAttnProcessor2_0 |
|
|
from ..models.transformers.transformer_wan import WanAttnProcessor2_0 |
|
|
|
|
|
|
|
|
AttentionProcessorRegistry.register( |
|
|
model_class=AttnProcessor2_0, |
|
|
metadata=AttentionProcessorMetadata( |
|
|
skip_processor_output_fn=_skip_proc_output_fn_Attention_AttnProcessor2_0, |
|
|
), |
|
|
) |
|
|
|
|
|
|
|
|
AttentionProcessorRegistry.register( |
|
|
model_class=CogView4AttnProcessor, |
|
|
metadata=AttentionProcessorMetadata( |
|
|
skip_processor_output_fn=_skip_proc_output_fn_Attention_CogView4AttnProcessor, |
|
|
), |
|
|
) |
|
|
|
|
|
|
|
|
AttentionProcessorRegistry.register( |
|
|
model_class=WanAttnProcessor2_0, |
|
|
metadata=AttentionProcessorMetadata( |
|
|
skip_processor_output_fn=_skip_proc_output_fn_Attention_WanAttnProcessor2_0, |
|
|
), |
|
|
) |
|
|
|
|
|
|
|
|
AttentionProcessorRegistry.register( |
|
|
model_class=FluxAttnProcessor, |
|
|
metadata=AttentionProcessorMetadata(skip_processor_output_fn=_skip_proc_output_fn_Attention_FluxAttnProcessor), |
|
|
) |
|
|
|
|
|
|
|
|
AttentionProcessorRegistry.register( |
|
|
model_class=QwenDoubleStreamAttnProcessor2_0, |
|
|
metadata=AttentionProcessorMetadata( |
|
|
skip_processor_output_fn=_skip_proc_output_fn_Attention_QwenDoubleStreamAttnProcessor2_0 |
|
|
), |
|
|
) |
|
|
|
|
|
|
|
|
def _register_transformer_blocks_metadata(): |
|
|
from ..models.attention import BasicTransformerBlock |
|
|
from ..models.transformers.cogvideox_transformer_3d import CogVideoXBlock |
|
|
from ..models.transformers.transformer_bria import BriaTransformerBlock |
|
|
from ..models.transformers.transformer_cogview4 import CogView4TransformerBlock |
|
|
from ..models.transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock |
|
|
from ..models.transformers.transformer_hunyuan_video import ( |
|
|
HunyuanVideoSingleTransformerBlock, |
|
|
HunyuanVideoTokenReplaceSingleTransformerBlock, |
|
|
HunyuanVideoTokenReplaceTransformerBlock, |
|
|
HunyuanVideoTransformerBlock, |
|
|
) |
|
|
from ..models.transformers.transformer_ltx import LTXVideoTransformerBlock |
|
|
from ..models.transformers.transformer_mochi import MochiTransformerBlock |
|
|
from ..models.transformers.transformer_qwenimage import QwenImageTransformerBlock |
|
|
from ..models.transformers.transformer_wan import WanTransformerBlock |
|
|
|
|
|
|
|
|
TransformerBlockRegistry.register( |
|
|
model_class=BasicTransformerBlock, |
|
|
metadata=TransformerBlockMetadata( |
|
|
return_hidden_states_index=0, |
|
|
return_encoder_hidden_states_index=None, |
|
|
), |
|
|
) |
|
|
TransformerBlockRegistry.register( |
|
|
model_class=BriaTransformerBlock, |
|
|
metadata=TransformerBlockMetadata( |
|
|
return_hidden_states_index=0, |
|
|
return_encoder_hidden_states_index=None, |
|
|
), |
|
|
) |
|
|
|
|
|
|
|
|
TransformerBlockRegistry.register( |
|
|
model_class=CogVideoXBlock, |
|
|
metadata=TransformerBlockMetadata( |
|
|
return_hidden_states_index=0, |
|
|
return_encoder_hidden_states_index=1, |
|
|
), |
|
|
) |
|
|
|
|
|
|
|
|
TransformerBlockRegistry.register( |
|
|
model_class=CogView4TransformerBlock, |
|
|
metadata=TransformerBlockMetadata( |
|
|
return_hidden_states_index=0, |
|
|
return_encoder_hidden_states_index=1, |
|
|
), |
|
|
) |
|
|
|
|
|
|
|
|
TransformerBlockRegistry.register( |
|
|
model_class=FluxTransformerBlock, |
|
|
metadata=TransformerBlockMetadata( |
|
|
return_hidden_states_index=1, |
|
|
return_encoder_hidden_states_index=0, |
|
|
), |
|
|
) |
|
|
TransformerBlockRegistry.register( |
|
|
model_class=FluxSingleTransformerBlock, |
|
|
metadata=TransformerBlockMetadata( |
|
|
return_hidden_states_index=1, |
|
|
return_encoder_hidden_states_index=0, |
|
|
), |
|
|
) |
|
|
|
|
|
|
|
|
TransformerBlockRegistry.register( |
|
|
model_class=HunyuanVideoTransformerBlock, |
|
|
metadata=TransformerBlockMetadata( |
|
|
return_hidden_states_index=0, |
|
|
return_encoder_hidden_states_index=1, |
|
|
), |
|
|
) |
|
|
TransformerBlockRegistry.register( |
|
|
model_class=HunyuanVideoSingleTransformerBlock, |
|
|
metadata=TransformerBlockMetadata( |
|
|
return_hidden_states_index=0, |
|
|
return_encoder_hidden_states_index=1, |
|
|
), |
|
|
) |
|
|
TransformerBlockRegistry.register( |
|
|
model_class=HunyuanVideoTokenReplaceTransformerBlock, |
|
|
metadata=TransformerBlockMetadata( |
|
|
return_hidden_states_index=0, |
|
|
return_encoder_hidden_states_index=1, |
|
|
), |
|
|
) |
|
|
TransformerBlockRegistry.register( |
|
|
model_class=HunyuanVideoTokenReplaceSingleTransformerBlock, |
|
|
metadata=TransformerBlockMetadata( |
|
|
return_hidden_states_index=0, |
|
|
return_encoder_hidden_states_index=1, |
|
|
), |
|
|
) |
|
|
|
|
|
|
|
|
TransformerBlockRegistry.register( |
|
|
model_class=LTXVideoTransformerBlock, |
|
|
metadata=TransformerBlockMetadata( |
|
|
return_hidden_states_index=0, |
|
|
return_encoder_hidden_states_index=None, |
|
|
), |
|
|
) |
|
|
|
|
|
|
|
|
TransformerBlockRegistry.register( |
|
|
model_class=MochiTransformerBlock, |
|
|
metadata=TransformerBlockMetadata( |
|
|
return_hidden_states_index=0, |
|
|
return_encoder_hidden_states_index=1, |
|
|
), |
|
|
) |
|
|
|
|
|
|
|
|
TransformerBlockRegistry.register( |
|
|
model_class=WanTransformerBlock, |
|
|
metadata=TransformerBlockMetadata( |
|
|
return_hidden_states_index=0, |
|
|
return_encoder_hidden_states_index=None, |
|
|
), |
|
|
) |
|
|
|
|
|
|
|
|
TransformerBlockRegistry.register( |
|
|
model_class=QwenImageTransformerBlock, |
|
|
metadata=TransformerBlockMetadata( |
|
|
return_hidden_states_index=1, |
|
|
return_encoder_hidden_states_index=0, |
|
|
), |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
def _skip_attention___ret___hidden_states(self, *args, **kwargs): |
|
|
hidden_states = kwargs.get("hidden_states", None) |
|
|
if hidden_states is None and len(args) > 0: |
|
|
hidden_states = args[0] |
|
|
return hidden_states |
|
|
|
|
|
|
|
|
def _skip_attention___ret___hidden_states___encoder_hidden_states(self, *args, **kwargs): |
|
|
hidden_states = kwargs.get("hidden_states", None) |
|
|
encoder_hidden_states = kwargs.get("encoder_hidden_states", None) |
|
|
if hidden_states is None and len(args) > 0: |
|
|
hidden_states = args[0] |
|
|
if encoder_hidden_states is None and len(args) > 1: |
|
|
encoder_hidden_states = args[1] |
|
|
return hidden_states, encoder_hidden_states |
|
|
|
|
|
|
|
|
_skip_proc_output_fn_Attention_AttnProcessor2_0 = _skip_attention___ret___hidden_states |
|
|
_skip_proc_output_fn_Attention_CogView4AttnProcessor = _skip_attention___ret___hidden_states___encoder_hidden_states |
|
|
_skip_proc_output_fn_Attention_WanAttnProcessor2_0 = _skip_attention___ret___hidden_states |
|
|
|
|
|
_skip_proc_output_fn_Attention_FluxAttnProcessor = _skip_attention___ret___hidden_states |
|
|
_skip_proc_output_fn_Attention_QwenDoubleStreamAttnProcessor2_0 = _skip_attention___ret___hidden_states |
|
|
|
|
|
|