diff --git a/diffsynth/__init__.py b/diffsynth/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bb67a43fa4e5791ab58e7e40260bc3df8b6bc7cc --- /dev/null +++ b/diffsynth/__init__.py @@ -0,0 +1 @@ +from .core import * diff --git a/diffsynth/configs/__init__.py b/diffsynth/configs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7ad5b7322143d89beb1dbec64d505c5c8b1ecf61 --- /dev/null +++ b/diffsynth/configs/__init__.py @@ -0,0 +1,2 @@ +from .model_configs import MODEL_CONFIGS +from .vram_management_module_maps import VRAM_MANAGEMENT_MODULE_MAPS, VERSION_CHECKER_MAPS diff --git a/diffsynth/configs/model_configs.py b/diffsynth/configs/model_configs.py new file mode 100644 index 0000000000000000000000000000000000000000..6f2d94d394da5f06006320e818b07d60a9b3b09e --- /dev/null +++ b/diffsynth/configs/model_configs.py @@ -0,0 +1,888 @@ +qwen_image_series = [ + { + # Example: ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors") + "model_hash": "0319a1cb19835fb510907dd3367c95ff", + "model_name": "qwen_image_dit", + "model_class": "diffsynth.models.qwen_image_dit.QwenImageDiT", + }, + { + # Example: ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors") + "model_hash": "8004730443f55db63092006dd9f7110e", + "model_name": "qwen_image_text_encoder", + "model_class": "diffsynth.models.qwen_image_text_encoder.QwenImageTextEncoder", + "state_dict_converter": "diffsynth.utils.state_dict_converters.qwen_image_text_encoder.QwenImageTextEncoderStateDictConverter", + }, + { + # Example: ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors") + "model_hash": "ed4ea5824d55ec3107b09815e318123a", + "model_name": "qwen_image_vae", + "model_class": "diffsynth.models.qwen_image_vae.QwenImageVAE", + }, + { + # Example: ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth", origin_file_pattern="model.safetensors") + "model_hash": "073bce9cf969e317e5662cd570c3e79c", + "model_name": "qwen_image_blockwise_controlnet", + "model_class": "diffsynth.models.qwen_image_controlnet.QwenImageBlockWiseControlNet", + }, + { + # Example: ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint", origin_file_pattern="model.safetensors") + "model_hash": "a9e54e480a628f0b956a688a81c33bab", + "model_name": "qwen_image_blockwise_controlnet", + "model_class": "diffsynth.models.qwen_image_controlnet.QwenImageBlockWiseControlNet", + "extra_kwargs": {"additional_in_dim": 4}, + }, + { + # Example: ModelConfig(model_id="DiffSynth-Studio/General-Image-Encoders", origin_file_pattern="SigLIP2-G384/model.safetensors") + "model_hash": "469c78b61e3e31bc9eec0d0af3d3f2f8", + "model_name": "siglip2_image_encoder", + "model_class": "diffsynth.models.siglip2_image_encoder.Siglip2ImageEncoder", + }, + { + # Example: ModelConfig(model_id="DiffSynth-Studio/General-Image-Encoders", origin_file_pattern="DINOv3-7B/model.safetensors") + "model_hash": "5722b5c873720009de96422993b15682", + "model_name": "dinov3_image_encoder", + "model_class": "diffsynth.models.dinov3_image_encoder.DINOv3ImageEncoder", + }, + { + # Example: + "model_hash": "a166c33455cdbd89c0888a3645ca5c0f", + "model_name": "qwen_image_image2lora_coarse", + "model_class": "diffsynth.models.qwen_image_image2lora.QwenImageImage2LoRAModel", + }, + { + # Example: + "model_hash": "a5476e691767a4da6d3a6634a10f7408", + "model_name": "qwen_image_image2lora_fine", + "model_class": "diffsynth.models.qwen_image_image2lora.QwenImageImage2LoRAModel", + "extra_kwargs": {"residual_length": 37*37+7, "residual_mid_dim": 64} + }, + { + # Example: + "model_hash": "0aad514690602ecaff932c701cb4b0bb", + "model_name": "qwen_image_image2lora_style", + "model_class": "diffsynth.models.qwen_image_image2lora.QwenImageImage2LoRAModel", + "extra_kwargs": {"compress_dim": 64, "use_residual": False} + }, + { + # Example: ModelConfig(model_id="Qwen/Qwen-Image-Layered", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors") + "model_hash": "8dc8cda05de16c73afa755e2c1ce2839", + "model_name": "qwen_image_dit", + "model_class": "diffsynth.models.qwen_image_dit.QwenImageDiT", + "extra_kwargs": {"use_layer3d_rope": True, "use_additional_t_cond": True} + }, + { + # Example: ModelConfig(model_id="Qwen/Qwen-Image-Layered", origin_file_pattern="vae/diffusion_pytorch_model.safetensors") + "model_hash": "44b39ddc499e027cfb24f7878d7416b9", + "model_name": "qwen_image_vae", + "model_class": "diffsynth.models.qwen_image_vae.QwenImageVAE", + "extra_kwargs": {"image_channels": 4} + }, +] + +wan_series = [ + { + # Example: ModelConfig(model_id="krea/krea-realtime-video", origin_file_pattern="krea-realtime-video-14b.safetensors") + "model_hash": "5ec04e02b42d2580483ad69f4e76346a", + "model_name": "wan_video_dit", + "model_class": "diffsynth.models.wan_video_dit.WanModel", + "extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 16, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06}, + "state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_dit.WanVideoDiTStateDictConverter", + }, + { + # Example: ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth") + "model_hash": "9c8818c2cbea55eca56c7b447df170da", + "model_name": "wan_video_text_encoder", + "model_class": "diffsynth.models.wan_video_text_encoder.WanTextEncoder", + }, + { + # Example: ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="Wan2.1_VAE.pth") + "model_hash": "ccc42284ea13e1ad04693284c7a09be6", + "model_name": "wan_video_vae", + "model_class": "diffsynth.models.wan_video_vae.WanVideoVAE", + "state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_vae.WanVideoVAEStateDictConverter", + }, + { + # Example: ModelConfig(model_id="meituan-longcat/LongCat-Video", origin_file_pattern="dit/diffusion_pytorch_model*.safetensors") + "model_hash": "8b27900f680d7251ce44e2dc8ae1ffef", + "model_name": "wan_video_dit", + "model_class": "diffsynth.models.longcat_video_dit.LongCatVideoTransformer3DModel", + }, + { + # Example: ModelConfig(model_id="ByteDance/Video-As-Prompt-Wan2.1-14B", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors") + "model_hash": "5f90e66a0672219f12d9a626c8c21f61", + "model_name": "wan_video_dit", + "model_class": "diffsynth.models.wan_video_dit.WanModel", + "extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06}, + "state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_dit.WanVideoDiTFromDiffusers" + }, + { + # Example: ModelConfig(model_id="ByteDance/Video-As-Prompt-Wan2.1-14B", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors") + "model_hash": "5f90e66a0672219f12d9a626c8c21f61", + "model_name": "wan_video_vap", + "model_class": "diffsynth.models.wan_video_mot.MotWanModel", + "state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_mot.WanVideoMotStateDictConverter" + }, + { + # Example: ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-480P", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth") + "model_hash": "5941c53e207d62f20f9025686193c40b", + "model_name": "wan_video_image_encoder", + "model_class": "diffsynth.models.wan_video_image_encoder.WanImageEncoder", + "state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_image_encoder.WanImageEncoderStateDictConverter" + }, + { + # Example: ModelConfig(model_id="DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1", origin_file_pattern="model.safetensors") + "model_hash": "dbd5ec76bbf977983f972c151d545389", + "model_name": "wan_video_motion_controller", + "model_class": "diffsynth.models.wan_video_motion_controller.WanMotionControllerModel", + }, + { + # Example: ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors") + "model_hash": "9269f8db9040a9d860eaca435be61814", + "model_name": "wan_video_dit", + "model_class": "diffsynth.models.wan_video_dit.WanModel", + "extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 16, 'dim': 1536, 'ffn_dim': 8960, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 12, 'num_layers': 30, 'eps': 1e-06} + }, + { + # Example: ModelConfig(model_id="Wan-AI/Wan2.1-FLF2V-14B-720P", origin_file_pattern="diffusion_pytorch_model*.safetensors") + "model_hash": "3ef3b1f8e1dab83d5b71fd7b617f859f", + "model_name": "wan_video_dit", + "model_class": "diffsynth.models.wan_video_dit.WanModel", + "extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06, 'has_image_pos_emb': True} + }, + { + # Example: ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors") + "model_hash": "349723183fc063b2bfc10bb2835cf677", + "model_name": "wan_video_dit", + "model_class": "diffsynth.models.wan_video_dit.WanModel", + "extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 48, 'dim': 1536, 'ffn_dim': 8960, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 12, 'num_layers': 30, 'eps': 1e-06} + }, + { + # Example: ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-InP", origin_file_pattern="diffusion_pytorch_model*.safetensors") + "model_hash": "6d6ccde6845b95ad9114ab993d917893", + "model_name": "wan_video_dit", + "model_class": "diffsynth.models.wan_video_dit.WanModel", + "extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 1536, 'ffn_dim': 8960, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 12, 'num_layers': 30, 'eps': 1e-06} + }, + { + # Example: ModelConfig(model_id="PAI/Wan2.1-Fun-14B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors") + "model_hash": "efa44cddf936c70abd0ea28b6cbe946c", + "model_name": "wan_video_dit", + "model_class": "diffsynth.models.wan_video_dit.WanModel", + "extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 48, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06} + }, + { + # Example: ModelConfig(model_id="PAI/Wan2.1-Fun-14B-InP", origin_file_pattern="diffusion_pytorch_model*.safetensors") + "model_hash": "6bfcfb3b342cb286ce886889d519a77e", + "model_name": "wan_video_dit", + "model_class": "diffsynth.models.wan_video_dit.WanModel", + "extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06} + }, + { + # Example: ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera", origin_file_pattern="diffusion_pytorch_model*.safetensors") + "model_hash": "ac6a5aa74f4a0aab6f64eb9a72f19901", + "model_name": "wan_video_dit", + "model_class": "diffsynth.models.wan_video_dit.WanModel", + "extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 32, 'dim': 1536, 'ffn_dim': 8960, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 12, 'num_layers': 30, 'eps': 1e-06, 'has_ref_conv': False, 'add_control_adapter': True, 'in_dim_control_adapter': 24} + }, + { + # Example: ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors") + "model_hash": "70ddad9d3a133785da5ea371aae09504", + "model_name": "wan_video_dit", + "model_class": "diffsynth.models.wan_video_dit.WanModel", + "extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 48, 'dim': 1536, 'ffn_dim': 8960, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 12, 'num_layers': 30, 'eps': 1e-06, 'has_ref_conv': True} + }, + { + # Example: ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control-Camera", origin_file_pattern="diffusion_pytorch_model*.safetensors") + "model_hash": "b61c605c2adbd23124d152ed28e049ae", + "model_name": "wan_video_dit", + "model_class": "diffsynth.models.wan_video_dit.WanModel", + "extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 32, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06, 'has_ref_conv': False, 'add_control_adapter': True, 'in_dim_control_adapter': 24} + }, + { + # Example: ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors") + "model_hash": "26bde73488a92e64cc20b0a7485b9e5b", + "model_name": "wan_video_dit", + "model_class": "diffsynth.models.wan_video_dit.WanModel", + "extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 48, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06, 'has_ref_conv': True} + }, + { + # Example: ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors") + "model_hash": "aafcfd9672c3a2456dc46e1cb6e52c70", + "model_name": "wan_video_dit", + "model_class": "diffsynth.models.wan_video_dit.WanModel", + "extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 16, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06} + }, + { + # Example: ModelConfig(model_id="iic/VACE-Wan2.1-1.3B-Preview", origin_file_pattern="diffusion_pytorch_model*.safetensors") + "model_hash": "a61453409b67cd3246cf0c3bebad47ba", + "model_name": "wan_video_dit", + "model_class": "diffsynth.models.wan_video_dit.WanModel", + "extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 16, 'dim': 1536, 'ffn_dim': 8960, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 12, 'num_layers': 30, 'eps': 1e-06}, + "state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_dit.WanVideoDiTStateDictConverter", + }, + { + # Example: ModelConfig(model_id="iic/VACE-Wan2.1-1.3B-Preview", origin_file_pattern="diffusion_pytorch_model*.safetensors") + "model_hash": "a61453409b67cd3246cf0c3bebad47ba", + "model_name": "wan_video_vace", + "model_class": "diffsynth.models.wan_video_vace.VaceWanModel", + "extra_kwargs": {"use_target_text_encoder": True}, + "state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_vace.VaceWanModelDictConverter" + }, + { + # Example: ModelConfig(model_id="Wan-AI/Wan2.1-VACE-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors") + "model_hash": "7a513e1f257a861512b1afd387a8ecd9", + "model_name": "wan_video_dit", + "model_class": "diffsynth.models.wan_video_dit.WanModel", + "extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 16, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06}, + "state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_dit.WanVideoDiTStateDictConverter", + }, + { + # Example: ModelConfig(model_id="Wan-AI/Wan2.1-VACE-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors") + "model_hash": "7a513e1f257a861512b1afd387a8ecd9", + "model_name": "wan_video_vace", + "model_class": "diffsynth.models.wan_video_vace.VaceWanModel", + "extra_kwargs": {'vace_layers': (0, 5, 10, 15, 20, 25, 30, 35), 'vace_in_dim': 96, 'glyph_channels': 16, 'patch_size': (1, 2, 2), 'has_image_input': False, 'dim': 5120, 'num_heads': 40, 'ffn_dim': 13824, 'eps': 1e-06}, + "state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_vace.VaceWanModelDictConverter" + }, + { + # Example: ModelConfig(model_id="Wan-AI/Wan2.2-Animate-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors") + "model_hash": "31fa352acb8a1b1d33cd8764273d80a2", + "model_name": "wan_video_dit", + "model_class": "diffsynth.models.wan_video_dit.WanModel", + "extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06}, + "state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_dit.WanVideoDiTStateDictConverter" + }, + { + # Example: ModelConfig(model_id="Wan-AI/Wan2.2-Animate-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors") + "model_hash": "31fa352acb8a1b1d33cd8764273d80a2", + "model_name": "wan_video_animate_adapter", + "model_class": "diffsynth.models.wan_video_animate_adapter.WanAnimateAdapter", + "state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_animate_adapter.WanAnimateAdapterStateDictConverter" + }, + { + # Example: ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-Control-Camera", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors") + "model_hash": "47dbeab5e560db3180adf51dc0232fb1", + "model_name": "wan_video_dit", + "model_class": "diffsynth.models.wan_video_dit.WanModel", + "extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06, 'has_ref_conv': False, 'add_control_adapter': True, 'in_dim_control_adapter': 24, 'require_clip_embedding': False} + }, + { + # Example: ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-Control", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors") + "model_hash": "2267d489f0ceb9f21836532952852ee5", + "model_name": "wan_video_dit", + "model_class": "diffsynth.models.wan_video_dit.WanModel", + "extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 52, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06, 'has_ref_conv': True, 'require_clip_embedding': False}, + }, + { + # Example: ModelConfig(model_id="Wan-AI/Wan2.2-I2V-A14B", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors") + "model_hash": "5b013604280dd715f8457c6ed6d6a626", + "model_name": "wan_video_dit", + "model_class": "diffsynth.models.wan_video_dit.WanModel", + "extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06, 'require_clip_embedding': False} + }, + { + # Example: ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors") + "model_hash": "966cffdcc52f9c46c391768b27637614", + "model_name": "wan_video_dit", + "model_class": "diffsynth.models.wan_video_dit_s2v.WanS2VModel", + "extra_kwargs": {'dim': 5120, 'in_dim': 16, 'ffn_dim': 13824, 'out_dim': 16, 'text_dim': 4096, 'freq_dim': 256, 'eps': 1e-06, 'patch_size': (1, 2, 2), 'num_heads': 40, 'num_layers': 40, 'cond_dim': 16, 'audio_dim': 1024, 'num_audio_token': 4} + }, + { + # Example: ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="diffusion_pytorch_model*.safetensors") + "model_hash": "1f5ab7703c6fc803fdded85ff040c316", + "model_name": "wan_video_dit", + "model_class": "diffsynth.models.wan_video_dit.WanModel", + "extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 48, 'dim': 3072, 'ffn_dim': 14336, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 48, 'num_heads': 24, 'num_layers': 30, 'eps': 1e-06, 'seperated_timestep': True, 'require_clip_embedding': False, 'require_vae_embedding': False, 'fuse_vae_embedding_in_latents': True} + }, + { + # Example: ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="Wan2.2_VAE.pth") + "model_hash": "e1de6c02cdac79f8b739f4d3698cd216", + "model_name": "wan_video_vae", + "model_class": "diffsynth.models.wan_video_vae.WanVideoVAE38", + "state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_vae.WanVideoVAEStateDictConverter", + }, + { + # Example: ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="wav2vec2-large-xlsr-53-english/model.safetensors") + "model_hash": "06be60f3a4526586d8431cd038a71486", + "model_name": "wans2v_audio_encoder", + "model_class": "diffsynth.models.wav2vec.WanS2VAudioEncoder", + "state_dict_converter": "diffsynth.utils.state_dict_converters.wans2v_audio_encoder.WanS2VAudioEncoderStateDictConverter", + }, + { + # Example: ModelConfig(model_id="Wan-AI/WanToDance-14B", origin_file_pattern="global_model.safetensors") + "model_hash": "eb18873fc0ba77b541eb7b62dbcd2059", + "model_name": "wan_video_dit", + "model_class": "diffsynth.models.wan_video_dit.WanModel", + "extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06, 'wantodance_enable_music_inject': True, 'wantodance_music_inject_layers': [0, 4, 8, 12, 16, 20, 24, 27], 'wantodance_enable_refimage': True, 'has_ref_conv': True, 'wantodance_enable_refface': False, 'wantodance_enable_global': True, 'wantodance_enable_dynamicfps': True, 'wantodance_enable_unimodel': True} + }, +] + +flux_series = [ + { + # Example: ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors") + "model_hash": "a29710fea6dddb0314663ee823598e50", + "model_name": "flux_dit", + "model_class": "diffsynth.models.flux_dit.FluxDiT", + "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverter", + }, + { + # Supported due to historical reasons. + "model_hash": "605c56eab23e9e2af863ad8f0813a25d", + "model_name": "flux_dit", + "model_class": "diffsynth.models.flux_dit.FluxDiT", + "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverterFromDiffusers", + }, + { + # Example: ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors") + "model_hash": "94eefa3dac9cec93cb1ebaf1747d7b78", + "model_name": "flux_text_encoder_clip", + "model_class": "diffsynth.models.flux_text_encoder_clip.FluxTextEncoderClip", + "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_text_encoder_clip.FluxTextEncoderClipStateDictConverter", + }, + { + # Example: ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors") + "model_hash": "22540b49eaedbc2f2784b2091a234c7c", + "model_name": "flux_text_encoder_t5", + "model_class": "diffsynth.models.flux_text_encoder_t5.FluxTextEncoderT5", + "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_text_encoder_t5.FluxTextEncoderT5StateDictConverter", + }, + { + # Example: ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors") + "model_hash": "21ea55f476dfc4fd135587abb59dfe5d", + "model_name": "flux_vae_encoder", + "model_class": "diffsynth.models.flux_vae.FluxVAEEncoder", + "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_vae.FluxVAEEncoderStateDictConverter", + }, + { + # Example: ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors") + "model_hash": "21ea55f476dfc4fd135587abb59dfe5d", + "model_name": "flux_vae_decoder", + "model_class": "diffsynth.models.flux_vae.FluxVAEDecoder", + "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_vae.FluxVAEDecoderStateDictConverter", + }, + { + # Example: ModelConfig(model_id="ostris/Flex.2-preview", origin_file_pattern="Flex.2-preview.safetensors") + "model_hash": "d02f41c13549fa5093d3521f62a5570a", + "model_name": "flux_dit", + "model_class": "diffsynth.models.flux_dit.FluxDiT", + "extra_kwargs": {'input_dim': 196, 'num_blocks': 8}, + "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverter", + }, + { + # Example: ModelConfig(model_id="DiffSynth-Studio/AttriCtrl-FLUX.1-Dev", origin_file_pattern="models/brightness.safetensors") + "model_hash": "0629116fce1472503a66992f96f3eb1a", + "model_name": "flux_value_controller", + "model_class": "diffsynth.models.flux_value_control.SingleValueEncoder", + }, + { + # Example: ModelConfig(model_id="alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta", origin_file_pattern="diffusion_pytorch_model.safetensors") + "model_hash": "52357cb26250681367488a8954c271e8", + "model_name": "flux_controlnet", + "model_class": "diffsynth.models.flux_controlnet.FluxControlNet", + "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_controlnet.FluxControlNetStateDictConverter", + "extra_kwargs": {"num_joint_blocks": 6, "num_single_blocks": 0, "additional_input_dim": 4}, + }, + { + # Example: ModelConfig(model_id="InstantX/FLUX.1-dev-Controlnet-Union-alpha", origin_file_pattern="diffusion_pytorch_model.safetensors") + "model_hash": "78d18b9101345ff695f312e7e62538c0", + "model_name": "flux_controlnet", + "model_class": "diffsynth.models.flux_controlnet.FluxControlNet", + "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_controlnet.FluxControlNetStateDictConverter", + "extra_kwargs": {"num_mode": 10, "mode_dict": {"canny": 0, "tile": 1, "depth": 2, "blur": 3, "pose": 4, "gray": 5, "lq": 6}}, + }, + { + # Example: ModelConfig(model_id="jasperai/Flux.1-dev-Controlnet-Upscaler", origin_file_pattern="diffusion_pytorch_model.safetensors") + "model_hash": "b001c89139b5f053c715fe772362dd2a", + "model_name": "flux_controlnet", + "model_class": "diffsynth.models.flux_controlnet.FluxControlNet", + "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_controlnet.FluxControlNetStateDictConverter", + "extra_kwargs": {"num_single_blocks": 0}, + }, + { + # Example: ModelConfig(model_id="ByteDance/InfiniteYou", origin_file_pattern="infu_flux_v1.0/aes_stage2/image_proj_model.bin") + "model_hash": "c07c0f04f5ff55e86b4e937c7a40d481", + "model_name": "infiniteyou_image_projector", + "model_class": "diffsynth.models.flux_infiniteyou.InfiniteYouImageProjector", + "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_infiniteyou.FluxInfiniteYouImageProjectorStateDictConverter", + }, + { + # Example: ModelConfig(model_id="ByteDance/InfiniteYou", origin_file_pattern="infu_flux_v1.0/aes_stage2/InfuseNetModel/*.safetensors") + "model_hash": "7f9583eb8ba86642abb9a21a4b2c9e16", + "model_name": "flux_controlnet", + "model_class": "diffsynth.models.flux_controlnet.FluxControlNet", + "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_controlnet.FluxControlNetStateDictConverter", + "extra_kwargs": {"num_joint_blocks": 4, "num_single_blocks": 10}, + }, + { + # Example: ModelConfig(model_id="DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev", origin_file_pattern="model.safetensors") + "model_hash": "77c2e4dd2440269eb33bfaa0d004f6ab", + "model_name": "flux_lora_encoder", + "model_class": "diffsynth.models.flux_lora_encoder.FluxLoRAEncoder", + }, + { + # Example: ModelConfig(model_id="DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev", origin_file_pattern="model.safetensors") + "model_hash": "30143afb2dea73d1ac580e0787628f8c", + "model_name": "flux_lora_patcher", + "model_class": "diffsynth.models.flux_lora_patcher.FluxLoraPatcher", + }, + { + # Example: ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="model*.safetensors") + "model_hash": "2bd19e845116e4f875a0a048e27fc219", + "model_name": "nexus_gen_llm", + "model_class": "diffsynth.models.nexus_gen.NexusGenAutoregressiveModel", + "state_dict_converter": "diffsynth.utils.state_dict_converters.nexus_gen.NexusGenAutoregressiveModelStateDictConverter", + }, + { + # Example: ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="edit_decoder.bin") + "model_hash": "63c969fd37cce769a90aa781fbff5f81", + "model_name": "nexus_gen_editing_adapter", + "model_class": "diffsynth.models.nexus_gen_projector.NexusGenImageEmbeddingMerger", + "state_dict_converter": "diffsynth.utils.state_dict_converters.nexus_gen_projector.NexusGenMergerStateDictConverter", + }, + { + # Example: ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="edit_decoder.bin") + "model_hash": "63c969fd37cce769a90aa781fbff5f81", + "model_name": "flux_dit", + "model_class": "diffsynth.models.flux_dit.FluxDiT", + "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverter", + }, + { + # Example: ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="generation_decoder.bin") + "model_hash": "3e6c61b0f9471135fc9c6d6a98e98b6d", + "model_name": "nexus_gen_generation_adapter", + "model_class": "diffsynth.models.nexus_gen_projector.NexusGenAdapter", + "state_dict_converter": "diffsynth.utils.state_dict_converters.nexus_gen_projector.NexusGenAdapterStateDictConverter", + }, + { + # Example: ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="generation_decoder.bin") + "model_hash": "3e6c61b0f9471135fc9c6d6a98e98b6d", + "model_name": "flux_dit", + "model_class": "diffsynth.models.flux_dit.FluxDiT", + "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverter", + }, + { + # Example: ModelConfig(model_id="InstantX/FLUX.1-dev-IP-Adapter", origin_file_pattern="ip-adapter.bin") + "model_hash": "4daaa66cc656a8fe369908693dad0a35", + "model_name": "flux_ipadapter", + "model_class": "diffsynth.models.flux_ipadapter.FluxIpAdapter", + "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_ipadapter.FluxIpAdapterStateDictConverter", + }, + { + # Example: ModelConfig(model_id="google/siglip-so400m-patch14-384", origin_file_pattern="model.safetensors") + "model_hash": "04d8c1e20a1f1b25f7434f111992a33f", + "model_name": "siglip_vision_model", + "model_class": "diffsynth.models.flux_ipadapter.SiglipVisionModelSO400M", + "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_ipadapter.SiglipStateDictConverter", + }, + { + # Example: ModelConfig(model_id="stepfun-ai/Step1X-Edit", origin_file_pattern="step1x-edit-i1258.safetensors"), + "model_hash": "d30fb9e02b1dbf4e509142f05cf7dd50", + "model_name": "step1x_connector", + "model_class": "diffsynth.models.step1x_connector.Qwen2Connector", + "state_dict_converter": "diffsynth.utils.state_dict_converters.step1x_connector.Qwen2ConnectorStateDictConverter", + }, + { + # Example: ModelConfig(model_id="stepfun-ai/Step1X-Edit", origin_file_pattern="step1x-edit-i1258.safetensors"), + "model_hash": "d30fb9e02b1dbf4e509142f05cf7dd50", + "model_name": "flux_dit", + "model_class": "diffsynth.models.flux_dit.FluxDiT", + "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverter", + "extra_kwargs": {"disable_guidance_embedder": True}, + }, + { + # Example: ModelConfig(model_id="MAILAND/majicflus_v1", origin_file_pattern="majicflus_v134.safetensors") + "model_hash": "3394f306c4cbf04334b712bf5aaed95f", + "model_name": "flux_dit", + "model_class": "diffsynth.models.flux_dit.FluxDiT", + "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverter", + }, +] + +flux2_series = [ + { + # Example: ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="text_encoder/*.safetensors") + "model_hash": "28fca3d8e5bf2a2d1271748a773f6757", + "model_name": "flux2_text_encoder", + "model_class": "diffsynth.models.flux2_text_encoder.Flux2TextEncoder", + "state_dict_converter": "diffsynth.utils.state_dict_converters.flux2_text_encoder.Flux2TextEncoderStateDictConverter", + }, + { + # Example: ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="transformer/*.safetensors") + "model_hash": "d38e1d5c5aec3b0a11e79327ac6e3b0f", + "model_name": "flux2_dit", + "model_class": "diffsynth.models.flux2_dit.Flux2DiT", + }, + { + # Example: ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="vae/diffusion_pytorch_model.safetensors") + "model_hash": "c54288e3ee12ca215898840682337b95", + "model_name": "flux2_vae", + "model_class": "diffsynth.models.flux2_vae.Flux2VAE", + }, + { + # Example: ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="transformer/*.safetensors") + "model_hash": "3bde7b817fec8143028b6825a63180df", + "model_name": "flux2_dit", + "model_class": "diffsynth.models.flux2_dit.Flux2DiT", + "extra_kwargs": {"guidance_embeds": False, "joint_attention_dim": 7680, "num_attention_heads": 24, "num_layers": 5, "num_single_layers": 20} + }, + { + # Example: ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="text_encoder/*.safetensors") + "model_hash": "9195f3ea256fcd0ae6d929c203470754", + "model_name": "z_image_text_encoder", + "model_class": "diffsynth.models.z_image_text_encoder.ZImageTextEncoder", + "extra_kwargs": {"model_size": "8B"}, + "state_dict_converter": "diffsynth.utils.state_dict_converters.z_image_text_encoder.ZImageTextEncoderStateDictConverter", + }, + { + # Example: ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="transformer/*.safetensors") + "model_hash": "39c6fc48f07bebecedbbaa971ff466c8", + "model_name": "flux2_dit", + "model_class": "diffsynth.models.flux2_dit.Flux2DiT", + "extra_kwargs": {"guidance_embeds": False, "joint_attention_dim": 12288, "num_attention_heads": 32, "num_layers": 8, "num_single_layers": 24} + }, +] + +z_image_series = [ + { + # Example: ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="transformer/*.safetensors") + "model_hash": "fc3a8a1247fe185ce116ccbe0e426c28", + "model_name": "z_image_dit", + "model_class": "diffsynth.models.z_image_dit.ZImageDiT", + }, + { + # Example: ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors") + "model_hash": "0f050f62a88876fea6eae0a18dac5a2e", + "model_name": "z_image_text_encoder", + "model_class": "diffsynth.models.z_image_text_encoder.ZImageTextEncoder", + }, + { + # Example: ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/vae/diffusion_pytorch_model.safetensors") + "model_hash": "1aafa3cc91716fb6b300cc1cd51b85a3", + "model_name": "flux_vae_encoder", + "model_class": "diffsynth.models.flux_vae.FluxVAEEncoder", + "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_vae.FluxVAEEncoderStateDictConverterDiffusers", + "extra_kwargs": {"use_conv_attention": False}, + }, + { + # Example: ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/vae/diffusion_pytorch_model.safetensors") + "model_hash": "1aafa3cc91716fb6b300cc1cd51b85a3", + "model_name": "flux_vae_decoder", + "model_class": "diffsynth.models.flux_vae.FluxVAEDecoder", + "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_vae.FluxVAEDecoderStateDictConverterDiffusers", + "extra_kwargs": {"use_conv_attention": False}, + }, + { + # Example: ModelConfig(model_id="Tongyi-MAI/Z-Image-Omni-Base", origin_file_pattern="transformer/*.safetensors") + "model_hash": "aa3563718e5c3ecde3dfbb020ca61180", + "model_name": "z_image_dit", + "model_class": "diffsynth.models.z_image_dit.ZImageDiT", + "extra_kwargs": {"siglip_feat_dim": 1152}, + }, + { + # Example: ModelConfig(model_id="Tongyi-MAI/Z-Image-Omni-Base", origin_file_pattern="siglip/model.safetensors") + "model_hash": "89d48e420f45cff95115a9f3e698d44a", + "model_name": "siglip_vision_model_428m", + "model_class": "diffsynth.models.siglip2_image_encoder.Siglip2ImageEncoder428M", + }, + { + # Example: ModelConfig(model_id="PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1", origin_file_pattern="Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.safetensors") + "model_hash": "1677708d40029ab380a95f6c731a57d7", + "model_name": "z_image_controlnet", + "model_class": "diffsynth.models.z_image_controlnet.ZImageControlNet", + }, + { + # Example: ??? + "model_hash": "9510cb8cd1dd34ee0e4f111c24905510", + "model_name": "z_image_image2lora_style", + "model_class": "diffsynth.models.z_image_image2lora.ZImageImage2LoRAModel", + "extra_kwargs": {"compress_dim": 128}, + }, + { + # Example: ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="model.safetensors") + "model_hash": "1392adecee344136041e70553f875f31", + "model_name": "z_image_text_encoder", + "model_class": "diffsynth.models.z_image_text_encoder.ZImageTextEncoder", + "extra_kwargs": {"model_size": "0.6B"}, + "state_dict_converter": "diffsynth.utils.state_dict_converters.z_image_text_encoder.ZImageTextEncoderStateDictConverter", + }, + { + # To ensure compatibility with the `model.diffusion_model` prefix introduced by other frameworks. + "model_hash": "8cf241a0d32f93d5de368502a086852f", + "model_name": "z_image_dit", + "model_class": "diffsynth.models.z_image_dit.ZImageDiT", + "state_dict_converter": "diffsynth.utils.state_dict_converters.z_image_dit.ZImageDiTStateDictConverter", + }, +] +""" +Offical model repo: https://www.modelscope.cn/models/Lightricks/LTX-2 +Repackaged model repo: https://www.modelscope.cn/models/DiffSynth-Studio/LTX-2-Repackage +For base models of LTX-2, offical checkpoint (with model config ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")) +and repackaged checkpoints (with model config ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="*.safetensors")) are both supported. +We have repackeged the official checkpoints in DiffSynth-Studio/LTX-2-Repackage repo to support separate loading of different submodules, +and avoid redundant memory usage when users only want to use part of the model. +""" +ltx2_series = [ + { + # Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors") + "model_hash": "aca7b0bbf8415e9c98360750268915fc", + "model_name": "ltx2_dit", + "model_class": "diffsynth.models.ltx2_dit.LTXModel", + "state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_dit.LTXModelStateDictConverter", + }, + { + # Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="transformer.safetensors") + "model_hash": "c567aaa37d5ed7454c73aa6024458661", + "model_name": "ltx2_dit", + "model_class": "diffsynth.models.ltx2_dit.LTXModel", + "state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_dit.LTXModelStateDictConverter", + }, + { + # Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors") + "model_hash": "aca7b0bbf8415e9c98360750268915fc", + "model_name": "ltx2_video_vae_encoder", + "model_class": "diffsynth.models.ltx2_video_vae.LTX2VideoEncoder", + "state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_video_vae.LTX2VideoEncoderStateDictConverter", + }, + { + # Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_encoder.safetensors") + "model_hash": "7f7e904a53260ec0351b05f32153754b", + "model_name": "ltx2_video_vae_encoder", + "model_class": "diffsynth.models.ltx2_video_vae.LTX2VideoEncoder", + "state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_video_vae.LTX2VideoEncoderStateDictConverter", + }, + { + # Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors") + "model_hash": "aca7b0bbf8415e9c98360750268915fc", + "model_name": "ltx2_video_vae_decoder", + "model_class": "diffsynth.models.ltx2_video_vae.LTX2VideoDecoder", + "state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_video_vae.LTX2VideoDecoderStateDictConverter", + }, + { + # Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_decoder.safetensors") + "model_hash": "dc6029ca2825147872b45e35a2dc3a97", + "model_name": "ltx2_video_vae_decoder", + "model_class": "diffsynth.models.ltx2_video_vae.LTX2VideoDecoder", + "state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_video_vae.LTX2VideoDecoderStateDictConverter", + }, + { + # Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors") + "model_hash": "aca7b0bbf8415e9c98360750268915fc", + "model_name": "ltx2_audio_vae_decoder", + "model_class": "diffsynth.models.ltx2_audio_vae.LTX2AudioDecoder", + "state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2AudioDecoderStateDictConverter", + }, + { + # Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_decoder.safetensors") + "model_hash": "7d7823dde8f1ea0b50fb07ac329dd4cb", + "model_name": "ltx2_audio_vae_decoder", + "model_class": "diffsynth.models.ltx2_audio_vae.LTX2AudioDecoder", + "state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2AudioDecoderStateDictConverter", + }, + { + # Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors") + "model_hash": "aca7b0bbf8415e9c98360750268915fc", + "model_name": "ltx2_audio_vocoder", + "model_class": "diffsynth.models.ltx2_audio_vae.LTX2Vocoder", + "state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2VocoderStateDictConverter", + }, + { + # Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vocoder.safetensors") + "model_hash": "f471360f6b24bef702ab73133d9f8bb9", + "model_name": "ltx2_audio_vocoder", + "model_class": "diffsynth.models.ltx2_audio_vae.LTX2Vocoder", + "state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2VocoderStateDictConverter", + }, + { + # Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors") + "model_hash": "aca7b0bbf8415e9c98360750268915fc", + "model_name": "ltx2_audio_vae_encoder", + "model_class": "diffsynth.models.ltx2_audio_vae.LTX2AudioEncoder", + "state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2AudioEncoderStateDictConverter", + }, + { + # Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_encoder.safetensors") + "model_hash": "29338f3b95e7e312a3460a482e4f4554", + "model_name": "ltx2_audio_vae_encoder", + "model_class": "diffsynth.models.ltx2_audio_vae.LTX2AudioEncoder", + "state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2AudioEncoderStateDictConverter", + }, + { + # Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors") + "model_hash": "aca7b0bbf8415e9c98360750268915fc", + "model_name": "ltx2_text_encoder_post_modules", + "model_class": "diffsynth.models.ltx2_text_encoder.LTX2TextEncoderPostModules", + "state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_text_encoder.LTX2TextEncoderPostModulesStateDictConverter", + }, + { + # Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="text_encoder_post_modules.safetensors") + "model_hash": "981629689c8be92a712ab3c5eb4fc3f6", + "model_name": "ltx2_text_encoder_post_modules", + "model_class": "diffsynth.models.ltx2_text_encoder.LTX2TextEncoderPostModules", + "state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_text_encoder.LTX2TextEncoderPostModulesStateDictConverter", + }, + { + # Example: ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors") + "model_hash": "33917f31c4a79196171154cca39f165e", + "model_name": "ltx2_text_encoder", + "model_class": "diffsynth.models.ltx2_text_encoder.LTX2TextEncoder", + "state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_text_encoder.LTX2TextEncoderStateDictConverter", + }, + { + # Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors") + "model_hash": "c79c458c6e99e0e14d47e676761732d2", + "model_name": "ltx2_latent_upsampler", + "model_class": "diffsynth.models.ltx2_upsampler.LTX2LatentUpsampler", + }, + { + # Example: ModelConfig(model_id="Lightricks/LTX-2.3", origin_file_pattern="ltx-2.3-22b-dev.safetensors") + "model_hash": "f3a83ecf3995dcc4fae2d27e08ad5767", + "model_name": "ltx2_dit", + "model_class": "diffsynth.models.ltx2_dit.LTXModel", + "extra_kwargs": {"apply_gated_attention": True, "cross_attention_adaln": True, "caption_channels": None}, + "state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_dit.LTXModelStateDictConverter", + }, + { + # Example: ModelConfig(model_id="Lightricks/LTX-2.3", origin_file_pattern="ltx-2.3-22b-dev.safetensors") + "model_hash": "f3a83ecf3995dcc4fae2d27e08ad5767", + "model_name": "ltx2_video_vae_encoder", + "model_class": "diffsynth.models.ltx2_video_vae.LTX2VideoEncoder", + "extra_kwargs": {"encoder_version": "ltx-2.3"}, + "state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_video_vae.LTX2VideoEncoderStateDictConverter", + }, + { + # Example: ModelConfig(model_id="Lightricks/LTX-2.3", origin_file_pattern="ltx-2.3-22b-dev.safetensors") + "model_hash": "f3a83ecf3995dcc4fae2d27e08ad5767", + "model_name": "ltx2_video_vae_decoder", + "model_class": "diffsynth.models.ltx2_video_vae.LTX2VideoDecoder", + "extra_kwargs": {"decoder_version": "ltx-2.3"}, + "state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_video_vae.LTX2VideoDecoderStateDictConverter", + }, + { + # Example: ModelConfig(model_id="Lightricks/LTX-2.3", origin_file_pattern="ltx-2.3-22b-dev.safetensors") + "model_hash": "f3a83ecf3995dcc4fae2d27e08ad5767", + "model_name": "ltx2_audio_vae_decoder", + "model_class": "diffsynth.models.ltx2_audio_vae.LTX2AudioDecoder", + "state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2AudioDecoderStateDictConverter", + }, + { + # Example: ModelConfig(model_id="Lightricks/LTX-2.3", origin_file_pattern="ltx-2.3-22b-dev.safetensors") + "model_hash": "f3a83ecf3995dcc4fae2d27e08ad5767", + "model_name": "ltx2_audio_vocoder", + "model_class": "diffsynth.models.ltx2_audio_vae.LTX2VocoderWithBWE", + "state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2VocoderStateDictConverter", + }, + { + # Example: ModelConfig(model_id="Lightricks/LTX-2.3", origin_file_pattern="ltx-2.3-22b-dev.safetensors") + "model_hash": "f3a83ecf3995dcc4fae2d27e08ad5767", + "model_name": "ltx2_audio_vae_encoder", + "model_class": "diffsynth.models.ltx2_audio_vae.LTX2AudioEncoder", + "state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2AudioEncoderStateDictConverter", + }, + { + # Example: ModelConfig(model_id="Lightricks/LTX-2.3", origin_file_pattern="ltx-2.3-22b-dev.safetensors") + "model_hash": "f3a83ecf3995dcc4fae2d27e08ad5767", + "model_name": "ltx2_text_encoder_post_modules", + "model_class": "diffsynth.models.ltx2_text_encoder.LTX2TextEncoderPostModules", + "extra_kwargs": {"separated_audio_video": True, "embedding_dim_gemma": 3840, "num_layers_gemma": 49, "video_attention_heads": 32, "video_attention_head_dim": 128, "audio_attention_heads": 32, "audio_attention_head_dim": 64, "num_connector_layers": 8, "apply_gated_attention": True}, + "state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_text_encoder.LTX2TextEncoderPostModulesStateDictConverter", + }, + { + # Example: ModelConfig(model_id="Lightricks/LTX-2.3", origin_file_pattern="ltx-2.3-spatial-upscaler-x2-1.0.safetensors") + "model_hash": "aed408774d694a2452f69936c32febb5", + "model_name": "ltx2_latent_upsampler", + "model_class": "diffsynth.models.ltx2_upsampler.LTX2LatentUpsampler", + "extra_kwargs": {"rational_resampler": False}, + }, + { + # Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2.3-Repackage", origin_file_pattern="transformer.safetensors") + "model_hash": "1c55afad76ed33c112a2978550b524d1", + "model_name": "ltx2_dit", + "model_class": "diffsynth.models.ltx2_dit.LTXModel", + "extra_kwargs": {"apply_gated_attention": True, "cross_attention_adaln": True, "caption_channels": None}, + "state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_dit.LTXModelStateDictConverter", + }, + { + # Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2.3-Repackage", origin_file_pattern="video_vae_encoder.safetensors") + "model_hash": "eecdc07c2ec30863b8a2b8b2134036cf", + "model_name": "ltx2_video_vae_encoder", + "model_class": "diffsynth.models.ltx2_video_vae.LTX2VideoEncoder", + "extra_kwargs": {"encoder_version": "ltx-2.3"}, + "state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_video_vae.LTX2VideoEncoderStateDictConverter", + }, + { + # Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2.3-Repackage", origin_file_pattern="video_vae_decoder.safetensors") + "model_hash": "deda2f542e17ee25bc8c38fd605316ea", + "model_name": "ltx2_video_vae_decoder", + "model_class": "diffsynth.models.ltx2_video_vae.LTX2VideoDecoder", + "extra_kwargs": {"decoder_version": "ltx-2.3"}, + "state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_video_vae.LTX2VideoDecoderStateDictConverter", + }, + { + # Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2.3-Repackage", origin_file_pattern="audio_vocoder.safetensors") + "model_hash": "7d7823dde8f1ea0b50fb07ac329dd4cb", + "model_name": "ltx2_audio_vae_decoder", + "model_class": "diffsynth.models.ltx2_audio_vae.LTX2AudioDecoder", + "state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2AudioDecoderStateDictConverter", + }, + { + # Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2.3-Repackage", origin_file_pattern="audio_vae_encoder.safetensors") + "model_hash": "29338f3b95e7e312a3460a482e4f4554", + "model_name": "ltx2_audio_vae_encoder", + "model_class": "diffsynth.models.ltx2_audio_vae.LTX2AudioEncoder", + "state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2AudioEncoderStateDictConverter", + }, + { + # Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2.3-Repackage", origin_file_pattern="audio_vocoder.safetensors") + "model_hash": "cd436c99e69ec5c80f050f0944f02a15", + "model_name": "ltx2_audio_vocoder", + "model_class": "diffsynth.models.ltx2_audio_vae.LTX2VocoderWithBWE", + "state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2VocoderStateDictConverter", + }, + { + # Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2.3-Repackage", origin_file_pattern="text_encoder_post_modules.safetensors") + "model_hash": "05da2aab1c4b061f72c426311c165a43", + "model_name": "ltx2_text_encoder_post_modules", + "model_class": "diffsynth.models.ltx2_text_encoder.LTX2TextEncoderPostModules", + "extra_kwargs": {"separated_audio_video": True, "embedding_dim_gemma": 3840, "num_layers_gemma": 49, "video_attention_heads": 32, "video_attention_head_dim": 128, "audio_attention_heads": 32, "audio_attention_head_dim": 64, "num_connector_layers": 8, "apply_gated_attention": True}, + "state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_text_encoder.LTX2TextEncoderPostModulesStateDictConverter", + }, +] +anima_series = [ + { + # Example: ModelConfig(model_id="circlestone-labs/Anima", origin_file_pattern="split_files/vae/qwen_image_vae.safetensors") + "model_hash": "a9995952c2d8e63cf82e115005eb61b9", + "model_name": "z_image_text_encoder", + "model_class": "diffsynth.models.z_image_text_encoder.ZImageTextEncoder", + "extra_kwargs": {"model_size": "0.6B"}, + }, + { + # Example: ModelConfig(model_id="circlestone-labs/Anima", origin_file_pattern="split_files/diffusion_models/anima-preview.safetensors") + "model_hash": "417673936471e79e31ed4d186d7a3f4a", + "model_name": "anima_dit", + "model_class": "diffsynth.models.anima_dit.AnimaDiT", + "state_dict_converter": "diffsynth.utils.state_dict_converters.anima_dit.AnimaDiTStateDictConverter", + } +] + +mova_series = [ + # Example: ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="audio_dit/diffusion_pytorch_model.safetensors") + { + "model_hash": "8c57e12790e2c45a64817e0ce28cde2f", + "model_name": "mova_audio_dit", + "model_class": "diffsynth.models.mova_audio_dit.MovaAudioDit", + "extra_kwargs": {'has_image_input': False, 'patch_size': [1], 'in_dim': 128, 'dim': 1536, 'ffn_dim': 8960, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 128, 'num_heads': 12, 'num_layers': 30, 'eps': 1e-06} + }, + # Example: ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="audio_vae/diffusion_pytorch_model.safetensors") + { + "model_hash": "418517fb2b4e919d2cac8f314fcf82ac", + "model_name": "mova_audio_vae", + "model_class": "diffsynth.models.mova_audio_vae.DacVAE", + }, + # Example: ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="dual_tower_bridge/diffusion_pytorch_model.safetensors") + { + "model_hash": "d1139dbbc8b4ab53cf4b4243d57bbceb", + "model_name": "mova_dual_tower_bridge", + "model_class": "diffsynth.models.mova_dual_tower_bridge.DualTowerConditionalBridge", + }, +] +MODEL_CONFIGS = qwen_image_series + wan_series + flux_series + flux2_series + z_image_series + ltx2_series + anima_series + mova_series diff --git a/diffsynth/configs/vram_management_module_maps.py b/diffsynth/configs/vram_management_module_maps.py new file mode 100644 index 0000000000000000000000000000000000000000..de276891f94451aac353ec1ff378bb9ed9bae814 --- /dev/null +++ b/diffsynth/configs/vram_management_module_maps.py @@ -0,0 +1,284 @@ +flux_general_vram_config = { + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + "torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.GroupNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + "diffsynth.models.general_modules.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + "diffsynth.models.flux_lora_encoder.LoRALayerBlock": "diffsynth.core.vram.layers.AutoWrappedModule", + "diffsynth.models.flux_lora_patcher.LoraMerger": "diffsynth.core.vram.layers.AutoWrappedModule", +} + +VRAM_MANAGEMENT_MODULE_MAPS = { + "diffsynth.models.qwen_image_dit.QwenImageDiT": { + "diffsynth.models.qwen_image_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + "torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule", + }, + "diffsynth.models.qwen_image_text_encoder.QwenImageTextEncoder": { + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + "torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule", + "transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VLRotaryEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule", + "transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + "transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VisionPatchEmbed": "diffsynth.core.vram.layers.AutoWrappedModule", + "transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VisionRotaryEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule", + }, + "diffsynth.models.qwen_image_vae.QwenImageVAE": { + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + "torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule", + "diffsynth.models.qwen_image_vae.QwenImageRMS_norm": "diffsynth.core.vram.layers.AutoWrappedModule", + }, + "diffsynth.models.qwen_image_controlnet.BlockWiseControlBlock": { + "diffsynth.models.qwen_image_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + }, + "diffsynth.models.siglip2_image_encoder.Siglip2ImageEncoder": { + "transformers.models.siglip.modeling_siglip.SiglipVisionEmbeddings": "diffsynth.core.vram.layers.AutoWrappedModule", + "transformers.models.siglip.modeling_siglip.SiglipMultiheadAttentionPoolingHead": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + }, + "diffsynth.models.dinov3_image_encoder.DINOv3ImageEncoder": { + "transformers.models.dinov3_vit.modeling_dinov3_vit.DINOv3ViTLayerScale": "diffsynth.core.vram.layers.AutoWrappedModule", + "transformers.models.dinov3_vit.modeling_dinov3_vit.DINOv3ViTRopePositionEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule", + "transformers.models.dinov3_vit.modeling_dinov3_vit.DINOv3ViTEmbeddings": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + }, + "diffsynth.models.qwen_image_image2lora.QwenImageImage2LoRAModel": { + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + }, + "diffsynth.models.wan_video_animate_adapter.WanAnimateAdapter": { + "diffsynth.models.wan_video_animate_adapter.FaceEncoder": "diffsynth.core.vram.layers.AutoWrappedModule", + "diffsynth.models.wan_video_animate_adapter.EqualLinear": "diffsynth.core.vram.layers.AutoWrappedModule", + "diffsynth.models.wan_video_animate_adapter.ConvLayer": "diffsynth.core.vram.layers.AutoWrappedModule", + "diffsynth.models.wan_video_animate_adapter.FusedLeakyReLU": "diffsynth.core.vram.layers.AutoWrappedModule", + "diffsynth.models.wan_video_animate_adapter.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + "torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Conv1d": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule", + }, + "diffsynth.models.wan_video_dit_s2v.WanS2VModel": { + "diffsynth.models.wan_video_dit.Head": "diffsynth.core.vram.layers.AutoWrappedModule", + "diffsynth.models.wan_video_dit_s2v.WanS2VDiTBlock": "diffsynth.core.vram.layers.AutoWrappedModule", + "diffsynth.models.wan_video_dit_s2v.CausalAudioEncoder": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + "torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + "diffsynth.models.wan_video_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule", + }, + "diffsynth.models.wan_video_dit.WanModel": { + "diffsynth.models.wan_video_dit.MLP": "diffsynth.core.vram.layers.AutoWrappedModule", + "diffsynth.models.wan_video_dit.DiTBlock": "diffsynth.core.vram.layers.AutoWrappedNonRecurseModule", + "diffsynth.models.wan_video_dit.Head": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + "torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + "diffsynth.models.wan_video_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule", + }, + "diffsynth.models.wan_video_image_encoder.WanImageEncoder": { + "diffsynth.models.wan_video_image_encoder.VisionTransformer": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + "torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + }, + "diffsynth.models.wan_video_mot.MotWanModel": { + "diffsynth.models.wan_video_mot.MotWanAttentionBlock": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + "torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + }, + "diffsynth.models.wan_video_motion_controller.WanMotionControllerModel": { + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + }, + "diffsynth.models.wan_video_text_encoder.WanTextEncoder": { + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + "torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule", + "diffsynth.models.wan_video_text_encoder.T5RelativeEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule", + "diffsynth.models.wan_video_text_encoder.T5LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + }, + "diffsynth.models.wan_video_vace.VaceWanModel": { + "diffsynth.models.wan_video_dit.DiTBlock": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + "torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + "diffsynth.models.wan_video_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + }, + "diffsynth.models.wan_video_vae.WanVideoVAE": { + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + "torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule", + "diffsynth.models.wan_video_vae.RMS_norm": "diffsynth.core.vram.layers.AutoWrappedModule", + "diffsynth.models.wan_video_vae.CausalConv3d": "diffsynth.core.vram.layers.AutoWrappedModule", + "diffsynth.models.wan_video_vae.Upsample": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.SiLU": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Dropout": "diffsynth.core.vram.layers.AutoWrappedModule", + }, + "diffsynth.models.wan_video_vae.WanVideoVAE38": { + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + "torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule", + "diffsynth.models.wan_video_vae.RMS_norm": "diffsynth.core.vram.layers.AutoWrappedModule", + "diffsynth.models.wan_video_vae.CausalConv3d": "diffsynth.core.vram.layers.AutoWrappedModule", + "diffsynth.models.wan_video_vae.Upsample": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.SiLU": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Dropout": "diffsynth.core.vram.layers.AutoWrappedModule", + }, + "diffsynth.models.wav2vec.WanS2VAudioEncoder": { + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + "torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Conv1d": "diffsynth.core.vram.layers.AutoWrappedModule", + }, + "diffsynth.models.longcat_video_dit.LongCatVideoTransformer3DModel": { + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + "torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule", + "diffsynth.models.longcat_video_dit.RMSNorm_FP32": "diffsynth.core.vram.layers.AutoWrappedModule", + "diffsynth.models.longcat_video_dit.LayerNorm_FP32": "diffsynth.core.vram.layers.AutoWrappedModule", + }, + "diffsynth.models.flux_dit.FluxDiT": { + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + "diffsynth.models.flux_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + }, + "diffsynth.models.flux_text_encoder_clip.FluxTextEncoderClip": flux_general_vram_config, + "diffsynth.models.flux_vae.FluxVAEEncoder": flux_general_vram_config, + "diffsynth.models.flux_vae.FluxVAEDecoder": flux_general_vram_config, + "diffsynth.models.flux_controlnet.FluxControlNet": flux_general_vram_config, + "diffsynth.models.flux_infiniteyou.InfiniteYouImageProjector": flux_general_vram_config, + "diffsynth.models.flux_ipadapter.FluxIpAdapter": flux_general_vram_config, + "diffsynth.models.flux_lora_patcher.FluxLoraPatcher": flux_general_vram_config, + "diffsynth.models.step1x_connector.Qwen2Connector": flux_general_vram_config, + "diffsynth.models.flux_lora_encoder.FluxLoRAEncoder": flux_general_vram_config, + "diffsynth.models.flux_text_encoder_t5.FluxTextEncoderT5": { + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + "torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule", + "transformers.models.t5.modeling_t5.T5LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + "transformers.models.t5.modeling_t5.T5DenseActDense": "diffsynth.core.vram.layers.AutoWrappedModule", + "transformers.models.t5.modeling_t5.T5DenseGatedActDense": "diffsynth.core.vram.layers.AutoWrappedModule", + }, + "diffsynth.models.flux_ipadapter.SiglipVisionModelSO400M": { + "transformers.models.siglip.modeling_siglip.SiglipVisionEmbeddings": "diffsynth.core.vram.layers.AutoWrappedModule", + "transformers.models.siglip.modeling_siglip.SiglipEncoder": "diffsynth.core.vram.layers.AutoWrappedModule", + "transformers.models.siglip.modeling_siglip.SiglipMultiheadAttentionPoolingHead": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.MultiheadAttention": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + "torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + }, + "diffsynth.models.flux2_dit.Flux2DiT": { + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + "torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + }, + "diffsynth.models.flux2_text_encoder.Flux2TextEncoder": { + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + "torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule", + "transformers.models.mistral.modeling_mistral.MistralRMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + }, + "diffsynth.models.flux2_vae.Flux2VAE": { + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + "torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.GroupNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + }, + "diffsynth.models.z_image_text_encoder.ZImageTextEncoder": { + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + "transformers.models.qwen3.modeling_qwen3.Qwen3RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule", + }, + "diffsynth.models.z_image_dit.ZImageDiT": { + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + "diffsynth.models.z_image_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + }, + "diffsynth.models.z_image_controlnet.ZImageControlNet": { + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + "diffsynth.models.z_image_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + }, + "diffsynth.models.z_image_image2lora.ZImageImage2LoRAModel": { + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + }, + "diffsynth.models.siglip2_image_encoder.Siglip2ImageEncoder428M": { + "transformers.models.siglip2.modeling_siglip2.Siglip2VisionEmbeddings": "diffsynth.core.vram.layers.AutoWrappedModule", + "transformers.models.siglip2.modeling_siglip2.Siglip2MultiheadAttentionPoolingHead": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + }, + "diffsynth.models.ltx2_dit.LTXModel": { + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + "torch.nn.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + }, + "diffsynth.models.ltx2_upsampler.LTX2LatentUpsampler": { + "torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.GroupNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + }, + "diffsynth.models.ltx2_video_vae.LTX2VideoEncoder": { + "torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule", + }, + "diffsynth.models.ltx2_video_vae.LTX2VideoDecoder": { + "torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule", + }, + "diffsynth.models.ltx2_audio_vae.LTX2AudioDecoder": { + "torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule", + }, + "diffsynth.models.ltx2_audio_vae.LTX2Vocoder": { + "torch.nn.Conv1d": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.ConvTranspose1d": "diffsynth.core.vram.layers.AutoWrappedModule", + }, + "diffsynth.models.ltx2_text_encoder.LTX2TextEncoderPostModules": { + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + "torch.nn.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + "diffsynth.models.ltx2_text_encoder.Embeddings1DConnector": "diffsynth.core.vram.layers.AutoWrappedModule", + }, + "diffsynth.models.ltx2_text_encoder.LTX2TextEncoder": { + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + "transformers.models.gemma3.modeling_gemma3.Gemma3MultiModalProjector": "diffsynth.core.vram.layers.AutoWrappedModule", + "transformers.models.gemma3.modeling_gemma3.Gemma3RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + "transformers.models.gemma3.modeling_gemma3.Gemma3TextScaledWordEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule", + }, + "diffsynth.models.anima_dit.AnimaDiT": { + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + "torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule", + }, + "diffsynth.models.mova_audio_dit.MovaAudioDit": { + "diffsynth.models.wan_video_dit.DiTBlock": "diffsynth.core.vram.layers.AutoWrappedNonRecurseModule", + "diffsynth.models.wan_video_dit.Head": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + "torch.nn.Conv1d": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + "diffsynth.models.wan_video_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + }, + "diffsynth.models.mova_dual_tower_bridge.DualTowerConditionalBridge": { + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + "torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + "diffsynth.models.wan_video_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + }, + "diffsynth.models.mova_audio_vae.DacVAE": { + "diffsynth.models.mova_audio_vae.Snake1d": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Conv1d": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.ConvTranspose1d": "diffsynth.core.vram.layers.AutoWrappedModule", + }, +} + +def QwenImageTextEncoder_Module_Map_Updater(): + current = VRAM_MANAGEMENT_MODULE_MAPS["diffsynth.models.qwen_image_text_encoder.QwenImageTextEncoder"] + from packaging import version + import transformers + if version.parse(transformers.__version__) >= version.parse("5.2.0"): + # The Qwen2RMSNorm in transformers 5.2.0+ has been renamed to Qwen2_5_VLRMSNorm, so we need to update the module map accordingly + current.pop("transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2RMSNorm", None) + current["transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VLRMSNorm"] = "diffsynth.core.vram.layers.AutoWrappedModule" + return current + +VERSION_CHECKER_MAPS = { + "diffsynth.models.qwen_image_text_encoder.QwenImageTextEncoder": QwenImageTextEncoder_Module_Map_Updater, +} \ No newline at end of file diff --git a/diffsynth/core/__init__.py b/diffsynth/core/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6c0a6c8774ba11b6e2dd9d54775d50431acdaaee --- /dev/null +++ b/diffsynth/core/__init__.py @@ -0,0 +1,6 @@ +from .attention import * +from .data import * +from .gradient import * +from .loader import * +from .vram import * +from .device import * diff --git a/diffsynth/core/attention/__init__.py b/diffsynth/core/attention/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..45cf8a4382397aa4ff6558f37191c726d821b95a --- /dev/null +++ b/diffsynth/core/attention/__init__.py @@ -0,0 +1 @@ +from .attention import attention_forward diff --git a/diffsynth/core/attention/attention.py b/diffsynth/core/attention/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..630d37545441eae401a198b065820307e286586b --- /dev/null +++ b/diffsynth/core/attention/attention.py @@ -0,0 +1,121 @@ +import torch, os +from einops import rearrange + + +try: + import flash_attn_interface + FLASH_ATTN_3_AVAILABLE = True +except ModuleNotFoundError: + FLASH_ATTN_3_AVAILABLE = False + +try: + import flash_attn + FLASH_ATTN_2_AVAILABLE = True +except ModuleNotFoundError: + FLASH_ATTN_2_AVAILABLE = False + +try: + from sageattention import sageattn + SAGE_ATTN_AVAILABLE = True +except ModuleNotFoundError: + SAGE_ATTN_AVAILABLE = False + +try: + import xformers.ops as xops + XFORMERS_AVAILABLE = True +except ModuleNotFoundError: + XFORMERS_AVAILABLE = False + + +def initialize_attention_priority(): + if os.environ.get('DIFFSYNTH_ATTENTION_IMPLEMENTATION') is not None: + return os.environ.get('DIFFSYNTH_ATTENTION_IMPLEMENTATION').lower() + elif FLASH_ATTN_3_AVAILABLE: + return "flash_attention_3" + elif FLASH_ATTN_2_AVAILABLE: + return "flash_attention_2" + elif SAGE_ATTN_AVAILABLE: + return "sage_attention" + elif XFORMERS_AVAILABLE: + return "xformers" + else: + return "torch" + + +ATTENTION_IMPLEMENTATION = initialize_attention_priority() + + +def rearrange_qkv(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", required_in_pattern="b n s d", dims=None): + dims = {} if dims is None else dims + if q_pattern != required_in_pattern: + q = rearrange(q, f"{q_pattern} -> {required_in_pattern}", **dims) + if k_pattern != required_in_pattern: + k = rearrange(k, f"{k_pattern} -> {required_in_pattern}", **dims) + if v_pattern != required_in_pattern: + v = rearrange(v, f"{v_pattern} -> {required_in_pattern}", **dims) + return q, k, v + + +def rearrange_out(out: torch.Tensor, out_pattern="b n s d", required_out_pattern="b n s d", dims=None): + dims = {} if dims is None else dims + if out_pattern != required_out_pattern: + out = rearrange(out, f"{required_out_pattern} -> {out_pattern}", **dims) + return out + + +def torch_sdpa(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, attn_mask=None, scale=None): + required_in_pattern, required_out_pattern= "b n s d", "b n s d" + q, k, v = rearrange_qkv(q, k, v, q_pattern, k_pattern, v_pattern, required_in_pattern, dims) + out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask, scale=scale) + out = rearrange_out(out, out_pattern, required_out_pattern, dims) + return out + + +def flash_attention_3(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, scale=None): + required_in_pattern, required_out_pattern= "b s n d", "b s n d" + q, k, v = rearrange_qkv(q, k, v, q_pattern, k_pattern, v_pattern, required_in_pattern, dims) + out = flash_attn_interface.flash_attn_func(q, k, v, softmax_scale=scale) + if isinstance(out, tuple): + out = out[0] + out = rearrange_out(out, out_pattern, required_out_pattern, dims) + return out + + +def flash_attention_2(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, scale=None): + required_in_pattern, required_out_pattern= "b s n d", "b s n d" + q, k, v = rearrange_qkv(q, k, v, q_pattern, k_pattern, v_pattern, required_in_pattern, dims) + out = flash_attn.flash_attn_func(q, k, v, softmax_scale=scale) + out = rearrange_out(out, out_pattern, required_out_pattern, dims) + return out + + +def sage_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, scale=None): + required_in_pattern, required_out_pattern= "b n s d", "b n s d" + q, k, v = rearrange_qkv(q, k, v, q_pattern, k_pattern, v_pattern, required_in_pattern, dims) + out = sageattn(q, k, v, sm_scale=scale) + out = rearrange_out(out, out_pattern, required_out_pattern, dims) + return out + + +def xformers_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, scale=None): + required_in_pattern, required_out_pattern= "b s n d", "b s n d" + q, k, v = rearrange_qkv(q, k, v, q_pattern, k_pattern, v_pattern, required_in_pattern, dims) + out = xops.memory_efficient_attention(q, k, v, scale=scale) + out = rearrange_out(out, out_pattern, required_out_pattern, dims) + return out + + +def attention_forward(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, attn_mask=None, scale=None, compatibility_mode=False): + if compatibility_mode or (attn_mask is not None): + return torch_sdpa(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, attn_mask=attn_mask, scale=scale) + else: + if ATTENTION_IMPLEMENTATION == "flash_attention_3": + return flash_attention_3(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale) + elif ATTENTION_IMPLEMENTATION == "flash_attention_2": + return flash_attention_2(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale) + elif ATTENTION_IMPLEMENTATION == "sage_attention": + return sage_attention(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale) + elif ATTENTION_IMPLEMENTATION == "xformers": + return xformers_attention(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale) + else: + return torch_sdpa(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale) diff --git a/diffsynth/core/data/__init__.py b/diffsynth/core/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d494a277d81eeb2a9575155eb983d8bc3879590a --- /dev/null +++ b/diffsynth/core/data/__init__.py @@ -0,0 +1 @@ +from .unified_dataset import UnifiedDataset diff --git a/diffsynth/core/data/operators.py b/diffsynth/core/data/operators.py new file mode 100644 index 0000000000000000000000000000000000000000..8a072e21f2e2e8a38026373029667f9f7c119ef7 --- /dev/null +++ b/diffsynth/core/data/operators.py @@ -0,0 +1,280 @@ +import math +import torch, torchvision, imageio, os +import imageio.v3 as iio +from PIL import Image +import torchaudio + + +class DataProcessingPipeline: + def __init__(self, operators=None): + self.operators: list[DataProcessingOperator] = [] if operators is None else operators + + def __call__(self, data): + for operator in self.operators: + data = operator(data) + return data + + def __rshift__(self, pipe): + if isinstance(pipe, DataProcessingOperator): + pipe = DataProcessingPipeline([pipe]) + return DataProcessingPipeline(self.operators + pipe.operators) + + +class DataProcessingOperator: + def __call__(self, data): + raise NotImplementedError("DataProcessingOperator cannot be called directly.") + + def __rshift__(self, pipe): + if isinstance(pipe, DataProcessingOperator): + pipe = DataProcessingPipeline([pipe]) + return DataProcessingPipeline([self]).__rshift__(pipe) + + +class DataProcessingOperatorRaw(DataProcessingOperator): + def __call__(self, data): + return data + + +class ToInt(DataProcessingOperator): + def __call__(self, data): + return int(data) + + +class ToFloat(DataProcessingOperator): + def __call__(self, data): + return float(data) + + +class ToStr(DataProcessingOperator): + def __init__(self, none_value=""): + self.none_value = none_value + + def __call__(self, data): + if data is None: data = self.none_value + return str(data) + + +class LoadImage(DataProcessingOperator): + def __init__(self, convert_RGB=True, convert_RGBA=False): + self.convert_RGB = convert_RGB + self.convert_RGBA = convert_RGBA + + def __call__(self, data: str): + image = Image.open(data) + if self.convert_RGB: image = image.convert("RGB") + if self.convert_RGBA: image = image.convert("RGBA") + return image + + +class ImageCropAndResize(DataProcessingOperator): + def __init__(self, height=None, width=None, max_pixels=None, height_division_factor=1, width_division_factor=1): + self.height = height + self.width = width + self.max_pixels = max_pixels + self.height_division_factor = height_division_factor + self.width_division_factor = width_division_factor + + def crop_and_resize(self, image, target_height, target_width): + width, height = image.size + scale = max(target_width / width, target_height / height) + image = torchvision.transforms.functional.resize( + image, + (round(height*scale), round(width*scale)), + interpolation=torchvision.transforms.InterpolationMode.BILINEAR + ) + image = torchvision.transforms.functional.center_crop(image, (target_height, target_width)) + return image + + def get_height_width(self, image): + if self.height is None or self.width is None: + width, height = image.size + if width * height > self.max_pixels: + scale = (width * height / self.max_pixels) ** 0.5 + height, width = int(height / scale), int(width / scale) + height = height // self.height_division_factor * self.height_division_factor + width = width // self.width_division_factor * self.width_division_factor + else: + height, width = self.height, self.width + return height, width + + def __call__(self, data: Image.Image): + image = self.crop_and_resize(data, *self.get_height_width(data)) + return image + + +class ToList(DataProcessingOperator): + def __call__(self, data): + return [data] + + +class FrameSamplerByRateMixin: + def __init__(self, num_frames=81, time_division_factor=4, time_division_remainder=1, frame_rate=24, fix_frame_rate=False): + self.num_frames = num_frames + self.time_division_factor = time_division_factor + self.time_division_remainder = time_division_remainder + self.frame_rate = frame_rate + self.fix_frame_rate = fix_frame_rate + + def get_reader(self, data: str): + return imageio.get_reader(data) + + def get_available_num_frames(self, reader): + if not self.fix_frame_rate: + return reader.count_frames() + meta_data = reader.get_meta_data() + total_original_frames = int(reader.count_frames()) + duration = meta_data["duration"] if "duration" in meta_data else total_original_frames / meta_data['fps'] + total_available_frames = math.floor(duration * self.frame_rate) + return int(total_available_frames) + + def get_num_frames(self, reader): + num_frames = self.num_frames + total_frames = self.get_available_num_frames(reader) + if int(total_frames) < num_frames: + num_frames = total_frames + while num_frames > 1 and num_frames % self.time_division_factor != self.time_division_remainder: + num_frames -= 1 + return num_frames + + def map_single_frame_id(self, new_sequence_id: int, raw_frame_rate: float, total_raw_frames: int) -> int: + if not self.fix_frame_rate: + return new_sequence_id + target_time_in_seconds = new_sequence_id / self.frame_rate + raw_frame_index_float = target_time_in_seconds * raw_frame_rate + frame_id = int(round(raw_frame_index_float)) + frame_id = min(frame_id, total_raw_frames - 1) + return frame_id + + +class LoadVideo(DataProcessingOperator, FrameSamplerByRateMixin): + def __init__(self, num_frames=81, time_division_factor=4, time_division_remainder=1, frame_processor=lambda x: x, frame_rate=24, fix_frame_rate=False): + FrameSamplerByRateMixin.__init__(self, num_frames, time_division_factor, time_division_remainder, frame_rate, fix_frame_rate) + # frame_processor is build in the video loader for high efficiency. + self.frame_processor = frame_processor + + def __call__(self, data: str): + reader = self.get_reader(data) + raw_frame_rate = reader.get_meta_data()['fps'] + total_raw_frames = reader.count_frames() + total_available = self.get_available_num_frames(reader) + # Pad short videos with the last frame instead of reducing num_frames + num_frames = self.num_frames + frames = [] + for frame_id in range(num_frames): + if frame_id < total_available: + raw_id = self.map_single_frame_id(frame_id, raw_frame_rate, total_raw_frames) + frame = reader.get_data(raw_id) + frame = Image.fromarray(frame) + frame = self.frame_processor(frame) + frames.append(frame) + else: + # Pad with the last frame + frames.append(frames[-1]) + reader.close() + return frames + + +class SequencialProcess(DataProcessingOperator): + def __init__(self, operator=lambda x: x): + self.operator = operator + + def __call__(self, data): + return [self.operator(i) for i in data] + + +class LoadGIF(DataProcessingOperator): + def __init__(self, num_frames=81, time_division_factor=4, time_division_remainder=1, frame_processor=lambda x: x): + self.num_frames = num_frames + self.time_division_factor = time_division_factor + self.time_division_remainder = time_division_remainder + # frame_processor is build in the video loader for high efficiency. + self.frame_processor = frame_processor + + def get_num_frames(self, path): + num_frames = self.num_frames + images = iio.imread(path, mode="RGB") + if len(images) < num_frames: + num_frames = len(images) + while num_frames > 1 and num_frames % self.time_division_factor != self.time_division_remainder: + num_frames -= 1 + return num_frames + + def __call__(self, data: str): + num_frames = self.get_num_frames(data) + frames = [] + images = iio.imread(data, mode="RGB") + for img in images: + frame = Image.fromarray(img) + frame = self.frame_processor(frame) + frames.append(frame) + if len(frames) >= num_frames: + break + return frames + + +class RouteByExtensionName(DataProcessingOperator): + def __init__(self, operator_map): + self.operator_map = operator_map + + def __call__(self, data: str): + file_ext_name = data.split(".")[-1].lower() + for ext_names, operator in self.operator_map: + if ext_names is None or file_ext_name in ext_names: + return operator(data) + raise ValueError(f"Unsupported file: {data}") + + +class RouteByType(DataProcessingOperator): + def __init__(self, operator_map): + self.operator_map = operator_map + + def __call__(self, data): + for dtype, operator in self.operator_map: + if dtype is None or isinstance(data, dtype): + return operator(data) + raise ValueError(f"Unsupported data: {data}") + + +class LoadTorchPickle(DataProcessingOperator): + def __init__(self, map_location="cpu"): + self.map_location = map_location + + def __call__(self, data): + return torch.load(data, map_location=self.map_location, weights_only=False) + + +class ToAbsolutePath(DataProcessingOperator): + def __init__(self, base_path=""): + self.base_path = base_path + + def __call__(self, data): + return os.path.join(self.base_path, data) + + +class LoadAudio(DataProcessingOperator): + def __init__(self, sr=16000): + self.sr = sr + def __call__(self, data: str): + import librosa + input_audio, sample_rate = librosa.load(data, sr=self.sr) + return input_audio + + +class LoadAudioWithTorchaudio(DataProcessingOperator, FrameSamplerByRateMixin): + + def __init__(self, num_frames=121, time_division_factor=8, time_division_remainder=1, frame_rate=24, fix_frame_rate=True): + FrameSamplerByRateMixin.__init__(self, num_frames, time_division_factor, time_division_remainder, frame_rate, fix_frame_rate) + + def __call__(self, data: str): + reader = self.get_reader(data) + num_frames = self.get_num_frames(reader) + duration = num_frames / self.frame_rate + waveform, sample_rate = torchaudio.load(data) + target_samples = int(duration * sample_rate) + current_samples = waveform.shape[-1] + if current_samples > target_samples: + waveform = waveform[..., :target_samples] + elif current_samples < target_samples: + padding = target_samples - current_samples + waveform = torch.nn.functional.pad(waveform, (0, padding)) + return waveform, sample_rate diff --git a/diffsynth/core/data/unified_dataset.py b/diffsynth/core/data/unified_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..9dd9c51142bd587c22f3922fd4e26f72c722b503 --- /dev/null +++ b/diffsynth/core/data/unified_dataset.py @@ -0,0 +1,118 @@ +from .operators import * +import torch, json, pandas + + +class UnifiedDataset(torch.utils.data.Dataset): + def __init__( + self, + base_path=None, metadata_path=None, + repeat=1, + data_file_keys=tuple(), + main_data_operator=lambda x: x, + special_operator_map=None, + max_data_items=None, + ): + self.base_path = base_path + self.metadata_path = metadata_path + self.repeat = repeat + self.data_file_keys = data_file_keys + self.main_data_operator = main_data_operator + self.cached_data_operator = LoadTorchPickle() + self.special_operator_map = {} if special_operator_map is None else special_operator_map + self.max_data_items = max_data_items + self.data = [] + self.cached_data = [] + self.load_from_cache = metadata_path is None + self.load_metadata(metadata_path) + + @staticmethod + def default_image_operator( + base_path="", + max_pixels=1920*1080, height=None, width=None, + height_division_factor=16, width_division_factor=16, + ): + return RouteByType(operator_map=[ + (str, ToAbsolutePath(base_path) >> LoadImage() >> ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor)), + (list, SequencialProcess(ToAbsolutePath(base_path) >> LoadImage() >> ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor))), + ]) + + @staticmethod + def default_video_operator( + base_path="", + max_pixels=1920*1080, height=None, width=None, + height_division_factor=16, width_division_factor=16, + num_frames=81, time_division_factor=4, time_division_remainder=1, + frame_rate=24, fix_frame_rate=False, + ): + return RouteByType(operator_map=[ + (str, ToAbsolutePath(base_path) >> RouteByExtensionName(operator_map=[ + (("jpg", "jpeg", "png", "webp"), LoadImage() >> ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor) >> ToList()), + (("gif",), LoadGIF( + num_frames, time_division_factor, time_division_remainder, + frame_processor=ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor), + )), + (("mp4", "avi", "mov", "wmv", "mkv", "flv", "webm"), LoadVideo( + num_frames, time_division_factor, time_division_remainder, + frame_processor=ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor), + frame_rate=frame_rate, fix_frame_rate=fix_frame_rate, + )), + ])), + ]) + + def search_for_cached_data_files(self, path): + for file_name in os.listdir(path): + subpath = os.path.join(path, file_name) + if os.path.isdir(subpath): + self.search_for_cached_data_files(subpath) + elif subpath.endswith(".pth"): + self.cached_data.append(subpath) + + def load_metadata(self, metadata_path): + if metadata_path is None: + print("No metadata_path. Searching for cached data files.") + self.search_for_cached_data_files(self.base_path) + print(f"{len(self.cached_data)} cached data files found.") + elif metadata_path.endswith(".json"): + with open(metadata_path, "r") as f: + metadata = json.load(f) + self.data = metadata + elif metadata_path.endswith(".jsonl"): + metadata = [] + with open(metadata_path, 'r') as f: + for line in f: + metadata.append(json.loads(line.strip())) + self.data = metadata + else: + metadata = pandas.read_csv(metadata_path) + self.data = [metadata.iloc[i].to_dict() for i in range(len(metadata))] + + def __getitem__(self, data_id): + if self.load_from_cache: + data = self.cached_data[data_id % len(self.cached_data)] + data = self.cached_data_operator(data) + else: + data = self.data[data_id % len(self.data)].copy() + for key in self.data_file_keys: + if key in data: + if key in self.special_operator_map: + data[key] = self.special_operator_map[key](data[key]) + elif key in self.data_file_keys: + data[key] = self.main_data_operator(data[key]) + return data + + def __len__(self): + if self.max_data_items is not None: + return self.max_data_items + elif self.load_from_cache: + return len(self.cached_data) * self.repeat + else: + return len(self.data) * self.repeat + + def check_data_equal(self, data1, data2): + # Debug only + if len(data1) != len(data2): + return False + for k in data1: + if data1[k] != data2[k]: + return False + return True diff --git a/diffsynth/core/device/__init__.py b/diffsynth/core/device/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..889d6823adb203630d2173e8ab25fcc350b32ce4 --- /dev/null +++ b/diffsynth/core/device/__init__.py @@ -0,0 +1,2 @@ +from .npu_compatible_device import parse_device_type, parse_nccl_backend, get_available_device_type, get_device_name +from .npu_compatible_device import IS_NPU_AVAILABLE, IS_CUDA_AVAILABLE diff --git a/diffsynth/core/device/npu_compatible_device.py b/diffsynth/core/device/npu_compatible_device.py new file mode 100644 index 0000000000000000000000000000000000000000..d96b8fb2479688e2209a7c8349d76ba093aa8e99 --- /dev/null +++ b/diffsynth/core/device/npu_compatible_device.py @@ -0,0 +1,107 @@ +import importlib +import torch +from typing import Any + + +def is_torch_npu_available(): + return importlib.util.find_spec("torch_npu") is not None + + +IS_CUDA_AVAILABLE = torch.cuda.is_available() +IS_NPU_AVAILABLE = is_torch_npu_available() and torch.npu.is_available() + +if IS_NPU_AVAILABLE: + import torch_npu + + torch.npu.config.allow_internal_format = False + + +def get_device_type() -> str: + """Get device type based on current machine, currently only support CPU, CUDA, NPU.""" + if IS_CUDA_AVAILABLE: + device = "cuda" + elif IS_NPU_AVAILABLE: + device = "npu" + else: + device = "cpu" + + return device + + +def get_torch_device() -> Any: + """Get torch attribute based on device type, e.g. torch.cuda or torch.npu""" + device_name = get_device_type() + + try: + return getattr(torch, device_name) + except AttributeError: + print(f"Device namespace '{device_name}' not found in torch, try to load 'torch.cuda'.") + return torch.cuda + + +def get_device_id() -> int: + """Get current device id based on device type.""" + return get_torch_device().current_device() + + +def get_device_name() -> str: + """Get current device name based on device type.""" + return f"{get_device_type()}:{get_device_id()}" + + +def synchronize() -> None: + """Execute torch synchronize operation.""" + get_torch_device().synchronize() + + +def empty_cache() -> None: + """Execute torch empty cache operation.""" + get_torch_device().empty_cache() + + +def get_nccl_backend() -> str: + """Return distributed communication backend type based on device type.""" + if IS_CUDA_AVAILABLE: + return "nccl" + elif IS_NPU_AVAILABLE: + return "hccl" + else: + raise RuntimeError(f"No available distributed communication backend found on device type {get_device_type()}.") + + +def enable_high_precision_for_bf16(): + """ + Set high accumulation dtype for matmul and reduction. + """ + if IS_CUDA_AVAILABLE: + torch.backends.cuda.matmul.allow_tf32 = False + torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False + + if IS_NPU_AVAILABLE: + torch.npu.matmul.allow_tf32 = False + torch.npu.matmul.allow_bf16_reduced_precision_reduction = False + + +def parse_device_type(device): + if isinstance(device, str): + if device.startswith("cuda"): + return "cuda" + elif device.startswith("npu"): + return "npu" + else: + return "cpu" + elif isinstance(device, torch.device): + return device.type + + +def parse_nccl_backend(device_type): + if device_type == "cuda": + return "nccl" + elif device_type == "npu": + return "hccl" + else: + raise RuntimeError(f"No available distributed communication backend found on device type {device_type}.") + + +def get_available_device_type(): + return get_device_type() diff --git a/diffsynth/core/gradient/__init__.py b/diffsynth/core/gradient/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..57914792a78ec32f69c3c99ae37535598efc8d52 --- /dev/null +++ b/diffsynth/core/gradient/__init__.py @@ -0,0 +1 @@ +from .gradient_checkpoint import gradient_checkpoint_forward diff --git a/diffsynth/core/gradient/gradient_checkpoint.py b/diffsynth/core/gradient/gradient_checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..d8aeed391a638e2e45de8947989d60db65ca5ac7 --- /dev/null +++ b/diffsynth/core/gradient/gradient_checkpoint.py @@ -0,0 +1,37 @@ +import torch +import warnings +# Suppress checkpoint requires_grad warning - gradients flow through model params, not inputs +warnings.filterwarnings("ignore", message=".*None of the inputs have requires_grad.*") + + +def create_custom_forward(module): + def custom_forward(*inputs, **kwargs): + return module(*inputs, **kwargs) + return custom_forward + + +def gradient_checkpoint_forward( + model, + use_gradient_checkpointing, + use_gradient_checkpointing_offload, + *args, + **kwargs, +): + if use_gradient_checkpointing_offload: + with torch.autograd.graph.save_on_cpu(): + model_output = torch.utils.checkpoint.checkpoint( + create_custom_forward(model), + *args, + **kwargs, + use_reentrant=True, + ) + elif use_gradient_checkpointing: + model_output = torch.utils.checkpoint.checkpoint( + create_custom_forward(model), + *args, + **kwargs, + use_reentrant=True, + ) + else: + model_output = model(*args, **kwargs) + return model_output diff --git a/diffsynth/core/loader/__init__.py b/diffsynth/core/loader/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1f56d814bae40436f66bca583e33d180d6e11247 --- /dev/null +++ b/diffsynth/core/loader/__init__.py @@ -0,0 +1,3 @@ +from .file import load_state_dict, hash_state_dict_keys, hash_model_file +from .model import load_model, load_model_with_disk_offload +from .config import ModelConfig diff --git a/diffsynth/core/loader/config.py b/diffsynth/core/loader/config.py new file mode 100644 index 0000000000000000000000000000000000000000..d4ce83cf30c17a37a242597323c4269c8dba0fd7 --- /dev/null +++ b/diffsynth/core/loader/config.py @@ -0,0 +1,119 @@ +import torch, glob, os +from typing import Optional, Union, Dict +from dataclasses import dataclass +from modelscope import snapshot_download +from huggingface_hub import snapshot_download as hf_snapshot_download +from typing import Optional + + +@dataclass +class ModelConfig: + path: Union[str, list[str]] = None + model_id: str = None + origin_file_pattern: Union[str, list[str]] = None + download_source: str = None + local_model_path: str = None + skip_download: bool = None + offload_device: Optional[Union[str, torch.device]] = None + offload_dtype: Optional[torch.dtype] = None + onload_device: Optional[Union[str, torch.device]] = None + onload_dtype: Optional[torch.dtype] = None + preparing_device: Optional[Union[str, torch.device]] = None + preparing_dtype: Optional[torch.dtype] = None + computation_device: Optional[Union[str, torch.device]] = None + computation_dtype: Optional[torch.dtype] = None + clear_parameters: bool = False + state_dict: Dict[str, torch.Tensor] = None + + def check_input(self): + if self.path is None and self.model_id is None: + raise ValueError(f"""No valid model files. Please use `ModelConfig(path="xxx")` or `ModelConfig(model_id="xxx/yyy", origin_file_pattern="zzz")`. `skip_download=True` only supports the first one.""") + + def parse_original_file_pattern(self): + if self.origin_file_pattern in [None, "", "./"]: + return "*" + elif self.origin_file_pattern.endswith("/"): + return self.origin_file_pattern + "*" + else: + return self.origin_file_pattern + + def parse_download_source(self): + if self.download_source is None: + if os.environ.get('DIFFSYNTH_DOWNLOAD_SOURCE') is not None: + return os.environ.get('DIFFSYNTH_DOWNLOAD_SOURCE') + else: + return "modelscope" + else: + return self.download_source + + def parse_skip_download(self): + if self.skip_download is None: + if os.environ.get('DIFFSYNTH_SKIP_DOWNLOAD') is not None: + if os.environ.get('DIFFSYNTH_SKIP_DOWNLOAD').lower() == "true": + return True + elif os.environ.get('DIFFSYNTH_SKIP_DOWNLOAD').lower() == "false": + return False + else: + return False + else: + return self.skip_download + + def download(self): + origin_file_pattern = self.parse_original_file_pattern() + downloaded_files = glob.glob(origin_file_pattern, root_dir=os.path.join(self.local_model_path, self.model_id)) + download_source = self.parse_download_source() + if download_source.lower() == "modelscope": + snapshot_download( + self.model_id, + local_dir=os.path.join(self.local_model_path, self.model_id), + allow_file_pattern=origin_file_pattern, + ignore_file_pattern=downloaded_files, + local_files_only=False + ) + elif download_source.lower() == "huggingface": + hf_snapshot_download( + self.model_id, + local_dir=os.path.join(self.local_model_path, self.model_id), + allow_patterns=origin_file_pattern, + ignore_patterns=downloaded_files, + local_files_only=False + ) + else: + raise ValueError("`download_source` should be `modelscope` or `huggingface`.") + + def require_downloading(self): + if self.path is not None: + return False + skip_download = self.parse_skip_download() + return not skip_download + + def reset_local_model_path(self): + if os.environ.get('DIFFSYNTH_MODEL_BASE_PATH') is not None: + self.local_model_path = os.environ.get('DIFFSYNTH_MODEL_BASE_PATH') + elif self.local_model_path is None: + self.local_model_path = "./models" + + def download_if_necessary(self): + self.check_input() + self.reset_local_model_path() + if self.require_downloading(): + self.download() + if self.path is None: + if self.origin_file_pattern in [None, "", "./"]: + self.path = os.path.join(self.local_model_path, self.model_id) + else: + self.path = glob.glob(os.path.join(self.local_model_path, self.model_id, self.origin_file_pattern)) + if isinstance(self.path, list) and len(self.path) == 1: + self.path = self.path[0] + + def vram_config(self): + return { + "offload_device": self.offload_device, + "offload_dtype": self.offload_dtype, + "onload_device": self.onload_device, + "onload_dtype": self.onload_dtype, + "preparing_device": self.preparing_device, + "preparing_dtype": self.preparing_dtype, + "computation_device": self.computation_device, + "computation_dtype": self.computation_dtype, + } diff --git a/diffsynth/core/loader/file.py b/diffsynth/core/loader/file.py new file mode 100644 index 0000000000000000000000000000000000000000..67d88155bff64d005a5115f42d549ea444d9b90f --- /dev/null +++ b/diffsynth/core/loader/file.py @@ -0,0 +1,130 @@ +from safetensors import safe_open +import torch, hashlib + + +def load_state_dict(file_path, torch_dtype=None, device="cpu", pin_memory=False, verbose=0): + if isinstance(file_path, list): + state_dict = {} + for file_path_ in file_path: + state_dict.update(load_state_dict(file_path_, torch_dtype, device, pin_memory=pin_memory, verbose=verbose)) + else: + if verbose >= 1: + print(f"Loading file [started]: {file_path}") + if file_path.endswith(".safetensors"): + state_dict = load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype, device=device) + else: + state_dict = load_state_dict_from_bin(file_path, torch_dtype=torch_dtype, device=device) + # If load state dict in CPU memory, `pin_memory=True` will make `model.to("cuda")` faster. + if pin_memory: + for i in state_dict: + state_dict[i] = state_dict[i].pin_memory() + if verbose >= 1: + print(f"Loading file [done]: {file_path}") + return state_dict + + +def load_state_dict_from_safetensors(file_path, torch_dtype=None, device="cpu"): + state_dict = {} + with safe_open(file_path, framework="pt", device=str(device)) as f: + for k in f.keys(): + state_dict[k] = f.get_tensor(k) + if torch_dtype is not None: + state_dict[k] = state_dict[k].to(torch_dtype) + return state_dict + + +def load_state_dict_from_bin(file_path, torch_dtype=None, device="cpu"): + state_dict = torch.load(file_path, map_location=device, weights_only=True) + if len(state_dict) == 1: + if "state_dict" in state_dict: + state_dict = state_dict["state_dict"] + elif "module" in state_dict: + state_dict = state_dict["module"] + elif "model_state" in state_dict: + state_dict = state_dict["model_state"] + if torch_dtype is not None: + for i in state_dict: + if isinstance(state_dict[i], torch.Tensor): + state_dict[i] = state_dict[i].to(torch_dtype) + return state_dict + + +def convert_state_dict_keys_to_single_str(state_dict, with_shape=True): + keys = [] + for key, value in state_dict.items(): + if isinstance(key, str): + if isinstance(value, torch.Tensor): + if with_shape: + shape = "_".join(map(str, list(value.shape))) + keys.append(key + ":" + shape) + keys.append(key) + elif isinstance(value, dict): + keys.append(key + "|" + convert_state_dict_keys_to_single_str(value, with_shape=with_shape)) + keys.sort() + keys_str = ",".join(keys) + return keys_str + + +def hash_state_dict_keys(state_dict, with_shape=True): + keys_str = convert_state_dict_keys_to_single_str(state_dict, with_shape=with_shape) + keys_str = keys_str.encode(encoding="UTF-8") + return hashlib.md5(keys_str).hexdigest() + + +def load_keys_dict(file_path): + if isinstance(file_path, list): + state_dict = {} + for file_path_ in file_path: + state_dict.update(load_keys_dict(file_path_)) + return state_dict + if file_path.endswith(".safetensors"): + return load_keys_dict_from_safetensors(file_path) + else: + return load_keys_dict_from_bin(file_path) + + +def load_keys_dict_from_safetensors(file_path): + keys_dict = {} + with safe_open(file_path, framework="pt", device="cpu") as f: + for k in f.keys(): + keys_dict[k] = f.get_slice(k).get_shape() + return keys_dict + + +def convert_state_dict_to_keys_dict(state_dict): + keys_dict = {} + for k, v in state_dict.items(): + if isinstance(v, torch.Tensor): + keys_dict[k] = list(v.shape) + else: + keys_dict[k] = convert_state_dict_to_keys_dict(v) + return keys_dict + + +def load_keys_dict_from_bin(file_path): + state_dict = load_state_dict_from_bin(file_path) + keys_dict = convert_state_dict_to_keys_dict(state_dict) + return keys_dict + + +def convert_keys_dict_to_single_str(state_dict, with_shape=True): + keys = [] + for key, value in state_dict.items(): + if isinstance(key, str): + if isinstance(value, dict): + keys.append(key + "|" + convert_keys_dict_to_single_str(value, with_shape=with_shape)) + else: + if with_shape: + shape = "_".join(map(str, list(value))) + keys.append(key + ":" + shape) + keys.append(key) + keys.sort() + keys_str = ",".join(keys) + return keys_str + + +def hash_model_file(path, with_shape=True): + keys_dict = load_keys_dict(path) + keys_str = convert_keys_dict_to_single_str(keys_dict, with_shape=with_shape) + keys_str = keys_str.encode(encoding="UTF-8") + return hashlib.md5(keys_str).hexdigest() diff --git a/diffsynth/core/loader/model.py b/diffsynth/core/loader/model.py new file mode 100644 index 0000000000000000000000000000000000000000..18004336d9ecd6dac9014b1f3a102f7599a7eb28 --- /dev/null +++ b/diffsynth/core/loader/model.py @@ -0,0 +1,115 @@ +from ..vram.initialization import skip_model_initialization +from ..vram.disk_map import DiskMap +from ..vram.layers import enable_vram_management +from .file import load_state_dict +import torch +from contextlib import contextmanager +from transformers.integrations import is_deepspeed_zero3_enabled +from transformers.utils import ContextManagers + + +def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, device="cpu", state_dict_converter=None, use_disk_map=False, module_map=None, vram_config=None, vram_limit=None, state_dict=None): + config = {} if config is None else config + # Skip ZeRO-3 initialization for VAE to avoid compatibility issues + skip_zero3 = 'vae' in model_class.__name__.lower() if hasattr(model_class, '__name__') else False + with ContextManagers(get_init_context(torch_dtype=torch_dtype, device=device, skip_zero3=skip_zero3)): + model = model_class(**config) + # What is `module_map`? + # This is a module mapping table for VRAM management. + if module_map is not None: + devices = [vram_config["offload_device"], vram_config["onload_device"], vram_config["preparing_device"], vram_config["computation_device"]] + device = [d for d in devices if d != "disk"][0] + dtypes = [vram_config["offload_dtype"], vram_config["onload_dtype"], vram_config["preparing_dtype"], vram_config["computation_dtype"]] + dtype = [d for d in dtypes if d != "disk"][0] + if vram_config["offload_device"] != "disk": + if state_dict is None: state_dict = DiskMap(path, device, torch_dtype=dtype) + if state_dict_converter is not None: + state_dict = state_dict_converter(state_dict) + else: + state_dict = {i: state_dict[i] for i in state_dict} + if is_deepspeed_zero3_enabled(): + from transformers.integrations.deepspeed import _load_state_dict_into_zero3_model + _load_state_dict_into_zero3_model(model, state_dict) + else: + model.load_state_dict(state_dict, assign=True) + model = enable_vram_management(model, module_map, vram_config=vram_config, disk_map=None, vram_limit=vram_limit) + else: + disk_map = DiskMap(path, device, state_dict_converter=state_dict_converter) + model = enable_vram_management(model, module_map, vram_config=vram_config, disk_map=disk_map, vram_limit=vram_limit) + else: + # Why do we use `DiskMap`? + # Sometimes a model file contains multiple models, + # and DiskMap can load only the parameters of a single model, + # avoiding the need to load all parameters in the file. + if state_dict is not None: + pass + elif use_disk_map: + state_dict = DiskMap(path, device, torch_dtype=torch_dtype) + else: + state_dict = load_state_dict(path, torch_dtype, device) + # Why do we use `state_dict_converter`? + # Some models are saved in complex formats, + # and we need to convert the state dict into the appropriate format. + if state_dict_converter is not None: + state_dict = state_dict_converter(state_dict) + else: + state_dict = {i: state_dict[i] for i in state_dict} + # Why does DeepSpeed ZeRO Stage 3 need to be handled separately? + # Because at this stage, model parameters are partitioned across multiple GPUs. + # Loading them directly could lead to excessive GPU memory consumption. + if is_deepspeed_zero3_enabled(): + from transformers.integrations.deepspeed import _load_state_dict_into_zero3_model + _load_state_dict_into_zero3_model(model, state_dict) + else: + model.load_state_dict(state_dict, assign=True) + # Why do we call `to()`? + # Because some models override the behavior of `to()`, + # especially those from libraries like Transformers. + model = model.to(dtype=torch_dtype, device=device) + if hasattr(model, "eval"): + model = model.eval() + return model + + +def load_model_with_disk_offload(model_class, path, config=None, torch_dtype=torch.bfloat16, device="cpu", state_dict_converter=None, module_map=None): + if isinstance(path, str): + path = [path] + config = {} if config is None else config + with skip_model_initialization(): + model = model_class(**config) + if hasattr(model, "eval"): + model = model.eval() + disk_map = DiskMap(path, device, state_dict_converter=state_dict_converter) + vram_config = { + "offload_dtype": "disk", + "offload_device": "disk", + "onload_dtype": "disk", + "onload_device": "disk", + "preparing_dtype": torch.float8_e4m3fn, + "preparing_device": device, + "computation_dtype": torch_dtype, + "computation_device": device, + } + enable_vram_management(model, module_map, vram_config=vram_config, disk_map=disk_map, vram_limit=80) + return model + + +def get_init_context(torch_dtype, device, skip_zero3=False): + if is_deepspeed_zero3_enabled() and not skip_zero3: + from transformers.modeling_utils import set_zero3_state + import deepspeed + # Why do we use "deepspeed.zero.Init"? + # Weight segmentation of the model can be performed on the CPU side + # and loading the segmented weights onto the computing card + init_contexts = [deepspeed.zero.Init(remote_device=device, dtype=torch_dtype), set_zero3_state()] + elif skip_zero3: + # For models excluded from ZeRO-3 (e.g. VAE), use normal initialization + # instead of skip_model_initialization to avoid meta tensor issues + init_contexts = [] + else: + # Why do we use `skip_model_initialization`? + # It skips the random initialization of model parameters, + # thereby speeding up model loading and avoiding excessive memory usage. + init_contexts = [skip_model_initialization()] + + return init_contexts diff --git a/diffsynth/core/npu_patch/npu_fused_operator.py b/diffsynth/core/npu_patch/npu_fused_operator.py new file mode 100644 index 0000000000000000000000000000000000000000..71660414518ab90e59657813cc927a7e9368faeb --- /dev/null +++ b/diffsynth/core/npu_patch/npu_fused_operator.py @@ -0,0 +1,30 @@ +import torch +from ..device.npu_compatible_device import get_device_type +try: + import torch_npu +except: + pass + + +def rms_norm_forward_npu(self, hidden_states): + "npu rms fused operator for RMSNorm.forward from diffsynth\models\general_modules.py" + if hidden_states.dtype != self.weight.dtype: + hidden_states = hidden_states.to(self.weight.dtype) + return torch_npu.npu_rms_norm(hidden_states, self.weight, self.eps)[0] + + +def rms_norm_forward_transformers_npu(self, hidden_states): + "npu rms fused operator for transformers" + if hidden_states.dtype != self.weight.dtype: + hidden_states = hidden_states.to(self.weight.dtype) + return torch_npu.npu_rms_norm(hidden_states, self.weight, self.variance_epsilon)[0] + + +def rotary_emb_Zimage_npu(self, x_in: torch.Tensor, freqs_cis: torch.Tensor): + "npu rope fused operator for Zimage" + with torch.amp.autocast(get_device_type(), enabled=False): + freqs_cis = freqs_cis.unsqueeze(2) + cos, sin = torch.chunk(torch.view_as_real(freqs_cis), 2, dim=-1) + cos = cos.expand(-1, -1, -1, -1, 2).flatten(-2) + sin = sin.expand(-1, -1, -1, -1, 2).flatten(-2) + return torch_npu.npu_rotary_mul(x_in, cos, sin, rotary_mode="interleave").to(x_in) \ No newline at end of file diff --git a/diffsynth/core/vram/__init__.py b/diffsynth/core/vram/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..32763bb9b4abfa4d5b2617827661c520d7e9fcae --- /dev/null +++ b/diffsynth/core/vram/__init__.py @@ -0,0 +1,2 @@ +from .initialization import skip_model_initialization +from .layers import * diff --git a/diffsynth/core/vram/disk_map.py b/diffsynth/core/vram/disk_map.py new file mode 100644 index 0000000000000000000000000000000000000000..a666590fa99a9cc4de05dc3f5fa84c212e43de38 --- /dev/null +++ b/diffsynth/core/vram/disk_map.py @@ -0,0 +1,93 @@ +from safetensors import safe_open +import torch, os + + +class SafetensorsCompatibleTensor: + def __init__(self, tensor): + self.tensor = tensor + + def get_shape(self): + return list(self.tensor.shape) + + +class SafetensorsCompatibleBinaryLoader: + def __init__(self, path, device): + print("Detected non-safetensors files, which may cause slower loading. It's recommended to convert it to a safetensors file.") + self.state_dict = torch.load(path, weights_only=True, map_location=device) + + def keys(self): + return self.state_dict.keys() + + def get_tensor(self, name): + return self.state_dict[name] + + def get_slice(self, name): + return SafetensorsCompatibleTensor(self.state_dict[name]) + + +class DiskMap: + + def __init__(self, path, device, torch_dtype=None, state_dict_converter=None, buffer_size=10**9): + self.path = path if isinstance(path, list) else [path] + self.device = device + self.torch_dtype = torch_dtype + if os.environ.get('DIFFSYNTH_DISK_MAP_BUFFER_SIZE') is not None: + self.buffer_size = int(os.environ.get('DIFFSYNTH_DISK_MAP_BUFFER_SIZE')) + else: + self.buffer_size = buffer_size + self.files = [] + self.flush_files() + self.name_map = {} + for file_id, file in enumerate(self.files): + for name in file.keys(): + self.name_map[name] = file_id + self.rename_dict = self.fetch_rename_dict(state_dict_converter) + + def flush_files(self): + if len(self.files) == 0: + for path in self.path: + if path.endswith(".safetensors"): + self.files.append(safe_open(path, framework="pt", device=str(self.device))) + else: + self.files.append(SafetensorsCompatibleBinaryLoader(path, device=self.device)) + else: + for i, path in enumerate(self.path): + if path.endswith(".safetensors"): + self.files[i] = safe_open(path, framework="pt", device=str(self.device)) + self.num_params = 0 + + def __getitem__(self, name): + if self.rename_dict is not None: name = self.rename_dict[name] + file_id = self.name_map[name] + param = self.files[file_id].get_tensor(name) + if self.torch_dtype is not None and isinstance(param, torch.Tensor): + param = param.to(self.torch_dtype) + if isinstance(param, torch.Tensor) and param.device == "cpu": + param = param.clone() + if isinstance(param, torch.Tensor): + self.num_params += param.numel() + if self.num_params > self.buffer_size: + self.flush_files() + return param + + def fetch_rename_dict(self, state_dict_converter): + if state_dict_converter is None: + return None + state_dict = {} + for file in self.files: + for name in file.keys(): + state_dict[name] = name + state_dict = state_dict_converter(state_dict) + return state_dict + + def __iter__(self): + if self.rename_dict is not None: + return self.rename_dict.__iter__() + else: + return self.name_map.__iter__() + + def __contains__(self, x): + if self.rename_dict is not None: + return x in self.rename_dict + else: + return x in self.name_map diff --git a/diffsynth/core/vram/initialization.py b/diffsynth/core/vram/initialization.py new file mode 100644 index 0000000000000000000000000000000000000000..bff2498b526638bfdd1c114c78aa0b98c251a47d --- /dev/null +++ b/diffsynth/core/vram/initialization.py @@ -0,0 +1,21 @@ +import torch +from contextlib import contextmanager + + +@contextmanager +def skip_model_initialization(device=torch.device("meta")): + + def register_empty_parameter(module, name, param): + old_register_parameter(module, name, param) + if param is not None: + param_cls = type(module._parameters[name]) + kwargs = module._parameters[name].__dict__ + kwargs["requires_grad"] = param.requires_grad + module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs) + + old_register_parameter = torch.nn.Module.register_parameter + torch.nn.Module.register_parameter = register_empty_parameter + try: + yield + finally: + torch.nn.Module.register_parameter = old_register_parameter diff --git a/diffsynth/core/vram/layers.py b/diffsynth/core/vram/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..7afb360e75fdf132be7c2e9a01f5e4b6fdd07048 --- /dev/null +++ b/diffsynth/core/vram/layers.py @@ -0,0 +1,479 @@ +import torch, copy +from typing import Union +from .initialization import skip_model_initialization +from .disk_map import DiskMap +from ..device import parse_device_type, get_device_name, IS_NPU_AVAILABLE + + +class AutoTorchModule(torch.nn.Module): + + def __init__( + self, + offload_dtype: torch.dtype = None, + offload_device: Union[str, torch.device] = None, + onload_dtype: torch.dtype = None, + onload_device: Union[str, torch.device] = None, + preparing_dtype: torch.dtype = None, + preparing_device: Union[str, torch.device] = None, + computation_dtype: torch.dtype = None, + computation_device: Union[str, torch.device] = None, + vram_limit: float = None, + ): + super().__init__() + self.set_dtype_and_device( + offload_dtype, + offload_device, + onload_dtype, + onload_device, + preparing_dtype, + preparing_device, + computation_dtype, + computation_device, + vram_limit, + ) + self.state = 0 + self.name = "" + self.computation_device_type = parse_device_type(self.computation_device) + + def set_dtype_and_device( + self, + offload_dtype: torch.dtype = None, + offload_device: Union[str, torch.device] = None, + onload_dtype: torch.dtype = None, + onload_device: Union[str, torch.device] = None, + preparing_dtype: torch.dtype = None, + preparing_device: Union[str, torch.device] = None, + computation_dtype: torch.dtype = None, + computation_device: Union[str, torch.device] = None, + vram_limit: float = None, + ): + self.offload_dtype = offload_dtype or computation_dtype + self.offload_device = offload_device or computation_dtype + self.onload_dtype = onload_dtype or computation_dtype + self.onload_device = onload_device or computation_dtype + self.preparing_dtype = preparing_dtype or computation_dtype + self.preparing_device = preparing_device or computation_dtype + self.computation_dtype = computation_dtype + self.computation_device = computation_device + self.vram_limit = vram_limit + + def cast_to(self, weight, dtype, device): + r = torch.empty_like(weight, dtype=dtype, device=device) + r.copy_(weight) + return r + + def check_free_vram(self): + device = self.computation_device if not IS_NPU_AVAILABLE else get_device_name() + gpu_mem_state = getattr(torch, self.computation_device_type).mem_get_info(device) + used_memory = (gpu_mem_state[1] - gpu_mem_state[0]) / (1024**3) + return used_memory < self.vram_limit + + def offload(self): + if self.state != 0: + self.to(dtype=self.offload_dtype, device=self.offload_device) + self.state = 0 + + def onload(self): + if self.state != 1: + self.to(dtype=self.onload_dtype, device=self.onload_device) + self.state = 1 + + def param_name(self, name): + if self.name == "": + return name + else: + return self.name + "." + name + + +class AutoWrappedModule(AutoTorchModule): + + def __init__( + self, + module: torch.nn.Module, + offload_dtype: torch.dtype = None, + offload_device: Union[str, torch.device] = None, + onload_dtype: torch.dtype = None, + onload_device: Union[str, torch.device] = None, + preparing_dtype: torch.dtype = None, + preparing_device: Union[str, torch.device] = None, + computation_dtype: torch.dtype = None, + computation_device: Union[str, torch.device] = None, + vram_limit: float = None, + name: str = "", + disk_map: DiskMap = None, + **kwargs + ): + super().__init__( + offload_dtype, + offload_device, + onload_dtype, + onload_device, + preparing_dtype, + preparing_device, + computation_dtype, + computation_device, + vram_limit, + ) + self.module = module + if offload_dtype == "disk": + self.name = name + self.disk_map = disk_map + self.required_params = [name for name, _ in self.module.named_parameters()] + self.disk_offload = True + else: + self.disk_offload = False + + def load_from_disk(self, torch_dtype, device, copy_module=False): + if copy_module: + module = copy.deepcopy(self.module) + else: + module = self.module + state_dict = {} + for name in self.required_params: + param = self.disk_map[self.param_name(name)] + param = param.to(dtype=torch_dtype, device=device) + state_dict[name] = param + module.load_state_dict(state_dict, assign=True) + module.to(dtype=torch_dtype, device=device) + return module + + def offload_to_disk(self, model: torch.nn.Module): + for buf in model.buffers(): + # If there are some parameters are registed in buffers (not in state dict), + # We cannot offload the model. + for children in model.children(): + self.offload_to_disk(children) + break + else: + model.to("meta") + + def offload(self): + # offload / onload / preparing -> offload + if self.state != 0: + if self.disk_offload: + self.offload_to_disk(self.module) + else: + self.to(dtype=self.offload_dtype, device=self.offload_device) + self.state = 0 + + def onload(self): + # offload / onload / preparing -> onload + if self.state < 1: + if self.disk_offload and self.onload_device != "disk" and self.offload_device == "disk": + self.load_from_disk(self.onload_dtype, self.onload_device) + elif self.onload_device != "disk": + self.to(dtype=self.onload_dtype, device=self.onload_device) + self.state = 1 + + def preparing(self): + # onload / preparing -> preparing + if self.state != 2: + if self.disk_offload and self.preparing_device != "disk" and self.onload_device == "disk": + self.load_from_disk(self.preparing_dtype, self.preparing_device) + elif self.preparing_device != "disk": + self.to(dtype=self.preparing_dtype, device=self.preparing_device) + self.state = 2 + + def cast_to(self, module, dtype, device): + return copy.deepcopy(module).to(dtype=dtype, device=device) + + def computation(self): + # onload / preparing -> computation (temporary) + if self.state == 2: + torch_dtype, device = self.preparing_dtype, self.preparing_device + else: + torch_dtype, device = self.onload_dtype, self.onload_device + if torch_dtype == self.computation_dtype and device == self.computation_device: + module = self.module + elif self.disk_offload and device == "disk": + module = self.load_from_disk(self.computation_dtype, self.computation_device, copy_module=True) + else: + module = self.cast_to(self.module, dtype=self.computation_dtype, device=self.computation_device) + return module + + def forward(self, *args, **kwargs): + if self.state == 1 and (self.vram_limit is None or self.check_free_vram()): + self.preparing() + module = self.computation() + return module(*args, **kwargs) + + def __getattr__(self, name): + if name in self.__dict__ or name == "module": + return super().__getattr__(name) + else: + return getattr(self.module, name) + + +class AutoWrappedNonRecurseModule(AutoWrappedModule): + + def __init__( + self, + module: torch.nn.Module, + offload_dtype: torch.dtype = None, + offload_device: Union[str, torch.device] = None, + onload_dtype: torch.dtype = None, + onload_device: Union[str, torch.device] = None, + preparing_dtype: torch.dtype = None, + preparing_device: Union[str, torch.device] = None, + computation_dtype: torch.dtype = None, + computation_device: Union[str, torch.device] = None, + vram_limit: float = None, + name: str = "", + disk_map: DiskMap = None, + **kwargs + ): + super().__init__( + module, + offload_dtype, + offload_device, + onload_dtype, + onload_device, + preparing_dtype, + preparing_device, + computation_dtype, + computation_device, + vram_limit, + name, + disk_map, + **kwargs + ) + if self.disk_offload: + self.required_params = [name for name, _ in self.module.named_parameters(recurse=False)] + + def load_from_disk(self, torch_dtype, device, copy_module=False): + if copy_module: + module = copy.deepcopy(self.module) + else: + module = self.module + state_dict = {} + for name in self.required_params: + param = self.disk_map[self.param_name(name)] + param = param.to(dtype=torch_dtype, device=device) + state_dict[name] = param + module.load_state_dict(state_dict, assign=True, strict=False) + return module + + def offload_to_disk(self, model: torch.nn.Module): + for name in self.required_params: + getattr(self, name).to("meta") + + def cast_to(self, module, dtype, device): + # Parameter casting is implemented in the model architecture. + return module + + def __getattr__(self, name): + if name in self.__dict__ or name == "module": + return super().__getattr__(name) + else: + return getattr(self.module, name) + + +class AutoWrappedLinear(torch.nn.Linear, AutoTorchModule): + def __init__( + self, + module: torch.nn.Linear, + offload_dtype: torch.dtype = None, + offload_device: Union[str, torch.device] = None, + onload_dtype: torch.dtype = None, + onload_device: Union[str, torch.device] = None, + preparing_dtype: torch.dtype = None, + preparing_device: Union[str, torch.device] = None, + computation_dtype: torch.dtype = None, + computation_device: Union[str, torch.device] = None, + vram_limit: float = None, + name: str = "", + disk_map: DiskMap = None, + **kwargs + ): + with skip_model_initialization(): + super().__init__( + in_features=module.in_features, + out_features=module.out_features, + bias=module.bias is not None, + ) + self.set_dtype_and_device( + offload_dtype, + offload_device, + onload_dtype, + onload_device, + preparing_dtype, + preparing_device, + computation_dtype, + computation_device, + vram_limit, + ) + self.weight = module.weight + self.bias = module.bias + self.state = 0 + self.name = name + self.lora_A_weights = [] + self.lora_B_weights = [] + self.lora_merger = None + self.enable_fp8 = computation_dtype in [torch.float8_e4m3fn, torch.float8_e4m3fnuz] + self.computation_device_type = parse_device_type(self.computation_device) + + if offload_dtype == "disk": + self.disk_map = disk_map + self.disk_offload = True + else: + self.disk_offload = False + + def fp8_linear( + self, + input: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor = None, + ) -> torch.Tensor: + device = input.device + origin_dtype = input.dtype + origin_shape = input.shape + input = input.reshape(-1, origin_shape[-1]) + + x_max = torch.max(torch.abs(input), dim=-1, keepdim=True).values + fp8_max = 448.0 + # For float8_e4m3fnuz, the maximum representable value is half of that of e4m3fn. + # To avoid overflow and ensure numerical compatibility during FP8 computation, + # we scale down the input by 2.0 in advance. + # This scaling will be compensated later during the final result scaling. + if self.computation_dtype == torch.float8_e4m3fnuz: + fp8_max = fp8_max / 2.0 + scale_a = torch.clamp(x_max / fp8_max, min=1.0).float().to(device=device) + scale_b = torch.ones((weight.shape[0], 1)).to(device=device) + input = input / (scale_a + 1e-8) + input = input.to(self.computation_dtype) + weight = weight.to(self.computation_dtype) + bias = bias.to(torch.bfloat16) + + result = torch._scaled_mm( + input, + weight.T, + scale_a=scale_a, + scale_b=scale_b.T, + bias=bias, + out_dtype=origin_dtype, + ) + new_shape = origin_shape[:-1] + result.shape[-1:] + result = result.reshape(new_shape) + return result + + def load_from_disk(self, torch_dtype, device, assign=True): + weight = self.disk_map[self.name + ".weight"].to(dtype=torch_dtype, device=device) + bias = None if self.bias is None else self.disk_map[self.name + ".bias"].to(dtype=torch_dtype, device=device) + if assign: + state_dict = {"weight": weight} + if bias is not None: state_dict["bias"] = bias + self.load_state_dict(state_dict, assign=True) + return weight, bias + + def offload(self): + # offload / onload / preparing -> offload + if self.state != 0: + if self.disk_offload: + self.to("meta") + else: + self.to(dtype=self.offload_dtype, device=self.offload_device) + self.state = 0 + + def onload(self): + # offload / onload / preparing -> onload + if self.state < 1: + if self.disk_offload and self.onload_device != "disk" and self.offload_device == "disk": + self.load_from_disk(self.onload_dtype, self.onload_device) + elif self.onload_device != "disk": + self.to(dtype=self.onload_dtype, device=self.onload_device) + self.state = 1 + + def preparing(self): + # onload / preparing -> preparing + if self.state != 2: + if self.disk_offload and self.preparing_device != "disk" and self.onload_device == "disk": + self.load_from_disk(self.preparing_dtype, self.preparing_device) + elif self.preparing_device != "disk": + self.to(dtype=self.preparing_dtype, device=self.preparing_device) + self.state = 2 + + def computation(self): + # onload / preparing -> computation (temporary) + if self.state == 2: + torch_dtype, device = self.preparing_dtype, self.preparing_device + else: + torch_dtype, device = self.onload_dtype, self.onload_device + if torch_dtype == self.computation_dtype and device == self.computation_device: + weight, bias = self.weight, self.bias + elif self.disk_offload and device == "disk": + weight, bias = self.load_from_disk(self.computation_dtype, self.computation_device, assign=False) + else: + weight = self.cast_to(self.weight, self.computation_dtype, self.computation_device) + bias = None if self.bias is None else self.cast_to(self.bias, self.computation_dtype, self.computation_device) + return weight, bias + + def linear_forward(self, x, weight, bias): + if self.enable_fp8: + out = self.fp8_linear(x, weight, bias) + else: + out = torch.nn.functional.linear(x, weight, bias) + return out + + def lora_forward(self, x, out): + if self.lora_merger is None: + for lora_A, lora_B in zip(self.lora_A_weights, self.lora_B_weights): + out = out + x @ lora_A.T.to(device=x.device, dtype=x.dtype) @ lora_B.T.to(device=x.device, dtype=x.dtype) + else: + lora_output = [] + for lora_A, lora_B in zip(self.lora_A_weights, self.lora_B_weights): + lora_output.append(x @ lora_A.T @ lora_B.T) + lora_output = torch.stack(lora_output) + out = self.lora_merger(out, lora_output) + return out + + def forward(self, x, *args, **kwargs): + if self.state == 1 and (self.vram_limit is None or self.check_free_vram()): + self.preparing() + weight, bias = self.computation() + out = self.linear_forward(x, weight, bias) + if len(self.lora_A_weights) > 0: + out = self.lora_forward(x, out) + return out + + +def enable_vram_management_recursively(model: torch.nn.Module, module_map: dict, vram_config: dict, vram_limit=None, name_prefix="", disk_map=None, **kwargs): + if isinstance(model, AutoWrappedNonRecurseModule): + model = model.module + for name, module in model.named_children(): + layer_name = name if name_prefix == "" else name_prefix + "." + name + for source_module, target_module in module_map.items(): + if isinstance(module, source_module): + module_ = target_module(module, **vram_config, vram_limit=vram_limit, name=layer_name, disk_map=disk_map, **kwargs) + if isinstance(module_, AutoWrappedNonRecurseModule): + enable_vram_management_recursively(module_, module_map, vram_config, vram_limit=vram_limit, name_prefix=layer_name, disk_map=disk_map, **kwargs) + setattr(model, name, module_) + break + else: + enable_vram_management_recursively(module, module_map, vram_config, vram_limit=vram_limit, name_prefix=layer_name, disk_map=disk_map, **kwargs) + + +def fill_vram_config(model, vram_config): + vram_config_ = vram_config.copy() + vram_config_["onload_dtype"] = vram_config["computation_dtype"] + vram_config_["onload_device"] = vram_config["computation_device"] + vram_config_["preparing_dtype"] = vram_config["computation_dtype"] + vram_config_["preparing_device"] = vram_config["computation_device"] + for k in vram_config: + if vram_config[k] != vram_config_[k]: + print(f"No fine-grained VRAM configuration is provided for {model.__class__.__name__}. [`onload`, `preparing`, `computation`] will be the same state. `vram_config` is set to {vram_config_}") + break + return vram_config_ + + +def enable_vram_management(model: torch.nn.Module, module_map: dict, vram_config: dict, vram_limit=None, disk_map=None, **kwargs): + for source_module, target_module in module_map.items(): + # If no fine-grained VRAM configuration is provided, the entire model will be managed uniformly. + if isinstance(model, source_module): + vram_config = fill_vram_config(model, vram_config) + model = target_module(model, **vram_config, vram_limit=vram_limit, disk_map=disk_map, **kwargs) + break + else: + enable_vram_management_recursively(model, module_map, vram_config, vram_limit=vram_limit, disk_map=disk_map, **kwargs) + # `vram_management_enabled` is a flag that allows the pipeline to determine whether VRAM management is enabled. + model.vram_management_enabled = True + return model diff --git a/diffsynth/diffusion/__init__.py b/diffsynth/diffusion/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0a4a0873a7b3d09e95aa00cfe340d653c58a834b --- /dev/null +++ b/diffsynth/diffusion/__init__.py @@ -0,0 +1,6 @@ +from .flow_match import FlowMatchScheduler +from .training_module import DiffusionTrainingModule +from .logger import ModelLogger +from .runner import launch_training_task, launch_data_process_task +from .parsers import * +from .loss import * diff --git a/diffsynth/diffusion/base_pipeline.py b/diffsynth/diffusion/base_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..588f765a7758fc0ace4c0a495c94b30e9fb00ec4 --- /dev/null +++ b/diffsynth/diffusion/base_pipeline.py @@ -0,0 +1,500 @@ +from PIL import Image +import torch +import numpy as np +from einops import repeat, reduce +from typing import Union +from ..core import AutoTorchModule, AutoWrappedLinear, load_state_dict, ModelConfig, parse_device_type +from ..core.device.npu_compatible_device import get_device_type +from ..utils.lora import GeneralLoRALoader +from ..models.model_loader import ModelPool +from ..utils.controlnet import ControlNetInput +from ..core.device import get_device_name, IS_NPU_AVAILABLE + + +class PipelineUnit: + def __init__( + self, + seperate_cfg: bool = False, + take_over: bool = False, + input_params: tuple[str] = None, + output_params: tuple[str] = None, + input_params_posi: dict[str, str] = None, + input_params_nega: dict[str, str] = None, + onload_model_names: tuple[str] = None + ): + self.seperate_cfg = seperate_cfg + self.take_over = take_over + self.input_params = input_params + self.output_params = output_params + self.input_params_posi = input_params_posi + self.input_params_nega = input_params_nega + self.onload_model_names = onload_model_names + + def fetch_input_params(self): + params = [] + if self.input_params is not None: + for param in self.input_params: + params.append(param) + if self.input_params_posi is not None: + for _, param in self.input_params_posi.items(): + params.append(param) + if self.input_params_nega is not None: + for _, param in self.input_params_nega.items(): + params.append(param) + params = sorted(list(set(params))) + return params + + def fetch_output_params(self): + params = [] + if self.output_params is not None: + for param in self.output_params: + params.append(param) + return params + + def process(self, pipe, **kwargs) -> dict: + return {} + + def post_process(self, pipe, **kwargs) -> dict: + return {} + + +class BasePipeline(torch.nn.Module): + + def __init__( + self, + device=get_device_type(), torch_dtype=torch.float16, + height_division_factor=64, width_division_factor=64, + time_division_factor=None, time_division_remainder=None, + ): + super().__init__() + # The device and torch_dtype is used for the storage of intermediate variables, not models. + self.device = device + self.torch_dtype = torch_dtype + self.device_type = parse_device_type(device) + # The following parameters are used for shape check. + self.height_division_factor = height_division_factor + self.width_division_factor = width_division_factor + self.time_division_factor = time_division_factor + self.time_division_remainder = time_division_remainder + # VRAM management + self.vram_management_enabled = False + # Pipeline Unit Runner + self.unit_runner = PipelineUnitRunner() + # LoRA Loader + self.lora_loader = GeneralLoRALoader + + + def to(self, *args, **kwargs): + device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs) + if device is not None: + self.device = device + if dtype is not None: + self.torch_dtype = dtype + super().to(*args, **kwargs) + return self + + + def check_resize_height_width(self, height, width, num_frames=None, verbose=1): + # Shape check + if height % self.height_division_factor != 0: + height = (height + self.height_division_factor - 1) // self.height_division_factor * self.height_division_factor + if verbose > 0: + print(f"height % {self.height_division_factor} != 0. We round it up to {height}.") + if width % self.width_division_factor != 0: + width = (width + self.width_division_factor - 1) // self.width_division_factor * self.width_division_factor + if verbose > 0: + print(f"width % {self.width_division_factor} != 0. We round it up to {width}.") + if num_frames is None: + return height, width + else: + if num_frames % self.time_division_factor != self.time_division_remainder: + num_frames = (num_frames + self.time_division_factor - 1) // self.time_division_factor * self.time_division_factor + self.time_division_remainder + if verbose > 0: + print(f"num_frames % {self.time_division_factor} != {self.time_division_remainder}. We round it up to {num_frames}.") + return height, width, num_frames + + + def preprocess_image(self, image, torch_dtype=None, device=None, pattern="B C H W", min_value=-1, max_value=1): + # Transform a PIL.Image to torch.Tensor + image = torch.Tensor(np.array(image, dtype=np.float32)) + image = image.to(dtype=torch_dtype or self.torch_dtype, device=device or self.device) + image = image * ((max_value - min_value) / 255) + min_value + image = repeat(image, f"H W C -> {pattern}", **({"B": 1} if "B" in pattern else {})) + return image + + + def preprocess_video(self, video, torch_dtype=None, device=None, pattern="B C T H W", min_value=-1, max_value=1): + # Transform a list of PIL.Image to torch.Tensor + video = [self.preprocess_image(image, torch_dtype=torch_dtype, device=device, min_value=min_value, max_value=max_value) for image in video] + video = torch.stack(video, dim=pattern.index("T") // 2) + return video + + + def vae_output_to_image(self, vae_output, pattern="B C H W", min_value=-1, max_value=1): + # Transform a torch.Tensor to PIL.Image + if pattern != "H W C": + vae_output = reduce(vae_output, f"{pattern} -> H W C", reduction="mean") + image = ((vae_output - min_value) * (255 / (max_value - min_value))).clip(0, 255) + image = image.to(device="cpu", dtype=torch.uint8) + image = Image.fromarray(image.numpy()) + return image + + + def vae_output_to_video(self, vae_output, pattern="B C T H W", min_value=-1, max_value=1): + # Transform a torch.Tensor to list of PIL.Image + if pattern != "T H W C": + vae_output = reduce(vae_output, f"{pattern} -> T H W C", reduction="mean") + video = [self.vae_output_to_image(image, pattern="H W C", min_value=min_value, max_value=max_value) for image in vae_output] + return video + + def output_audio_format_check(self, audio_output): + # output standard foramt: [C, T], output dtype: float() + # remove batch dim + if audio_output.ndim == 3: + audio_output = audio_output.squeeze(0) + return audio_output.float() + + def load_models_to_device(self, model_names): + if self.vram_management_enabled: + # offload models + for name, model in self.named_children(): + if name not in model_names: + if hasattr(model, "vram_management_enabled") and model.vram_management_enabled: + if hasattr(model, "offload"): + model.offload() + else: + for module in model.modules(): + if hasattr(module, "offload"): + module.offload() + getattr(torch, self.device_type).empty_cache() + # onload models + for name, model in self.named_children(): + if name in model_names: + if hasattr(model, "vram_management_enabled") and model.vram_management_enabled: + if hasattr(model, "onload"): + model.onload() + else: + for module in model.modules(): + if hasattr(module, "onload"): + module.onload() + + + def generate_noise(self, shape, seed=None, rand_device="cpu", rand_torch_dtype=torch.float32, device=None, torch_dtype=None): + # Initialize Gaussian noise + generator = None if seed is None else torch.Generator(rand_device).manual_seed(seed) + noise = torch.randn(shape, generator=generator, device=rand_device, dtype=rand_torch_dtype) + noise = noise.to(dtype=torch_dtype or self.torch_dtype, device=device or self.device) + return noise + + + def get_vram(self): + device = self.device if not IS_NPU_AVAILABLE else get_device_name() + return getattr(torch, self.device_type).mem_get_info(device)[1] / (1024 ** 3) + + def get_module(self, model, name): + if "." in name: + name, suffix = name[:name.index(".")], name[name.index(".") + 1:] + if name.isdigit(): + return self.get_module(model[int(name)], suffix) + else: + return self.get_module(getattr(model, name), suffix) + else: + return getattr(model, name) + + def freeze_except(self, model_names): + self.eval() + self.requires_grad_(False) + for name in model_names: + module = self.get_module(self, name) + if module is None: + print(f"No {name} models in the pipeline. We cannot enable training on the model. If this occurs during the data processing stage, it is normal.") + continue + module.train() + module.requires_grad_(True) + + + def blend_with_mask(self, base, addition, mask): + return base * (1 - mask) + addition * mask + + + def step(self, scheduler, latents, progress_id, noise_pred, input_latents=None, inpaint_mask=None, **kwargs): + timestep = scheduler.timesteps[progress_id] + if inpaint_mask is not None: + noise_pred_expected = scheduler.return_to_timestep(scheduler.timesteps[progress_id], latents, input_latents) + noise_pred = self.blend_with_mask(noise_pred_expected, noise_pred, inpaint_mask) + latents_next = scheduler.step(noise_pred, timestep, latents) + return latents_next + + + def split_pipeline_units(self, model_names: list[str]): + return PipelineUnitGraph().split_pipeline_units(self.units, model_names) + + + def flush_vram_management_device(self, device): + for module in self.modules(): + if isinstance(module, AutoTorchModule): + module.offload_device = device + module.onload_device = device + module.preparing_device = device + module.computation_device = device + + + def load_lora( + self, + module: torch.nn.Module, + lora_config: Union[ModelConfig, str] = None, + alpha=1, + hotload=None, + state_dict=None, + verbose=1, + ): + if state_dict is None: + if isinstance(lora_config, str): + lora = load_state_dict(lora_config, torch_dtype=self.torch_dtype, device=self.device) + else: + lora_config.download_if_necessary() + lora = load_state_dict(lora_config.path, torch_dtype=self.torch_dtype, device=self.device) + else: + lora = state_dict + lora_loader = self.lora_loader(torch_dtype=self.torch_dtype, device=self.device) + lora = lora_loader.convert_state_dict(lora) + if hotload is None: + hotload = hasattr(module, "vram_management_enabled") and getattr(module, "vram_management_enabled") + if hotload: + if not (hasattr(module, "vram_management_enabled") and getattr(module, "vram_management_enabled")): + raise ValueError("VRAM Management is not enabled. LoRA hotloading is not supported.") + updated_num = 0 + for _, module in module.named_modules(): + if isinstance(module, AutoWrappedLinear): + name = module.name + lora_a_name = f'{name}.lora_A.weight' + lora_b_name = f'{name}.lora_B.weight' + if lora_a_name in lora and lora_b_name in lora: + updated_num += 1 + module.lora_A_weights.append(lora[lora_a_name] * alpha) + module.lora_B_weights.append(lora[lora_b_name]) + if verbose >= 1: + print(f"{updated_num} tensors are patched by LoRA. You can use `pipe.clear_lora()` to clear all LoRA layers.") + else: + lora_loader.fuse_lora_to_base_model(module, lora, alpha=alpha) + + + def clear_lora(self, verbose=1): + cleared_num = 0 + for name, module in self.named_modules(): + if isinstance(module, AutoWrappedLinear): + if hasattr(module, "lora_A_weights"): + if len(module.lora_A_weights) > 0: + cleared_num += 1 + module.lora_A_weights.clear() + if hasattr(module, "lora_B_weights"): + module.lora_B_weights.clear() + if verbose >= 1: + print(f"{cleared_num} LoRA layers are cleared.") + + + def download_and_load_models(self, model_configs: list[ModelConfig] = [], vram_limit: float = None): + model_pool = ModelPool() + for model_config in model_configs: + model_config.download_if_necessary() + vram_config = model_config.vram_config() + vram_config["computation_dtype"] = vram_config["computation_dtype"] or self.torch_dtype + vram_config["computation_device"] = vram_config["computation_device"] or self.device + model_pool.auto_load_model( + model_config.path, + vram_config=vram_config, + vram_limit=vram_limit, + clear_parameters=model_config.clear_parameters, + state_dict=model_config.state_dict, + ) + return model_pool + + + def check_vram_management_state(self): + vram_management_enabled = False + for module in self.children(): + if hasattr(module, "vram_management_enabled") and getattr(module, "vram_management_enabled"): + vram_management_enabled = True + return vram_management_enabled + + + def cfg_guided_model_fn(self, model_fn, cfg_scale, inputs_shared, inputs_posi, inputs_nega, **inputs_others): + if inputs_shared.get("positive_only_lora", None) is not None: + self.clear_lora(verbose=0) + self.load_lora(self.dit, state_dict=inputs_shared["positive_only_lora"], verbose=0) + noise_pred_posi = model_fn(**inputs_posi, **inputs_shared, **inputs_others) + if cfg_scale != 1.0: + if inputs_shared.get("positive_only_lora", None) is not None: + self.clear_lora(verbose=0) + noise_pred_nega = model_fn(**inputs_nega, **inputs_shared, **inputs_others) + if isinstance(noise_pred_posi, tuple): + # Separately handling different output types of latents, eg. video and audio latents. + noise_pred = tuple( + n_nega + cfg_scale * (n_posi - n_nega) + for n_posi, n_nega in zip(noise_pred_posi, noise_pred_nega) + ) + else: + noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega) + else: + noise_pred = noise_pred_posi + return noise_pred + + def compile_pipeline(self, mode: str = "default", dynamic: bool = True, fullgraph: bool = False, compile_models: list = None, **kwargs): + """ + compile the pipeline with torch.compile. The models that will be compiled are determined by the `compilable_models` attribute of the pipeline. + If a model has `_repeated_blocks` attribute, we will compile these blocks with regional compilation. Otherwise, we will compile the whole model. + See https://docs.pytorch.org/docs/stable/generated/torch.compile.html#torch.compile for details about compilation arguments. + Args: + mode: The compilation mode, which will be passed to `torch.compile`, options are "default", "reduce-overhead", "max-autotune" and "max-autotune-no-cudagraphs. Default to "default". + dynamic: Whether to enable dynamic graph compilation to support dynamic input shapes, which will be passed to `torch.compile`. Default to True (recommended). + fullgraph: Whether to use full graph compilation, which will be passed to `torch.compile`. Default to False (recommended). + compile_models: The list of model names to be compiled. If None, we will compile the models in `pipeline.compilable_models`. Default to None. + **kwargs: Other arguments for `torch.compile`. + """ + compile_models = compile_models or getattr(self, "compilable_models", []) + if len(compile_models) == 0: + print("No compilable models in the pipeline. Skip compilation.") + return + for name in compile_models: + model = getattr(self, name, None) + if model is None: + print(f"Model '{name}' not found in the pipeline.") + continue + repeated_blocks = getattr(model, "_repeated_blocks", None) + # regional compilation for repeated blocks. + if repeated_blocks is not None: + for submod in model.modules(): + if submod.__class__.__name__ in repeated_blocks: + submod.compile(mode=mode, dynamic=dynamic, fullgraph=fullgraph, **kwargs) + # compile the whole model. + else: + model.compile(mode=mode, dynamic=dynamic, fullgraph=fullgraph, **kwargs) + print(f"{name} is compiled with mode={mode}, dynamic={dynamic}, fullgraph={fullgraph}.") + + +class PipelineUnitGraph: + def __init__(self): + pass + + def build_edges(self, units: list[PipelineUnit]): + # Establish dependencies between units + # to search for subsequent related computation units. + last_compute_unit_id = {} + edges = [] + for unit_id, unit in enumerate(units): + for input_param in unit.fetch_input_params(): + if input_param in last_compute_unit_id: + edges.append((last_compute_unit_id[input_param], unit_id)) + for output_param in unit.fetch_output_params(): + last_compute_unit_id[output_param] = unit_id + return edges + + def build_chains(self, units: list[PipelineUnit]): + # Establish updating chains for each variable + # to track their computation process. + params = sum([unit.fetch_input_params() + unit.fetch_output_params() for unit in units], []) + params = sorted(list(set(params))) + chains = {param: [] for param in params} + for unit_id, unit in enumerate(units): + for param in unit.fetch_output_params(): + chains[param].append(unit_id) + return chains + + def search_direct_unit_ids(self, units: list[PipelineUnit], model_names: list[str]): + # Search for units that directly participate in the model's computation. + related_unit_ids = [] + for unit_id, unit in enumerate(units): + for model_name in model_names: + if unit.onload_model_names is not None and model_name in unit.onload_model_names: + related_unit_ids.append(unit_id) + break + return related_unit_ids + + def search_related_unit_ids(self, edges, start_unit_ids, direction="target"): + # Search for subsequent related computation units. + related_unit_ids = [unit_id for unit_id in start_unit_ids] + while True: + neighbors = [] + for source, target in edges: + if direction == "target" and source in related_unit_ids and target not in related_unit_ids: + neighbors.append(target) + elif direction == "source" and source not in related_unit_ids and target in related_unit_ids: + neighbors.append(source) + neighbors = sorted(list(set(neighbors))) + if len(neighbors) == 0: + break + else: + related_unit_ids.extend(neighbors) + related_unit_ids = sorted(list(set(related_unit_ids))) + return related_unit_ids + + def search_updating_unit_ids(self, units: list[PipelineUnit], chains, related_unit_ids): + # If the input parameters of this subgraph are updated outside the subgraph, + # search for the units where these updates occur. + first_compute_unit_id = {} + for unit_id in related_unit_ids: + for param in units[unit_id].fetch_input_params(): + if param not in first_compute_unit_id: + first_compute_unit_id[param] = unit_id + updating_unit_ids = [] + for param in first_compute_unit_id: + unit_id = first_compute_unit_id[param] + chain = chains[param] + if unit_id in chain and chain.index(unit_id) != len(chain) - 1: + for unit_id_ in chain[chain.index(unit_id) + 1:]: + if unit_id_ not in related_unit_ids: + updating_unit_ids.append(unit_id_) + related_unit_ids.extend(updating_unit_ids) + related_unit_ids = sorted(list(set(related_unit_ids))) + return related_unit_ids + + def split_pipeline_units(self, units: list[PipelineUnit], model_names: list[str]): + # Split the computation graph, + # separating all model-related computations. + related_unit_ids = self.search_direct_unit_ids(units, model_names) + edges = self.build_edges(units) + chains = self.build_chains(units) + while True: + num_related_unit_ids = len(related_unit_ids) + related_unit_ids = self.search_related_unit_ids(edges, related_unit_ids, "target") + related_unit_ids = self.search_updating_unit_ids(units, chains, related_unit_ids) + if len(related_unit_ids) == num_related_unit_ids: + break + else: + num_related_unit_ids = len(related_unit_ids) + related_units = [units[i] for i in related_unit_ids] + unrelated_units = [units[i] for i in range(len(units)) if i not in related_unit_ids] + return related_units, unrelated_units + + +class PipelineUnitRunner: + def __init__(self): + pass + + def __call__(self, unit: PipelineUnit, pipe: BasePipeline, inputs_shared: dict, inputs_posi: dict, inputs_nega: dict) -> tuple[dict, dict]: + if unit.take_over: + # Let the pipeline unit take over this function. + inputs_shared, inputs_posi, inputs_nega = unit.process(pipe, inputs_shared=inputs_shared, inputs_posi=inputs_posi, inputs_nega=inputs_nega) + elif unit.seperate_cfg: + # Positive side + processor_inputs = {name: inputs_posi.get(name_) for name, name_ in unit.input_params_posi.items()} + if unit.input_params is not None: + for name in unit.input_params: + processor_inputs[name] = inputs_shared.get(name) + processor_outputs = unit.process(pipe, **processor_inputs) + inputs_posi.update(processor_outputs) + # Negative side + if inputs_shared["cfg_scale"] != 1: + processor_inputs = {name: inputs_nega.get(name_) for name, name_ in unit.input_params_nega.items()} + if unit.input_params is not None: + for name in unit.input_params: + processor_inputs[name] = inputs_shared.get(name) + processor_outputs = unit.process(pipe, **processor_inputs) + inputs_nega.update(processor_outputs) + else: + inputs_nega.update(processor_outputs) + else: + processor_inputs = {name: inputs_shared.get(name) for name in unit.input_params} + processor_outputs = unit.process(pipe, **processor_inputs) + inputs_shared.update(processor_outputs) + return inputs_shared, inputs_posi, inputs_nega diff --git a/diffsynth/diffusion/flow_match.py b/diffsynth/diffusion/flow_match.py new file mode 100644 index 0000000000000000000000000000000000000000..208fb1e0c3e24999f4e8e23f7c926358762a8ebc --- /dev/null +++ b/diffsynth/diffusion/flow_match.py @@ -0,0 +1,236 @@ +import torch, math +from typing_extensions import Literal + + +class FlowMatchScheduler(): + + def __init__(self, template: Literal["FLUX.1", "Wan", "Qwen-Image", "FLUX.2", "Z-Image", "LTX-2", "Qwen-Image-Lightning"] = "FLUX.1"): + self.set_timesteps_fn = { + "FLUX.1": FlowMatchScheduler.set_timesteps_flux, + "Wan": FlowMatchScheduler.set_timesteps_wan, + "Qwen-Image": FlowMatchScheduler.set_timesteps_qwen_image, + "FLUX.2": FlowMatchScheduler.set_timesteps_flux2, + "Z-Image": FlowMatchScheduler.set_timesteps_z_image, + "LTX-2": FlowMatchScheduler.set_timesteps_ltx2, + "Qwen-Image-Lightning": FlowMatchScheduler.set_timesteps_qwen_image_lightning, + }.get(template, FlowMatchScheduler.set_timesteps_flux) + self.num_train_timesteps = 1000 + + @staticmethod + def set_timesteps_flux(num_inference_steps=100, denoising_strength=1.0, shift=None): + sigma_min = 0.003/1.002 + sigma_max = 1.0 + shift = 3 if shift is None else shift + num_train_timesteps = 1000 + sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength + sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps) + sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) + timesteps = sigmas * num_train_timesteps + return sigmas, timesteps + + @staticmethod + def set_timesteps_wan(num_inference_steps=100, denoising_strength=1.0, shift=None): + sigma_min = 0.0 + sigma_max = 1.0 + shift = 5 if shift is None else shift + num_train_timesteps = 1000 + sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength + sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps + 1)[:-1] + sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) + timesteps = sigmas * num_train_timesteps + return sigmas, timesteps + + @staticmethod + def _calculate_shift_qwen_image(image_seq_len, base_seq_len=256, max_seq_len=8192, base_shift=0.5, max_shift=0.9): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + @staticmethod + def set_timesteps_qwen_image(num_inference_steps=100, denoising_strength=1.0, exponential_shift_mu=None, dynamic_shift_len=None): + sigma_min = 0.0 + sigma_max = 1.0 + num_train_timesteps = 1000 + shift_terminal = 0.02 + # Sigmas + sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength + sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps + 1)[:-1] + # Mu + if exponential_shift_mu is not None: + mu = exponential_shift_mu + elif dynamic_shift_len is not None: + mu = FlowMatchScheduler._calculate_shift_qwen_image(dynamic_shift_len) + else: + mu = 0.8 + sigmas = math.exp(mu) / (math.exp(mu) + (1 / sigmas - 1)) + # Shift terminal + one_minus_z = 1 - sigmas + scale_factor = one_minus_z[-1] / (1 - shift_terminal) + sigmas = 1 - (one_minus_z / scale_factor) + # Timesteps + timesteps = sigmas * num_train_timesteps + return sigmas, timesteps + + @staticmethod + def set_timesteps_qwen_image_lightning(num_inference_steps=100, denoising_strength=1.0, exponential_shift_mu=None, dynamic_shift_len=None): + sigma_min = 0.0 + sigma_max = 1.0 + num_train_timesteps = 1000 + base_shift = math.log(3) + max_shift = math.log(3) + # Sigmas + sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength + sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps + 1)[:-1] + # Mu + if exponential_shift_mu is not None: + mu = exponential_shift_mu + elif dynamic_shift_len is not None: + mu = FlowMatchScheduler._calculate_shift_qwen_image(dynamic_shift_len, base_shift=base_shift, max_shift=max_shift) + else: + mu = 0.8 + sigmas = math.exp(mu) / (math.exp(mu) + (1 / sigmas - 1)) + # Timesteps + timesteps = sigmas * num_train_timesteps + return sigmas, timesteps + + @staticmethod + def compute_empirical_mu(image_seq_len, num_steps): + a1, b1 = 8.73809524e-05, 1.89833333 + a2, b2 = 0.00016927, 0.45666666 + + if image_seq_len > 4300: + mu = a2 * image_seq_len + b2 + return float(mu) + + m_200 = a2 * image_seq_len + b2 + m_10 = a1 * image_seq_len + b1 + + a = (m_200 - m_10) / 190.0 + b = m_200 - 200.0 * a + mu = a * num_steps + b + + return float(mu) + + @staticmethod + def set_timesteps_flux2(num_inference_steps=100, denoising_strength=1.0, dynamic_shift_len=None): + sigma_min = 1 / num_inference_steps + sigma_max = 1.0 + num_train_timesteps = 1000 + sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength + sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps) + if dynamic_shift_len is None: + # If you ask me why I set mu=0.8, + # I can only say that it yields better training results. + mu = 0.8 + else: + mu = FlowMatchScheduler.compute_empirical_mu(dynamic_shift_len, num_inference_steps) + sigmas = math.exp(mu) / (math.exp(mu) + (1 / sigmas - 1)) + timesteps = sigmas * num_train_timesteps + return sigmas, timesteps + + @staticmethod + def set_timesteps_z_image(num_inference_steps=100, denoising_strength=1.0, shift=None, target_timesteps=None): + sigma_min = 0.0 + sigma_max = 1.0 + shift = 3 if shift is None else shift + num_train_timesteps = 1000 + sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength + sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps + 1)[:-1] + sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) + timesteps = sigmas * num_train_timesteps + if target_timesteps is not None: + target_timesteps = target_timesteps.to(dtype=timesteps.dtype, device=timesteps.device) + for timestep in target_timesteps: + timestep_id = torch.argmin((timesteps - timestep).abs()) + timesteps[timestep_id] = timestep + return sigmas, timesteps + + @staticmethod + def set_timesteps_ltx2(num_inference_steps=100, denoising_strength=1.0, dynamic_shift_len=None, terminal=0.1, special_case=None): + num_train_timesteps = 1000 + if special_case == "stage2": + sigmas = torch.Tensor([0.909375, 0.725, 0.421875]) + elif special_case == "ditilled_stage1": + sigmas = torch.Tensor([1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875]) + else: + dynamic_shift_len = dynamic_shift_len or 4096 + sigma_shift = FlowMatchScheduler._calculate_shift_qwen_image( + image_seq_len=dynamic_shift_len, + base_seq_len=1024, + max_seq_len=4096, + base_shift=0.95, + max_shift=2.05, + ) + sigma_min = 0.0 + sigma_max = 1.0 + sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength + sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps + 1)[:-1] + sigmas = math.exp(sigma_shift) / (math.exp(sigma_shift) + (1 / sigmas - 1)) + # Shift terminal + one_minus_z = 1.0 - sigmas + scale_factor = one_minus_z[-1] / (1 - terminal) + sigmas = 1.0 - (one_minus_z / scale_factor) + timesteps = sigmas * num_train_timesteps + return sigmas, timesteps + + def set_training_weight(self): + steps = 1000 + x = self.timesteps + y = torch.exp(-2 * ((x - steps / 2) / steps) ** 2) + y_shifted = y - y.min() + bsmntw_weighing = y_shifted * (steps / y_shifted.sum()) + if len(self.timesteps) != 1000: + # This is an empirical formula. + bsmntw_weighing = bsmntw_weighing * (len(self.timesteps) / steps) + bsmntw_weighing = bsmntw_weighing + bsmntw_weighing[1] + self.linear_timesteps_weights = bsmntw_weighing + + def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, training=False, **kwargs): + self.sigmas, self.timesteps = self.set_timesteps_fn( + num_inference_steps=num_inference_steps, + denoising_strength=denoising_strength, + **kwargs, + ) + if training: + self.set_training_weight() + self.training = True + else: + self.training = False + + def step(self, model_output, timestep, sample, to_final=False, **kwargs): + if isinstance(timestep, torch.Tensor): + timestep = timestep.cpu() + timestep_id = torch.argmin((self.timesteps - timestep).abs()) + sigma = self.sigmas[timestep_id] + if to_final or timestep_id + 1 >= len(self.timesteps): + sigma_ = 0 + else: + sigma_ = self.sigmas[timestep_id + 1] + prev_sample = sample + model_output * (sigma_ - sigma) + return prev_sample + + def return_to_timestep(self, timestep, sample, sample_stablized): + if isinstance(timestep, torch.Tensor): + timestep = timestep.cpu() + timestep_id = torch.argmin((self.timesteps - timestep).abs()) + sigma = self.sigmas[timestep_id] + model_output = (sample - sample_stablized) / sigma + return model_output + + def add_noise(self, original_samples, noise, timestep): + if isinstance(timestep, torch.Tensor): + timestep = timestep.cpu() + timestep_id = torch.argmin((self.timesteps - timestep).abs()) + sigma = self.sigmas[timestep_id] + sample = (1 - sigma) * original_samples + sigma * noise + return sample + + def training_target(self, sample, noise, timestep): + target = noise - sample + return target + + def training_weight(self, timestep): + timestep_id = torch.argmin((self.timesteps - timestep.to(self.timesteps.device)).abs()) + weights = self.linear_timesteps_weights[timestep_id] + return weights diff --git a/diffsynth/diffusion/logger.py b/diffsynth/diffusion/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..a884057debe817526c1cb169ded9682d5add5933 --- /dev/null +++ b/diffsynth/diffusion/logger.py @@ -0,0 +1,43 @@ +import os, torch +from accelerate import Accelerator + + +class ModelLogger: + def __init__(self, output_path, remove_prefix_in_ckpt=None, state_dict_converter=lambda x:x, resume_step=0): + self.output_path = output_path + self.remove_prefix_in_ckpt = remove_prefix_in_ckpt + self.state_dict_converter = state_dict_converter + self.num_steps = resume_step + + + def on_step_end(self, accelerator: Accelerator, model: torch.nn.Module, save_steps=None, **kwargs): + self.num_steps += 1 + if save_steps is not None and self.num_steps % save_steps == 0: + self.save_model(accelerator, model, f"step-{self.num_steps}.safetensors") + + + def on_epoch_end(self, accelerator: Accelerator, model: torch.nn.Module, epoch_id): + accelerator.wait_for_everyone() + state_dict = accelerator.get_state_dict(model) + if accelerator.is_main_process: + state_dict = accelerator.unwrap_model(model).export_trainable_state_dict(state_dict, remove_prefix=self.remove_prefix_in_ckpt) + state_dict = self.state_dict_converter(state_dict) + os.makedirs(self.output_path, exist_ok=True) + path = os.path.join(self.output_path, f"epoch-{epoch_id}.safetensors") + accelerator.save(state_dict, path, safe_serialization=True) + + + def on_training_end(self, accelerator: Accelerator, model: torch.nn.Module, save_steps=None): + if save_steps is not None and self.num_steps % save_steps != 0: + self.save_model(accelerator, model, f"step-{self.num_steps}.safetensors") + + + def save_model(self, accelerator: Accelerator, model: torch.nn.Module, file_name): + accelerator.wait_for_everyone() + state_dict = accelerator.get_state_dict(model) + if accelerator.is_main_process: + state_dict = accelerator.unwrap_model(model).export_trainable_state_dict(state_dict, remove_prefix=self.remove_prefix_in_ckpt) + state_dict = self.state_dict_converter(state_dict) + os.makedirs(self.output_path, exist_ok=True) + path = os.path.join(self.output_path, file_name) + accelerator.save(state_dict, path, safe_serialization=True) diff --git a/diffsynth/diffusion/loss.py b/diffsynth/diffusion/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..10ad3a0d2f06ac34e2a78f57d391db5d19c63175 --- /dev/null +++ b/diffsynth/diffusion/loss.py @@ -0,0 +1,158 @@ +from .base_pipeline import BasePipeline +import torch + + +def FlowMatchSFTLoss(pipe: BasePipeline, **inputs): + max_timestep_boundary = int(inputs.get("max_timestep_boundary", 1) * len(pipe.scheduler.timesteps)) + min_timestep_boundary = int(inputs.get("min_timestep_boundary", 0) * len(pipe.scheduler.timesteps)) + + timestep_id = torch.randint(min_timestep_boundary, max_timestep_boundary, (1,)) + timestep = pipe.scheduler.timesteps[timestep_id].to(dtype=pipe.torch_dtype, device=pipe.device) + + noise = torch.randn_like(inputs["input_latents"]) + inputs["latents"] = pipe.scheduler.add_noise(inputs["input_latents"], noise, timestep) + training_target = pipe.scheduler.training_target(inputs["input_latents"], noise, timestep) + + if "first_frame_latents" in inputs: + inputs["latents"][:, :, 0:1] = inputs["first_frame_latents"] + + models = {name: getattr(pipe, name) for name in pipe.in_iteration_models} + noise_pred = pipe.model_fn(**models, **inputs, timestep=timestep) + + if "first_frame_latents" in inputs: + noise_pred = noise_pred[:, :, 1:] + training_target = training_target[:, :, 1:] + + loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float()) + loss = loss * pipe.scheduler.training_weight(timestep) + return loss + + +def FlowMatchSFTAudioVideoLoss(pipe: BasePipeline, **inputs): + max_timestep_boundary = int(inputs.get("max_timestep_boundary", 1) * len(pipe.scheduler.timesteps)) + min_timestep_boundary = int(inputs.get("min_timestep_boundary", 0) * len(pipe.scheduler.timesteps)) + + timestep_id = torch.randint(min_timestep_boundary, max_timestep_boundary, (1,)) + timestep = pipe.scheduler.timesteps[timestep_id].to(dtype=pipe.torch_dtype, device=pipe.device) + + # video + noise = torch.randn_like(inputs["input_latents"]) + inputs["video_latents"] = pipe.scheduler.add_noise(inputs["input_latents"], noise, timestep) + training_target = pipe.scheduler.training_target(inputs["input_latents"], noise, timestep) + + # audio + if inputs.get("audio_input_latents") is not None: + audio_noise = torch.randn_like(inputs["audio_input_latents"]) + inputs["audio_latents"] = pipe.scheduler.add_noise(inputs["audio_input_latents"], audio_noise, timestep) + training_target_audio = pipe.scheduler.training_target(inputs["audio_input_latents"], audio_noise, timestep) + + models = {name: getattr(pipe, name) for name in pipe.in_iteration_models} + noise_pred, noise_pred_audio = pipe.model_fn(**models, **inputs, timestep=timestep) + + loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float()) + loss = loss * pipe.scheduler.training_weight(timestep) + if inputs.get("audio_input_latents") is not None: + loss_audio = torch.nn.functional.mse_loss(noise_pred_audio.float(), training_target_audio.float()) + loss_audio = loss_audio * pipe.scheduler.training_weight(timestep) + loss = loss + loss_audio + return loss + + +def DirectDistillLoss(pipe: BasePipeline, **inputs): + pipe.scheduler.set_timesteps(inputs["num_inference_steps"]) + pipe.scheduler.training = True + models = {name: getattr(pipe, name) for name in pipe.in_iteration_models} + for progress_id, timestep in enumerate(pipe.scheduler.timesteps): + timestep = timestep.unsqueeze(0).to(dtype=pipe.torch_dtype, device=pipe.device) + noise_pred = pipe.model_fn(**models, **inputs, timestep=timestep, progress_id=progress_id) + inputs["latents"] = pipe.step(pipe.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs) + loss = torch.nn.functional.mse_loss(inputs["latents"].float(), inputs["input_latents"].float()) + return loss + + +class TrajectoryImitationLoss(torch.nn.Module): + def __init__(self): + super().__init__() + self.initialized = False + + def initialize(self, device): + import lpips # TODO: remove it + self.loss_fn = lpips.LPIPS(net='alex').to(device) + self.initialized = True + + def fetch_trajectory(self, pipe: BasePipeline, timesteps_student, inputs_shared, inputs_posi, inputs_nega, num_inference_steps, cfg_scale): + trajectory = [inputs_shared["latents"].clone()] + + pipe.scheduler.set_timesteps(num_inference_steps, target_timesteps=timesteps_student) + models = {name: getattr(pipe, name) for name in pipe.in_iteration_models} + for progress_id, timestep in enumerate(pipe.scheduler.timesteps): + timestep = timestep.unsqueeze(0).to(dtype=pipe.torch_dtype, device=pipe.device) + noise_pred = pipe.cfg_guided_model_fn( + pipe.model_fn, cfg_scale, + inputs_shared, inputs_posi, inputs_nega, + **models, timestep=timestep, progress_id=progress_id + ) + inputs_shared["latents"] = pipe.step(pipe.scheduler, progress_id=progress_id, noise_pred=noise_pred.detach(), **inputs_shared) + + trajectory.append(inputs_shared["latents"].clone()) + return pipe.scheduler.timesteps, trajectory + + def align_trajectory(self, pipe: BasePipeline, timesteps_teacher, trajectory_teacher, inputs_shared, inputs_posi, inputs_nega, num_inference_steps, cfg_scale): + loss = 0 + pipe.scheduler.set_timesteps(num_inference_steps, training=True) + models = {name: getattr(pipe, name) for name in pipe.in_iteration_models} + for progress_id, timestep in enumerate(pipe.scheduler.timesteps): + timestep = timestep.unsqueeze(0).to(dtype=pipe.torch_dtype, device=pipe.device) + + progress_id_teacher = torch.argmin((timesteps_teacher - timestep).abs()) + inputs_shared["latents"] = trajectory_teacher[progress_id_teacher] + + noise_pred = pipe.cfg_guided_model_fn( + pipe.model_fn, cfg_scale, + inputs_shared, inputs_posi, inputs_nega, + **models, timestep=timestep, progress_id=progress_id + ) + + sigma = pipe.scheduler.sigmas[progress_id] + sigma_ = 0 if progress_id + 1 >= len(pipe.scheduler.timesteps) else pipe.scheduler.sigmas[progress_id + 1] + if progress_id + 1 >= len(pipe.scheduler.timesteps): + latents_ = trajectory_teacher[-1] + else: + progress_id_teacher = torch.argmin((timesteps_teacher - pipe.scheduler.timesteps[progress_id + 1]).abs()) + latents_ = trajectory_teacher[progress_id_teacher] + + denom = sigma_ - sigma + denom = torch.sign(denom) * torch.clamp(denom.abs(), min=1e-6) + target = (latents_ - inputs_shared["latents"]) / denom + loss = loss + torch.nn.functional.mse_loss(noise_pred.float(), target.float()) * pipe.scheduler.training_weight(timestep) + return loss + + def compute_regularization(self, pipe: BasePipeline, trajectory_teacher, inputs_shared, inputs_posi, inputs_nega, num_inference_steps, cfg_scale): + inputs_shared["latents"] = trajectory_teacher[0] + pipe.scheduler.set_timesteps(num_inference_steps) + models = {name: getattr(pipe, name) for name in pipe.in_iteration_models} + for progress_id, timestep in enumerate(pipe.scheduler.timesteps): + timestep = timestep.unsqueeze(0).to(dtype=pipe.torch_dtype, device=pipe.device) + noise_pred = pipe.cfg_guided_model_fn( + pipe.model_fn, cfg_scale, + inputs_shared, inputs_posi, inputs_nega, + **models, timestep=timestep, progress_id=progress_id + ) + inputs_shared["latents"] = pipe.step(pipe.scheduler, progress_id=progress_id, noise_pred=noise_pred.detach(), **inputs_shared) + + image_pred = pipe.vae_decoder(inputs_shared["latents"]) + image_real = pipe.vae_decoder(trajectory_teacher[-1]) + loss = self.loss_fn(image_pred.float(), image_real.float()) + return loss + + def forward(self, pipe: BasePipeline, inputs_shared, inputs_posi, inputs_nega): + if not self.initialized: + self.initialize(pipe.device) + with torch.no_grad(): + pipe.scheduler.set_timesteps(8) + timesteps_teacher, trajectory_teacher = self.fetch_trajectory(inputs_shared["teacher"], pipe.scheduler.timesteps, inputs_shared, inputs_posi, inputs_nega, 50, 2) + timesteps_teacher = timesteps_teacher.to(dtype=pipe.torch_dtype, device=pipe.device) + loss_1 = self.align_trajectory(pipe, timesteps_teacher, trajectory_teacher, inputs_shared, inputs_posi, inputs_nega, 8, 1) + loss_2 = self.compute_regularization(pipe, trajectory_teacher, inputs_shared, inputs_posi, inputs_nega, 8, 1) + loss = loss_1 + loss_2 + return loss diff --git a/diffsynth/diffusion/parsers.py b/diffsynth/diffusion/parsers.py new file mode 100644 index 0000000000000000000000000000000000000000..78d4bf4c2d772ff0d5ff006103c56bb62eff0069 --- /dev/null +++ b/diffsynth/diffusion/parsers.py @@ -0,0 +1,71 @@ +import argparse + + +def add_dataset_base_config(parser: argparse.ArgumentParser): + parser.add_argument("--dataset_base_path", type=str, default="", required=True, help="Base path of the dataset.") + parser.add_argument("--dataset_metadata_path", type=str, default=None, help="Path to the metadata file of the dataset.") + parser.add_argument("--dataset_repeat", type=int, default=1, help="Number of times to repeat the dataset per epoch.") + parser.add_argument("--dataset_num_workers", type=int, default=0, help="Number of workers for data loading.") + parser.add_argument("--data_file_keys", type=str, default="image,video", help="Data file keys in the metadata. Comma-separated.") + return parser + +def add_image_size_config(parser: argparse.ArgumentParser): + parser.add_argument("--height", type=int, default=None, help="Height of images. Leave `height` and `width` empty to enable dynamic resolution.") + parser.add_argument("--width", type=int, default=None, help="Width of images. Leave `height` and `width` empty to enable dynamic resolution.") + parser.add_argument("--max_pixels", type=int, default=1024*1024, help="Maximum number of pixels per frame, used for dynamic resolution.") + return parser + +def add_video_size_config(parser: argparse.ArgumentParser): + parser.add_argument("--height", type=int, default=None, help="Height of images. Leave `height` and `width` empty to enable dynamic resolution.") + parser.add_argument("--width", type=int, default=None, help="Width of images. Leave `height` and `width` empty to enable dynamic resolution.") + parser.add_argument("--max_pixels", type=int, default=1024*1024, help="Maximum number of pixels per frame, used for dynamic resolution.") + parser.add_argument("--num_frames", type=int, default=81, help="Number of frames per video. Frames are sampled from the video prefix.") + return parser + +def add_model_config(parser: argparse.ArgumentParser): + parser.add_argument("--model_paths", type=str, default=None, help="Paths to load models. In JSON format.") + parser.add_argument("--model_id_with_origin_paths", type=str, default=None, help="Model ID with origin paths, e.g., Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors. Comma-separated.") + parser.add_argument("--extra_inputs", default=None, help="Additional model inputs, comma-separated.") + parser.add_argument("--fp8_models", default=None, help="Models with FP8 precision, comma-separated.") + parser.add_argument("--offload_models", default=None, help="Models with offload, comma-separated. Only used in splited training.") + return parser + +def add_training_config(parser: argparse.ArgumentParser): + parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate.") + parser.add_argument("--num_epochs", type=int, default=1, help="Number of epochs.") + parser.add_argument("--trainable_models", type=str, default=None, help="Models to train, e.g., dit, vae, text_encoder.") + parser.add_argument("--find_unused_parameters", default=False, action="store_true", help="Whether to find unused parameters in DDP.") + parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay.") + parser.add_argument("--task", type=str, default="sft", required=False, help="Task type.") + return parser + +def add_output_config(parser: argparse.ArgumentParser): + parser.add_argument("--output_path", type=str, default="./models", help="Output save path.") + parser.add_argument("--remove_prefix_in_ckpt", type=str, default="pipe.dit.", help="Remove prefix in ckpt.") + parser.add_argument("--save_steps", type=int, default=None, help="Number of checkpoint saving invervals. If None, checkpoints will be saved every epoch.") + parser.add_argument("--resume_step", type=int, default=0, help="Starting step count when resuming. ModelLogger.num_steps initializes here; training stops when num_steps reaches num_epochs * steps_per_epoch.") + return parser + +def add_lora_config(parser: argparse.ArgumentParser): + parser.add_argument("--lora_base_model", type=str, default=None, help="Which model LoRA is added to.") + parser.add_argument("--lora_target_modules", type=str, default="q,k,v,o,ffn.0,ffn.2", help="Which layers LoRA is added to.") + parser.add_argument("--lora_rank", type=int, default=32, help="Rank of LoRA.") + parser.add_argument("--lora_checkpoint", type=str, default=None, help="Path to the LoRA checkpoint. If provided, LoRA will be loaded from this checkpoint.") + parser.add_argument("--preset_lora_path", type=str, default=None, help="Path to the preset LoRA checkpoint. If provided, this LoRA will be fused to the base model.") + parser.add_argument("--preset_lora_model", type=str, default=None, help="Which model the preset LoRA is fused to.") + return parser + +def add_gradient_config(parser: argparse.ArgumentParser): + parser.add_argument("--use_gradient_checkpointing", default=False, action="store_true", help="Whether to use gradient checkpointing.") + parser.add_argument("--use_gradient_checkpointing_offload", default=False, action="store_true", help="Whether to offload gradient checkpointing to CPU memory.") + parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Gradient accumulation steps.") + return parser + +def add_general_config(parser: argparse.ArgumentParser): + parser = add_dataset_base_config(parser) + parser = add_model_config(parser) + parser = add_training_config(parser) + parser = add_output_config(parser) + parser = add_lora_config(parser) + parser = add_gradient_config(parser) + return parser diff --git a/diffsynth/diffusion/runner.py b/diffsynth/diffusion/runner.py new file mode 100644 index 0000000000000000000000000000000000000000..f51a17a146fd948976a4ece6d13f518b58c894a0 --- /dev/null +++ b/diffsynth/diffusion/runner.py @@ -0,0 +1,135 @@ +import os, torch +from tqdm import tqdm +from accelerate import Accelerator +from .training_module import DiffusionTrainingModule +from .logger import ModelLogger + + +def launch_training_task( + accelerator: Accelerator, + dataset: torch.utils.data.Dataset, + model: DiffusionTrainingModule, + model_logger: ModelLogger, + learning_rate: float = 1e-5, + weight_decay: float = 1e-2, + num_workers: int = 1, + save_steps: int = None, + num_epochs: int = 1, + args = None, +): + if args is not None: + learning_rate = args.learning_rate + weight_decay = args.weight_decay + num_workers = args.dataset_num_workers + save_steps = args.save_steps + num_epochs = args.num_epochs + + optimizer = torch.optim.AdamW(model.trainable_modules(), lr=learning_rate, weight_decay=weight_decay) + scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer) + dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, collate_fn=lambda x: x[0], num_workers=num_workers) + model.to(device=accelerator.device) + # Exclude VAE from DeepSpeed ZeRO-3 wrapping to avoid compatibility issues + # Store VAE outside the module tree so DeepSpeed doesn't touch it + vae_module = getattr(model.pipe, 'vae', None) + if vae_module is not None: + del model.pipe._modules['vae'] + model, optimizer, dataloader, scheduler = accelerator.prepare(model, optimizer, dataloader, scheduler) + if vae_module is not None: + vae_module.to(accelerator.device) + # Store VAE as a non-module attribute so pipeline code can still use pipe.vae + pipe = model.module.pipe if hasattr(model, 'module') else model.pipe + # Use object.__setattr__ to bypass nn.Module's __setattr__ which would register it as a submodule + object.__setattr__(pipe, 'vae', vae_module) + initialize_deepspeed_gradient_checkpointing(accelerator) + # Training log file + log_path = os.path.join(model_logger.output_path, "training_log.txt") + if accelerator.is_main_process: + os.makedirs(model_logger.output_path, exist_ok=True) + log_file = open(log_path, "a") + log_file.write(f"Training started. Epochs: {num_epochs}, LR: {learning_rate}, Steps/epoch: {len(dataloader)}\n") + log_file.flush() + else: + log_file = None + + total_target = num_epochs * len(dataloader) + reached_target = False + for epoch_id in range(num_epochs): + if reached_target: + break + progress = tqdm( + total=total_target, + initial=model_logger.num_steps, + desc=f"Epoch {epoch_id+1}/{num_epochs}", + ) + for step_id, data in enumerate(dataloader): + if model_logger.num_steps >= total_target: + reached_target = True + break + with accelerator.accumulate(model): + optimizer.zero_grad() + if dataset.load_from_cache: + loss = model({}, inputs=data) + else: + loss = model(data) + accelerator.backward(loss) + optimizer.step() + model_logger.on_step_end(accelerator, model, save_steps, loss=loss) + scheduler.step() + + # Log loss + loss_val = loss.item() + progress.update(1) + progress.set_postfix(loss=f"{loss_val:.4f}") + if accelerator.is_main_process and log_file is not None and (model_logger.num_steps % 10 == 0 or model_logger.num_steps <= 5): + log_file.write(f"epoch={epoch_id+1} step={model_logger.num_steps} loss={loss_val:.6f}\n") + log_file.flush() + progress.close() + if save_steps is None: + model_logger.on_epoch_end(accelerator, model, epoch_id) + if accelerator.is_main_process and log_file is not None: + log_file.write(f"Epoch {epoch_id+1} completed. Checkpoint saved.\n") + log_file.flush() + model_logger.on_training_end(accelerator, model, save_steps) + if log_file is not None: + log_file.close() + + +def launch_data_process_task( + accelerator: Accelerator, + dataset: torch.utils.data.Dataset, + model: DiffusionTrainingModule, + model_logger: ModelLogger, + num_workers: int = 8, + args = None, +): + if args is not None: + num_workers = args.dataset_num_workers + + dataloader = torch.utils.data.DataLoader(dataset, shuffle=False, collate_fn=lambda x: x[0], num_workers=num_workers) + model.to(device=accelerator.device) + model, dataloader = accelerator.prepare(model, dataloader) + + for data_id, data in enumerate(tqdm(dataloader)): + with accelerator.accumulate(model): + with torch.no_grad(): + folder = os.path.join(model_logger.output_path, str(accelerator.process_index)) + os.makedirs(folder, exist_ok=True) + save_path = os.path.join(model_logger.output_path, str(accelerator.process_index), f"{data_id}.pth") + data = model(data) + torch.save(data, save_path) + + +def initialize_deepspeed_gradient_checkpointing(accelerator: Accelerator): + if getattr(accelerator.state, "deepspeed_plugin", None) is not None: + ds_config = accelerator.state.deepspeed_plugin.deepspeed_config + if "activation_checkpointing" in ds_config: + import deepspeed + act_config = ds_config["activation_checkpointing"] + deepspeed.checkpointing.configure( + mpu_=None, + partition_activations=act_config.get("partition_activations", False), + checkpoint_in_cpu=act_config.get("cpu_checkpointing", False), + contiguous_checkpointing=act_config.get("contiguous_memory_optimization", False) + ) + else: + print("Do not find activation_checkpointing config in deepspeed config, skip initializing deepspeed gradient checkpointing.") diff --git a/diffsynth/diffusion/training_module.py b/diffsynth/diffusion/training_module.py new file mode 100644 index 0000000000000000000000000000000000000000..0a00118049595d079eef61496c9f0a07f8e91d7d --- /dev/null +++ b/diffsynth/diffusion/training_module.py @@ -0,0 +1,302 @@ +import torch, json, os, inspect +from ..core import ModelConfig, load_state_dict +from ..utils.controlnet import ControlNetInput +from .base_pipeline import PipelineUnit +from peft import LoraConfig, inject_adapter_in_model + + +class GeneralUnit_RemoveCache(PipelineUnit): + def __init__(self, required_params=tuple(), force_remove_params_shared=tuple(), force_remove_params_posi=tuple(), force_remove_params_nega=tuple()): + super().__init__(take_over=True) + self.required_params = required_params + self.force_remove_params_shared = force_remove_params_shared + self.force_remove_params_posi = force_remove_params_posi + self.force_remove_params_nega = force_remove_params_nega + + def process_params(self, inputs, required_params, force_remove_params): + inputs_ = {} + for name, param in inputs.items(): + if name in required_params and name not in force_remove_params: + inputs_[name] = param + return inputs_ + + def process(self, pipe, inputs_shared, inputs_posi, inputs_nega): + inputs_shared = self.process_params(inputs_shared, self.required_params, self.force_remove_params_shared) + inputs_posi = self.process_params(inputs_posi, self.required_params, self.force_remove_params_posi) + inputs_nega = self.process_params(inputs_nega, self.required_params, self.force_remove_params_nega) + return inputs_shared, inputs_posi, inputs_nega + + +class DiffusionTrainingModule(torch.nn.Module): + def __init__(self): + super().__init__() + + + def to(self, *args, **kwargs): + for name, model in self.named_children(): + model.to(*args, **kwargs) + return self + + + def trainable_modules(self): + trainable_modules = filter(lambda p: p.requires_grad, self.parameters()) + return trainable_modules + + + def trainable_param_names(self): + trainable_param_names = list(filter(lambda named_param: named_param[1].requires_grad, self.named_parameters())) + trainable_param_names = set([named_param[0] for named_param in trainable_param_names]) + return trainable_param_names + + + def add_lora_to_model(self, model, target_modules, lora_rank, lora_alpha=None, upcast_dtype=None): + if lora_alpha is None: + lora_alpha = lora_rank + if isinstance(target_modules, list) and len(target_modules) == 1: + target_modules = target_modules[0] + lora_config = LoraConfig(r=lora_rank, lora_alpha=lora_alpha, target_modules=target_modules) + model = inject_adapter_in_model(lora_config, model) + if upcast_dtype is not None: + for param in model.parameters(): + if param.requires_grad: + param.data = param.to(upcast_dtype) + return model + + + def mapping_lora_state_dict(self, state_dict): + new_state_dict = {} + for key, value in state_dict.items(): + if "lora_A.weight" in key or "lora_B.weight" in key: + new_key = key.replace("lora_A.weight", "lora_A.default.weight").replace("lora_B.weight", "lora_B.default.weight") + new_state_dict[new_key] = value + elif "lora_A.default.weight" in key or "lora_B.default.weight" in key: + new_state_dict[key] = value + return new_state_dict + + + def export_trainable_state_dict(self, state_dict, remove_prefix=None): + trainable_param_names = self.trainable_param_names() + state_dict = {name: param for name, param in state_dict.items() if name in trainable_param_names} + if remove_prefix is not None: + state_dict_ = {} + for name, param in state_dict.items(): + if name.startswith(remove_prefix): + name = name[len(remove_prefix):] + state_dict_[name] = param + state_dict = state_dict_ + return state_dict + + + def transfer_data_to_device(self, data, device, torch_float_dtype=None): + if data is None: + return data + elif isinstance(data, torch.Tensor): + data = data.to(device) + if torch_float_dtype is not None and data.dtype in [torch.float, torch.float16, torch.bfloat16]: + data = data.to(torch_float_dtype) + return data + elif isinstance(data, tuple): + data = tuple(self.transfer_data_to_device(x, device, torch_float_dtype) for x in data) + return data + elif isinstance(data, list): + data = list(self.transfer_data_to_device(x, device, torch_float_dtype) for x in data) + return data + elif isinstance(data, dict): + data = {i: self.transfer_data_to_device(data[i], device, torch_float_dtype) for i in data} + return data + else: + return data + + def parse_vram_config(self, fp8=False, offload=False, device="cpu"): + if fp8: + return { + "offload_dtype": torch.float8_e4m3fn, + "offload_device": device, + "onload_dtype": torch.float8_e4m3fn, + "onload_device": device, + "preparing_dtype": torch.float8_e4m3fn, + "preparing_device": device, + "computation_dtype": torch.bfloat16, + "computation_device": device, + } + elif offload: + return { + "offload_dtype": "disk", + "offload_device": "disk", + "onload_dtype": "disk", + "onload_device": "disk", + "preparing_dtype": torch.bfloat16, + "preparing_device": device, + "computation_dtype": torch.bfloat16, + "computation_device": device, + "clear_parameters": True, + } + else: + return {} + + def parse_model_configs(self, model_paths, model_id_with_origin_paths, fp8_models=None, offload_models=None, device="cpu"): + fp8_models = [] if fp8_models is None else fp8_models.split(",") + offload_models = [] if offload_models is None else offload_models.split(",") + model_configs = [] + if model_paths is not None: + model_paths = json.loads(model_paths) + for path in model_paths: + vram_config = self.parse_vram_config( + fp8=path in fp8_models, + offload=path in offload_models, + device=device + ) + model_configs.append(ModelConfig(path=path, **vram_config)) + if model_id_with_origin_paths is not None: + model_id_with_origin_paths = model_id_with_origin_paths.split(",") + for model_id_with_origin_path in model_id_with_origin_paths: + vram_config = self.parse_vram_config( + fp8=model_id_with_origin_path in fp8_models, + offload=model_id_with_origin_path in offload_models, + device=device + ) + config = self.parse_path_or_model_id(model_id_with_origin_path) + model_configs.append(ModelConfig(model_id=config.model_id, origin_file_pattern=config.origin_file_pattern, **vram_config)) + return model_configs + + + def parse_path_or_model_id(self, model_id_with_origin_path, default_value=None): + if model_id_with_origin_path is None: + return default_value + elif os.path.exists(model_id_with_origin_path): + return ModelConfig(path=model_id_with_origin_path) + else: + if ":" not in model_id_with_origin_path: + raise ValueError(f"Failed to parse model config: {model_id_with_origin_path}. This is neither a valid path nor in the format of `model_id/origin_file_pattern`.") + split_id = model_id_with_origin_path.rfind(":") + model_id = model_id_with_origin_path[:split_id] + origin_file_pattern = model_id_with_origin_path[split_id + 1:] + return ModelConfig(model_id=model_id, origin_file_pattern=origin_file_pattern) + + + def auto_detect_lora_target_modules( + self, + model: torch.nn.Module, + search_for_linear=False, + linear_detector=lambda x: min(x.weight.shape) >= 512, + block_list_detector=lambda x: isinstance(x, torch.nn.ModuleList) and len(x) > 1, + name_prefix="", + ): + lora_target_modules = [] + if search_for_linear: + for name, module in model.named_modules(): + module_name = name_prefix + ["", "."][name_prefix != ""] + name + if isinstance(module, torch.nn.Linear) and linear_detector(module): + lora_target_modules.append(module_name) + else: + for name, module in model.named_children(): + module_name = name_prefix + ["", "."][name_prefix != ""] + name + lora_target_modules += self.auto_detect_lora_target_modules( + module, + search_for_linear=block_list_detector(module), + linear_detector=linear_detector, + block_list_detector=block_list_detector, + name_prefix=module_name, + ) + return lora_target_modules + + + def parse_lora_target_modules(self, model, lora_target_modules): + if lora_target_modules == "": + print("No LoRA target modules specified. The framework will automatically search for them.") + lora_target_modules = self.auto_detect_lora_target_modules(model) + print(f"LoRA will be patched at {lora_target_modules}.") + else: + lora_target_modules = lora_target_modules.split(",") + return lora_target_modules + + + def switch_pipe_to_training_mode( + self, + pipe, + trainable_models=None, + lora_base_model=None, lora_target_modules="", lora_rank=32, lora_checkpoint=None, + preset_lora_path=None, preset_lora_model=None, + task="sft", + ): + # Scheduler + pipe.scheduler.set_timesteps(1000, training=True) + + # Freeze untrainable models + pipe.freeze_except([] if trainable_models is None else trainable_models.split(",")) + + # Preset LoRA + if preset_lora_path is not None: + pipe.load_lora(getattr(pipe, preset_lora_model), preset_lora_path) + + # FP8 + # FP8 relies on a model-specific memory management scheme. + # It is delegated to the subclass. + + # Add LoRA to the base models + if lora_base_model is not None and not task.endswith(":data_process"): + if (not hasattr(pipe, lora_base_model)) or getattr(pipe, lora_base_model) is None: + print(f"No {lora_base_model} models in the pipeline. We cannot patch LoRA on the model. If this occurs during the data processing stage, it is normal.") + return + model = self.add_lora_to_model( + getattr(pipe, lora_base_model), + target_modules=self.parse_lora_target_modules(getattr(pipe, lora_base_model), lora_target_modules), + lora_rank=lora_rank, + upcast_dtype=pipe.torch_dtype, + ) + if lora_checkpoint is not None: + state_dict = load_state_dict(lora_checkpoint) + state_dict = self.mapping_lora_state_dict(state_dict) + load_result = model.load_state_dict(state_dict, strict=False) + print(f"LoRA checkpoint loaded: {lora_checkpoint}, total {len(state_dict)} keys") + if len(load_result[1]) > 0: + print(f"Warning, LoRA key mismatch! Unexpected keys in LoRA checkpoint: {load_result[1]}") + setattr(pipe, lora_base_model, model) + + + def split_pipeline_units( + self, task, pipe, + trainable_models=None, lora_base_model=None, + # TODO: set `remove_unnecessary_params` to `True` by default + remove_unnecessary_params=False, + # TODO: move `loss_required_params` to `loss.py` + loss_required_params=("input_latents", "max_timestep_boundary", "min_timestep_boundary", "first_frame_latents", "video_latents", "audio_input_latents", "num_inference_steps"), + force_remove_params_shared=tuple(), + force_remove_params_posi=tuple(), + force_remove_params_nega=tuple(), + ): + models_require_backward = [] + if trainable_models is not None: + models_require_backward += trainable_models.split(",") + if lora_base_model is not None: + models_require_backward += [lora_base_model] + if task.endswith(":data_process"): + other_units, pipe.units = pipe.split_pipeline_units(models_require_backward) + if remove_unnecessary_params: + required_params = list(loss_required_params) + [i for i in inspect.signature(self.pipe.model_fn).parameters] + for unit in other_units: + required_params.extend(unit.fetch_input_params()) + required_params = sorted(list(set(required_params))) + pipe.units.append(GeneralUnit_RemoveCache(required_params, force_remove_params_shared, force_remove_params_posi, force_remove_params_nega)) + elif task.endswith(":train"): + pipe.units, _ = pipe.split_pipeline_units(models_require_backward) + return pipe + + def parse_extra_inputs(self, data, extra_inputs, inputs_shared): + controlnet_keys_map = ( + ("blockwise_controlnet_", "blockwise_controlnet_inputs",), + ("controlnet_", "controlnet_inputs"), + ) + controlnet_inputs = {} + for extra_input in extra_inputs: + for prefix, name in controlnet_keys_map: + if extra_input.startswith(prefix): + if name not in controlnet_inputs: + controlnet_inputs[name] = {} + controlnet_inputs[name][extra_input.replace(prefix, "")] = data[extra_input] + break + else: + inputs_shared[extra_input] = data[extra_input] + for name, params in controlnet_inputs.items(): + inputs_shared[name] = [ControlNetInput(**params)] + return inputs_shared diff --git a/diffsynth/models/anima_dit.py b/diffsynth/models/anima_dit.py new file mode 100644 index 0000000000000000000000000000000000000000..d7519802beeabeff358f873fd3a6bd83cfc54ede --- /dev/null +++ b/diffsynth/models/anima_dit.py @@ -0,0 +1,1307 @@ +# original code from: comfy/ldm/cosmos/predict2.py + +import torch +from torch import nn +from einops import rearrange, repeat +from einops.layers.torch import Rearrange +import logging +from typing import Callable, Optional, Tuple, List +import math +from torchvision import transforms +from ..core.attention import attention_forward +from ..core.gradient import gradient_checkpoint_forward + + +class VideoPositionEmb(nn.Module): + def forward(self, x_B_T_H_W_C: torch.Tensor, fps=Optional[torch.Tensor], device=None, dtype=None) -> torch.Tensor: + """ + It delegates the embedding generation to generate_embeddings function. + """ + B_T_H_W_C = x_B_T_H_W_C.shape + embeddings = self.generate_embeddings(B_T_H_W_C, fps=fps, device=device, dtype=dtype) + + return embeddings + + def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor], device=None): + raise NotImplementedError + + +def normalize(x: torch.Tensor, dim: Optional[List[int]] = None, eps: float = 0) -> torch.Tensor: + """ + Normalizes the input tensor along specified dimensions such that the average square norm of elements is adjusted. + + Args: + x (torch.Tensor): The input tensor to normalize. + dim (list, optional): The dimensions over which to normalize. If None, normalizes over all dimensions except the first. + eps (float, optional): A small constant to ensure numerical stability during division. + + Returns: + torch.Tensor: The normalized tensor. + """ + if dim is None: + dim = list(range(1, x.ndim)) + norm = torch.linalg.vector_norm(x, dim=dim, keepdim=True, dtype=torch.float32) + norm = torch.add(eps, norm, alpha=math.sqrt(norm.numel() / x.numel())) + return x / norm.to(x.dtype) + + +class LearnablePosEmbAxis(VideoPositionEmb): + def __init__( + self, + *, # enforce keyword arguments + interpolation: str, + model_channels: int, + len_h: int, + len_w: int, + len_t: int, + device=None, + dtype=None, + **kwargs, + ): + """ + Args: + interpolation (str): we curretly only support "crop", ideally when we need extrapolation capacity, we should adjust frequency or other more advanced methods. they are not implemented yet. + """ + del kwargs # unused + super().__init__() + self.interpolation = interpolation + assert self.interpolation in ["crop"], f"Unknown interpolation method {self.interpolation}" + + self.pos_emb_h = nn.Parameter(torch.empty(len_h, model_channels, device=device, dtype=dtype)) + self.pos_emb_w = nn.Parameter(torch.empty(len_w, model_channels, device=device, dtype=dtype)) + self.pos_emb_t = nn.Parameter(torch.empty(len_t, model_channels, device=device, dtype=dtype)) + + def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor], device=None, dtype=None) -> torch.Tensor: + B, T, H, W, _ = B_T_H_W_C + if self.interpolation == "crop": + emb_h_H = self.pos_emb_h[:H].to(device=device, dtype=dtype) + emb_w_W = self.pos_emb_w[:W].to(device=device, dtype=dtype) + emb_t_T = self.pos_emb_t[:T].to(device=device, dtype=dtype) + emb = ( + repeat(emb_t_T, "t d-> b t h w d", b=B, h=H, w=W) + + repeat(emb_h_H, "h d-> b t h w d", b=B, t=T, w=W) + + repeat(emb_w_W, "w d-> b t h w d", b=B, t=T, h=H) + ) + assert list(emb.shape)[:4] == [B, T, H, W], f"bad shape: {list(emb.shape)[:4]} != {B, T, H, W}" + else: + raise ValueError(f"Unknown interpolation method {self.interpolation}") + + return normalize(emb, dim=-1, eps=1e-6) + + +class VideoRopePosition3DEmb(VideoPositionEmb): + def __init__( + self, + *, # enforce keyword arguments + head_dim: int, + len_h: int, + len_w: int, + len_t: int, + base_fps: int = 24, + h_extrapolation_ratio: float = 1.0, + w_extrapolation_ratio: float = 1.0, + t_extrapolation_ratio: float = 1.0, + enable_fps_modulation: bool = True, + device=None, + **kwargs, # used for compatibility with other positional embeddings; unused in this class + ): + del kwargs + super().__init__() + self.base_fps = base_fps + self.max_h = len_h + self.max_w = len_w + self.enable_fps_modulation = enable_fps_modulation + + dim = head_dim + dim_h = dim // 6 * 2 + dim_w = dim_h + dim_t = dim - 2 * dim_h + assert dim == dim_h + dim_w + dim_t, f"bad dim: {dim} != {dim_h} + {dim_w} + {dim_t}" + self.register_buffer( + "dim_spatial_range", + torch.arange(0, dim_h, 2, device=device)[: (dim_h // 2)].float() / dim_h, + persistent=False, + ) + self.register_buffer( + "dim_temporal_range", + torch.arange(0, dim_t, 2, device=device)[: (dim_t // 2)].float() / dim_t, + persistent=False, + ) + + self.h_ntk_factor = h_extrapolation_ratio ** (dim_h / (dim_h - 2)) + self.w_ntk_factor = w_extrapolation_ratio ** (dim_w / (dim_w - 2)) + self.t_ntk_factor = t_extrapolation_ratio ** (dim_t / (dim_t - 2)) + + def generate_embeddings( + self, + B_T_H_W_C: torch.Size, + fps: Optional[torch.Tensor] = None, + h_ntk_factor: Optional[float] = None, + w_ntk_factor: Optional[float] = None, + t_ntk_factor: Optional[float] = None, + device=None, + dtype=None, + ): + """ + Generate embeddings for the given input size. + + Args: + B_T_H_W_C (torch.Size): Input tensor size (Batch, Time, Height, Width, Channels). + fps (Optional[torch.Tensor], optional): Frames per second. Defaults to None. + h_ntk_factor (Optional[float], optional): Height NTK factor. If None, uses self.h_ntk_factor. + w_ntk_factor (Optional[float], optional): Width NTK factor. If None, uses self.w_ntk_factor. + t_ntk_factor (Optional[float], optional): Time NTK factor. If None, uses self.t_ntk_factor. + + Returns: + Not specified in the original code snippet. + """ + h_ntk_factor = h_ntk_factor if h_ntk_factor is not None else self.h_ntk_factor + w_ntk_factor = w_ntk_factor if w_ntk_factor is not None else self.w_ntk_factor + t_ntk_factor = t_ntk_factor if t_ntk_factor is not None else self.t_ntk_factor + + h_theta = 10000.0 * h_ntk_factor + w_theta = 10000.0 * w_ntk_factor + t_theta = 10000.0 * t_ntk_factor + + h_spatial_freqs = 1.0 / (h_theta**self.dim_spatial_range.to(device=device)) + w_spatial_freqs = 1.0 / (w_theta**self.dim_spatial_range.to(device=device)) + temporal_freqs = 1.0 / (t_theta**self.dim_temporal_range.to(device=device)) + + B, T, H, W, _ = B_T_H_W_C + seq = torch.arange(max(H, W, T), dtype=torch.float, device=device) + uniform_fps = (fps is None) or isinstance(fps, (int, float)) or (fps.min() == fps.max()) + assert ( + uniform_fps or B == 1 or T == 1 + ), "For video batch, batch size should be 1 for non-uniform fps. For image batch, T should be 1" + half_emb_h = torch.outer(seq[:H].to(device=device), h_spatial_freqs) + half_emb_w = torch.outer(seq[:W].to(device=device), w_spatial_freqs) + + # apply sequence scaling in temporal dimension + if fps is None or self.enable_fps_modulation is False: # image case + half_emb_t = torch.outer(seq[:T].to(device=device), temporal_freqs) + else: + half_emb_t = torch.outer(seq[:T].to(device=device) / fps * self.base_fps, temporal_freqs) + + half_emb_h = torch.stack([torch.cos(half_emb_h), -torch.sin(half_emb_h), torch.sin(half_emb_h), torch.cos(half_emb_h)], dim=-1) + half_emb_w = torch.stack([torch.cos(half_emb_w), -torch.sin(half_emb_w), torch.sin(half_emb_w), torch.cos(half_emb_w)], dim=-1) + half_emb_t = torch.stack([torch.cos(half_emb_t), -torch.sin(half_emb_t), torch.sin(half_emb_t), torch.cos(half_emb_t)], dim=-1) + + em_T_H_W_D = torch.cat( + [ + repeat(half_emb_t, "t d x -> t h w d x", h=H, w=W), + repeat(half_emb_h, "h d x -> t h w d x", t=T, w=W), + repeat(half_emb_w, "w d x -> t h w d x", t=T, h=H), + ] + , dim=-2, + ) + + return rearrange(em_T_H_W_D, "t h w d (i j) -> (t h w) d i j", i=2, j=2).float() + + +def apply_rotary_pos_emb( + t: torch.Tensor, + freqs: torch.Tensor, +) -> torch.Tensor: + t_ = t.reshape(*t.shape[:-1], 2, -1).movedim(-2, -1).unsqueeze(-2).float() + t_out = freqs[..., 0] * t_[..., 0] + freqs[..., 1] * t_[..., 1] + t_out = t_out.movedim(-1, -2).reshape(*t.shape).type_as(t) + return t_out + + +# ---------------------- Feed Forward Network ----------------------- +class GPT2FeedForward(nn.Module): + def __init__(self, d_model: int, d_ff: int, device=None, dtype=None, operations=None) -> None: + super().__init__() + self.activation = nn.GELU() + self.layer1 = operations.Linear(d_model, d_ff, bias=False, device=device, dtype=dtype) + self.layer2 = operations.Linear(d_ff, d_model, bias=False, device=device, dtype=dtype) + + self._layer_id = None + self._dim = d_model + self._hidden_dim = d_ff + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.layer1(x) + + x = self.activation(x) + x = self.layer2(x) + return x + + +def torch_attention_op(q_B_S_H_D: torch.Tensor, k_B_S_H_D: torch.Tensor, v_B_S_H_D: torch.Tensor, transformer_options: Optional[dict] = {}) -> torch.Tensor: + """Computes multi-head attention using PyTorch's native implementation. + + This function provides a PyTorch backend alternative to Transformer Engine's attention operation. + It rearranges the input tensors to match PyTorch's expected format, computes scaled dot-product + attention, and rearranges the output back to the original format. + + The input tensor names use the following dimension conventions: + + - B: batch size + - S: sequence length + - H: number of attention heads + - D: head dimension + + Args: + q_B_S_H_D: Query tensor with shape (batch, seq_len, n_heads, head_dim) + k_B_S_H_D: Key tensor with shape (batch, seq_len, n_heads, head_dim) + v_B_S_H_D: Value tensor with shape (batch, seq_len, n_heads, head_dim) + + Returns: + Attention output tensor with shape (batch, seq_len, n_heads * head_dim) + """ + in_q_shape = q_B_S_H_D.shape + in_k_shape = k_B_S_H_D.shape + q_B_H_S_D = rearrange(q_B_S_H_D, "b ... h k -> b h ... k").view(in_q_shape[0], in_q_shape[-2], -1, in_q_shape[-1]) + k_B_H_S_D = rearrange(k_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1]) + v_B_H_S_D = rearrange(v_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1]) + return attention_forward(q_B_H_S_D, k_B_H_S_D, v_B_H_S_D, out_pattern="b s (n d)") + + +class Attention(nn.Module): + """ + A flexible attention module supporting both self-attention and cross-attention mechanisms. + + This module implements a multi-head attention layer that can operate in either self-attention + or cross-attention mode. The mode is determined by whether a context dimension is provided. + The implementation uses scaled dot-product attention and supports optional bias terms and + dropout regularization. + + Args: + query_dim (int): The dimensionality of the query vectors. + context_dim (int, optional): The dimensionality of the context (key/value) vectors. + If None, the module operates in self-attention mode using query_dim. Default: None + n_heads (int, optional): Number of attention heads for multi-head attention. Default: 8 + head_dim (int, optional): The dimension of each attention head. Default: 64 + dropout (float, optional): Dropout probability applied to the output. Default: 0.0 + qkv_format (str, optional): Format specification for QKV tensors. Default: "bshd" + backend (str, optional): Backend to use for the attention operation. Default: "transformer_engine" + + Examples: + >>> # Self-attention with 512 dimensions and 8 heads + >>> self_attn = Attention(query_dim=512) + >>> x = torch.randn(32, 16, 512) # (batch_size, seq_len, dim) + >>> out = self_attn(x) # (32, 16, 512) + + >>> # Cross-attention + >>> cross_attn = Attention(query_dim=512, context_dim=256) + >>> query = torch.randn(32, 16, 512) + >>> context = torch.randn(32, 8, 256) + >>> out = cross_attn(query, context) # (32, 16, 512) + """ + + def __init__( + self, + query_dim: int, + context_dim: Optional[int] = None, + n_heads: int = 8, + head_dim: int = 64, + dropout: float = 0.0, + device=None, + dtype=None, + operations=None, + ) -> None: + super().__init__() + logging.debug( + f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using " + f"{n_heads} heads with a dimension of {head_dim}." + ) + self.is_selfattn = context_dim is None # self attention + + context_dim = query_dim if context_dim is None else context_dim + inner_dim = head_dim * n_heads + + self.n_heads = n_heads + self.head_dim = head_dim + self.query_dim = query_dim + self.context_dim = context_dim + + self.q_proj = operations.Linear(query_dim, inner_dim, bias=False, device=device, dtype=dtype) + self.q_norm = operations.RMSNorm(self.head_dim, eps=1e-6, device=device, dtype=dtype) + + self.k_proj = operations.Linear(context_dim, inner_dim, bias=False, device=device, dtype=dtype) + self.k_norm = operations.RMSNorm(self.head_dim, eps=1e-6, device=device, dtype=dtype) + + self.v_proj = operations.Linear(context_dim, inner_dim, bias=False, device=device, dtype=dtype) + self.v_norm = nn.Identity() + + self.output_proj = operations.Linear(inner_dim, query_dim, bias=False, device=device, dtype=dtype) + self.output_dropout = nn.Dropout(dropout) if dropout > 1e-4 else nn.Identity() + + self.attn_op = torch_attention_op + + self._query_dim = query_dim + self._context_dim = context_dim + self._inner_dim = inner_dim + + def compute_qkv( + self, + x: torch.Tensor, + context: Optional[torch.Tensor] = None, + rope_emb: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + q = self.q_proj(x) + context = x if context is None else context + k = self.k_proj(context) + v = self.v_proj(context) + q, k, v = map( + lambda t: rearrange(t, "b ... (h d) -> b ... h d", h=self.n_heads, d=self.head_dim), + (q, k, v), + ) + + def apply_norm_and_rotary_pos_emb( + q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, rope_emb: Optional[torch.Tensor] + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + q = self.q_norm(q) + k = self.k_norm(k) + v = self.v_norm(v) + if self.is_selfattn and rope_emb is not None: # only apply to self-attention! + q = apply_rotary_pos_emb(q, rope_emb) + k = apply_rotary_pos_emb(k, rope_emb) + return q, k, v + + q, k, v = apply_norm_and_rotary_pos_emb(q, k, v, rope_emb) + + return q, k, v + + def compute_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, transformer_options: Optional[dict] = {}) -> torch.Tensor: + result = self.attn_op(q, k, v, transformer_options=transformer_options) # [B, S, H, D] + return self.output_dropout(self.output_proj(result)) + + def forward( + self, + x: torch.Tensor, + context: Optional[torch.Tensor] = None, + rope_emb: Optional[torch.Tensor] = None, + transformer_options: Optional[dict] = {}, + ) -> torch.Tensor: + """ + Args: + x (Tensor): The query tensor of shape [B, Mq, K] + context (Optional[Tensor]): The key tensor of shape [B, Mk, K] or use x as context [self attention] if None + """ + q, k, v = self.compute_qkv(x, context, rope_emb=rope_emb) + return self.compute_attention(q, k, v, transformer_options=transformer_options) + + +class Timesteps(nn.Module): + def __init__(self, num_channels: int): + super().__init__() + self.num_channels = num_channels + + def forward(self, timesteps_B_T: torch.Tensor) -> torch.Tensor: + assert timesteps_B_T.ndim == 2, f"Expected 2D input, got {timesteps_B_T.ndim}" + timesteps = timesteps_B_T.flatten().float() + half_dim = self.num_channels // 2 + exponent = -math.log(10000) * torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) + exponent = exponent / (half_dim - 0.0) + + emb = torch.exp(exponent) + emb = timesteps[:, None].float() * emb[None, :] + + sin_emb = torch.sin(emb) + cos_emb = torch.cos(emb) + emb = torch.cat([cos_emb, sin_emb], dim=-1) + + return rearrange(emb, "(b t) d -> b t d", b=timesteps_B_T.shape[0], t=timesteps_B_T.shape[1]) + + +class TimestepEmbedding(nn.Module): + def __init__(self, in_features: int, out_features: int, use_adaln_lora: bool = False, device=None, dtype=None, operations=None): + super().__init__() + logging.debug( + f"Using AdaLN LoRA Flag: {use_adaln_lora}. We enable bias if no AdaLN LoRA for backward compatibility." + ) + self.in_dim = in_features + self.out_dim = out_features + self.linear_1 = operations.Linear(in_features, out_features, bias=not use_adaln_lora, device=device, dtype=dtype) + self.activation = nn.SiLU() + self.use_adaln_lora = use_adaln_lora + if use_adaln_lora: + self.linear_2 = operations.Linear(out_features, 3 * out_features, bias=False, device=device, dtype=dtype) + else: + self.linear_2 = operations.Linear(out_features, out_features, bias=False, device=device, dtype=dtype) + + def forward(self, sample: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + emb = self.linear_1(sample) + emb = self.activation(emb) + emb = self.linear_2(emb) + + if self.use_adaln_lora: + adaln_lora_B_T_3D = emb + emb_B_T_D = sample + else: + adaln_lora_B_T_3D = None + emb_B_T_D = emb + + return emb_B_T_D, adaln_lora_B_T_3D + + +class PatchEmbed(nn.Module): + """ + PatchEmbed is a module for embedding patches from an input tensor by applying either 3D or 2D convolutional layers, + depending on the . This module can process inputs with temporal (video) and spatial (image) dimensions, + making it suitable for video and image processing tasks. It supports dividing the input into patches + and embedding each patch into a vector of size `out_channels`. + + Parameters: + - spatial_patch_size (int): The size of each spatial patch. + - temporal_patch_size (int): The size of each temporal patch. + - in_channels (int): Number of input channels. Default: 3. + - out_channels (int): The dimension of the embedding vector for each patch. Default: 768. + - bias (bool): If True, adds a learnable bias to the output of the convolutional layers. Default: True. + """ + + def __init__( + self, + spatial_patch_size: int, + temporal_patch_size: int, + in_channels: int = 3, + out_channels: int = 768, + device=None, dtype=None, operations=None + ): + super().__init__() + self.spatial_patch_size = spatial_patch_size + self.temporal_patch_size = temporal_patch_size + + self.proj = nn.Sequential( + Rearrange( + "b c (t r) (h m) (w n) -> b t h w (c r m n)", + r=temporal_patch_size, + m=spatial_patch_size, + n=spatial_patch_size, + ), + operations.Linear( + in_channels * spatial_patch_size * spatial_patch_size * temporal_patch_size, out_channels, bias=False, device=device, dtype=dtype + ), + ) + self.dim = in_channels * spatial_patch_size * spatial_patch_size * temporal_patch_size + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass of the PatchEmbed module. + + Parameters: + - x (torch.Tensor): The input tensor of shape (B, C, T, H, W) where + B is the batch size, + C is the number of channels, + T is the temporal dimension, + H is the height, and + W is the width of the input. + + Returns: + - torch.Tensor: The embedded patches as a tensor, with shape b t h w c. + """ + assert x.dim() == 5 + _, _, T, H, W = x.shape + assert ( + H % self.spatial_patch_size == 0 and W % self.spatial_patch_size == 0 + ), f"H,W {(H, W)} should be divisible by spatial_patch_size {self.spatial_patch_size}" + assert T % self.temporal_patch_size == 0 + x = self.proj(x) + return x + + +class FinalLayer(nn.Module): + """ + The final layer of video DiT. + """ + + def __init__( + self, + hidden_size: int, + spatial_patch_size: int, + temporal_patch_size: int, + out_channels: int, + use_adaln_lora: bool = False, + adaln_lora_dim: int = 256, + device=None, dtype=None, operations=None + ): + super().__init__() + self.layer_norm = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = operations.Linear( + hidden_size, spatial_patch_size * spatial_patch_size * temporal_patch_size * out_channels, bias=False, device=device, dtype=dtype + ) + self.hidden_size = hidden_size + self.n_adaln_chunks = 2 + self.use_adaln_lora = use_adaln_lora + self.adaln_lora_dim = adaln_lora_dim + if use_adaln_lora: + self.adaln_modulation = nn.Sequential( + nn.SiLU(), + operations.Linear(hidden_size, adaln_lora_dim, bias=False, device=device, dtype=dtype), + operations.Linear(adaln_lora_dim, self.n_adaln_chunks * hidden_size, bias=False, device=device, dtype=dtype), + ) + else: + self.adaln_modulation = nn.Sequential( + nn.SiLU(), operations.Linear(hidden_size, self.n_adaln_chunks * hidden_size, bias=False, device=device, dtype=dtype) + ) + + def forward( + self, + x_B_T_H_W_D: torch.Tensor, + emb_B_T_D: torch.Tensor, + adaln_lora_B_T_3D: Optional[torch.Tensor] = None, + ): + if self.use_adaln_lora: + assert adaln_lora_B_T_3D is not None + shift_B_T_D, scale_B_T_D = ( + self.adaln_modulation(emb_B_T_D) + adaln_lora_B_T_3D[:, :, : 2 * self.hidden_size] + ).chunk(2, dim=-1) + else: + shift_B_T_D, scale_B_T_D = self.adaln_modulation(emb_B_T_D).chunk(2, dim=-1) + + shift_B_T_1_1_D, scale_B_T_1_1_D = rearrange(shift_B_T_D, "b t d -> b t 1 1 d"), rearrange( + scale_B_T_D, "b t d -> b t 1 1 d" + ) + + def _fn( + _x_B_T_H_W_D: torch.Tensor, + _norm_layer: nn.Module, + _scale_B_T_1_1_D: torch.Tensor, + _shift_B_T_1_1_D: torch.Tensor, + ) -> torch.Tensor: + return _norm_layer(_x_B_T_H_W_D) * (1 + _scale_B_T_1_1_D) + _shift_B_T_1_1_D + + x_B_T_H_W_D = _fn(x_B_T_H_W_D, self.layer_norm, scale_B_T_1_1_D, shift_B_T_1_1_D) + x_B_T_H_W_O = self.linear(x_B_T_H_W_D) + return x_B_T_H_W_O + + +class Block(nn.Module): + """ + A transformer block that combines self-attention, cross-attention and MLP layers with AdaLN modulation. + Each component (self-attention, cross-attention, MLP) has its own layer normalization and AdaLN modulation. + + Parameters: + x_dim (int): Dimension of input features + context_dim (int): Dimension of context features for cross-attention + num_heads (int): Number of attention heads + mlp_ratio (float): Multiplier for MLP hidden dimension. Default: 4.0 + use_adaln_lora (bool): Whether to use AdaLN-LoRA modulation. Default: False + adaln_lora_dim (int): Hidden dimension for AdaLN-LoRA layers. Default: 256 + + The block applies the following sequence: + 1. Self-attention with AdaLN modulation + 2. Cross-attention with AdaLN modulation + 3. MLP with AdaLN modulation + + Each component uses skip connections and layer normalization. + """ + + def __init__( + self, + x_dim: int, + context_dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + use_adaln_lora: bool = False, + adaln_lora_dim: int = 256, + device=None, + dtype=None, + operations=None, + ): + super().__init__() + self.x_dim = x_dim + self.layer_norm_self_attn = operations.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype) + self.self_attn = Attention(x_dim, None, num_heads, x_dim // num_heads, device=device, dtype=dtype, operations=operations) + + self.layer_norm_cross_attn = operations.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype) + self.cross_attn = Attention( + x_dim, context_dim, num_heads, x_dim // num_heads, device=device, dtype=dtype, operations=operations + ) + + self.layer_norm_mlp = operations.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype) + self.mlp = GPT2FeedForward(x_dim, int(x_dim * mlp_ratio), device=device, dtype=dtype, operations=operations) + + self.use_adaln_lora = use_adaln_lora + if self.use_adaln_lora: + self.adaln_modulation_self_attn = nn.Sequential( + nn.SiLU(), + operations.Linear(x_dim, adaln_lora_dim, bias=False, device=device, dtype=dtype), + operations.Linear(adaln_lora_dim, 3 * x_dim, bias=False, device=device, dtype=dtype), + ) + self.adaln_modulation_cross_attn = nn.Sequential( + nn.SiLU(), + operations.Linear(x_dim, adaln_lora_dim, bias=False, device=device, dtype=dtype), + operations.Linear(adaln_lora_dim, 3 * x_dim, bias=False, device=device, dtype=dtype), + ) + self.adaln_modulation_mlp = nn.Sequential( + nn.SiLU(), + operations.Linear(x_dim, adaln_lora_dim, bias=False, device=device, dtype=dtype), + operations.Linear(adaln_lora_dim, 3 * x_dim, bias=False, device=device, dtype=dtype), + ) + else: + self.adaln_modulation_self_attn = nn.Sequential(nn.SiLU(), operations.Linear(x_dim, 3 * x_dim, bias=False, device=device, dtype=dtype)) + self.adaln_modulation_cross_attn = nn.Sequential(nn.SiLU(), operations.Linear(x_dim, 3 * x_dim, bias=False, device=device, dtype=dtype)) + self.adaln_modulation_mlp = nn.Sequential(nn.SiLU(), operations.Linear(x_dim, 3 * x_dim, bias=False, device=device, dtype=dtype)) + + def forward( + self, + x_B_T_H_W_D: torch.Tensor, + emb_B_T_D: torch.Tensor, + crossattn_emb: torch.Tensor, + rope_emb_L_1_1_D: Optional[torch.Tensor] = None, + adaln_lora_B_T_3D: Optional[torch.Tensor] = None, + extra_per_block_pos_emb: Optional[torch.Tensor] = None, + transformer_options: Optional[dict] = {}, + ) -> torch.Tensor: + residual_dtype = x_B_T_H_W_D.dtype + compute_dtype = emb_B_T_D.dtype + if extra_per_block_pos_emb is not None: + x_B_T_H_W_D = x_B_T_H_W_D + extra_per_block_pos_emb + + if self.use_adaln_lora: + shift_self_attn_B_T_D, scale_self_attn_B_T_D, gate_self_attn_B_T_D = ( + self.adaln_modulation_self_attn(emb_B_T_D) + adaln_lora_B_T_3D + ).chunk(3, dim=-1) + shift_cross_attn_B_T_D, scale_cross_attn_B_T_D, gate_cross_attn_B_T_D = ( + self.adaln_modulation_cross_attn(emb_B_T_D) + adaln_lora_B_T_3D + ).chunk(3, dim=-1) + shift_mlp_B_T_D, scale_mlp_B_T_D, gate_mlp_B_T_D = ( + self.adaln_modulation_mlp(emb_B_T_D) + adaln_lora_B_T_3D + ).chunk(3, dim=-1) + else: + shift_self_attn_B_T_D, scale_self_attn_B_T_D, gate_self_attn_B_T_D = self.adaln_modulation_self_attn( + emb_B_T_D + ).chunk(3, dim=-1) + shift_cross_attn_B_T_D, scale_cross_attn_B_T_D, gate_cross_attn_B_T_D = self.adaln_modulation_cross_attn( + emb_B_T_D + ).chunk(3, dim=-1) + shift_mlp_B_T_D, scale_mlp_B_T_D, gate_mlp_B_T_D = self.adaln_modulation_mlp(emb_B_T_D).chunk(3, dim=-1) + + # Reshape tensors from (B, T, D) to (B, T, 1, 1, D) for broadcasting + shift_self_attn_B_T_1_1_D = rearrange(shift_self_attn_B_T_D, "b t d -> b t 1 1 d") + scale_self_attn_B_T_1_1_D = rearrange(scale_self_attn_B_T_D, "b t d -> b t 1 1 d") + gate_self_attn_B_T_1_1_D = rearrange(gate_self_attn_B_T_D, "b t d -> b t 1 1 d") + + shift_cross_attn_B_T_1_1_D = rearrange(shift_cross_attn_B_T_D, "b t d -> b t 1 1 d") + scale_cross_attn_B_T_1_1_D = rearrange(scale_cross_attn_B_T_D, "b t d -> b t 1 1 d") + gate_cross_attn_B_T_1_1_D = rearrange(gate_cross_attn_B_T_D, "b t d -> b t 1 1 d") + + shift_mlp_B_T_1_1_D = rearrange(shift_mlp_B_T_D, "b t d -> b t 1 1 d") + scale_mlp_B_T_1_1_D = rearrange(scale_mlp_B_T_D, "b t d -> b t 1 1 d") + gate_mlp_B_T_1_1_D = rearrange(gate_mlp_B_T_D, "b t d -> b t 1 1 d") + + B, T, H, W, D = x_B_T_H_W_D.shape + + def _fn(_x_B_T_H_W_D, _norm_layer, _scale_B_T_1_1_D, _shift_B_T_1_1_D): + return _norm_layer(_x_B_T_H_W_D) * (1 + _scale_B_T_1_1_D) + _shift_B_T_1_1_D + + normalized_x_B_T_H_W_D = _fn( + x_B_T_H_W_D, + self.layer_norm_self_attn, + scale_self_attn_B_T_1_1_D, + shift_self_attn_B_T_1_1_D, + ) + result_B_T_H_W_D = rearrange( + self.self_attn( + # normalized_x_B_T_HW_D, + rearrange(normalized_x_B_T_H_W_D.to(compute_dtype), "b t h w d -> b (t h w) d"), + None, + rope_emb=rope_emb_L_1_1_D, + transformer_options=transformer_options, + ), + "b (t h w) d -> b t h w d", + t=T, + h=H, + w=W, + ) + x_B_T_H_W_D = x_B_T_H_W_D + gate_self_attn_B_T_1_1_D.to(residual_dtype) * result_B_T_H_W_D.to(residual_dtype) + + def _x_fn( + _x_B_T_H_W_D: torch.Tensor, + layer_norm_cross_attn: Callable, + _scale_cross_attn_B_T_1_1_D: torch.Tensor, + _shift_cross_attn_B_T_1_1_D: torch.Tensor, + transformer_options: Optional[dict] = {}, + ) -> torch.Tensor: + _normalized_x_B_T_H_W_D = _fn( + _x_B_T_H_W_D, layer_norm_cross_attn, _scale_cross_attn_B_T_1_1_D, _shift_cross_attn_B_T_1_1_D + ) + _result_B_T_H_W_D = rearrange( + self.cross_attn( + rearrange(_normalized_x_B_T_H_W_D.to(compute_dtype), "b t h w d -> b (t h w) d"), + crossattn_emb, + rope_emb=rope_emb_L_1_1_D, + transformer_options=transformer_options, + ), + "b (t h w) d -> b t h w d", + t=T, + h=H, + w=W, + ) + return _result_B_T_H_W_D + + result_B_T_H_W_D = _x_fn( + x_B_T_H_W_D, + self.layer_norm_cross_attn, + scale_cross_attn_B_T_1_1_D, + shift_cross_attn_B_T_1_1_D, + transformer_options=transformer_options, + ) + x_B_T_H_W_D = result_B_T_H_W_D.to(residual_dtype) * gate_cross_attn_B_T_1_1_D.to(residual_dtype) + x_B_T_H_W_D + + normalized_x_B_T_H_W_D = _fn( + x_B_T_H_W_D, + self.layer_norm_mlp, + scale_mlp_B_T_1_1_D, + shift_mlp_B_T_1_1_D, + ) + result_B_T_H_W_D = self.mlp(normalized_x_B_T_H_W_D.to(compute_dtype)) + x_B_T_H_W_D = x_B_T_H_W_D + gate_mlp_B_T_1_1_D.to(residual_dtype) * result_B_T_H_W_D.to(residual_dtype) + return x_B_T_H_W_D + + +class MiniTrainDIT(nn.Module): + """ + A clean impl of DIT that can load and reproduce the training results of the original DIT model in~(cosmos 1) + A general implementation of adaln-modulated VIT-like~(DiT) transformer for video processing. + + Args: + max_img_h (int): Maximum height of the input images. + max_img_w (int): Maximum width of the input images. + max_frames (int): Maximum number of frames in the video sequence. + in_channels (int): Number of input channels (e.g., RGB channels for color images). + out_channels (int): Number of output channels. + patch_spatial (tuple): Spatial resolution of patches for input processing. + patch_temporal (int): Temporal resolution of patches for input processing. + concat_padding_mask (bool): If True, includes a mask channel in the input to handle padding. + model_channels (int): Base number of channels used throughout the model. + num_blocks (int): Number of transformer blocks. + num_heads (int): Number of heads in the multi-head attention layers. + mlp_ratio (float): Expansion ratio for MLP blocks. + crossattn_emb_channels (int): Number of embedding channels for cross-attention. + pos_emb_cls (str): Type of positional embeddings. + pos_emb_learnable (bool): Whether positional embeddings are learnable. + pos_emb_interpolation (str): Method for interpolating positional embeddings. + min_fps (int): Minimum frames per second. + max_fps (int): Maximum frames per second. + use_adaln_lora (bool): Whether to use AdaLN-LoRA. + adaln_lora_dim (int): Dimension for AdaLN-LoRA. + rope_h_extrapolation_ratio (float): Height extrapolation ratio for RoPE. + rope_w_extrapolation_ratio (float): Width extrapolation ratio for RoPE. + rope_t_extrapolation_ratio (float): Temporal extrapolation ratio for RoPE. + extra_per_block_abs_pos_emb (bool): Whether to use extra per-block absolute positional embeddings. + extra_h_extrapolation_ratio (float): Height extrapolation ratio for extra embeddings. + extra_w_extrapolation_ratio (float): Width extrapolation ratio for extra embeddings. + extra_t_extrapolation_ratio (float): Temporal extrapolation ratio for extra embeddings. + """ + + def __init__( + self, + max_img_h: int, + max_img_w: int, + max_frames: int, + in_channels: int, + out_channels: int, + patch_spatial: int, # tuple, + patch_temporal: int, + concat_padding_mask: bool = True, + # attention settings + model_channels: int = 768, + num_blocks: int = 10, + num_heads: int = 16, + mlp_ratio: float = 4.0, + # cross attention settings + crossattn_emb_channels: int = 1024, + # positional embedding settings + pos_emb_cls: str = "sincos", + pos_emb_learnable: bool = False, + pos_emb_interpolation: str = "crop", + min_fps: int = 1, + max_fps: int = 30, + use_adaln_lora: bool = False, + adaln_lora_dim: int = 256, + rope_h_extrapolation_ratio: float = 1.0, + rope_w_extrapolation_ratio: float = 1.0, + rope_t_extrapolation_ratio: float = 1.0, + extra_per_block_abs_pos_emb: bool = False, + extra_h_extrapolation_ratio: float = 1.0, + extra_w_extrapolation_ratio: float = 1.0, + extra_t_extrapolation_ratio: float = 1.0, + rope_enable_fps_modulation: bool = True, + image_model=None, + device=None, + dtype=None, + operations=None, + ) -> None: + super().__init__() + self.dtype = dtype + self.max_img_h = max_img_h + self.max_img_w = max_img_w + self.max_frames = max_frames + self.in_channels = in_channels + self.out_channels = out_channels + self.patch_spatial = patch_spatial + self.patch_temporal = patch_temporal + self.num_heads = num_heads + self.num_blocks = num_blocks + self.model_channels = model_channels + self.concat_padding_mask = concat_padding_mask + # positional embedding settings + self.pos_emb_cls = pos_emb_cls + self.pos_emb_learnable = pos_emb_learnable + self.pos_emb_interpolation = pos_emb_interpolation + self.min_fps = min_fps + self.max_fps = max_fps + self.rope_h_extrapolation_ratio = rope_h_extrapolation_ratio + self.rope_w_extrapolation_ratio = rope_w_extrapolation_ratio + self.rope_t_extrapolation_ratio = rope_t_extrapolation_ratio + self.extra_per_block_abs_pos_emb = extra_per_block_abs_pos_emb + self.extra_h_extrapolation_ratio = extra_h_extrapolation_ratio + self.extra_w_extrapolation_ratio = extra_w_extrapolation_ratio + self.extra_t_extrapolation_ratio = extra_t_extrapolation_ratio + self.rope_enable_fps_modulation = rope_enable_fps_modulation + + self.build_pos_embed(device=device, dtype=dtype) + self.use_adaln_lora = use_adaln_lora + self.adaln_lora_dim = adaln_lora_dim + self.t_embedder = nn.Sequential( + Timesteps(model_channels), + TimestepEmbedding(model_channels, model_channels, use_adaln_lora=use_adaln_lora, device=device, dtype=dtype, operations=operations,), + ) + + in_channels = in_channels + 1 if concat_padding_mask else in_channels + self.x_embedder = PatchEmbed( + spatial_patch_size=patch_spatial, + temporal_patch_size=patch_temporal, + in_channels=in_channels, + out_channels=model_channels, + device=device, dtype=dtype, operations=operations, + ) + + self.blocks = nn.ModuleList( + [ + Block( + x_dim=model_channels, + context_dim=crossattn_emb_channels, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + use_adaln_lora=use_adaln_lora, + adaln_lora_dim=adaln_lora_dim, + device=device, dtype=dtype, operations=operations, + ) + for _ in range(num_blocks) + ] + ) + + self.final_layer = FinalLayer( + hidden_size=self.model_channels, + spatial_patch_size=self.patch_spatial, + temporal_patch_size=self.patch_temporal, + out_channels=self.out_channels, + use_adaln_lora=self.use_adaln_lora, + adaln_lora_dim=self.adaln_lora_dim, + device=device, dtype=dtype, operations=operations, + ) + + self.t_embedding_norm = operations.RMSNorm(model_channels, eps=1e-6, device=device, dtype=dtype) + + def build_pos_embed(self, device=None, dtype=None) -> None: + if self.pos_emb_cls == "rope3d": + cls_type = VideoRopePosition3DEmb + else: + raise ValueError(f"Unknown pos_emb_cls {self.pos_emb_cls}") + + logging.debug(f"Building positional embedding with {self.pos_emb_cls} class, impl {cls_type}") + kwargs = dict( + model_channels=self.model_channels, + len_h=self.max_img_h // self.patch_spatial, + len_w=self.max_img_w // self.patch_spatial, + len_t=self.max_frames // self.patch_temporal, + max_fps=self.max_fps, + min_fps=self.min_fps, + is_learnable=self.pos_emb_learnable, + interpolation=self.pos_emb_interpolation, + head_dim=self.model_channels // self.num_heads, + h_extrapolation_ratio=self.rope_h_extrapolation_ratio, + w_extrapolation_ratio=self.rope_w_extrapolation_ratio, + t_extrapolation_ratio=self.rope_t_extrapolation_ratio, + enable_fps_modulation=self.rope_enable_fps_modulation, + device=device, + ) + self.pos_embedder = cls_type( + **kwargs, # type: ignore + ) + + if self.extra_per_block_abs_pos_emb: + kwargs["h_extrapolation_ratio"] = self.extra_h_extrapolation_ratio + kwargs["w_extrapolation_ratio"] = self.extra_w_extrapolation_ratio + kwargs["t_extrapolation_ratio"] = self.extra_t_extrapolation_ratio + kwargs["device"] = device + kwargs["dtype"] = dtype + self.extra_pos_embedder = LearnablePosEmbAxis( + **kwargs, # type: ignore + ) + + def prepare_embedded_sequence( + self, + x_B_C_T_H_W: torch.Tensor, + fps: Optional[torch.Tensor] = None, + padding_mask: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: + """ + Prepares an embedded sequence tensor by applying positional embeddings and handling padding masks. + + Args: + x_B_C_T_H_W (torch.Tensor): video + fps (Optional[torch.Tensor]): Frames per second tensor to be used for positional embedding when required. + If None, a default value (`self.base_fps`) will be used. + padding_mask (Optional[torch.Tensor]): current it is not used + + Returns: + Tuple[torch.Tensor, Optional[torch.Tensor]]: + - A tensor of shape (B, T, H, W, D) with the embedded sequence. + - An optional positional embedding tensor, returned only if the positional embedding class + (`self.pos_emb_cls`) includes 'rope'. Otherwise, None. + + Notes: + - If `self.concat_padding_mask` is True, a padding mask channel is concatenated to the input tensor. + - The method of applying positional embeddings depends on the value of `self.pos_emb_cls`. + - If 'rope' is in `self.pos_emb_cls` (case insensitive), the positional embeddings are generated using + the `self.pos_embedder` with the shape [T, H, W]. + - If "fps_aware" is in `self.pos_emb_cls`, the positional embeddings are generated using the + `self.pos_embedder` with the fps tensor. + - Otherwise, the positional embeddings are generated without considering fps. + """ + if self.concat_padding_mask: + if padding_mask is None: + padding_mask = torch.zeros(x_B_C_T_H_W.shape[0], 1, x_B_C_T_H_W.shape[3], x_B_C_T_H_W.shape[4], dtype=x_B_C_T_H_W.dtype, device=x_B_C_T_H_W.device) + else: + padding_mask = transforms.functional.resize( + padding_mask, list(x_B_C_T_H_W.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST + ) + x_B_C_T_H_W = torch.cat( + [x_B_C_T_H_W, padding_mask.unsqueeze(1).repeat(1, 1, x_B_C_T_H_W.shape[2], 1, 1)], dim=1 + ) + x_B_T_H_W_D = self.x_embedder(x_B_C_T_H_W) + + if self.extra_per_block_abs_pos_emb: + extra_pos_emb = self.extra_pos_embedder(x_B_T_H_W_D, fps=fps, device=x_B_C_T_H_W.device, dtype=x_B_C_T_H_W.dtype) + else: + extra_pos_emb = None + + if "rope" in self.pos_emb_cls.lower(): + return x_B_T_H_W_D, self.pos_embedder(x_B_T_H_W_D, fps=fps, device=x_B_C_T_H_W.device), extra_pos_emb + x_B_T_H_W_D = x_B_T_H_W_D + self.pos_embedder(x_B_T_H_W_D, device=x_B_C_T_H_W.device) # [B, T, H, W, D] + + return x_B_T_H_W_D, None, extra_pos_emb + + def unpatchify(self, x_B_T_H_W_M: torch.Tensor) -> torch.Tensor: + x_B_C_Tt_Hp_Wp = rearrange( + x_B_T_H_W_M, + "B T H W (p1 p2 t C) -> B C (T t) (H p1) (W p2)", + p1=self.patch_spatial, + p2=self.patch_spatial, + t=self.patch_temporal, + ) + return x_B_C_Tt_Hp_Wp + + def pad_to_patch_size(self, img, patch_size=(2, 2), padding_mode="circular"): + if padding_mode == "circular" and (torch.jit.is_tracing() or torch.jit.is_scripting()): + padding_mode = "reflect" + + pad = () + for i in range(img.ndim - 2): + pad = (0, (patch_size[i] - img.shape[i + 2] % patch_size[i]) % patch_size[i]) + pad + + return torch.nn.functional.pad(img, pad, mode=padding_mode) + + def forward( + self, + x: torch.Tensor, + timesteps: torch.Tensor, + context: torch.Tensor, + fps: Optional[torch.Tensor] = None, + padding_mask: Optional[torch.Tensor] = None, + use_gradient_checkpointing=False, + use_gradient_checkpointing_offload=False, + **kwargs, + ): + orig_shape = list(x.shape) + x = self.pad_to_patch_size(x, (self.patch_temporal, self.patch_spatial, self.patch_spatial)) + x_B_C_T_H_W = x + timesteps_B_T = timesteps + crossattn_emb = context + """ + Args: + x: (B, C, T, H, W) tensor of spatial-temp inputs + timesteps: (B, ) tensor of timesteps + crossattn_emb: (B, N, D) tensor of cross-attention embeddings + """ + x_B_T_H_W_D, rope_emb_L_1_1_D, extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = self.prepare_embedded_sequence( + x_B_C_T_H_W, + fps=fps, + padding_mask=padding_mask, + ) + + if timesteps_B_T.ndim == 1: + timesteps_B_T = timesteps_B_T.unsqueeze(1) + t_embedding_B_T_D, adaln_lora_B_T_3D = self.t_embedder[1](self.t_embedder[0](timesteps_B_T).to(x_B_T_H_W_D.dtype)) + t_embedding_B_T_D = self.t_embedding_norm(t_embedding_B_T_D) + + # for logging purpose + affline_scale_log_info = {} + affline_scale_log_info["t_embedding_B_T_D"] = t_embedding_B_T_D.detach() + self.affline_scale_log_info = affline_scale_log_info + self.affline_emb = t_embedding_B_T_D + self.crossattn_emb = crossattn_emb + + if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None: + assert ( + x_B_T_H_W_D.shape == extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape + ), f"{x_B_T_H_W_D.shape} != {extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape}" + + block_kwargs = { + "rope_emb_L_1_1_D": rope_emb_L_1_1_D.unsqueeze(1).unsqueeze(0), + "adaln_lora_B_T_3D": adaln_lora_B_T_3D, + "extra_per_block_pos_emb": extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, + "transformer_options": kwargs.get("transformer_options", {}), + } + + # The residual stream for this model has large values. To make fp16 compute_dtype work, we keep the residual stream + # in fp32, but run attention and MLP modules in fp16. + # An alternate method that clamps fp16 values "works" in the sense that it makes coherent images, but there is noticeable + # quality degradation and visual artifacts. + if x_B_T_H_W_D.dtype == torch.float16: + x_B_T_H_W_D = x_B_T_H_W_D.float() + + for block in self.blocks: + x_B_T_H_W_D = gradient_checkpoint_forward( + block, + use_gradient_checkpointing=use_gradient_checkpointing, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + x_B_T_H_W_D=x_B_T_H_W_D, + emb_B_T_D=t_embedding_B_T_D, + crossattn_emb=crossattn_emb, + **block_kwargs, + ) + + x_B_T_H_W_O = self.final_layer(x_B_T_H_W_D.to(crossattn_emb.dtype), t_embedding_B_T_D, adaln_lora_B_T_3D=adaln_lora_B_T_3D) + x_B_C_Tt_Hp_Wp = self.unpatchify(x_B_T_H_W_O)[:, :, :orig_shape[-3], :orig_shape[-2], :orig_shape[-1]] + return x_B_C_Tt_Hp_Wp + + +def rotate_half(x): + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb2(x, cos, sin, unsqueeze_dim=1): + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + x_embed = (x * cos) + (rotate_half(x) * sin) + return x_embed + + +class RotaryEmbedding(nn.Module): + def __init__(self, head_dim): + super().__init__() + self.rope_theta = 10000 + inv_freq = 1.0 / (self.rope_theta ** (torch.arange(0, head_dim, 2, dtype=torch.int64).to(dtype=torch.float) / head_dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + @torch.no_grad() + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class LLMAdapterAttention(nn.Module): + def __init__(self, query_dim, context_dim, n_heads, head_dim, device=None, dtype=None, operations=None): + super().__init__() + + inner_dim = head_dim * n_heads + self.n_heads = n_heads + self.head_dim = head_dim + self.query_dim = query_dim + self.context_dim = context_dim + + self.q_proj = operations.Linear(query_dim, inner_dim, bias=False, device=device, dtype=dtype) + self.q_norm = operations.RMSNorm(self.head_dim, eps=1e-6, device=device, dtype=dtype) + + self.k_proj = operations.Linear(context_dim, inner_dim, bias=False, device=device, dtype=dtype) + self.k_norm = operations.RMSNorm(self.head_dim, eps=1e-6, device=device, dtype=dtype) + + self.v_proj = operations.Linear(context_dim, inner_dim, bias=False, device=device, dtype=dtype) + + self.o_proj = operations.Linear(inner_dim, query_dim, bias=False, device=device, dtype=dtype) + + def forward(self, x, mask=None, context=None, position_embeddings=None, position_embeddings_context=None): + context = x if context is None else context + input_shape = x.shape[:-1] + q_shape = (*input_shape, self.n_heads, self.head_dim) + context_shape = context.shape[:-1] + kv_shape = (*context_shape, self.n_heads, self.head_dim) + + query_states = self.q_norm(self.q_proj(x).view(q_shape)).transpose(1, 2) + key_states = self.k_norm(self.k_proj(context).view(kv_shape)).transpose(1, 2) + value_states = self.v_proj(context).view(kv_shape).transpose(1, 2) + + if position_embeddings is not None: + assert position_embeddings_context is not None + cos, sin = position_embeddings + query_states = apply_rotary_pos_emb2(query_states, cos, sin) + cos, sin = position_embeddings_context + key_states = apply_rotary_pos_emb2(key_states, cos, sin) + + attn_output = torch.nn.functional.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask=mask) + + attn_output = attn_output.transpose(1, 2).reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output + + def init_weights(self): + torch.nn.init.zeros_(self.o_proj.weight) + + +class LLMAdapterTransformerBlock(nn.Module): + def __init__(self, source_dim, model_dim, num_heads=16, mlp_ratio=4.0, use_self_attn=False, layer_norm=False, device=None, dtype=None, operations=None): + super().__init__() + self.use_self_attn = use_self_attn + + if self.use_self_attn: + self.norm_self_attn = operations.LayerNorm(model_dim, device=device, dtype=dtype) if layer_norm else operations.RMSNorm(model_dim, eps=1e-6, device=device, dtype=dtype) + self.self_attn = LLMAdapterAttention( + query_dim=model_dim, + context_dim=model_dim, + n_heads=num_heads, + head_dim=model_dim//num_heads, + device=device, + dtype=dtype, + operations=operations, + ) + + self.norm_cross_attn = operations.LayerNorm(model_dim, device=device, dtype=dtype) if layer_norm else operations.RMSNorm(model_dim, eps=1e-6, device=device, dtype=dtype) + self.cross_attn = LLMAdapterAttention( + query_dim=model_dim, + context_dim=source_dim, + n_heads=num_heads, + head_dim=model_dim//num_heads, + device=device, + dtype=dtype, + operations=operations, + ) + + self.norm_mlp = operations.LayerNorm(model_dim, device=device, dtype=dtype) if layer_norm else operations.RMSNorm(model_dim, eps=1e-6, device=device, dtype=dtype) + self.mlp = nn.Sequential( + operations.Linear(model_dim, int(model_dim * mlp_ratio), device=device, dtype=dtype), + nn.GELU(), + operations.Linear(int(model_dim * mlp_ratio), model_dim, device=device, dtype=dtype) + ) + + def forward(self, x, context, target_attention_mask=None, source_attention_mask=None, position_embeddings=None, position_embeddings_context=None): + if self.use_self_attn: + normed = self.norm_self_attn(x) + attn_out = self.self_attn(normed, mask=target_attention_mask, position_embeddings=position_embeddings, position_embeddings_context=position_embeddings) + x = x + attn_out + + normed = self.norm_cross_attn(x) + attn_out = self.cross_attn(normed, mask=source_attention_mask, context=context, position_embeddings=position_embeddings, position_embeddings_context=position_embeddings_context) + x = x + attn_out + + x = x + self.mlp(self.norm_mlp(x)) + return x + + def init_weights(self): + torch.nn.init.zeros_(self.mlp[2].weight) + self.cross_attn.init_weights() + + +class LLMAdapter(nn.Module): + def __init__( + self, + source_dim=1024, + target_dim=1024, + model_dim=1024, + num_layers=6, + num_heads=16, + use_self_attn=True, + layer_norm=False, + device=None, + dtype=None, + operations=None, + ): + super().__init__() + + self.embed = operations.Embedding(32128, target_dim, device=device, dtype=dtype) + if model_dim != target_dim: + self.in_proj = operations.Linear(target_dim, model_dim, device=device, dtype=dtype) + else: + self.in_proj = nn.Identity() + self.rotary_emb = RotaryEmbedding(model_dim//num_heads) + self.blocks = nn.ModuleList([ + LLMAdapterTransformerBlock(source_dim, model_dim, num_heads=num_heads, use_self_attn=use_self_attn, layer_norm=layer_norm, device=device, dtype=dtype, operations=operations) for _ in range(num_layers) + ]) + self.out_proj = operations.Linear(model_dim, target_dim, device=device, dtype=dtype) + self.norm = operations.RMSNorm(target_dim, eps=1e-6, device=device, dtype=dtype) + + def forward(self, source_hidden_states, target_input_ids, target_attention_mask=None, source_attention_mask=None): + if target_attention_mask is not None: + target_attention_mask = target_attention_mask.to(torch.bool) + if target_attention_mask.ndim == 2: + target_attention_mask = target_attention_mask.unsqueeze(1).unsqueeze(1) + + if source_attention_mask is not None: + source_attention_mask = source_attention_mask.to(torch.bool) + if source_attention_mask.ndim == 2: + source_attention_mask = source_attention_mask.unsqueeze(1).unsqueeze(1) + + context = source_hidden_states + x = self.in_proj(self.embed(target_input_ids).to(context.dtype)) + position_ids = torch.arange(x.shape[1], device=x.device).unsqueeze(0) + position_ids_context = torch.arange(context.shape[1], device=x.device).unsqueeze(0) + position_embeddings = self.rotary_emb(x, position_ids) + position_embeddings_context = self.rotary_emb(x, position_ids_context) + for block in self.blocks: + x = block(x, context, target_attention_mask=target_attention_mask, source_attention_mask=source_attention_mask, position_embeddings=position_embeddings, position_embeddings_context=position_embeddings_context) + return self.norm(self.out_proj(x)) + + +class AnimaDiT(MiniTrainDIT): + + _repeated_blocks = ["Block"] + + def __init__(self): + kwargs = {'image_model': 'anima', 'max_img_h': 240, 'max_img_w': 240, 'max_frames': 128, 'in_channels': 16, 'out_channels': 16, 'patch_spatial': 2, 'patch_temporal': 1, 'model_channels': 2048, 'concat_padding_mask': True, 'crossattn_emb_channels': 1024, 'pos_emb_cls': 'rope3d', 'pos_emb_learnable': True, 'pos_emb_interpolation': 'crop', 'min_fps': 1, 'max_fps': 30, 'use_adaln_lora': True, 'adaln_lora_dim': 256, 'num_blocks': 28, 'num_heads': 16, 'extra_per_block_abs_pos_emb': False, 'rope_h_extrapolation_ratio': 4.0, 'rope_w_extrapolation_ratio': 4.0, 'rope_t_extrapolation_ratio': 1.0, 'extra_h_extrapolation_ratio': 1.0, 'extra_w_extrapolation_ratio': 1.0, 'extra_t_extrapolation_ratio': 1.0, 'rope_enable_fps_modulation': False, 'dtype': torch.bfloat16, 'device': None, 'operations': torch.nn} + super().__init__(**kwargs) + self.llm_adapter = LLMAdapter(device=kwargs.get("device"), dtype=kwargs.get("dtype"), operations=kwargs.get("operations")) + + def preprocess_text_embeds(self, text_embeds, text_ids, t5xxl_weights=None): + if text_ids is not None: + out = self.llm_adapter(text_embeds, text_ids) + if t5xxl_weights is not None: + out = out * t5xxl_weights + + if out.shape[1] < 512: + out = torch.nn.functional.pad(out, (0, 0, 0, 512 - out.shape[1])) + return out + else: + return text_embeds + + def forward( + self, + x, timesteps, context, + use_gradient_checkpointing=False, + use_gradient_checkpointing_offload=False, + **kwargs + ): + t5xxl_ids = kwargs.pop("t5xxl_ids", None) + if t5xxl_ids is not None: + context = self.preprocess_text_embeds(context, t5xxl_ids, t5xxl_weights=kwargs.pop("t5xxl_weights", None)) + return super().forward( + x, timesteps, context, + use_gradient_checkpointing=use_gradient_checkpointing, use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + **kwargs + ) diff --git a/diffsynth/models/dinov3_image_encoder.py b/diffsynth/models/dinov3_image_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..c394a0315306534f8209e74d6953d6061299bf23 --- /dev/null +++ b/diffsynth/models/dinov3_image_encoder.py @@ -0,0 +1,96 @@ +from transformers import DINOv3ViTModel, DINOv3ViTImageProcessorFast +from transformers.models.dinov3_vit.modeling_dinov3_vit import DINOv3ViTConfig +import torch + +from ..core.device.npu_compatible_device import get_device_type + + +class DINOv3ImageEncoder(DINOv3ViTModel): + def __init__(self): + config = DINOv3ViTConfig( + architectures = [ + "DINOv3ViTModel" + ], + attention_dropout = 0.0, + drop_path_rate = 0.0, + dtype = "float32", + hidden_act = "silu", + hidden_size = 4096, + image_size = 224, + initializer_range = 0.02, + intermediate_size = 8192, + key_bias = False, + layer_norm_eps = 1e-05, + layerscale_value = 1.0, + mlp_bias = True, + model_type = "dinov3_vit", + num_attention_heads = 32, + num_channels = 3, + num_hidden_layers = 40, + num_register_tokens = 4, + patch_size = 16, + pos_embed_jitter = None, + pos_embed_rescale = 2.0, + pos_embed_shift = None, + proj_bias = True, + query_bias = False, + rope_theta = 100.0, + transformers_version = "4.56.1", + use_gated_mlp = True, + value_bias = False + ) + super().__init__(config) + self.processor = DINOv3ViTImageProcessorFast( + crop_size = None, + data_format = "channels_first", + default_to_square = True, + device = None, + disable_grouping = None, + do_center_crop = None, + do_convert_rgb = None, + do_normalize = True, + do_rescale = True, + do_resize = True, + image_mean = [ + 0.485, + 0.456, + 0.406 + ], + image_processor_type = "DINOv3ViTImageProcessorFast", + image_std = [ + 0.229, + 0.224, + 0.225 + ], + input_data_format = None, + resample = 2, + rescale_factor = 0.00392156862745098, + return_tensors = None, + size = { + "height": 224, + "width": 224 + } + ) + + def forward(self, image, torch_dtype=torch.bfloat16, device=get_device_type()): + inputs = self.processor(images=image, return_tensors="pt") + pixel_values = inputs["pixel_values"].to(dtype=torch_dtype, device=device) + bool_masked_pos = None + head_mask = None + + pixel_values = pixel_values.to(torch_dtype) + hidden_states = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos) + position_embeddings = self.rope_embeddings(pixel_values) + + for i, layer_module in enumerate(self.layer): + layer_head_mask = head_mask[i] if head_mask is not None else None + hidden_states = layer_module( + hidden_states, + attention_mask=layer_head_mask, + position_embeddings=position_embeddings, + ) + + sequence_output = self.norm(hidden_states) + pooled_output = sequence_output[:, 0, :] + + return pooled_output diff --git a/diffsynth/models/flux2_dit.py b/diffsynth/models/flux2_dit.py new file mode 100644 index 0000000000000000000000000000000000000000..1eecadd6508858d09779b59b5ff8ef60744287e9 --- /dev/null +++ b/diffsynth/models/flux2_dit.py @@ -0,0 +1,1053 @@ +import inspect +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch, math +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from ..core.attention import attention_forward +from ..core.gradient import gradient_checkpoint_forward + + +def get_timestep_embedding( + timesteps: torch.Tensor, + embedding_dim: int, + flip_sin_to_cos: bool = False, + downscale_freq_shift: float = 1, + scale: float = 1, + max_period: int = 10000, +) -> torch.Tensor: + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. + + Args + timesteps (torch.Tensor): + a 1-D Tensor of N indices, one per batch element. These may be fractional. + embedding_dim (int): + the dimension of the output. + flip_sin_to_cos (bool): + Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False) + downscale_freq_shift (float): + Controls the delta between frequencies between dimensions + scale (float): + Scaling factor applied to the embeddings. + max_period (int): + Controls the maximum frequency of the embeddings + Returns + torch.Tensor: an [N x dim] Tensor of positional embeddings. + """ + assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" + + half_dim = embedding_dim // 2 + exponent = -math.log(max_period) * torch.arange( + start=0, end=half_dim, dtype=torch.float32, device=timesteps.device + ) + exponent = exponent / (half_dim - downscale_freq_shift) + + emb = torch.exp(exponent) + emb = timesteps[:, None].float() * emb[None, :] + + # scale embeddings + emb = scale * emb + + # concat sine and cosine embeddings + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) + + # flip sine and cosine embeddings + if flip_sin_to_cos: + emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) + + # zero pad + if embedding_dim % 2 == 1: + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + + +class TimestepEmbedding(nn.Module): + def __init__( + self, + in_channels: int, + time_embed_dim: int, + act_fn: str = "silu", + out_dim: int = None, + post_act_fn: Optional[str] = None, + cond_proj_dim=None, + sample_proj_bias=True, + ): + super().__init__() + + self.linear_1 = nn.Linear(in_channels, time_embed_dim, sample_proj_bias) + + if cond_proj_dim is not None: + self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False) + else: + self.cond_proj = None + + self.act = torch.nn.SiLU() + + if out_dim is not None: + time_embed_dim_out = out_dim + else: + time_embed_dim_out = time_embed_dim + self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias) + + if post_act_fn is None: + self.post_act = None + + def forward(self, sample, condition=None): + if condition is not None: + sample = sample + self.cond_proj(condition) + sample = self.linear_1(sample) + + if self.act is not None: + sample = self.act(sample) + + sample = self.linear_2(sample) + + if self.post_act is not None: + sample = self.post_act(sample) + return sample + + +class Timesteps(nn.Module): + def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1): + super().__init__() + self.num_channels = num_channels + self.flip_sin_to_cos = flip_sin_to_cos + self.downscale_freq_shift = downscale_freq_shift + self.scale = scale + + def forward(self, timesteps: torch.Tensor) -> torch.Tensor: + t_emb = get_timestep_embedding( + timesteps, + self.num_channels, + flip_sin_to_cos=self.flip_sin_to_cos, + downscale_freq_shift=self.downscale_freq_shift, + scale=self.scale, + ) + return t_emb + + +class AdaLayerNormContinuous(nn.Module): + r""" + Adaptive normalization layer with a norm layer (layer_norm or rms_norm). + + Args: + embedding_dim (`int`): Embedding dimension to use during projection. + conditioning_embedding_dim (`int`): Dimension of the input condition. + elementwise_affine (`bool`, defaults to `True`): + Boolean flag to denote if affine transformation should be applied. + eps (`float`, defaults to 1e-5): Epsilon factor. + bias (`bias`, defaults to `True`): Boolean flag to denote if bias should be use. + norm_type (`str`, defaults to `"layer_norm"`): + Normalization layer to use. Values supported: "layer_norm", "rms_norm". + """ + + def __init__( + self, + embedding_dim: int, + conditioning_embedding_dim: int, + # NOTE: It is a bit weird that the norm layer can be configured to have scale and shift parameters + # because the output is immediately scaled and shifted by the projected conditioning embeddings. + # Note that AdaLayerNorm does not let the norm layer have scale and shift parameters. + # However, this is how it was implemented in the original code, and it's rather likely you should + # set `elementwise_affine` to False. + elementwise_affine=True, + eps=1e-5, + bias=True, + norm_type="layer_norm", + ): + super().__init__() + self.silu = nn.SiLU() + self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias) + if norm_type == "layer_norm": + self.norm = nn.LayerNorm(embedding_dim, eps, elementwise_affine, bias) + + def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor: + # convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT) + emb = self.linear(self.silu(conditioning_embedding).to(x.dtype)) + scale, shift = torch.chunk(emb, 2, dim=1) + x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :] + return x + + +def get_1d_rotary_pos_embed( + dim: int, + pos: Union[np.ndarray, int], + theta: float = 10000.0, + use_real=False, + linear_factor=1.0, + ntk_factor=1.0, + repeat_interleave_real=True, + freqs_dtype=torch.float32, # torch.float32, torch.float64 (flux) +): + """ + Precompute the frequency tensor for complex exponentials (cis) with given dimensions. + + This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end + index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64 + data type. + + Args: + dim (`int`): Dimension of the frequency tensor. + pos (`np.ndarray` or `int`): Position indices for the frequency tensor. [S] or scalar + theta (`float`, *optional*, defaults to 10000.0): + Scaling factor for frequency computation. Defaults to 10000.0. + use_real (`bool`, *optional*): + If True, return real part and imaginary part separately. Otherwise, return complex numbers. + linear_factor (`float`, *optional*, defaults to 1.0): + Scaling factor for the context extrapolation. Defaults to 1.0. + ntk_factor (`float`, *optional*, defaults to 1.0): + Scaling factor for the NTK-Aware RoPE. Defaults to 1.0. + repeat_interleave_real (`bool`, *optional*, defaults to `True`): + If `True` and `use_real`, real part and imaginary part are each interleaved with themselves to reach `dim`. + Otherwise, they are concateanted with themselves. + freqs_dtype (`torch.float32` or `torch.float64`, *optional*, defaults to `torch.float32`): + the dtype of the frequency tensor. + Returns: + `torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2] + """ + assert dim % 2 == 0 + + if isinstance(pos, int): + pos = torch.arange(pos) + if isinstance(pos, np.ndarray): + pos = torch.from_numpy(pos) # type: ignore # [S] + + theta = theta * ntk_factor + freqs = ( + 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device) / dim)) / linear_factor + ) # [D/2] + freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2] + is_npu = freqs.device.type == "npu" + if is_npu: + freqs = freqs.float() + if use_real and repeat_interleave_real: + # flux, hunyuan-dit, cogvideox + freqs_cos = freqs.cos().repeat_interleave(2, dim=1, output_size=freqs.shape[1] * 2).float() # [S, D] + freqs_sin = freqs.sin().repeat_interleave(2, dim=1, output_size=freqs.shape[1] * 2).float() # [S, D] + return freqs_cos, freqs_sin + elif use_real: + # stable audio, allegro + freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D] + freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D] + return freqs_cos, freqs_sin + else: + # lumina + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2] + return freqs_cis + + +def apply_rotary_emb( + x: torch.Tensor, + freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], + use_real: bool = True, + use_real_unbind_dim: int = -1, + sequence_dim: int = 2, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings + to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are + reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting + tensors contain rotary embeddings and are returned as real tensors. + + Args: + x (`torch.Tensor`): + Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply + freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],) + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. + """ + if use_real: + cos, sin = freqs_cis # [S, D] + if sequence_dim == 2: + cos = cos[None, None, :, :] + sin = sin[None, None, :, :] + elif sequence_dim == 1: + cos = cos[None, :, None, :] + sin = sin[None, :, None, :] + else: + raise ValueError(f"`sequence_dim={sequence_dim}` but should be 1 or 2.") + + cos, sin = cos.to(x.device), sin.to(x.device) + + if use_real_unbind_dim == -1: + # Used for flux, cogvideox, hunyuan-dit + x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, H, S, D//2] + x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) + elif use_real_unbind_dim == -2: + # Used for Stable Audio, OmniGen, CogView4 and Cosmos + x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, H, S, D//2] + x_rotated = torch.cat([-x_imag, x_real], dim=-1) + else: + raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.") + + out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) + + return out + else: + # used for lumina + x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) + freqs_cis = freqs_cis.unsqueeze(2) + x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3) + + return x_out.type_as(x) + +def _get_projections(attn: "Flux2Attention", hidden_states, encoder_hidden_states=None): + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + encoder_query = encoder_key = encoder_value = None + if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None: + encoder_query = attn.add_q_proj(encoder_hidden_states) + encoder_key = attn.add_k_proj(encoder_hidden_states) + encoder_value = attn.add_v_proj(encoder_hidden_states) + + return query, key, value, encoder_query, encoder_key, encoder_value + + +def _get_fused_projections(attn: "Flux2Attention", hidden_states, encoder_hidden_states=None): + query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1) + + encoder_query = encoder_key = encoder_value = (None,) + if encoder_hidden_states is not None and hasattr(attn, "to_added_qkv"): + encoder_query, encoder_key, encoder_value = attn.to_added_qkv(encoder_hidden_states).chunk(3, dim=-1) + + return query, key, value, encoder_query, encoder_key, encoder_value + + +def _get_qkv_projections(attn: "Flux2Attention", hidden_states, encoder_hidden_states=None): + return _get_projections(attn, hidden_states, encoder_hidden_states) + + +class Flux2SwiGLU(nn.Module): + """ + Flux 2 uses a SwiGLU-style activation in the transformer feedforward sub-blocks, but with the linear projection + layer fused into the first linear layer of the FF sub-block. Thus, this module has no trainable parameters. + """ + + def __init__(self): + super().__init__() + self.gate_fn = nn.SiLU() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x1, x2 = x.chunk(2, dim=-1) + x = self.gate_fn(x1) * x2 + return x + + +class Flux2FeedForward(nn.Module): + def __init__( + self, + dim: int, + dim_out: Optional[int] = None, + mult: float = 3.0, + inner_dim: Optional[int] = None, + bias: bool = False, + ): + super().__init__() + if inner_dim is None: + inner_dim = int(dim * mult) + dim_out = dim_out or dim + + # Flux2SwiGLU will reduce the dimension by half + self.linear_in = nn.Linear(dim, inner_dim * 2, bias=bias) + self.act_fn = Flux2SwiGLU() + self.linear_out = nn.Linear(inner_dim, dim_out, bias=bias) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.linear_in(x) + x = self.act_fn(x) + x = self.linear_out(x) + return x + + +class Flux2AttnProcessor: + _attention_backend = None + _parallel_config = None + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.") + + def __call__( + self, + attn: "Flux2Attention", + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections( + attn, hidden_states, encoder_hidden_states + ) + + query = query.unflatten(-1, (attn.heads, -1)) + key = key.unflatten(-1, (attn.heads, -1)) + value = value.unflatten(-1, (attn.heads, -1)) + + query = attn.norm_q(query) + key = attn.norm_k(key) + + if attn.added_kv_proj_dim is not None: + encoder_query = encoder_query.unflatten(-1, (attn.heads, -1)) + encoder_key = encoder_key.unflatten(-1, (attn.heads, -1)) + encoder_value = encoder_value.unflatten(-1, (attn.heads, -1)) + + encoder_query = attn.norm_added_q(encoder_query) + encoder_key = attn.norm_added_k(encoder_key) + + query = torch.cat([encoder_query, query], dim=1) + key = torch.cat([encoder_key, key], dim=1) + value = torch.cat([encoder_value, value], dim=1) + + if image_rotary_emb is not None: + query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1) + key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1) + + query, key, value = query.to(hidden_states.dtype), key.to(hidden_states.dtype), value.to(hidden_states.dtype) + hidden_states = attention_forward( + query, + key, + value, + q_pattern="b s n d", k_pattern="b s n d", v_pattern="b s n d", out_pattern="b s n d", + ) + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.to(query.dtype) + + if encoder_hidden_states is not None: + encoder_hidden_states, hidden_states = hidden_states.split_with_sizes( + [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1 + ) + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + + if encoder_hidden_states is not None: + return hidden_states, encoder_hidden_states + else: + return hidden_states + + +class Flux2Attention(torch.nn.Module): + _default_processor_cls = Flux2AttnProcessor + _available_processors = [Flux2AttnProcessor] + + def __init__( + self, + query_dim: int, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias: bool = False, + added_kv_proj_dim: Optional[int] = None, + added_proj_bias: Optional[bool] = True, + out_bias: bool = True, + eps: float = 1e-5, + out_dim: int = None, + elementwise_affine: bool = True, + processor=None, + ): + super().__init__() + + self.head_dim = dim_head + self.inner_dim = out_dim if out_dim is not None else dim_head * heads + self.query_dim = query_dim + self.out_dim = out_dim if out_dim is not None else query_dim + self.heads = out_dim // dim_head if out_dim is not None else heads + + self.use_bias = bias + self.dropout = dropout + + self.added_kv_proj_dim = added_kv_proj_dim + self.added_proj_bias = added_proj_bias + + self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) + self.to_k = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) + self.to_v = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) + + # QK Norm + self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) + self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) + + self.to_out = torch.nn.ModuleList([]) + self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)) + self.to_out.append(torch.nn.Dropout(dropout)) + + if added_kv_proj_dim is not None: + self.norm_added_q = torch.nn.RMSNorm(dim_head, eps=eps) + self.norm_added_k = torch.nn.RMSNorm(dim_head, eps=eps) + self.add_q_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) + self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) + self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) + self.to_add_out = torch.nn.Linear(self.inner_dim, query_dim, bias=out_bias) + + if processor is None: + processor = self._default_processor_cls() + self.processor = processor + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) + kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters} + return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs) + + +class Flux2ParallelSelfAttnProcessor: + _attention_backend = None + _parallel_config = None + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.") + + def __call__( + self, + attn: "Flux2ParallelSelfAttention", + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + # Parallel in (QKV + MLP in) projection + hidden_states = attn.to_qkv_mlp_proj(hidden_states) + qkv, mlp_hidden_states = torch.split( + hidden_states, [3 * attn.inner_dim, attn.mlp_hidden_dim * attn.mlp_mult_factor], dim=-1 + ) + + # Handle the attention logic + query, key, value = qkv.chunk(3, dim=-1) + + query = query.unflatten(-1, (attn.heads, -1)) + key = key.unflatten(-1, (attn.heads, -1)) + value = value.unflatten(-1, (attn.heads, -1)) + + query = attn.norm_q(query) + key = attn.norm_k(key) + + if image_rotary_emb is not None: + query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1) + key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1) + + query, key, value = query.to(hidden_states.dtype), key.to(hidden_states.dtype), value.to(hidden_states.dtype) + hidden_states = attention_forward( + query, + key, + value, + q_pattern="b s n d", k_pattern="b s n d", v_pattern="b s n d", out_pattern="b s n d", + ) + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.to(query.dtype) + + # Handle the feedforward (FF) logic + mlp_hidden_states = attn.mlp_act_fn(mlp_hidden_states) + + # Concatenate and parallel output projection + hidden_states = torch.cat([hidden_states, mlp_hidden_states], dim=-1) + hidden_states = attn.to_out(hidden_states) + + return hidden_states + + +class Flux2ParallelSelfAttention(torch.nn.Module): + """ + Flux 2 parallel self-attention for the Flux 2 single-stream transformer blocks. + + This implements a parallel transformer block, where the attention QKV projections are fused to the feedforward (FF) + input projections, and the attention output projections are fused to the FF output projections. See the [ViT-22B + paper](https://arxiv.org/abs/2302.05442) for a visual depiction of this type of transformer block. + """ + + _default_processor_cls = Flux2ParallelSelfAttnProcessor + _available_processors = [Flux2ParallelSelfAttnProcessor] + # Does not support QKV fusion as the QKV projections are always fused + _supports_qkv_fusion = False + + def __init__( + self, + query_dim: int, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias: bool = False, + out_bias: bool = True, + eps: float = 1e-5, + out_dim: int = None, + elementwise_affine: bool = True, + mlp_ratio: float = 4.0, + mlp_mult_factor: int = 2, + processor=None, + ): + super().__init__() + + self.head_dim = dim_head + self.inner_dim = out_dim if out_dim is not None else dim_head * heads + self.query_dim = query_dim + self.out_dim = out_dim if out_dim is not None else query_dim + self.heads = out_dim // dim_head if out_dim is not None else heads + + self.use_bias = bias + self.dropout = dropout + + self.mlp_ratio = mlp_ratio + self.mlp_hidden_dim = int(query_dim * self.mlp_ratio) + self.mlp_mult_factor = mlp_mult_factor + + # Fused QKV projections + MLP input projection + self.to_qkv_mlp_proj = torch.nn.Linear( + self.query_dim, self.inner_dim * 3 + self.mlp_hidden_dim * self.mlp_mult_factor, bias=bias + ) + self.mlp_act_fn = Flux2SwiGLU() + + # QK Norm + self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) + self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) + + # Fused attention output projection + MLP output projection + self.to_out = torch.nn.Linear(self.inner_dim + self.mlp_hidden_dim, self.out_dim, bias=out_bias) + + if processor is None: + processor = self._default_processor_cls() + self.processor = processor + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) + kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters} + return self.processor(self, hidden_states, attention_mask, image_rotary_emb, **kwargs) + + +class Flux2SingleTransformerBlock(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + mlp_ratio: float = 3.0, + eps: float = 1e-6, + bias: bool = False, + ): + super().__init__() + + self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) + + # Note that the MLP in/out linear layers are fused with the attention QKV/out projections, respectively; this + # is often called a "parallel" transformer block. See the [ViT-22B paper](https://arxiv.org/abs/2302.05442) + # for a visual depiction of this type of transformer block. + self.attn = Flux2ParallelSelfAttention( + query_dim=dim, + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=dim, + bias=bias, + out_bias=bias, + eps=eps, + mlp_ratio=mlp_ratio, + mlp_mult_factor=2, + processor=Flux2ParallelSelfAttnProcessor(), + ) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor], + temb_mod_params: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + split_hidden_states: bool = False, + text_seq_len: Optional[int] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # If encoder_hidden_states is None, hidden_states is assumed to have encoder_hidden_states already + # concatenated + if encoder_hidden_states is not None: + text_seq_len = encoder_hidden_states.shape[1] + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + + mod_shift, mod_scale, mod_gate = temb_mod_params + + norm_hidden_states = self.norm(hidden_states) + norm_hidden_states = (1 + mod_scale) * norm_hidden_states + mod_shift + + joint_attention_kwargs = joint_attention_kwargs or {} + attn_output = self.attn( + hidden_states=norm_hidden_states, + image_rotary_emb=image_rotary_emb, + **joint_attention_kwargs, + ) + + hidden_states = hidden_states + mod_gate * attn_output + if hidden_states.dtype == torch.float16: + hidden_states = hidden_states.clip(-65504, 65504) + + if split_hidden_states: + encoder_hidden_states, hidden_states = hidden_states[:, :text_seq_len], hidden_states[:, text_seq_len:] + return encoder_hidden_states, hidden_states + else: + return hidden_states + + +class Flux2TransformerBlock(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + mlp_ratio: float = 3.0, + eps: float = 1e-6, + bias: bool = False, + ): + super().__init__() + self.mlp_hidden_dim = int(dim * mlp_ratio) + + self.norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) + self.norm1_context = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) + + self.attn = Flux2Attention( + query_dim=dim, + added_kv_proj_dim=dim, + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=dim, + bias=bias, + added_proj_bias=bias, + out_bias=bias, + eps=eps, + processor=Flux2AttnProcessor(), + ) + + self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) + self.ff = Flux2FeedForward(dim=dim, dim_out=dim, mult=mlp_ratio, bias=bias) + + self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) + self.ff_context = Flux2FeedForward(dim=dim, dim_out=dim, mult=mlp_ratio, bias=bias) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb_mod_params_img: Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...], + temb_mod_params_txt: Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...], + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + joint_attention_kwargs = joint_attention_kwargs or {} + + # Modulation parameters shape: [1, 1, self.dim] + (shift_msa, scale_msa, gate_msa), (shift_mlp, scale_mlp, gate_mlp) = temb_mod_params_img + (c_shift_msa, c_scale_msa, c_gate_msa), (c_shift_mlp, c_scale_mlp, c_gate_mlp) = temb_mod_params_txt + + # Img stream + norm_hidden_states = self.norm1(hidden_states) + norm_hidden_states = (1 + scale_msa) * norm_hidden_states + shift_msa + + # Conditioning txt stream + norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states) + norm_encoder_hidden_states = (1 + c_scale_msa) * norm_encoder_hidden_states + c_shift_msa + + # Attention on concatenated img + txt stream + attention_outputs = self.attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, + image_rotary_emb=image_rotary_emb, + **joint_attention_kwargs, + ) + + attn_output, context_attn_output = attention_outputs + + # Process attention outputs for the image stream (`hidden_states`). + attn_output = gate_msa * attn_output + hidden_states = hidden_states + attn_output + + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp + + ff_output = self.ff(norm_hidden_states) + hidden_states = hidden_states + gate_mlp * ff_output + + # Process attention outputs for the text stream (`encoder_hidden_states`). + context_attn_output = c_gate_msa * context_attn_output + encoder_hidden_states = encoder_hidden_states + context_attn_output + + norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) + norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp) + c_shift_mlp + + context_ff_output = self.ff_context(norm_encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states + c_gate_mlp * context_ff_output + if encoder_hidden_states.dtype == torch.float16: + encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) + + return encoder_hidden_states, hidden_states + + +class Flux2PosEmbed(nn.Module): + # modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11 + def __init__(self, theta: int, axes_dim: List[int]): + super().__init__() + self.theta = theta + self.axes_dim = axes_dim + + def forward(self, ids: torch.Tensor) -> torch.Tensor: + # Expected ids shape: [S, len(self.axes_dim)] + cos_out = [] + sin_out = [] + pos = ids.float() + is_mps = ids.device.type == "mps" + is_npu = ids.device.type == "npu" + freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64 + # Unlike Flux 1, loop over len(self.axes_dim) rather than ids.shape[-1] + for i in range(len(self.axes_dim)): + cos, sin = get_1d_rotary_pos_embed( + self.axes_dim[i], + pos[..., i], + theta=self.theta, + repeat_interleave_real=True, + use_real=True, + freqs_dtype=freqs_dtype, + ) + cos_out.append(cos) + sin_out.append(sin) + freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device) + freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device) + return freqs_cos, freqs_sin + + +class Flux2TimestepGuidanceEmbeddings(nn.Module): + def __init__( + self, + in_channels: int = 256, + embedding_dim: int = 6144, + bias: bool = False, + guidance_embeds: bool = True, + ): + super().__init__() + + self.time_proj = Timesteps(num_channels=in_channels, flip_sin_to_cos=True, downscale_freq_shift=0) + self.timestep_embedder = TimestepEmbedding( + in_channels=in_channels, time_embed_dim=embedding_dim, sample_proj_bias=bias + ) + + if guidance_embeds: + self.guidance_embedder = TimestepEmbedding( + in_channels=in_channels, time_embed_dim=embedding_dim, sample_proj_bias=bias + ) + else: + self.guidance_embedder = None + + def forward(self, timestep: torch.Tensor, guidance: torch.Tensor) -> torch.Tensor: + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder(timesteps_proj.to(timestep.dtype)) # (N, D) + + if guidance is not None and self.guidance_embedder is not None: + guidance_proj = self.time_proj(guidance) + guidance_emb = self.guidance_embedder(guidance_proj.to(guidance.dtype)) # (N, D) + time_guidance_emb = timesteps_emb + guidance_emb + return time_guidance_emb + else: + return timesteps_emb + + +class Flux2Modulation(nn.Module): + def __init__(self, dim: int, mod_param_sets: int = 2, bias: bool = False): + super().__init__() + self.mod_param_sets = mod_param_sets + + self.linear = nn.Linear(dim, dim * 3 * self.mod_param_sets, bias=bias) + self.act_fn = nn.SiLU() + + def forward(self, temb: torch.Tensor) -> Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...]: + mod = self.act_fn(temb) + mod = self.linear(mod) + + if mod.ndim == 2: + mod = mod.unsqueeze(1) + mod_params = torch.chunk(mod, 3 * self.mod_param_sets, dim=-1) + # Return tuple of 3-tuples of modulation params shift/scale/gate + return tuple(mod_params[3 * i : 3 * (i + 1)] for i in range(self.mod_param_sets)) + + +class Flux2DiT(torch.nn.Module): + + _repeated_blocks = ["Flux2TransformerBlock", "Flux2SingleTransformerBlock"] + + def __init__( + self, + patch_size: int = 1, + in_channels: int = 128, + out_channels: Optional[int] = None, + num_layers: int = 8, + num_single_layers: int = 48, + attention_head_dim: int = 128, + num_attention_heads: int = 48, + joint_attention_dim: int = 15360, + timestep_guidance_channels: int = 256, + mlp_ratio: float = 3.0, + axes_dims_rope: Tuple[int, ...] = (32, 32, 32, 32), + rope_theta: int = 2000, + eps: float = 1e-6, + guidance_embeds: bool = True, + ): + super().__init__() + self.out_channels = out_channels or in_channels + self.inner_dim = num_attention_heads * attention_head_dim + + # 1. Sinusoidal positional embedding for RoPE on image and text tokens + self.pos_embed = Flux2PosEmbed(theta=rope_theta, axes_dim=axes_dims_rope) + + # 2. Combined timestep + guidance embedding + self.time_guidance_embed = Flux2TimestepGuidanceEmbeddings( + in_channels=timestep_guidance_channels, + embedding_dim=self.inner_dim, + bias=False, + guidance_embeds=guidance_embeds, + ) + + # 3. Modulation (double stream and single stream blocks share modulation parameters, resp.) + # Two sets of shift/scale/gate modulation parameters for the double stream attn and FF sub-blocks + self.double_stream_modulation_img = Flux2Modulation(self.inner_dim, mod_param_sets=2, bias=False) + self.double_stream_modulation_txt = Flux2Modulation(self.inner_dim, mod_param_sets=2, bias=False) + # Only one set of modulation parameters as the attn and FF sub-blocks are run in parallel for single stream + self.single_stream_modulation = Flux2Modulation(self.inner_dim, mod_param_sets=1, bias=False) + + # 4. Input projections + self.x_embedder = nn.Linear(in_channels, self.inner_dim, bias=False) + self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim, bias=False) + + # 5. Double Stream Transformer Blocks + self.transformer_blocks = nn.ModuleList( + [ + Flux2TransformerBlock( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + mlp_ratio=mlp_ratio, + eps=eps, + bias=False, + ) + for _ in range(num_layers) + ] + ) + + # 6. Single Stream Transformer Blocks + self.single_transformer_blocks = nn.ModuleList( + [ + Flux2SingleTransformerBlock( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + mlp_ratio=mlp_ratio, + eps=eps, + bias=False, + ) + for _ in range(num_single_layers) + ] + ) + + # 7. Output layers + self.norm_out = AdaLayerNormContinuous( + self.inner_dim, self.inner_dim, elementwise_affine=False, eps=eps, bias=False + ) + self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=False) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + timestep: torch.LongTensor = None, + img_ids: torch.Tensor = None, + txt_ids: torch.Tensor = None, + guidance: torch.Tensor = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + use_gradient_checkpointing=False, + use_gradient_checkpointing_offload=False, + ): + # 0. Handle input arguments + if joint_attention_kwargs is not None: + joint_attention_kwargs = joint_attention_kwargs.copy() + lora_scale = joint_attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + num_txt_tokens = encoder_hidden_states.shape[1] + + # 1. Calculate timestep embedding and modulation parameters + timestep = timestep.to(hidden_states.dtype) * 1000 + + if guidance is not None: + guidance = guidance.to(hidden_states.dtype) * 1000 + + temb = self.time_guidance_embed(timestep, guidance) + + double_stream_mod_img = self.double_stream_modulation_img(temb) + double_stream_mod_txt = self.double_stream_modulation_txt(temb) + single_stream_mod = self.single_stream_modulation(temb)[0] + + # 2. Input projection for image (hidden_states) and conditioning text (encoder_hidden_states) + hidden_states = self.x_embedder(hidden_states) + encoder_hidden_states = self.context_embedder(encoder_hidden_states) + + # 3. Calculate RoPE embeddings from image and text tokens + # NOTE: the below logic means that we can't support batched inference with images of different resolutions or + # text prompts of differents lengths. Is this a use case we want to support? + if img_ids.ndim == 3: + img_ids = img_ids[0] + if txt_ids.ndim == 3: + txt_ids = txt_ids[0] + + image_rotary_emb = self.pos_embed(img_ids) + text_rotary_emb = self.pos_embed(txt_ids) + concat_rotary_emb = ( + torch.cat([text_rotary_emb[0], image_rotary_emb[0]], dim=0), + torch.cat([text_rotary_emb[1], image_rotary_emb[1]], dim=0), + ) + + # 4. Double Stream Transformer Blocks + for index_block, block in enumerate(self.transformer_blocks): + encoder_hidden_states, hidden_states = gradient_checkpoint_forward( + block, + use_gradient_checkpointing=use_gradient_checkpointing, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb_mod_params_img=double_stream_mod_img, + temb_mod_params_txt=double_stream_mod_txt, + image_rotary_emb=concat_rotary_emb, + joint_attention_kwargs=joint_attention_kwargs, + ) + # Concatenate text and image streams for single-block inference + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + + # 5. Single Stream Transformer Blocks + for index_block, block in enumerate(self.single_transformer_blocks): + hidden_states = gradient_checkpoint_forward( + block, + use_gradient_checkpointing=use_gradient_checkpointing, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + hidden_states=hidden_states, + encoder_hidden_states=None, + temb_mod_params=single_stream_mod, + image_rotary_emb=concat_rotary_emb, + joint_attention_kwargs=joint_attention_kwargs, + ) + # Remove text tokens from concatenated stream + hidden_states = hidden_states[:, num_txt_tokens:, ...] + + # 6. Output layers + hidden_states = self.norm_out(hidden_states, temb) + output = self.proj_out(hidden_states) + + return output diff --git a/diffsynth/models/flux2_text_encoder.py b/diffsynth/models/flux2_text_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..f3c68411f3160655ebefd49dfb6424b19373a301 --- /dev/null +++ b/diffsynth/models/flux2_text_encoder.py @@ -0,0 +1,58 @@ +from transformers import Mistral3ForConditionalGeneration, Mistral3Config + + +class Flux2TextEncoder(Mistral3ForConditionalGeneration): + def __init__(self): + config = Mistral3Config(**{ + "architectures": [ + "Mistral3ForConditionalGeneration" + ], + "dtype": "bfloat16", + "image_token_index": 10, + "model_type": "mistral3", + "multimodal_projector_bias": False, + "projector_hidden_act": "gelu", + "spatial_merge_size": 2, + "text_config": { + "attention_dropout": 0.0, + "dtype": "bfloat16", + "head_dim": 128, + "hidden_act": "silu", + "hidden_size": 5120, + "initializer_range": 0.02, + "intermediate_size": 32768, + "max_position_embeddings": 131072, + "model_type": "mistral", + "num_attention_heads": 32, + "num_hidden_layers": 40, + "num_key_value_heads": 8, + "rms_norm_eps": 1e-05, + "rope_theta": 1000000000.0, + "sliding_window": None, + "use_cache": True, + "vocab_size": 131072 + }, + "transformers_version": "4.57.1", + "vision_config": { + "attention_dropout": 0.0, + "dtype": "bfloat16", + "head_dim": 64, + "hidden_act": "silu", + "hidden_size": 1024, + "image_size": 1540, + "initializer_range": 0.02, + "intermediate_size": 4096, + "model_type": "pixtral", + "num_attention_heads": 16, + "num_channels": 3, + "num_hidden_layers": 24, + "patch_size": 14, + "rope_theta": 10000.0 + }, + "vision_feature_layer": -1 + }) + super().__init__(config) + + def forward(self, input_ids = None, pixel_values = None, attention_mask = None, position_ids = None, past_key_values = None, inputs_embeds = None, labels = None, use_cache = None, output_attentions = None, output_hidden_states = None, return_dict = None, cache_position = None, logits_to_keep = 0, image_sizes = None, **kwargs): + return super().forward(input_ids, pixel_values, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, logits_to_keep, image_sizes, **kwargs) + diff --git a/diffsynth/models/flux2_vae.py b/diffsynth/models/flux2_vae.py new file mode 100644 index 0000000000000000000000000000000000000000..c7904b17618cc3c0811e42fde0f80ecd8f15f7ee --- /dev/null +++ b/diffsynth/models/flux2_vae.py @@ -0,0 +1,2322 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +from typing import Dict, Optional, Tuple, Union, Callable + +import torch +import torch.nn as nn +from einops import rearrange +import torch.nn.functional as F +import inspect + +ACT2CLS = { + "swish": nn.SiLU, + "silu": nn.SiLU, + "mish": nn.Mish, + "gelu": nn.GELU, + "relu": nn.ReLU, +} + +def get_activation(act_fn: str) -> nn.Module: + """Helper function to get activation function from string. + + Args: + act_fn (str): Name of activation function. + + Returns: + nn.Module: Activation function. + """ + + act_fn = act_fn.lower() + if act_fn in ACT2CLS: + return ACT2CLS[act_fn]() + else: + raise ValueError(f"activation function {act_fn} not found in ACT2FN mapping {list(ACT2CLS.keys())}") + +class ResnetBlock2D(nn.Module): + r""" + A Resnet block. + + Parameters: + in_channels (`int`): The number of channels in the input. + out_channels (`int`, *optional*, default to be `None`): + The number of output channels for the first conv2d layer. If None, same as `in_channels`. + dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use. + temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding. + groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer. + groups_out (`int`, *optional*, default to None): + The number of groups to use for the second normalization layer. if set to None, same as `groups`. + eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization. + non_linearity (`str`, *optional*, default to `"swish"`): the activation function to use. + time_embedding_norm (`str`, *optional*, default to `"default"` ): Time scale shift config. + By default, apply timestep embedding conditioning with a simple shift mechanism. Choose "scale_shift" for a + stronger conditioning with scale and shift. + kernel (`torch.Tensor`, optional, default to None): FIR filter, see + [`~models.resnet.FirUpsample2D`] and [`~models.resnet.FirDownsample2D`]. + output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output. + use_in_shortcut (`bool`, *optional*, default to `True`): + If `True`, add a 1x1 nn.conv2d layer for skip-connection. + up (`bool`, *optional*, default to `False`): If `True`, add an upsample layer. + down (`bool`, *optional*, default to `False`): If `True`, add a downsample layer. + conv_shortcut_bias (`bool`, *optional*, default to `True`): If `True`, adds a learnable bias to the + `conv_shortcut` output. + conv_2d_out_channels (`int`, *optional*, default to `None`): the number of channels in the output. + If None, same as `out_channels`. + """ + + def __init__( + self, + *, + in_channels: int, + out_channels: Optional[int] = None, + conv_shortcut: bool = False, + dropout: float = 0.0, + temb_channels: int = 512, + groups: int = 32, + groups_out: Optional[int] = None, + pre_norm: bool = True, + eps: float = 1e-6, + non_linearity: str = "swish", + skip_time_act: bool = False, + time_embedding_norm: str = "default", # default, scale_shift, + kernel: Optional[torch.Tensor] = None, + output_scale_factor: float = 1.0, + use_in_shortcut: Optional[bool] = None, + up: bool = False, + down: bool = False, + conv_shortcut_bias: bool = True, + conv_2d_out_channels: Optional[int] = None, + ): + super().__init__() + if time_embedding_norm == "ada_group": + raise ValueError( + "This class cannot be used with `time_embedding_norm==ada_group`, please use `ResnetBlockCondNorm2D` instead", + ) + if time_embedding_norm == "spatial": + raise ValueError( + "This class cannot be used with `time_embedding_norm==spatial`, please use `ResnetBlockCondNorm2D` instead", + ) + + self.pre_norm = True + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + self.up = up + self.down = down + self.output_scale_factor = output_scale_factor + self.time_embedding_norm = time_embedding_norm + self.skip_time_act = skip_time_act + + if groups_out is None: + groups_out = groups + + self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) + + self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + + if temb_channels is not None: + if self.time_embedding_norm == "default": + self.time_emb_proj = nn.Linear(temb_channels, out_channels) + elif self.time_embedding_norm == "scale_shift": + self.time_emb_proj = nn.Linear(temb_channels, 2 * out_channels) + else: + raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ") + else: + self.time_emb_proj = None + + self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) + + self.dropout = torch.nn.Dropout(dropout) + conv_2d_out_channels = conv_2d_out_channels or out_channels + self.conv2 = nn.Conv2d(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1) + + self.nonlinearity = get_activation(non_linearity) + + self.upsample = self.downsample = None + if self.up: + if kernel == "fir": + fir_kernel = (1, 3, 3, 1) + self.upsample = lambda x: upsample_2d(x, kernel=fir_kernel) + elif kernel == "sde_vp": + self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest") + else: + self.upsample = Upsample2D(in_channels, use_conv=False) + elif self.down: + if kernel == "fir": + fir_kernel = (1, 3, 3, 1) + self.downsample = lambda x: downsample_2d(x, kernel=fir_kernel) + elif kernel == "sde_vp": + self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2) + else: + self.downsample = Downsample2D(in_channels, use_conv=False, padding=1, name="op") + + self.use_in_shortcut = self.in_channels != conv_2d_out_channels if use_in_shortcut is None else use_in_shortcut + + self.conv_shortcut = None + if self.use_in_shortcut: + self.conv_shortcut = nn.Conv2d( + in_channels, + conv_2d_out_channels, + kernel_size=1, + stride=1, + padding=0, + bias=conv_shortcut_bias, + ) + + def forward(self, input_tensor: torch.Tensor, temb: torch.Tensor, *args, **kwargs) -> torch.Tensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + hidden_states = input_tensor + + hidden_states = self.norm1(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + + if self.upsample is not None: + # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 + if hidden_states.shape[0] >= 64: + input_tensor = input_tensor.contiguous() + hidden_states = hidden_states.contiguous() + input_tensor = self.upsample(input_tensor) + hidden_states = self.upsample(hidden_states) + elif self.downsample is not None: + input_tensor = self.downsample(input_tensor) + hidden_states = self.downsample(hidden_states) + + hidden_states = self.conv1(hidden_states) + + if self.time_emb_proj is not None: + if not self.skip_time_act: + temb = self.nonlinearity(temb) + temb = self.time_emb_proj(temb)[:, :, None, None] + + if self.time_embedding_norm == "default": + if temb is not None: + hidden_states = hidden_states + temb + hidden_states = self.norm2(hidden_states) + elif self.time_embedding_norm == "scale_shift": + if temb is None: + raise ValueError( + f" `temb` should not be None when `time_embedding_norm` is {self.time_embedding_norm}" + ) + time_scale, time_shift = torch.chunk(temb, 2, dim=1) + hidden_states = self.norm2(hidden_states) + hidden_states = hidden_states * (1 + time_scale) + time_shift + else: + hidden_states = self.norm2(hidden_states) + + hidden_states = self.nonlinearity(hidden_states) + + hidden_states = self.dropout(hidden_states) + hidden_states = self.conv2(hidden_states) + + if self.conv_shortcut is not None: + input_tensor = self.conv_shortcut(input_tensor.contiguous()) + + output_tensor = (input_tensor + hidden_states) / self.output_scale_factor + + return output_tensor + +class Downsample2D(nn.Module): + """A 2D downsampling layer with an optional convolution. + + Parameters: + channels (`int`): + number of channels in the inputs and outputs. + use_conv (`bool`, default `False`): + option to use a convolution. + out_channels (`int`, optional): + number of output channels. Defaults to `channels`. + padding (`int`, default `1`): + padding for the convolution. + name (`str`, default `conv`): + name of the downsampling 2D layer. + """ + + def __init__( + self, + channels: int, + use_conv: bool = False, + out_channels: Optional[int] = None, + padding: int = 1, + name: str = "conv", + kernel_size=3, + norm_type=None, + eps=None, + elementwise_affine=None, + bias=True, + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.padding = padding + stride = 2 + self.name = name + + if norm_type == "ln_norm": + self.norm = nn.LayerNorm(channels, eps, elementwise_affine) + elif norm_type == "rms_norm": + self.norm = RMSNorm(channels, eps, elementwise_affine) + elif norm_type is None: + self.norm = None + else: + raise ValueError(f"unknown norm_type: {norm_type}") + + if use_conv: + conv = nn.Conv2d( + self.channels, self.out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias + ) + else: + assert self.channels == self.out_channels + conv = nn.AvgPool2d(kernel_size=stride, stride=stride) + + # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed + if name == "conv": + self.Conv2d_0 = conv + self.conv = conv + elif name == "Conv2d_0": + self.conv = conv + else: + self.conv = conv + + def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + assert hidden_states.shape[1] == self.channels + + if self.norm is not None: + hidden_states = self.norm(hidden_states.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + + if self.use_conv and self.padding == 0: + pad = (0, 1, 0, 1) + hidden_states = F.pad(hidden_states, pad, mode="constant", value=0) + + assert hidden_states.shape[1] == self.channels + + hidden_states = self.conv(hidden_states) + + return hidden_states + +class Upsample2D(nn.Module): + """A 2D upsampling layer with an optional convolution. + + Parameters: + channels (`int`): + number of channels in the inputs and outputs. + use_conv (`bool`, default `False`): + option to use a convolution. + use_conv_transpose (`bool`, default `False`): + option to use a convolution transpose. + out_channels (`int`, optional): + number of output channels. Defaults to `channels`. + name (`str`, default `conv`): + name of the upsampling 2D layer. + """ + + def __init__( + self, + channels: int, + use_conv: bool = False, + use_conv_transpose: bool = False, + out_channels: Optional[int] = None, + name: str = "conv", + kernel_size: Optional[int] = None, + padding=1, + norm_type=None, + eps=None, + elementwise_affine=None, + bias=True, + interpolate=True, + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_conv_transpose = use_conv_transpose + self.name = name + self.interpolate = interpolate + + if norm_type == "ln_norm": + self.norm = nn.LayerNorm(channels, eps, elementwise_affine) + elif norm_type == "rms_norm": + self.norm = RMSNorm(channels, eps, elementwise_affine) + elif norm_type is None: + self.norm = None + else: + raise ValueError(f"unknown norm_type: {norm_type}") + + conv = None + if use_conv_transpose: + if kernel_size is None: + kernel_size = 4 + conv = nn.ConvTranspose2d( + channels, self.out_channels, kernel_size=kernel_size, stride=2, padding=padding, bias=bias + ) + elif use_conv: + if kernel_size is None: + kernel_size = 3 + conv = nn.Conv2d(self.channels, self.out_channels, kernel_size=kernel_size, padding=padding, bias=bias) + + # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed + if name == "conv": + self.conv = conv + else: + self.Conv2d_0 = conv + + def forward(self, hidden_states: torch.Tensor, output_size: Optional[int] = None, *args, **kwargs) -> torch.Tensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + assert hidden_states.shape[1] == self.channels + + if self.norm is not None: + hidden_states = self.norm(hidden_states.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + + if self.use_conv_transpose: + return self.conv(hidden_states) + + # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 until PyTorch 2.1 + # https://github.com/pytorch/pytorch/issues/86679#issuecomment-1783978767 + dtype = hidden_states.dtype + if dtype == torch.bfloat16: + hidden_states = hidden_states.to(torch.float32) + + # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 + if hidden_states.shape[0] >= 64: + hidden_states = hidden_states.contiguous() + + # if `output_size` is passed we force the interpolation output + # size and do not make use of `scale_factor=2` + if self.interpolate: + # upsample_nearest_nhwc also fails when the number of output elements is large + # https://github.com/pytorch/pytorch/issues/141831 + scale_factor = ( + 2 if output_size is None else max([f / s for f, s in zip(output_size, hidden_states.shape[-2:])]) + ) + if hidden_states.numel() * scale_factor > pow(2, 31): + hidden_states = hidden_states.contiguous() + + if output_size is None: + hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest") + else: + hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest") + + # Cast back to original dtype + if dtype == torch.bfloat16: + hidden_states = hidden_states.to(dtype) + + # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed + if self.use_conv: + if self.name == "conv": + hidden_states = self.conv(hidden_states) + else: + hidden_states = self.Conv2d_0(hidden_states) + + return hidden_states + + +class Attention(nn.Module): + r""" + A cross attention layer. + + Parameters: + query_dim (`int`): + The number of channels in the query. + cross_attention_dim (`int`, *optional*): + The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`. + heads (`int`, *optional*, defaults to 8): + The number of heads to use for multi-head attention. + kv_heads (`int`, *optional*, defaults to `None`): + The number of key and value heads to use for multi-head attention. Defaults to `heads`. If + `kv_heads=heads`, the model will use Multi Head Attention (MHA), if `kv_heads=1` the model will use Multi + Query Attention (MQA) otherwise GQA is used. + dim_head (`int`, *optional*, defaults to 64): + The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): + The dropout probability to use. + bias (`bool`, *optional*, defaults to False): + Set to `True` for the query, key, and value linear layers to contain a bias parameter. + upcast_attention (`bool`, *optional*, defaults to False): + Set to `True` to upcast the attention computation to `float32`. + upcast_softmax (`bool`, *optional*, defaults to False): + Set to `True` to upcast the softmax computation to `float32`. + cross_attention_norm (`str`, *optional*, defaults to `None`): + The type of normalization to use for the cross attention. Can be `None`, `layer_norm`, or `group_norm`. + cross_attention_norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups to use for the group norm in the cross attention. + added_kv_proj_dim (`int`, *optional*, defaults to `None`): + The number of channels to use for the added key and value projections. If `None`, no projection is used. + norm_num_groups (`int`, *optional*, defaults to `None`): + The number of groups to use for the group norm in the attention. + spatial_norm_dim (`int`, *optional*, defaults to `None`): + The number of channels to use for the spatial normalization. + out_bias (`bool`, *optional*, defaults to `True`): + Set to `True` to use a bias in the output linear layer. + scale_qk (`bool`, *optional*, defaults to `True`): + Set to `True` to scale the query and key by `1 / sqrt(dim_head)`. + only_cross_attention (`bool`, *optional*, defaults to `False`): + Set to `True` to only use cross attention and not added_kv_proj_dim. Can only be set to `True` if + `added_kv_proj_dim` is not `None`. + eps (`float`, *optional*, defaults to 1e-5): + An additional value added to the denominator in group normalization that is used for numerical stability. + rescale_output_factor (`float`, *optional*, defaults to 1.0): + A factor to rescale the output by dividing it with this value. + residual_connection (`bool`, *optional*, defaults to `False`): + Set to `True` to add the residual connection to the output. + _from_deprecated_attn_block (`bool`, *optional*, defaults to `False`): + Set to `True` if the attention block is loaded from a deprecated state dict. + processor (`AttnProcessor`, *optional*, defaults to `None`): + The attention processor to use. If `None`, defaults to `AttnProcessor2_0` if `torch 2.x` is used and + `AttnProcessor` otherwise. + """ + + def __init__( + self, + query_dim: int, + cross_attention_dim: Optional[int] = None, + heads: int = 8, + kv_heads: Optional[int] = None, + dim_head: int = 64, + dropout: float = 0.0, + bias: bool = False, + upcast_attention: bool = False, + upcast_softmax: bool = False, + cross_attention_norm: Optional[str] = None, + cross_attention_norm_num_groups: int = 32, + qk_norm: Optional[str] = None, + added_kv_proj_dim: Optional[int] = None, + added_proj_bias: Optional[bool] = True, + norm_num_groups: Optional[int] = None, + spatial_norm_dim: Optional[int] = None, + out_bias: bool = True, + scale_qk: bool = True, + only_cross_attention: bool = False, + eps: float = 1e-5, + rescale_output_factor: float = 1.0, + residual_connection: bool = False, + _from_deprecated_attn_block: bool = False, + processor: Optional["AttnProcessor"] = None, + out_dim: int = None, + out_context_dim: int = None, + context_pre_only=None, + pre_only=False, + elementwise_affine: bool = True, + is_causal: bool = False, + ): + super().__init__() + + # To prevent circular import. + # from .normalization import FP32LayerNorm, LpNorm, RMSNorm + + self.inner_dim = out_dim if out_dim is not None else dim_head * heads + self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads + self.query_dim = query_dim + self.use_bias = bias + self.is_cross_attention = cross_attention_dim is not None + self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim + self.upcast_attention = upcast_attention + self.upcast_softmax = upcast_softmax + self.rescale_output_factor = rescale_output_factor + self.residual_connection = residual_connection + self.dropout = dropout + self.fused_projections = False + self.out_dim = out_dim if out_dim is not None else query_dim + self.out_context_dim = out_context_dim if out_context_dim is not None else query_dim + self.context_pre_only = context_pre_only + self.pre_only = pre_only + self.is_causal = is_causal + + # we make use of this private variable to know whether this class is loaded + # with an deprecated state dict so that we can convert it on the fly + self._from_deprecated_attn_block = _from_deprecated_attn_block + + self.scale_qk = scale_qk + self.scale = dim_head**-0.5 if self.scale_qk else 1.0 + + self.heads = out_dim // dim_head if out_dim is not None else heads + # for slice_size > 0 the attention score computation + # is split across the batch axis to save memory + # You can set slice_size with `set_attention_slice` + self.sliceable_head_dim = heads + + self.added_kv_proj_dim = added_kv_proj_dim + self.only_cross_attention = only_cross_attention + + if self.added_kv_proj_dim is None and self.only_cross_attention: + raise ValueError( + "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`." + ) + + if norm_num_groups is not None: + self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True) + else: + self.group_norm = None + + if spatial_norm_dim is not None: + self.spatial_norm = SpatialNorm(f_channels=query_dim, zq_channels=spatial_norm_dim) + else: + self.spatial_norm = None + + if qk_norm is None: + self.norm_q = None + self.norm_k = None + elif qk_norm == "layer_norm": + self.norm_q = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) + self.norm_k = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) + elif qk_norm == "fp32_layer_norm": + self.norm_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps) + self.norm_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps) + elif qk_norm == "layer_norm_across_heads": + # Lumina applies qk norm across all heads + self.norm_q = nn.LayerNorm(dim_head * heads, eps=eps) + self.norm_k = nn.LayerNorm(dim_head * kv_heads, eps=eps) + elif qk_norm == "rms_norm": + self.norm_q = RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) + self.norm_k = RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) + elif qk_norm == "rms_norm_across_heads": + # LTX applies qk norm across all heads + self.norm_q = RMSNorm(dim_head * heads, eps=eps) + self.norm_k = RMSNorm(dim_head * kv_heads, eps=eps) + elif qk_norm == "l2": + self.norm_q = LpNorm(p=2, dim=-1, eps=eps) + self.norm_k = LpNorm(p=2, dim=-1, eps=eps) + else: + raise ValueError( + f"unknown qk_norm: {qk_norm}. Should be one of None, 'layer_norm', 'fp32_layer_norm', 'layer_norm_across_heads', 'rms_norm', 'rms_norm_across_heads', 'l2'." + ) + + if cross_attention_norm is None: + self.norm_cross = None + elif cross_attention_norm == "layer_norm": + self.norm_cross = nn.LayerNorm(self.cross_attention_dim) + elif cross_attention_norm == "group_norm": + if self.added_kv_proj_dim is not None: + # The given `encoder_hidden_states` are initially of shape + # (batch_size, seq_len, added_kv_proj_dim) before being projected + # to (batch_size, seq_len, cross_attention_dim). The norm is applied + # before the projection, so we need to use `added_kv_proj_dim` as + # the number of channels for the group norm. + norm_cross_num_channels = added_kv_proj_dim + else: + norm_cross_num_channels = self.cross_attention_dim + + self.norm_cross = nn.GroupNorm( + num_channels=norm_cross_num_channels, num_groups=cross_attention_norm_num_groups, eps=1e-5, affine=True + ) + else: + raise ValueError( + f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'" + ) + + self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias) + + if not self.only_cross_attention: + # only relevant for the `AddedKVProcessor` classes + self.to_k = nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias) + self.to_v = nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias) + else: + self.to_k = None + self.to_v = None + + self.added_proj_bias = added_proj_bias + if self.added_kv_proj_dim is not None: + self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias) + self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias) + if self.context_pre_only is not None: + self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) + else: + self.add_q_proj = None + self.add_k_proj = None + self.add_v_proj = None + + if not self.pre_only: + self.to_out = nn.ModuleList([]) + self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)) + self.to_out.append(nn.Dropout(dropout)) + else: + self.to_out = None + + if self.context_pre_only is not None and not self.context_pre_only: + self.to_add_out = nn.Linear(self.inner_dim, self.out_context_dim, bias=out_bias) + else: + self.to_add_out = None + + if qk_norm is not None and added_kv_proj_dim is not None: + if qk_norm == "layer_norm": + self.norm_added_q = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) + self.norm_added_k = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) + elif qk_norm == "fp32_layer_norm": + self.norm_added_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps) + self.norm_added_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps) + elif qk_norm == "rms_norm": + self.norm_added_q = RMSNorm(dim_head, eps=eps) + self.norm_added_k = RMSNorm(dim_head, eps=eps) + elif qk_norm == "rms_norm_across_heads": + # Wan applies qk norm across all heads + # Wan also doesn't apply a q norm + self.norm_added_q = None + self.norm_added_k = RMSNorm(dim_head * kv_heads, eps=eps) + else: + raise ValueError( + f"unknown qk_norm: {qk_norm}. Should be one of `None,'layer_norm','fp32_layer_norm','rms_norm'`" + ) + else: + self.norm_added_q = None + self.norm_added_k = None + + # set attention processor + # We use the AttnProcessor2_0 by default when torch 2.x is used which uses + # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention + # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 + if processor is None: + processor = ( + AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor() + ) + self.set_processor(processor) + + def set_use_xla_flash_attention( + self, + use_xla_flash_attention: bool, + partition_spec: Optional[Tuple[Optional[str], ...]] = None, + is_flux=False, + ) -> None: + r""" + Set whether to use xla flash attention from `torch_xla` or not. + + Args: + use_xla_flash_attention (`bool`): + Whether to use pallas flash attention kernel from `torch_xla` or not. + partition_spec (`Tuple[]`, *optional*): + Specify the partition specification if using SPMD. Otherwise None. + """ + if use_xla_flash_attention: + if not is_torch_xla_available: + raise "torch_xla is not available" + elif is_torch_xla_version("<", "2.3"): + raise "flash attention pallas kernel is supported from torch_xla version 2.3" + elif is_spmd() and is_torch_xla_version("<", "2.4"): + raise "flash attention pallas kernel using SPMD is supported from torch_xla version 2.4" + else: + if is_flux: + processor = XLAFluxFlashAttnProcessor2_0(partition_spec) + else: + processor = XLAFlashAttnProcessor2_0(partition_spec) + else: + processor = ( + AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor() + ) + self.set_processor(processor) + + def set_use_npu_flash_attention(self, use_npu_flash_attention: bool) -> None: + r""" + Set whether to use npu flash attention from `torch_npu` or not. + + """ + if use_npu_flash_attention: + processor = AttnProcessorNPU() + else: + # set attention processor + # We use the AttnProcessor2_0 by default when torch 2.x is used which uses + # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention + # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 + processor = ( + AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor() + ) + self.set_processor(processor) + + def set_use_memory_efficient_attention_xformers( + self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None + ) -> None: + r""" + Set whether to use memory efficient attention from `xformers` or not. + + Args: + use_memory_efficient_attention_xformers (`bool`): + Whether to use memory efficient attention from `xformers` or not. + attention_op (`Callable`, *optional*): + The attention operation to use. Defaults to `None` which uses the default attention operation from + `xformers`. + """ + is_custom_diffusion = hasattr(self, "processor") and isinstance( + self.processor, + (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor, CustomDiffusionAttnProcessor2_0), + ) + is_added_kv_processor = hasattr(self, "processor") and isinstance( + self.processor, + ( + AttnAddedKVProcessor, + AttnAddedKVProcessor2_0, + SlicedAttnAddedKVProcessor, + XFormersAttnAddedKVProcessor, + ), + ) + is_ip_adapter = hasattr(self, "processor") and isinstance( + self.processor, + (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, IPAdapterXFormersAttnProcessor), + ) + is_joint_processor = hasattr(self, "processor") and isinstance( + self.processor, + ( + JointAttnProcessor2_0, + XFormersJointAttnProcessor, + ), + ) + + if use_memory_efficient_attention_xformers: + if is_added_kv_processor and is_custom_diffusion: + raise NotImplementedError( + f"Memory efficient attention is currently not supported for custom diffusion for attention processor type {self.processor}" + ) + if not is_xformers_available(): + raise ModuleNotFoundError( + ( + "Refer to https://github.com/facebookresearch/xformers for more information on how to install" + " xformers" + ), + name="xformers", + ) + elif not torch.cuda.is_available(): + raise ValueError( + "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is" + " only available for GPU " + ) + else: + try: + # Make sure we can run the memory efficient attention + dtype = None + if attention_op is not None: + op_fw, op_bw = attention_op + dtype, *_ = op_fw.SUPPORTED_DTYPES + q = torch.randn((1, 2, 40), device="cuda", dtype=dtype) + _ = xformers.ops.memory_efficient_attention(q, q, q) + except Exception as e: + raise e + + if is_custom_diffusion: + processor = CustomDiffusionXFormersAttnProcessor( + train_kv=self.processor.train_kv, + train_q_out=self.processor.train_q_out, + hidden_size=self.processor.hidden_size, + cross_attention_dim=self.processor.cross_attention_dim, + attention_op=attention_op, + ) + processor.load_state_dict(self.processor.state_dict()) + if hasattr(self.processor, "to_k_custom_diffusion"): + processor.to(self.processor.to_k_custom_diffusion.weight.device) + elif is_added_kv_processor: + # TODO(Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP + # which uses this type of cross attention ONLY because the attention mask of format + # [0, ..., -10.000, ..., 0, ...,] is not supported + # throw warning + logger.info( + "Memory efficient attention with `xformers` might currently not work correctly if an attention mask is required for the attention operation." + ) + processor = XFormersAttnAddedKVProcessor(attention_op=attention_op) + elif is_ip_adapter: + processor = IPAdapterXFormersAttnProcessor( + hidden_size=self.processor.hidden_size, + cross_attention_dim=self.processor.cross_attention_dim, + num_tokens=self.processor.num_tokens, + scale=self.processor.scale, + attention_op=attention_op, + ) + processor.load_state_dict(self.processor.state_dict()) + if hasattr(self.processor, "to_k_ip"): + processor.to( + device=self.processor.to_k_ip[0].weight.device, dtype=self.processor.to_k_ip[0].weight.dtype + ) + elif is_joint_processor: + processor = XFormersJointAttnProcessor(attention_op=attention_op) + else: + processor = XFormersAttnProcessor(attention_op=attention_op) + else: + if is_custom_diffusion: + attn_processor_class = ( + CustomDiffusionAttnProcessor2_0 + if hasattr(F, "scaled_dot_product_attention") + else CustomDiffusionAttnProcessor + ) + processor = attn_processor_class( + train_kv=self.processor.train_kv, + train_q_out=self.processor.train_q_out, + hidden_size=self.processor.hidden_size, + cross_attention_dim=self.processor.cross_attention_dim, + ) + processor.load_state_dict(self.processor.state_dict()) + if hasattr(self.processor, "to_k_custom_diffusion"): + processor.to(self.processor.to_k_custom_diffusion.weight.device) + elif is_ip_adapter: + processor = IPAdapterAttnProcessor2_0( + hidden_size=self.processor.hidden_size, + cross_attention_dim=self.processor.cross_attention_dim, + num_tokens=self.processor.num_tokens, + scale=self.processor.scale, + ) + processor.load_state_dict(self.processor.state_dict()) + if hasattr(self.processor, "to_k_ip"): + processor.to( + device=self.processor.to_k_ip[0].weight.device, dtype=self.processor.to_k_ip[0].weight.dtype + ) + else: + # set attention processor + # We use the AttnProcessor2_0 by default when torch 2.x is used which uses + # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention + # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 + processor = ( + AttnProcessor2_0() + if hasattr(F, "scaled_dot_product_attention") and self.scale_qk + else AttnProcessor() + ) + + self.set_processor(processor) + + def set_attention_slice(self, slice_size: int) -> None: + r""" + Set the slice size for attention computation. + + Args: + slice_size (`int`): + The slice size for attention computation. + """ + if slice_size is not None and slice_size > self.sliceable_head_dim: + raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.") + + if slice_size is not None and self.added_kv_proj_dim is not None: + processor = SlicedAttnAddedKVProcessor(slice_size) + elif slice_size is not None: + processor = SlicedAttnProcessor(slice_size) + elif self.added_kv_proj_dim is not None: + processor = AttnAddedKVProcessor() + else: + # set attention processor + # We use the AttnProcessor2_0 by default when torch 2.x is used which uses + # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention + # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 + processor = ( + AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor() + ) + + self.set_processor(processor) + + def set_processor(self, processor: "AttnProcessor") -> None: + r""" + Set the attention processor to use. + + Args: + processor (`AttnProcessor`): + The attention processor to use. + """ + # if current processor is in `self._modules` and if passed `processor` is not, we need to + # pop `processor` from `self._modules` + if ( + hasattr(self, "processor") + and isinstance(self.processor, torch.nn.Module) + and not isinstance(processor, torch.nn.Module) + ): + logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}") + self._modules.pop("processor") + + self.processor = processor + + def get_processor(self, return_deprecated_lora: bool = False) -> "AttentionProcessor": + r""" + Get the attention processor in use. + + Args: + return_deprecated_lora (`bool`, *optional*, defaults to `False`): + Set to `True` to return the deprecated LoRA attention processor. + + Returns: + "AttentionProcessor": The attention processor in use. + """ + if not return_deprecated_lora: + return self.processor + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + **cross_attention_kwargs, + ) -> torch.Tensor: + r""" + The forward method of the `Attention` class. + + Args: + hidden_states (`torch.Tensor`): + The hidden states of the query. + encoder_hidden_states (`torch.Tensor`, *optional*): + The hidden states of the encoder. + attention_mask (`torch.Tensor`, *optional*): + The attention mask to use. If `None`, no mask is applied. + **cross_attention_kwargs: + Additional keyword arguments to pass along to the cross attention. + + Returns: + `torch.Tensor`: The output of the attention layer. + """ + # The `Attention` class can call different attention processors / attention functions + # here we simply pass along all tensors to the selected processor class + # For standard processors that are defined here, `**cross_attention_kwargs` is empty + + attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) + quiet_attn_parameters = {"ip_adapter_masks", "ip_hidden_states"} + unused_kwargs = [ + k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters + ] + if len(unused_kwargs) > 0: + logger.warning( + f"cross_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored." + ) + cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters} + + return self.processor( + self, + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + + def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor: + r""" + Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size // heads, seq_len, dim * heads]`. `heads` + is the number of heads initialized while constructing the `Attention` class. + + Args: + tensor (`torch.Tensor`): The tensor to reshape. + + Returns: + `torch.Tensor`: The reshaped tensor. + """ + head_size = self.heads + batch_size, seq_len, dim = tensor.shape + tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) + tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) + return tensor + + def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Tensor: + r""" + Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size, seq_len, heads, dim // heads]` `heads` is + the number of heads initialized while constructing the `Attention` class. + + Args: + tensor (`torch.Tensor`): The tensor to reshape. + out_dim (`int`, *optional*, defaults to `3`): The output dimension of the tensor. If `3`, the tensor is + reshaped to `[batch_size * heads, seq_len, dim // heads]`. + + Returns: + `torch.Tensor`: The reshaped tensor. + """ + head_size = self.heads + if tensor.ndim == 3: + batch_size, seq_len, dim = tensor.shape + extra_dim = 1 + else: + batch_size, extra_dim, seq_len, dim = tensor.shape + tensor = tensor.reshape(batch_size, seq_len * extra_dim, head_size, dim // head_size) + tensor = tensor.permute(0, 2, 1, 3) + + if out_dim == 3: + tensor = tensor.reshape(batch_size * head_size, seq_len * extra_dim, dim // head_size) + + return tensor + + def get_attention_scores( + self, query: torch.Tensor, key: torch.Tensor, attention_mask: Optional[torch.Tensor] = None + ) -> torch.Tensor: + r""" + Compute the attention scores. + + Args: + query (`torch.Tensor`): The query tensor. + key (`torch.Tensor`): The key tensor. + attention_mask (`torch.Tensor`, *optional*): The attention mask to use. If `None`, no mask is applied. + + Returns: + `torch.Tensor`: The attention probabilities/scores. + """ + dtype = query.dtype + if self.upcast_attention: + query = query.float() + key = key.float() + + if attention_mask is None: + baddbmm_input = torch.empty( + query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device + ) + beta = 0 + else: + baddbmm_input = attention_mask + beta = 1 + + attention_scores = torch.baddbmm( + baddbmm_input, + query, + key.transpose(-1, -2), + beta=beta, + alpha=self.scale, + ) + del baddbmm_input + + if self.upcast_softmax: + attention_scores = attention_scores.float() + + attention_probs = attention_scores.softmax(dim=-1) + del attention_scores + + attention_probs = attention_probs.to(dtype) + + return attention_probs + + def prepare_attention_mask( + self, attention_mask: torch.Tensor, target_length: int, batch_size: int, out_dim: int = 3 + ) -> torch.Tensor: + r""" + Prepare the attention mask for the attention computation. + + Args: + attention_mask (`torch.Tensor`): + The attention mask to prepare. + target_length (`int`): + The target length of the attention mask. This is the length of the attention mask after padding. + batch_size (`int`): + The batch size, which is used to repeat the attention mask. + out_dim (`int`, *optional*, defaults to `3`): + The output dimension of the attention mask. Can be either `3` or `4`. + + Returns: + `torch.Tensor`: The prepared attention mask. + """ + head_size = self.heads + if attention_mask is None: + return attention_mask + + current_length: int = attention_mask.shape[-1] + if current_length != target_length: + if attention_mask.device.type == "mps": + # HACK: MPS: Does not support padding by greater than dimension of input tensor. + # Instead, we can manually construct the padding tensor. + padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length) + padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device) + attention_mask = torch.cat([attention_mask, padding], dim=2) + else: + # TODO: for pipelines such as stable-diffusion, padding cross-attn mask: + # we want to instead pad by (0, remaining_length), where remaining_length is: + # remaining_length: int = target_length - current_length + # TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding + attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) + + if out_dim == 3: + if attention_mask.shape[0] < batch_size * head_size: + attention_mask = attention_mask.repeat_interleave( + head_size, dim=0, output_size=attention_mask.shape[0] * head_size + ) + elif out_dim == 4: + attention_mask = attention_mask.unsqueeze(1) + attention_mask = attention_mask.repeat_interleave( + head_size, dim=1, output_size=attention_mask.shape[1] * head_size + ) + + return attention_mask + + def norm_encoder_hidden_states(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor: + r""" + Normalize the encoder hidden states. Requires `self.norm_cross` to be specified when constructing the + `Attention` class. + + Args: + encoder_hidden_states (`torch.Tensor`): Hidden states of the encoder. + + Returns: + `torch.Tensor`: The normalized encoder hidden states. + """ + assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states" + + if isinstance(self.norm_cross, nn.LayerNorm): + encoder_hidden_states = self.norm_cross(encoder_hidden_states) + elif isinstance(self.norm_cross, nn.GroupNorm): + # Group norm norms along the channels dimension and expects + # input to be in the shape of (N, C, *). In this case, we want + # to norm along the hidden dimension, so we need to move + # (batch_size, sequence_length, hidden_size) -> + # (batch_size, hidden_size, sequence_length) + encoder_hidden_states = encoder_hidden_states.transpose(1, 2) + encoder_hidden_states = self.norm_cross(encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states.transpose(1, 2) + else: + assert False + + return encoder_hidden_states + + @torch.no_grad() + def fuse_projections(self, fuse=True): + device = self.to_q.weight.data.device + dtype = self.to_q.weight.data.dtype + + if not self.is_cross_attention: + # fetch weight matrices. + concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data]) + in_features = concatenated_weights.shape[1] + out_features = concatenated_weights.shape[0] + + # create a new single projection layer and copy over the weights. + self.to_qkv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype) + self.to_qkv.weight.copy_(concatenated_weights) + if self.use_bias: + concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data]) + self.to_qkv.bias.copy_(concatenated_bias) + + else: + concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data]) + in_features = concatenated_weights.shape[1] + out_features = concatenated_weights.shape[0] + + self.to_kv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype) + self.to_kv.weight.copy_(concatenated_weights) + if self.use_bias: + concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data]) + self.to_kv.bias.copy_(concatenated_bias) + + # handle added projections for SD3 and others. + if ( + getattr(self, "add_q_proj", None) is not None + and getattr(self, "add_k_proj", None) is not None + and getattr(self, "add_v_proj", None) is not None + ): + concatenated_weights = torch.cat( + [self.add_q_proj.weight.data, self.add_k_proj.weight.data, self.add_v_proj.weight.data] + ) + in_features = concatenated_weights.shape[1] + out_features = concatenated_weights.shape[0] + + self.to_added_qkv = nn.Linear( + in_features, out_features, bias=self.added_proj_bias, device=device, dtype=dtype + ) + self.to_added_qkv.weight.copy_(concatenated_weights) + if self.added_proj_bias: + concatenated_bias = torch.cat( + [self.add_q_proj.bias.data, self.add_k_proj.bias.data, self.add_v_proj.bias.data] + ) + self.to_added_qkv.bias.copy_(concatenated_bias) + + self.fused_projections = fuse + +class AttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + temb: Optional[torch.Tensor] = None, + *args, + **kwargs, + ) -> torch.Tensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + residual = hidden_states + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + +class UNetMidBlock2D(nn.Module): + """ + A 2D UNet mid-block [`UNetMidBlock2D`] with multiple residual blocks and optional attention blocks. + + Args: + in_channels (`int`): The number of input channels. + temb_channels (`int`): The number of temporal embedding channels. + dropout (`float`, *optional*, defaults to 0.0): The dropout rate. + num_layers (`int`, *optional*, defaults to 1): The number of residual blocks. + resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks. + resnet_time_scale_shift (`str`, *optional*, defaults to `default`): + The type of normalization to apply to the time embeddings. This can help to improve the performance of the + model on tasks with long-range temporal dependencies. + resnet_act_fn (`str`, *optional*, defaults to `swish`): The activation function for the resnet blocks. + resnet_groups (`int`, *optional*, defaults to 32): + The number of groups to use in the group normalization layers of the resnet blocks. + attn_groups (`Optional[int]`, *optional*, defaults to None): The number of groups for the attention blocks. + resnet_pre_norm (`bool`, *optional*, defaults to `True`): + Whether to use pre-normalization for the resnet blocks. + add_attention (`bool`, *optional*, defaults to `True`): Whether to add attention blocks. + attention_head_dim (`int`, *optional*, defaults to 1): + Dimension of a single attention head. The number of attention heads is determined based on this value and + the number of input channels. + output_scale_factor (`float`, *optional*, defaults to 1.0): The output scale factor. + + Returns: + `torch.Tensor`: The output of the last residual block, which is a tensor of shape `(batch_size, in_channels, + height, width)`. + + """ + + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", # default, spatial + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + attn_groups: Optional[int] = None, + resnet_pre_norm: bool = True, + add_attention: bool = True, + attention_head_dim: int = 1, + output_scale_factor: float = 1.0, + ): + super().__init__() + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + self.add_attention = add_attention + + if attn_groups is None: + attn_groups = resnet_groups if resnet_time_scale_shift == "default" else None + + # there is always at least one resnet + if resnet_time_scale_shift == "spatial": + resnets = [ + ResnetBlockCondNorm2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm="spatial", + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + ) + ] + else: + resnets = [ + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ] + attentions = [] + + if attention_head_dim is None: + logger.warning( + f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {in_channels}." + ) + attention_head_dim = in_channels + + for _ in range(num_layers): + if self.add_attention: + attentions.append( + Attention( + in_channels, + heads=in_channels // attention_head_dim, + dim_head=attention_head_dim, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + norm_num_groups=attn_groups, + spatial_norm_dim=temb_channels if resnet_time_scale_shift == "spatial" else None, + residual_connection=True, + bias=True, + upcast_softmax=True, + _from_deprecated_attn_block=True, + ) + ) + else: + attentions.append(None) + + if resnet_time_scale_shift == "spatial": + resnets.append( + ResnetBlockCondNorm2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm="spatial", + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + ) + ) + else: + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + self.gradient_checkpointing = False + + def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor: + hidden_states = self.resnets[0](hidden_states, temb) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + if torch.is_grad_enabled() and self.gradient_checkpointing: + if attn is not None: + hidden_states = attn(hidden_states, temb=temb) + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb) + else: + if attn is not None: + hidden_states = attn(hidden_states, temb=temb) + hidden_states = resnet(hidden_states, temb) + + return hidden_states + +class DownEncoderBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor: float = 1.0, + add_downsample: bool = True, + downsample_padding: int = 1, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + if resnet_time_scale_shift == "spatial": + resnets.append( + ResnetBlockCondNorm2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=None, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm="spatial", + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + ) + ) + else: + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=None, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + for resnet in self.resnets: + hidden_states = resnet(hidden_states, temb=None) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + return hidden_states + + +class UpDecoderBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + resolution_idx: Optional[int] = None, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", # default, spatial + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor: float = 1.0, + add_upsample: bool = True, + temb_channels: Optional[int] = None, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + input_channels = in_channels if i == 0 else out_channels + + if resnet_time_scale_shift == "spatial": + resnets.append( + ResnetBlockCondNorm2D( + in_channels=input_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm="spatial", + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + ) + ) + else: + resnets.append( + ResnetBlock2D( + in_channels=input_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.resolution_idx = resolution_idx + + def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor: + for resnet in self.resnets: + hidden_states = resnet(hidden_states, temb=temb) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + + return hidden_states + +class Encoder(nn.Module): + r""" + The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation. + + Args: + in_channels (`int`, *optional*, defaults to 3): + The number of input channels. + out_channels (`int`, *optional*, defaults to 3): + The number of output channels. + down_block_types (`Tuple[str, ...]`, *optional*, defaults to `("DownEncoderBlock2D",)`): + The types of down blocks to use. See `~diffusers.models.unet_2d_blocks.get_down_block` for available + options. + block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`): + The number of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): + The number of layers per block. + norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups for normalization. + act_fn (`str`, *optional*, defaults to `"silu"`): + The activation function to use. See `~diffusers.models.activations.get_activation` for available options. + double_z (`bool`, *optional*, defaults to `True`): + Whether to double the number of output channels for the last block. + """ + + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",), + block_out_channels: Tuple[int, ...] = (64,), + layers_per_block: int = 2, + norm_num_groups: int = 32, + act_fn: str = "silu", + double_z: bool = True, + mid_block_add_attention=True, + ): + super().__init__() + self.layers_per_block = layers_per_block + + self.conv_in = nn.Conv2d( + in_channels, + block_out_channels[0], + kernel_size=3, + stride=1, + padding=1, + ) + + self.down_blocks = nn.ModuleList([]) + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = DownEncoderBlock2D( + num_layers=self.layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + add_downsample=not is_final_block, + resnet_eps=1e-6, + downsample_padding=0, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + # attention_head_dim=output_channel, + # temb_channels=None, + ) + self.down_blocks.append(down_block) + + # mid + self.mid_block = UNetMidBlock2D( + in_channels=block_out_channels[-1], + resnet_eps=1e-6, + resnet_act_fn=act_fn, + output_scale_factor=1, + resnet_time_scale_shift="default", + attention_head_dim=block_out_channels[-1], + resnet_groups=norm_num_groups, + temb_channels=None, + add_attention=mid_block_add_attention, + ) + + # out + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6) + self.conv_act = nn.SiLU() + + conv_out_channels = 2 * out_channels if double_z else out_channels + self.conv_out = nn.Conv2d(block_out_channels[-1], conv_out_channels, 3, padding=1) + + self.gradient_checkpointing = False + + def forward(self, sample: torch.Tensor) -> torch.Tensor: + r"""The forward method of the `Encoder` class.""" + + sample = self.conv_in(sample) + + if torch.is_grad_enabled() and self.gradient_checkpointing: + # down + for down_block in self.down_blocks: + sample = self._gradient_checkpointing_func(down_block, sample) + # middle + sample = self._gradient_checkpointing_func(self.mid_block, sample) + + else: + # down + for down_block in self.down_blocks: + sample = down_block(sample) + + # middle + sample = self.mid_block(sample) + + # post-process + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + return sample + +class Decoder(nn.Module): + r""" + The `Decoder` layer of a variational autoencoder that decodes its latent representation into an output sample. + + Args: + in_channels (`int`, *optional*, defaults to 3): + The number of input channels. + out_channels (`int`, *optional*, defaults to 3): + The number of output channels. + up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`): + The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options. + block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`): + The number of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): + The number of layers per block. + norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups for normalization. + act_fn (`str`, *optional*, defaults to `"silu"`): + The activation function to use. See `~diffusers.models.activations.get_activation` for available options. + norm_type (`str`, *optional*, defaults to `"group"`): + The normalization type to use. Can be either `"group"` or `"spatial"`. + """ + + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",), + block_out_channels: Tuple[int, ...] = (64,), + layers_per_block: int = 2, + norm_num_groups: int = 32, + act_fn: str = "silu", + norm_type: str = "group", # group, spatial + mid_block_add_attention=True, + ): + super().__init__() + self.layers_per_block = layers_per_block + + self.conv_in = nn.Conv2d( + in_channels, + block_out_channels[-1], + kernel_size=3, + stride=1, + padding=1, + ) + + self.up_blocks = nn.ModuleList([]) + + temb_channels = in_channels if norm_type == "spatial" else None + + # mid + self.mid_block = UNetMidBlock2D( + in_channels=block_out_channels[-1], + resnet_eps=1e-6, + resnet_act_fn=act_fn, + output_scale_factor=1, + resnet_time_scale_shift="default" if norm_type == "group" else norm_type, + attention_head_dim=block_out_channels[-1], + resnet_groups=norm_num_groups, + temb_channels=temb_channels, + add_attention=mid_block_add_attention, + ) + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + + is_final_block = i == len(block_out_channels) - 1 + + up_block = UpDecoderBlock2D( + num_layers=self.layers_per_block + 1, + in_channels=prev_output_channel, + out_channels=output_channel, + # prev_output_channel=prev_output_channel, + add_upsample=not is_final_block, + resnet_eps=1e-6, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + # attention_head_dim=output_channel, + temb_channels=temb_channels, + resnet_time_scale_shift=norm_type, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + if norm_type == "spatial": + self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels) + else: + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6) + self.conv_act = nn.SiLU() + self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1) + + self.gradient_checkpointing = False + + def forward( + self, + sample: torch.Tensor, + latent_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + r"""The forward method of the `Decoder` class.""" + + sample = self.conv_in(sample) + + if torch.is_grad_enabled() and self.gradient_checkpointing: + # middle + sample = self._gradient_checkpointing_func(self.mid_block, sample, latent_embeds) + + # up + for up_block in self.up_blocks: + sample = self._gradient_checkpointing_func(up_block, sample, latent_embeds) + else: + # middle + sample = self.mid_block(sample, latent_embeds) + + # up + for up_block in self.up_blocks: + sample = up_block(sample, latent_embeds) + + # post-process + if latent_embeds is None: + sample = self.conv_norm_out(sample) + else: + sample = self.conv_norm_out(sample, latent_embeds) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + return sample + + +class Flux2VAE(torch.nn.Module): + r""" + A VAE model with KL loss for encoding images into latents and decoding latent representations into images. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + + Parameters: + in_channels (int, *optional*, defaults to 3): Number of channels in the input image. + out_channels (int, *optional*, defaults to 3): Number of channels in the output. + down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`): + Tuple of downsample block types. + up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`): + Tuple of upsample block types. + block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`): + Tuple of block output channels. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space. + sample_size (`int`, *optional*, defaults to `32`): Sample input size. + force_upcast (`bool`, *optional*, default to `True`): + If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE + can be fine-tuned / trained to a lower range without losing too much precision in which case `force_upcast` + can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix + mid_block_add_attention (`bool`, *optional*, default to `True`): + If enabled, the mid_block of the Encoder and Decoder will have attention blocks. If set to false, the + mid_block will only have resnet blocks + """ + + _supports_gradient_checkpointing = True + _no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D"] + + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + down_block_types: Tuple[str, ...] = ( + "DownEncoderBlock2D", + "DownEncoderBlock2D", + "DownEncoderBlock2D", + "DownEncoderBlock2D", + ), + up_block_types: Tuple[str, ...] = ( + "UpDecoderBlock2D", + "UpDecoderBlock2D", + "UpDecoderBlock2D", + "UpDecoderBlock2D", + ), + block_out_channels: Tuple[int, ...] = ( + 128, + 256, + 512, + 512, + ), + layers_per_block: int = 2, + act_fn: str = "silu", + latent_channels: int = 32, + norm_num_groups: int = 32, + sample_size: int = 1024, # YiYi notes: not sure + force_upcast: bool = True, + use_quant_conv: bool = True, + use_post_quant_conv: bool = True, + mid_block_add_attention: bool = True, + batch_norm_eps: float = 1e-4, + batch_norm_momentum: float = 0.1, + patch_size: Tuple[int, int] = (2, 2), + ): + super().__init__() + + # pass init params to Encoder + self.encoder = Encoder( + in_channels=in_channels, + out_channels=latent_channels, + down_block_types=down_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + act_fn=act_fn, + norm_num_groups=norm_num_groups, + double_z=True, + mid_block_add_attention=mid_block_add_attention, + ) + + # pass init params to Decoder + self.decoder = Decoder( + in_channels=latent_channels, + out_channels=out_channels, + up_block_types=up_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + norm_num_groups=norm_num_groups, + act_fn=act_fn, + mid_block_add_attention=mid_block_add_attention, + ) + + self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) if use_quant_conv else None + self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1) if use_post_quant_conv else None + + self.bn = nn.BatchNorm2d( + math.prod(patch_size) * latent_channels, + eps=batch_norm_eps, + momentum=batch_norm_momentum, + affine=False, + track_running_stats=True, + ) + + self.use_slicing = False + self.use_tiling = False + + @property + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self): + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor(self, processor): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + def _encode(self, x: torch.Tensor) -> torch.Tensor: + batch_size, num_channels, height, width = x.shape + + if self.use_tiling and (width > self.tile_sample_min_size or height > self.tile_sample_min_size): + return self._tiled_encode(x) + + enc = self.encoder(x) + if self.quant_conv is not None: + enc = self.quant_conv(enc) + + return enc + + def encode( + self, x: torch.Tensor, return_dict: bool = True + ): + """ + Encode a batch of images into latents. + + Args: + x (`torch.Tensor`): Input batch of images. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. + + Returns: + The latent representations of the encoded images. If `return_dict` is True, a + [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. + """ + if self.use_slicing and x.shape[0] > 1: + encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)] + h = torch.cat(encoded_slices) + else: + h = self._encode(x) + + + h = rearrange(h, "B C (H P) (W Q) -> B (C P Q) H W", P=2, Q=2) + h = h[:, :128] + latents_bn_mean = self.bn.running_mean.view(1, -1, 1, 1).to(h.device, h.dtype) + latents_bn_std = torch.sqrt(self.bn.running_var.view(1, -1, 1, 1) + 0.0001).to( + h.device, h.dtype + ) + h = (h - latents_bn_mean) / latents_bn_std + return h + + def _decode(self, z: torch.Tensor, return_dict: bool = True): + if self.use_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size): + return self.tiled_decode(z, return_dict=return_dict) + + if self.post_quant_conv is not None: + z = self.post_quant_conv(z) + + dec = self.decoder(z) + + if not return_dict: + return (dec,) + + return dec + + def decode( + self, z: torch.FloatTensor, return_dict: bool = True, generator=None + ): + latents_bn_mean = self.bn.running_mean.view(1, -1, 1, 1).to(z.device, z.dtype) + latents_bn_std = torch.sqrt(self.bn.running_var.view(1, -1, 1, 1) + 0.0001).to( + z.device, z.dtype + ) + z = z * latents_bn_std + latents_bn_mean + z = rearrange(z, "B (C P Q) H W -> B C (H P) (W Q)", P=2, Q=2) + """ + Decode a batch of images. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + + """ + if self.use_slicing and z.shape[0] > 1: + decoded_slices = [self._decode(z_slice) for z_slice in z.split(1)] + decoded = torch.cat(decoded_slices) + else: + decoded = self._decode(z) + + if not return_dict: + return (decoded,) + + return decoded + + def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[2], b.shape[2], blend_extent) + for y in range(blend_extent): + b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent) + return b + + def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[3], b.shape[3], blend_extent) + for x in range(blend_extent): + b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent) + return b + + def _tiled_encode(self, x: torch.Tensor) -> torch.Tensor: + r"""Encode a batch of images using a tiled encoder. + + When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several + steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is + different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the + tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the + output, but they should be much less noticeable. + + Args: + x (`torch.Tensor`): Input batch of images. + + Returns: + `torch.Tensor`: + The latent representation of the encoded videos. + """ + + overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor)) + blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor) + row_limit = self.tile_latent_min_size - blend_extent + + # Split the image into 512x512 tiles and encode them separately. + rows = [] + for i in range(0, x.shape[2], overlap_size): + row = [] + for j in range(0, x.shape[3], overlap_size): + tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size] + tile = self.encoder(tile) + if self.config.use_quant_conv: + tile = self.quant_conv(tile) + row.append(tile) + rows.append(row) + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent) + result_row.append(tile[:, :, :row_limit, :row_limit]) + result_rows.append(torch.cat(result_row, dim=3)) + + enc = torch.cat(result_rows, dim=2) + return enc + + def tiled_encode(self, x: torch.Tensor, return_dict: bool = True): + r"""Encode a batch of images using a tiled encoder. + + When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several + steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is + different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the + tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the + output, but they should be much less noticeable. + + Args: + x (`torch.Tensor`): Input batch of images. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. + + Returns: + [`~models.autoencoder_kl.AutoencoderKLOutput`] or `tuple`: + If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain + `tuple` is returned. + """ + + overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor)) + blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor) + row_limit = self.tile_latent_min_size - blend_extent + + # Split the image into 512x512 tiles and encode them separately. + rows = [] + for i in range(0, x.shape[2], overlap_size): + row = [] + for j in range(0, x.shape[3], overlap_size): + tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size] + tile = self.encoder(tile) + if self.config.use_quant_conv: + tile = self.quant_conv(tile) + row.append(tile) + rows.append(row) + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent) + result_row.append(tile[:, :, :row_limit, :row_limit]) + result_rows.append(torch.cat(result_row, dim=3)) + + moments = torch.cat(result_rows, dim=2) + return moments + + def tiled_decode(self, z: torch.Tensor, return_dict: bool = True): + r""" + Decode a batch of images using a tiled decoder. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor)) + blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor) + row_limit = self.tile_sample_min_size - blend_extent + + # Split z into overlapping 64x64 tiles and decode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, z.shape[2], overlap_size): + row = [] + for j in range(0, z.shape[3], overlap_size): + tile = z[:, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size] + if self.config.use_post_quant_conv: + tile = self.post_quant_conv(tile) + decoded = self.decoder(tile) + row.append(decoded) + rows.append(row) + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent) + result_row.append(tile[:, :, :row_limit, :row_limit]) + result_rows.append(torch.cat(result_row, dim=3)) + + dec = torch.cat(result_rows, dim=2) + if not return_dict: + return (dec,) + + return dec + + def forward( + self, + sample: torch.Tensor, + sample_posterior: bool = False, + return_dict: bool = True, + generator: Optional[torch.Generator] = None, + ): + r""" + Args: + sample (`torch.Tensor`): Input sample. + sample_posterior (`bool`, *optional*, defaults to `False`): + Whether to sample from the posterior. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`DecoderOutput`] instead of a plain tuple. + """ + x = sample + posterior = self.encode(x).latent_dist + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + dec = self.decode(z).sample + + if not return_dict: + return (dec,) + + return dec diff --git a/diffsynth/models/flux_controlnet.py b/diffsynth/models/flux_controlnet.py new file mode 100644 index 0000000000000000000000000000000000000000..7fb1138bb74f7b55c6e92ac312098e4168829f0a --- /dev/null +++ b/diffsynth/models/flux_controlnet.py @@ -0,0 +1,384 @@ +import torch +from einops import rearrange, repeat +from .flux_dit import RoPEEmbedding, TimestepEmbeddings, FluxJointTransformerBlock, FluxSingleTransformerBlock, RMSNorm +# from .utils import hash_state_dict_keys, init_weights_on_device +from contextlib import contextmanager + +def hash_state_dict_keys(state_dict, with_shape=True): + keys_str = convert_state_dict_keys_to_single_str(state_dict, with_shape=with_shape) + keys_str = keys_str.encode(encoding="UTF-8") + return hashlib.md5(keys_str).hexdigest() + +@contextmanager +def init_weights_on_device(device = torch.device("meta"), include_buffers :bool = False): + + old_register_parameter = torch.nn.Module.register_parameter + if include_buffers: + old_register_buffer = torch.nn.Module.register_buffer + + def register_empty_parameter(module, name, param): + old_register_parameter(module, name, param) + if param is not None: + param_cls = type(module._parameters[name]) + kwargs = module._parameters[name].__dict__ + kwargs["requires_grad"] = param.requires_grad + module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs) + + def register_empty_buffer(module, name, buffer, persistent=True): + old_register_buffer(module, name, buffer, persistent=persistent) + if buffer is not None: + module._buffers[name] = module._buffers[name].to(device) + + def patch_tensor_constructor(fn): + def wrapper(*args, **kwargs): + kwargs["device"] = device + return fn(*args, **kwargs) + + return wrapper + + if include_buffers: + tensor_constructors_to_patch = { + torch_function_name: getattr(torch, torch_function_name) + for torch_function_name in ["empty", "zeros", "ones", "full"] + } + else: + tensor_constructors_to_patch = {} + + try: + torch.nn.Module.register_parameter = register_empty_parameter + if include_buffers: + torch.nn.Module.register_buffer = register_empty_buffer + for torch_function_name in tensor_constructors_to_patch.keys(): + setattr(torch, torch_function_name, patch_tensor_constructor(getattr(torch, torch_function_name))) + yield + finally: + torch.nn.Module.register_parameter = old_register_parameter + if include_buffers: + torch.nn.Module.register_buffer = old_register_buffer + for torch_function_name, old_torch_function in tensor_constructors_to_patch.items(): + setattr(torch, torch_function_name, old_torch_function) + +class FluxControlNet(torch.nn.Module): + def __init__(self, disable_guidance_embedder=False, num_joint_blocks=5, num_single_blocks=10, num_mode=0, mode_dict={}, additional_input_dim=0): + super().__init__() + self.pos_embedder = RoPEEmbedding(3072, 10000, [16, 56, 56]) + self.time_embedder = TimestepEmbeddings(256, 3072) + self.guidance_embedder = None if disable_guidance_embedder else TimestepEmbeddings(256, 3072) + self.pooled_text_embedder = torch.nn.Sequential(torch.nn.Linear(768, 3072), torch.nn.SiLU(), torch.nn.Linear(3072, 3072)) + self.context_embedder = torch.nn.Linear(4096, 3072) + self.x_embedder = torch.nn.Linear(64, 3072) + + self.blocks = torch.nn.ModuleList([FluxJointTransformerBlock(3072, 24) for _ in range(num_joint_blocks)]) + self.single_blocks = torch.nn.ModuleList([FluxSingleTransformerBlock(3072, 24) for _ in range(num_single_blocks)]) + + self.controlnet_blocks = torch.nn.ModuleList([torch.nn.Linear(3072, 3072) for _ in range(num_joint_blocks)]) + self.controlnet_single_blocks = torch.nn.ModuleList([torch.nn.Linear(3072, 3072) for _ in range(num_single_blocks)]) + + self.mode_dict = mode_dict + self.controlnet_mode_embedder = torch.nn.Embedding(num_mode, 3072) if len(mode_dict) > 0 else None + self.controlnet_x_embedder = torch.nn.Linear(64 + additional_input_dim, 3072) + + + def prepare_image_ids(self, latents): + batch_size, _, height, width = latents.shape + latent_image_ids = torch.zeros(height // 2, width // 2, 3) + latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None] + latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :] + + latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape + + latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1) + latent_image_ids = latent_image_ids.reshape( + batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels + ) + latent_image_ids = latent_image_ids.to(device=latents.device, dtype=latents.dtype) + + return latent_image_ids + + + def patchify(self, hidden_states): + hidden_states = rearrange(hidden_states, "B C (H P) (W Q) -> B (H W) (C P Q)", P=2, Q=2) + return hidden_states + + + def align_res_stack_to_original_blocks(self, res_stack, num_blocks, hidden_states): + if len(res_stack) == 0: + return [torch.zeros_like(hidden_states)] * num_blocks + interval = (num_blocks + len(res_stack) - 1) // len(res_stack) + aligned_res_stack = [res_stack[block_id // interval] for block_id in range(num_blocks)] + return aligned_res_stack + + + def forward( + self, + hidden_states, + controlnet_conditioning, + timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids, image_ids=None, + processor_id=None, + tiled=False, tile_size=128, tile_stride=64, + **kwargs + ): + if image_ids is None: + image_ids = self.prepare_image_ids(hidden_states) + + conditioning = self.time_embedder(timestep, hidden_states.dtype) + self.pooled_text_embedder(pooled_prompt_emb) + if self.guidance_embedder is not None: + guidance = guidance * 1000 + conditioning = conditioning + self.guidance_embedder(guidance, hidden_states.dtype) + prompt_emb = self.context_embedder(prompt_emb) + if self.controlnet_mode_embedder is not None: # Different from FluxDiT + processor_id = torch.tensor([self.mode_dict[processor_id]], dtype=torch.int) + processor_id = repeat(processor_id, "D -> B D", B=1).to(text_ids.device) + prompt_emb = torch.concat([self.controlnet_mode_embedder(processor_id), prompt_emb], dim=1) + text_ids = torch.cat([text_ids[:, :1], text_ids], dim=1) + image_rotary_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1)) + + hidden_states = self.patchify(hidden_states) + hidden_states = self.x_embedder(hidden_states) + controlnet_conditioning = self.patchify(controlnet_conditioning) # Different from FluxDiT + hidden_states = hidden_states + self.controlnet_x_embedder(controlnet_conditioning) # Different from FluxDiT + + controlnet_res_stack = [] + for block, controlnet_block in zip(self.blocks, self.controlnet_blocks): + hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb) + controlnet_res_stack.append(controlnet_block(hidden_states)) + + controlnet_single_res_stack = [] + hidden_states = torch.cat([prompt_emb, hidden_states], dim=1) + for block, controlnet_block in zip(self.single_blocks, self.controlnet_single_blocks): + hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb) + controlnet_single_res_stack.append(controlnet_block(hidden_states[:, prompt_emb.shape[1]:])) + + controlnet_res_stack = self.align_res_stack_to_original_blocks(controlnet_res_stack, 19, hidden_states[:, prompt_emb.shape[1]:]) + controlnet_single_res_stack = self.align_res_stack_to_original_blocks(controlnet_single_res_stack, 38, hidden_states[:, prompt_emb.shape[1]:]) + + return controlnet_res_stack, controlnet_single_res_stack + + + # @staticmethod + # def state_dict_converter(): + # return FluxControlNetStateDictConverter() + + def quantize(self): + def cast_to(weight, dtype=None, device=None, copy=False): + if device is None or weight.device == device: + if not copy: + if dtype is None or weight.dtype == dtype: + return weight + return weight.to(dtype=dtype, copy=copy) + + r = torch.empty_like(weight, dtype=dtype, device=device) + r.copy_(weight) + return r + + def cast_weight(s, input=None, dtype=None, device=None): + if input is not None: + if dtype is None: + dtype = input.dtype + if device is None: + device = input.device + weight = cast_to(s.weight, dtype, device) + return weight + + def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None): + if input is not None: + if dtype is None: + dtype = input.dtype + if bias_dtype is None: + bias_dtype = dtype + if device is None: + device = input.device + bias = None + weight = cast_to(s.weight, dtype, device) + bias = cast_to(s.bias, bias_dtype, device) + return weight, bias + + class quantized_layer: + class QLinear(torch.nn.Linear): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self,input,**kwargs): + weight,bias= cast_bias_weight(self,input) + return torch.nn.functional.linear(input,weight,bias) + + class QRMSNorm(torch.nn.Module): + def __init__(self, module): + super().__init__() + self.module = module + + def forward(self,hidden_states,**kwargs): + weight= cast_weight(self.module,hidden_states) + input_dtype = hidden_states.dtype + variance = hidden_states.to(torch.float32).square().mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.module.eps) + hidden_states = hidden_states.to(input_dtype) * weight + return hidden_states + + class QEmbedding(torch.nn.Embedding): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self,input,**kwargs): + weight= cast_weight(self,input) + return torch.nn.functional.embedding( + input, weight, self.padding_idx, self.max_norm, + self.norm_type, self.scale_grad_by_freq, self.sparse) + + def replace_layer(model): + for name, module in model.named_children(): + if isinstance(module,quantized_layer.QRMSNorm): + continue + if isinstance(module, torch.nn.Linear): + with init_weights_on_device(): + new_layer = quantized_layer.QLinear(module.in_features,module.out_features) + new_layer.weight = module.weight + if module.bias is not None: + new_layer.bias = module.bias + setattr(model, name, new_layer) + elif isinstance(module, RMSNorm): + if hasattr(module,"quantized"): + continue + module.quantized= True + new_layer = quantized_layer.QRMSNorm(module) + setattr(model, name, new_layer) + elif isinstance(module,torch.nn.Embedding): + rows, cols = module.weight.shape + new_layer = quantized_layer.QEmbedding( + num_embeddings=rows, + embedding_dim=cols, + _weight=module.weight, + # _freeze=module.freeze, + padding_idx=module.padding_idx, + max_norm=module.max_norm, + norm_type=module.norm_type, + scale_grad_by_freq=module.scale_grad_by_freq, + sparse=module.sparse) + setattr(model, name, new_layer) + else: + replace_layer(module) + + replace_layer(self) + + + +class FluxControlNetStateDictConverter: + def __init__(self): + pass + + def from_diffusers(self, state_dict): + hash_value = hash_state_dict_keys(state_dict) + global_rename_dict = { + "context_embedder": "context_embedder", + "x_embedder": "x_embedder", + "time_text_embed.timestep_embedder.linear_1": "time_embedder.timestep_embedder.0", + "time_text_embed.timestep_embedder.linear_2": "time_embedder.timestep_embedder.2", + "time_text_embed.guidance_embedder.linear_1": "guidance_embedder.timestep_embedder.0", + "time_text_embed.guidance_embedder.linear_2": "guidance_embedder.timestep_embedder.2", + "time_text_embed.text_embedder.linear_1": "pooled_text_embedder.0", + "time_text_embed.text_embedder.linear_2": "pooled_text_embedder.2", + "norm_out.linear": "final_norm_out.linear", + "proj_out": "final_proj_out", + } + rename_dict = { + "proj_out": "proj_out", + "norm1.linear": "norm1_a.linear", + "norm1_context.linear": "norm1_b.linear", + "attn.to_q": "attn.a_to_q", + "attn.to_k": "attn.a_to_k", + "attn.to_v": "attn.a_to_v", + "attn.to_out.0": "attn.a_to_out", + "attn.add_q_proj": "attn.b_to_q", + "attn.add_k_proj": "attn.b_to_k", + "attn.add_v_proj": "attn.b_to_v", + "attn.to_add_out": "attn.b_to_out", + "ff.net.0.proj": "ff_a.0", + "ff.net.2": "ff_a.2", + "ff_context.net.0.proj": "ff_b.0", + "ff_context.net.2": "ff_b.2", + "attn.norm_q": "attn.norm_q_a", + "attn.norm_k": "attn.norm_k_a", + "attn.norm_added_q": "attn.norm_q_b", + "attn.norm_added_k": "attn.norm_k_b", + } + rename_dict_single = { + "attn.to_q": "a_to_q", + "attn.to_k": "a_to_k", + "attn.to_v": "a_to_v", + "attn.norm_q": "norm_q_a", + "attn.norm_k": "norm_k_a", + "norm.linear": "norm.linear", + "proj_mlp": "proj_in_besides_attn", + "proj_out": "proj_out", + } + state_dict_ = {} + for name, param in state_dict.items(): + if name.endswith(".weight") or name.endswith(".bias"): + suffix = ".weight" if name.endswith(".weight") else ".bias" + prefix = name[:-len(suffix)] + if prefix in global_rename_dict: + state_dict_[global_rename_dict[prefix] + suffix] = param + elif prefix.startswith("transformer_blocks."): + names = prefix.split(".") + names[0] = "blocks" + middle = ".".join(names[2:]) + if middle in rename_dict: + name_ = ".".join(names[:2] + [rename_dict[middle]] + [suffix[1:]]) + state_dict_[name_] = param + elif prefix.startswith("single_transformer_blocks."): + names = prefix.split(".") + names[0] = "single_blocks" + middle = ".".join(names[2:]) + if middle in rename_dict_single: + name_ = ".".join(names[:2] + [rename_dict_single[middle]] + [suffix[1:]]) + state_dict_[name_] = param + else: + state_dict_[name] = param + else: + state_dict_[name] = param + for name in list(state_dict_.keys()): + if ".proj_in_besides_attn." in name: + name_ = name.replace(".proj_in_besides_attn.", ".to_qkv_mlp.") + param = torch.concat([ + state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_q.")], + state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_k.")], + state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_v.")], + state_dict_[name], + ], dim=0) + state_dict_[name_] = param + state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_q.")) + state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_k.")) + state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_v.")) + state_dict_.pop(name) + for name in list(state_dict_.keys()): + for component in ["a", "b"]: + if f".{component}_to_q." in name: + name_ = name.replace(f".{component}_to_q.", f".{component}_to_qkv.") + param = torch.concat([ + state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")], + state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")], + state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")], + ], dim=0) + state_dict_[name_] = param + state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_q.")) + state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_k.")) + state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_v.")) + if hash_value == "78d18b9101345ff695f312e7e62538c0": + extra_kwargs = {"num_mode": 10, "mode_dict": {"canny": 0, "tile": 1, "depth": 2, "blur": 3, "pose": 4, "gray": 5, "lq": 6}} + elif hash_value == "b001c89139b5f053c715fe772362dd2a": + extra_kwargs = {"num_single_blocks": 0} + elif hash_value == "52357cb26250681367488a8954c271e8": + extra_kwargs = {"num_joint_blocks": 6, "num_single_blocks": 0, "additional_input_dim": 4} + elif hash_value == "0cfd1740758423a2a854d67c136d1e8c": + extra_kwargs = {"num_joint_blocks": 4, "num_single_blocks": 1} + elif hash_value == "7f9583eb8ba86642abb9a21a4b2c9e16": + extra_kwargs = {"num_joint_blocks": 4, "num_single_blocks": 10} + elif hash_value == "43ad5aaa27dd4ee01b832ed16773fa52": + extra_kwargs = {"num_joint_blocks": 6, "num_single_blocks": 0} + else: + extra_kwargs = {} + return state_dict_, extra_kwargs + + + def from_civitai(self, state_dict): + return self.from_diffusers(state_dict) diff --git a/diffsynth/models/flux_dit.py b/diffsynth/models/flux_dit.py new file mode 100644 index 0000000000000000000000000000000000000000..46fa861dedeaee37c7f58f477f544f802f6deffa --- /dev/null +++ b/diffsynth/models/flux_dit.py @@ -0,0 +1,398 @@ +import torch +from .general_modules import TimestepEmbeddings, AdaLayerNorm, RMSNorm +from einops import rearrange + + +def interact_with_ipadapter(hidden_states, q, ip_k, ip_v, scale=1.0): + batch_size, num_tokens = hidden_states.shape[0:2] + ip_hidden_states = torch.nn.functional.scaled_dot_product_attention(q, ip_k, ip_v) + ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, num_tokens, -1) + hidden_states = hidden_states + scale * ip_hidden_states + return hidden_states + + +class RoPEEmbedding(torch.nn.Module): + def __init__(self, dim, theta, axes_dim): + super().__init__() + self.dim = dim + self.theta = theta + self.axes_dim = axes_dim + + + def rope(self, pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor: + assert dim % 2 == 0, "The dimension must be even." + + scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim + omega = 1.0 / (theta**scale) + + batch_size, seq_length = pos.shape + out = torch.einsum("...n,d->...nd", pos, omega) + cos_out = torch.cos(out) + sin_out = torch.sin(out) + + stacked_out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1) + out = stacked_out.view(batch_size, -1, dim // 2, 2, 2) + return out.float() + + + def forward(self, ids): + n_axes = ids.shape[-1] + emb = torch.cat([self.rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], dim=-3) + return emb.unsqueeze(1) + + + +class FluxJointAttention(torch.nn.Module): + def __init__(self, dim_a, dim_b, num_heads, head_dim, only_out_a=False): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + self.only_out_a = only_out_a + + self.a_to_qkv = torch.nn.Linear(dim_a, dim_a * 3) + self.b_to_qkv = torch.nn.Linear(dim_b, dim_b * 3) + + self.norm_q_a = RMSNorm(head_dim, eps=1e-6) + self.norm_k_a = RMSNorm(head_dim, eps=1e-6) + self.norm_q_b = RMSNorm(head_dim, eps=1e-6) + self.norm_k_b = RMSNorm(head_dim, eps=1e-6) + + self.a_to_out = torch.nn.Linear(dim_a, dim_a) + if not only_out_a: + self.b_to_out = torch.nn.Linear(dim_b, dim_b) + + + def apply_rope(self, xq, xk, freqs_cis): + xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) + xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2) + xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] + xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] + return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) + + def forward(self, hidden_states_a, hidden_states_b, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None): + batch_size = hidden_states_a.shape[0] + + # Part A + qkv_a = self.a_to_qkv(hidden_states_a) + qkv_a = qkv_a.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2) + q_a, k_a, v_a = qkv_a.chunk(3, dim=1) + q_a, k_a = self.norm_q_a(q_a), self.norm_k_a(k_a) + + # Part B + qkv_b = self.b_to_qkv(hidden_states_b) + qkv_b = qkv_b.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2) + q_b, k_b, v_b = qkv_b.chunk(3, dim=1) + q_b, k_b = self.norm_q_b(q_b), self.norm_k_b(k_b) + + q = torch.concat([q_b, q_a], dim=2) + k = torch.concat([k_b, k_a], dim=2) + v = torch.concat([v_b, v_a], dim=2) + + q, k = self.apply_rope(q, k, image_rotary_emb) + + hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim) + hidden_states = hidden_states.to(q.dtype) + hidden_states_b, hidden_states_a = hidden_states[:, :hidden_states_b.shape[1]], hidden_states[:, hidden_states_b.shape[1]:] + if ipadapter_kwargs_list is not None: + hidden_states_a = interact_with_ipadapter(hidden_states_a, q_a, **ipadapter_kwargs_list) + hidden_states_a = self.a_to_out(hidden_states_a) + if self.only_out_a: + return hidden_states_a + else: + hidden_states_b = self.b_to_out(hidden_states_b) + return hidden_states_a, hidden_states_b + + + +class FluxJointTransformerBlock(torch.nn.Module): + def __init__(self, dim, num_attention_heads): + super().__init__() + self.norm1_a = AdaLayerNorm(dim) + self.norm1_b = AdaLayerNorm(dim) + + self.attn = FluxJointAttention(dim, dim, num_attention_heads, dim // num_attention_heads) + + self.norm2_a = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + self.ff_a = torch.nn.Sequential( + torch.nn.Linear(dim, dim*4), + torch.nn.GELU(approximate="tanh"), + torch.nn.Linear(dim*4, dim) + ) + + self.norm2_b = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + self.ff_b = torch.nn.Sequential( + torch.nn.Linear(dim, dim*4), + torch.nn.GELU(approximate="tanh"), + torch.nn.Linear(dim*4, dim) + ) + + + def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None): + norm_hidden_states_a, gate_msa_a, shift_mlp_a, scale_mlp_a, gate_mlp_a = self.norm1_a(hidden_states_a, emb=temb) + norm_hidden_states_b, gate_msa_b, shift_mlp_b, scale_mlp_b, gate_mlp_b = self.norm1_b(hidden_states_b, emb=temb) + + # Attention + attn_output_a, attn_output_b = self.attn(norm_hidden_states_a, norm_hidden_states_b, image_rotary_emb, attn_mask, ipadapter_kwargs_list) + + # Part A + hidden_states_a = hidden_states_a + gate_msa_a * attn_output_a + norm_hidden_states_a = self.norm2_a(hidden_states_a) * (1 + scale_mlp_a) + shift_mlp_a + hidden_states_a = hidden_states_a + gate_mlp_a * self.ff_a(norm_hidden_states_a) + + # Part B + hidden_states_b = hidden_states_b + gate_msa_b * attn_output_b + norm_hidden_states_b = self.norm2_b(hidden_states_b) * (1 + scale_mlp_b) + shift_mlp_b + hidden_states_b = hidden_states_b + gate_mlp_b * self.ff_b(norm_hidden_states_b) + + return hidden_states_a, hidden_states_b + + + +class FluxSingleAttention(torch.nn.Module): + def __init__(self, dim_a, dim_b, num_heads, head_dim): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + + self.a_to_qkv = torch.nn.Linear(dim_a, dim_a * 3) + + self.norm_q_a = RMSNorm(head_dim, eps=1e-6) + self.norm_k_a = RMSNorm(head_dim, eps=1e-6) + + + def apply_rope(self, xq, xk, freqs_cis): + xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) + xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2) + xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] + xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] + return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) + + + def forward(self, hidden_states, image_rotary_emb): + batch_size = hidden_states.shape[0] + + qkv_a = self.a_to_qkv(hidden_states) + qkv_a = qkv_a.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2) + q_a, k_a, v = qkv_a.chunk(3, dim=1) + q_a, k_a = self.norm_q_a(q_a), self.norm_k_a(k_a) + + q, k = self.apply_rope(q_a, k_a, image_rotary_emb) + + hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim) + hidden_states = hidden_states.to(q.dtype) + return hidden_states + + + +class AdaLayerNormSingle(torch.nn.Module): + def __init__(self, dim): + super().__init__() + self.silu = torch.nn.SiLU() + self.linear = torch.nn.Linear(dim, 3 * dim, bias=True) + self.norm = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + + + def forward(self, x, emb): + emb = self.linear(self.silu(emb)) + shift_msa, scale_msa, gate_msa = emb.chunk(3, dim=1) + x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None] + return x, gate_msa + + + +class FluxSingleTransformerBlock(torch.nn.Module): + def __init__(self, dim, num_attention_heads): + super().__init__() + self.num_heads = num_attention_heads + self.head_dim = dim // num_attention_heads + self.dim = dim + + self.norm = AdaLayerNormSingle(dim) + self.to_qkv_mlp = torch.nn.Linear(dim, dim * (3 + 4)) + self.norm_q_a = RMSNorm(self.head_dim, eps=1e-6) + self.norm_k_a = RMSNorm(self.head_dim, eps=1e-6) + + self.proj_out = torch.nn.Linear(dim * 5, dim) + + + def apply_rope(self, xq, xk, freqs_cis): + xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) + xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2) + xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] + xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] + return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) + + + def process_attention(self, hidden_states, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None): + batch_size = hidden_states.shape[0] + + qkv = hidden_states.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2) + q, k, v = qkv.chunk(3, dim=1) + q, k = self.norm_q_a(q), self.norm_k_a(k) + + q, k = self.apply_rope(q, k, image_rotary_emb) + + hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim) + hidden_states = hidden_states.to(q.dtype) + if ipadapter_kwargs_list is not None: + hidden_states = interact_with_ipadapter(hidden_states, q, **ipadapter_kwargs_list) + return hidden_states + + + def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None): + residual = hidden_states_a + norm_hidden_states, gate = self.norm(hidden_states_a, emb=temb) + hidden_states_a = self.to_qkv_mlp(norm_hidden_states) + attn_output, mlp_hidden_states = hidden_states_a[:, :, :self.dim * 3], hidden_states_a[:, :, self.dim * 3:] + + attn_output = self.process_attention(attn_output, image_rotary_emb, attn_mask, ipadapter_kwargs_list) + mlp_hidden_states = torch.nn.functional.gelu(mlp_hidden_states, approximate="tanh") + + hidden_states_a = torch.cat([attn_output, mlp_hidden_states], dim=2) + hidden_states_a = gate.unsqueeze(1) * self.proj_out(hidden_states_a) + hidden_states_a = residual + hidden_states_a + + return hidden_states_a, hidden_states_b + + + +class AdaLayerNormContinuous(torch.nn.Module): + def __init__(self, dim): + super().__init__() + self.silu = torch.nn.SiLU() + self.linear = torch.nn.Linear(dim, dim * 2, bias=True) + self.norm = torch.nn.LayerNorm(dim, eps=1e-6, elementwise_affine=False) + + def forward(self, x, conditioning): + emb = self.linear(self.silu(conditioning)) + shift, scale = torch.chunk(emb, 2, dim=1) + x = self.norm(x) * (1 + scale)[:, None] + shift[:, None] + return x + + + +class FluxDiT(torch.nn.Module): + + _repeated_blocks = ["FluxJointTransformerBlock", "FluxSingleTransformerBlock"] + + def __init__(self, disable_guidance_embedder=False, input_dim=64, num_blocks=19): + super().__init__() + self.pos_embedder = RoPEEmbedding(3072, 10000, [16, 56, 56]) + self.time_embedder = TimestepEmbeddings(256, 3072) + self.guidance_embedder = None if disable_guidance_embedder else TimestepEmbeddings(256, 3072) + self.pooled_text_embedder = torch.nn.Sequential(torch.nn.Linear(768, 3072), torch.nn.SiLU(), torch.nn.Linear(3072, 3072)) + self.context_embedder = torch.nn.Linear(4096, 3072) + self.x_embedder = torch.nn.Linear(input_dim, 3072) + + self.blocks = torch.nn.ModuleList([FluxJointTransformerBlock(3072, 24) for _ in range(num_blocks)]) + self.single_blocks = torch.nn.ModuleList([FluxSingleTransformerBlock(3072, 24) for _ in range(38)]) + + self.final_norm_out = AdaLayerNormContinuous(3072) + self.final_proj_out = torch.nn.Linear(3072, 64) + + self.input_dim = input_dim + + + def patchify(self, hidden_states): + hidden_states = rearrange(hidden_states, "B C (H P) (W Q) -> B (H W) (C P Q)", P=2, Q=2) + return hidden_states + + + def unpatchify(self, hidden_states, height, width): + hidden_states = rearrange(hidden_states, "B (H W) (C P Q) -> B C (H P) (W Q)", P=2, Q=2, H=height//2, W=width//2) + return hidden_states + + + def prepare_image_ids(self, latents): + batch_size, _, height, width = latents.shape + latent_image_ids = torch.zeros(height // 2, width // 2, 3) + latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None] + latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :] + + latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape + + latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1) + latent_image_ids = latent_image_ids.reshape( + batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels + ) + latent_image_ids = latent_image_ids.to(device=latents.device, dtype=latents.dtype) + + return latent_image_ids + + + def construct_mask(self, entity_masks, prompt_seq_len, image_seq_len): + N = len(entity_masks) + batch_size = entity_masks[0].shape[0] + total_seq_len = N * prompt_seq_len + image_seq_len + patched_masks = [self.patchify(entity_masks[i]) for i in range(N)] + attention_mask = torch.ones((batch_size, total_seq_len, total_seq_len), dtype=torch.bool).to(device=entity_masks[0].device) + + image_start = N * prompt_seq_len + image_end = N * prompt_seq_len + image_seq_len + # prompt-image mask + for i in range(N): + prompt_start = i * prompt_seq_len + prompt_end = (i + 1) * prompt_seq_len + image_mask = torch.sum(patched_masks[i], dim=-1) > 0 + image_mask = image_mask.unsqueeze(1).repeat(1, prompt_seq_len, 1) + # prompt update with image + attention_mask[:, prompt_start:prompt_end, image_start:image_end] = image_mask + # image update with prompt + attention_mask[:, image_start:image_end, prompt_start:prompt_end] = image_mask.transpose(1, 2) + # prompt-prompt mask + for i in range(N): + for j in range(N): + if i != j: + prompt_start_i = i * prompt_seq_len + prompt_end_i = (i + 1) * prompt_seq_len + prompt_start_j = j * prompt_seq_len + prompt_end_j = (j + 1) * prompt_seq_len + attention_mask[:, prompt_start_i:prompt_end_i, prompt_start_j:prompt_end_j] = False + + attention_mask = attention_mask.float() + attention_mask[attention_mask == 0] = float('-inf') + attention_mask[attention_mask == 1] = 0 + return attention_mask + + + def process_entity_masks(self, hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids, repeat_dim): + max_masks = 0 + attention_mask = None + prompt_embs = [prompt_emb] + if entity_masks is not None: + # entity_masks + batch_size, max_masks = entity_masks.shape[0], entity_masks.shape[1] + entity_masks = entity_masks.repeat(1, 1, repeat_dim, 1, 1) + entity_masks = [entity_masks[:, i, None].squeeze(1) for i in range(max_masks)] + # global mask + global_mask = torch.ones_like(entity_masks[0]).to(device=hidden_states.device, dtype=hidden_states.dtype) + entity_masks = entity_masks + [global_mask] # append global to last + # attention mask + attention_mask = self.construct_mask(entity_masks, prompt_emb.shape[1], hidden_states.shape[1]) + attention_mask = attention_mask.to(device=hidden_states.device, dtype=hidden_states.dtype) + attention_mask = attention_mask.unsqueeze(1) + # embds: n_masks * b * seq * d + local_embs = [entity_prompt_emb[:, i, None].squeeze(1) for i in range(max_masks)] + prompt_embs = local_embs + prompt_embs # append global to last + prompt_embs = [self.context_embedder(prompt_emb) for prompt_emb in prompt_embs] + prompt_emb = torch.cat(prompt_embs, dim=1) + + # positional embedding + text_ids = torch.cat([text_ids] * (max_masks + 1), dim=1) + image_rotary_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1)) + return prompt_emb, image_rotary_emb, attention_mask + + + def forward( + self, + hidden_states, + timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids, image_ids=None, + tiled=False, tile_size=128, tile_stride=64, entity_prompt_emb=None, entity_masks=None, + use_gradient_checkpointing=False, + **kwargs + ): + # (Deprecated) The real forward is in `pipelines.flux_image`. + return None diff --git a/diffsynth/models/flux_infiniteyou.py b/diffsynth/models/flux_infiniteyou.py new file mode 100644 index 0000000000000000000000000000000000000000..861538a4b02fb6a52edee662b6efcd60f78f6916 --- /dev/null +++ b/diffsynth/models/flux_infiniteyou.py @@ -0,0 +1,129 @@ +import math +import torch +import torch.nn as nn + + +# FFN +def FeedForward(dim, mult=4): + inner_dim = int(dim * mult) + return nn.Sequential( + nn.LayerNorm(dim), + nn.Linear(dim, inner_dim, bias=False), + nn.GELU(), + nn.Linear(inner_dim, dim, bias=False), + ) + + +def reshape_tensor(x, heads): + bs, length, width = x.shape + #(bs, length, width) --> (bs, length, n_heads, dim_per_head) + x = x.view(bs, length, heads, -1) + # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head) + x = x.transpose(1, 2) + # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head) + x = x.reshape(bs, heads, length, -1) + return x + + +class PerceiverAttention(nn.Module): + + def __init__(self, *, dim, dim_head=64, heads=8): + super().__init__() + self.scale = dim_head**-0.5 + self.dim_head = dim_head + self.heads = heads + inner_dim = dim_head * heads + + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + + self.to_q = nn.Linear(dim, inner_dim, bias=False) + self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) + self.to_out = nn.Linear(inner_dim, dim, bias=False) + + def forward(self, x, latents): + """ + Args: + x (torch.Tensor): image features + shape (b, n1, D) + latent (torch.Tensor): latent features + shape (b, n2, D) + """ + x = self.norm1(x) + latents = self.norm2(latents) + + b, l, _ = latents.shape + + q = self.to_q(latents) + kv_input = torch.cat((x, latents), dim=-2) + k, v = self.to_kv(kv_input).chunk(2, dim=-1) + + q = reshape_tensor(q, self.heads) + k = reshape_tensor(k, self.heads) + v = reshape_tensor(v, self.heads) + + # attention + scale = 1 / math.sqrt(math.sqrt(self.dim_head)) + weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards + weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) + out = weight @ v + + out = out.permute(0, 2, 1, 3).reshape(b, l, -1) + + return self.to_out(out) + + +class InfiniteYouImageProjector(nn.Module): + + def __init__( + self, + dim=1280, + depth=4, + dim_head=64, + heads=20, + num_queries=8, + embedding_dim=512, + output_dim=4096, + ff_mult=4, + ): + super().__init__() + self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5) + self.proj_in = nn.Linear(embedding_dim, dim) + + self.proj_out = nn.Linear(dim, output_dim) + self.norm_out = nn.LayerNorm(output_dim) + + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append( + nn.ModuleList([ + PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), + FeedForward(dim=dim, mult=ff_mult), + ])) + + def forward(self, x): + + latents = self.latents.repeat(x.size(0), 1, 1) + latents = latents.to(dtype=x.dtype, device=x.device) + + x = self.proj_in(x) + + for attn, ff in self.layers: + latents = attn(x, latents) + latents + latents = ff(latents) + latents + + latents = self.proj_out(latents) + return self.norm_out(latents) + + @staticmethod + def state_dict_converter(): + return FluxInfiniteYouImageProjectorStateDictConverter() + + +class FluxInfiniteYouImageProjectorStateDictConverter: + + def __init__(self): + pass + + def from_diffusers(self, state_dict): + return state_dict['image_proj'] diff --git a/diffsynth/models/flux_ipadapter.py b/diffsynth/models/flux_ipadapter.py new file mode 100644 index 0000000000000000000000000000000000000000..31176fc2c2a508388502b45dc27e4d2218f16eec --- /dev/null +++ b/diffsynth/models/flux_ipadapter.py @@ -0,0 +1,110 @@ +from .general_modules import RMSNorm +from transformers import SiglipVisionModel, SiglipVisionConfig +import torch + + +class SiglipVisionModelSO400M(SiglipVisionModel): + def __init__(self): + config = SiglipVisionConfig( + hidden_size=1152, + image_size=384, + intermediate_size=4304, + model_type="siglip_vision_model", + num_attention_heads=16, + num_hidden_layers=27, + patch_size=14, + architectures=["SiglipModel"], + initializer_factor=1.0, + torch_dtype="float32", + transformers_version="4.37.0.dev0" + ) + super().__init__(config) + +class MLPProjModel(torch.nn.Module): + def __init__(self, cross_attention_dim=768, id_embeddings_dim=512, num_tokens=4): + super().__init__() + + self.cross_attention_dim = cross_attention_dim + self.num_tokens = num_tokens + + self.proj = torch.nn.Sequential( + torch.nn.Linear(id_embeddings_dim, id_embeddings_dim*2), + torch.nn.GELU(), + torch.nn.Linear(id_embeddings_dim*2, cross_attention_dim*num_tokens), + ) + self.norm = torch.nn.LayerNorm(cross_attention_dim) + + def forward(self, id_embeds): + x = self.proj(id_embeds) + x = x.reshape(-1, self.num_tokens, self.cross_attention_dim) + x = self.norm(x) + return x + +class IpAdapterModule(torch.nn.Module): + def __init__(self, num_attention_heads, attention_head_dim, input_dim): + super().__init__() + self.num_heads = num_attention_heads + self.head_dim = attention_head_dim + output_dim = num_attention_heads * attention_head_dim + self.to_k_ip = torch.nn.Linear(input_dim, output_dim, bias=False) + self.to_v_ip = torch.nn.Linear(input_dim, output_dim, bias=False) + self.norm_added_k = RMSNorm(attention_head_dim, eps=1e-5, elementwise_affine=False) + + + def forward(self, hidden_states): + batch_size = hidden_states.shape[0] + # ip_k + ip_k = self.to_k_ip(hidden_states) + ip_k = ip_k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + ip_k = self.norm_added_k(ip_k) + # ip_v + ip_v = self.to_v_ip(hidden_states) + ip_v = ip_v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + return ip_k, ip_v + + +class FluxIpAdapter(torch.nn.Module): + def __init__(self, num_attention_heads=24, attention_head_dim=128, cross_attention_dim=4096, num_tokens=128, num_blocks=57): + super().__init__() + self.ipadapter_modules = torch.nn.ModuleList([IpAdapterModule(num_attention_heads, attention_head_dim, cross_attention_dim) for _ in range(num_blocks)]) + self.image_proj = MLPProjModel(cross_attention_dim=cross_attention_dim, id_embeddings_dim=1152, num_tokens=num_tokens) + self.set_adapter() + + def set_adapter(self): + self.call_block_id = {i:i for i in range(len(self.ipadapter_modules))} + + def forward(self, hidden_states, scale=1.0): + hidden_states = self.image_proj(hidden_states) + hidden_states = hidden_states.view(1, -1, hidden_states.shape[-1]) + ip_kv_dict = {} + for block_id in self.call_block_id: + ipadapter_id = self.call_block_id[block_id] + ip_k, ip_v = self.ipadapter_modules[ipadapter_id](hidden_states) + ip_kv_dict[block_id] = { + "ip_k": ip_k, + "ip_v": ip_v, + "scale": scale + } + return ip_kv_dict + + @staticmethod + def state_dict_converter(): + return FluxIpAdapterStateDictConverter() + + +class FluxIpAdapterStateDictConverter: + def __init__(self): + pass + + def from_diffusers(self, state_dict): + state_dict_ = {} + for name in state_dict["ip_adapter"]: + name_ = 'ipadapter_modules.' + name + state_dict_[name_] = state_dict["ip_adapter"][name] + for name in state_dict["image_proj"]: + name_ = "image_proj." + name + state_dict_[name_] = state_dict["image_proj"][name] + return state_dict_ + + def from_civitai(self, state_dict): + return self.from_diffusers(state_dict) diff --git a/diffsynth/models/flux_lora_encoder.py b/diffsynth/models/flux_lora_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..13589b0611f3140479ef4faa3b7a29371caa447b --- /dev/null +++ b/diffsynth/models/flux_lora_encoder.py @@ -0,0 +1,521 @@ +import torch +from einops import rearrange + + +def low_version_attention(query, key, value, attn_bias=None): + scale = 1 / query.shape[-1] ** 0.5 + query = query * scale + attn = torch.matmul(query, key.transpose(-2, -1)) + if attn_bias is not None: + attn = attn + attn_bias + attn = attn.softmax(-1) + return attn @ value + + +class Attention(torch.nn.Module): + + def __init__(self, q_dim, num_heads, head_dim, kv_dim=None, bias_q=False, bias_kv=False, bias_out=False): + super().__init__() + dim_inner = head_dim * num_heads + kv_dim = kv_dim if kv_dim is not None else q_dim + self.num_heads = num_heads + self.head_dim = head_dim + + self.to_q = torch.nn.Linear(q_dim, dim_inner, bias=bias_q) + self.to_k = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv) + self.to_v = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv) + self.to_out = torch.nn.Linear(dim_inner, q_dim, bias=bias_out) + + def interact_with_ipadapter(self, hidden_states, q, ip_k, ip_v, scale=1.0): + batch_size = q.shape[0] + ip_k = ip_k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + ip_v = ip_v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + ip_hidden_states = torch.nn.functional.scaled_dot_product_attention(q, ip_k, ip_v) + hidden_states = hidden_states + scale * ip_hidden_states + return hidden_states + + def torch_forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None, ipadapter_kwargs=None, qkv_preprocessor=None): + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + + batch_size = encoder_hidden_states.shape[0] + + q = self.to_q(hidden_states) + k = self.to_k(encoder_hidden_states) + v = self.to_v(encoder_hidden_states) + + q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + + if qkv_preprocessor is not None: + q, k, v = qkv_preprocessor(q, k, v) + + hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) + if ipadapter_kwargs is not None: + hidden_states = self.interact_with_ipadapter(hidden_states, q, **ipadapter_kwargs) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim) + hidden_states = hidden_states.to(q.dtype) + + hidden_states = self.to_out(hidden_states) + + return hidden_states + + def xformers_forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None): + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + + q = self.to_q(hidden_states) + k = self.to_k(encoder_hidden_states) + v = self.to_v(encoder_hidden_states) + + q = rearrange(q, "b f (n d) -> (b n) f d", n=self.num_heads) + k = rearrange(k, "b f (n d) -> (b n) f d", n=self.num_heads) + v = rearrange(v, "b f (n d) -> (b n) f d", n=self.num_heads) + + if attn_mask is not None: + hidden_states = low_version_attention(q, k, v, attn_bias=attn_mask) + else: + import xformers.ops as xops + hidden_states = xops.memory_efficient_attention(q, k, v) + hidden_states = rearrange(hidden_states, "(b n) f d -> b f (n d)", n=self.num_heads) + + hidden_states = hidden_states.to(q.dtype) + hidden_states = self.to_out(hidden_states) + + return hidden_states + + def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None, ipadapter_kwargs=None, qkv_preprocessor=None): + return self.torch_forward(hidden_states, encoder_hidden_states=encoder_hidden_states, attn_mask=attn_mask, ipadapter_kwargs=ipadapter_kwargs, qkv_preprocessor=qkv_preprocessor) + + + + + +class CLIPEncoderLayer(torch.nn.Module): + def __init__(self, embed_dim, intermediate_size, num_heads=12, head_dim=64, use_quick_gelu=True): + super().__init__() + self.attn = Attention(q_dim=embed_dim, num_heads=num_heads, head_dim=head_dim, bias_q=True, bias_kv=True, bias_out=True) + self.layer_norm1 = torch.nn.LayerNorm(embed_dim) + self.layer_norm2 = torch.nn.LayerNorm(embed_dim) + self.fc1 = torch.nn.Linear(embed_dim, intermediate_size) + self.fc2 = torch.nn.Linear(intermediate_size, embed_dim) + + self.use_quick_gelu = use_quick_gelu + + def quickGELU(self, x): + return x * torch.sigmoid(1.702 * x) + + def forward(self, hidden_states, attn_mask=None): + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states = self.attn(hidden_states, attn_mask=attn_mask) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.fc1(hidden_states) + if self.use_quick_gelu: + hidden_states = self.quickGELU(hidden_states) + else: + hidden_states = torch.nn.functional.gelu(hidden_states) + hidden_states = self.fc2(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class SDTextEncoder(torch.nn.Module): + def __init__(self, embed_dim=768, vocab_size=49408, max_position_embeddings=77, num_encoder_layers=12, encoder_intermediate_size=3072): + super().__init__() + + # token_embedding + self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim) + + # position_embeds (This is a fixed tensor) + self.position_embeds = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, embed_dim)) + + # encoders + self.encoders = torch.nn.ModuleList([CLIPEncoderLayer(embed_dim, encoder_intermediate_size) for _ in range(num_encoder_layers)]) + + # attn_mask + self.attn_mask = self.attention_mask(max_position_embeddings) + + # final_layer_norm + self.final_layer_norm = torch.nn.LayerNorm(embed_dim) + + def attention_mask(self, length): + mask = torch.empty(length, length) + mask.fill_(float("-inf")) + mask.triu_(1) + return mask + + def forward(self, input_ids, clip_skip=1): + embeds = self.token_embedding(input_ids) + self.position_embeds + attn_mask = self.attn_mask.to(device=embeds.device, dtype=embeds.dtype) + for encoder_id, encoder in enumerate(self.encoders): + embeds = encoder(embeds, attn_mask=attn_mask) + if encoder_id + clip_skip == len(self.encoders): + break + embeds = self.final_layer_norm(embeds) + return embeds + + @staticmethod + def state_dict_converter(): + return SDTextEncoderStateDictConverter() + + +class SDTextEncoderStateDictConverter: + def __init__(self): + pass + + def from_diffusers(self, state_dict): + rename_dict = { + "text_model.embeddings.token_embedding.weight": "token_embedding.weight", + "text_model.embeddings.position_embedding.weight": "position_embeds", + "text_model.final_layer_norm.weight": "final_layer_norm.weight", + "text_model.final_layer_norm.bias": "final_layer_norm.bias" + } + attn_rename_dict = { + "self_attn.q_proj": "attn.to_q", + "self_attn.k_proj": "attn.to_k", + "self_attn.v_proj": "attn.to_v", + "self_attn.out_proj": "attn.to_out", + "layer_norm1": "layer_norm1", + "layer_norm2": "layer_norm2", + "mlp.fc1": "fc1", + "mlp.fc2": "fc2", + } + state_dict_ = {} + for name in state_dict: + if name in rename_dict: + param = state_dict[name] + if name == "text_model.embeddings.position_embedding.weight": + param = param.reshape((1, param.shape[0], param.shape[1])) + state_dict_[rename_dict[name]] = param + elif name.startswith("text_model.encoder.layers."): + param = state_dict[name] + names = name.split(".") + layer_id, layer_type, tail = names[3], ".".join(names[4:-1]), names[-1] + name_ = ".".join(["encoders", layer_id, attn_rename_dict[layer_type], tail]) + state_dict_[name_] = param + return state_dict_ + + def from_civitai(self, state_dict): + rename_dict = { + "cond_stage_model.transformer.text_model.embeddings.token_embedding.weight": "token_embedding.weight", + "cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm1.bias": "encoders.0.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm1.weight": "encoders.0.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm2.bias": "encoders.0.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm2.weight": "encoders.0.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc1.bias": "encoders.0.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc1.weight": "encoders.0.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc2.bias": "encoders.0.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc2.weight": "encoders.0.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.k_proj.bias": "encoders.0.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.k_proj.weight": "encoders.0.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.out_proj.bias": "encoders.0.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.out_proj.weight": "encoders.0.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.q_proj.bias": "encoders.0.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.q_proj.weight": "encoders.0.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.v_proj.bias": "encoders.0.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.v_proj.weight": "encoders.0.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm1.bias": "encoders.1.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm1.weight": "encoders.1.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm2.bias": "encoders.1.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm2.weight": "encoders.1.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc1.bias": "encoders.1.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc1.weight": "encoders.1.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc2.bias": "encoders.1.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc2.weight": "encoders.1.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.k_proj.bias": "encoders.1.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.k_proj.weight": "encoders.1.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.out_proj.bias": "encoders.1.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.out_proj.weight": "encoders.1.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.q_proj.bias": "encoders.1.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.q_proj.weight": "encoders.1.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.v_proj.bias": "encoders.1.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.v_proj.weight": "encoders.1.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm1.bias": "encoders.10.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm1.weight": "encoders.10.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm2.bias": "encoders.10.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm2.weight": "encoders.10.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc1.bias": "encoders.10.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc1.weight": "encoders.10.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc2.bias": "encoders.10.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc2.weight": "encoders.10.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.k_proj.bias": "encoders.10.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.k_proj.weight": "encoders.10.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.out_proj.bias": "encoders.10.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.out_proj.weight": "encoders.10.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.q_proj.bias": "encoders.10.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.q_proj.weight": "encoders.10.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.v_proj.bias": "encoders.10.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.v_proj.weight": "encoders.10.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm1.bias": "encoders.11.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm1.weight": "encoders.11.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm2.bias": "encoders.11.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm2.weight": "encoders.11.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc1.bias": "encoders.11.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc1.weight": "encoders.11.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc2.bias": "encoders.11.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc2.weight": "encoders.11.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.k_proj.bias": "encoders.11.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.k_proj.weight": "encoders.11.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.out_proj.bias": "encoders.11.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.out_proj.weight": "encoders.11.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.q_proj.bias": "encoders.11.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.q_proj.weight": "encoders.11.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.v_proj.bias": "encoders.11.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.v_proj.weight": "encoders.11.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm1.bias": "encoders.2.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm1.weight": "encoders.2.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm2.bias": "encoders.2.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm2.weight": "encoders.2.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc1.bias": "encoders.2.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc1.weight": "encoders.2.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc2.bias": "encoders.2.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc2.weight": "encoders.2.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.k_proj.bias": "encoders.2.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.k_proj.weight": "encoders.2.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.out_proj.bias": "encoders.2.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.out_proj.weight": "encoders.2.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.q_proj.bias": "encoders.2.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.q_proj.weight": "encoders.2.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.v_proj.bias": "encoders.2.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.v_proj.weight": "encoders.2.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm1.bias": "encoders.3.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm1.weight": "encoders.3.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm2.bias": "encoders.3.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm2.weight": "encoders.3.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc1.bias": "encoders.3.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc1.weight": "encoders.3.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc2.bias": "encoders.3.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc2.weight": "encoders.3.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.k_proj.bias": "encoders.3.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.k_proj.weight": "encoders.3.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.out_proj.bias": "encoders.3.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.out_proj.weight": "encoders.3.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.q_proj.bias": "encoders.3.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.q_proj.weight": "encoders.3.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.v_proj.bias": "encoders.3.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.v_proj.weight": "encoders.3.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm1.bias": "encoders.4.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm1.weight": "encoders.4.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm2.bias": "encoders.4.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm2.weight": "encoders.4.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc1.bias": "encoders.4.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc1.weight": "encoders.4.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc2.bias": "encoders.4.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc2.weight": "encoders.4.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.k_proj.bias": "encoders.4.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.k_proj.weight": "encoders.4.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.out_proj.bias": "encoders.4.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.out_proj.weight": "encoders.4.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.q_proj.bias": "encoders.4.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.q_proj.weight": "encoders.4.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.v_proj.bias": "encoders.4.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.v_proj.weight": "encoders.4.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm1.bias": "encoders.5.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm1.weight": "encoders.5.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm2.bias": "encoders.5.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm2.weight": "encoders.5.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc1.bias": "encoders.5.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc1.weight": "encoders.5.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc2.bias": "encoders.5.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc2.weight": "encoders.5.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.k_proj.bias": "encoders.5.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.k_proj.weight": "encoders.5.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.out_proj.bias": "encoders.5.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.out_proj.weight": "encoders.5.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.q_proj.bias": "encoders.5.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.q_proj.weight": "encoders.5.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.v_proj.bias": "encoders.5.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.v_proj.weight": "encoders.5.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm1.bias": "encoders.6.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm1.weight": "encoders.6.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm2.bias": "encoders.6.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm2.weight": "encoders.6.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc1.bias": "encoders.6.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc1.weight": "encoders.6.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc2.bias": "encoders.6.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc2.weight": "encoders.6.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.k_proj.bias": "encoders.6.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.k_proj.weight": "encoders.6.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.out_proj.bias": "encoders.6.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.out_proj.weight": "encoders.6.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.q_proj.bias": "encoders.6.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.q_proj.weight": "encoders.6.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.v_proj.bias": "encoders.6.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.v_proj.weight": "encoders.6.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm1.bias": "encoders.7.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm1.weight": "encoders.7.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm2.bias": "encoders.7.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm2.weight": "encoders.7.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc1.bias": "encoders.7.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc1.weight": "encoders.7.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc2.bias": "encoders.7.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc2.weight": "encoders.7.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.k_proj.bias": "encoders.7.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.k_proj.weight": "encoders.7.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.out_proj.bias": "encoders.7.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.out_proj.weight": "encoders.7.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.q_proj.bias": "encoders.7.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.q_proj.weight": "encoders.7.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.v_proj.bias": "encoders.7.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.v_proj.weight": "encoders.7.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm1.bias": "encoders.8.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm1.weight": "encoders.8.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm2.bias": "encoders.8.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm2.weight": "encoders.8.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc1.bias": "encoders.8.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc1.weight": "encoders.8.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc2.bias": "encoders.8.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc2.weight": "encoders.8.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.k_proj.bias": "encoders.8.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.k_proj.weight": "encoders.8.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.out_proj.bias": "encoders.8.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.out_proj.weight": "encoders.8.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.q_proj.bias": "encoders.8.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.q_proj.weight": "encoders.8.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.v_proj.bias": "encoders.8.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.v_proj.weight": "encoders.8.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm1.bias": "encoders.9.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm1.weight": "encoders.9.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm2.bias": "encoders.9.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm2.weight": "encoders.9.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc1.bias": "encoders.9.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc1.weight": "encoders.9.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc2.bias": "encoders.9.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc2.weight": "encoders.9.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.k_proj.bias": "encoders.9.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.k_proj.weight": "encoders.9.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.out_proj.bias": "encoders.9.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.out_proj.weight": "encoders.9.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.q_proj.bias": "encoders.9.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.q_proj.weight": "encoders.9.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.v_proj.bias": "encoders.9.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.v_proj.weight": "encoders.9.attn.to_v.weight", + "cond_stage_model.transformer.text_model.final_layer_norm.bias": "final_layer_norm.bias", + "cond_stage_model.transformer.text_model.final_layer_norm.weight": "final_layer_norm.weight", + "cond_stage_model.transformer.text_model.embeddings.position_embedding.weight": "position_embeds" + } + state_dict_ = {} + for name in state_dict: + if name in rename_dict: + param = state_dict[name] + if name == "cond_stage_model.transformer.text_model.embeddings.position_embedding.weight": + param = param.reshape((1, param.shape[0], param.shape[1])) + state_dict_[rename_dict[name]] = param + return state_dict_ + + + +class LoRALayerBlock(torch.nn.Module): + def __init__(self, L, dim_in, dim_out): + super().__init__() + self.x = torch.nn.Parameter(torch.randn(1, L, dim_in)) + self.layer_norm = torch.nn.LayerNorm(dim_out) + + def forward(self, lora_A, lora_B): + x = self.x @ lora_A.T @ lora_B.T + x = self.layer_norm(x) + return x + + +class LoRAEmbedder(torch.nn.Module): + def __init__(self, lora_patterns=None, L=1, out_dim=2048): + super().__init__() + if lora_patterns is None: + lora_patterns = self.default_lora_patterns() + + model_dict = {} + for lora_pattern in lora_patterns: + name, dim = lora_pattern["name"], lora_pattern["dim"] + model_dict[name.replace(".", "___")] = LoRALayerBlock(L, dim[0], dim[1]) + self.model_dict = torch.nn.ModuleDict(model_dict) + + proj_dict = {} + for lora_pattern in lora_patterns: + layer_type, dim = lora_pattern["type"], lora_pattern["dim"] + if layer_type not in proj_dict: + proj_dict[layer_type.replace(".", "___")] = torch.nn.Linear(dim[1], out_dim) + self.proj_dict = torch.nn.ModuleDict(proj_dict) + + self.lora_patterns = lora_patterns + + + def default_lora_patterns(self): + lora_patterns = [] + lora_dict = { + "attn.a_to_qkv": (3072, 9216), "attn.a_to_out": (3072, 3072), "ff_a.0": (3072, 12288), "ff_a.2": (12288, 3072), "norm1_a.linear": (3072, 18432), + "attn.b_to_qkv": (3072, 9216), "attn.b_to_out": (3072, 3072), "ff_b.0": (3072, 12288), "ff_b.2": (12288, 3072), "norm1_b.linear": (3072, 18432), + } + for i in range(19): + for suffix in lora_dict: + lora_patterns.append({ + "name": f"blocks.{i}.{suffix}", + "dim": lora_dict[suffix], + "type": suffix, + }) + lora_dict = {"to_qkv_mlp": (3072, 21504), "proj_out": (15360, 3072), "norm.linear": (3072, 9216)} + for i in range(38): + for suffix in lora_dict: + lora_patterns.append({ + "name": f"single_blocks.{i}.{suffix}", + "dim": lora_dict[suffix], + "type": suffix, + }) + return lora_patterns + + def forward(self, lora): + lora_emb = [] + for lora_pattern in self.lora_patterns: + name, layer_type = lora_pattern["name"], lora_pattern["type"] + lora_A = lora[name + ".lora_A.weight"] + lora_B = lora[name + ".lora_B.weight"] + lora_out = self.model_dict[name.replace(".", "___")](lora_A, lora_B) + lora_out = self.proj_dict[layer_type.replace(".", "___")](lora_out) + lora_emb.append(lora_out) + lora_emb = torch.concat(lora_emb, dim=1) + return lora_emb + + +class FluxLoRAEncoder(torch.nn.Module): + def __init__(self, embed_dim=4096, encoder_intermediate_size=8192, num_encoder_layers=1, num_embeds_per_lora=16, num_special_embeds=1): + super().__init__() + self.num_embeds_per_lora = num_embeds_per_lora + # embedder + self.embedder = LoRAEmbedder(L=num_embeds_per_lora, out_dim=embed_dim) + + # encoders + self.encoders = torch.nn.ModuleList([CLIPEncoderLayer(embed_dim, encoder_intermediate_size, num_heads=32, head_dim=128) for _ in range(num_encoder_layers)]) + + # special embedding + self.special_embeds = torch.nn.Parameter(torch.randn(1, num_special_embeds, embed_dim)) + self.num_special_embeds = num_special_embeds + + # final layer + self.final_layer_norm = torch.nn.LayerNorm(embed_dim) + self.final_linear = torch.nn.Linear(embed_dim, embed_dim) + + def forward(self, lora): + lora_embeds = self.embedder(lora) + special_embeds = self.special_embeds.to(dtype=lora_embeds.dtype, device=lora_embeds.device) + embeds = torch.concat([special_embeds, lora_embeds], dim=1) + for encoder_id, encoder in enumerate(self.encoders): + embeds = encoder(embeds) + embeds = embeds[:, :self.num_special_embeds] + embeds = self.final_layer_norm(embeds) + embeds = self.final_linear(embeds) + return embeds + + @staticmethod + def state_dict_converter(): + return FluxLoRAEncoderStateDictConverter() + + +class FluxLoRAEncoderStateDictConverter: + def from_civitai(self, state_dict): + return state_dict diff --git a/diffsynth/models/flux_lora_patcher.py b/diffsynth/models/flux_lora_patcher.py new file mode 100644 index 0000000000000000000000000000000000000000..5d8fc8cea03bbbc658d8e1869432d903b6ae7ce9 --- /dev/null +++ b/diffsynth/models/flux_lora_patcher.py @@ -0,0 +1,306 @@ +import torch, math +from ..core.loader import load_state_dict +from typing import Union + +class GeneralLoRALoader: + def __init__(self, device="cpu", torch_dtype=torch.float32): + self.device = device + self.torch_dtype = torch_dtype + + + def get_name_dict(self, lora_state_dict): + lora_name_dict = {} + for key in lora_state_dict: + if ".lora_B." not in key: + continue + keys = key.split(".") + if len(keys) > keys.index("lora_B") + 2: + keys.pop(keys.index("lora_B") + 1) + keys.pop(keys.index("lora_B")) + if keys[0] == "diffusion_model": + keys.pop(0) + keys.pop(-1) + target_name = ".".join(keys) + lora_name_dict[target_name] = (key, key.replace(".lora_B.", ".lora_A.")) + return lora_name_dict + + + def load(self, model: torch.nn.Module, state_dict_lora, alpha=1.0): + updated_num = 0 + lora_name_dict = self.get_name_dict(state_dict_lora) + for name, module in model.named_modules(): + if name in lora_name_dict: + weight_up = state_dict_lora[lora_name_dict[name][0]].to(device=self.device, dtype=self.torch_dtype) + weight_down = state_dict_lora[lora_name_dict[name][1]].to(device=self.device, dtype=self.torch_dtype) + if len(weight_up.shape) == 4: + weight_up = weight_up.squeeze(3).squeeze(2) + weight_down = weight_down.squeeze(3).squeeze(2) + weight_lora = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3) + else: + weight_lora = alpha * torch.mm(weight_up, weight_down) + state_dict = module.state_dict() + state_dict["weight"] = state_dict["weight"].to(device=self.device, dtype=self.torch_dtype) + weight_lora + module.load_state_dict(state_dict) + updated_num += 1 + print(f"{updated_num} tensors are updated by LoRA.") + +class FluxLoRALoader(GeneralLoRALoader): + def __init__(self, device="cpu", torch_dtype=torch.float32): + super().__init__(device=device, torch_dtype=torch_dtype) + + self.diffusers_rename_dict = { + "transformer.single_transformer_blocks.blockid.attn.to_k.lora_A.weight":"single_blocks.blockid.a_to_k.lora_A.default.weight", + "transformer.single_transformer_blocks.blockid.attn.to_k.lora_B.weight":"single_blocks.blockid.a_to_k.lora_B.default.weight", + "transformer.single_transformer_blocks.blockid.attn.to_q.lora_A.weight":"single_blocks.blockid.a_to_q.lora_A.default.weight", + "transformer.single_transformer_blocks.blockid.attn.to_q.lora_B.weight":"single_blocks.blockid.a_to_q.lora_B.default.weight", + "transformer.single_transformer_blocks.blockid.attn.to_v.lora_A.weight":"single_blocks.blockid.a_to_v.lora_A.default.weight", + "transformer.single_transformer_blocks.blockid.attn.to_v.lora_B.weight":"single_blocks.blockid.a_to_v.lora_B.default.weight", + "transformer.single_transformer_blocks.blockid.norm.linear.lora_A.weight":"single_blocks.blockid.norm.linear.lora_A.default.weight", + "transformer.single_transformer_blocks.blockid.norm.linear.lora_B.weight":"single_blocks.blockid.norm.linear.lora_B.default.weight", + "transformer.single_transformer_blocks.blockid.proj_mlp.lora_A.weight":"single_blocks.blockid.proj_in_besides_attn.lora_A.default.weight", + "transformer.single_transformer_blocks.blockid.proj_mlp.lora_B.weight":"single_blocks.blockid.proj_in_besides_attn.lora_B.default.weight", + "transformer.single_transformer_blocks.blockid.proj_out.lora_A.weight":"single_blocks.blockid.proj_out.lora_A.default.weight", + "transformer.single_transformer_blocks.blockid.proj_out.lora_B.weight":"single_blocks.blockid.proj_out.lora_B.default.weight", + "transformer.transformer_blocks.blockid.attn.add_k_proj.lora_A.weight":"blocks.blockid.attn.b_to_k.lora_A.default.weight", + "transformer.transformer_blocks.blockid.attn.add_k_proj.lora_B.weight":"blocks.blockid.attn.b_to_k.lora_B.default.weight", + "transformer.transformer_blocks.blockid.attn.add_q_proj.lora_A.weight":"blocks.blockid.attn.b_to_q.lora_A.default.weight", + "transformer.transformer_blocks.blockid.attn.add_q_proj.lora_B.weight":"blocks.blockid.attn.b_to_q.lora_B.default.weight", + "transformer.transformer_blocks.blockid.attn.add_v_proj.lora_A.weight":"blocks.blockid.attn.b_to_v.lora_A.default.weight", + "transformer.transformer_blocks.blockid.attn.add_v_proj.lora_B.weight":"blocks.blockid.attn.b_to_v.lora_B.default.weight", + "transformer.transformer_blocks.blockid.attn.to_add_out.lora_A.weight":"blocks.blockid.attn.b_to_out.lora_A.default.weight", + "transformer.transformer_blocks.blockid.attn.to_add_out.lora_B.weight":"blocks.blockid.attn.b_to_out.lora_B.default.weight", + "transformer.transformer_blocks.blockid.attn.to_k.lora_A.weight":"blocks.blockid.attn.a_to_k.lora_A.default.weight", + "transformer.transformer_blocks.blockid.attn.to_k.lora_B.weight":"blocks.blockid.attn.a_to_k.lora_B.default.weight", + "transformer.transformer_blocks.blockid.attn.to_out.0.lora_A.weight":"blocks.blockid.attn.a_to_out.lora_A.default.weight", + "transformer.transformer_blocks.blockid.attn.to_out.0.lora_B.weight":"blocks.blockid.attn.a_to_out.lora_B.default.weight", + "transformer.transformer_blocks.blockid.attn.to_q.lora_A.weight":"blocks.blockid.attn.a_to_q.lora_A.default.weight", + "transformer.transformer_blocks.blockid.attn.to_q.lora_B.weight":"blocks.blockid.attn.a_to_q.lora_B.default.weight", + "transformer.transformer_blocks.blockid.attn.to_v.lora_A.weight":"blocks.blockid.attn.a_to_v.lora_A.default.weight", + "transformer.transformer_blocks.blockid.attn.to_v.lora_B.weight":"blocks.blockid.attn.a_to_v.lora_B.default.weight", + "transformer.transformer_blocks.blockid.ff.net.0.proj.lora_A.weight":"blocks.blockid.ff_a.0.lora_A.default.weight", + "transformer.transformer_blocks.blockid.ff.net.0.proj.lora_B.weight":"blocks.blockid.ff_a.0.lora_B.default.weight", + "transformer.transformer_blocks.blockid.ff.net.2.lora_A.weight":"blocks.blockid.ff_a.2.lora_A.default.weight", + "transformer.transformer_blocks.blockid.ff.net.2.lora_B.weight":"blocks.blockid.ff_a.2.lora_B.default.weight", + "transformer.transformer_blocks.blockid.ff_context.net.0.proj.lora_A.weight":"blocks.blockid.ff_b.0.lora_A.default.weight", + "transformer.transformer_blocks.blockid.ff_context.net.0.proj.lora_B.weight":"blocks.blockid.ff_b.0.lora_B.default.weight", + "transformer.transformer_blocks.blockid.ff_context.net.2.lora_A.weight":"blocks.blockid.ff_b.2.lora_A.default.weight", + "transformer.transformer_blocks.blockid.ff_context.net.2.lora_B.weight":"blocks.blockid.ff_b.2.lora_B.default.weight", + "transformer.transformer_blocks.blockid.norm1.linear.lora_A.weight":"blocks.blockid.norm1_a.linear.lora_A.default.weight", + "transformer.transformer_blocks.blockid.norm1.linear.lora_B.weight":"blocks.blockid.norm1_a.linear.lora_B.default.weight", + "transformer.transformer_blocks.blockid.norm1_context.linear.lora_A.weight":"blocks.blockid.norm1_b.linear.lora_A.default.weight", + "transformer.transformer_blocks.blockid.norm1_context.linear.lora_B.weight":"blocks.blockid.norm1_b.linear.lora_B.default.weight", + } + + self.civitai_rename_dict = { + "lora_unet_double_blocks_blockid_img_mod_lin.lora_down.weight": "blocks.blockid.norm1_a.linear.lora_A.default.weight", + "lora_unet_double_blocks_blockid_img_mod_lin.lora_up.weight": "blocks.blockid.norm1_a.linear.lora_B.default.weight", + "lora_unet_double_blocks_blockid_txt_mod_lin.lora_down.weight": "blocks.blockid.norm1_b.linear.lora_A.default.weight", + "lora_unet_double_blocks_blockid_txt_mod_lin.lora_up.weight": "blocks.blockid.norm1_b.linear.lora_B.default.weight", + "lora_unet_double_blocks_blockid_img_attn_qkv.lora_down.weight": "blocks.blockid.attn.a_to_qkv.lora_A.default.weight", + "lora_unet_double_blocks_blockid_img_attn_qkv.lora_up.weight": "blocks.blockid.attn.a_to_qkv.lora_B.default.weight", + "lora_unet_double_blocks_blockid_txt_attn_qkv.lora_down.weight": "blocks.blockid.attn.b_to_qkv.lora_A.default.weight", + "lora_unet_double_blocks_blockid_txt_attn_qkv.lora_up.weight": "blocks.blockid.attn.b_to_qkv.lora_B.default.weight", + "lora_unet_double_blocks_blockid_img_attn_proj.lora_down.weight": "blocks.blockid.attn.a_to_out.lora_A.default.weight", + "lora_unet_double_blocks_blockid_img_attn_proj.lora_up.weight": "blocks.blockid.attn.a_to_out.lora_B.default.weight", + "lora_unet_double_blocks_blockid_txt_attn_proj.lora_down.weight": "blocks.blockid.attn.b_to_out.lora_A.default.weight", + "lora_unet_double_blocks_blockid_txt_attn_proj.lora_up.weight": "blocks.blockid.attn.b_to_out.lora_B.default.weight", + "lora_unet_double_blocks_blockid_img_mlp_0.lora_down.weight": "blocks.blockid.ff_a.0.lora_A.default.weight", + "lora_unet_double_blocks_blockid_img_mlp_0.lora_up.weight": "blocks.blockid.ff_a.0.lora_B.default.weight", + "lora_unet_double_blocks_blockid_img_mlp_2.lora_down.weight": "blocks.blockid.ff_a.2.lora_A.default.weight", + "lora_unet_double_blocks_blockid_img_mlp_2.lora_up.weight": "blocks.blockid.ff_a.2.lora_B.default.weight", + "lora_unet_double_blocks_blockid_txt_mlp_0.lora_down.weight": "blocks.blockid.ff_b.0.lora_A.default.weight", + "lora_unet_double_blocks_blockid_txt_mlp_0.lora_up.weight": "blocks.blockid.ff_b.0.lora_B.default.weight", + "lora_unet_double_blocks_blockid_txt_mlp_2.lora_down.weight": "blocks.blockid.ff_b.2.lora_A.default.weight", + "lora_unet_double_blocks_blockid_txt_mlp_2.lora_up.weight": "blocks.blockid.ff_b.2.lora_B.default.weight", + "lora_unet_single_blocks_blockid_modulation_lin.lora_down.weight": "single_blocks.blockid.norm.linear.lora_A.default.weight", + "lora_unet_single_blocks_blockid_modulation_lin.lora_up.weight": "single_blocks.blockid.norm.linear.lora_B.default.weight", + "lora_unet_single_blocks_blockid_linear1.lora_down.weight": "single_blocks.blockid.to_qkv_mlp.lora_A.default.weight", + "lora_unet_single_blocks_blockid_linear1.lora_up.weight": "single_blocks.blockid.to_qkv_mlp.lora_B.default.weight", + "lora_unet_single_blocks_blockid_linear2.lora_down.weight": "single_blocks.blockid.proj_out.lora_A.default.weight", + "lora_unet_single_blocks_blockid_linear2.lora_up.weight": "single_blocks.blockid.proj_out.lora_B.default.weight", + } + + def load(self, model: torch.nn.Module, state_dict_lora, alpha=1.0): + super().load(model, state_dict_lora, alpha) + + + def convert_state_dict(self,state_dict): + + def guess_block_id(name,model_resource): + if model_resource == 'civitai': + names = name.split("_") + for i in names: + if i.isdigit(): + return i, name.replace(f"_{i}_", "_blockid_") + if model_resource == 'diffusers': + names = name.split(".") + for i in names: + if i.isdigit(): + return i, name.replace(f"transformer_blocks.{i}.", "transformer_blocks.blockid.") + return None, None + + def guess_resource(state_dict): + for k in state_dict: + if "lora_unet_" in k: + return 'civitai' + elif k.startswith("transformer."): + return 'diffusers' + else: + None + + model_resource = guess_resource(state_dict) + if model_resource is None: + return state_dict + + rename_dict = self.diffusers_rename_dict if model_resource == 'diffusers' else self.civitai_rename_dict + def guess_alpha(state_dict): + for name, param in state_dict.items(): + if ".alpha" in name: + for suffix in [".lora_down.weight", ".lora_A.weight"]: + name_ = name.replace(".alpha", suffix) + if name_ in state_dict: + lora_alpha = param.item() / state_dict[name_].shape[0] + lora_alpha = math.sqrt(lora_alpha) + return lora_alpha + + return 1 + + alpha = guess_alpha(state_dict) + + state_dict_ = {} + for name, param in state_dict.items(): + block_id, source_name = guess_block_id(name,model_resource) + if alpha != 1: + param *= alpha + if source_name in rename_dict: + target_name = rename_dict[source_name] + target_name = target_name.replace(".blockid.", f".{block_id}.") + state_dict_[target_name] = param + else: + state_dict_[name] = param + + if model_resource == 'diffusers': + for name in list(state_dict_.keys()): + if "single_blocks." in name and ".a_to_q." in name: + mlp = state_dict_.get(name.replace(".a_to_q.", ".proj_in_besides_attn."), None) + if mlp is None: + dim = 4 + if 'lora_A' in name: + dim = 1 + mlp = torch.zeros(dim * state_dict_[name].shape[0], + *state_dict_[name].shape[1:], + dtype=state_dict_[name].dtype) + else: + state_dict_.pop(name.replace(".a_to_q.", ".proj_in_besides_attn.")) + if 'lora_A' in name: + param = torch.concat([ + state_dict_.pop(name), + state_dict_.pop(name.replace(".a_to_q.", ".a_to_k.")), + state_dict_.pop(name.replace(".a_to_q.", ".a_to_v.")), + mlp, + ], dim=0) + elif 'lora_B' in name: + d, r = state_dict_[name].shape + param = torch.zeros((3*d+mlp.shape[0], 3*r+mlp.shape[1]), dtype=state_dict_[name].dtype, device=state_dict_[name].device) + param[:d, :r] = state_dict_.pop(name) + param[d:2*d, r:2*r] = state_dict_.pop(name.replace(".a_to_q.", ".a_to_k.")) + param[2*d:3*d, 2*r:3*r] = state_dict_.pop(name.replace(".a_to_q.", ".a_to_v.")) + param[3*d:, 3*r:] = mlp + else: + param = torch.concat([ + state_dict_.pop(name), + state_dict_.pop(name.replace(".a_to_q.", ".a_to_k.")), + state_dict_.pop(name.replace(".a_to_q.", ".a_to_v.")), + mlp, + ], dim=0) + name_ = name.replace(".a_to_q.", ".to_qkv_mlp.") + state_dict_[name_] = param + for name in list(state_dict_.keys()): + for component in ["a", "b"]: + if f".{component}_to_q." in name: + name_ = name.replace(f".{component}_to_q.", f".{component}_to_qkv.") + concat_dim = 0 + if 'lora_A' in name: + param = torch.concat([ + state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")], + state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")], + state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")], + ], dim=0) + elif 'lora_B' in name: + origin = state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")] + d, r = origin.shape + # print(d, r) + param = torch.zeros((3*d, 3*r), dtype=origin.dtype, device=origin.device) + param[:d, :r] = state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")] + param[d:2*d, r:2*r] = state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")] + param[2*d:3*d, 2*r:3*r] = state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")] + else: + param = torch.concat([ + state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")], + state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")], + state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")], + ], dim=0) + state_dict_[name_] = param + state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_q.")) + state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_k.")) + state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_v.")) + return state_dict_ + + +class LoraMerger(torch.nn.Module): + def __init__(self, dim): + super().__init__() + self.weight_base = torch.nn.Parameter(torch.randn((dim,))) + self.weight_lora = torch.nn.Parameter(torch.randn((dim,))) + self.weight_cross = torch.nn.Parameter(torch.randn((dim,))) + self.weight_out = torch.nn.Parameter(torch.ones((dim,))) + self.bias = torch.nn.Parameter(torch.randn((dim,))) + self.activation = torch.nn.Sigmoid() + self.norm_base = torch.nn.LayerNorm(dim, eps=1e-5) + self.norm_lora = torch.nn.LayerNorm(dim, eps=1e-5) + + def forward(self, base_output, lora_outputs): + norm_base_output = self.norm_base(base_output) + norm_lora_outputs = self.norm_lora(lora_outputs) + gate = self.activation( + norm_base_output * self.weight_base \ + + norm_lora_outputs * self.weight_lora \ + + norm_base_output * norm_lora_outputs * self.weight_cross + self.bias + ) + output = base_output + (self.weight_out * gate * lora_outputs).sum(dim=0) + return output + +class FluxLoraPatcher(torch.nn.Module): + def __init__(self, lora_patterns=None): + super().__init__() + if lora_patterns is None: + lora_patterns = self.default_lora_patterns() + model_dict = {} + for lora_pattern in lora_patterns: + name, dim = lora_pattern["name"], lora_pattern["dim"] + model_dict[name.replace(".", "___")] = LoraMerger(dim) + self.model_dict = torch.nn.ModuleDict(model_dict) + + def default_lora_patterns(self): + lora_patterns = [] + lora_dict = { + "attn.a_to_qkv": 9216, "attn.a_to_out": 3072, "ff_a.0": 12288, "ff_a.2": 3072, "norm1_a.linear": 18432, + "attn.b_to_qkv": 9216, "attn.b_to_out": 3072, "ff_b.0": 12288, "ff_b.2": 3072, "norm1_b.linear": 18432, + } + for i in range(19): + for suffix in lora_dict: + lora_patterns.append({ + "name": f"blocks.{i}.{suffix}", + "dim": lora_dict[suffix] + }) + lora_dict = {"to_qkv_mlp": 21504, "proj_out": 3072, "norm.linear": 9216} + for i in range(38): + for suffix in lora_dict: + lora_patterns.append({ + "name": f"single_blocks.{i}.{suffix}", + "dim": lora_dict[suffix] + }) + return lora_patterns + + def forward(self, base_output, lora_outputs, name): + return self.model_dict[name.replace(".", "___")](base_output, lora_outputs) diff --git a/diffsynth/models/flux_text_encoder_clip.py b/diffsynth/models/flux_text_encoder_clip.py new file mode 100644 index 0000000000000000000000000000000000000000..1425423ce6d1df946198a16a7e96078ab8fed807 --- /dev/null +++ b/diffsynth/models/flux_text_encoder_clip.py @@ -0,0 +1,112 @@ +import torch + + +class Attention(torch.nn.Module): + + def __init__(self, q_dim, num_heads, head_dim, kv_dim=None, bias_q=False, bias_kv=False, bias_out=False): + super().__init__() + dim_inner = head_dim * num_heads + kv_dim = kv_dim if kv_dim is not None else q_dim + self.num_heads = num_heads + self.head_dim = head_dim + + self.to_q = torch.nn.Linear(q_dim, dim_inner, bias=bias_q) + self.to_k = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv) + self.to_v = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv) + self.to_out = torch.nn.Linear(dim_inner, q_dim, bias=bias_out) + + def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None): + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + + batch_size = encoder_hidden_states.shape[0] + + q = self.to_q(hidden_states) + k = self.to_k(encoder_hidden_states) + v = self.to_v(encoder_hidden_states) + + q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + + hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim) + hidden_states = hidden_states.to(q.dtype) + + hidden_states = self.to_out(hidden_states) + + return hidden_states + + +class CLIPEncoderLayer(torch.nn.Module): + def __init__(self, embed_dim, intermediate_size, num_heads=12, head_dim=64, use_quick_gelu=True): + super().__init__() + self.attn = Attention(q_dim=embed_dim, num_heads=num_heads, head_dim=head_dim, bias_q=True, bias_kv=True, bias_out=True) + self.layer_norm1 = torch.nn.LayerNorm(embed_dim) + self.layer_norm2 = torch.nn.LayerNorm(embed_dim) + self.fc1 = torch.nn.Linear(embed_dim, intermediate_size) + self.fc2 = torch.nn.Linear(intermediate_size, embed_dim) + + self.use_quick_gelu = use_quick_gelu + + def quickGELU(self, x): + return x * torch.sigmoid(1.702 * x) + + def forward(self, hidden_states, attn_mask=None): + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states = self.attn(hidden_states, attn_mask=attn_mask) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.fc1(hidden_states) + if self.use_quick_gelu: + hidden_states = self.quickGELU(hidden_states) + else: + hidden_states = torch.nn.functional.gelu(hidden_states) + hidden_states = self.fc2(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class FluxTextEncoderClip(torch.nn.Module): + def __init__(self, embed_dim=768, vocab_size=49408, max_position_embeddings=77, num_encoder_layers=12, encoder_intermediate_size=3072): + super().__init__() + + # token_embedding + self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim) + + # position_embeds (This is a fixed tensor) + self.position_embeds = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, embed_dim)) + + # encoders + self.encoders = torch.nn.ModuleList([CLIPEncoderLayer(embed_dim, encoder_intermediate_size) for _ in range(num_encoder_layers)]) + + # attn_mask + self.attn_mask = self.attention_mask(max_position_embeddings) + + # final_layer_norm + self.final_layer_norm = torch.nn.LayerNorm(embed_dim) + + def attention_mask(self, length): + mask = torch.empty(length, length) + mask.fill_(float("-inf")) + mask.triu_(1) + return mask + + def forward(self, input_ids, clip_skip=2, extra_mask=None): + embeds = self.token_embedding(input_ids) + embeds = embeds + self.position_embeds.to(dtype=embeds.dtype, device=input_ids.device) + attn_mask = self.attn_mask.to(device=embeds.device, dtype=embeds.dtype) + if extra_mask is not None: + attn_mask[:, extra_mask[0]==0] = float("-inf") + for encoder_id, encoder in enumerate(self.encoders): + embeds = encoder(embeds, attn_mask=attn_mask) + if encoder_id + clip_skip == len(self.encoders): + hidden_states = embeds + embeds = self.final_layer_norm(embeds) + pooled_embeds = embeds[torch.arange(embeds.shape[0]), input_ids.to(dtype=torch.int).argmax(dim=-1)] + return pooled_embeds, hidden_states diff --git a/diffsynth/models/flux_text_encoder_t5.py b/diffsynth/models/flux_text_encoder_t5.py new file mode 100644 index 0000000000000000000000000000000000000000..ee72e4a89b2089b62c6ea86c4aef91755ee9ee9a --- /dev/null +++ b/diffsynth/models/flux_text_encoder_t5.py @@ -0,0 +1,43 @@ +import torch +from transformers import T5EncoderModel, T5Config + + +class FluxTextEncoderT5(T5EncoderModel): + def __init__(self): + config = T5Config(**{ + "architectures": [ + "T5EncoderModel" + ], + "classifier_dropout": 0.0, + "d_ff": 10240, + "d_kv": 64, + "d_model": 4096, + "decoder_start_token_id": 0, + "dense_act_fn": "gelu_new", + "dropout_rate": 0.1, + "dtype": "bfloat16", + "eos_token_id": 1, + "feed_forward_proj": "gated-gelu", + "initializer_factor": 1.0, + "is_encoder_decoder": True, + "is_gated_act": True, + "layer_norm_epsilon": 1e-06, + "model_type": "t5", + "num_decoder_layers": 24, + "num_heads": 64, + "num_layers": 24, + "output_past": True, + "pad_token_id": 0, + "relative_attention_max_distance": 128, + "relative_attention_num_buckets": 32, + "tie_word_embeddings": False, + "transformers_version": "4.57.1", + "use_cache": True, + "vocab_size": 32128 + }) + super().__init__(config) + + def forward(self, input_ids): + outputs = super().forward(input_ids=input_ids) + prompt_emb = outputs.last_hidden_state + return prompt_emb diff --git a/diffsynth/models/flux_vae.py b/diffsynth/models/flux_vae.py new file mode 100644 index 0000000000000000000000000000000000000000..5eabeaee6ad54f0b1c02f1b24cd2ccd84d0238bf --- /dev/null +++ b/diffsynth/models/flux_vae.py @@ -0,0 +1,451 @@ +import torch +from einops import rearrange, repeat + + +class TileWorker: + def __init__(self): + pass + + + def mask(self, height, width, border_width): + # Create a mask with shape (height, width). + # The centre area is filled with 1, and the border line is filled with values in range (0, 1]. + x = torch.arange(height).repeat(width, 1).T + y = torch.arange(width).repeat(height, 1) + mask = torch.stack([x + 1, height - x, y + 1, width - y]).min(dim=0).values + mask = (mask / border_width).clip(0, 1) + return mask + + + def tile(self, model_input, tile_size, tile_stride, tile_device, tile_dtype): + # Convert a tensor (b, c, h, w) to (b, c, tile_size, tile_size, tile_num) + batch_size, channel, _, _ = model_input.shape + model_input = model_input.to(device=tile_device, dtype=tile_dtype) + unfold_operator = torch.nn.Unfold( + kernel_size=(tile_size, tile_size), + stride=(tile_stride, tile_stride) + ) + model_input = unfold_operator(model_input) + model_input = model_input.view((batch_size, channel, tile_size, tile_size, -1)) + + return model_input + + + def tiled_inference(self, forward_fn, model_input, tile_batch_size, inference_device, inference_dtype, tile_device, tile_dtype): + # Call y=forward_fn(x) for each tile + tile_num = model_input.shape[-1] + model_output_stack = [] + + for tile_id in range(0, tile_num, tile_batch_size): + + # process input + tile_id_ = min(tile_id + tile_batch_size, tile_num) + x = model_input[:, :, :, :, tile_id: tile_id_] + x = x.to(device=inference_device, dtype=inference_dtype) + x = rearrange(x, "b c h w n -> (n b) c h w") + + # process output + y = forward_fn(x) + y = rearrange(y, "(n b) c h w -> b c h w n", n=tile_id_-tile_id) + y = y.to(device=tile_device, dtype=tile_dtype) + model_output_stack.append(y) + + model_output = torch.concat(model_output_stack, dim=-1) + return model_output + + + def io_scale(self, model_output, tile_size): + # Determine the size modification happened in forward_fn + # We only consider the same scale on height and width. + io_scale = model_output.shape[2] / tile_size + return io_scale + + + def untile(self, model_output, height, width, tile_size, tile_stride, border_width, tile_device, tile_dtype): + # The reversed function of tile + mask = self.mask(tile_size, tile_size, border_width) + mask = mask.to(device=tile_device, dtype=tile_dtype) + mask = rearrange(mask, "h w -> 1 1 h w 1") + model_output = model_output * mask + + fold_operator = torch.nn.Fold( + output_size=(height, width), + kernel_size=(tile_size, tile_size), + stride=(tile_stride, tile_stride) + ) + mask = repeat(mask[0, 0, :, :, 0], "h w -> 1 (h w) n", n=model_output.shape[-1]) + model_output = rearrange(model_output, "b c h w n -> b (c h w) n") + model_output = fold_operator(model_output) / fold_operator(mask) + + return model_output + + + def tiled_forward(self, forward_fn, model_input, tile_size, tile_stride, tile_batch_size=1, tile_device="cpu", tile_dtype=torch.float32, border_width=None): + # Prepare + inference_device, inference_dtype = model_input.device, model_input.dtype + height, width = model_input.shape[2], model_input.shape[3] + border_width = int(tile_stride*0.5) if border_width is None else border_width + + # tile + model_input = self.tile(model_input, tile_size, tile_stride, tile_device, tile_dtype) + + # inference + model_output = self.tiled_inference(forward_fn, model_input, tile_batch_size, inference_device, inference_dtype, tile_device, tile_dtype) + + # resize + io_scale = self.io_scale(model_output, tile_size) + height, width = int(height*io_scale), int(width*io_scale) + tile_size, tile_stride = int(tile_size*io_scale), int(tile_stride*io_scale) + border_width = int(border_width*io_scale) + + # untile + model_output = self.untile(model_output, height, width, tile_size, tile_stride, border_width, tile_device, tile_dtype) + + # Done! + model_output = model_output.to(device=inference_device, dtype=inference_dtype) + return model_output + + +class ConvAttention(torch.nn.Module): + + def __init__(self, q_dim, num_heads, head_dim, kv_dim=None, bias_q=False, bias_kv=False, bias_out=False): + super().__init__() + dim_inner = head_dim * num_heads + kv_dim = kv_dim if kv_dim is not None else q_dim + self.num_heads = num_heads + self.head_dim = head_dim + + self.to_q = torch.nn.Conv2d(q_dim, dim_inner, kernel_size=(1, 1), bias=bias_q) + self.to_k = torch.nn.Conv2d(kv_dim, dim_inner, kernel_size=(1, 1), bias=bias_kv) + self.to_v = torch.nn.Conv2d(kv_dim, dim_inner, kernel_size=(1, 1), bias=bias_kv) + self.to_out = torch.nn.Conv2d(dim_inner, q_dim, kernel_size=(1, 1), bias=bias_out) + + def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None): + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + + batch_size = encoder_hidden_states.shape[0] + + conv_input = rearrange(hidden_states, "B L C -> B C L 1") + q = self.to_q(conv_input) + q = rearrange(q[:, :, :, 0], "B C L -> B L C") + conv_input = rearrange(encoder_hidden_states, "B L C -> B C L 1") + k = self.to_k(conv_input) + v = self.to_v(conv_input) + k = rearrange(k[:, :, :, 0], "B C L -> B L C") + v = rearrange(v[:, :, :, 0], "B C L -> B L C") + + q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + + hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim) + hidden_states = hidden_states.to(q.dtype) + + conv_input = rearrange(hidden_states, "B L C -> B C L 1") + hidden_states = self.to_out(conv_input) + hidden_states = rearrange(hidden_states[:, :, :, 0], "B C L -> B L C") + + return hidden_states + + +class Attention(torch.nn.Module): + + def __init__(self, q_dim, num_heads, head_dim, kv_dim=None, bias_q=False, bias_kv=False, bias_out=False): + super().__init__() + dim_inner = head_dim * num_heads + kv_dim = kv_dim if kv_dim is not None else q_dim + self.num_heads = num_heads + self.head_dim = head_dim + + self.to_q = torch.nn.Linear(q_dim, dim_inner, bias=bias_q) + self.to_k = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv) + self.to_v = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv) + self.to_out = torch.nn.Linear(dim_inner, q_dim, bias=bias_out) + + def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None): + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + + batch_size = encoder_hidden_states.shape[0] + + q = self.to_q(hidden_states) + k = self.to_k(encoder_hidden_states) + v = self.to_v(encoder_hidden_states) + + q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + + hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim) + hidden_states = hidden_states.to(q.dtype) + + hidden_states = self.to_out(hidden_states) + + return hidden_states + + +class VAEAttentionBlock(torch.nn.Module): + + def __init__(self, num_attention_heads, attention_head_dim, in_channels, num_layers=1, norm_num_groups=32, eps=1e-5, use_conv_attention=True): + super().__init__() + inner_dim = num_attention_heads * attention_head_dim + + self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=eps, affine=True) + + if use_conv_attention: + self.transformer_blocks = torch.nn.ModuleList([ + ConvAttention( + inner_dim, + num_attention_heads, + attention_head_dim, + bias_q=True, + bias_kv=True, + bias_out=True + ) + for d in range(num_layers) + ]) + else: + self.transformer_blocks = torch.nn.ModuleList([ + Attention( + inner_dim, + num_attention_heads, + attention_head_dim, + bias_q=True, + bias_kv=True, + bias_out=True + ) + for d in range(num_layers) + ]) + + def forward(self, hidden_states, time_emb, text_emb, res_stack): + batch, _, height, width = hidden_states.shape + residual = hidden_states + + hidden_states = self.norm(hidden_states) + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) + + for block in self.transformer_blocks: + hidden_states = block(hidden_states) + + hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() + hidden_states = hidden_states + residual + + return hidden_states, time_emb, text_emb, res_stack + + +class ResnetBlock(torch.nn.Module): + def __init__(self, in_channels, out_channels, temb_channels=None, groups=32, eps=1e-5): + super().__init__() + self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) + self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + if temb_channels is not None: + self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels) + self.norm2 = torch.nn.GroupNorm(num_groups=groups, num_channels=out_channels, eps=eps, affine=True) + self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + self.nonlinearity = torch.nn.SiLU() + self.conv_shortcut = None + if in_channels != out_channels: + self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=True) + + def forward(self, hidden_states, time_emb, text_emb, res_stack, **kwargs): + x = hidden_states + x = self.norm1(x) + x = self.nonlinearity(x) + x = self.conv1(x) + if time_emb is not None: + emb = self.nonlinearity(time_emb) + emb = self.time_emb_proj(emb)[:, :, None, None] + x = x + emb + x = self.norm2(x) + x = self.nonlinearity(x) + x = self.conv2(x) + if self.conv_shortcut is not None: + hidden_states = self.conv_shortcut(hidden_states) + hidden_states = hidden_states + x + return hidden_states, time_emb, text_emb, res_stack + + +class UpSampler(torch.nn.Module): + def __init__(self, channels): + super().__init__() + self.conv = torch.nn.Conv2d(channels, channels, 3, padding=1) + + def forward(self, hidden_states, time_emb, text_emb, res_stack, **kwargs): + hidden_states = torch.nn.functional.interpolate(hidden_states, scale_factor=2.0, mode="nearest") + hidden_states = self.conv(hidden_states) + return hidden_states, time_emb, text_emb, res_stack + + +class DownSampler(torch.nn.Module): + def __init__(self, channels, padding=1, extra_padding=False): + super().__init__() + self.conv = torch.nn.Conv2d(channels, channels, 3, stride=2, padding=padding) + self.extra_padding = extra_padding + + def forward(self, hidden_states, time_emb, text_emb, res_stack, **kwargs): + if self.extra_padding: + hidden_states = torch.nn.functional.pad(hidden_states, (0, 1, 0, 1), mode="constant", value=0) + hidden_states = self.conv(hidden_states) + return hidden_states, time_emb, text_emb, res_stack + + +class FluxVAEDecoder(torch.nn.Module): + def __init__(self, use_conv_attention=True): + super().__init__() + self.scaling_factor = 0.3611 + self.shift_factor = 0.1159 + self.conv_in = torch.nn.Conv2d(16, 512, kernel_size=3, padding=1) # Different from SD 1.x + + self.blocks = torch.nn.ModuleList([ + # UNetMidBlock2D + ResnetBlock(512, 512, eps=1e-6), + VAEAttentionBlock(1, 512, 512, 1, eps=1e-6, use_conv_attention=use_conv_attention), + ResnetBlock(512, 512, eps=1e-6), + # UpDecoderBlock2D + ResnetBlock(512, 512, eps=1e-6), + ResnetBlock(512, 512, eps=1e-6), + ResnetBlock(512, 512, eps=1e-6), + UpSampler(512), + # UpDecoderBlock2D + ResnetBlock(512, 512, eps=1e-6), + ResnetBlock(512, 512, eps=1e-6), + ResnetBlock(512, 512, eps=1e-6), + UpSampler(512), + # UpDecoderBlock2D + ResnetBlock(512, 256, eps=1e-6), + ResnetBlock(256, 256, eps=1e-6), + ResnetBlock(256, 256, eps=1e-6), + UpSampler(256), + # UpDecoderBlock2D + ResnetBlock(256, 128, eps=1e-6), + ResnetBlock(128, 128, eps=1e-6), + ResnetBlock(128, 128, eps=1e-6), + ]) + + self.conv_norm_out = torch.nn.GroupNorm(num_channels=128, num_groups=32, eps=1e-6) + self.conv_act = torch.nn.SiLU() + self.conv_out = torch.nn.Conv2d(128, 3, kernel_size=3, padding=1) + + def tiled_forward(self, sample, tile_size=64, tile_stride=32): + hidden_states = TileWorker().tiled_forward( + lambda x: self.forward(x), + sample, + tile_size, + tile_stride, + tile_device=sample.device, + tile_dtype=sample.dtype + ) + return hidden_states + + def forward(self, sample, tiled=False, tile_size=64, tile_stride=32, **kwargs): + # For VAE Decoder, we do not need to apply the tiler on each layer. + if tiled: + return self.tiled_forward(sample, tile_size=tile_size, tile_stride=tile_stride) + + # 1. pre-process + hidden_states = sample / self.scaling_factor + self.shift_factor + hidden_states = self.conv_in(hidden_states) + time_emb = None + text_emb = None + res_stack = None + + # 2. blocks + for i, block in enumerate(self.blocks): + hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack) + + # 3. output + hidden_states = self.conv_norm_out(hidden_states) + hidden_states = self.conv_act(hidden_states) + hidden_states = self.conv_out(hidden_states) + + return hidden_states + + +class FluxVAEEncoder(torch.nn.Module): + def __init__(self, use_conv_attention=True): + super().__init__() + self.scaling_factor = 0.3611 + self.shift_factor = 0.1159 + self.conv_in = torch.nn.Conv2d(3, 128, kernel_size=3, padding=1) + + self.blocks = torch.nn.ModuleList([ + # DownEncoderBlock2D + ResnetBlock(128, 128, eps=1e-6), + ResnetBlock(128, 128, eps=1e-6), + DownSampler(128, padding=0, extra_padding=True), + # DownEncoderBlock2D + ResnetBlock(128, 256, eps=1e-6), + ResnetBlock(256, 256, eps=1e-6), + DownSampler(256, padding=0, extra_padding=True), + # DownEncoderBlock2D + ResnetBlock(256, 512, eps=1e-6), + ResnetBlock(512, 512, eps=1e-6), + DownSampler(512, padding=0, extra_padding=True), + # DownEncoderBlock2D + ResnetBlock(512, 512, eps=1e-6), + ResnetBlock(512, 512, eps=1e-6), + # UNetMidBlock2D + ResnetBlock(512, 512, eps=1e-6), + VAEAttentionBlock(1, 512, 512, 1, eps=1e-6, use_conv_attention=use_conv_attention), + ResnetBlock(512, 512, eps=1e-6), + ]) + + self.conv_norm_out = torch.nn.GroupNorm(num_channels=512, num_groups=32, eps=1e-6) + self.conv_act = torch.nn.SiLU() + self.conv_out = torch.nn.Conv2d(512, 32, kernel_size=3, padding=1) + + def tiled_forward(self, sample, tile_size=64, tile_stride=32): + hidden_states = TileWorker().tiled_forward( + lambda x: self.forward(x), + sample, + tile_size, + tile_stride, + tile_device=sample.device, + tile_dtype=sample.dtype + ) + return hidden_states + + def forward(self, sample, tiled=False, tile_size=64, tile_stride=32, **kwargs): + # For VAE Decoder, we do not need to apply the tiler on each layer. + if tiled: + return self.tiled_forward(sample, tile_size=tile_size, tile_stride=tile_stride) + + # 1. pre-process + hidden_states = self.conv_in(sample) + time_emb = None + text_emb = None + res_stack = None + + # 2. blocks + for i, block in enumerate(self.blocks): + hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack) + + # 3. output + hidden_states = self.conv_norm_out(hidden_states) + hidden_states = self.conv_act(hidden_states) + hidden_states = self.conv_out(hidden_states) + hidden_states = hidden_states[:, :16] + hidden_states = (hidden_states - self.shift_factor) * self.scaling_factor + + return hidden_states + + def encode_video(self, sample, batch_size=8): + B = sample.shape[0] + hidden_states = [] + + for i in range(0, sample.shape[2], batch_size): + + j = min(i + batch_size, sample.shape[2]) + sample_batch = rearrange(sample[:,:,i:j], "B C T H W -> (B T) C H W") + + hidden_states_batch = self(sample_batch) + hidden_states_batch = rearrange(hidden_states_batch, "(B T) C H W -> B C T H W", B=B) + + hidden_states.append(hidden_states_batch) + + hidden_states = torch.concat(hidden_states, dim=2) + return hidden_states diff --git a/diffsynth/models/flux_value_control.py b/diffsynth/models/flux_value_control.py new file mode 100644 index 0000000000000000000000000000000000000000..549dbc93b41343a42266af11584e2e7d39a17cd6 --- /dev/null +++ b/diffsynth/models/flux_value_control.py @@ -0,0 +1,56 @@ +import torch +from .general_modules import TemporalTimesteps + + +class MultiValueEncoder(torch.nn.Module): + def __init__(self, encoders=()): + super().__init__() + if not isinstance(encoders, list): + encoders = [encoders] + self.encoders = torch.nn.ModuleList(encoders) + + def __call__(self, values, dtype): + emb = [] + for encoder, value in zip(self.encoders, values): + if value is not None: + value = value.unsqueeze(0) + emb.append(encoder(value, dtype)) + emb = torch.concat(emb, dim=0) + return emb + + +class SingleValueEncoder(torch.nn.Module): + def __init__(self, dim_in=256, dim_out=4096, prefer_len=32, computation_device=None): + super().__init__() + self.prefer_len = prefer_len + self.prefer_proj = TemporalTimesteps(num_channels=dim_in, flip_sin_to_cos=True, downscale_freq_shift=0, computation_device=computation_device) + self.prefer_value_embedder = torch.nn.Sequential( + torch.nn.Linear(dim_in, dim_out), torch.nn.SiLU(), torch.nn.Linear(dim_out, dim_out) + ) + self.positional_embedding = torch.nn.Parameter( + torch.randn(self.prefer_len, dim_out) + ) + + def forward(self, value, dtype): + value = value * 1000 + emb = self.prefer_proj(value).to(dtype) + emb = self.prefer_value_embedder(emb).squeeze(0) + base_embeddings = emb.expand(self.prefer_len, -1) + positional_embedding = self.positional_embedding.to(dtype=base_embeddings.dtype, device=base_embeddings.device) + learned_embeddings = base_embeddings + positional_embedding + return learned_embeddings + + @staticmethod + def state_dict_converter(): + return SingleValueEncoderStateDictConverter() + + +class SingleValueEncoderStateDictConverter: + def __init__(self): + pass + + def from_diffusers(self, state_dict): + return state_dict + + def from_civitai(self, state_dict): + return state_dict diff --git a/diffsynth/models/general_modules.py b/diffsynth/models/general_modules.py new file mode 100644 index 0000000000000000000000000000000000000000..1e97ba6e80447ab9f10eed9424edbd6e3d147cb4 --- /dev/null +++ b/diffsynth/models/general_modules.py @@ -0,0 +1,146 @@ +import torch, math + + +def get_timestep_embedding( + timesteps: torch.Tensor, + embedding_dim: int, + flip_sin_to_cos: bool = False, + downscale_freq_shift: float = 1, + scale: float = 1, + max_period: int = 10000, + computation_device = None, + align_dtype_to_timestep = False, +): + assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" + + half_dim = embedding_dim // 2 + exponent = -math.log(max_period) * torch.arange( + start=0, end=half_dim, dtype=torch.float32, device=timesteps.device if computation_device is None else computation_device + ) + exponent = exponent / (half_dim - downscale_freq_shift) + + emb = torch.exp(exponent) + if align_dtype_to_timestep: + emb = emb.to(timesteps.dtype) + emb = timesteps[:, None].float() * emb[None, :] + + # scale embeddings + emb = scale * emb + + # concat sine and cosine embeddings + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) + + # flip sine and cosine embeddings + if flip_sin_to_cos: + emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) + + # zero pad + if embedding_dim % 2 == 1: + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + + +class TemporalTimesteps(torch.nn.Module): + def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, computation_device = None, scale=1, align_dtype_to_timestep=False): + super().__init__() + self.num_channels = num_channels + self.flip_sin_to_cos = flip_sin_to_cos + self.downscale_freq_shift = downscale_freq_shift + self.computation_device = computation_device + self.scale = scale + self.align_dtype_to_timestep = align_dtype_to_timestep + + def forward(self, timesteps): + t_emb = get_timestep_embedding( + timesteps, + self.num_channels, + flip_sin_to_cos=self.flip_sin_to_cos, + downscale_freq_shift=self.downscale_freq_shift, + computation_device=self.computation_device, + scale=self.scale, + align_dtype_to_timestep=self.align_dtype_to_timestep, + ) + return t_emb + + +class DiffusersCompatibleTimestepProj(torch.nn.Module): + def __init__(self, dim_in, dim_out): + super().__init__() + self.linear_1 = torch.nn.Linear(dim_in, dim_out) + self.act = torch.nn.SiLU() + self.linear_2 = torch.nn.Linear(dim_out, dim_out) + + def forward(self, x): + x = self.linear_1(x) + x = self.act(x) + x = self.linear_2(x) + return x + + +class TimestepEmbeddings(torch.nn.Module): + def __init__(self, dim_in, dim_out, computation_device=None, diffusers_compatible_format=False, scale=1, align_dtype_to_timestep=False, use_additional_t_cond=False): + super().__init__() + self.time_proj = TemporalTimesteps(num_channels=dim_in, flip_sin_to_cos=True, downscale_freq_shift=0, computation_device=computation_device, scale=scale, align_dtype_to_timestep=align_dtype_to_timestep) + if diffusers_compatible_format: + self.timestep_embedder = DiffusersCompatibleTimestepProj(dim_in, dim_out) + else: + self.timestep_embedder = torch.nn.Sequential( + torch.nn.Linear(dim_in, dim_out), torch.nn.SiLU(), torch.nn.Linear(dim_out, dim_out) + ) + self.use_additional_t_cond = use_additional_t_cond + if use_additional_t_cond: + self.addition_t_embedding = torch.nn.Embedding(2, dim_out) + + def forward(self, timestep, dtype, addition_t_cond=None): + time_emb = self.time_proj(timestep).to(dtype) + time_emb = self.timestep_embedder(time_emb) + if addition_t_cond is not None: + addition_t_emb = self.addition_t_embedding(addition_t_cond) + addition_t_emb = addition_t_emb.to(dtype=dtype) + time_emb = time_emb + addition_t_emb + return time_emb + + +class RMSNorm(torch.nn.Module): + def __init__(self, dim, eps, elementwise_affine=True): + super().__init__() + self.eps = eps + if elementwise_affine: + self.weight = torch.nn.Parameter(torch.ones((dim,))) + else: + self.weight = None + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + variance = hidden_states.to(torch.float32).square().mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.eps) + hidden_states = hidden_states.to(input_dtype) + if self.weight is not None: + hidden_states = hidden_states * self.weight + return hidden_states + + +class AdaLayerNorm(torch.nn.Module): + def __init__(self, dim, single=False, dual=False): + super().__init__() + self.single = single + self.dual = dual + self.linear = torch.nn.Linear(dim, dim * [[6, 2][single], 9][dual]) + self.norm = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + + def forward(self, x, emb): + emb = self.linear(torch.nn.functional.silu(emb)) + if self.single: + scale, shift = emb.unsqueeze(1).chunk(2, dim=2) + x = self.norm(x) * (1 + scale) + shift + return x + elif self.dual: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp, shift_msa2, scale_msa2, gate_msa2 = emb.unsqueeze(1).chunk(9, dim=2) + norm_x = self.norm(x) + x = norm_x * (1 + scale_msa) + shift_msa + norm_x2 = norm_x * (1 + scale_msa2) + shift_msa2 + return x, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_x2, gate_msa2 + else: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.unsqueeze(1).chunk(6, dim=2) + x = self.norm(x) * (1 + scale_msa) + shift_msa + return x, gate_msa, shift_mlp, scale_mlp, gate_mlp diff --git a/diffsynth/models/longcat_video_dit.py b/diffsynth/models/longcat_video_dit.py new file mode 100644 index 0000000000000000000000000000000000000000..dbe1c21b9304c484c72cb17e66ff97e7c85bc16b --- /dev/null +++ b/diffsynth/models/longcat_video_dit.py @@ -0,0 +1,902 @@ +from typing import List, Optional, Tuple + +import math +import torch +import torch.nn as nn +import torch.amp as amp + +import numpy as np +import torch.nn.functional as F +from einops import rearrange, repeat +from .wan_video_dit import flash_attention +from ..core.device.npu_compatible_device import get_device_type +from ..core.gradient import gradient_checkpoint_forward + + +class RMSNorm_FP32(torch.nn.Module): + def __init__(self, dim: int, eps: float): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + output = self._norm(x.float()).type_as(x) + return output * self.weight + + +def broadcat(tensors, dim=-1): + num_tensors = len(tensors) + shape_lens = set(list(map(lambda t: len(t.shape), tensors))) + assert len(shape_lens) == 1, "tensors must all have the same number of dimensions" + shape_len = list(shape_lens)[0] + dim = (dim + shape_len) if dim < 0 else dim + dims = list(zip(*map(lambda t: list(t.shape), tensors))) + expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim] + assert all( + [*map(lambda t: len(set(t[1])) <= 2, expandable_dims)] + ), "invalid dimensions for broadcastable concatentation" + max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims)) + expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims)) + expanded_dims.insert(dim, (dim, dims[dim])) + expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims))) + tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes))) + return torch.cat(tensors, dim=dim) + + +def rotate_half(x): + x = rearrange(x, "... (d r) -> ... d r", r=2) + x1, x2 = x.unbind(dim=-1) + x = torch.stack((-x2, x1), dim=-1) + return rearrange(x, "... d r -> ... (d r)") + + +class RotaryPositionalEmbedding(nn.Module): + + def __init__(self, + head_dim, + cp_split_hw=None + ): + """Rotary positional embedding for 3D + Reference : https://blog.eleuther.ai/rotary-embeddings/ + Paper: https://arxiv.org/pdf/2104.09864.pdf + Args: + dim: Dimension of embedding + base: Base value for exponential + """ + super().__init__() + self.head_dim = head_dim + assert self.head_dim % 8 == 0, 'Dim must be a multiply of 8 for 3D RoPE.' + self.cp_split_hw = cp_split_hw + # We take the assumption that the longest side of grid will not larger than 512, i.e, 512 * 8 = 4098 input pixels + self.base = 10000 + self.freqs_dict = {} + + def register_grid_size(self, grid_size): + if grid_size not in self.freqs_dict: + self.freqs_dict.update({ + grid_size: self.precompute_freqs_cis_3d(grid_size) + }) + + def precompute_freqs_cis_3d(self, grid_size): + num_frames, height, width = grid_size + dim_t = self.head_dim - 4 * (self.head_dim // 6) + dim_h = 2 * (self.head_dim // 6) + dim_w = 2 * (self.head_dim // 6) + freqs_t = 1.0 / (self.base ** (torch.arange(0, dim_t, 2)[: (dim_t // 2)].float() / dim_t)) + freqs_h = 1.0 / (self.base ** (torch.arange(0, dim_h, 2)[: (dim_h // 2)].float() / dim_h)) + freqs_w = 1.0 / (self.base ** (torch.arange(0, dim_w, 2)[: (dim_w // 2)].float() / dim_w)) + grid_t = np.linspace(0, num_frames, num_frames, endpoint=False, dtype=np.float32) + grid_h = np.linspace(0, height, height, endpoint=False, dtype=np.float32) + grid_w = np.linspace(0, width, width, endpoint=False, dtype=np.float32) + grid_t = torch.from_numpy(grid_t).float() + grid_h = torch.from_numpy(grid_h).float() + grid_w = torch.from_numpy(grid_w).float() + freqs_t = torch.einsum("..., f -> ... f", grid_t, freqs_t) + freqs_h = torch.einsum("..., f -> ... f", grid_h, freqs_h) + freqs_w = torch.einsum("..., f -> ... f", grid_w, freqs_w) + freqs_t = repeat(freqs_t, "... n -> ... (n r)", r=2) + freqs_h = repeat(freqs_h, "... n -> ... (n r)", r=2) + freqs_w = repeat(freqs_w, "... n -> ... (n r)", r=2) + freqs = broadcat((freqs_t[:, None, None, :], freqs_h[None, :, None, :], freqs_w[None, None, :, :]), dim=-1) + # (T H W D) + freqs = rearrange(freqs, "T H W D -> (T H W) D") + # if self.cp_split_hw[0] * self.cp_split_hw[1] > 1: + # with torch.no_grad(): + # freqs = rearrange(freqs, "(T H W) D -> T H W D", T=num_frames, H=height, W=width) + # freqs = context_parallel_util.split_cp_2d(freqs, seq_dim_hw=(1, 2), split_hw=self.cp_split_hw) + # freqs = rearrange(freqs, "T H W D -> (T H W) D") + + return freqs + + def forward(self, q, k, grid_size): + """3D RoPE. + + Args: + query: [B, head, seq, head_dim] + key: [B, head, seq, head_dim] + Returns: + query and key with the same shape as input. + """ + + if grid_size not in self.freqs_dict: + self.register_grid_size(grid_size) + + freqs_cis = self.freqs_dict[grid_size].to(q.device) + q_, k_ = q.float(), k.float() + freqs_cis = freqs_cis.float().to(q.device) + cos, sin = freqs_cis.cos(), freqs_cis.sin() + cos, sin = rearrange(cos, 'n d -> 1 1 n d'), rearrange(sin, 'n d -> 1 1 n d') + q_ = (q_ * cos) + (rotate_half(q_) * sin) + k_ = (k_ * cos) + (rotate_half(k_) * sin) + + return q_.type_as(q), k_.type_as(k) + + +class Attention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + enable_flashattn3: bool = False, + enable_flashattn2: bool = False, + enable_xformers: bool = False, + enable_bsa: bool = False, + bsa_params: dict = None, + cp_split_hw: Optional[List[int]] = None + ) -> None: + super().__init__() + assert dim % num_heads == 0, "dim should be divisible by num_heads" + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim**-0.5 + self.enable_flashattn3 = enable_flashattn3 + self.enable_flashattn2 = enable_flashattn2 + self.enable_xformers = enable_xformers + self.enable_bsa = enable_bsa + self.bsa_params = bsa_params + self.cp_split_hw = cp_split_hw + + self.qkv = nn.Linear(dim, dim * 3, bias=True) + self.q_norm = RMSNorm_FP32(self.head_dim, eps=1e-6) + self.k_norm = RMSNorm_FP32(self.head_dim, eps=1e-6) + self.proj = nn.Linear(dim, dim) + + self.rope_3d = RotaryPositionalEmbedding( + self.head_dim, + cp_split_hw=cp_split_hw + ) + + def _process_attn(self, q, k, v, shape): + q = rearrange(q, "B H S D -> B S (H D)") + k = rearrange(k, "B H S D -> B S (H D)") + v = rearrange(v, "B H S D -> B S (H D)") + x = flash_attention(q, k, v, num_heads=self.num_heads) + x = rearrange(x, "B S (H D) -> B H S D", H=self.num_heads) + return x + + def forward(self, x: torch.Tensor, shape=None, num_cond_latents=None, return_kv=False) -> torch.Tensor: + """ + """ + B, N, C = x.shape + qkv = self.qkv(x) + + qkv_shape = (B, N, 3, self.num_heads, self.head_dim) + qkv = qkv.view(qkv_shape).permute((2, 0, 3, 1, 4)) # [3, B, H, N, D] + q, k, v = qkv.unbind(0) + q, k = self.q_norm(q), self.k_norm(k) + + if return_kv: + k_cache, v_cache = k.clone(), v.clone() + + q, k = self.rope_3d(q, k, shape) + + # cond mode + if num_cond_latents is not None and num_cond_latents > 0: + num_cond_latents_thw = num_cond_latents * (N // shape[0]) + # process the condition tokens + q_cond = q[:, :, :num_cond_latents_thw].contiguous() + k_cond = k[:, :, :num_cond_latents_thw].contiguous() + v_cond = v[:, :, :num_cond_latents_thw].contiguous() + x_cond = self._process_attn(q_cond, k_cond, v_cond, shape) + # process the noise tokens + q_noise = q[:, :, num_cond_latents_thw:].contiguous() + x_noise = self._process_attn(q_noise, k, v, shape) + # merge x_cond and x_noise + x = torch.cat([x_cond, x_noise], dim=2).contiguous() + else: + x = self._process_attn(q, k, v, shape) + + x_output_shape = (B, N, C) + x = x.transpose(1, 2) # [B, H, N, D] --> [B, N, H, D] + x = x.reshape(x_output_shape) # [B, N, H, D] --> [B, N, C] + x = self.proj(x) + + if return_kv: + return x, (k_cache, v_cache) + else: + return x + + def forward_with_kv_cache(self, x: torch.Tensor, shape=None, num_cond_latents=None, kv_cache=None) -> torch.Tensor: + """ + """ + B, N, C = x.shape + qkv = self.qkv(x) + + qkv_shape = (B, N, 3, self.num_heads, self.head_dim) + qkv = qkv.view(qkv_shape).permute((2, 0, 3, 1, 4)) # [3, B, H, N, D] + q, k, v = qkv.unbind(0) + q, k = self.q_norm(q), self.k_norm(k) + + T, H, W = shape + k_cache, v_cache = kv_cache + assert k_cache.shape[0] == v_cache.shape[0] and k_cache.shape[0] in [1, B] + if k_cache.shape[0] == 1: + k_cache = k_cache.repeat(B, 1, 1, 1) + v_cache = v_cache.repeat(B, 1, 1, 1) + + if num_cond_latents is not None and num_cond_latents > 0: + k_full = torch.cat([k_cache, k], dim=2).contiguous() + v_full = torch.cat([v_cache, v], dim=2).contiguous() + q_padding = torch.cat([torch.empty_like(k_cache), q], dim=2).contiguous() + q_padding, k_full = self.rope_3d(q_padding, k_full, (T + num_cond_latents, H, W)) + q = q_padding[:, :, -N:].contiguous() + + x = self._process_attn(q, k_full, v_full, shape) + + x_output_shape = (B, N, C) + x = x.transpose(1, 2) # [B, H, N, D] --> [B, N, H, D] + x = x.reshape(x_output_shape) # [B, N, H, D] --> [B, N, C] + x = self.proj(x) + + return x + + +class MultiHeadCrossAttention(nn.Module): + def __init__( + self, + dim, + num_heads, + enable_flashattn3=False, + enable_flashattn2=False, + enable_xformers=False, + ): + super(MultiHeadCrossAttention, self).__init__() + assert dim % num_heads == 0, "d_model must be divisible by num_heads" + + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + + self.q_linear = nn.Linear(dim, dim) + self.kv_linear = nn.Linear(dim, dim * 2) + self.proj = nn.Linear(dim, dim) + + self.q_norm = RMSNorm_FP32(self.head_dim, eps=1e-6) + self.k_norm = RMSNorm_FP32(self.head_dim, eps=1e-6) + + self.enable_flashattn3 = enable_flashattn3 + self.enable_flashattn2 = enable_flashattn2 + self.enable_xformers = enable_xformers + + def _process_cross_attn(self, x, cond, kv_seqlen): + B, N, C = x.shape + assert C == self.dim and cond.shape[2] == self.dim + + q = self.q_linear(x).view(1, -1, self.num_heads, self.head_dim) + kv = self.kv_linear(cond).view(1, -1, 2, self.num_heads, self.head_dim) + k, v = kv.unbind(2) + + q, k = self.q_norm(q), self.k_norm(k) + + q = rearrange(q, "B S H D -> B S (H D)") + k = rearrange(k, "B S H D -> B S (H D)") + v = rearrange(v, "B S H D -> B S (H D)") + x = flash_attention(q, k, v, num_heads=self.num_heads) + + x = x.view(B, -1, C) + x = self.proj(x) + return x + + def forward(self, x, cond, kv_seqlen, num_cond_latents=None, shape=None): + """ + x: [B, N, C] + cond: [B, M, C] + """ + if num_cond_latents is None or num_cond_latents == 0: + return self._process_cross_attn(x, cond, kv_seqlen) + else: + B, N, C = x.shape + if num_cond_latents is not None and num_cond_latents > 0: + assert shape is not None, "SHOULD pass in the shape" + num_cond_latents_thw = num_cond_latents * (N // shape[0]) + x_noise = x[:, num_cond_latents_thw:] # [B, N_noise, C] + output_noise = self._process_cross_attn(x_noise, cond, kv_seqlen) # [B, N_noise, C] + output = torch.cat([ + torch.zeros((B, num_cond_latents_thw, C), dtype=output_noise.dtype, device=output_noise.device), + output_noise + ], dim=1).contiguous() + else: + raise NotImplementedError + + return output + + +class LayerNorm_FP32(nn.LayerNorm): + def __init__(self, dim, eps, elementwise_affine): + super().__init__(dim, eps=eps, elementwise_affine=elementwise_affine) + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + origin_dtype = inputs.dtype + out = F.layer_norm( + inputs.float(), + self.normalized_shape, + None if self.weight is None else self.weight.float(), + None if self.bias is None else self.bias.float() , + self.eps + ).to(origin_dtype) + return out + + +def modulate_fp32(norm_func, x, shift, scale): + # Suppose x is (B, N, D), shift is (B, -1, D), scale is (B, -1, D) + # ensure the modulation params be fp32 + assert shift.dtype == torch.float32, scale.dtype == torch.float32 + dtype = x.dtype + x = norm_func(x.to(torch.float32)) + x = x * (scale + 1) + shift + x = x.to(dtype) + return x + + +class FinalLayer_FP32(nn.Module): + """ + The final layer of DiT. + """ + + def __init__(self, hidden_size, num_patch, out_channels, adaln_tembed_dim): + super().__init__() + self.hidden_size = hidden_size + self.num_patch = num_patch + self.out_channels = out_channels + self.adaln_tembed_dim = adaln_tembed_dim + + self.norm_final = LayerNorm_FP32(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, num_patch * out_channels, bias=True) + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(adaln_tembed_dim, 2 * hidden_size, bias=True)) + + def forward(self, x, t, latent_shape): + # timestep shape: [B, T, C] + assert t.dtype == torch.float32 + B, N, C = x.shape + T, _, _ = latent_shape + + with amp.autocast(get_device_type(), dtype=torch.float32): + shift, scale = self.adaLN_modulation(t).unsqueeze(2).chunk(2, dim=-1) # [B, T, 1, C] + x = modulate_fp32(self.norm_final, x.view(B, T, -1, C), shift, scale).view(B, N, C) + x = self.linear(x) + return x + + +class FeedForwardSwiGLU(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + multiple_of: int = 256, + ffn_dim_multiplier: Optional[float] = None, + ): + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + # custom dim factor multiplier + if ffn_dim_multiplier is not None: + hidden_dim = int(ffn_dim_multiplier * hidden_dim) + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + + self.dim = dim + self.hidden_dim = hidden_dim + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + + def forward(self, x): + return self.w2(F.silu(self.w1(x)) * self.w3(x)) + + +class TimestepEmbedder(nn.Module): + """ + Embeds scalar timesteps into vector representations. + """ + + def __init__(self, t_embed_dim, frequency_embedding_size=256): + super().__init__() + self.t_embed_dim = t_embed_dim + self.frequency_embedding_size = frequency_embedding_size + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, t_embed_dim, bias=True), + nn.SiLU(), + nn.Linear(t_embed_dim, t_embed_dim, bias=True), + ) + + @staticmethod + def timestep_embedding(t, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + :param t: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an (N, D) Tensor of positional embeddings. + """ + half = dim // 2 + freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half) + freqs = freqs.to(device=t.device) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t, dtype): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + if t_freq.dtype != dtype: + t_freq = t_freq.to(dtype) + t_emb = self.mlp(t_freq) + return t_emb + + +class CaptionEmbedder(nn.Module): + """ + Embeds class labels into vector representations. + """ + + def __init__(self, in_channels, hidden_size): + super().__init__() + self.in_channels = in_channels + self.hidden_size = hidden_size + self.y_proj = nn.Sequential( + nn.Linear(in_channels, hidden_size, bias=True), + nn.GELU(approximate="tanh"), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + + def forward(self, caption): + B, _, N, C = caption.shape + caption = self.y_proj(caption) + return caption + + +class PatchEmbed3D(nn.Module): + """Video to Patch Embedding. + + Args: + patch_size (int): Patch token size. Default: (2,4,4). + in_chans (int): Number of input video channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__( + self, + patch_size=(2, 4, 4), + in_chans=3, + embed_dim=96, + norm_layer=None, + flatten=True, + ): + super().__init__() + self.patch_size = patch_size + self.flatten = flatten + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + """Forward function.""" + # padding + _, _, D, H, W = x.size() + if W % self.patch_size[2] != 0: + x = F.pad(x, (0, self.patch_size[2] - W % self.patch_size[2])) + if H % self.patch_size[1] != 0: + x = F.pad(x, (0, 0, 0, self.patch_size[1] - H % self.patch_size[1])) + if D % self.patch_size[0] != 0: + x = F.pad(x, (0, 0, 0, 0, 0, self.patch_size[0] - D % self.patch_size[0])) + + B, C, T, H, W = x.shape + x = self.proj(x) # (B C T H W) + if self.norm is not None: + D, Wh, Ww = x.size(2), x.size(3), x.size(4) + x = x.flatten(2).transpose(1, 2) + x = self.norm(x) + x = x.transpose(1, 2).view(-1, self.embed_dim, D, Wh, Ww) + if self.flatten: + x = x.flatten(2).transpose(1, 2) # BCTHW -> BNC + return x + + +class LongCatSingleStreamBlock(nn.Module): + def __init__( + self, + hidden_size: int, + num_heads: int, + mlp_ratio: int, + adaln_tembed_dim: int, + enable_flashattn3: bool = False, + enable_flashattn2: bool = False, + enable_xformers: bool = False, + enable_bsa: bool = False, + bsa_params=None, + cp_split_hw=None + ): + super().__init__() + + self.hidden_size = hidden_size + + # scale and gate modulation + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(adaln_tembed_dim, 6 * hidden_size, bias=True) + ) + + self.mod_norm_attn = LayerNorm_FP32(hidden_size, eps=1e-6, elementwise_affine=False) + self.mod_norm_ffn = LayerNorm_FP32(hidden_size, eps=1e-6, elementwise_affine=False) + self.pre_crs_attn_norm = LayerNorm_FP32(hidden_size, eps=1e-6, elementwise_affine=True) + + self.attn = Attention( + dim=hidden_size, + num_heads=num_heads, + enable_flashattn3=enable_flashattn3, + enable_flashattn2=enable_flashattn2, + enable_xformers=enable_xformers, + enable_bsa=enable_bsa, + bsa_params=bsa_params, + cp_split_hw=cp_split_hw + ) + self.cross_attn = MultiHeadCrossAttention( + dim=hidden_size, + num_heads=num_heads, + enable_flashattn3=enable_flashattn3, + enable_flashattn2=enable_flashattn2, + enable_xformers=enable_xformers, + ) + self.ffn = FeedForwardSwiGLU(dim=hidden_size, hidden_dim=int(hidden_size * mlp_ratio)) + + def forward(self, x, y, t, y_seqlen, latent_shape, num_cond_latents=None, return_kv=False, kv_cache=None, skip_crs_attn=False): + """ + x: [B, N, C] + y: [1, N_valid_tokens, C] + t: [B, T, C_t] + y_seqlen: [B]; type of a list + latent_shape: latent shape of a single item + """ + x_dtype = x.dtype + + B, N, C = x.shape + T, _, _ = latent_shape # S != T*H*W in case of CP split on H*W. + + # compute modulation params in fp32 + with amp.autocast(device_type=get_device_type(), dtype=torch.float32): + shift_msa, scale_msa, gate_msa, \ + shift_mlp, scale_mlp, gate_mlp = \ + self.adaLN_modulation(t).unsqueeze(2).chunk(6, dim=-1) # [B, T, 1, C] + + # self attn with modulation + x_m = modulate_fp32(self.mod_norm_attn, x.view(B, T, -1, C), shift_msa, scale_msa).view(B, N, C) + + if kv_cache is not None: + kv_cache = (kv_cache[0].to(x.device), kv_cache[1].to(x.device)) + attn_outputs = self.attn.forward_with_kv_cache(x_m, shape=latent_shape, num_cond_latents=num_cond_latents, kv_cache=kv_cache) + else: + attn_outputs = self.attn(x_m, shape=latent_shape, num_cond_latents=num_cond_latents, return_kv=return_kv) + + if return_kv: + x_s, kv_cache = attn_outputs + else: + x_s = attn_outputs + + with amp.autocast(device_type=get_device_type(), dtype=torch.float32): + x = x + (gate_msa * x_s.view(B, -1, N//T, C)).view(B, -1, C) # [B, N, C] + x = x.to(x_dtype) + + # cross attn + if not skip_crs_attn: + if kv_cache is not None: + num_cond_latents = None + x = x + self.cross_attn(self.pre_crs_attn_norm(x), y, y_seqlen, num_cond_latents=num_cond_latents, shape=latent_shape) + + # ffn with modulation + x_m = modulate_fp32(self.mod_norm_ffn, x.view(B, -1, N//T, C), shift_mlp, scale_mlp).view(B, -1, C) + x_s = self.ffn(x_m) + with amp.autocast(device_type=get_device_type(), dtype=torch.float32): + x = x + (gate_mlp * x_s.view(B, -1, N//T, C)).view(B, -1, C) # [B, N, C] + x = x.to(x_dtype) + + if return_kv: + return x, kv_cache + else: + return x + + +class LongCatVideoTransformer3DModel(torch.nn.Module): + def __init__( + self, + in_channels: int = 16, + out_channels: int = 16, + hidden_size: int = 4096, + depth: int = 48, + num_heads: int = 32, + caption_channels: int = 4096, + mlp_ratio: int = 4, + adaln_tembed_dim: int = 512, + frequency_embedding_size: int = 256, + # default params + patch_size: Tuple[int] = (1, 2, 2), + # attention config + enable_flashattn3: bool = False, + enable_flashattn2: bool = True, + enable_xformers: bool = False, + enable_bsa: bool = False, + bsa_params: dict = {'sparsity': 0.9375, 'chunk_3d_shape_q': [4, 4, 4], 'chunk_3d_shape_k': [4, 4, 4]}, + cp_split_hw: Optional[List[int]] = [1, 1], + text_tokens_zero_pad: bool = True, + ) -> None: + super().__init__() + + self.patch_size = patch_size + self.in_channels = in_channels + self.out_channels = out_channels + self.cp_split_hw = cp_split_hw + + self.x_embedder = PatchEmbed3D(patch_size, in_channels, hidden_size) + self.t_embedder = TimestepEmbedder(t_embed_dim=adaln_tembed_dim, frequency_embedding_size=frequency_embedding_size) + self.y_embedder = CaptionEmbedder( + in_channels=caption_channels, + hidden_size=hidden_size, + ) + + self.blocks = nn.ModuleList( + [ + LongCatSingleStreamBlock( + hidden_size=hidden_size, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + adaln_tembed_dim=adaln_tembed_dim, + enable_flashattn3=enable_flashattn3, + enable_flashattn2=enable_flashattn2, + enable_xformers=enable_xformers, + enable_bsa=enable_bsa, + bsa_params=bsa_params, + cp_split_hw=cp_split_hw + ) + for i in range(depth) + ] + ) + + self.final_layer = FinalLayer_FP32( + hidden_size, + np.prod(self.patch_size), + out_channels, + adaln_tembed_dim, + ) + + self.gradient_checkpointing = False + self.text_tokens_zero_pad = text_tokens_zero_pad + + self.lora_dict = {} + self.active_loras = [] + + def enable_loras(self, lora_key_list=[]): + self.disable_all_loras() + + module_loras = {} # {module_name: [lora1, lora2, ...]} + model_device = next(self.parameters()).device + model_dtype = next(self.parameters()).dtype + + for lora_key in lora_key_list: + if lora_key in self.lora_dict: + for lora in self.lora_dict[lora_key].loras: + lora.to(model_device, dtype=model_dtype, non_blocking=True) + module_name = lora.lora_name.replace("lora___lorahyphen___", "").replace("___lorahyphen___", ".") + if module_name not in module_loras: + module_loras[module_name] = [] + module_loras[module_name].append(lora) + self.active_loras.append(lora_key) + + for module_name, loras in module_loras.items(): + module = self._get_module_by_name(module_name) + if not hasattr(module, 'org_forward'): + module.org_forward = module.forward + module.forward = self._create_multi_lora_forward(module, loras) + + def _create_multi_lora_forward(self, module, loras): + def multi_lora_forward(x, *args, **kwargs): + weight_dtype = x.dtype + org_output = module.org_forward(x, *args, **kwargs) + + total_lora_output = 0 + for lora in loras: + if lora.use_lora: + lx = lora.lora_down(x.to(lora.lora_down.weight.dtype)) + lx = lora.lora_up(lx) + lora_output = lx.to(weight_dtype) * lora.multiplier * lora.alpha_scale + total_lora_output += lora_output + + return org_output + total_lora_output + + return multi_lora_forward + + def _get_module_by_name(self, module_name): + try: + module = self + for part in module_name.split('.'): + module = getattr(module, part) + return module + except AttributeError as e: + raise ValueError(f"Cannot find module: {module_name}, error: {e}") + + def disable_all_loras(self): + for name, module in self.named_modules(): + if hasattr(module, 'org_forward'): + module.forward = module.org_forward + delattr(module, 'org_forward') + + for lora_key, lora_network in self.lora_dict.items(): + for lora in lora_network.loras: + lora.to("cpu") + + self.active_loras.clear() + + def enable_bsa(self,): + for block in self.blocks: + block.attn.enable_bsa = True + + def disable_bsa(self,): + for block in self.blocks: + block.attn.enable_bsa = False + + def forward( + self, + hidden_states, + timestep, + encoder_hidden_states, + encoder_attention_mask=None, + num_cond_latents=0, + return_kv=False, + kv_cache_dict={}, + skip_crs_attn=False, + offload_kv_cache=False, + use_gradient_checkpointing=False, + use_gradient_checkpointing_offload=False, + ): + + B, _, T, H, W = hidden_states.shape + + N_t = T // self.patch_size[0] + N_h = H // self.patch_size[1] + N_w = W // self.patch_size[2] + + assert self.patch_size[0]==1, "Currently, 3D x_embedder should not compress the temporal dimension." + + # expand the shape of timestep from [B] to [B, T] + if len(timestep.shape) == 1: + timestep = timestep.unsqueeze(1).expand(-1, N_t).clone() # [B, T] + timestep[:, :num_cond_latents] = 0 + + dtype = hidden_states.dtype + hidden_states = hidden_states.to(dtype) + timestep = timestep.to(dtype) + encoder_hidden_states = encoder_hidden_states.to(dtype) + + hidden_states = self.x_embedder(hidden_states) # [B, N, C] + + with amp.autocast(device_type=get_device_type(), dtype=torch.float32): + t = self.t_embedder(timestep.float().flatten(), dtype=torch.float32).reshape(B, N_t, -1) # [B, T, C_t] + + encoder_hidden_states = self.y_embedder(encoder_hidden_states) # [B, 1, N_token, C] + + if self.text_tokens_zero_pad and encoder_attention_mask is not None: + encoder_hidden_states = encoder_hidden_states * encoder_attention_mask[:, None, :, None] + encoder_attention_mask = (encoder_attention_mask * 0 + 1).to(encoder_attention_mask.dtype) + + if encoder_attention_mask is not None: + encoder_attention_mask = encoder_attention_mask.squeeze(1).squeeze(1) + encoder_hidden_states = encoder_hidden_states.squeeze(1).masked_select(encoder_attention_mask.unsqueeze(-1) != 0).view(1, -1, hidden_states.shape[-1]) # [1, N_valid_tokens, C] + y_seqlens = encoder_attention_mask.sum(dim=1).tolist() # [B] + else: + y_seqlens = [encoder_hidden_states.shape[2]] * encoder_hidden_states.shape[0] + encoder_hidden_states = encoder_hidden_states.squeeze(1).view(1, -1, hidden_states.shape[-1]) + + # if self.cp_split_hw[0] * self.cp_split_hw[1] > 1: + # hidden_states = rearrange(hidden_states, "B (T H W) C -> B T H W C", T=N_t, H=N_h, W=N_w) + # hidden_states = context_parallel_util.split_cp_2d(hidden_states, seq_dim_hw=(2, 3), split_hw=self.cp_split_hw) + # hidden_states = rearrange(hidden_states, "B T H W C -> B (T H W) C") + + # blocks + kv_cache_dict_ret = {} + for i, block in enumerate(self.blocks): + block_outputs = gradient_checkpoint_forward( + block, + use_gradient_checkpointing=use_gradient_checkpointing, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + x=hidden_states, + y=encoder_hidden_states, + t=t, + y_seqlen=y_seqlens, + latent_shape=(N_t, N_h, N_w), + num_cond_latents=num_cond_latents, + return_kv=return_kv, + kv_cache=kv_cache_dict.get(i, None), + skip_crs_attn=skip_crs_attn, + ) + + if return_kv: + hidden_states, kv_cache = block_outputs + if offload_kv_cache: + kv_cache_dict_ret[i] = (kv_cache[0].cpu(), kv_cache[1].cpu()) + else: + kv_cache_dict_ret[i] = (kv_cache[0].contiguous(), kv_cache[1].contiguous()) + else: + hidden_states = block_outputs + + hidden_states = self.final_layer(hidden_states, t, (N_t, N_h, N_w)) # [B, N, C=T_p*H_p*W_p*C_out] + + # if self.cp_split_hw[0] * self.cp_split_hw[1] > 1: + # hidden_states = context_parallel_util.gather_cp_2d(hidden_states, shape=(N_t, N_h, N_w), split_hw=self.cp_split_hw) + + hidden_states = self.unpatchify(hidden_states, N_t, N_h, N_w) # [B, C_out, H, W] + + # cast to float32 for better accuracy + hidden_states = hidden_states.to(torch.float32) + + if return_kv: + return hidden_states, kv_cache_dict_ret + else: + return hidden_states + + + def unpatchify(self, x, N_t, N_h, N_w): + """ + Args: + x (torch.Tensor): of shape [B, N, C] + + Return: + x (torch.Tensor): of shape [B, C_out, T, H, W] + """ + T_p, H_p, W_p = self.patch_size + x = rearrange( + x, + "B (N_t N_h N_w) (T_p H_p W_p C_out) -> B C_out (N_t T_p) (N_h H_p) (N_w W_p)", + N_t=N_t, + N_h=N_h, + N_w=N_w, + T_p=T_p, + H_p=H_p, + W_p=W_p, + C_out=self.out_channels, + ) + return x + + @staticmethod + def state_dict_converter(): + return LongCatVideoTransformer3DModelDictConverter() + + +class LongCatVideoTransformer3DModelDictConverter: + def __init__(self): + pass + + def from_diffusers(self, state_dict): + return state_dict + + def from_civitai(self, state_dict): + return state_dict + diff --git a/diffsynth/models/ltx2_audio_vae.py b/diffsynth/models/ltx2_audio_vae.py new file mode 100644 index 0000000000000000000000000000000000000000..082a7b8eeb7d998f929096fa74328939ade008dd --- /dev/null +++ b/diffsynth/models/ltx2_audio_vae.py @@ -0,0 +1,1872 @@ +from typing import Set, Tuple, Optional, List +from enum import Enum +import math +import einops +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchaudio +from .ltx2_common import VideoLatentShape, AudioLatentShape, Patchifier, NormType, build_normalization_layer + + +class AudioProcessor(nn.Module): + """Converts audio waveforms to log-mel spectrograms with optional resampling.""" + + def __init__( + self, + sample_rate: int = 16000, + mel_bins: int = 64, + mel_hop_length: int = 160, + n_fft: int = 1024, + ) -> None: + super().__init__() + self.sample_rate = sample_rate + self.mel_transform = torchaudio.transforms.MelSpectrogram( + sample_rate=sample_rate, + n_fft=n_fft, + win_length=n_fft, + hop_length=mel_hop_length, + f_min=0.0, + f_max=sample_rate / 2.0, + n_mels=mel_bins, + window_fn=torch.hann_window, + center=True, + pad_mode="reflect", + power=1.0, + mel_scale="slaney", + norm="slaney", + ) + + def resample_waveform( + self, + waveform: torch.Tensor, + source_rate: int, + target_rate: int, + ) -> torch.Tensor: + """Resample waveform to target sample rate if needed.""" + if source_rate == target_rate: + return waveform + resampled = torchaudio.functional.resample(waveform, source_rate, target_rate) + return resampled.to(device=waveform.device, dtype=waveform.dtype) + + def waveform_to_mel( + self, + waveform: torch.Tensor, + waveform_sample_rate: int, + ) -> torch.Tensor: + """Convert waveform to log-mel spectrogram [batch, channels, time, n_mels].""" + waveform = self.resample_waveform(waveform, waveform_sample_rate, self.sample_rate) + + mel = self.mel_transform(waveform) + mel = torch.log(torch.clamp(mel, min=1e-5)) + + mel = mel.to(device=waveform.device, dtype=waveform.dtype) + return mel.permute(0, 1, 3, 2).contiguous() + + +class AudioPatchifier(Patchifier): + def __init__( + self, + patch_size: int, + sample_rate: int = 16000, + hop_length: int = 160, + audio_latent_downsample_factor: int = 4, + is_causal: bool = True, + shift: int = 0, + ): + """ + Patchifier tailored for spectrogram/audio latents. + Args: + patch_size: Number of mel bins combined into a single patch. This + controls the resolution along the frequency axis. + sample_rate: Original waveform sampling rate. Used to map latent + indices back to seconds so downstream consumers can align audio + and video cues. + hop_length: Window hop length used for the spectrogram. Determines + how many real-time samples separate two consecutive latent frames. + audio_latent_downsample_factor: Ratio between spectrogram frames and + latent frames; compensates for additional downsampling inside the + VAE encoder. + is_causal: When True, timing is shifted to account for causal + receptive fields so timestamps do not peek into the future. + shift: Integer offset applied to the latent indices. Enables + constructing overlapping windows from the same latent sequence. + """ + self.hop_length = hop_length + self.sample_rate = sample_rate + self.audio_latent_downsample_factor = audio_latent_downsample_factor + self.is_causal = is_causal + self.shift = shift + self._patch_size = (1, patch_size, patch_size) + + @property + def patch_size(self) -> Tuple[int, int, int]: + return self._patch_size + + def get_token_count(self, tgt_shape: AudioLatentShape) -> int: + return tgt_shape.frames + + def _get_audio_latent_time_in_sec( + self, + start_latent: int, + end_latent: int, + dtype: torch.dtype, + device: Optional[torch.device] = None, + ) -> torch.Tensor: + """ + Converts latent indices into real-time seconds while honoring causal + offsets and the configured hop length. + Args: + start_latent: Inclusive start index inside the latent sequence. This + sets the first timestamp returned. + end_latent: Exclusive end index. Determines how many timestamps get + generated. + dtype: Floating-point dtype used for the returned tensor, allowing + callers to control precision. + device: Target device for the timestamp tensor. When omitted the + computation occurs on CPU to avoid surprising GPU allocations. + """ + if device is None: + device = torch.device("cpu") + + audio_latent_frame = torch.arange(start_latent, end_latent, dtype=dtype, device=device) + + audio_mel_frame = audio_latent_frame * self.audio_latent_downsample_factor + + if self.is_causal: + # Frame offset for causal alignment. + # The "+1" ensures the timestamp corresponds to the first sample that is fully available. + causal_offset = 1 + audio_mel_frame = (audio_mel_frame + causal_offset - self.audio_latent_downsample_factor).clip(min=0) + + return audio_mel_frame * self.hop_length / self.sample_rate + + def _compute_audio_timings( + self, + batch_size: int, + num_steps: int, + device: Optional[torch.device] = None, + ) -> torch.Tensor: + """ + Builds a `(B, 1, T, 2)` tensor containing timestamps for each latent frame. + This helper method underpins `get_patch_grid_bounds` for the audio patchifier. + Args: + batch_size: Number of sequences to broadcast the timings over. + num_steps: Number of latent frames (time steps) to convert into timestamps. + device: Device on which the resulting tensor should reside. + """ + resolved_device = device + if resolved_device is None: + resolved_device = torch.device("cpu") + + start_timings = self._get_audio_latent_time_in_sec( + self.shift, + num_steps + self.shift, + torch.float32, + resolved_device, + ) + start_timings = start_timings.unsqueeze(0).expand(batch_size, -1).unsqueeze(1) + + end_timings = self._get_audio_latent_time_in_sec( + self.shift + 1, + num_steps + self.shift + 1, + torch.float32, + resolved_device, + ) + end_timings = end_timings.unsqueeze(0).expand(batch_size, -1).unsqueeze(1) + + return torch.stack([start_timings, end_timings], dim=-1) + + def patchify( + self, + audio_latents: torch.Tensor, + ) -> torch.Tensor: + """ + Flattens the audio latent tensor along time. Use `get_patch_grid_bounds` + to derive timestamps for each latent frame based on the configured hop + length and downsampling. + Args: + audio_latents: Latent tensor to patchify. + Returns: + Flattened patch tokens tensor. Use `get_patch_grid_bounds` to compute the + corresponding timing metadata when needed. + """ + audio_latents = einops.rearrange( + audio_latents, + "b c t f -> b t (c f)", + ) + + return audio_latents + + def unpatchify( + self, + audio_latents: torch.Tensor, + output_shape: AudioLatentShape, + ) -> torch.Tensor: + """ + Restores the `(B, C, T, F)` spectrogram tensor from flattened patches. + Use `get_patch_grid_bounds` to recompute the timestamps that describe each + frame's position in real time. + Args: + audio_latents: Latent tensor to unpatchify. + output_shape: Shape of the unpatched output tensor. + Returns: + Unpatched latent tensor. Use `get_patch_grid_bounds` to compute the timing + metadata associated with the restored latents. + """ + # audio_latents shape: (batch, time, freq * channels) + audio_latents = einops.rearrange( + audio_latents, + "b t (c f) -> b c t f", + c=output_shape.channels, + f=output_shape.mel_bins, + ) + + return audio_latents + + def unpatchify_audio( + self, + audio_latents: torch.Tensor, + channels: int, + mel_bins: int + ) -> torch.Tensor: + audio_latents = einops.rearrange( + audio_latents, + "b t (c f) -> b c t f", + c=channels, + f=mel_bins, + ) + return audio_latents + + def get_patch_grid_bounds( + self, + output_shape: AudioLatentShape | VideoLatentShape, + device: Optional[torch.device] = None, + ) -> torch.Tensor: + """ + Return the temporal bounds `[inclusive start, exclusive end)` for every + patch emitted by `patchify`. For audio this corresponds to timestamps in + seconds aligned with the original spectrogram grid. + The returned tensor has shape `[batch_size, 1, time_steps, 2]`, where: + - axis 1 (size 1) represents the temporal dimension + - axis 3 (size 2) stores the `[start, end)` timestamps per patch + Args: + output_shape: Audio grid specification describing the number of time steps. + device: Target device for the returned tensor. + """ + if not isinstance(output_shape, AudioLatentShape): + raise ValueError("AudioPatchifier expects AudioLatentShape when computing coordinates") + + return self._compute_audio_timings(output_shape.batch, output_shape.frames, device) + + +class AttentionType(Enum): + """Enum for specifying the attention mechanism type.""" + + VANILLA = "vanilla" + LINEAR = "linear" + NONE = "none" + + +class AttnBlock(torch.nn.Module): + def __init__( + self, + in_channels: int, + norm_type: NormType = NormType.GROUP, + ) -> None: + super().__init__() + self.in_channels = in_channels + + self.norm = build_normalization_layer(in_channels, normtype=norm_type) + self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q.shape + q = q.reshape(b, c, h * w).contiguous() + q = q.permute(0, 2, 1).contiguous() # b,hw,c + k = k.reshape(b, c, h * w).contiguous() # b,c,hw + w_ = torch.bmm(q, k).contiguous() # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c) ** (-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b, c, h * w).contiguous() + w_ = w_.permute(0, 2, 1).contiguous() # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm(v, w_).contiguous() # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b, c, h, w).contiguous() + + h_ = self.proj_out(h_) + + return x + h_ + + +def make_attn( + in_channels: int, + attn_type: AttentionType = AttentionType.VANILLA, + norm_type: NormType = NormType.GROUP, +) -> torch.nn.Module: + match attn_type: + case AttentionType.VANILLA: + return AttnBlock(in_channels, norm_type=norm_type) + case AttentionType.NONE: + return torch.nn.Identity() + case AttentionType.LINEAR: + raise NotImplementedError(f"Attention type {attn_type.value} is not supported yet.") + case _: + raise ValueError(f"Unknown attention type: {attn_type}") + + +class CausalityAxis(Enum): + """Enum for specifying the causality axis in causal convolutions.""" + + NONE = None + WIDTH = "width" + HEIGHT = "height" + WIDTH_COMPATIBILITY = "width-compatibility" + + +class CausalConv2d(torch.nn.Module): + """ + A causal 2D convolution. + This layer ensures that the output at time `t` only depends on inputs + at time `t` and earlier. It achieves this by applying asymmetric padding + to the time dimension (width) before the convolution. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int | tuple[int, int], + stride: int = 1, + dilation: int | tuple[int, int] = 1, + groups: int = 1, + bias: bool = True, + causality_axis: CausalityAxis = CausalityAxis.HEIGHT, + ) -> None: + super().__init__() + + self.causality_axis = causality_axis + + # Ensure kernel_size and dilation are tuples + kernel_size = torch.nn.modules.utils._pair(kernel_size) + dilation = torch.nn.modules.utils._pair(dilation) + + # Calculate padding dimensions + pad_h = (kernel_size[0] - 1) * dilation[0] + pad_w = (kernel_size[1] - 1) * dilation[1] + + # The padding tuple for F.pad is (pad_left, pad_right, pad_top, pad_bottom) + match self.causality_axis: + case CausalityAxis.NONE: + self.padding = (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2) + case CausalityAxis.WIDTH | CausalityAxis.WIDTH_COMPATIBILITY: + self.padding = (pad_w, 0, pad_h // 2, pad_h - pad_h // 2) + case CausalityAxis.HEIGHT: + self.padding = (pad_w // 2, pad_w - pad_w // 2, pad_h, 0) + case _: + raise ValueError(f"Invalid causality_axis: {causality_axis}") + + # The internal convolution layer uses no padding, as we handle it manually + self.conv = torch.nn.Conv2d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=0, + dilation=dilation, + groups=groups, + bias=bias, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # Apply causal padding before convolution + x = F.pad(x, self.padding) + return self.conv(x) + + +def make_conv2d( + in_channels: int, + out_channels: int, + kernel_size: int | tuple[int, int], + stride: int = 1, + padding: tuple[int, int, int, int] | None = None, + dilation: int = 1, + groups: int = 1, + bias: bool = True, + causality_axis: CausalityAxis | None = None, +) -> torch.nn.Module: + """ + Create a 2D convolution layer that can be either causal or non-causal. + Args: + in_channels: Number of input channels + out_channels: Number of output channels + kernel_size: Size of the convolution kernel + stride: Convolution stride + padding: Padding (if None, will be calculated based on causal flag) + dilation: Dilation rate + groups: Number of groups for grouped convolution + bias: Whether to use bias + causality_axis: Dimension along which to apply causality. + Returns: + Either a regular Conv2d or CausalConv2d layer + """ + if causality_axis is not None: + # For causal convolution, padding is handled internally by CausalConv2d + return CausalConv2d(in_channels, out_channels, kernel_size, stride, dilation, groups, bias, causality_axis) + else: + # For non-causal convolution, use symmetric padding if not specified + if padding is None: + padding = kernel_size // 2 if isinstance(kernel_size, int) else tuple(k // 2 for k in kernel_size) + + return torch.nn.Conv2d( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + groups, + bias, + ) + + + +LRELU_SLOPE = 0.1 + + +class ResBlock1(torch.nn.Module): + def __init__(self, channels: int, kernel_size: int = 3, dilation: Tuple[int, int, int] = (1, 3, 5)): + super(ResBlock1, self).__init__() + self.convs1 = torch.nn.ModuleList( + [ + torch.nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding="same", + ), + torch.nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding="same", + ), + torch.nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[2], + padding="same", + ), + ] + ) + + self.convs2 = torch.nn.ModuleList( + [ + torch.nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding="same", + ), + torch.nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding="same", + ), + torch.nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding="same", + ), + ] + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + for conv1, conv2 in zip(self.convs1, self.convs2, strict=True): + xt = torch.nn.functional.leaky_relu(x, LRELU_SLOPE) + xt = conv1(xt) + xt = torch.nn.functional.leaky_relu(xt, LRELU_SLOPE) + xt = conv2(xt) + x = xt + x + return x + + +class ResBlock2(torch.nn.Module): + def __init__(self, channels: int, kernel_size: int = 3, dilation: Tuple[int, int] = (1, 3)): + super(ResBlock2, self).__init__() + self.convs = torch.nn.ModuleList( + [ + torch.nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding="same", + ), + torch.nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding="same", + ), + ] + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + for conv in self.convs: + xt = torch.nn.functional.leaky_relu(x, LRELU_SLOPE) + xt = conv(xt) + x = xt + x + return x + + +class ResnetBlock(torch.nn.Module): + def __init__( + self, + *, + in_channels: int, + out_channels: int | None = None, + conv_shortcut: bool = False, + dropout: float = 0.0, + temb_channels: int = 512, + norm_type: NormType = NormType.GROUP, + causality_axis: CausalityAxis = CausalityAxis.HEIGHT, + ) -> None: + super().__init__() + self.causality_axis = causality_axis + + if self.causality_axis != CausalityAxis.NONE and norm_type == NormType.GROUP: + raise ValueError("Causal ResnetBlock with GroupNorm is not supported.") + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = build_normalization_layer(in_channels, normtype=norm_type) + self.non_linearity = torch.nn.SiLU() + self.conv1 = make_conv2d(in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis) + if temb_channels > 0: + self.temb_proj = torch.nn.Linear(temb_channels, out_channels) + self.norm2 = build_normalization_layer(out_channels, normtype=norm_type) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = make_conv2d(out_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = make_conv2d( + in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis + ) + else: + self.nin_shortcut = make_conv2d( + in_channels, out_channels, kernel_size=1, stride=1, causality_axis=causality_axis + ) + + def forward( + self, + x: torch.Tensor, + temb: torch.Tensor | None = None, + ) -> torch.Tensor: + h = x + h = self.norm1(h) + h = self.non_linearity(h) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(self.non_linearity(temb))[:, :, None, None] + + h = self.norm2(h) + h = self.non_linearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + x = self.conv_shortcut(x) if self.use_conv_shortcut else self.nin_shortcut(x) + + return x + h + + +class Downsample(torch.nn.Module): + """ + A downsampling layer that can use either a strided convolution + or average pooling. Supports standard and causal padding for the + convolutional mode. + """ + + def __init__( + self, + in_channels: int, + with_conv: bool, + causality_axis: CausalityAxis = CausalityAxis.WIDTH, + ) -> None: + super().__init__() + self.with_conv = with_conv + self.causality_axis = causality_axis + + if self.causality_axis != CausalityAxis.NONE and not self.with_conv: + raise ValueError("causality is only supported when `with_conv=True`.") + + if self.with_conv: + # Do time downsampling here + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.with_conv: + # Padding tuple is in the order: (left, right, top, bottom). + match self.causality_axis: + case CausalityAxis.NONE: + pad = (0, 1, 0, 1) + case CausalityAxis.WIDTH: + pad = (2, 0, 0, 1) + case CausalityAxis.HEIGHT: + pad = (0, 1, 2, 0) + case CausalityAxis.WIDTH_COMPATIBILITY: + pad = (1, 0, 0, 1) + case _: + raise ValueError(f"Invalid causality_axis: {self.causality_axis}") + + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + # This branch is only taken if with_conv=False, which implies causality_axis is NONE. + x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + + return x + + +def build_downsampling_path( # noqa: PLR0913 + *, + ch: int, + ch_mult: Tuple[int, ...], + num_resolutions: int, + num_res_blocks: int, + resolution: int, + temb_channels: int, + dropout: float, + norm_type: NormType, + causality_axis: CausalityAxis, + attn_type: AttentionType, + attn_resolutions: Set[int], + resamp_with_conv: bool, +) -> tuple[torch.nn.ModuleList, int]: + """Build the downsampling path with residual blocks, attention, and downsampling layers.""" + down_modules = torch.nn.ModuleList() + curr_res = resolution + in_ch_mult = (1, *tuple(ch_mult)) + block_in = ch + + for i_level in range(num_resolutions): + block = torch.nn.ModuleList() + attn = torch.nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + + for _ in range(num_res_blocks): + block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=temb_channels, + dropout=dropout, + norm_type=norm_type, + causality_axis=causality_axis, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type, norm_type=norm_type)) + + down = torch.nn.Module() + down.block = block + down.attn = attn + if i_level != num_resolutions - 1: + down.downsample = Downsample(block_in, resamp_with_conv, causality_axis=causality_axis) + curr_res = curr_res // 2 + down_modules.append(down) + + return down_modules, block_in + + +class Upsample(torch.nn.Module): + def __init__( + self, + in_channels: int, + with_conv: bool, + causality_axis: CausalityAxis = CausalityAxis.HEIGHT, + ) -> None: + super().__init__() + self.with_conv = with_conv + self.causality_axis = causality_axis + if self.with_conv: + self.conv = make_conv2d(in_channels, in_channels, kernel_size=3, stride=1, causality_axis=causality_axis) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + if self.with_conv: + x = self.conv(x) + # Drop FIRST element in the causal axis to undo encoder's padding, while keeping the length 1 + 2 * n. + # For example, if the input is [0, 1, 2], after interpolation, the output is [0, 0, 1, 1, 2, 2]. + # The causal convolution will pad the first element as [-, -, 0, 0, 1, 1, 2, 2], + # So the output elements rely on the following windows: + # 0: [-,-,0] + # 1: [-,0,0] + # 2: [0,0,1] + # 3: [0,1,1] + # 4: [1,1,2] + # 5: [1,2,2] + # Notice that the first and second elements in the output rely only on the first element in the input, + # while all other elements rely on two elements in the input. + # So we can drop the first element to undo the padding (rather than the last element). + # This is a no-op for non-causal convolutions. + match self.causality_axis: + case CausalityAxis.NONE: + pass # x remains unchanged + case CausalityAxis.HEIGHT: + x = x[:, :, 1:, :] + case CausalityAxis.WIDTH: + x = x[:, :, :, 1:] + case CausalityAxis.WIDTH_COMPATIBILITY: + pass # x remains unchanged + case _: + raise ValueError(f"Invalid causality_axis: {self.causality_axis}") + + return x + + +def build_upsampling_path( # noqa: PLR0913 + *, + ch: int, + ch_mult: Tuple[int, ...], + num_resolutions: int, + num_res_blocks: int, + resolution: int, + temb_channels: int, + dropout: float, + norm_type: NormType, + causality_axis: CausalityAxis, + attn_type: AttentionType, + attn_resolutions: Set[int], + resamp_with_conv: bool, + initial_block_channels: int, +) -> tuple[torch.nn.ModuleList, int]: + """Build the upsampling path with residual blocks, attention, and upsampling layers.""" + up_modules = torch.nn.ModuleList() + block_in = initial_block_channels + curr_res = resolution // (2 ** (num_resolutions - 1)) + + for level in reversed(range(num_resolutions)): + stage = torch.nn.Module() + stage.block = torch.nn.ModuleList() + stage.attn = torch.nn.ModuleList() + block_out = ch * ch_mult[level] + + for _ in range(num_res_blocks + 1): + stage.block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=temb_channels, + dropout=dropout, + norm_type=norm_type, + causality_axis=causality_axis, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + stage.attn.append(make_attn(block_in, attn_type=attn_type, norm_type=norm_type)) + + if level != 0: + stage.upsample = Upsample(block_in, resamp_with_conv, causality_axis=causality_axis) + curr_res *= 2 + + up_modules.insert(0, stage) + + return up_modules, block_in + + +class PerChannelStatistics(nn.Module): + """ + Per-channel statistics for normalizing and denormalizing the latent representation. + This statics is computed over the entire dataset and stored in model's checkpoint under AudioVAE state_dict. + """ + + def __init__(self, latent_channels: int = 128) -> None: + super().__init__() + self.register_buffer("std-of-means", torch.empty(latent_channels)) + self.register_buffer("mean-of-means", torch.empty(latent_channels)) + + def un_normalize(self, x: torch.Tensor) -> torch.Tensor: + return (x * self.get_buffer("std-of-means").to(x)) + self.get_buffer("mean-of-means").to(x) + + def normalize(self, x: torch.Tensor) -> torch.Tensor: + return (x - self.get_buffer("mean-of-means").to(x)) / self.get_buffer("std-of-means").to(x) + + +LATENT_DOWNSAMPLE_FACTOR = 4 + + +def build_mid_block( + channels: int, + temb_channels: int, + dropout: float, + norm_type: NormType, + causality_axis: CausalityAxis, + attn_type: AttentionType, + add_attention: bool, +) -> torch.nn.Module: + """Build the middle block with two ResNet blocks and optional attention.""" + mid = torch.nn.Module() + mid.block_1 = ResnetBlock( + in_channels=channels, + out_channels=channels, + temb_channels=temb_channels, + dropout=dropout, + norm_type=norm_type, + causality_axis=causality_axis, + ) + mid.attn_1 = make_attn(channels, attn_type=attn_type, norm_type=norm_type) if add_attention else torch.nn.Identity() + mid.block_2 = ResnetBlock( + in_channels=channels, + out_channels=channels, + temb_channels=temb_channels, + dropout=dropout, + norm_type=norm_type, + causality_axis=causality_axis, + ) + return mid + + +def run_mid_block(mid: torch.nn.Module, features: torch.Tensor) -> torch.Tensor: + """Run features through the middle block.""" + features = mid.block_1(features, temb=None) + features = mid.attn_1(features) + return mid.block_2(features, temb=None) + + +class LTX2AudioEncoder(torch.nn.Module): + """ + Encoder that compresses audio spectrograms into latent representations. + The encoder uses a series of downsampling blocks with residual connections, + attention mechanisms, and configurable causal convolutions. + """ + + def __init__( # noqa: PLR0913 + self, + *, + ch: int = 128, + ch_mult: Tuple[int, ...] = (1, 2, 4), + num_res_blocks: int = 2, + attn_resolutions: Set[int] = set(), + dropout: float = 0.0, + resamp_with_conv: bool = True, + in_channels: int = 2, + resolution: int = 256, + z_channels: int = 8, + double_z: bool = True, + attn_type: AttentionType = AttentionType.VANILLA, + mid_block_add_attention: bool = False, + norm_type: NormType = NormType.PIXEL, + causality_axis: CausalityAxis = CausalityAxis.HEIGHT, + sample_rate: int = 16000, + mel_hop_length: int = 160, + n_fft: int = 1024, + is_causal: bool = True, + mel_bins: int = 64, + **_ignore_kwargs, + ) -> None: + """ + Initialize the Encoder. + Args: + Arguments are configuration parameters, loaded from the audio VAE checkpoint config + (audio_vae.model.params.ddconfig): + ch: Base number of feature channels used in the first convolution layer. + ch_mult: Multiplicative factors for the number of channels at each resolution level. + num_res_blocks: Number of residual blocks to use at each resolution level. + attn_resolutions: Spatial resolutions (e.g., in time/frequency) at which to apply attention. + resolution: Input spatial resolution of the spectrogram (height, width). + z_channels: Number of channels in the latent representation. + norm_type: Normalization layer type to use within the network (e.g., group, batch). + causality_axis: Axis along which convolutions should be causal (e.g., time axis). + sample_rate: Audio sample rate in Hz for the input signals. + mel_hop_length: Hop length used when computing the mel spectrogram. + n_fft: FFT size used to compute the spectrogram. + mel_bins: Number of mel-frequency bins in the input spectrogram. + in_channels: Number of channels in the input spectrogram tensor. + double_z: If True, predict both mean and log-variance (doubling latent channels). + is_causal: If True, use causal convolutions suitable for streaming setups. + dropout: Dropout probability used in residual and mid blocks. + attn_type: Type of attention mechanism to use in attention blocks. + resamp_with_conv: If True, perform resolution changes using strided convolutions. + mid_block_add_attention: If True, add an attention block in the mid-level of the encoder. + """ + super().__init__() + + self.per_channel_statistics = PerChannelStatistics(latent_channels=ch) + self.sample_rate = sample_rate + self.mel_hop_length = mel_hop_length + self.n_fft = n_fft + self.is_causal = is_causal + self.mel_bins = mel_bins + + self.patchifier = AudioPatchifier( + patch_size=1, + audio_latent_downsample_factor=LATENT_DOWNSAMPLE_FACTOR, + sample_rate=sample_rate, + hop_length=mel_hop_length, + is_causal=is_causal, + ) + + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.z_channels = z_channels + self.double_z = double_z + self.norm_type = norm_type + self.causality_axis = causality_axis + self.attn_type = attn_type + + # downsampling + self.conv_in = make_conv2d( + in_channels, + self.ch, + kernel_size=3, + stride=1, + causality_axis=self.causality_axis, + ) + + self.non_linearity = torch.nn.SiLU() + + self.down, block_in = build_downsampling_path( + ch=ch, + ch_mult=ch_mult, + num_resolutions=self.num_resolutions, + num_res_blocks=num_res_blocks, + resolution=resolution, + temb_channels=self.temb_ch, + dropout=dropout, + norm_type=self.norm_type, + causality_axis=self.causality_axis, + attn_type=self.attn_type, + attn_resolutions=attn_resolutions, + resamp_with_conv=resamp_with_conv, + ) + + self.mid = build_mid_block( + channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + norm_type=self.norm_type, + causality_axis=self.causality_axis, + attn_type=self.attn_type, + add_attention=mid_block_add_attention, + ) + + self.norm_out = build_normalization_layer(block_in, normtype=self.norm_type) + self.conv_out = make_conv2d( + block_in, + 2 * z_channels if double_z else z_channels, + kernel_size=3, + stride=1, + causality_axis=self.causality_axis, + ) + + def forward(self, spectrogram: torch.Tensor) -> torch.Tensor: + """ + Encode audio spectrogram into latent representations. + Args: + spectrogram: Input spectrogram of shape (batch, channels, time, frequency) + Returns: + Encoded latent representation of shape (batch, channels, frames, mel_bins) + """ + h = self.conv_in(spectrogram) + h = self._run_downsampling_path(h) + h = run_mid_block(self.mid, h) + h = self._finalize_output(h) + + return self._normalize_latents(h) + + def _run_downsampling_path(self, h: torch.Tensor) -> torch.Tensor: + for level in range(self.num_resolutions): + stage = self.down[level] + for block_idx in range(self.num_res_blocks): + h = stage.block[block_idx](h, temb=None) + if stage.attn: + h = stage.attn[block_idx](h) + + if level != self.num_resolutions - 1: + h = stage.downsample(h) + + return h + + def _finalize_output(self, h: torch.Tensor) -> torch.Tensor: + h = self.norm_out(h) + h = self.non_linearity(h) + return self.conv_out(h) + + def _normalize_latents(self, latent_output: torch.Tensor) -> torch.Tensor: + """ + Normalize encoder latents using per-channel statistics. + When the encoder is configured with ``double_z=True``, the final + convolution produces twice the number of latent channels, typically + interpreted as two concatenated tensors along the channel dimension + (e.g., mean and variance or other auxiliary parameters). + This method intentionally uses only the first half of the channels + (the "mean" component) as input to the patchifier and normalization + logic. The remaining channels are left unchanged by this method and + are expected to be consumed elsewhere in the VAE pipeline. + If ``double_z=False``, the encoder output already contains only the + mean latents and the chunking operation simply returns that tensor. + """ + means = torch.chunk(latent_output, 2, dim=1)[0] + latent_shape = AudioLatentShape( + batch=means.shape[0], + channels=means.shape[1], + frames=means.shape[2], + mel_bins=means.shape[3], + ) + latent_patched = self.patchifier.patchify(means) + latent_normalized = self.per_channel_statistics.normalize(latent_patched) + return self.patchifier.unpatchify(latent_normalized, latent_shape) + + +class LTX2AudioDecoder(torch.nn.Module): + """ + Symmetric decoder that reconstructs audio spectrograms from latent features. + The decoder mirrors the encoder structure with configurable channel multipliers, + attention resolutions, and causal convolutions. + """ + + def __init__( # noqa: PLR0913 + self, + *, + ch: int = 128, + out_ch: int = 2, + ch_mult: Tuple[int, ...] = (1, 2, 4), + num_res_blocks: int = 2, + attn_resolutions: Set[int] = set(), + resolution: int=256, + z_channels: int=8, + norm_type: NormType = NormType.PIXEL, + causality_axis: CausalityAxis = CausalityAxis.HEIGHT, + dropout: float = 0.0, + mid_block_add_attention: bool = False, + sample_rate: int = 16000, + mel_hop_length: int = 160, + is_causal: bool = True, + mel_bins: int | None = 64, + ) -> None: + """ + Initialize the Decoder. + Args: + Arguments are configuration parameters, loaded from the audio VAE checkpoint config + (audio_vae.model.params.ddconfig): + - ch, out_ch, ch_mult, num_res_blocks, attn_resolutions + - resolution, z_channels + - norm_type, causality_axis + """ + super().__init__() + + # Internal behavioural defaults that are not driven by the checkpoint. + resamp_with_conv = True + attn_type = AttentionType.VANILLA + + # Per-channel statistics for denormalizing latents + self.per_channel_statistics = PerChannelStatistics(latent_channels=ch) + self.sample_rate = sample_rate + self.mel_hop_length = mel_hop_length + self.is_causal = is_causal + self.mel_bins = mel_bins + self.patchifier = AudioPatchifier( + patch_size=1, + audio_latent_downsample_factor=LATENT_DOWNSAMPLE_FACTOR, + sample_rate=sample_rate, + hop_length=mel_hop_length, + is_causal=is_causal, + ) + + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.out_ch = out_ch + self.give_pre_end = False + self.tanh_out = False + self.norm_type = norm_type + self.z_channels = z_channels + self.channel_multipliers = ch_mult + self.attn_resolutions = attn_resolutions + self.causality_axis = causality_axis + self.attn_type = attn_type + + base_block_channels = ch * self.channel_multipliers[-1] + base_resolution = resolution // (2 ** (self.num_resolutions - 1)) + self.z_shape = (1, z_channels, base_resolution, base_resolution) + + self.conv_in = make_conv2d( + z_channels, base_block_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis + ) + self.non_linearity = torch.nn.SiLU() + self.mid = build_mid_block( + channels=base_block_channels, + temb_channels=self.temb_ch, + dropout=dropout, + norm_type=self.norm_type, + causality_axis=self.causality_axis, + attn_type=self.attn_type, + add_attention=mid_block_add_attention, + ) + self.up, final_block_channels = build_upsampling_path( + ch=ch, + ch_mult=ch_mult, + num_resolutions=self.num_resolutions, + num_res_blocks=num_res_blocks, + resolution=resolution, + temb_channels=self.temb_ch, + dropout=dropout, + norm_type=self.norm_type, + causality_axis=self.causality_axis, + attn_type=self.attn_type, + attn_resolutions=attn_resolutions, + resamp_with_conv=resamp_with_conv, + initial_block_channels=base_block_channels, + ) + + self.norm_out = build_normalization_layer(final_block_channels, normtype=self.norm_type) + self.conv_out = make_conv2d( + final_block_channels, out_ch, kernel_size=3, stride=1, causality_axis=self.causality_axis + ) + + def forward(self, sample: torch.Tensor) -> torch.Tensor: + """ + Decode latent features back to audio spectrograms. + Args: + sample: Encoded latent representation of shape (batch, channels, frames, mel_bins) + Returns: + Reconstructed audio spectrogram of shape (batch, channels, time, frequency) + """ + sample, target_shape = self._denormalize_latents(sample) + + h = self.conv_in(sample) + h = run_mid_block(self.mid, h) + h = self._run_upsampling_path(h) + h = self._finalize_output(h) + + return self._adjust_output_shape(h, target_shape) + + def _denormalize_latents(self, sample: torch.Tensor) -> tuple[torch.Tensor, AudioLatentShape]: + latent_shape = AudioLatentShape( + batch=sample.shape[0], + channels=sample.shape[1], + frames=sample.shape[2], + mel_bins=sample.shape[3], + ) + + sample_patched = self.patchifier.patchify(sample) + sample_denormalized = self.per_channel_statistics.un_normalize(sample_patched) + sample = self.patchifier.unpatchify(sample_denormalized, latent_shape) + + target_frames = latent_shape.frames * LATENT_DOWNSAMPLE_FACTOR + if self.causality_axis != CausalityAxis.NONE: + target_frames = max(target_frames - (LATENT_DOWNSAMPLE_FACTOR - 1), 1) + + target_shape = AudioLatentShape( + batch=latent_shape.batch, + channels=self.out_ch, + frames=target_frames, + mel_bins=self.mel_bins if self.mel_bins is not None else latent_shape.mel_bins, + ) + + return sample, target_shape + + def _adjust_output_shape( + self, + decoded_output: torch.Tensor, + target_shape: AudioLatentShape, + ) -> torch.Tensor: + """ + Adjust output shape to match target dimensions for variable-length audio. + This function handles the common case where decoded audio spectrograms need to be + resized to match a specific target shape. + Args: + decoded_output: Tensor of shape (batch, channels, time, frequency) + target_shape: AudioLatentShape describing (batch, channels, time, mel bins) + Returns: + Tensor adjusted to match target_shape exactly + """ + # Current output shape: (batch, channels, time, frequency) + _, _, current_time, current_freq = decoded_output.shape + target_channels = target_shape.channels + target_time = target_shape.frames + target_freq = target_shape.mel_bins + + # Step 1: Crop first to avoid exceeding target dimensions + decoded_output = decoded_output[ + :, :target_channels, : min(current_time, target_time), : min(current_freq, target_freq) + ] + + # Step 2: Calculate padding needed for time and frequency dimensions + time_padding_needed = target_time - decoded_output.shape[2] + freq_padding_needed = target_freq - decoded_output.shape[3] + + # Step 3: Apply padding if needed + if time_padding_needed > 0 or freq_padding_needed > 0: + # PyTorch padding format: (pad_left, pad_right, pad_top, pad_bottom) + # For audio: pad_left/right = frequency, pad_top/bottom = time + padding = ( + 0, + max(freq_padding_needed, 0), # frequency padding (left, right) + 0, + max(time_padding_needed, 0), # time padding (top, bottom) + ) + decoded_output = F.pad(decoded_output, padding) + + # Step 4: Final safety crop to ensure exact target shape + decoded_output = decoded_output[:, :target_channels, :target_time, :target_freq] + + return decoded_output + + def _run_upsampling_path(self, h: torch.Tensor) -> torch.Tensor: + for level in reversed(range(self.num_resolutions)): + stage = self.up[level] + for block_idx, block in enumerate(stage.block): + h = block(h, temb=None) + if stage.attn: + h = stage.attn[block_idx](h) + + if level != 0 and hasattr(stage, "upsample"): + h = stage.upsample(h) + + return h + + def _finalize_output(self, h: torch.Tensor) -> torch.Tensor: + if self.give_pre_end: + return h + + h = self.norm_out(h) + h = self.non_linearity(h) + h = self.conv_out(h) + return torch.tanh(h) if self.tanh_out else h + + +def get_padding(kernel_size: int, dilation: int = 1) -> int: + return int((kernel_size * dilation - dilation) / 2) + + +# --------------------------------------------------------------------------- +# Anti-aliased resampling helpers (kaiser-sinc filters) for BigVGAN v2 +# Adopted from https://github.com/NVIDIA/BigVGAN +# --------------------------------------------------------------------------- + + +def _sinc(x: torch.Tensor) -> torch.Tensor: + return torch.where( + x == 0, + torch.tensor(1.0, device=x.device, dtype=x.dtype), + torch.sin(math.pi * x) / math.pi / x, + ) + + +def kaiser_sinc_filter1d(cutoff: float, half_width: float, kernel_size: int) -> torch.Tensor: + even = kernel_size % 2 == 0 + half_size = kernel_size // 2 + delta_f = 4 * half_width + amplitude = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95 + if amplitude > 50.0: + beta = 0.1102 * (amplitude - 8.7) + elif amplitude >= 21.0: + beta = 0.5842 * (amplitude - 21) ** 0.4 + 0.07886 * (amplitude - 21.0) + else: + beta = 0.0 + window = torch.kaiser_window(kernel_size, beta=beta, periodic=False) + time = torch.arange(-half_size, half_size) + 0.5 if even else torch.arange(kernel_size) - half_size + if cutoff == 0: + filter_ = torch.zeros_like(time) + else: + filter_ = 2 * cutoff * window * _sinc(2 * cutoff * time) + filter_ /= filter_.sum() + return filter_.view(1, 1, kernel_size) + + +class LowPassFilter1d(nn.Module): + def __init__( + self, + cutoff: float = 0.5, + half_width: float = 0.6, + stride: int = 1, + padding: bool = True, + padding_mode: str = "replicate", + kernel_size: int = 12, + ) -> None: + super().__init__() + if cutoff < -0.0: + raise ValueError("Minimum cutoff must be larger than zero.") + if cutoff > 0.5: + raise ValueError("A cutoff above 0.5 does not make sense.") + self.kernel_size = kernel_size + self.even = kernel_size % 2 == 0 + self.pad_left = kernel_size // 2 - int(self.even) + self.pad_right = kernel_size // 2 + self.stride = stride + self.padding = padding + self.padding_mode = padding_mode + self.register_buffer("filter", kaiser_sinc_filter1d(cutoff, half_width, kernel_size)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + _, n_channels, _ = x.shape + if self.padding: + x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode) + return F.conv1d(x, self.filter.expand(n_channels, -1, -1), stride=self.stride, groups=n_channels) + + +class UpSample1d(nn.Module): + def __init__( + self, + ratio: int = 2, + kernel_size: int | None = None, + persistent: bool = True, + window_type: str = "kaiser", + ) -> None: + super().__init__() + self.ratio = ratio + self.stride = ratio + + if window_type == "hann": + # Hann-windowed sinc filter equivalent to torchaudio.functional.resample + rolloff = 0.99 + lowpass_filter_width = 6 + width = math.ceil(lowpass_filter_width / rolloff) + self.kernel_size = 2 * width * ratio + 1 + self.pad = width + self.pad_left = 2 * width * ratio + self.pad_right = self.kernel_size - ratio + time_axis = (torch.arange(self.kernel_size) / ratio - width) * rolloff + time_clamped = time_axis.clamp(-lowpass_filter_width, lowpass_filter_width) + window = torch.cos(time_clamped * math.pi / lowpass_filter_width / 2) ** 2 + sinc_filter = (torch.sinc(time_axis) * window * rolloff / ratio).view(1, 1, -1) + else: + # Kaiser-windowed sinc filter (BigVGAN default). + self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size + self.pad = self.kernel_size // ratio - 1 + self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2 + self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2 + sinc_filter = kaiser_sinc_filter1d( + cutoff=0.5 / ratio, + half_width=0.6 / ratio, + kernel_size=self.kernel_size, + ) + + self.register_buffer("filter", sinc_filter, persistent=persistent) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + _, n_channels, _ = x.shape + x = F.pad(x, (self.pad, self.pad), mode="replicate") + filt = self.filter.to(dtype=x.dtype, device=x.device).expand(n_channels, -1, -1) + x = self.ratio * F.conv_transpose1d(x, filt, stride=self.stride, groups=n_channels) + return x[..., self.pad_left : -self.pad_right] + + +class DownSample1d(nn.Module): + def __init__(self, ratio: int = 2, kernel_size: int | None = None) -> None: + super().__init__() + self.ratio = ratio + self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size + self.lowpass = LowPassFilter1d( + cutoff=0.5 / ratio, + half_width=0.6 / ratio, + stride=ratio, + kernel_size=self.kernel_size, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.lowpass(x) + + +class Activation1d(nn.Module): + def __init__( + self, + activation: nn.Module, + up_ratio: int = 2, + down_ratio: int = 2, + up_kernel_size: int = 12, + down_kernel_size: int = 12, + ) -> None: + super().__init__() + self.act = activation + self.upsample = UpSample1d(up_ratio, up_kernel_size) + self.downsample = DownSample1d(down_ratio, down_kernel_size) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.upsample(x) + x = self.act(x) + return self.downsample(x) + + +class Snake(nn.Module): + def __init__( + self, + in_features: int, + alpha: float = 1.0, + alpha_trainable: bool = True, + alpha_logscale: bool = True, + ) -> None: + super().__init__() + self.alpha_logscale = alpha_logscale + self.alpha = nn.Parameter(torch.zeros(in_features) if alpha_logscale else torch.ones(in_features) * alpha) + self.alpha.requires_grad = alpha_trainable + self.eps = 1e-9 + + def forward(self, x: torch.Tensor) -> torch.Tensor: + alpha = self.alpha.unsqueeze(0).unsqueeze(-1) + if self.alpha_logscale: + alpha = torch.exp(alpha) + return x + (1.0 / (alpha + self.eps)) * torch.sin(x * alpha).pow(2) + + +class SnakeBeta(nn.Module): + def __init__( + self, + in_features: int, + alpha: float = 1.0, + alpha_trainable: bool = True, + alpha_logscale: bool = True, + ) -> None: + super().__init__() + self.alpha_logscale = alpha_logscale + self.alpha = nn.Parameter(torch.zeros(in_features) if alpha_logscale else torch.ones(in_features) * alpha) + self.alpha.requires_grad = alpha_trainable + self.beta = nn.Parameter(torch.zeros(in_features) if alpha_logscale else torch.ones(in_features) * alpha) + self.beta.requires_grad = alpha_trainable + self.eps = 1e-9 + + def forward(self, x: torch.Tensor) -> torch.Tensor: + alpha = self.alpha.unsqueeze(0).unsqueeze(-1) + beta = self.beta.unsqueeze(0).unsqueeze(-1) + if self.alpha_logscale: + alpha = torch.exp(alpha) + beta = torch.exp(beta) + return x + (1.0 / (beta + self.eps)) * torch.sin(x * alpha).pow(2) + + +class AMPBlock1(nn.Module): + def __init__( + self, + channels: int, + kernel_size: int = 3, + dilation: tuple[int, int, int] = (1, 3, 5), + activation: str = "snake", + ) -> None: + super().__init__() + act_cls = SnakeBeta if activation == "snakebeta" else Snake + self.convs1 = nn.ModuleList( + [ + nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ), + nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ), + nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2]), + ), + ] + ) + + self.convs2 = nn.ModuleList( + [ + nn.Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1)), + nn.Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1)), + nn.Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1)), + ] + ) + + self.acts1 = nn.ModuleList([Activation1d(act_cls(channels)) for _ in range(len(self.convs1))]) + self.acts2 = nn.ModuleList([Activation1d(act_cls(channels)) for _ in range(len(self.convs2))]) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + for c1, c2, a1, a2 in zip(self.convs1, self.convs2, self.acts1, self.acts2, strict=True): + xt = a1(x) + xt = c1(xt) + xt = a2(xt) + xt = c2(xt) + x = x + xt + return x + + +class LTX2Vocoder(torch.nn.Module): + """ + LTX2Vocoder model for synthesizing audio from Mel spectrograms. + Args: + resblock_kernel_sizes: List of kernel sizes for the residual blocks. + This value is read from the checkpoint at `config.vocoder.resblock_kernel_sizes`. + upsample_rates: List of upsampling rates. + This value is read from the checkpoint at `config.vocoder.upsample_rates`. + upsample_kernel_sizes: List of kernel sizes for the upsampling layers. + This value is read from the checkpoint at `config.vocoder.upsample_kernel_sizes`. + resblock_dilation_sizes: List of dilation sizes for the residual blocks. + This value is read from the checkpoint at `config.vocoder.resblock_dilation_sizes`. + upsample_initial_channel: Initial number of channels for the upsampling layers. + This value is read from the checkpoint at `config.vocoder.upsample_initial_channel`. + resblock: Type of residual block to use ("1", "2", or "AMP1"). + This value is read from the checkpoint at `config.vocoder.resblock`. + output_sampling_rate: Waveform sample rate. + This value is read from the checkpoint at `config.vocoder.output_sampling_rate`. + activation: Activation type for BigVGAN v2 ("snake" or "snakebeta"). Only used when resblock="AMP1". + use_tanh_at_final: Apply tanh at the output (when apply_final_activation=True). + apply_final_activation: Whether to apply the final tanh/clamp activation. + use_bias_at_final: Whether to use bias in the final conv layer. + """ + + def __init__( # noqa: PLR0913 + self, + resblock_kernel_sizes: List[int] | None = [3, 7, 11], + upsample_rates: List[int] | None = [6, 5, 2, 2, 2], + upsample_kernel_sizes: List[int] | None = [16, 15, 8, 4, 4], + resblock_dilation_sizes: List[List[int]] | None = [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + upsample_initial_channel: int = 1024, + resblock: str = "1", + output_sampling_rate: int = 24000, + activation: str = "snake", + use_tanh_at_final: bool = True, + apply_final_activation: bool = True, + use_bias_at_final: bool = True, + ) -> None: + super().__init__() + + # Mutable default values are not supported as default arguments. + if resblock_kernel_sizes is None: + resblock_kernel_sizes = [3, 7, 11] + if upsample_rates is None: + upsample_rates = [6, 5, 2, 2, 2] + if upsample_kernel_sizes is None: + upsample_kernel_sizes = [16, 15, 8, 4, 4] + if resblock_dilation_sizes is None: + resblock_dilation_sizes = [[1, 3, 5], [1, 3, 5], [1, 3, 5]] + + self.output_sampling_rate = output_sampling_rate + self.num_kernels = len(resblock_kernel_sizes) + self.num_upsamples = len(upsample_rates) + self.use_tanh_at_final = use_tanh_at_final + self.apply_final_activation = apply_final_activation + self.is_amp = resblock == "AMP1" + + # All production checkpoints are stereo: 128 input channels (2 stereo channels x 64 mel + # bins each), 2 output channels. + self.conv_pre = nn.Conv1d( + in_channels=128, + out_channels=upsample_initial_channel, + kernel_size=7, + stride=1, + padding=3, + ) + resblock_cls = ResBlock1 if resblock == "1" else AMPBlock1 + + self.ups = nn.ModuleList( + nn.ConvTranspose1d( + upsample_initial_channel // (2**i), + upsample_initial_channel // (2 ** (i + 1)), + kernel_size, + stride, + padding=(kernel_size - stride) // 2, + ) + for i, (stride, kernel_size) in enumerate(zip(upsample_rates, upsample_kernel_sizes, strict=True)) + ) + + final_channels = upsample_initial_channel // (2 ** len(upsample_rates)) + self.resblocks = nn.ModuleList() + + for i in range(len(upsample_rates)): + ch = upsample_initial_channel // (2 ** (i + 1)) + for kernel_size, dilations in zip(resblock_kernel_sizes, resblock_dilation_sizes, strict=True): + if self.is_amp: + self.resblocks.append(resblock_cls(ch, kernel_size, dilations, activation=activation)) + else: + self.resblocks.append(resblock_cls(ch, kernel_size, dilations)) + + if self.is_amp: + self.act_post: nn.Module = Activation1d(SnakeBeta(final_channels)) + else: + self.act_post = nn.LeakyReLU() + + # All production checkpoints are stereo: this final conv maps `final_channels` to 2 output channels (stereo). + self.conv_post = nn.Conv1d( + in_channels=final_channels, + out_channels=2, + kernel_size=7, + stride=1, + padding=3, + bias=use_bias_at_final, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass of the vocoder. + Args: + x: Input Mel spectrogram tensor. Can be either: + - 3D: (batch_size, time, mel_bins) for mono + - 4D: (batch_size, 2, time, mel_bins) for stereo + Returns: + Audio waveform tensor of shape (batch_size, out_channels, audio_length) + """ + x = x.transpose(2, 3) # (batch, channels, time, mel_bins) -> (batch, channels, mel_bins, time) + + if x.dim() == 4: # stereo + assert x.shape[1] == 2, "Input must have 2 channels for stereo" + x = einops.rearrange(x, "b s c t -> b (s c) t") + + x = self.conv_pre(x) + + for i in range(self.num_upsamples): + if not self.is_amp: + x = F.leaky_relu(x, LRELU_SLOPE) + x = self.ups[i](x) + start = i * self.num_kernels + end = start + self.num_kernels + + # Evaluate all resblocks with the same input tensor so they can run + # independently (and thus in parallel on accelerator hardware) before + # aggregating their outputs via mean. + block_outputs = torch.stack( + [self.resblocks[idx](x) for idx in range(start, end)], + dim=0, + ) + x = block_outputs.mean(dim=0) + + x = self.act_post(x) + x = self.conv_post(x) + + if self.apply_final_activation: + x = torch.tanh(x) if self.use_tanh_at_final else torch.clamp(x, -1, 1) + + return x + + +class _STFTFn(nn.Module): + """Implements STFT as a convolution with precomputed DFT x Hann-window bases. + The DFT basis rows (real and imaginary parts interleaved) multiplied by the causal + Hann window are stored as buffers and loaded from the checkpoint. Using the exact + bfloat16 bases from training ensures the mel values fed to the BWE generator are + bit-identical to what it was trained on. + """ + + def __init__(self, filter_length: int, hop_length: int, win_length: int) -> None: + super().__init__() + self.hop_length = hop_length + self.win_length = win_length + n_freqs = filter_length // 2 + 1 + self.register_buffer("forward_basis", torch.zeros(n_freqs * 2, 1, filter_length)) + self.register_buffer("inverse_basis", torch.zeros(n_freqs * 2, 1, filter_length)) + + def forward(self, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """Compute magnitude and phase spectrogram from a batch of waveforms. + Applies causal (left-only) padding of win_length - hop_length samples so that + each output frame depends only on past and present input — no lookahead. + Args: + y: Waveform tensor of shape (B, T). + Returns: + magnitude: Linear amplitude spectrogram, shape (B, n_freqs, T_frames). + phase: Phase spectrogram in radians, shape (B, n_freqs, T_frames). + """ + if y.dim() == 2: + y = y.unsqueeze(1) # (B, 1, T) + left_pad = max(0, self.win_length - self.hop_length) # causal: left-only + y = F.pad(y, (left_pad, 0)) + spec = F.conv1d(y, self.forward_basis, stride=self.hop_length, padding=0) + n_freqs = spec.shape[1] // 2 + real, imag = spec[:, :n_freqs], spec[:, n_freqs:] + magnitude = torch.sqrt(real**2 + imag**2) + phase = torch.atan2(imag.float(), real.float()).to(real.dtype) + return magnitude, phase + + +class MelSTFT(nn.Module): + """Causal log-mel spectrogram module whose buffers are loaded from the checkpoint. + Computes a log-mel spectrogram by running the causal STFT (_STFTFn) on the input + waveform and projecting the linear magnitude spectrum onto the mel filterbank. + The module's state dict layout matches the 'mel_stft.*' keys stored in the checkpoint + (mel_basis, stft_fn.forward_basis, stft_fn.inverse_basis). + """ + + def __init__( + self, + filter_length: int, + hop_length: int, + win_length: int, + n_mel_channels: int, + ) -> None: + super().__init__() + self.stft_fn = _STFTFn(filter_length, hop_length, win_length) + + # Initialized to zeros; load_state_dict overwrites with the checkpoint's + # exact bfloat16 filterbank (vocoder.mel_stft.mel_basis, shape [n_mels, n_freqs]). + n_freqs = filter_length // 2 + 1 + self.register_buffer("mel_basis", torch.zeros(n_mel_channels, n_freqs)) + + def mel_spectrogram(self, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Compute log-mel spectrogram and auxiliary spectral quantities. + Args: + y: Waveform tensor of shape (B, T). + Returns: + log_mel: Log-compressed mel spectrogram, shape (B, n_mel_channels, T_frames). + magnitude: Linear amplitude spectrogram, shape (B, n_freqs, T_frames). + phase: Phase spectrogram in radians, shape (B, n_freqs, T_frames). + energy: Per-frame energy (L2 norm over frequency), shape (B, T_frames). + """ + magnitude, phase = self.stft_fn(y) + energy = torch.norm(magnitude, dim=1) + mel = torch.matmul(self.mel_basis.to(magnitude.dtype), magnitude) + log_mel = torch.log(torch.clamp(mel, min=1e-5)) + return log_mel, magnitude, phase, energy + + +class LTX2VocoderWithBWE(nn.Module): + """LTX2Vocoder with bandwidth extension (BWE) upsampling. + Chains a mel-to-wav vocoder with a BWE module that upsamples the output + to a higher sample rate. The BWE computes a mel spectrogram from the + vocoder output, runs it through a second generator to predict a residual, + and adds it to a sinc-resampled skip connection. + """ + + def __init__( + self, + input_sampling_rate: int = 16000, + output_sampling_rate: int = 48000, + hop_length: int = 80, + ) -> None: + super().__init__() + self.vocoder = LTX2Vocoder( + resblock_kernel_sizes=[3, 7, 11], + upsample_rates=[5, 2, 2, 2, 2, 2], + upsample_kernel_sizes=[11, 4, 4, 4, 4, 4], + resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]], + upsample_initial_channel=1536, + resblock="AMP1", + activation="snakebeta", + use_tanh_at_final=False, + apply_final_activation=True, + use_bias_at_final=False, + output_sampling_rate=input_sampling_rate, + ) + self.bwe_generator = LTX2Vocoder( + resblock_kernel_sizes=[3, 7, 11], + upsample_rates=[6, 5, 2, 2, 2], + upsample_kernel_sizes=[12, 11, 4, 4, 4], + resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]], + upsample_initial_channel=512, + resblock="AMP1", + activation="snakebeta", + use_tanh_at_final=False, + apply_final_activation=False, + use_bias_at_final=False, + output_sampling_rate=output_sampling_rate, + ) + + self.mel_stft = MelSTFT( + filter_length=512, + hop_length=hop_length, + win_length=512, + n_mel_channels=64, + ) + self.input_sampling_rate = input_sampling_rate + self.output_sampling_rate = output_sampling_rate + self.hop_length = hop_length + # Compute the resampler on CPU so the sinc filter is materialized even when + # the model is constructed on meta device (SingleGPUModelBuilder pattern). + # The filter is not stored in the checkpoint (persistent=False). + with torch.device("cpu"): + self.resampler = UpSample1d( + ratio=output_sampling_rate // input_sampling_rate, persistent=False, window_type="hann" + ) + + @property + def conv_pre(self) -> nn.Conv1d: + return self.vocoder.conv_pre + + @property + def conv_post(self) -> nn.Conv1d: + return self.vocoder.conv_post + + def _compute_mel(self, audio: torch.Tensor) -> torch.Tensor: + """Compute log-mel spectrogram from waveform using causal STFT bases. + Args: + audio: Waveform tensor of shape (B, C, T). + Returns: + mel: Log-mel spectrogram of shape (B, C, n_mels, T_frames). + """ + batch, n_channels, _ = audio.shape + flat = audio.reshape(batch * n_channels, -1) # (B*C, T) + mel, _, _, _ = self.mel_stft.mel_spectrogram(flat) # (B*C, n_mels, T_frames) + return mel.reshape(batch, n_channels, mel.shape[1], mel.shape[2]) # (B, C, n_mels, T_frames) + + def forward(self, mel_spec: torch.Tensor) -> torch.Tensor: + """Run the full vocoder + BWE forward pass. + Args: + mel_spec: Mel spectrogram of shape (B, 2, T, mel_bins) for stereo + or (B, T, mel_bins) for mono. Same format as LTX2Vocoder.forward. + Returns: + Waveform tensor of shape (B, out_channels, T_out) clipped to [-1, 1]. + """ + x = self.vocoder(mel_spec) + _, _, length_low_rate = x.shape + output_length = length_low_rate * self.output_sampling_rate // self.input_sampling_rate + + # Pad to multiple of hop_length for exact mel frame count + remainder = length_low_rate % self.hop_length + if remainder != 0: + x = F.pad(x, (0, self.hop_length - remainder)) + + # Compute mel spectrogram from vocoder output: (B, C, n_mels, T_frames) + mel = self._compute_mel(x) + + # LTX2Vocoder.forward expects (B, C, T, mel_bins) — transpose before calling bwe_generator + mel_for_bwe = mel.transpose(2, 3) # (B, C, T_frames, mel_bins) + residual = self.bwe_generator(mel_for_bwe) + skip = self.resampler(x) + assert residual.shape == skip.shape, f"residual {residual.shape} != skip {skip.shape}" + + return torch.clamp(residual + skip, -1, 1)[..., :output_length] diff --git a/diffsynth/models/ltx2_common.py b/diffsynth/models/ltx2_common.py new file mode 100644 index 0000000000000000000000000000000000000000..a658ec6ecb76bbf7372a0e0893ec52b2a8307ddd --- /dev/null +++ b/diffsynth/models/ltx2_common.py @@ -0,0 +1,388 @@ +from dataclasses import dataclass +from typing import NamedTuple, Protocol, Tuple +import torch +from torch import nn +from enum import Enum + + +class VideoPixelShape(NamedTuple): + """ + Shape of the tensor representing the video pixel array. Assumes BGR channel format. + """ + + batch: int + frames: int + height: int + width: int + fps: float + + +class SpatioTemporalScaleFactors(NamedTuple): + """ + Describes the spatiotemporal downscaling between decoded video space and + the corresponding VAE latent grid. + """ + + time: int + width: int + height: int + + @classmethod + def default(cls) -> "SpatioTemporalScaleFactors": + return cls(time=8, width=32, height=32) + + +VIDEO_SCALE_FACTORS = SpatioTemporalScaleFactors.default() + + +class VideoLatentShape(NamedTuple): + """ + Shape of the tensor representing video in VAE latent space. + The latent representation is a 5D tensor with dimensions ordered as + (batch, channels, frames, height, width). Spatial and temporal dimensions + are downscaled relative to pixel space according to the VAE's scale factors. + """ + + batch: int + channels: int + frames: int + height: int + width: int + + def to_torch_shape(self) -> torch.Size: + return torch.Size([self.batch, self.channels, self.frames, self.height, self.width]) + + @staticmethod + def from_torch_shape(shape: torch.Size) -> "VideoLatentShape": + return VideoLatentShape( + batch=shape[0], + channels=shape[1], + frames=shape[2], + height=shape[3], + width=shape[4], + ) + + def mask_shape(self) -> "VideoLatentShape": + return self._replace(channels=1) + + @staticmethod + def from_pixel_shape( + shape: VideoPixelShape, + latent_channels: int = 128, + scale_factors: SpatioTemporalScaleFactors = VIDEO_SCALE_FACTORS, + ) -> "VideoLatentShape": + frames = (shape.frames - 1) // scale_factors[0] + 1 + height = shape.height // scale_factors[1] + width = shape.width // scale_factors[2] + + return VideoLatentShape( + batch=shape.batch, + channels=latent_channels, + frames=frames, + height=height, + width=width, + ) + + def upscale(self, scale_factors: SpatioTemporalScaleFactors = VIDEO_SCALE_FACTORS) -> "VideoLatentShape": + return self._replace( + channels=3, + frames=(self.frames - 1) * scale_factors.time + 1, + height=self.height * scale_factors.height, + width=self.width * scale_factors.width, + ) + + +class AudioLatentShape(NamedTuple): + """ + Shape of audio in VAE latent space: (batch, channels, frames, mel_bins). + mel_bins is the number of frequency bins from the mel-spectrogram encoding. + """ + + batch: int + channels: int + frames: int + mel_bins: int + + def to_torch_shape(self) -> torch.Size: + return torch.Size([self.batch, self.channels, self.frames, self.mel_bins]) + + def mask_shape(self) -> "AudioLatentShape": + return self._replace(channels=1, mel_bins=1) + + @staticmethod + def from_torch_shape(shape: torch.Size) -> "AudioLatentShape": + return AudioLatentShape( + batch=shape[0], + channels=shape[1], + frames=shape[2], + mel_bins=shape[3], + ) + + @staticmethod + def from_duration( + batch: int, + duration: float, + channels: int = 8, + mel_bins: int = 16, + sample_rate: int = 16000, + hop_length: int = 160, + audio_latent_downsample_factor: int = 4, + ) -> "AudioLatentShape": + latents_per_second = float(sample_rate) / float(hop_length) / float(audio_latent_downsample_factor) + + return AudioLatentShape( + batch=batch, + channels=channels, + frames=round(duration * latents_per_second), + mel_bins=mel_bins, + ) + + @staticmethod + def from_video_pixel_shape( + shape: VideoPixelShape, + channels: int = 8, + mel_bins: int = 16, + sample_rate: int = 16000, + hop_length: int = 160, + audio_latent_downsample_factor: int = 4, + ) -> "AudioLatentShape": + return AudioLatentShape.from_duration( + batch=shape.batch, + duration=float(shape.frames) / float(shape.fps), + channels=channels, + mel_bins=mel_bins, + sample_rate=sample_rate, + hop_length=hop_length, + audio_latent_downsample_factor=audio_latent_downsample_factor, + ) + + +@dataclass(frozen=True) +class LatentState: + """ + State of latents during the diffusion denoising process. + Attributes: + latent: The current noisy latent tensor being denoised. + denoise_mask: Mask encoding the denoising strength for each token (1 = full denoising, 0 = no denoising). + positions: Positional indices for each latent element, used for positional embeddings. + clean_latent: Initial state of the latent before denoising, may include conditioning latents. + """ + + latent: torch.Tensor + denoise_mask: torch.Tensor + positions: torch.Tensor + clean_latent: torch.Tensor + + def clone(self) -> "LatentState": + return LatentState( + latent=self.latent.clone(), + denoise_mask=self.denoise_mask.clone(), + positions=self.positions.clone(), + clean_latent=self.clean_latent.clone(), + ) + + +class NormType(Enum): + """Normalization layer types: GROUP (GroupNorm) or PIXEL (per-location RMS norm).""" + + GROUP = "group" + PIXEL = "pixel" + + +class PixelNorm(nn.Module): + """ + Per-pixel (per-location) RMS normalization layer. + For each element along the chosen dimension, this layer normalizes the tensor + by the root-mean-square of its values across that dimension: + y = x / sqrt(mean(x^2, dim=dim, keepdim=True) + eps) + """ + + def __init__(self, dim: int = 1, eps: float = 1e-8) -> None: + """ + Args: + dim: Dimension along which to compute the RMS (typically channels). + eps: Small constant added for numerical stability. + """ + super().__init__() + self.dim = dim + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Apply RMS normalization along the configured dimension. + """ + # Compute mean of squared values along `dim`, keep dimensions for broadcasting. + mean_sq = torch.mean(x**2, dim=self.dim, keepdim=True) + # Normalize by the root-mean-square (RMS). + rms = torch.sqrt(mean_sq + self.eps) + return x / rms + + +def build_normalization_layer( + in_channels: int, *, num_groups: int = 32, normtype: NormType = NormType.GROUP +) -> nn.Module: + """ + Create a normalization layer based on the normalization type. + Args: + in_channels: Number of input channels + num_groups: Number of groups for group normalization + normtype: Type of normalization: "group" or "pixel" + Returns: + A normalization layer + """ + if normtype == NormType.GROUP: + return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) + if normtype == NormType.PIXEL: + return PixelNorm(dim=1, eps=1e-6) + raise ValueError(f"Invalid normalization type: {normtype}") + + +def rms_norm(x: torch.Tensor, weight: torch.Tensor | None = None, eps: float = 1e-6) -> torch.Tensor: + """Root-mean-square (RMS) normalize `x` over its last dimension. + Thin wrapper around `torch.nn.functional.rms_norm` that infers the normalized + shape and forwards `weight` and `eps`. + """ + return torch.nn.functional.rms_norm(x, (x.shape[-1],), weight=weight, eps=eps) + + +@dataclass(frozen=True) +class Modality: + """ + Input data for a single modality (video or audio) in the transformer. + Bundles the latent tokens, timestep embeddings, positional information, + and text conditioning context for processing by the diffusion transformer. + Attributes: + latent: Patchified latent tokens, shape ``(B, T, D)`` where *B* is + the batch size, *T* is the total number of tokens (noisy + + conditioning), and *D* is the input dimension. + timesteps: Per-token timestep embeddings, shape ``(B, T)``. + positions: Positional coordinates, shape ``(B, 3, T)`` for video + (time, height, width) or ``(B, 1, T)`` for audio. + context: Text conditioning embeddings from the prompt encoder. + enabled: Whether this modality is active in the current forward pass. + context_mask: Optional mask for the text context tokens. + attention_mask: Optional 2-D self-attention mask, shape ``(B, T, T)``. + Values in ``[0, 1]`` where ``1`` = full attention and ``0`` = no + attention. ``None`` means unrestricted (full) attention between + all tokens. Built incrementally by conditioning items; see + :class:`~ltx_core.conditioning.types.attention_strength_wrapper.ConditioningItemAttentionStrengthWrapper`. + """ + + latent: ( + torch.Tensor + ) # Shape: (B, T, D) where B is the batch size, T is the number of tokens, and D is input dimension + sigma: torch.Tensor # Shape: (B,). Current sigma value, used for cross-attention timestep calculation. + timesteps: torch.Tensor # Shape: (B, T) where T is the number of timesteps + positions: ( + torch.Tensor + ) # Shape: (B, 3, T) for video, where 3 is the number of dimensions and T is the number of tokens + context: torch.Tensor + enabled: bool = True + context_mask: torch.Tensor | None = None + attention_mask: torch.Tensor | None = None + + +def to_denoised( + sample: torch.Tensor, + velocity: torch.Tensor, + sigma: float | torch.Tensor, + calc_dtype: torch.dtype = torch.float32, +) -> torch.Tensor: + """ + Convert the sample and its denoising velocity to denoised sample. + Returns: + Denoised sample + """ + if isinstance(sigma, torch.Tensor): + sigma = sigma.to(calc_dtype) + return (sample.to(calc_dtype) - velocity.to(calc_dtype) * sigma).to(sample.dtype) + + + +class Patchifier(Protocol): + """ + Protocol for patchifiers that convert latent tensors into patches and assemble them back. + """ + + def patchify( + self, + latents: torch.Tensor, + ) -> torch.Tensor: + ... + """ + Convert latent tensors into flattened patch tokens. + Args: + latents: Latent tensor to patchify. + Returns: + Flattened patch tokens tensor. + """ + + def unpatchify( + self, + latents: torch.Tensor, + output_shape: AudioLatentShape | VideoLatentShape, + ) -> torch.Tensor: + """ + Converts latent tensors between spatio-temporal formats and flattened sequence representations. + Args: + latents: Patch tokens that must be rearranged back into the latent grid constructed by `patchify`. + output_shape: Shape of the output tensor. Note that output_shape is either AudioLatentShape or + VideoLatentShape. + Returns: + Dense latent tensor restored from the flattened representation. + """ + + @property + def patch_size(self) -> Tuple[int, int, int]: + ... + """ + Returns the patch size as a tuple of (temporal, height, width) dimensions + """ + + def get_patch_grid_bounds( + self, + output_shape: AudioLatentShape | VideoLatentShape, + device: torch.device | None = None, + ) -> torch.Tensor: + ... + """ + Compute metadata describing where each latent patch resides within the + grid specified by `output_shape`. + Args: + output_shape: Target grid layout for the patches. + device: Target device for the returned tensor. + Returns: + Tensor containing patch coordinate metadata such as spatial or temporal intervals. + """ + + +def get_pixel_coords( + latent_coords: torch.Tensor, + scale_factors: SpatioTemporalScaleFactors, + causal_fix: bool = False, +) -> torch.Tensor: + """ + Map latent-space `[start, end)` coordinates to their pixel-space equivalents by scaling + each axis (frame/time, height, width) with the corresponding VAE downsampling factors. + Optionally compensate for causal encoding that keeps the first frame at unit temporal scale. + Args: + latent_coords: Tensor of latent bounds shaped `(batch, 3, num_patches, 2)`. + scale_factors: SpatioTemporalScaleFactors tuple `(temporal, height, width)` with integer scale factors applied + per axis. + causal_fix: When True, rewrites the temporal axis of the first frame so causal VAEs + that treat frame zero differently still yield non-negative timestamps. + """ + # Broadcast the VAE scale factors so they align with the `(batch, axis, patch, bound)` layout. + broadcast_shape = [1] * latent_coords.ndim + broadcast_shape[1] = -1 # axis dimension corresponds to (frame/time, height, width) + scale_tensor = torch.tensor(scale_factors, device=latent_coords.device).view(*broadcast_shape) + + # Apply per-axis scaling to convert latent bounds into pixel-space coordinates. + pixel_coords = latent_coords * scale_tensor + + if causal_fix: + # VAE temporal stride for the very first frame is 1 instead of `scale_factors[0]`. + # Shift and clamp to keep the first-frame timestamps causal and non-negative. + pixel_coords[:, 0, ...] = (pixel_coords[:, 0, ...] + 1 - scale_factors[0]).clamp(min=0) + + return pixel_coords diff --git a/diffsynth/models/ltx2_dit.py b/diffsynth/models/ltx2_dit.py new file mode 100644 index 0000000000000000000000000000000000000000..9df0ed3a78b5ee074386c67d2f2bc931fd6ed19b --- /dev/null +++ b/diffsynth/models/ltx2_dit.py @@ -0,0 +1,1683 @@ +import math +import functools +from dataclasses import dataclass, replace +from enum import Enum +from typing import Optional, Tuple, Callable +import numpy as np +import torch +from einops import rearrange +from .ltx2_common import rms_norm, Modality +from ..core.attention.attention import attention_forward +from ..core import gradient_checkpoint_forward + + +def get_timestep_embedding( + timesteps: torch.Tensor, + embedding_dim: int, + flip_sin_to_cos: bool = False, + downscale_freq_shift: float = 1, + scale: float = 1, + max_period: int = 10000, +) -> torch.Tensor: + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. + Args + timesteps (torch.Tensor): + a 1-D Tensor of N indices, one per batch element. These may be fractional. + embedding_dim (int): + the dimension of the output. + flip_sin_to_cos (bool): + Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False) + downscale_freq_shift (float): + Controls the delta between frequencies between dimensions + scale (float): + Scaling factor applied to the embeddings. + max_period (int): + Controls the maximum frequency of the embeddings + Returns + torch.Tensor: an [N x dim] Tensor of positional embeddings. + """ + assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" + + half_dim = embedding_dim // 2 + exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device) + exponent = exponent / (half_dim - downscale_freq_shift) + + emb = torch.exp(exponent) + emb = timesteps[:, None].float() * emb[None, :] + + # scale embeddings + emb = scale * emb + + # concat sine and cosine embeddings + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) + + # flip sine and cosine embeddings + if flip_sin_to_cos: + emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) + + # zero pad + if embedding_dim % 2 == 1: + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + + +class TimestepEmbedding(torch.nn.Module): + def __init__( + self, + in_channels: int, + time_embed_dim: int, + out_dim: int | None = None, + post_act_fn: str | None = None, + cond_proj_dim: int | None = None, + sample_proj_bias: bool = True, + ): + super().__init__() + + self.linear_1 = torch.nn.Linear(in_channels, time_embed_dim, sample_proj_bias) + + if cond_proj_dim is not None: + self.cond_proj = torch.nn.Linear(cond_proj_dim, in_channels, bias=False) + else: + self.cond_proj = None + + self.act = torch.nn.SiLU() + time_embed_dim_out = out_dim if out_dim is not None else time_embed_dim + + self.linear_2 = torch.nn.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias) + + if post_act_fn is None: + self.post_act = None + + def forward(self, sample: torch.Tensor, condition: torch.Tensor | None = None) -> torch.Tensor: + if condition is not None: + sample = sample + self.cond_proj(condition) + sample = self.linear_1(sample) + + if self.act is not None: + sample = self.act(sample) + + sample = self.linear_2(sample) + + if self.post_act is not None: + sample = self.post_act(sample) + return sample + + +class Timesteps(torch.nn.Module): + def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1): + super().__init__() + self.num_channels = num_channels + self.flip_sin_to_cos = flip_sin_to_cos + self.downscale_freq_shift = downscale_freq_shift + self.scale = scale + + def forward(self, timesteps: torch.Tensor) -> torch.Tensor: + t_emb = get_timestep_embedding( + timesteps, + self.num_channels, + flip_sin_to_cos=self.flip_sin_to_cos, + downscale_freq_shift=self.downscale_freq_shift, + scale=self.scale, + ) + return t_emb + + +class PixArtAlphaCombinedTimestepSizeEmbeddings(torch.nn.Module): + """ + For PixArt-Alpha. + Reference: + https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29 + """ + + def __init__( + self, + embedding_dim: int, + size_emb_dim: int, + ): + super().__init__() + + self.outdim = size_emb_dim + self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) + self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) + + def forward( + self, + timestep: torch.Tensor, + hidden_dtype: torch.dtype, + ) -> torch.Tensor: + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D) + return timesteps_emb + + +class PerturbationType(Enum): + """Types of attention perturbations for STG (Spatio-Temporal Guidance).""" + + SKIP_A2V_CROSS_ATTN = "skip_a2v_cross_attn" + SKIP_V2A_CROSS_ATTN = "skip_v2a_cross_attn" + SKIP_VIDEO_SELF_ATTN = "skip_video_self_attn" + SKIP_AUDIO_SELF_ATTN = "skip_audio_self_attn" + + +@dataclass(frozen=True) +class Perturbation: + """A single perturbation specifying which attention type to skip and in which blocks.""" + + type: PerturbationType + blocks: list[int] | None # None means all blocks + + def is_perturbed(self, perturbation_type: PerturbationType, block: int) -> bool: + if self.type != perturbation_type: + return False + + if self.blocks is None: + return True + + return block in self.blocks + + +@dataclass(frozen=True) +class PerturbationConfig: + """Configuration holding a list of perturbations for a single sample.""" + + perturbations: list[Perturbation] | None + + def is_perturbed(self, perturbation_type: PerturbationType, block: int) -> bool: + if self.perturbations is None: + return False + + return any(perturbation.is_perturbed(perturbation_type, block) for perturbation in self.perturbations) + + @staticmethod + def empty() -> "PerturbationConfig": + return PerturbationConfig([]) + + +@dataclass(frozen=True) +class BatchedPerturbationConfig: + """Perturbation configurations for a batch, with utilities for generating attention masks.""" + + perturbations: list[PerturbationConfig] + + def mask( + self, perturbation_type: PerturbationType, block: int, device, dtype: torch.dtype + ) -> torch.Tensor: + mask = torch.ones((len(self.perturbations),), device=device, dtype=dtype) + for batch_idx, perturbation in enumerate(self.perturbations): + if perturbation.is_perturbed(perturbation_type, block): + mask[batch_idx] = 0 + + return mask + + def mask_like(self, perturbation_type: PerturbationType, block: int, values: torch.Tensor) -> torch.Tensor: + mask = self.mask(perturbation_type, block, values.device, values.dtype) + return mask.view(mask.numel(), *([1] * len(values.shape[1:]))) + + def any_in_batch(self, perturbation_type: PerturbationType, block: int) -> bool: + return any(perturbation.is_perturbed(perturbation_type, block) for perturbation in self.perturbations) + + def all_in_batch(self, perturbation_type: PerturbationType, block: int) -> bool: + return all(perturbation.is_perturbed(perturbation_type, block) for perturbation in self.perturbations) + + @staticmethod + def empty(batch_size: int) -> "BatchedPerturbationConfig": + return BatchedPerturbationConfig([PerturbationConfig.empty() for _ in range(batch_size)]) + + + +ADALN_NUM_BASE_PARAMS = 6 +# Cross-attention AdaLN adds 3 more (scale, shift, gate) for the CA norm. +ADALN_NUM_CROSS_ATTN_PARAMS = 3 + + +def adaln_embedding_coefficient(cross_attention_adaln: bool) -> int: + """Total number of AdaLN parameters per block.""" + return ADALN_NUM_BASE_PARAMS + (ADALN_NUM_CROSS_ATTN_PARAMS if cross_attention_adaln else 0) + + +class AdaLayerNormSingle(torch.nn.Module): + r""" + Norm layer adaptive layer norm single (adaLN-single). + As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3). + Parameters: + embedding_dim (`int`): The size of each embedding vector. + use_additional_conditions (`bool`): To use additional conditions for normalization or not. + """ + + def __init__(self, embedding_dim: int, embedding_coefficient: int = 6): + super().__init__() + + self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings( + embedding_dim, + size_emb_dim=embedding_dim // 3, + ) + + self.silu = torch.nn.SiLU() + self.linear = torch.nn.Linear(embedding_dim, embedding_coefficient * embedding_dim, bias=True) + + def forward( + self, + timestep: torch.Tensor, + hidden_dtype: Optional[torch.dtype] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + embedded_timestep = self.emb(timestep, hidden_dtype=hidden_dtype) + return self.linear(self.silu(embedded_timestep)), embedded_timestep + + +class LTXRopeType(Enum): + INTERLEAVED = "interleaved" + SPLIT = "split" + + +def apply_rotary_emb( + input_tensor: torch.Tensor, + freqs_cis: Tuple[torch.Tensor, torch.Tensor], + rope_type: LTXRopeType = LTXRopeType.INTERLEAVED, +) -> torch.Tensor: + if rope_type == LTXRopeType.INTERLEAVED: + return apply_interleaved_rotary_emb(input_tensor, *freqs_cis) + elif rope_type == LTXRopeType.SPLIT: + return apply_split_rotary_emb(input_tensor, *freqs_cis) + else: + raise ValueError(f"Invalid rope type: {rope_type}") + + + +def apply_interleaved_rotary_emb( + input_tensor: torch.Tensor, cos_freqs: torch.Tensor, sin_freqs: torch.Tensor +) -> torch.Tensor: + t_dup = rearrange(input_tensor, "... (d r) -> ... d r", r=2) + t1, t2 = t_dup.unbind(dim=-1) + t_dup = torch.stack((-t2, t1), dim=-1) + input_tensor_rot = rearrange(t_dup, "... d r -> ... (d r)") + + out = input_tensor * cos_freqs + input_tensor_rot * sin_freqs + + return out + + +def apply_split_rotary_emb( + input_tensor: torch.Tensor, cos_freqs: torch.Tensor, sin_freqs: torch.Tensor +) -> torch.Tensor: + needs_reshape = False + if input_tensor.ndim != 4 and cos_freqs.ndim == 4: + b, h, t, _ = cos_freqs.shape + input_tensor = input_tensor.reshape(b, t, h, -1).swapaxes(1, 2) + needs_reshape = True + + split_input = rearrange(input_tensor, "... (d r) -> ... d r", d=2) + first_half_input = split_input[..., :1, :] + second_half_input = split_input[..., 1:, :] + + output = split_input * cos_freqs.unsqueeze(-2) + first_half_output = output[..., :1, :] + second_half_output = output[..., 1:, :] + + first_half_output.addcmul_(-sin_freqs.unsqueeze(-2), second_half_input) + second_half_output.addcmul_(sin_freqs.unsqueeze(-2), first_half_input) + + output = rearrange(output, "... d r -> ... (d r)") + if needs_reshape: + output = output.swapaxes(1, 2).reshape(b, t, -1) + + return output + + +@functools.lru_cache(maxsize=5) +def generate_freq_grid_np( + positional_embedding_theta: float, positional_embedding_max_pos_count: int, inner_dim: int +) -> torch.Tensor: + theta = positional_embedding_theta + start = 1 + end = theta + + n_elem = 2 * positional_embedding_max_pos_count + pow_indices = np.power( + theta, + np.linspace( + np.log(start) / np.log(theta), + np.log(end) / np.log(theta), + inner_dim // n_elem, + dtype=np.float64, + ), + ) + return torch.tensor(pow_indices * math.pi / 2, dtype=torch.float32) + + +@functools.lru_cache(maxsize=5) +def generate_freq_grid_pytorch( + positional_embedding_theta: float, positional_embedding_max_pos_count: int, inner_dim: int +) -> torch.Tensor: + theta = positional_embedding_theta + start = 1 + end = theta + n_elem = 2 * positional_embedding_max_pos_count + + indices = theta ** ( + torch.linspace( + math.log(start, theta), + math.log(end, theta), + inner_dim // n_elem, + dtype=torch.float32, + ) + ) + indices = indices.to(dtype=torch.float32) + + indices = indices * math.pi / 2 + + return indices + + +def get_fractional_positions(indices_grid: torch.Tensor, max_pos: list[int]) -> torch.Tensor: + n_pos_dims = indices_grid.shape[1] + assert n_pos_dims == len(max_pos), ( + f"Number of position dimensions ({n_pos_dims}) must match max_pos length ({len(max_pos)})" + ) + fractional_positions = torch.stack( + [indices_grid[:, i] / max_pos[i] for i in range(n_pos_dims)], + dim=-1, + ) + return fractional_positions + + +def generate_freqs( + indices: torch.Tensor, indices_grid: torch.Tensor, max_pos: list[int], use_middle_indices_grid: bool +) -> torch.Tensor: + if use_middle_indices_grid: + assert len(indices_grid.shape) == 4 + assert indices_grid.shape[-1] == 2 + indices_grid_start, indices_grid_end = indices_grid[..., 0], indices_grid[..., 1] + indices_grid = (indices_grid_start + indices_grid_end) / 2.0 + elif len(indices_grid.shape) == 4: + indices_grid = indices_grid[..., 0] + + fractional_positions = get_fractional_positions(indices_grid, max_pos) + indices = indices.to(device=fractional_positions.device) + + freqs = (indices * (fractional_positions.unsqueeze(-1) * 2 - 1)).transpose(-1, -2).flatten(2) + return freqs + + +def split_freqs_cis(freqs: torch.Tensor, pad_size: int, num_attention_heads: int) -> tuple[torch.Tensor, torch.Tensor]: + cos_freq = freqs.cos() + sin_freq = freqs.sin() + + if pad_size != 0: + cos_padding = torch.ones_like(cos_freq[:, :, :pad_size]) + sin_padding = torch.zeros_like(sin_freq[:, :, :pad_size]) + + cos_freq = torch.concatenate([cos_padding, cos_freq], axis=-1) + sin_freq = torch.concatenate([sin_padding, sin_freq], axis=-1) + + # Reshape freqs to be compatible with multi-head attention + b = cos_freq.shape[0] + t = cos_freq.shape[1] + + cos_freq = cos_freq.reshape(b, t, num_attention_heads, -1) + sin_freq = sin_freq.reshape(b, t, num_attention_heads, -1) + + cos_freq = torch.swapaxes(cos_freq, 1, 2) # (B,H,T,D//2) + sin_freq = torch.swapaxes(sin_freq, 1, 2) # (B,H,T,D//2) + return cos_freq, sin_freq + + +def interleaved_freqs_cis(freqs: torch.Tensor, pad_size: int) -> tuple[torch.Tensor, torch.Tensor]: + cos_freq = freqs.cos().repeat_interleave(2, dim=-1) + sin_freq = freqs.sin().repeat_interleave(2, dim=-1) + if pad_size != 0: + cos_padding = torch.ones_like(cos_freq[:, :, :pad_size]) + sin_padding = torch.zeros_like(cos_freq[:, :, :pad_size]) + cos_freq = torch.cat([cos_padding, cos_freq], dim=-1) + sin_freq = torch.cat([sin_padding, sin_freq], dim=-1) + return cos_freq, sin_freq + + +def precompute_freqs_cis( + indices_grid: torch.Tensor, + dim: int, + out_dtype: torch.dtype, + theta: float = 10000.0, + max_pos: list[int] | None = None, + use_middle_indices_grid: bool = False, + num_attention_heads: int = 32, + rope_type: LTXRopeType = LTXRopeType.INTERLEAVED, + freq_grid_generator: Callable[[float, int, int, torch.device], torch.Tensor] = generate_freq_grid_pytorch, +) -> tuple[torch.Tensor, torch.Tensor]: + if max_pos is None: + max_pos = [20, 2048, 2048] + + indices = freq_grid_generator(theta, indices_grid.shape[1], dim) + freqs = generate_freqs(indices, indices_grid, max_pos, use_middle_indices_grid) + + if rope_type == LTXRopeType.SPLIT: + expected_freqs = dim // 2 + current_freqs = freqs.shape[-1] + pad_size = expected_freqs - current_freqs + cos_freq, sin_freq = split_freqs_cis(freqs, pad_size, num_attention_heads) + else: + # 2 because of cos and sin by 3 for (t, x, y), 1 for temporal only + n_elem = 2 * indices_grid.shape[1] + cos_freq, sin_freq = interleaved_freqs_cis(freqs, dim % n_elem) + return cos_freq.to(out_dtype), sin_freq.to(out_dtype) + + +class Attention(torch.nn.Module): + def __init__( + self, + query_dim: int, + context_dim: int | None = None, + heads: int = 8, + dim_head: int = 64, + norm_eps: float = 1e-6, + rope_type: LTXRopeType = LTXRopeType.INTERLEAVED, + apply_gated_attention: bool = False, + ) -> None: + super().__init__() + self.rope_type = rope_type + + inner_dim = dim_head * heads + context_dim = query_dim if context_dim is None else context_dim + + self.heads = heads + self.dim_head = dim_head + + self.q_norm = torch.nn.RMSNorm(inner_dim, eps=norm_eps) + self.k_norm = torch.nn.RMSNorm(inner_dim, eps=norm_eps) + + self.to_q = torch.nn.Linear(query_dim, inner_dim, bias=True) + self.to_k = torch.nn.Linear(context_dim, inner_dim, bias=True) + self.to_v = torch.nn.Linear(context_dim, inner_dim, bias=True) + + # Optional per-head gating + if apply_gated_attention: + self.to_gate_logits = torch.nn.Linear(query_dim, heads, bias=True) + else: + self.to_gate_logits = None + + self.to_out = torch.nn.Sequential(torch.nn.Linear(inner_dim, query_dim, bias=True), torch.nn.Identity()) + + def forward( + self, + x: torch.Tensor, + context: torch.Tensor | None = None, + mask: torch.Tensor | None = None, + pe: torch.Tensor | None = None, + k_pe: torch.Tensor | None = None, + perturbation_mask: torch.Tensor | None = None, + all_perturbed: bool = False, + ) -> torch.Tensor: + q = self.to_q(x) + context = x if context is None else context + k = self.to_k(context) + v = self.to_v(context) + + q = self.q_norm(q) + k = self.k_norm(k) + + if pe is not None: + q = apply_rotary_emb(q, pe, self.rope_type) + k = apply_rotary_emb(k, pe if k_pe is None else k_pe, self.rope_type) + + # Reshape for attention_forward using unflatten + q = q.unflatten(-1, (self.heads, self.dim_head)) + k = k.unflatten(-1, (self.heads, self.dim_head)) + v = v.unflatten(-1, (self.heads, self.dim_head)) + + out = attention_forward( + q=q, + k=k, + v=v, + q_pattern="b s n d", + k_pattern="b s n d", + v_pattern="b s n d", + out_pattern="b s n d", + attn_mask=mask + ) + + # Reshape back to original format + out = out.flatten(2, 3) + + # Apply per-head gating if enabled + if self.to_gate_logits is not None: + gate_logits = self.to_gate_logits(x) # (B, T, H) + b, t, _ = out.shape + # Reshape to (B, T, H, D) for per-head gating + out = out.view(b, t, self.heads, self.dim_head) + # Apply gating: 2 * sigmoid(x) so that zero-init gives identity (2 * 0.5 = 1.0) + gates = 2.0 * torch.sigmoid(gate_logits) # (B, T, H) + out = out * gates.unsqueeze(-1) # (B, T, H, D) * (B, T, H, 1) + # Reshape back to (B, T, H*D) + out = out.view(b, t, self.heads * self.dim_head) + + return self.to_out(out) + + +class PixArtAlphaTextProjection(torch.nn.Module): + """ + Projects caption embeddings. Also handles dropout for classifier-free guidance. + Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py + """ + + def __init__(self, in_features: int, hidden_size: int, out_features: int | None = None, act_fn: str = "gelu_tanh"): + super().__init__() + if out_features is None: + out_features = hidden_size + self.linear_1 = torch.nn.Linear(in_features=in_features, out_features=hidden_size, bias=True) + if act_fn == "gelu_tanh": + self.act_1 = torch.nn.GELU(approximate="tanh") + elif act_fn == "silu": + self.act_1 = torch.nn.SiLU() + else: + raise ValueError(f"Unknown activation function: {act_fn}") + self.linear_2 = torch.nn.Linear(in_features=hidden_size, out_features=out_features, bias=True) + + def forward(self, caption: torch.Tensor) -> torch.Tensor: + hidden_states = self.linear_1(caption) + hidden_states = self.act_1(hidden_states) + hidden_states = self.linear_2(hidden_states) + return hidden_states + +@dataclass(frozen=True) +class TransformerArgs: + x: torch.Tensor + context: torch.Tensor + context_mask: torch.Tensor + timesteps: torch.Tensor + embedded_timestep: torch.Tensor + positional_embeddings: torch.Tensor + cross_positional_embeddings: torch.Tensor | None + cross_scale_shift_timestep: torch.Tensor | None + cross_gate_timestep: torch.Tensor | None + enabled: bool + prompt_timestep: torch.Tensor | None = None + self_attention_mask: torch.Tensor | None = ( + None # Additive log-space self-attention bias (B, 1, T, T), None = full attention + ) + + +class TransformerArgsPreprocessor: + def __init__( # noqa: PLR0913 + self, + patchify_proj: torch.nn.Linear, + adaln: AdaLayerNormSingle, + inner_dim: int, + max_pos: list[int], + num_attention_heads: int, + use_middle_indices_grid: bool, + timestep_scale_multiplier: int, + double_precision_rope: bool, + positional_embedding_theta: float, + rope_type: LTXRopeType, + caption_projection: torch.nn.Module | None = None, + prompt_adaln: AdaLayerNormSingle | None = None, + ) -> None: + self.patchify_proj = patchify_proj + self.adaln = adaln + self.inner_dim = inner_dim + self.max_pos = max_pos + self.num_attention_heads = num_attention_heads + self.use_middle_indices_grid = use_middle_indices_grid + self.timestep_scale_multiplier = timestep_scale_multiplier + self.double_precision_rope = double_precision_rope + self.positional_embedding_theta = positional_embedding_theta + self.rope_type = rope_type + self.caption_projection = caption_projection + self.prompt_adaln = prompt_adaln + + def _prepare_timestep( + self, timestep: torch.Tensor, adaln: AdaLayerNormSingle, batch_size: int, hidden_dtype: torch.dtype + ) -> tuple[torch.Tensor, torch.Tensor]: + """Prepare timestep embeddings.""" + timestep_scaled = timestep * self.timestep_scale_multiplier + timestep, embedded_timestep = adaln( + timestep_scaled.flatten(), + hidden_dtype=hidden_dtype, + ) + # Second dimension is 1 or number of tokens (if timestep_per_token) + timestep = timestep.view(batch_size, -1, timestep.shape[-1]) + embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.shape[-1]) + return timestep, embedded_timestep + + def _prepare_context( + self, + context: torch.Tensor, + x: torch.Tensor, + ) -> torch.Tensor: + """Prepare context for transformer blocks.""" + if self.caption_projection is not None: + context = self.caption_projection(context) + batch_size = x.shape[0] + return context.view(batch_size, -1, x.shape[-1]) + + def _prepare_attention_mask(self, attention_mask: torch.Tensor | None, x_dtype: torch.dtype) -> torch.Tensor | None: + """Prepare attention mask.""" + if attention_mask is None or torch.is_floating_point(attention_mask): + return attention_mask + + return (attention_mask - 1).to(x_dtype).reshape( + (attention_mask.shape[0], 1, -1, attention_mask.shape[-1]) + ) * torch.finfo(x_dtype).max + + def _prepare_self_attention_mask( + self, attention_mask: torch.Tensor | None, x_dtype: torch.dtype + ) -> torch.Tensor | None: + """Prepare self-attention mask by converting [0,1] values to additive log-space bias. + Input shape: (B, T, T) with values in [0, 1]. + Output shape: (B, 1, T, T) with 0.0 for full attention and a large negative value + for masked positions. + Positions with attention_mask <= 0 are fully masked (mapped to the dtype's minimum + representable value). Strictly positive entries are converted via log-space for + smooth attenuation, with small values clamped for numerical stability. + Returns None if input is None (no masking). + """ + if attention_mask is None: + return None + + # Convert [0, 1] attention mask to additive log-space bias: + # 1.0 -> log(1.0) = 0.0 (no bias, full attention) + # 0.0 -> finfo.min (fully masked) + finfo = torch.finfo(x_dtype) + eps = finfo.tiny + + bias = torch.full_like(attention_mask, finfo.min, dtype=x_dtype) + positive = attention_mask > 0 + if positive.any(): + bias[positive] = torch.log(attention_mask[positive].clamp(min=eps)).to(x_dtype) + + return bias.unsqueeze(1) # (B, 1, T, T) for head broadcast + + def _prepare_positional_embeddings( + self, + positions: torch.Tensor, + inner_dim: int, + max_pos: list[int], + use_middle_indices_grid: bool, + num_attention_heads: int, + x_dtype: torch.dtype, + ) -> torch.Tensor: + """Prepare positional embeddings.""" + freq_grid_generator = generate_freq_grid_np if self.double_precision_rope else generate_freq_grid_pytorch + pe = precompute_freqs_cis( + positions, + dim=inner_dim, + out_dtype=x_dtype, + theta=self.positional_embedding_theta, + max_pos=max_pos, + use_middle_indices_grid=use_middle_indices_grid, + num_attention_heads=num_attention_heads, + rope_type=self.rope_type, + freq_grid_generator=freq_grid_generator, + ) + return pe + + def prepare( + self, + modality: Modality, + cross_modality: Modality | None = None, # noqa: ARG002 + ) -> TransformerArgs: + x = self.patchify_proj(modality.latent) + batch_size = x.shape[0] + timestep, embedded_timestep = self._prepare_timestep( + modality.timesteps, self.adaln, batch_size, modality.latent.dtype + ) + prompt_timestep = None + if self.prompt_adaln is not None: + prompt_timestep, _ = self._prepare_timestep( + modality.sigma, self.prompt_adaln, batch_size, modality.latent.dtype + ) + context = self._prepare_context(modality.context, x) + attention_mask = self._prepare_attention_mask(modality.context_mask, modality.latent.dtype) + pe = self._prepare_positional_embeddings( + positions=modality.positions, + inner_dim=self.inner_dim, + max_pos=self.max_pos, + use_middle_indices_grid=self.use_middle_indices_grid, + num_attention_heads=self.num_attention_heads, + x_dtype=modality.latent.dtype, + ) + self_attention_mask = self._prepare_self_attention_mask(modality.attention_mask, modality.latent.dtype) + return TransformerArgs( + x=x, + context=context, + context_mask=attention_mask, + timesteps=timestep, + embedded_timestep=embedded_timestep, + positional_embeddings=pe, + cross_positional_embeddings=None, + cross_scale_shift_timestep=None, + cross_gate_timestep=None, + enabled=modality.enabled, + prompt_timestep=prompt_timestep, + self_attention_mask=self_attention_mask, + ) + + +class MultiModalTransformerArgsPreprocessor: + def __init__( # noqa: PLR0913 + self, + patchify_proj: torch.nn.Linear, + adaln: AdaLayerNormSingle, + cross_scale_shift_adaln: AdaLayerNormSingle, + cross_gate_adaln: AdaLayerNormSingle, + inner_dim: int, + max_pos: list[int], + num_attention_heads: int, + cross_pe_max_pos: int, + use_middle_indices_grid: bool, + audio_cross_attention_dim: int, + timestep_scale_multiplier: int, + double_precision_rope: bool, + positional_embedding_theta: float, + rope_type: LTXRopeType, + av_ca_timestep_scale_multiplier: int, + caption_projection: torch.nn.Module | None = None, + prompt_adaln: AdaLayerNormSingle | None = None, + ) -> None: + self.simple_preprocessor = TransformerArgsPreprocessor( + patchify_proj=patchify_proj, + adaln=adaln, + inner_dim=inner_dim, + max_pos=max_pos, + num_attention_heads=num_attention_heads, + use_middle_indices_grid=use_middle_indices_grid, + timestep_scale_multiplier=timestep_scale_multiplier, + double_precision_rope=double_precision_rope, + positional_embedding_theta=positional_embedding_theta, + rope_type=rope_type, + caption_projection=caption_projection, + prompt_adaln=prompt_adaln, + ) + self.cross_scale_shift_adaln = cross_scale_shift_adaln + self.cross_gate_adaln = cross_gate_adaln + self.cross_pe_max_pos = cross_pe_max_pos + self.audio_cross_attention_dim = audio_cross_attention_dim + self.av_ca_timestep_scale_multiplier = av_ca_timestep_scale_multiplier + + def prepare( + self, + modality: Modality, + cross_modality: Modality | None = None, + ) -> TransformerArgs: + transformer_args = self.simple_preprocessor.prepare(modality) + if cross_modality is None: + return transformer_args + + if cross_modality.sigma.numel() > 1: + if cross_modality.sigma.shape[0] != modality.timesteps.shape[0]: + raise ValueError("Cross modality sigma must have the same batch size as the modality") + if cross_modality.sigma.ndim != 1: + raise ValueError("Cross modality sigma must be a 1D tensor") + + cross_timestep = cross_modality.sigma.view( + modality.timesteps.shape[0], 1, *[1] * len(modality.timesteps.shape[2:]) + ) + + cross_pe = self.simple_preprocessor._prepare_positional_embeddings( + positions=modality.positions[:, 0:1, :], + inner_dim=self.audio_cross_attention_dim, + max_pos=[self.cross_pe_max_pos], + use_middle_indices_grid=True, + num_attention_heads=self.simple_preprocessor.num_attention_heads, + x_dtype=modality.latent.dtype, + ) + + cross_scale_shift_timestep, cross_gate_timestep = self._prepare_cross_attention_timestep( + timestep=cross_timestep, + timestep_scale_multiplier=self.simple_preprocessor.timestep_scale_multiplier, + batch_size=transformer_args.x.shape[0], + hidden_dtype=modality.latent.dtype, + ) + + return replace( + transformer_args, + cross_positional_embeddings=cross_pe, + cross_scale_shift_timestep=cross_scale_shift_timestep, + cross_gate_timestep=cross_gate_timestep, + ) + + def _prepare_cross_attention_timestep( + self, + timestep: torch.Tensor | None, + timestep_scale_multiplier: int, + batch_size: int, + hidden_dtype: torch.dtype, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Prepare cross attention timestep embeddings.""" + timestep = timestep * timestep_scale_multiplier + + av_ca_factor = self.av_ca_timestep_scale_multiplier / timestep_scale_multiplier + + scale_shift_timestep, _ = self.cross_scale_shift_adaln( + timestep.flatten(), + hidden_dtype=hidden_dtype, + ) + scale_shift_timestep = scale_shift_timestep.view(batch_size, -1, scale_shift_timestep.shape[-1]) + gate_noise_timestep, _ = self.cross_gate_adaln( + timestep.flatten() * av_ca_factor, + hidden_dtype=hidden_dtype, + ) + gate_noise_timestep = gate_noise_timestep.view(batch_size, -1, gate_noise_timestep.shape[-1]) + + return scale_shift_timestep, gate_noise_timestep + + +@dataclass +class TransformerConfig: + dim: int + heads: int + d_head: int + context_dim: int + apply_gated_attention: bool = False + cross_attention_adaln: bool = False + + +class BasicAVTransformerBlock(torch.nn.Module): + def __init__( + self, + idx: int, + video: TransformerConfig | None = None, + audio: TransformerConfig | None = None, + rope_type: LTXRopeType = LTXRopeType.INTERLEAVED, + norm_eps: float = 1e-6, + ): + super().__init__() + + self.idx = idx + if video is not None: + self.attn1 = Attention( + query_dim=video.dim, + heads=video.heads, + dim_head=video.d_head, + context_dim=None, + rope_type=rope_type, + norm_eps=norm_eps, + apply_gated_attention=video.apply_gated_attention, + ) + self.attn2 = Attention( + query_dim=video.dim, + context_dim=video.context_dim, + heads=video.heads, + dim_head=video.d_head, + rope_type=rope_type, + norm_eps=norm_eps, + apply_gated_attention=video.apply_gated_attention, + ) + self.ff = FeedForward(video.dim, dim_out=video.dim) + video_sst_size = adaln_embedding_coefficient(video.cross_attention_adaln) + self.scale_shift_table = torch.nn.Parameter(torch.empty(video_sst_size, video.dim)) + + if audio is not None: + self.audio_attn1 = Attention( + query_dim=audio.dim, + heads=audio.heads, + dim_head=audio.d_head, + context_dim=None, + rope_type=rope_type, + norm_eps=norm_eps, + apply_gated_attention=audio.apply_gated_attention, + ) + self.audio_attn2 = Attention( + query_dim=audio.dim, + context_dim=audio.context_dim, + heads=audio.heads, + dim_head=audio.d_head, + rope_type=rope_type, + norm_eps=norm_eps, + apply_gated_attention=audio.apply_gated_attention, + ) + self.audio_ff = FeedForward(audio.dim, dim_out=audio.dim) + audio_sst_size = adaln_embedding_coefficient(audio.cross_attention_adaln) + self.audio_scale_shift_table = torch.nn.Parameter(torch.empty(audio_sst_size, audio.dim)) + + if audio is not None and video is not None: + # Q: Video, K,V: Audio + self.audio_to_video_attn = Attention( + query_dim=video.dim, + context_dim=audio.dim, + heads=audio.heads, + dim_head=audio.d_head, + rope_type=rope_type, + norm_eps=norm_eps, + apply_gated_attention=video.apply_gated_attention, + ) + + # Q: Audio, K,V: Video + self.video_to_audio_attn = Attention( + query_dim=audio.dim, + context_dim=video.dim, + heads=audio.heads, + dim_head=audio.d_head, + rope_type=rope_type, + norm_eps=norm_eps, + apply_gated_attention=audio.apply_gated_attention, + ) + + self.scale_shift_table_a2v_ca_audio = torch.nn.Parameter(torch.empty(5, audio.dim)) + self.scale_shift_table_a2v_ca_video = torch.nn.Parameter(torch.empty(5, video.dim)) + + self.cross_attention_adaln = (video is not None and video.cross_attention_adaln) or ( + audio is not None and audio.cross_attention_adaln + ) + + if self.cross_attention_adaln and video is not None: + self.prompt_scale_shift_table = torch.nn.Parameter(torch.empty(2, video.dim)) + if self.cross_attention_adaln and audio is not None: + self.audio_prompt_scale_shift_table = torch.nn.Parameter(torch.empty(2, audio.dim)) + + self.norm_eps = norm_eps + + def get_ada_values( + self, scale_shift_table: torch.Tensor, batch_size: int, timestep: torch.Tensor, indices: slice + ) -> tuple[torch.Tensor, ...]: + num_ada_params = scale_shift_table.shape[0] + + ada_values = ( + scale_shift_table[indices].unsqueeze(0).unsqueeze(0).to(device=timestep.device, dtype=timestep.dtype) + + timestep.reshape(batch_size, timestep.shape[1], num_ada_params, -1)[:, :, indices, :] + ).unbind(dim=2) + return ada_values + + def get_av_ca_ada_values( + self, + scale_shift_table: torch.Tensor, + batch_size: int, + scale_shift_timestep: torch.Tensor, + gate_timestep: torch.Tensor, + scale_shift_indices: slice, + num_scale_shift_values: int = 4, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + scale_shift_ada_values = self.get_ada_values( + scale_shift_table[:num_scale_shift_values, :], batch_size, scale_shift_timestep, scale_shift_indices + ) + gate_ada_values = self.get_ada_values( + scale_shift_table[num_scale_shift_values:, :], batch_size, gate_timestep, slice(None, None) + ) + + scale, shift = (t.squeeze(2) for t in scale_shift_ada_values) + (gate,) = (t.squeeze(2) for t in gate_ada_values) + + return scale, shift, gate + + def _apply_text_cross_attention( + self, + x: torch.Tensor, + context: torch.Tensor, + attn: Attention, + scale_shift_table: torch.Tensor, + prompt_scale_shift_table: torch.Tensor | None, + timestep: torch.Tensor, + prompt_timestep: torch.Tensor | None, + context_mask: torch.Tensor | None, + cross_attention_adaln: bool = False, + ) -> torch.Tensor: + """Apply text cross-attention, with optional AdaLN modulation.""" + if cross_attention_adaln: + shift_q, scale_q, gate = self.get_ada_values(scale_shift_table, x.shape[0], timestep, slice(6, 9)) + return apply_cross_attention_adaln( + x, + context, + attn, + shift_q, + scale_q, + gate, + prompt_scale_shift_table, + prompt_timestep, + context_mask, + self.norm_eps, + ) + return attn(rms_norm(x, eps=self.norm_eps), context=context, mask=context_mask) + + def forward( # noqa: PLR0915 + self, + video: TransformerArgs | None, + audio: TransformerArgs | None, + perturbations: BatchedPerturbationConfig | None = None, + ) -> tuple[TransformerArgs | None, TransformerArgs | None]: + if video is None and audio is None: + raise ValueError("At least one of video or audio must be provided") + + batch_size = (video or audio).x.shape[0] + + if perturbations is None: + perturbations = BatchedPerturbationConfig.empty(batch_size) + + vx = video.x if video is not None else None + ax = audio.x if audio is not None else None + + run_vx = video is not None and video.enabled and vx.numel() > 0 + run_ax = audio is not None and audio.enabled and ax.numel() > 0 + + run_a2v = run_vx and (audio is not None and ax.numel() > 0) + run_v2a = run_ax and (video is not None and vx.numel() > 0) + + if run_vx: + vshift_msa, vscale_msa, vgate_msa = self.get_ada_values( + self.scale_shift_table, vx.shape[0], video.timesteps, slice(0, 3) + ) + norm_vx = rms_norm(vx, eps=self.norm_eps) * (1 + vscale_msa) + vshift_msa + del vshift_msa, vscale_msa + + all_perturbed = perturbations.all_in_batch(PerturbationType.SKIP_VIDEO_SELF_ATTN, self.idx) + none_perturbed = not perturbations.any_in_batch(PerturbationType.SKIP_VIDEO_SELF_ATTN, self.idx) + v_mask = ( + perturbations.mask_like(PerturbationType.SKIP_VIDEO_SELF_ATTN, self.idx, vx) + if not all_perturbed and not none_perturbed + else None + ) + vx = ( + vx + + self.attn1( + norm_vx, + pe=video.positional_embeddings, + mask=video.self_attention_mask, + perturbation_mask=v_mask, + all_perturbed=all_perturbed, + ) + * vgate_msa + ) + del vgate_msa, norm_vx, v_mask + vx = vx + self._apply_text_cross_attention( + vx, + video.context, + self.attn2, + self.scale_shift_table, + getattr(self, "prompt_scale_shift_table", None), + video.timesteps, + video.prompt_timestep, + video.context_mask, + cross_attention_adaln=self.cross_attention_adaln, + ) + + if run_ax: + ashift_msa, ascale_msa, agate_msa = self.get_ada_values( + self.audio_scale_shift_table, ax.shape[0], audio.timesteps, slice(0, 3) + ) + + norm_ax = rms_norm(ax, eps=self.norm_eps) * (1 + ascale_msa) + ashift_msa + del ashift_msa, ascale_msa + all_perturbed = perturbations.all_in_batch(PerturbationType.SKIP_AUDIO_SELF_ATTN, self.idx) + none_perturbed = not perturbations.any_in_batch(PerturbationType.SKIP_AUDIO_SELF_ATTN, self.idx) + a_mask = ( + perturbations.mask_like(PerturbationType.SKIP_AUDIO_SELF_ATTN, self.idx, ax) + if not all_perturbed and not none_perturbed + else None + ) + ax = ( + ax + + self.audio_attn1( + norm_ax, + pe=audio.positional_embeddings, + mask=audio.self_attention_mask, + perturbation_mask=a_mask, + all_perturbed=all_perturbed, + ) + * agate_msa + ) + del agate_msa, norm_ax, a_mask + ax = ax + self._apply_text_cross_attention( + ax, + audio.context, + self.audio_attn2, + self.audio_scale_shift_table, + getattr(self, "audio_prompt_scale_shift_table", None), + audio.timesteps, + audio.prompt_timestep, + audio.context_mask, + cross_attention_adaln=self.cross_attention_adaln, + ) + + # Audio - Video cross attention. + if run_a2v or run_v2a: + vx_norm3 = rms_norm(vx, eps=self.norm_eps) + ax_norm3 = rms_norm(ax, eps=self.norm_eps) + + if run_a2v and not perturbations.all_in_batch(PerturbationType.SKIP_A2V_CROSS_ATTN, self.idx): + scale_ca_video_a2v, shift_ca_video_a2v, gate_out_a2v = self.get_av_ca_ada_values( + self.scale_shift_table_a2v_ca_video, + vx.shape[0], + video.cross_scale_shift_timestep, + video.cross_gate_timestep, + slice(0, 2), + ) + vx_scaled = vx_norm3 * (1 + scale_ca_video_a2v) + shift_ca_video_a2v + del scale_ca_video_a2v, shift_ca_video_a2v + + scale_ca_audio_a2v, shift_ca_audio_a2v, _ = self.get_av_ca_ada_values( + self.scale_shift_table_a2v_ca_audio, + ax.shape[0], + audio.cross_scale_shift_timestep, + audio.cross_gate_timestep, + slice(0, 2), + ) + ax_scaled = ax_norm3 * (1 + scale_ca_audio_a2v) + shift_ca_audio_a2v + del scale_ca_audio_a2v, shift_ca_audio_a2v + a2v_mask = perturbations.mask_like(PerturbationType.SKIP_A2V_CROSS_ATTN, self.idx, vx) + vx = vx + ( + self.audio_to_video_attn( + vx_scaled, + context=ax_scaled, + pe=video.cross_positional_embeddings, + k_pe=audio.cross_positional_embeddings, + ) + * gate_out_a2v + * a2v_mask + ) + del gate_out_a2v, a2v_mask, vx_scaled, ax_scaled + + if run_v2a and not perturbations.all_in_batch(PerturbationType.SKIP_V2A_CROSS_ATTN, self.idx): + scale_ca_audio_v2a, shift_ca_audio_v2a, gate_out_v2a = self.get_av_ca_ada_values( + self.scale_shift_table_a2v_ca_audio, + ax.shape[0], + audio.cross_scale_shift_timestep, + audio.cross_gate_timestep, + slice(2, 4), + ) + ax_scaled = ax_norm3 * (1 + scale_ca_audio_v2a) + shift_ca_audio_v2a + del scale_ca_audio_v2a, shift_ca_audio_v2a + scale_ca_video_v2a, shift_ca_video_v2a, _ = self.get_av_ca_ada_values( + self.scale_shift_table_a2v_ca_video, + vx.shape[0], + video.cross_scale_shift_timestep, + video.cross_gate_timestep, + slice(2, 4), + ) + vx_scaled = vx_norm3 * (1 + scale_ca_video_v2a) + shift_ca_video_v2a + del scale_ca_video_v2a, shift_ca_video_v2a + v2a_mask = perturbations.mask_like(PerturbationType.SKIP_V2A_CROSS_ATTN, self.idx, ax) + ax = ax + ( + self.video_to_audio_attn( + ax_scaled, + context=vx_scaled, + pe=audio.cross_positional_embeddings, + k_pe=video.cross_positional_embeddings, + ) + * gate_out_v2a + * v2a_mask + ) + del gate_out_v2a, v2a_mask, ax_scaled, vx_scaled + + del vx_norm3, ax_norm3 + + if run_vx: + vshift_mlp, vscale_mlp, vgate_mlp = self.get_ada_values( + self.scale_shift_table, vx.shape[0], video.timesteps, slice(3, 6) + ) + vx_scaled = rms_norm(vx, eps=self.norm_eps) * (1 + vscale_mlp) + vshift_mlp + vx = vx + self.ff(vx_scaled) * vgate_mlp + + del vshift_mlp, vscale_mlp, vgate_mlp, vx_scaled + + if run_ax: + ashift_mlp, ascale_mlp, agate_mlp = self.get_ada_values( + self.audio_scale_shift_table, ax.shape[0], audio.timesteps, slice(3, 6) + ) + ax_scaled = rms_norm(ax, eps=self.norm_eps) * (1 + ascale_mlp) + ashift_mlp + ax = ax + self.audio_ff(ax_scaled) * agate_mlp + + del ashift_mlp, ascale_mlp, agate_mlp, ax_scaled + + return replace(video, x=vx) if video is not None else None, replace(audio, x=ax) if audio is not None else None + + +def apply_cross_attention_adaln( + x: torch.Tensor, + context: torch.Tensor, + attn: Attention, + q_shift: torch.Tensor, + q_scale: torch.Tensor, + q_gate: torch.Tensor, + prompt_scale_shift_table: torch.Tensor, + prompt_timestep: torch.Tensor, + context_mask: torch.Tensor | None = None, + norm_eps: float = 1e-6, +) -> torch.Tensor: + batch_size = x.shape[0] + shift_kv, scale_kv = ( + prompt_scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + + prompt_timestep.reshape(batch_size, prompt_timestep.shape[1], 2, -1) + ).unbind(dim=2) + attn_input = rms_norm(x, eps=norm_eps) * (1 + q_scale) + q_shift + encoder_hidden_states = context * (1 + scale_kv) + shift_kv + return attn(attn_input, context=encoder_hidden_states, mask=context_mask) * q_gate + + +class GELUApprox(torch.nn.Module): + def __init__(self, dim_in: int, dim_out: int) -> None: + super().__init__() + self.proj = torch.nn.Linear(dim_in, dim_out) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.nn.functional.gelu(self.proj(x), approximate="tanh") + + +class FeedForward(torch.nn.Module): + def __init__(self, dim: int, dim_out: int, mult: int = 4) -> None: + super().__init__() + inner_dim = int(dim * mult) + project_in = GELUApprox(dim, inner_dim) + + self.net = torch.nn.Sequential(project_in, torch.nn.Identity(), torch.nn.Linear(inner_dim, dim_out)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.net(x) + + +class LTXModelType(Enum): + AudioVideo = "ltx av model" + VideoOnly = "ltx video only model" + AudioOnly = "ltx audio only model" + + def is_video_enabled(self) -> bool: + return self in (LTXModelType.AudioVideo, LTXModelType.VideoOnly) + + def is_audio_enabled(self) -> bool: + return self in (LTXModelType.AudioVideo, LTXModelType.AudioOnly) + + +class LTXModel(torch.nn.Module): + """ + LTX model transformer implementation. + This class implements the transformer blocks for the LTX model. + """ + _repeated_blocks = ["BasicAVTransformerBlock"] + + def __init__( # noqa: PLR0913 + self, + *, + model_type: LTXModelType = LTXModelType.AudioVideo, + num_attention_heads: int = 32, + attention_head_dim: int = 128, + in_channels: int = 128, + out_channels: int = 128, + num_layers: int = 48, + cross_attention_dim: int = 4096, + norm_eps: float = 1e-06, + caption_channels: int = 3840, + positional_embedding_theta: float = 10000.0, + positional_embedding_max_pos: list[int] | None = [20, 2048, 2048], + timestep_scale_multiplier: int = 1000, + use_middle_indices_grid: bool = True, + audio_num_attention_heads: int = 32, + audio_attention_head_dim: int = 64, + audio_in_channels: int = 128, + audio_out_channels: int = 128, + audio_cross_attention_dim: int = 2048, + audio_positional_embedding_max_pos: list[int] | None = [20], + av_ca_timestep_scale_multiplier: int = 1000, + rope_type: LTXRopeType = LTXRopeType.SPLIT, + double_precision_rope: bool = True, + apply_gated_attention: bool = False, + cross_attention_adaln: bool = False, + ): + super().__init__() + self._enable_gradient_checkpointing = False + self.use_middle_indices_grid = use_middle_indices_grid + self.rope_type = rope_type + self.double_precision_rope = double_precision_rope + self.timestep_scale_multiplier = timestep_scale_multiplier + self.positional_embedding_theta = positional_embedding_theta + self.model_type = model_type + self.cross_attention_adaln = cross_attention_adaln + cross_pe_max_pos = None + if model_type.is_video_enabled(): + if positional_embedding_max_pos is None: + positional_embedding_max_pos = [20, 2048, 2048] + self.positional_embedding_max_pos = positional_embedding_max_pos + self.num_attention_heads = num_attention_heads + self.inner_dim = num_attention_heads * attention_head_dim + self._init_video( + in_channels=in_channels, + out_channels=out_channels, + caption_channels=caption_channels, + norm_eps=norm_eps, + ) + + if model_type.is_audio_enabled(): + if audio_positional_embedding_max_pos is None: + audio_positional_embedding_max_pos = [20] + self.audio_positional_embedding_max_pos = audio_positional_embedding_max_pos + self.audio_num_attention_heads = audio_num_attention_heads + self.audio_inner_dim = self.audio_num_attention_heads * audio_attention_head_dim + self._init_audio( + in_channels=audio_in_channels, + out_channels=audio_out_channels, + caption_channels=caption_channels, + norm_eps=norm_eps, + ) + + if model_type.is_video_enabled() and model_type.is_audio_enabled(): + cross_pe_max_pos = max(self.positional_embedding_max_pos[0], self.audio_positional_embedding_max_pos[0]) + self.av_ca_timestep_scale_multiplier = av_ca_timestep_scale_multiplier + self.audio_cross_attention_dim = audio_cross_attention_dim + self._init_audio_video(num_scale_shift_values=4) + + self._init_preprocessors(cross_pe_max_pos) + # Initialize transformer blocks + self._init_transformer_blocks( + num_layers=num_layers, + attention_head_dim=attention_head_dim if model_type.is_video_enabled() else 0, + cross_attention_dim=cross_attention_dim, + audio_attention_head_dim=audio_attention_head_dim if model_type.is_audio_enabled() else 0, + audio_cross_attention_dim=audio_cross_attention_dim, + norm_eps=norm_eps, + apply_gated_attention=apply_gated_attention, + ) + + @property + def _adaln_embedding_coefficient(self) -> int: + return adaln_embedding_coefficient(self.cross_attention_adaln) + + def _init_video( + self, + in_channels: int, + out_channels: int, + caption_channels: int, + norm_eps: float, + ) -> None: + """Initialize video-specific components.""" + # Video input components + self.patchify_proj = torch.nn.Linear(in_channels, self.inner_dim, bias=True) + self.adaln_single = AdaLayerNormSingle(self.inner_dim, embedding_coefficient=self._adaln_embedding_coefficient) + self.prompt_adaln_single = AdaLayerNormSingle(self.inner_dim, embedding_coefficient=2) if self.cross_attention_adaln else None + + # Video caption projection + if caption_channels is not None: + self.caption_projection = PixArtAlphaTextProjection( + in_features=caption_channels, + hidden_size=self.inner_dim, + ) + + # Video output components + self.scale_shift_table = torch.nn.Parameter(torch.empty(2, self.inner_dim)) + self.norm_out = torch.nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=norm_eps) + self.proj_out = torch.nn.Linear(self.inner_dim, out_channels) + + def _init_audio( + self, + in_channels: int, + out_channels: int, + caption_channels: int, + norm_eps: float, + ) -> None: + """Initialize audio-specific components.""" + + # Audio input components + self.audio_patchify_proj = torch.nn.Linear(in_channels, self.audio_inner_dim, bias=True) + + self.audio_adaln_single = AdaLayerNormSingle(self.audio_inner_dim, embedding_coefficient=self._adaln_embedding_coefficient) + self.audio_prompt_adaln_single = AdaLayerNormSingle(self.audio_inner_dim, embedding_coefficient=2) if self.cross_attention_adaln else None + + # Audio caption projection + if caption_channels is not None: + self.audio_caption_projection = PixArtAlphaTextProjection( + in_features=caption_channels, + hidden_size=self.audio_inner_dim, + ) + + # Audio output components + self.audio_scale_shift_table = torch.nn.Parameter(torch.empty(2, self.audio_inner_dim)) + self.audio_norm_out = torch.nn.LayerNorm(self.audio_inner_dim, elementwise_affine=False, eps=norm_eps) + self.audio_proj_out = torch.nn.Linear(self.audio_inner_dim, out_channels) + + def _init_audio_video( + self, + num_scale_shift_values: int, + ) -> None: + """Initialize audio-video cross-attention components.""" + self.av_ca_video_scale_shift_adaln_single = AdaLayerNormSingle( + self.inner_dim, + embedding_coefficient=num_scale_shift_values, + ) + + self.av_ca_audio_scale_shift_adaln_single = AdaLayerNormSingle( + self.audio_inner_dim, + embedding_coefficient=num_scale_shift_values, + ) + + self.av_ca_a2v_gate_adaln_single = AdaLayerNormSingle( + self.inner_dim, + embedding_coefficient=1, + ) + + self.av_ca_v2a_gate_adaln_single = AdaLayerNormSingle( + self.audio_inner_dim, + embedding_coefficient=1, + ) + + def _init_preprocessors( + self, + cross_pe_max_pos: int | None = None, + ) -> None: + """Initialize preprocessors for LTX.""" + + if self.model_type.is_video_enabled() and self.model_type.is_audio_enabled(): + self.video_args_preprocessor = MultiModalTransformerArgsPreprocessor( + patchify_proj=self.patchify_proj, + adaln=self.adaln_single, + cross_scale_shift_adaln=self.av_ca_video_scale_shift_adaln_single, + cross_gate_adaln=self.av_ca_a2v_gate_adaln_single, + inner_dim=self.inner_dim, + max_pos=self.positional_embedding_max_pos, + num_attention_heads=self.num_attention_heads, + cross_pe_max_pos=cross_pe_max_pos, + use_middle_indices_grid=self.use_middle_indices_grid, + audio_cross_attention_dim=self.audio_cross_attention_dim, + timestep_scale_multiplier=self.timestep_scale_multiplier, + double_precision_rope=self.double_precision_rope, + positional_embedding_theta=self.positional_embedding_theta, + rope_type=self.rope_type, + av_ca_timestep_scale_multiplier=self.av_ca_timestep_scale_multiplier, + caption_projection=getattr(self, "caption_projection", None), + prompt_adaln=getattr(self, "prompt_adaln_single", None), + ) + self.audio_args_preprocessor = MultiModalTransformerArgsPreprocessor( + patchify_proj=self.audio_patchify_proj, + adaln=self.audio_adaln_single, + cross_scale_shift_adaln=self.av_ca_audio_scale_shift_adaln_single, + cross_gate_adaln=self.av_ca_v2a_gate_adaln_single, + inner_dim=self.audio_inner_dim, + max_pos=self.audio_positional_embedding_max_pos, + num_attention_heads=self.audio_num_attention_heads, + cross_pe_max_pos=cross_pe_max_pos, + use_middle_indices_grid=self.use_middle_indices_grid, + audio_cross_attention_dim=self.audio_cross_attention_dim, + timestep_scale_multiplier=self.timestep_scale_multiplier, + double_precision_rope=self.double_precision_rope, + positional_embedding_theta=self.positional_embedding_theta, + rope_type=self.rope_type, + av_ca_timestep_scale_multiplier=self.av_ca_timestep_scale_multiplier, + caption_projection=getattr(self, "audio_caption_projection", None), + prompt_adaln=getattr(self, "audio_prompt_adaln_single", None), + ) + elif self.model_type.is_video_enabled(): + self.video_args_preprocessor = TransformerArgsPreprocessor( + patchify_proj=self.patchify_proj, + adaln=self.adaln_single, + inner_dim=self.inner_dim, + max_pos=self.positional_embedding_max_pos, + num_attention_heads=self.num_attention_heads, + use_middle_indices_grid=self.use_middle_indices_grid, + timestep_scale_multiplier=self.timestep_scale_multiplier, + double_precision_rope=self.double_precision_rope, + positional_embedding_theta=self.positional_embedding_theta, + rope_type=self.rope_type, + caption_projection=getattr(self, "caption_projection", None), + prompt_adaln=getattr(self, "prompt_adaln_single", None), + ) + elif self.model_type.is_audio_enabled(): + self.audio_args_preprocessor = TransformerArgsPreprocessor( + patchify_proj=self.audio_patchify_proj, + adaln=self.audio_adaln_single, + inner_dim=self.audio_inner_dim, + max_pos=self.audio_positional_embedding_max_pos, + num_attention_heads=self.audio_num_attention_heads, + use_middle_indices_grid=self.use_middle_indices_grid, + timestep_scale_multiplier=self.timestep_scale_multiplier, + double_precision_rope=self.double_precision_rope, + positional_embedding_theta=self.positional_embedding_theta, + rope_type=self.rope_type, + caption_projection=getattr(self, "audio_caption_projection", None), + prompt_adaln=getattr(self, "audio_prompt_adaln_single", None), + ) + + def _init_transformer_blocks( + self, + num_layers: int, + attention_head_dim: int, + cross_attention_dim: int, + audio_attention_head_dim: int, + audio_cross_attention_dim: int, + norm_eps: float, + apply_gated_attention: bool, + ) -> None: + """Initialize transformer blocks for LTX.""" + video_config = ( + TransformerConfig( + dim=self.inner_dim, + heads=self.num_attention_heads, + d_head=attention_head_dim, + context_dim=cross_attention_dim, + apply_gated_attention=apply_gated_attention, + cross_attention_adaln=self.cross_attention_adaln, + ) + if self.model_type.is_video_enabled() + else None + ) + audio_config = ( + TransformerConfig( + dim=self.audio_inner_dim, + heads=self.audio_num_attention_heads, + d_head=audio_attention_head_dim, + context_dim=audio_cross_attention_dim, + apply_gated_attention=apply_gated_attention, + cross_attention_adaln=self.cross_attention_adaln, + ) + if self.model_type.is_audio_enabled() + else None + ) + self.transformer_blocks = torch.nn.ModuleList( + [ + BasicAVTransformerBlock( + idx=idx, + video=video_config, + audio=audio_config, + rope_type=self.rope_type, + norm_eps=norm_eps, + ) + for idx in range(num_layers) + ] + ) + + def set_gradient_checkpointing(self, enable: bool) -> None: + """Enable or disable gradient checkpointing for transformer blocks. + Gradient checkpointing trades compute for memory by recomputing activations + during the backward pass instead of storing them. This can significantly + reduce memory usage at the cost of ~20-30% slower training. + Args: + enable: Whether to enable gradient checkpointing + """ + self._enable_gradient_checkpointing = enable + + def _process_transformer_blocks( + self, + video: TransformerArgs | None, + audio: TransformerArgs | None, + perturbations: BatchedPerturbationConfig, + use_gradient_checkpointing: bool = False, + use_gradient_checkpointing_offload: bool = False, + ) -> tuple[TransformerArgs, TransformerArgs]: + """Process transformer blocks for LTXAV.""" + + # Process transformer blocks + for block in self.transformer_blocks: + video, audio = gradient_checkpoint_forward( + block, + use_gradient_checkpointing, + use_gradient_checkpointing_offload, + video=video, + audio=audio, + perturbations=perturbations, + ) + + return video, audio + + def _process_output( + self, + scale_shift_table: torch.Tensor, + norm_out: torch.nn.LayerNorm, + proj_out: torch.nn.Linear, + x: torch.Tensor, + embedded_timestep: torch.Tensor, + ) -> torch.Tensor: + """Process output for LTXV.""" + # Apply scale-shift modulation + scale_shift_values = ( + scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + embedded_timestep[:, :, None] + ) + shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1] + + x = norm_out(x) + x = x * (1 + scale) + shift + x = proj_out(x) + return x + + def _forward( + self, + video: Modality | None, + audio: Modality | None, + perturbations: BatchedPerturbationConfig, + use_gradient_checkpointing: bool = False, + use_gradient_checkpointing_offload: bool = False, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Forward pass for LTX models. + Returns: + Processed output tensors + """ + if not self.model_type.is_video_enabled() and video is not None: + raise ValueError("Video is not enabled for this model") + if not self.model_type.is_audio_enabled() and audio is not None: + raise ValueError("Audio is not enabled for this model") + + video_args = self.video_args_preprocessor.prepare(video, audio) if video is not None else None + audio_args = self.audio_args_preprocessor.prepare(audio, video) if audio is not None else None + # Process transformer blocks + video_out, audio_out = self._process_transformer_blocks( + video=video_args, + audio=audio_args, + perturbations=perturbations, + use_gradient_checkpointing=use_gradient_checkpointing, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + ) + + # Process output + vx = ( + self._process_output( + self.scale_shift_table, self.norm_out, self.proj_out, video_out.x, video_out.embedded_timestep + ) + if video_out is not None + else None + ) + ax = ( + self._process_output( + self.audio_scale_shift_table, + self.audio_norm_out, + self.audio_proj_out, + audio_out.x, + audio_out.embedded_timestep, + ) + if audio_out is not None + else None + ) + return vx, ax + + def forward(self, video_latents, video_positions, video_context, video_timesteps, audio_latents, audio_positions, audio_context, audio_timesteps, sigma, use_gradient_checkpointing=False, use_gradient_checkpointing_offload=False): + cross_pe_max_pos = None + if self.model_type.is_video_enabled() and self.model_type.is_audio_enabled(): + cross_pe_max_pos = max(self.positional_embedding_max_pos[0], self.audio_positional_embedding_max_pos[0]) + self._init_preprocessors(cross_pe_max_pos) + video = Modality(video_latents, sigma, video_timesteps, video_positions, video_context) + audio = Modality(audio_latents, sigma, audio_timesteps, audio_positions, audio_context) if audio_latents is not None else None + vx, ax = self._forward(video=video, audio=audio, perturbations=None, use_gradient_checkpointing=use_gradient_checkpointing, use_gradient_checkpointing_offload=use_gradient_checkpointing_offload) + return vx, ax diff --git a/diffsynth/models/ltx2_text_encoder.py b/diffsynth/models/ltx2_text_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..e4f3b1a3e1995b1adefadf31348a9fe19b9f8f02 --- /dev/null +++ b/diffsynth/models/ltx2_text_encoder.py @@ -0,0 +1,549 @@ +import math +import torch +import torch.nn as nn +from einops import rearrange +from transformers import Gemma3ForConditionalGeneration, Gemma3Config, AutoTokenizer +from .ltx2_dit import (LTXRopeType, generate_freq_grid_np, generate_freq_grid_pytorch, precompute_freqs_cis, Attention, + FeedForward) +from .ltx2_common import rms_norm + + +class LTX2TextEncoder(Gemma3ForConditionalGeneration): + def __init__(self): + config = Gemma3Config( + **{ + "architectures": ["Gemma3ForConditionalGeneration"], + "boi_token_index": 255999, + "dtype": "bfloat16", + "eoi_token_index": 256000, + "eos_token_id": [1, 106], + "image_token_index": 262144, + "initializer_range": 0.02, + "mm_tokens_per_image": 256, + "model_type": "gemma3", + "text_config": { + "_sliding_window_pattern": 6, + "attention_bias": False, + "attention_dropout": 0.0, + "attn_logit_softcapping": None, + "cache_implementation": "hybrid", + "dtype": "bfloat16", + "final_logit_softcapping": None, + "head_dim": 256, + "hidden_activation": "gelu_pytorch_tanh", + "hidden_size": 3840, + "initializer_range": 0.02, + "intermediate_size": 15360, + "layer_types": [ + "sliding_attention", "sliding_attention", "sliding_attention", "sliding_attention", + "sliding_attention", "full_attention", "sliding_attention", "sliding_attention", + "sliding_attention", "sliding_attention", "sliding_attention", "full_attention", + "sliding_attention", "sliding_attention", "sliding_attention", "sliding_attention", + "sliding_attention", "full_attention", "sliding_attention", "sliding_attention", + "sliding_attention", "sliding_attention", "sliding_attention", "full_attention", + "sliding_attention", "sliding_attention", "sliding_attention", "sliding_attention", + "sliding_attention", "full_attention", "sliding_attention", "sliding_attention", + "sliding_attention", "sliding_attention", "sliding_attention", "full_attention", + "sliding_attention", "sliding_attention", "sliding_attention", "sliding_attention", + "sliding_attention", "full_attention", "sliding_attention", "sliding_attention", + "sliding_attention", "sliding_attention", "sliding_attention", "full_attention" + ], + "max_position_embeddings": 131072, + "model_type": "gemma3_text", + "num_attention_heads": 16, + "num_hidden_layers": 48, + "num_key_value_heads": 8, + "query_pre_attn_scalar": 256, + "rms_norm_eps": 1e-06, + "rope_local_base_freq": 10000, + "rope_scaling": { + "factor": 8.0, + "rope_type": "linear" + }, + "rope_theta": 1000000, + "sliding_window": 1024, + "sliding_window_pattern": 6, + "use_bidirectional_attention": False, + "use_cache": True, + "vocab_size": 262208 + }, + "transformers_version": "4.57.3", + "vision_config": { + "attention_dropout": 0.0, + "dtype": "bfloat16", + "hidden_act": "gelu_pytorch_tanh", + "hidden_size": 1152, + "image_size": 896, + "intermediate_size": 4304, + "layer_norm_eps": 1e-06, + "model_type": "siglip_vision_model", + "num_attention_heads": 16, + "num_channels": 3, + "num_hidden_layers": 27, + "patch_size": 14, + "vision_use_head": False + } + }) + super().__init__(config) + + +class LTXVGemmaTokenizer: + """ + Tokenizer wrapper for Gemma models compatible with LTXV processes. + This class wraps HuggingFace's `AutoTokenizer` for use with Gemma text encoders, + ensuring correct settings and output formatting for downstream consumption. + """ + + def __init__(self, tokenizer_path: str, max_length: int = 1024): + """ + Initialize the tokenizer. + Args: + tokenizer_path (str): Path to the pretrained tokenizer files or model directory. + max_length (int, optional): Max sequence length for encoding. Defaults to 256. + """ + self.tokenizer = AutoTokenizer.from_pretrained( + tokenizer_path, local_files_only=True, model_max_length=max_length + ) + # Gemma expects left padding for chat-style prompts; for plain text it doesn't matter much. + self.tokenizer.padding_side = "left" + if self.tokenizer.pad_token is None: + self.tokenizer.pad_token = self.tokenizer.eos_token + + self.max_length = max_length + + def tokenize_with_weights(self, text: str, return_word_ids: bool = False) -> dict[str, list[tuple[int, int]]]: + """ + Tokenize the given text and return token IDs and attention weights. + Args: + text (str): The input string to tokenize. + return_word_ids (bool, optional): If True, includes the token's position (index) in the output tuples. + If False (default), omits the indices. + Returns: + dict[str, list[tuple[int, int]]] OR dict[str, list[tuple[int, int, int]]]: + A dictionary with a "gemma" key mapping to: + - a list of (token_id, attention_mask) tuples if return_word_ids is False; + - a list of (token_id, attention_mask, index) tuples if return_word_ids is True. + Example: + >>> tokenizer = LTXVGemmaTokenizer("path/to/tokenizer", max_length=8) + >>> tokenizer.tokenize_with_weights("hello world") + {'gemma': [(1234, 1), (5678, 1), (2, 0), ...]} + """ + text = text.strip() + encoded = self.tokenizer( + text, + padding="max_length", + max_length=self.max_length, + truncation=True, + return_tensors="pt", + ) + input_ids = encoded.input_ids + attention_mask = encoded.attention_mask + tuples = [ + (token_id, attn, i) for i, (token_id, attn) in enumerate(zip(input_ids[0], attention_mask[0], strict=True)) + ] + out = {"gemma": tuples} + + if not return_word_ids: + # Return only (token_id, attention_mask) pairs, omitting token position + out = {k: [(t, w) for t, w, _ in v] for k, v in out.items()} + + return out + + +class GemmaFeaturesExtractorProjLinear(nn.Module): + """ + Feature extractor module for Gemma models. + This module applies a single linear projection to the input tensor. + It expects a flattened feature tensor of shape (batch_size, 3840*49). + The linear layer maps this to a (batch_size, 3840) embedding. + Attributes: + aggregate_embed (nn.Linear): Linear projection layer. + """ + + def __init__(self) -> None: + """ + Initialize the GemmaFeaturesExtractorProjLinear module. + The input dimension is expected to be 3840 * 49, and the output is 3840. + """ + super().__init__() + self.aggregate_embed = nn.Linear(3840 * 49, 3840, bias=False) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + padding_side: str = "left", + ) -> tuple[torch.Tensor, torch.Tensor | None]: + encoded = torch.stack(hidden_states, dim=-1) if isinstance(hidden_states, (list, tuple)) else hidden_states + dtype = encoded.dtype + sequence_lengths = attention_mask.sum(dim=-1) + normed = _norm_and_concat_padded_batch(encoded, sequence_lengths, padding_side) + features = self.aggregate_embed(normed.to(dtype)) + return features, features + + +class GemmaSeperatedFeaturesExtractorProjLinear(nn.Module): + """22B: per-token RMS norm → rescale → dual aggregate embeds""" + + def __init__( + self, + num_layers: int, + embedding_dim: int, + video_inner_dim: int, + audio_inner_dim: int, + ): + super().__init__() + in_dim = embedding_dim * num_layers + self.video_aggregate_embed = torch.nn.Linear(in_dim, video_inner_dim, bias=True) + self.audio_aggregate_embed = torch.nn.Linear(in_dim, audio_inner_dim, bias=True) + self.embedding_dim = embedding_dim + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + padding_side: str = "left", # noqa: ARG002 + ) -> tuple[torch.Tensor, torch.Tensor | None]: + encoded = torch.stack(hidden_states, dim=-1) if isinstance(hidden_states, (list, tuple)) else hidden_states + normed = norm_and_concat_per_token_rms(encoded, attention_mask) + normed = normed.to(encoded.dtype) + v_dim = self.video_aggregate_embed.out_features + video = self.video_aggregate_embed(_rescale_norm(normed, v_dim, self.embedding_dim)) + audio = None + if self.audio_aggregate_embed is not None: + a_dim = self.audio_aggregate_embed.out_features + audio = self.audio_aggregate_embed(_rescale_norm(normed, a_dim, self.embedding_dim)) + return video, audio + + + +class _BasicTransformerBlock1D(nn.Module): + def __init__( + self, + dim: int, + heads: int, + dim_head: int, + rope_type: LTXRopeType = LTXRopeType.INTERLEAVED, + apply_gated_attention: bool = False, + ): + super().__init__() + + self.attn1 = Attention( + query_dim=dim, + heads=heads, + dim_head=dim_head, + rope_type=rope_type, + apply_gated_attention=apply_gated_attention, + ) + + self.ff = FeedForward( + dim, + dim_out=dim, + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + pe: torch.Tensor | None = None, + ) -> torch.Tensor: + # Notice that normalization is always applied before the real computation in the following blocks. + + # 1. Normalization Before Self-Attention + norm_hidden_states = rms_norm(hidden_states) + + norm_hidden_states = norm_hidden_states.squeeze(1) + + # 2. Self-Attention + attn_output = self.attn1(norm_hidden_states, mask=attention_mask, pe=pe) + + hidden_states = attn_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + # 3. Normalization before Feed-Forward + norm_hidden_states = rms_norm(hidden_states) + + # 4. Feed-forward + ff_output = self.ff(norm_hidden_states) + + hidden_states = ff_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + return hidden_states + + +class Embeddings1DConnector(nn.Module): + """ + Embeddings1DConnector applies a 1D transformer-based processing to sequential embeddings (e.g., for video, audio, or + other modalities). It supports rotary positional encoding (rope), optional causal temporal positioning, and can + substitute padded positions with learnable registers. The module is highly configurable for head size, number of + layers, and register usage. + Args: + attention_head_dim (int): Dimension of each attention head (default=128). + num_attention_heads (int): Number of attention heads (default=30). + num_layers (int): Number of transformer layers (default=2). + positional_embedding_theta (float): Scaling factor for position embedding (default=10000.0). + positional_embedding_max_pos (list[int] | None): Max positions for positional embeddings (default=[1]). + causal_temporal_positioning (bool): If True, uses causal attention (default=False). + num_learnable_registers (int | None): Number of learnable registers to replace padded tokens. If None, disables + register replacement. (default=128) + rope_type (LTXRopeType): The RoPE variant to use (default=DEFAULT_ROPE_TYPE). + double_precision_rope (bool): Use double precision rope calculation (default=False). + """ + + _supports_gradient_checkpointing = True + + def __init__( + self, + attention_head_dim: int = 128, + num_attention_heads: int = 30, + num_layers: int = 2, + positional_embedding_theta: float = 10000.0, + positional_embedding_max_pos: list[int] | None = [4096], + causal_temporal_positioning: bool = False, + num_learnable_registers: int | None = 128, + rope_type: LTXRopeType = LTXRopeType.SPLIT, + double_precision_rope: bool = True, + apply_gated_attention: bool = False, + ): + super().__init__() + self.num_attention_heads = num_attention_heads + self.inner_dim = num_attention_heads * attention_head_dim + self.causal_temporal_positioning = causal_temporal_positioning + self.positional_embedding_theta = positional_embedding_theta + self.positional_embedding_max_pos = ( + positional_embedding_max_pos if positional_embedding_max_pos is not None else [1] + ) + self.rope_type = rope_type + self.double_precision_rope = double_precision_rope + self.transformer_1d_blocks = nn.ModuleList( + [ + _BasicTransformerBlock1D( + dim=self.inner_dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + rope_type=rope_type, + apply_gated_attention=apply_gated_attention, + ) + for _ in range(num_layers) + ] + ) + + self.num_learnable_registers = num_learnable_registers + if self.num_learnable_registers: + self.learnable_registers = nn.Parameter( + torch.rand(self.num_learnable_registers, self.inner_dim, dtype=torch.bfloat16) * 2.0 - 1.0 + ) + + def _replace_padded_with_learnable_registers( + self, hidden_states: torch.Tensor, attention_mask: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + assert hidden_states.shape[1] % self.num_learnable_registers == 0, ( + f"Hidden states sequence length {hidden_states.shape[1]} must be divisible by num_learnable_registers " + f"{self.num_learnable_registers}." + ) + + num_registers_duplications = hidden_states.shape[1] // self.num_learnable_registers + learnable_registers = torch.tile(self.learnable_registers, (num_registers_duplications, 1)) + attention_mask_binary = (attention_mask.squeeze(1).squeeze(1).unsqueeze(-1) >= -9000.0).int() + + non_zero_hidden_states = hidden_states[:, attention_mask_binary.squeeze().bool(), :] + non_zero_nums = non_zero_hidden_states.shape[1] + pad_length = hidden_states.shape[1] - non_zero_nums + adjusted_hidden_states = nn.functional.pad(non_zero_hidden_states, pad=(0, 0, 0, pad_length), value=0) + flipped_mask = torch.flip(attention_mask_binary, dims=[1]) + hidden_states = flipped_mask * adjusted_hidden_states + (1 - flipped_mask) * learnable_registers + + attention_mask = torch.full_like( + attention_mask, + 0.0, + dtype=attention_mask.dtype, + device=attention_mask.device, + ) + + return hidden_states, attention_mask + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Forward pass of Embeddings1DConnector. + Args: + hidden_states (torch.Tensor): Input tensor of embeddings (shape [batch, seq_len, feature_dim]). + attention_mask (torch.Tensor|None): Optional mask for valid tokens (shape compatible with hidden_states). + Returns: + tuple[torch.Tensor, torch.Tensor]: Processed features and the corresponding (possibly modified) mask. + """ + if self.num_learnable_registers: + hidden_states, attention_mask = self._replace_padded_with_learnable_registers(hidden_states, attention_mask) + + indices_grid = torch.arange(hidden_states.shape[1], dtype=torch.float32, device=hidden_states.device) + indices_grid = indices_grid[None, None, :] + freq_grid_generator = generate_freq_grid_np if self.double_precision_rope else generate_freq_grid_pytorch + freqs_cis = precompute_freqs_cis( + indices_grid=indices_grid, + dim=self.inner_dim, + out_dtype=hidden_states.dtype, + theta=self.positional_embedding_theta, + max_pos=self.positional_embedding_max_pos, + num_attention_heads=self.num_attention_heads, + rope_type=self.rope_type, + freq_grid_generator=freq_grid_generator, + ) + + for block in self.transformer_1d_blocks: + hidden_states = block(hidden_states, attention_mask=attention_mask, pe=freqs_cis) + + hidden_states = rms_norm(hidden_states) + + return hidden_states, attention_mask + + +class LTX2TextEncoderPostModules(nn.Module): + def __init__( + self, + separated_audio_video: bool = False, + embedding_dim_gemma: int = 3840, + num_layers_gemma: int = 49, + video_attention_heads: int = 32, + video_attention_head_dim: int = 128, + audio_attention_heads: int = 32, + audio_attention_head_dim: int = 64, + num_connector_layers: int = 2, + apply_gated_attention: bool = False, + ): + super().__init__() + if not separated_audio_video: + self.feature_extractor_linear = GemmaFeaturesExtractorProjLinear() + self.embeddings_connector = Embeddings1DConnector() + self.audio_embeddings_connector = Embeddings1DConnector() + else: + # LTX-2.3 + self.feature_extractor_linear = GemmaSeperatedFeaturesExtractorProjLinear( + num_layers_gemma, embedding_dim_gemma, video_attention_heads * video_attention_head_dim, + audio_attention_heads * audio_attention_head_dim) + self.embeddings_connector = Embeddings1DConnector( + attention_head_dim=video_attention_head_dim, + num_attention_heads=video_attention_heads, + num_layers=num_connector_layers, + apply_gated_attention=apply_gated_attention, + ) + self.audio_embeddings_connector = Embeddings1DConnector( + attention_head_dim=audio_attention_head_dim, + num_attention_heads=audio_attention_heads, + num_layers=num_connector_layers, + apply_gated_attention=apply_gated_attention, + ) + + def create_embeddings( + self, + video_features: torch.Tensor, + audio_features: torch.Tensor | None, + additive_attention_mask: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor]: + video_encoded, video_mask = self.embeddings_connector(video_features, additive_attention_mask) + video_encoded, binary_mask = _to_binary_mask(video_encoded, video_mask) + audio_encoded, _ = self.audio_embeddings_connector(audio_features, additive_attention_mask) + + return video_encoded, audio_encoded, binary_mask + + def process_hidden_states( + self, + hidden_states: tuple[torch.Tensor, ...], + attention_mask: torch.Tensor, + padding_side: str = "left", + ): + video_feats, audio_feats = self.feature_extractor_linear(hidden_states, attention_mask, padding_side) + additive_mask = _convert_to_additive_mask(attention_mask, video_feats.dtype) + video_enc, audio_enc, binary_mask = self.create_embeddings(video_feats, audio_feats, additive_mask) + return video_enc, audio_enc, binary_mask + + +def _norm_and_concat_padded_batch( + encoded_text: torch.Tensor, + sequence_lengths: torch.Tensor, + padding_side: str = "right", +) -> torch.Tensor: + """Normalize and flatten multi-layer hidden states, respecting padding. + Performs per-batch, per-layer normalization using masked mean and range, + then concatenates across the layer dimension. + Args: + encoded_text: Hidden states of shape [batch, seq_len, hidden_dim, num_layers]. + sequence_lengths: Number of valid (non-padded) tokens per batch item. + padding_side: Whether padding is on "left" or "right". + Returns: + Normalized tensor of shape [batch, seq_len, hidden_dim * num_layers], + with padded positions zeroed out. + """ + b, t, d, l = encoded_text.shape # noqa: E741 + device = encoded_text.device + # Build mask: [B, T, 1, 1] + token_indices = torch.arange(t, device=device)[None, :] # [1, T] + if padding_side == "right": + # For right padding, valid tokens are from 0 to sequence_length-1 + mask = token_indices < sequence_lengths[:, None] # [B, T] + elif padding_side == "left": + # For left padding, valid tokens are from (T - sequence_length) to T-1 + start_indices = t - sequence_lengths[:, None] # [B, 1] + mask = token_indices >= start_indices # [B, T] + else: + raise ValueError(f"padding_side must be 'left' or 'right', got {padding_side}") + mask = rearrange(mask, "b t -> b t 1 1") + eps = 1e-6 + # Compute masked mean: [B, 1, 1, L] + masked = encoded_text.masked_fill(~mask, 0.0) + denom = (sequence_lengths * d).view(b, 1, 1, 1) + mean = masked.sum(dim=(1, 2), keepdim=True) / (denom + eps) + # Compute masked min/max: [B, 1, 1, L] + x_min = encoded_text.masked_fill(~mask, float("inf")).amin(dim=(1, 2), keepdim=True) + x_max = encoded_text.masked_fill(~mask, float("-inf")).amax(dim=(1, 2), keepdim=True) + range_ = x_max - x_min + # Normalize only the valid tokens + normed = 8 * (encoded_text - mean) / (range_ + eps) + # concat to be [Batch, T, D * L] - this preserves the original structure + normed = normed.reshape(b, t, -1) # [B, T, D * L] + # Apply mask to preserve original padding (set padded positions to 0) + mask_flattened = rearrange(mask, "b t 1 1 -> b t 1").expand(-1, -1, d * l) + normed = normed.masked_fill(~mask_flattened, 0.0) + + return normed + + +def _convert_to_additive_mask(attention_mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: + return (attention_mask - 1).to(dtype).reshape( + (attention_mask.shape[0], 1, -1, attention_mask.shape[-1])) * torch.finfo(dtype).max + +def _to_binary_mask(encoded: torch.Tensor, encoded_mask: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """Convert connector output mask to binary mask and apply to encoded tensor.""" + binary_mask = (encoded_mask < 0.000001).to(torch.int64) + binary_mask = binary_mask.reshape([encoded.shape[0], encoded.shape[1], 1]) + encoded = encoded * binary_mask + return encoded, binary_mask + + +def norm_and_concat_per_token_rms( + encoded_text: torch.Tensor, + attention_mask: torch.Tensor, +) -> torch.Tensor: + """Per-token RMSNorm normalization for V2 models. + Args: + encoded_text: [B, T, D, L] + attention_mask: [B, T] binary mask + Returns: + [B, T, D*L] normalized tensor with padding zeroed out. + """ + B, T, D, L = encoded_text.shape # noqa: N806 + variance = torch.mean(encoded_text**2, dim=2, keepdim=True) # [B,T,1,L] + normed = encoded_text * torch.rsqrt(variance + 1e-6) + normed = normed.reshape(B, T, D * L) + mask_3d = attention_mask.bool().unsqueeze(-1) # [B, T, 1] + return torch.where(mask_3d, normed, torch.zeros_like(normed)) + + +def _rescale_norm(x: torch.Tensor, target_dim: int, source_dim: int) -> torch.Tensor: + """Rescale normalization: x * sqrt(target_dim / source_dim).""" + return x * math.sqrt(target_dim / source_dim) diff --git a/diffsynth/models/ltx2_upsampler.py b/diffsynth/models/ltx2_upsampler.py new file mode 100644 index 0000000000000000000000000000000000000000..862ca14bc320880f2b10c933fadc575c8ebad681 --- /dev/null +++ b/diffsynth/models/ltx2_upsampler.py @@ -0,0 +1,313 @@ +import math +from typing import Optional, Tuple +import torch +from einops import rearrange +import torch.nn.functional as F +from .ltx2_video_vae import LTX2VideoEncoder + +class PixelShuffleND(torch.nn.Module): + """ + N-dimensional pixel shuffle operation for upsampling tensors. + Args: + dims (int): Number of dimensions to apply pixel shuffle to. + - 1: Temporal (e.g., frames) + - 2: Spatial (e.g., height and width) + - 3: Spatiotemporal (e.g., depth, height, width) + upscale_factors (tuple[int, int, int], optional): Upscaling factors for each dimension. + For dims=1, only the first value is used. + For dims=2, the first two values are used. + For dims=3, all three values are used. + The input tensor is rearranged so that the channel dimension is split into + smaller channels and upscaling factors, and the upscaling factors are moved + into the corresponding spatial/temporal dimensions. + Note: + This operation is equivalent to the patchifier operation in for the models. Consider + using this class instead. + """ + + def __init__(self, dims: int, upscale_factors: tuple[int, int, int] = (2, 2, 2)): + super().__init__() + assert dims in [1, 2, 3], "dims must be 1, 2, or 3" + self.dims = dims + self.upscale_factors = upscale_factors + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.dims == 3: + return rearrange( + x, + "b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)", + p1=self.upscale_factors[0], + p2=self.upscale_factors[1], + p3=self.upscale_factors[2], + ) + elif self.dims == 2: + return rearrange( + x, + "b (c p1 p2) h w -> b c (h p1) (w p2)", + p1=self.upscale_factors[0], + p2=self.upscale_factors[1], + ) + elif self.dims == 1: + return rearrange( + x, + "b (c p1) f h w -> b c (f p1) h w", + p1=self.upscale_factors[0], + ) + else: + raise ValueError(f"Unsupported dims: {self.dims}") + + +class ResBlock(torch.nn.Module): + """ + Residual block with two convolutional layers, group normalization, and SiLU activation. + Args: + channels (int): Number of input and output channels. + mid_channels (Optional[int]): Number of channels in the intermediate convolution layer. Defaults to `channels` + if not specified. + dims (int): Dimensionality of the convolution (2 for Conv2d, 3 for Conv3d). Defaults to 3. + """ + + def __init__(self, channels: int, mid_channels: Optional[int] = None, dims: int = 3): + super().__init__() + if mid_channels is None: + mid_channels = channels + + conv = torch.nn.Conv2d if dims == 2 else torch.nn.Conv3d + + self.conv1 = conv(channels, mid_channels, kernel_size=3, padding=1) + self.norm1 = torch.nn.GroupNorm(32, mid_channels) + self.conv2 = conv(mid_channels, channels, kernel_size=3, padding=1) + self.norm2 = torch.nn.GroupNorm(32, channels) + self.activation = torch.nn.SiLU() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + residual = x + x = self.conv1(x) + x = self.norm1(x) + x = self.activation(x) + x = self.conv2(x) + x = self.norm2(x) + x = self.activation(x + residual) + return x + + +class BlurDownsample(torch.nn.Module): + """ + Anti-aliased spatial downsampling by integer stride using a fixed separable binomial kernel. + Applies only on H,W. Works for dims=2 or dims=3 (per-frame). + """ + + def __init__(self, dims: int, stride: int, kernel_size: int = 5) -> None: + super().__init__() + assert dims in (2, 3) + assert isinstance(stride, int) + assert stride >= 1 + assert kernel_size >= 3 + assert kernel_size % 2 == 1 + self.dims = dims + self.stride = stride + self.kernel_size = kernel_size + + # 5x5 separable binomial kernel using binomial coefficients [1, 4, 6, 4, 1] from + # the 4th row of Pascal's triangle. This kernel is used for anti-aliasing and + # provides a smooth approximation of a Gaussian filter (often called a "binomial filter"). + # The 2D kernel is constructed as the outer product and normalized. + k = torch.tensor([math.comb(kernel_size - 1, k) for k in range(kernel_size)]) + k2d = k[:, None] @ k[None, :] + k2d = (k2d / k2d.sum()).float() # shape (kernel_size, kernel_size) + self.register_buffer("kernel", k2d[None, None, :, :]) # (1, 1, kernel_size, kernel_size) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.stride == 1: + return x + + if self.dims == 2: + return self._apply_2d(x) + else: + # dims == 3: apply per-frame on H,W + b, _, f, _, _ = x.shape + x = rearrange(x, "b c f h w -> (b f) c h w") + x = self._apply_2d(x) + h2, w2 = x.shape[-2:] + x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f, h=h2, w=w2) + return x + + def _apply_2d(self, x2d: torch.Tensor) -> torch.Tensor: + c = x2d.shape[1] + weight = self.kernel.expand(c, 1, self.kernel_size, self.kernel_size) # depthwise + x2d = F.conv2d(x2d, weight=weight, bias=None, stride=self.stride, padding=self.kernel_size // 2, groups=c) + return x2d + + +def _rational_for_scale(scale: float) -> Tuple[int, int]: + mapping = {0.75: (3, 4), 1.5: (3, 2), 2.0: (2, 1), 4.0: (4, 1)} + if float(scale) not in mapping: + raise ValueError(f"Unsupported scale {scale}. Choose from {list(mapping.keys())}") + return mapping[float(scale)] + + +class SpatialRationalResampler(torch.nn.Module): + """ + Fully-learned rational spatial scaling: up by 'num' via PixelShuffle, then anti-aliased + downsample by 'den' using fixed blur + stride. Operates on H,W only. + For dims==3, work per-frame for spatial scaling (temporal axis untouched). + Args: + mid_channels (`int`): Number of intermediate channels for the convolution layer + scale (`float`): Spatial scaling factor. Supported values are: + - 0.75: Downsample by 3/4 (reduce spatial size) + - 1.5: Upsample by 3/2 (increase spatial size) + - 2.0: Upsample by 2x (double spatial size) + - 4.0: Upsample by 4x (quadruple spatial size) + Any other value will raise a ValueError. + """ + + def __init__(self, mid_channels: int, scale: float): + super().__init__() + self.scale = float(scale) + self.num, self.den = _rational_for_scale(self.scale) + self.conv = torch.nn.Conv2d(mid_channels, (self.num**2) * mid_channels, kernel_size=3, padding=1) + self.pixel_shuffle = PixelShuffleND(2, upscale_factors=(self.num, self.num)) + self.blur_down = BlurDownsample(dims=2, stride=self.den) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + b, _, f, _, _ = x.shape + x = rearrange(x, "b c f h w -> (b f) c h w") + x = self.conv(x) + x = self.pixel_shuffle(x) + x = self.blur_down(x) + x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f) + return x + + +class LTX2LatentUpsampler(torch.nn.Module): + """ + Model to upsample VAE latents spatially and/or temporally. + Args: + in_channels (`int`): Number of channels in the input latent + mid_channels (`int`): Number of channels in the middle layers + num_blocks_per_stage (`int`): Number of ResBlocks to use in each stage (pre/post upsampling) + dims (`int`): Number of dimensions for convolutions (2 or 3) + spatial_upsample (`bool`): Whether to spatially upsample the latent + temporal_upsample (`bool`): Whether to temporally upsample the latent + spatial_scale (`float`): Scale factor for spatial upsampling + rational_resampler (`bool`): Whether to use a rational resampler for spatial upsampling + """ + def __init__( + self, + in_channels: int = 128, + mid_channels: int = 1024, + num_blocks_per_stage: int = 4, + dims: int = 3, + spatial_upsample: bool = True, + temporal_upsample: bool = False, + spatial_scale: float = 2.0, + rational_resampler: bool = True, + ): + super().__init__() + + self.in_channels = in_channels + self.mid_channels = mid_channels + self.num_blocks_per_stage = num_blocks_per_stage + self.dims = dims + self.spatial_upsample = spatial_upsample + self.temporal_upsample = temporal_upsample + self.spatial_scale = float(spatial_scale) + self.rational_resampler = rational_resampler + + conv = torch.nn.Conv2d if dims == 2 else torch.nn.Conv3d + + self.initial_conv = conv(in_channels, mid_channels, kernel_size=3, padding=1) + self.initial_norm = torch.nn.GroupNorm(32, mid_channels) + self.initial_activation = torch.nn.SiLU() + + self.res_blocks = torch.nn.ModuleList([ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)]) + + if spatial_upsample and temporal_upsample: + self.upsampler = torch.nn.Sequential( + torch.nn.Conv3d(mid_channels, 8 * mid_channels, kernel_size=3, padding=1), + PixelShuffleND(3), + ) + elif spatial_upsample: + if rational_resampler: + self.upsampler = SpatialRationalResampler(mid_channels=mid_channels, scale=self.spatial_scale) + else: + self.upsampler = torch.nn.Sequential( + torch.nn.Conv2d(mid_channels, 4 * mid_channels, kernel_size=3, padding=1), + PixelShuffleND(2), + ) + elif temporal_upsample: + self.upsampler = torch.nn.Sequential( + torch.nn.Conv3d(mid_channels, 2 * mid_channels, kernel_size=3, padding=1), + PixelShuffleND(1), + ) + else: + raise ValueError("Either spatial_upsample or temporal_upsample must be True") + + self.post_upsample_res_blocks = torch.nn.ModuleList( + [ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)] + ) + + self.final_conv = conv(mid_channels, in_channels, kernel_size=3, padding=1) + + def forward(self, latent: torch.Tensor) -> torch.Tensor: + b, _, f, _, _ = latent.shape + + if self.dims == 2: + x = rearrange(latent, "b c f h w -> (b f) c h w") + x = self.initial_conv(x) + x = self.initial_norm(x) + x = self.initial_activation(x) + + for block in self.res_blocks: + x = block(x) + + x = self.upsampler(x) + + for block in self.post_upsample_res_blocks: + x = block(x) + + x = self.final_conv(x) + x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f) + else: + x = self.initial_conv(latent) + x = self.initial_norm(x) + x = self.initial_activation(x) + + for block in self.res_blocks: + x = block(x) + + if self.temporal_upsample: + x = self.upsampler(x) + # remove the first frame after upsampling. + # This is done because the first frame encodes one pixel frame. + x = x[:, :, 1:, :, :] + elif isinstance(self.upsampler, SpatialRationalResampler): + x = self.upsampler(x) + else: + x = rearrange(x, "b c f h w -> (b f) c h w") + x = self.upsampler(x) + x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f) + + for block in self.post_upsample_res_blocks: + x = block(x) + + x = self.final_conv(x) + + return x + + +def upsample_video(latent: torch.Tensor, video_encoder: LTX2VideoEncoder, upsampler: "LTX2LatentUpsampler") -> torch.Tensor: + """ + Apply upsampling to the latent representation using the provided upsampler, + with normalization and un-normalization based on the video encoder's per-channel statistics. + Args: + latent: Input latent tensor of shape [B, C, F, H, W]. + video_encoder: VideoEncoder with per_channel_statistics for normalization. + upsampler: LTX2LatentUpsampler module to perform upsampling. + Returns: + torch.Tensor: Upsampled and re-normalized latent tensor. + """ + latent = video_encoder.per_channel_statistics.un_normalize(latent) + latent = upsampler(latent) + latent = video_encoder.per_channel_statistics.normalize(latent) + return latent diff --git a/diffsynth/models/ltx2_video_vae.py b/diffsynth/models/ltx2_video_vae.py new file mode 100644 index 0000000000000000000000000000000000000000..a70dc80e830e6910976451247b11334e824f34d4 --- /dev/null +++ b/diffsynth/models/ltx2_video_vae.py @@ -0,0 +1,2322 @@ +import itertools +import math +import einops +from dataclasses import replace, dataclass +from typing import Any, Callable, Iterator, List, NamedTuple, Tuple, Union, Optional +import torch +from einops import rearrange +from torch import nn +from torch.nn import functional as F +from enum import Enum +from .ltx2_common import PixelNorm, SpatioTemporalScaleFactors, VideoLatentShape, Patchifier, AudioLatentShape +from .ltx2_dit import PixArtAlphaCombinedTimestepSizeEmbeddings + +VAE_SPATIAL_FACTOR = 32 +VAE_TEMPORAL_FACTOR = 8 + + +class VideoLatentPatchifier(Patchifier): + def __init__(self, patch_size: int): + # Patch sizes for video latents. + self._patch_size = ( + 1, # temporal dimension + patch_size, # height dimension + patch_size, # width dimension + ) + + @property + def patch_size(self) -> Tuple[int, int, int]: + return self._patch_size + + def get_token_count(self, tgt_shape: VideoLatentShape) -> int: + return math.prod(tgt_shape.to_torch_shape()[2:]) // math.prod(self._patch_size) + + def patchify( + self, + latents: torch.Tensor, + ) -> torch.Tensor: + latents = einops.rearrange( + latents, + "b c (f p1) (h p2) (w p3) -> b (f h w) (c p1 p2 p3)", + p1=self._patch_size[0], + p2=self._patch_size[1], + p3=self._patch_size[2], + ) + + return latents + + def unpatchify( + self, + latents: torch.Tensor, + output_shape: VideoLatentShape, + ) -> torch.Tensor: + assert self._patch_size[0] == 1, "Temporal patch size must be 1 for symmetric patchifier" + + patch_grid_frames = output_shape.frames // self._patch_size[0] + patch_grid_height = output_shape.height // self._patch_size[1] + patch_grid_width = output_shape.width // self._patch_size[2] + + latents = einops.rearrange( + latents, + "b (f h w) (c p q) -> b c f (h p) (w q)", + f=patch_grid_frames, + h=patch_grid_height, + w=patch_grid_width, + p=self._patch_size[1], + q=self._patch_size[2], + ) + + return latents + + def unpatchify_video( + self, + latents: torch.Tensor, + frames: int, + height: int, + width: int, + ) -> torch.Tensor: + latents = einops.rearrange( + latents, + "b (f h w) (c p q) -> b c f (h p) (w q)", + f=frames, + h=height // self._patch_size[1], + w=width // self._patch_size[2], + p=self._patch_size[1], + q=self._patch_size[2], + ) + return latents + + def get_patch_grid_bounds( + self, + output_shape: AudioLatentShape | VideoLatentShape, + device: Optional[torch.device] = None, + ) -> torch.Tensor: + """ + Return the per-dimension bounds [inclusive start, exclusive end) for every + patch produced by `patchify`. The bounds are expressed in the original + video grid coordinates: frame/time, height, and width. + The resulting tensor is shaped `[batch_size, 3, num_patches, 2]`, where: + - axis 1 (size 3) enumerates (frame/time, height, width) dimensions + - axis 3 (size 2) stores `[start, end)` indices within each dimension + Args: + output_shape: Video grid description containing frames, height, and width. + device: Device of the latent tensor. + """ + if not isinstance(output_shape, VideoLatentShape): + raise ValueError("VideoLatentPatchifier expects VideoLatentShape when computing coordinates") + + frames = output_shape.frames + height = output_shape.height + width = output_shape.width + batch_size = output_shape.batch + + # Validate inputs to ensure positive dimensions + assert frames > 0, f"frames must be positive, got {frames}" + assert height > 0, f"height must be positive, got {height}" + assert width > 0, f"width must be positive, got {width}" + assert batch_size > 0, f"batch_size must be positive, got {batch_size}" + + # Generate grid coordinates for each dimension (frame, height, width) + # We use torch.arange to create the starting coordinates for each patch. + # indexing='ij' ensures the dimensions are in the order (frame, height, width). + grid_coords = torch.meshgrid( + torch.arange(start=0, end=frames, step=self._patch_size[0], device=device), + torch.arange(start=0, end=height, step=self._patch_size[1], device=device), + torch.arange(start=0, end=width, step=self._patch_size[2], device=device), + indexing="ij", + ) + + # Stack the grid coordinates to create the start coordinates tensor. + # Shape becomes (3, grid_f, grid_h, grid_w) + patch_starts = torch.stack(grid_coords, dim=0) + + # Create a tensor containing the size of a single patch: + # (frame_patch_size, height_patch_size, width_patch_size). + # Reshape to (3, 1, 1, 1) to enable broadcasting when adding to the start coordinates. + patch_size_delta = torch.tensor( + self._patch_size, + device=patch_starts.device, + dtype=patch_starts.dtype, + ).view(3, 1, 1, 1) + + # Calculate end coordinates: start + patch_size + # Shape becomes (3, grid_f, grid_h, grid_w) + patch_ends = patch_starts + patch_size_delta + + # Stack start and end coordinates together along the last dimension + # Shape becomes (3, grid_f, grid_h, grid_w, 2), where the last dimension is [start, end] + latent_coords = torch.stack((patch_starts, patch_ends), dim=-1) + + # Broadcast to batch size and flatten all spatial/temporal dimensions into one sequence. + # Final Shape: (batch_size, 3, num_patches, 2) + latent_coords = einops.repeat( + latent_coords, + "c f h w bounds -> b c (f h w) bounds", + b=batch_size, + bounds=2, + ) + + return latent_coords + + +class NormLayerType(Enum): + GROUP_NORM = "group_norm" + PIXEL_NORM = "pixel_norm" + + +class LogVarianceType(Enum): + PER_CHANNEL = "per_channel" + UNIFORM = "uniform" + CONSTANT = "constant" + NONE = "none" + + +class PaddingModeType(Enum): + ZEROS = "zeros" + REFLECT = "reflect" + REPLICATE = "replicate" + CIRCULAR = "circular" + + +class DualConv3d(nn.Module): + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: Union[int, Tuple[int, int, int]] = 1, + padding: Union[int, Tuple[int, int, int]] = 0, + dilation: Union[int, Tuple[int, int, int]] = 1, + groups: int = 1, + bias: bool = True, + padding_mode: str = "zeros", + ) -> None: + super(DualConv3d, self).__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.padding_mode = padding_mode + # Ensure kernel_size, stride, padding, and dilation are tuples of length 3 + if isinstance(kernel_size, int): + kernel_size = (kernel_size, kernel_size, kernel_size) + if kernel_size == (1, 1, 1): + raise ValueError("kernel_size must be greater than 1. Use make_linear_nd instead.") + if isinstance(stride, int): + stride = (stride, stride, stride) + if isinstance(padding, int): + padding = (padding, padding, padding) + if isinstance(dilation, int): + dilation = (dilation, dilation, dilation) + + # Set parameters for convolutions + self.groups = groups + self.bias = bias + + # Define the size of the channels after the first convolution + intermediate_channels = out_channels if in_channels < out_channels else in_channels + + # Define parameters for the first convolution + self.weight1 = nn.Parameter( + torch.Tensor( + intermediate_channels, + in_channels // groups, + 1, + kernel_size[1], + kernel_size[2], + )) + self.stride1 = (1, stride[1], stride[2]) + self.padding1 = (0, padding[1], padding[2]) + self.dilation1 = (1, dilation[1], dilation[2]) + if bias: + self.bias1 = nn.Parameter(torch.Tensor(intermediate_channels)) + else: + self.register_parameter("bias1", None) + + # Define parameters for the second convolution + self.weight2 = nn.Parameter(torch.Tensor(out_channels, intermediate_channels // groups, kernel_size[0], 1, 1)) + self.stride2 = (stride[0], 1, 1) + self.padding2 = (padding[0], 0, 0) + self.dilation2 = (dilation[0], 1, 1) + if bias: + self.bias2 = nn.Parameter(torch.Tensor(out_channels)) + else: + self.register_parameter("bias2", None) + + # Initialize weights and biases + self.reset_parameters() + + def reset_parameters(self) -> None: + nn.init.kaiming_uniform_(self.weight1, a=torch.sqrt(5)) + nn.init.kaiming_uniform_(self.weight2, a=torch.sqrt(5)) + if self.bias: + fan_in1, _ = nn.init._calculate_fan_in_and_fan_out(self.weight1) + bound1 = 1 / torch.sqrt(fan_in1) + nn.init.uniform_(self.bias1, -bound1, bound1) + fan_in2, _ = nn.init._calculate_fan_in_and_fan_out(self.weight2) + bound2 = 1 / torch.sqrt(fan_in2) + nn.init.uniform_(self.bias2, -bound2, bound2) + + def forward( + self, + x: torch.Tensor, + use_conv3d: bool = False, + skip_time_conv: bool = False, + ) -> torch.Tensor: + if use_conv3d: + return self.forward_with_3d(x=x, skip_time_conv=skip_time_conv) + else: + return self.forward_with_2d(x=x, skip_time_conv=skip_time_conv) + + def forward_with_3d(self, x: torch.Tensor, skip_time_conv: bool = False) -> torch.Tensor: + # First convolution + x = F.conv3d( + x, + self.weight1, + self.bias1, + self.stride1, + self.padding1, + self.dilation1, + self.groups, + padding_mode=self.padding_mode, + ) + + if skip_time_conv: + return x + + # Second convolution + x = F.conv3d( + x, + self.weight2, + self.bias2, + self.stride2, + self.padding2, + self.dilation2, + self.groups, + padding_mode=self.padding_mode, + ) + + return x + + def forward_with_2d(self, x: torch.Tensor, skip_time_conv: bool = False) -> torch.Tensor: + b, _, _, h, w = x.shape + + # First 2D convolution + x = rearrange(x, "b c d h w -> (b d) c h w") + # Squeeze the depth dimension out of weight1 since it's 1 + weight1 = self.weight1.squeeze(2) + # Select stride, padding, and dilation for the 2D convolution + stride1 = (self.stride1[1], self.stride1[2]) + padding1 = (self.padding1[1], self.padding1[2]) + dilation1 = (self.dilation1[1], self.dilation1[2]) + x = F.conv2d( + x, + weight1, + self.bias1, + stride1, + padding1, + dilation1, + self.groups, + padding_mode=self.padding_mode, + ) + + _, _, h, w = x.shape + + if skip_time_conv: + x = rearrange(x, "(b d) c h w -> b c d h w", b=b) + return x + + # Second convolution which is essentially treated as a 1D convolution across the 'd' dimension + x = rearrange(x, "(b d) c h w -> (b h w) c d", b=b) + + # Reshape weight2 to match the expected dimensions for conv1d + weight2 = self.weight2.squeeze(-1).squeeze(-1) + # Use only the relevant dimension for stride, padding, and dilation for the 1D convolution + stride2 = self.stride2[0] + padding2 = self.padding2[0] + dilation2 = self.dilation2[0] + x = F.conv1d( + x, + weight2, + self.bias2, + stride2, + padding2, + dilation2, + self.groups, + padding_mode=self.padding_mode, + ) + x = rearrange(x, "(b h w) c d -> b c d h w", b=b, h=h, w=w) + + return x + + @property + def weight(self) -> torch.Tensor: + return self.weight2 + + +class CausalConv3d(nn.Module): + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int = 3, + stride: Union[int, Tuple[int]] = 1, + dilation: int = 1, + groups: int = 1, + bias: bool = True, + spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS, + ) -> None: + super().__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + + kernel_size = (kernel_size, kernel_size, kernel_size) + self.time_kernel_size = kernel_size[0] + + dilation = (dilation, 1, 1) + + height_pad = kernel_size[1] // 2 + width_pad = kernel_size[2] // 2 + padding = (0, height_pad, width_pad) + + self.conv = nn.Conv3d( + in_channels, + out_channels, + kernel_size, + stride=stride, + dilation=dilation, + padding=padding, + padding_mode=spatial_padding_mode.value, + groups=groups, + bias=bias, + ) + + def forward(self, x: torch.Tensor, causal: bool = True) -> torch.Tensor: + if causal: + first_frame_pad = x[:, :, :1, :, :].repeat((1, 1, self.time_kernel_size - 1, 1, 1)) + x = torch.concatenate((first_frame_pad, x), dim=2) + else: + first_frame_pad = x[:, :, :1, :, :].repeat((1, 1, (self.time_kernel_size - 1) // 2, 1, 1)) + last_frame_pad = x[:, :, -1:, :, :].repeat((1, 1, (self.time_kernel_size - 1) // 2, 1, 1)) + x = torch.concatenate((first_frame_pad, x, last_frame_pad), dim=2) + x = self.conv(x) + return x + + @property + def weight(self) -> torch.Tensor: + return self.conv.weight + + +def make_conv_nd( # noqa: PLR0913 + dims: Union[int, Tuple[int, int]], + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + padding: int = 0, + dilation: int = 1, + groups: int = 1, + bias: bool = True, + causal: bool = False, + spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS, + temporal_padding_mode: PaddingModeType = PaddingModeType.ZEROS, +) -> nn.Module: + if not (spatial_padding_mode == temporal_padding_mode or causal): + raise NotImplementedError("spatial and temporal padding modes must be equal") + if dims == 2: + return nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + padding_mode=spatial_padding_mode.value, + ) + elif dims == 3: + if causal: + return CausalConv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + groups=groups, + bias=bias, + spatial_padding_mode=spatial_padding_mode, + ) + return nn.Conv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + padding_mode=spatial_padding_mode.value, + ) + elif dims == (2, 1): + return DualConv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + bias=bias, + padding_mode=spatial_padding_mode.value, + ) + else: + raise ValueError(f"unsupported dimensions: {dims}") + + +def make_linear_nd( + dims: int, + in_channels: int, + out_channels: int, + bias: bool = True, +) -> nn.Module: + if dims == 2: + return nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias) + elif dims in (3, (2, 1)): + return nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias) + else: + raise ValueError(f"unsupported dimensions: {dims}") + + +def patchify(x: torch.Tensor, patch_size_hw: int, patch_size_t: int = 1) -> torch.Tensor: + """ + Rearrange spatial dimensions into channels. Divides image into patch_size x patch_size blocks + and moves pixels from each block into separate channels (space-to-depth). + Args: + x: Input tensor (4D or 5D) + patch_size_hw: Spatial patch size for height and width. With patch_size_hw=4, divides HxW into 4x4 blocks. + patch_size_t: Temporal patch size for frames. Default=1 (no temporal patching). + For 5D: (B, C, F, H, W) -> (B, Cx(patch_size_hw^2)x(patch_size_t), F/patch_size_t, H/patch_size_hw, W/patch_size_hw) + Example: (B, 3, 33, 512, 512) with patch_size_hw=4, patch_size_t=1 -> (B, 48, 33, 128, 128) + """ + if patch_size_hw == 1 and patch_size_t == 1: + return x + if x.dim() == 4: + x = rearrange(x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size_hw, r=patch_size_hw) + elif x.dim() == 5: + x = rearrange( + x, + "b c (f p) (h q) (w r) -> b (c p r q) f h w", + p=patch_size_t, + q=patch_size_hw, + r=patch_size_hw, + ) + else: + raise ValueError(f"Invalid input shape: {x.shape}") + + return x + + +def unpatchify(x: torch.Tensor, patch_size_hw: int, patch_size_t: int = 1) -> torch.Tensor: + """ + Rearrange channels back into spatial dimensions. Inverse of patchify - moves pixels from + channels back into patch_size x patch_size blocks (depth-to-space). + Args: + x: Input tensor (4D or 5D) + patch_size_hw: Spatial patch size for height and width. With patch_size_hw=4, expands HxW by 4x. + patch_size_t: Temporal patch size for frames. Default=1 (no temporal expansion). + For 5D: (B, Cx(patch_size_hw^2)x(patch_size_t), F, H, W) -> (B, C, Fxpatch_size_t, Hxpatch_size_hw, Wxpatch_size_hw) + Example: (B, 48, 33, 128, 128) with patch_size_hw=4, patch_size_t=1 -> (B, 3, 33, 512, 512) + """ + if patch_size_hw == 1 and patch_size_t == 1: + return x + + if x.dim() == 4: + x = rearrange(x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size_hw, r=patch_size_hw) + elif x.dim() == 5: + x = rearrange( + x, + "b (c p r q) f h w -> b c (f p) (h q) (w r)", + p=patch_size_t, + q=patch_size_hw, + r=patch_size_hw, + ) + + return x + + +class PerChannelStatistics(nn.Module): + """ + Per-channel statistics for normalizing and denormalizing the latent representation. + This statics is computed over the entire dataset and stored in model's checkpoint under VAE state_dict. + """ + + def __init__(self, latent_channels: int = 128): + super().__init__() + self.register_buffer("std-of-means", torch.empty(latent_channels)) + self.register_buffer("mean-of-means", torch.empty(latent_channels)) + + def un_normalize(self, x: torch.Tensor) -> torch.Tensor: + return (x * self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x)) + self.get_buffer("mean-of-means").view( + 1, -1, 1, 1, 1).to(x) + + def normalize(self, x: torch.Tensor) -> torch.Tensor: + return (x - self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x)) / self.get_buffer("std-of-means").view( + 1, -1, 1, 1, 1).to(x) + + +class ResnetBlock3D(nn.Module): + r""" + A Resnet block. + Parameters: + in_channels (`int`): The number of channels in the input. + out_channels (`int`, *optional*, default to be `None`): + The number of output channels for the first conv layer. If None, same as `in_channels`. + dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use. + groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer. + eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization. + """ + + def __init__( + self, + dims: Union[int, Tuple[int, int]], + in_channels: int, + out_channels: Optional[int] = None, + dropout: float = 0.0, + groups: int = 32, + eps: float = 1e-6, + norm_layer: NormLayerType = NormLayerType.PIXEL_NORM, + inject_noise: bool = False, + timestep_conditioning: bool = False, + spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS, + ): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.inject_noise = inject_noise + + if norm_layer == NormLayerType.GROUP_NORM: + self.norm1 = nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) + elif norm_layer == NormLayerType.PIXEL_NORM: + self.norm1 = PixelNorm() + + self.non_linearity = nn.SiLU() + + self.conv1 = make_conv_nd( + dims, + in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + causal=True, + spatial_padding_mode=spatial_padding_mode, + ) + + if inject_noise: + self.per_channel_scale1 = nn.Parameter(torch.zeros((in_channels, 1, 1))) + + if norm_layer == NormLayerType.GROUP_NORM: + self.norm2 = nn.GroupNorm(num_groups=groups, num_channels=out_channels, eps=eps, affine=True) + elif norm_layer == NormLayerType.PIXEL_NORM: + self.norm2 = PixelNorm() + + self.dropout = torch.nn.Dropout(dropout) + + self.conv2 = make_conv_nd( + dims, + out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + causal=True, + spatial_padding_mode=spatial_padding_mode, + ) + + if inject_noise: + self.per_channel_scale2 = nn.Parameter(torch.zeros((in_channels, 1, 1))) + + self.conv_shortcut = (make_linear_nd(dims=dims, in_channels=in_channels, out_channels=out_channels) + if in_channels != out_channels else nn.Identity()) + + # Using GroupNorm with 1 group is equivalent to LayerNorm but works with (B, C, ...) layout + # avoiding the need for dimension rearrangement used in standard nn.LayerNorm + self.norm3 = (nn.GroupNorm(num_groups=1, num_channels=in_channels, eps=eps, affine=True) + if in_channels != out_channels else nn.Identity()) + + self.timestep_conditioning = timestep_conditioning + + if timestep_conditioning: + self.scale_shift_table = nn.Parameter(torch.zeros(4, in_channels)) + + def _feed_spatial_noise( + self, + hidden_states: torch.Tensor, + per_channel_scale: torch.Tensor, + generator: Optional[torch.Generator] = None, + ) -> torch.Tensor: + spatial_shape = hidden_states.shape[-2:] + device = hidden_states.device + dtype = hidden_states.dtype + + # similar to the "explicit noise inputs" method in style-gan + spatial_noise = torch.randn(spatial_shape, device=device, dtype=dtype, generator=generator)[None] + scaled_noise = (spatial_noise * per_channel_scale)[None, :, None, ...] + hidden_states = hidden_states + scaled_noise + + return hidden_states + + def forward( + self, + input_tensor: torch.Tensor, + causal: bool = True, + timestep: Optional[torch.Tensor] = None, + generator: Optional[torch.Generator] = None, + ) -> torch.Tensor: + hidden_states = input_tensor + batch_size = hidden_states.shape[0] + + hidden_states = self.norm1(hidden_states) + if self.timestep_conditioning: + if timestep is None: + raise ValueError("'timestep' parameter must be provided when 'timestep_conditioning' is True") + ada_values = self.scale_shift_table[None, ..., None, None, None].to( + device=hidden_states.device, dtype=hidden_states.dtype) + timestep.reshape( + batch_size, + 4, + -1, + timestep.shape[-3], + timestep.shape[-2], + timestep.shape[-1], + ) + shift1, scale1, shift2, scale2 = ada_values.unbind(dim=1) + + hidden_states = hidden_states * (1 + scale1) + shift1 + + hidden_states = self.non_linearity(hidden_states) + + hidden_states = self.conv1(hidden_states, causal=causal) + + if self.inject_noise: + hidden_states = self._feed_spatial_noise( + hidden_states, + self.per_channel_scale1.to(device=hidden_states.device, dtype=hidden_states.dtype), + generator=generator, + ) + + hidden_states = self.norm2(hidden_states) + + if self.timestep_conditioning: + hidden_states = hidden_states * (1 + scale2) + shift2 + + hidden_states = self.non_linearity(hidden_states) + + hidden_states = self.dropout(hidden_states) + + hidden_states = self.conv2(hidden_states, causal=causal) + + if self.inject_noise: + hidden_states = self._feed_spatial_noise( + hidden_states, + self.per_channel_scale2.to(device=hidden_states.device, dtype=hidden_states.dtype), + generator=generator, + ) + + input_tensor = self.norm3(input_tensor) + + batch_size = input_tensor.shape[0] + + input_tensor = self.conv_shortcut(input_tensor) + + output_tensor = input_tensor + hidden_states + + return output_tensor + + +class UNetMidBlock3D(nn.Module): + """ + A 3D UNet mid-block [`UNetMidBlock3D`] with multiple residual blocks. + Args: + in_channels (`int`): The number of input channels. + dropout (`float`, *optional*, defaults to 0.0): The dropout rate. + num_layers (`int`, *optional*, defaults to 1): The number of residual blocks. + resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks. + resnet_groups (`int`, *optional*, defaults to 32): + The number of groups to use in the group normalization layers of the resnet blocks. + norm_layer (`str`, *optional*, defaults to `group_norm`): + The normalization layer to use. Can be either `group_norm` or `pixel_norm`. + inject_noise (`bool`, *optional*, defaults to `False`): + Whether to inject noise into the hidden states. + timestep_conditioning (`bool`, *optional*, defaults to `False`): + Whether to condition the hidden states on the timestep. + Returns: + `torch.Tensor`: The output of the last residual block, which is a tensor of shape `(batch_size, + in_channels, height, width)`. + """ + + def __init__( + self, + dims: Union[int, Tuple[int, int]], + in_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_groups: int = 32, + norm_layer: NormLayerType = NormLayerType.GROUP_NORM, + inject_noise: bool = False, + timestep_conditioning: bool = False, + spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS, + ): + super().__init__() + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + + self.timestep_conditioning = timestep_conditioning + + if timestep_conditioning: + self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(embedding_dim=in_channels * 4, + size_emb_dim=0) + + self.res_blocks = nn.ModuleList([ + ResnetBlock3D( + dims=dims, + in_channels=in_channels, + out_channels=in_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + norm_layer=norm_layer, + inject_noise=inject_noise, + timestep_conditioning=timestep_conditioning, + spatial_padding_mode=spatial_padding_mode, + ) for _ in range(num_layers) + ]) + + def forward( + self, + hidden_states: torch.Tensor, + causal: bool = True, + timestep: Optional[torch.Tensor] = None, + generator: Optional[torch.Generator] = None, + ) -> torch.Tensor: + timestep_embed = None + if self.timestep_conditioning: + if timestep is None: + raise ValueError("'timestep' parameter must be provided when 'timestep_conditioning' is True") + batch_size = hidden_states.shape[0] + timestep_embed = self.time_embedder( + timestep=timestep.flatten(), + hidden_dtype=hidden_states.dtype, + ) + timestep_embed = timestep_embed.view(batch_size, timestep_embed.shape[-1], 1, 1, 1) + + for resnet in self.res_blocks: + hidden_states = resnet( + hidden_states, + causal=causal, + timestep=timestep_embed, + generator=generator, + ) + + return hidden_states + + +class SpaceToDepthDownsample(nn.Module): + + def __init__( + self, + dims: Union[int, Tuple[int, int]], + in_channels: int, + out_channels: int, + stride: Tuple[int, int, int], + spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS, + ): + super().__init__() + self.stride = stride + self.group_size = in_channels * math.prod(stride) // out_channels + self.conv = make_conv_nd( + dims=dims, + in_channels=in_channels, + out_channels=out_channels // math.prod(stride), + kernel_size=3, + stride=1, + causal=True, + spatial_padding_mode=spatial_padding_mode, + ) + + def forward( + self, + x: torch.Tensor, + causal: bool = True, + ) -> torch.Tensor: + if self.stride[0] == 2: + x = torch.cat([x[:, :, :1, :, :], x], dim=2) # duplicate first frames for padding + + # skip connection + x_in = rearrange( + x, + "b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w", + p1=self.stride[0], + p2=self.stride[1], + p3=self.stride[2], + ) + x_in = rearrange(x_in, "b (c g) d h w -> b c g d h w", g=self.group_size) + x_in = x_in.mean(dim=2) + + # conv + x = self.conv(x, causal=causal) + x = rearrange( + x, + "b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w", + p1=self.stride[0], + p2=self.stride[1], + p3=self.stride[2], + ) + + x = x + x_in + + return x + + +class DepthToSpaceUpsample(nn.Module): + + def __init__( + self, + dims: int | Tuple[int, int], + in_channels: int, + stride: Tuple[int, int, int], + residual: bool = False, + out_channels_reduction_factor: int = 1, + spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS, + ): + super().__init__() + self.stride = stride + self.out_channels = math.prod(stride) * in_channels // out_channels_reduction_factor + self.conv = make_conv_nd( + dims=dims, + in_channels=in_channels, + out_channels=self.out_channels, + kernel_size=3, + stride=1, + causal=True, + spatial_padding_mode=spatial_padding_mode, + ) + self.residual = residual + self.out_channels_reduction_factor = out_channels_reduction_factor + + def forward( + self, + x: torch.Tensor, + causal: bool = True, + ) -> torch.Tensor: + if self.residual: + # Reshape and duplicate the input to match the output shape + x_in = rearrange( + x, + "b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)", + p1=self.stride[0], + p2=self.stride[1], + p3=self.stride[2], + ) + num_repeat = math.prod(self.stride) // self.out_channels_reduction_factor + x_in = x_in.repeat(1, num_repeat, 1, 1, 1) + if self.stride[0] == 2: + x_in = x_in[:, :, 1:, :, :] + x = self.conv(x, causal=causal) + x = rearrange( + x, + "b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)", + p1=self.stride[0], + p2=self.stride[1], + p3=self.stride[2], + ) + if self.stride[0] == 2: + x = x[:, :, 1:, :, :] + if self.residual: + x = x + x_in + return x + + +def compute_trapezoidal_mask_1d( + length: int, + ramp_left: int, + ramp_right: int, + left_starts_from_0: bool = False, +) -> torch.Tensor: + """ + Generate a 1D trapezoidal blending mask with linear ramps. + Args: + length: Output length of the mask. + ramp_left: Fade-in length on the left. + ramp_right: Fade-out length on the right. + left_starts_from_0: Whether the ramp starts from 0 or first non-zero value. + Useful for temporal tiles where the first tile is causal. + Returns: + A 1D tensor of shape `(length,)` with values in [0, 1]. + """ + if length <= 0: + raise ValueError("Mask length must be positive.") + + ramp_left = max(0, min(ramp_left, length)) + ramp_right = max(0, min(ramp_right, length)) + + mask = torch.ones(length) + + if ramp_left > 0: + interval_length = ramp_left + 1 if left_starts_from_0 else ramp_left + 2 + fade_in = torch.linspace(0.0, 1.0, interval_length)[:-1] + if not left_starts_from_0: + fade_in = fade_in[1:] + mask[:ramp_left] *= fade_in + + if ramp_right > 0: + fade_out = torch.linspace(1.0, 0.0, steps=ramp_right + 2)[1:-1] + mask[-ramp_right:] *= fade_out + + return mask.clamp_(0, 1) + + +@dataclass(frozen=True) +class SpatialTilingConfig: + """Configuration for dividing each frame into spatial tiles with optional overlap. + Args: + tile_size_in_pixels (int): Size of each tile in pixels. Must be at least 64 and divisible by 32. + tile_overlap_in_pixels (int, optional): Overlap between tiles in pixels. Must be divisible by 32. Defaults to 0. + """ + + tile_size_in_pixels: int + tile_overlap_in_pixels: int = 0 + + def __post_init__(self) -> None: + if self.tile_size_in_pixels < 64: + raise ValueError(f"tile_size_in_pixels must be at least 64, got {self.tile_size_in_pixels}") + if self.tile_size_in_pixels % 32 != 0: + raise ValueError(f"tile_size_in_pixels must be divisible by 32, got {self.tile_size_in_pixels}") + if self.tile_overlap_in_pixels % 32 != 0: + raise ValueError(f"tile_overlap_in_pixels must be divisible by 32, got {self.tile_overlap_in_pixels}") + if self.tile_overlap_in_pixels >= self.tile_size_in_pixels: + raise ValueError( + f"Overlap must be less than tile size, got {self.tile_overlap_in_pixels} and {self.tile_size_in_pixels}" + ) + + +@dataclass(frozen=True) +class TemporalTilingConfig: + """Configuration for dividing a video into temporal tiles (chunks of frames) with optional overlap. + Args: + tile_size_in_frames (int): Number of frames in each tile. Must be at least 16 and divisible by 8. + tile_overlap_in_frames (int, optional): Number of overlapping frames between consecutive tiles. + Must be divisible by 8. Defaults to 0. + """ + + tile_size_in_frames: int + tile_overlap_in_frames: int = 0 + + def __post_init__(self) -> None: + if self.tile_size_in_frames < 16: + raise ValueError(f"tile_size_in_frames must be at least 16, got {self.tile_size_in_frames}") + if self.tile_size_in_frames % 8 != 0: + raise ValueError(f"tile_size_in_frames must be divisible by 8, got {self.tile_size_in_frames}") + if self.tile_overlap_in_frames % 8 != 0: + raise ValueError(f"tile_overlap_in_frames must be divisible by 8, got {self.tile_overlap_in_frames}") + if self.tile_overlap_in_frames >= self.tile_size_in_frames: + raise ValueError( + f"Overlap must be less than tile size, got {self.tile_overlap_in_frames} and {self.tile_size_in_frames}" + ) + + +@dataclass(frozen=True) +class TilingConfig: + """Configuration for splitting video into tiles with optional overlap. + Attributes: + spatial_config: Configuration for splitting spatial dimensions into tiles. + temporal_config: Configuration for splitting temporal dimension into tiles. + """ + + spatial_config: SpatialTilingConfig | None = None + temporal_config: TemporalTilingConfig | None = None + + @classmethod + def default(cls) -> "TilingConfig": + return cls( + spatial_config=SpatialTilingConfig(tile_size_in_pixels=512, tile_overlap_in_pixels=64), + temporal_config=TemporalTilingConfig(tile_size_in_frames=64, tile_overlap_in_frames=24), + ) + + +@dataclass(frozen=True) +class DimensionIntervals: + """Intervals which a single dimension of the latent space is split into. + Each interval is defined by its start, end, left ramp, and right ramp. + The start and end are the indices of the first and last element (exclusive) in the interval. + Ramps are regions of the interval where the value of the mask tensor is + interpolated between 0 and 1 for blending with neighboring intervals. + The left ramp and right ramp values are the lengths of the left and right ramps. + """ + + starts: List[int] + ends: List[int] + left_ramps: List[int] + right_ramps: List[int] + + +@dataclass(frozen=True) +class LatentIntervals: + """Intervals which the latent tensor of given shape is split into. + Each dimension of the latent space is split into intervals based on the length along said dimension. + """ + + original_shape: torch.Size + dimension_intervals: Tuple[DimensionIntervals, ...] + + +# Operation to split a single dimension of the tensor into intervals based on the length along the dimension. +SplitOperation = Callable[[int], DimensionIntervals] +# Operation to map the intervals in input dimension to slices and masks along a corresponding output dimension. +MappingOperation = Callable[[DimensionIntervals], tuple[list[slice], list[torch.Tensor | None]]] + + +def default_split_operation(length: int) -> DimensionIntervals: + return DimensionIntervals(starts=[0], ends=[length], left_ramps=[0], right_ramps=[0]) + + +DEFAULT_SPLIT_OPERATION: SplitOperation = default_split_operation + + +def default_mapping_operation(_intervals: DimensionIntervals,) -> tuple[list[slice], list[torch.Tensor | None]]: + return [slice(0, None)], [None] + + +DEFAULT_MAPPING_OPERATION: MappingOperation = default_mapping_operation + + +class Tile(NamedTuple): + """ + Represents a single tile. + Attributes: + in_coords: + Tuple of slices specifying where to cut the tile from the INPUT tensor. + out_coords: + Tuple of slices specifying where this tile's OUTPUT should be placed in the reconstructed OUTPUT tensor. + masks_1d: + Per-dimension masks in OUTPUT units. + These are used to create all-dimensional blending mask. + Methods: + blend_mask: + Create a single N-D mask from the per-dimension masks. + """ + + in_coords: Tuple[slice, ...] + out_coords: Tuple[slice, ...] + masks_1d: Tuple[Tuple[torch.Tensor, ...]] + + @property + def blend_mask(self) -> torch.Tensor: + num_dims = len(self.out_coords) + per_dimension_masks: List[torch.Tensor] = [] + + for dim_idx in range(num_dims): + mask_1d = self.masks_1d[dim_idx] + view_shape = [1] * num_dims + if mask_1d is None: + # Broadcast mask along this dimension (length 1). + one = torch.ones(1) + + view_shape[dim_idx] = 1 + per_dimension_masks.append(one.view(*view_shape)) + continue + + # Reshape (L,) -> (1, ..., L, ..., 1) so masks across dimensions broadcast-multiply. + view_shape[dim_idx] = mask_1d.shape[0] + per_dimension_masks.append(mask_1d.view(*view_shape)) + + # Multiply per-dimension masks to form the full N-D mask (separable blending window). + combined_mask = per_dimension_masks[0] + for mask in per_dimension_masks[1:]: + combined_mask = combined_mask * mask + + return combined_mask + + +def create_tiles_from_intervals_and_mappers( + intervals: LatentIntervals, + mappers: List[MappingOperation], +) -> List[Tile]: + full_dim_input_slices = [] + full_dim_output_slices = [] + full_dim_masks_1d = [] + for axis_index in range(len(intervals.original_shape)): + dimension_intervals = intervals.dimension_intervals[axis_index] + starts = dimension_intervals.starts + ends = dimension_intervals.ends + input_slices = [slice(s, e) for s, e in zip(starts, ends, strict=True)] + output_slices, masks_1d = mappers[axis_index](dimension_intervals) + full_dim_input_slices.append(input_slices) + full_dim_output_slices.append(output_slices) + full_dim_masks_1d.append(masks_1d) + + tiles = [] + tile_in_coords = list(itertools.product(*full_dim_input_slices)) + tile_out_coords = list(itertools.product(*full_dim_output_slices)) + tile_mask_1ds = list(itertools.product(*full_dim_masks_1d)) + for in_coord, out_coord, mask_1d in zip(tile_in_coords, tile_out_coords, tile_mask_1ds, strict=True): + tiles.append(Tile( + in_coords=in_coord, + out_coords=out_coord, + masks_1d=mask_1d, + )) + return tiles + + +def create_tiles( + latent_shape: torch.Size, + splitters: List[SplitOperation], + mappers: List[MappingOperation], +) -> List[Tile]: + if len(splitters) != len(latent_shape): + raise ValueError(f"Number of splitters must be equal to number of dimensions in latent shape, " + f"got {len(splitters)} and {len(latent_shape)}") + if len(mappers) != len(latent_shape): + raise ValueError(f"Number of mappers must be equal to number of dimensions in latent shape, " + f"got {len(mappers)} and {len(latent_shape)}") + intervals = [splitter(length) for splitter, length in zip(splitters, latent_shape, strict=True)] + latent_intervals = LatentIntervals(original_shape=latent_shape, dimension_intervals=tuple(intervals)) + return create_tiles_from_intervals_and_mappers(latent_intervals, mappers) + + +def _make_encoder_block( + block_name: str, + block_config: dict[str, Any], + in_channels: int, + convolution_dimensions: int, + norm_layer: NormLayerType, + norm_num_groups: int, + spatial_padding_mode: PaddingModeType, +) -> Tuple[nn.Module, int]: + out_channels = in_channels + + if block_name == "res_x": + block = UNetMidBlock3D( + dims=convolution_dimensions, + in_channels=in_channels, + num_layers=block_config["num_layers"], + resnet_eps=1e-6, + resnet_groups=norm_num_groups, + norm_layer=norm_layer, + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "res_x_y": + out_channels = in_channels * block_config.get("multiplier", 2) + block = ResnetBlock3D( + dims=convolution_dimensions, + in_channels=in_channels, + out_channels=out_channels, + eps=1e-6, + groups=norm_num_groups, + norm_layer=norm_layer, + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_time": + block = make_conv_nd( + dims=convolution_dimensions, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + stride=(2, 1, 1), + causal=True, + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_space": + block = make_conv_nd( + dims=convolution_dimensions, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + stride=(1, 2, 2), + causal=True, + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_all": + block = make_conv_nd( + dims=convolution_dimensions, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + stride=(2, 2, 2), + causal=True, + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_all_x_y": + out_channels = in_channels * block_config.get("multiplier", 2) + block = make_conv_nd( + dims=convolution_dimensions, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + stride=(2, 2, 2), + causal=True, + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_all_res": + out_channels = in_channels * block_config.get("multiplier", 2) + block = SpaceToDepthDownsample( + dims=convolution_dimensions, + in_channels=in_channels, + out_channels=out_channels, + stride=(2, 2, 2), + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_space_res": + out_channels = in_channels * block_config.get("multiplier", 2) + block = SpaceToDepthDownsample( + dims=convolution_dimensions, + in_channels=in_channels, + out_channels=out_channels, + stride=(1, 2, 2), + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_time_res": + out_channels = in_channels * block_config.get("multiplier", 2) + block = SpaceToDepthDownsample( + dims=convolution_dimensions, + in_channels=in_channels, + out_channels=out_channels, + stride=(2, 1, 1), + spatial_padding_mode=spatial_padding_mode, + ) + else: + raise ValueError(f"unknown block: {block_name}") + + return block, out_channels + + +class LTX2VideoEncoder(nn.Module): + _DEFAULT_NORM_NUM_GROUPS = 32 + """ + Variational Autoencoder Encoder. Encodes video frames into a latent representation. + The encoder compresses the input video through a series of downsampling operations controlled by + patch_size and encoder_blocks. The output is a normalized latent tensor with shape (B, 128, F', H', W'). + Compression Behavior: + The total compression is determined by: + 1. Initial spatial compression via patchify: H -> H/4, W -> W/4 (patch_size=4) + 2. Sequential compression through encoder_blocks based on their stride patterns + Compression blocks apply 2x compression in specified dimensions: + - "compress_time" / "compress_time_res": temporal only + - "compress_space" / "compress_space_res": spatial only (H and W) + - "compress_all" / "compress_all_res": all dimensions (F, H, W) + - "res_x" / "res_x_y": no compression + Standard LTX Video configuration: + - patch_size=4 + - encoder_blocks: 1x compress_space_res, 1x compress_time_res, 2x compress_all_res + - Final dimensions: F' = 1 + (F-1)/8, H' = H/32, W' = W/32 + - Example: (B, 3, 33, 512, 512) -> (B, 128, 5, 16, 16) + - Note: Input must have 1 + 8*k frames (e.g., 1, 9, 17, 25, 33...) + Args: + convolution_dimensions: The number of dimensions to use in convolutions (2D or 3D). + in_channels: The number of input channels. For RGB images, this is 3. + out_channels: The number of output channels (latent channels). For latent channels, this is 128. + encoder_blocks: The list of blocks to construct the encoder. Each block is a tuple of (block_name, params) + where params is either an int (num_layers) or a dict with configuration. + patch_size: The patch size for initial spatial compression. Should be a power of 2. + norm_layer: The normalization layer to use. Can be either `group_norm` or `pixel_norm`. + latent_log_var: The log variance mode. Can be either `per_channel`, `uniform`, `constant` or `none`. + """ + + def __init__( + self, + convolution_dimensions: int = 3, + in_channels: int = 3, + out_channels: int = 128, + patch_size: int = 4, + norm_layer: NormLayerType = NormLayerType.PIXEL_NORM, + latent_log_var: LogVarianceType = LogVarianceType.UNIFORM, + encoder_spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS, + encoder_version: str = "ltx-2", + ): + super().__init__() + if encoder_version == "ltx-2": + encoder_blocks = [ + ['res_x', {'num_layers': 4}], + ['compress_space_res', {'multiplier': 2}], + ['res_x', {'num_layers': 6}], + ['compress_time_res', {'multiplier': 2}], + ['res_x', {'num_layers': 6}], + ['compress_all_res', {'multiplier': 2}], + ['res_x', {'num_layers': 2}], + ['compress_all_res', {'multiplier': 2}], + ['res_x', {'num_layers': 2}] + ] + else: + # LTX-2.3 + encoder_blocks = [ + ["res_x", {"num_layers": 4}], + ["compress_space_res", {"multiplier": 2}], + ["res_x", {"num_layers": 6}], + ["compress_time_res", {"multiplier": 2}], + ["res_x", {"num_layers": 4}], + ["compress_all_res", {"multiplier": 2}], + ["res_x", {"num_layers": 2}], + ["compress_all_res", {"multiplier": 1}], + ["res_x", {"num_layers": 2}] + ] + self.patch_size = patch_size + self.norm_layer = norm_layer + self.latent_channels = out_channels + self.latent_log_var = latent_log_var + self._norm_num_groups = self._DEFAULT_NORM_NUM_GROUPS + + # Per-channel statistics for normalizing latents + self.per_channel_statistics = PerChannelStatistics(latent_channels=out_channels) + + in_channels = in_channels * patch_size**2 + feature_channels = out_channels + + self.conv_in = make_conv_nd( + dims=convolution_dimensions, + in_channels=in_channels, + out_channels=feature_channels, + kernel_size=3, + stride=1, + padding=1, + causal=True, + spatial_padding_mode=encoder_spatial_padding_mode, + ) + + self.down_blocks = nn.ModuleList([]) + + for block_name, block_params in encoder_blocks: + # Convert int to dict format for uniform handling + block_config = {"num_layers": block_params} if isinstance(block_params, int) else block_params + + block, feature_channels = _make_encoder_block( + block_name=block_name, + block_config=block_config, + in_channels=feature_channels, + convolution_dimensions=convolution_dimensions, + norm_layer=norm_layer, + norm_num_groups=self._norm_num_groups, + spatial_padding_mode=encoder_spatial_padding_mode, + ) + + self.down_blocks.append(block) + + # out + if norm_layer == NormLayerType.GROUP_NORM: + self.conv_norm_out = nn.GroupNorm(num_channels=feature_channels, num_groups=self._norm_num_groups, eps=1e-6) + elif norm_layer == NormLayerType.PIXEL_NORM: + self.conv_norm_out = PixelNorm() + + self.conv_act = nn.SiLU() + + conv_out_channels = out_channels + if latent_log_var == LogVarianceType.PER_CHANNEL: + conv_out_channels *= 2 + elif latent_log_var in {LogVarianceType.UNIFORM, LogVarianceType.CONSTANT}: + conv_out_channels += 1 + elif latent_log_var != LogVarianceType.NONE: + raise ValueError(f"Invalid latent_log_var: {latent_log_var}") + + self.conv_out = make_conv_nd( + dims=convolution_dimensions, + in_channels=feature_channels, + out_channels=conv_out_channels, + kernel_size=3, + padding=1, + causal=True, + spatial_padding_mode=encoder_spatial_padding_mode, + ) + + def forward(self, sample: torch.Tensor) -> torch.Tensor: + r""" + Encode video frames into normalized latent representation. + Args: + sample: Input video (B, C, F, H, W). F must be 1 + 8*k (e.g., 1, 9, 17, 25, 33...). + Returns: + Normalized latent means (B, 128, F', H', W') where F' = 1+(F-1)/8, H' = H/32, W' = W/32. + Example: (B, 3, 33, 512, 512) -> (B, 128, 5, 16, 16). + """ + # Validate frame count + frames_count = sample.shape[2] + if ((frames_count - 1) % 8) != 0: + frames_to_crop = (frames_count - 1) % 8 + sample = sample[:, :, :-frames_to_crop, ...] + + # Initial spatial compression: trade spatial resolution for channel depth + # This reduces H,W by patch_size and increases channels, making convolutions more efficient + # Example: (B, 3, F, 512, 512) -> (B, 48, F, 128, 128) with patch_size=4 + sample = patchify(sample, patch_size_hw=self.patch_size, patch_size_t=1) + sample = self.conv_in(sample) + + for down_block in self.down_blocks: + sample = down_block(sample) + + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + if self.latent_log_var == LogVarianceType.UNIFORM: + # Uniform Variance: model outputs N means and 1 shared log-variance channel. + # We need to expand the single logvar to match the number of means channels + # to create a format compatible with PER_CHANNEL (means + logvar, each with N channels). + # Sample shape: (B, N+1, ...) where N = latent_channels (e.g., 128 means + 1 logvar = 129) + # Target shape: (B, 2*N, ...) where first N are means, last N are logvar + + if sample.shape[1] < 2: + raise ValueError(f"Invalid channel count for UNIFORM mode: expected at least 2 channels " + f"(N means + 1 logvar), got {sample.shape[1]}") + + # Extract means (first N channels) and logvar (last 1 channel) + means = sample[:, :-1, ...] # (B, N, ...) + logvar = sample[:, -1:, ...] # (B, 1, ...) + + # Repeat logvar N times to match means channels + # Use expand/repeat pattern that works for both 4D and 5D tensors + num_channels = means.shape[1] + repeat_shape = [1, num_channels] + [1] * (sample.ndim - 2) + repeated_logvar = logvar.repeat(*repeat_shape) # (B, N, ...) + + # Concatenate to create (B, 2*N, ...) format: [means, repeated_logvar] + sample = torch.cat([means, repeated_logvar], dim=1) + elif self.latent_log_var == LogVarianceType.CONSTANT: + sample = sample[:, :-1, ...] + approx_ln_0 = -30 # this is the minimal clamp value in DiagonalGaussianDistribution objects + sample = torch.cat( + [sample, torch.ones_like(sample, device=sample.device) * approx_ln_0], + dim=1, + ) + + # Split into means and logvar, then normalize means + means, _ = torch.chunk(sample, 2, dim=1) + return self.per_channel_statistics.normalize(means) + + + def tiled_encode_video( + self, + video: torch.Tensor, + tile_size: int = 512, + tile_overlap: int = 128, + ) -> torch.Tensor: + """Encode video using spatial tiling for memory efficiency. + Splits the video into overlapping spatial tiles, encodes each tile separately, + and blends the results using linear feathering in the overlap regions. + Args: + video: Input tensor of shape [B, C, F, H, W] + tile_size: Tile size in pixels (must be divisible by 32) + tile_overlap: Overlap between tiles in pixels (must be divisible by 32) + Returns: + Encoded latent tensor [B, C_latent, F_latent, H_latent, W_latent] + """ + batch, _channels, frames, height, width = video.shape + device = video.device + dtype = video.dtype + + # Validate tile parameters + if tile_size % VAE_SPATIAL_FACTOR != 0: + raise ValueError(f"tile_size must be divisible by {VAE_SPATIAL_FACTOR}, got {tile_size}") + if tile_overlap % VAE_SPATIAL_FACTOR != 0: + raise ValueError(f"tile_overlap must be divisible by {VAE_SPATIAL_FACTOR}, got {tile_overlap}") + if tile_overlap >= tile_size: + raise ValueError(f"tile_overlap ({tile_overlap}) must be less than tile_size ({tile_size})") + + # If video fits in a single tile, use regular encoding + if height <= tile_size and width <= tile_size: + return self.forward(video) + + # Calculate output dimensions + # VAE compresses: H -> H/32, W -> W/32, F -> 1 + (F-1)/8 + output_height = height // VAE_SPATIAL_FACTOR + output_width = width // VAE_SPATIAL_FACTOR + output_frames = 1 + (frames - 1) // VAE_TEMPORAL_FACTOR + + # Latent channels (128 for LTX-2) + # Get from a small test encode or assume 128 + latent_channels = 128 + + # Initialize output and weight tensors + output = torch.zeros( + (batch, latent_channels, output_frames, output_height, output_width), + device=device, + dtype=dtype, + ) + weights = torch.zeros( + (batch, 1, output_frames, output_height, output_width), + device=device, + dtype=dtype, + ) + + # Calculate tile positions with overlap + # Step size is tile_size - tile_overlap + step_h = tile_size - tile_overlap + step_w = tile_size - tile_overlap + + h_positions = list(range(0, max(1, height - tile_overlap), step_h)) + w_positions = list(range(0, max(1, width - tile_overlap), step_w)) + + # Ensure last tile covers the edge + if h_positions[-1] + tile_size < height: + h_positions.append(height - tile_size) + if w_positions[-1] + tile_size < width: + w_positions.append(width - tile_size) + + # Remove duplicates and sort + h_positions = sorted(set(h_positions)) + w_positions = sorted(set(w_positions)) + + # Overlap in latent space + overlap_out_h = tile_overlap // VAE_SPATIAL_FACTOR + overlap_out_w = tile_overlap // VAE_SPATIAL_FACTOR + + # Process each tile + for h_pos in h_positions: + for w_pos in w_positions: + # Calculate tile boundaries in input space + h_start = max(0, h_pos) + w_start = max(0, w_pos) + h_end = min(h_start + tile_size, height) + w_end = min(w_start + tile_size, width) + + # Ensure tile dimensions are divisible by VAE_SPATIAL_FACTOR + tile_h = ((h_end - h_start) // VAE_SPATIAL_FACTOR) * VAE_SPATIAL_FACTOR + tile_w = ((w_end - w_start) // VAE_SPATIAL_FACTOR) * VAE_SPATIAL_FACTOR + + if tile_h < VAE_SPATIAL_FACTOR or tile_w < VAE_SPATIAL_FACTOR: + continue + + # Adjust end positions + h_end = h_start + tile_h + w_end = w_start + tile_w + + # Extract tile + tile = video[:, :, :, h_start:h_end, w_start:w_end] + + # Encode tile + encoded_tile = self.forward(tile) + + # Get actual encoded dimensions + _, _, tile_out_frames, tile_out_height, tile_out_width = encoded_tile.shape + + # Calculate output positions + out_h_start = h_start // VAE_SPATIAL_FACTOR + out_w_start = w_start // VAE_SPATIAL_FACTOR + out_h_end = min(out_h_start + tile_out_height, output_height) + out_w_end = min(out_w_start + tile_out_width, output_width) + + # Trim encoded tile if necessary + actual_tile_h = out_h_end - out_h_start + actual_tile_w = out_w_end - out_w_start + encoded_tile = encoded_tile[:, :, :, :actual_tile_h, :actual_tile_w] + + # Create blending mask with linear feathering at edges + mask = torch.ones( + (1, 1, tile_out_frames, actual_tile_h, actual_tile_w), + device=device, + dtype=dtype, + ) + + # Apply feathering at edges (linear blend in overlap regions) + # Left edge + if h_pos > 0 and overlap_out_h > 0 and overlap_out_h < actual_tile_h: + fade_in = torch.linspace(0.0, 1.0, overlap_out_h + 2, device=device, dtype=dtype)[1:-1] + mask[:, :, :, :overlap_out_h, :] *= fade_in.view(1, 1, 1, -1, 1) + + # Right edge (bottom in height dimension) + if h_end < height and overlap_out_h > 0 and overlap_out_h < actual_tile_h: + fade_out = torch.linspace(1.0, 0.0, overlap_out_h + 2, device=device, dtype=dtype)[1:-1] + mask[:, :, :, -overlap_out_h:, :] *= fade_out.view(1, 1, 1, -1, 1) + + # Top edge (left in width dimension) + if w_pos > 0 and overlap_out_w > 0 and overlap_out_w < actual_tile_w: + fade_in = torch.linspace(0.0, 1.0, overlap_out_w + 2, device=device, dtype=dtype)[1:-1] + mask[:, :, :, :, :overlap_out_w] *= fade_in.view(1, 1, 1, 1, -1) + + # Bottom edge (right in width dimension) + if w_end < width and overlap_out_w > 0 and overlap_out_w < actual_tile_w: + fade_out = torch.linspace(1.0, 0.0, overlap_out_w + 2, device=device, dtype=dtype)[1:-1] + mask[:, :, :, :, -overlap_out_w:] *= fade_out.view(1, 1, 1, 1, -1) + + # Accumulate weighted results + output[:, :, :, out_h_start:out_h_end, out_w_start:out_w_end] += encoded_tile * mask + weights[:, :, :, out_h_start:out_h_end, out_w_start:out_w_end] += mask + + # Normalize by weights (avoid division by zero) + output = output / (weights + 1e-8) + + return output + + def encode( + self, + video: torch.Tensor, + tiled=False, + tile_size_in_pixels: Optional[int] = 512, + tile_overlap_in_pixels: Optional[int] = 128, + **kwargs, + ) -> torch.Tensor: + if video.ndim == 4: + video = video.unsqueeze(0) # [C, F, H, W] -> [B, C, F, H, W] + # Choose encoding method based on tiling flag + if tiled: + latents = self.tiled_encode_video( + video=video, + tile_size=tile_size_in_pixels, + tile_overlap=tile_overlap_in_pixels, + ) + else: + # Encode video - VAE expects [B, C, F, H, W], returns [B, C, F', H', W'] + latents = self.forward(video) + return latents + + +def _make_decoder_block( + block_name: str, + block_config: dict[str, Any], + in_channels: int, + convolution_dimensions: int, + norm_layer: NormLayerType, + timestep_conditioning: bool, + norm_num_groups: int, + spatial_padding_mode: PaddingModeType, +) -> Tuple[nn.Module, int]: + out_channels = in_channels + if block_name == "res_x": + block = UNetMidBlock3D( + dims=convolution_dimensions, + in_channels=in_channels, + num_layers=block_config["num_layers"], + resnet_eps=1e-6, + resnet_groups=norm_num_groups, + norm_layer=norm_layer, + inject_noise=block_config.get("inject_noise", False), + timestep_conditioning=timestep_conditioning, + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "attn_res_x": + block = UNetMidBlock3D( + dims=convolution_dimensions, + in_channels=in_channels, + num_layers=block_config["num_layers"], + resnet_groups=norm_num_groups, + norm_layer=norm_layer, + inject_noise=block_config.get("inject_noise", False), + timestep_conditioning=timestep_conditioning, + attention_head_dim=block_config["attention_head_dim"], + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "res_x_y": + out_channels = in_channels // block_config.get("multiplier", 2) + block = ResnetBlock3D( + dims=convolution_dimensions, + in_channels=in_channels, + out_channels=out_channels, + eps=1e-6, + groups=norm_num_groups, + norm_layer=norm_layer, + inject_noise=block_config.get("inject_noise", False), + timestep_conditioning=False, + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_time": + out_channels = in_channels // block_config.get("multiplier", 1) + block = DepthToSpaceUpsample( + dims=convolution_dimensions, + in_channels=in_channels, + stride=(2, 1, 1), + out_channels_reduction_factor=block_config.get("multiplier", 1), + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_space": + out_channels = in_channels // block_config.get("multiplier", 1) + block = DepthToSpaceUpsample( + dims=convolution_dimensions, + in_channels=in_channels, + stride=(1, 2, 2), + out_channels_reduction_factor=block_config.get("multiplier", 1), + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_all": + out_channels = in_channels // block_config.get("multiplier", 1) + block = DepthToSpaceUpsample( + dims=convolution_dimensions, + in_channels=in_channels, + stride=(2, 2, 2), + residual=block_config.get("residual", False), + out_channels_reduction_factor=block_config.get("multiplier", 1), + spatial_padding_mode=spatial_padding_mode, + ) + else: + raise ValueError(f"unknown layer: {block_name}") + + return block, out_channels + + +class LTX2VideoDecoder(nn.Module): + _DEFAULT_NORM_NUM_GROUPS = 32 + """ + Variational Autoencoder Decoder. Decodes latent representation into video frames. + The decoder upsamples latents through a series of upsampling operations (inverse of encoder). + Output dimensions: F = 8x(F'-1) + 1, H = 32xH', W = 32xW' for standard LTX Video configuration. + Upsampling blocks expand dimensions by 2x in specified dimensions: + - "compress_time": temporal only + - "compress_space": spatial only (H and W) + - "compress_all": all dimensions (F, H, W) + - "res_x" / "res_x_y" / "attn_res_x": no upsampling + Causal Mode: + causal=False (standard): Symmetric padding, allows future frame dependencies. + causal=True: Causal padding, each frame depends only on past/current frames. + First frame removed after temporal upsampling in both modes. Output shape unchanged. + Example: (B, 128, 5, 16, 16) -> (B, 3, 33, 512, 512) for both modes. + Args: + convolution_dimensions: The number of dimensions to use in convolutions (2D or 3D). + in_channels: The number of input channels (latent channels). Default is 128. + out_channels: The number of output channels. For RGB images, this is 3. + decoder_blocks: The list of blocks to construct the decoder. Each block is a tuple of (block_name, params) + where params is either an int (num_layers) or a dict with configuration. + patch_size: Final spatial expansion factor. For standard LTX Video, use 4 for 4x spatial expansion: + H -> Hx4, W -> Wx4. Should be a power of 2. + norm_layer: The normalization layer to use. Can be either `group_norm` or `pixel_norm`. + causal: Whether to use causal convolutions. For standard LTX Video, use False for symmetric padding. + When True, uses causal padding (past/current frames only). + timestep_conditioning: Whether to condition the decoder on timestep for denoising. + """ + + def __init__( + self, + convolution_dimensions: int = 3, + in_channels: int = 128, + out_channels: int = 3, + decoder_blocks: List[Tuple[str, int | dict]] = [], # noqa: B006 + patch_size: int = 4, + norm_layer: NormLayerType = NormLayerType.PIXEL_NORM, + causal: bool = False, + timestep_conditioning: bool = False, + decoder_spatial_padding_mode: PaddingModeType = PaddingModeType.REFLECT, + decoder_version: str = "ltx-2", + base_channels: int = 128, + ): + super().__init__() + + # Spatiotemporal downscaling between decoded video space and VAE latents. + # According to the LTXV paper, the standard configuration downsamples + # video inputs by a factor of 8 in the temporal dimension and 32 in + # each spatial dimension (height and width). This parameter determines how + # many video frames and pixels correspond to a single latent cell. + if decoder_version == "ltx-2": + decoder_blocks = [ + ['res_x', {'num_layers': 5, 'inject_noise': False}], + ['compress_all', {'residual': True, 'multiplier': 2}], + ['res_x', {'num_layers': 5, 'inject_noise': False}], + ['compress_all', {'residual': True, 'multiplier': 2}], + ['res_x', {'num_layers': 5, 'inject_noise': False}], + ['compress_all', {'residual': True, 'multiplier': 2}], + ['res_x', {'num_layers': 5, 'inject_noise': False}] + ] + else: + # LTX-2.3 + decoder_blocks = [ + ["res_x", {"num_layers": 4}], + ["compress_space", {"multiplier": 2}], + ["res_x", {"num_layers": 6}], + ["compress_time", {"multiplier": 2}], + ["res_x", {"num_layers": 4}], + ["compress_all", {"multiplier": 1}], + ["res_x", {"num_layers": 2}], + ["compress_all", {"multiplier": 2}], + ["res_x", {"num_layers": 2}] + ] + self.video_downscale_factors = SpatioTemporalScaleFactors( + time=8, + width=32, + height=32, + ) + + self.patch_size = patch_size + out_channels = out_channels * patch_size**2 + self.causal = causal + self.timestep_conditioning = timestep_conditioning + self._norm_num_groups = self._DEFAULT_NORM_NUM_GROUPS + + # Per-channel statistics for denormalizing latents + self.per_channel_statistics = PerChannelStatistics(latent_channels=in_channels) + + # Noise and timestep parameters for decoder conditioning + self.decode_noise_scale = 0.025 + self.decode_timestep = 0.05 + + # LTX VAE decoder architecture uses 3 upsampler blocks with multiplier equals to 2. + # Hence the total feature_channels is multiplied by 8 (2^3). + feature_channels = base_channels * 8 + + self.conv_in = make_conv_nd( + dims=convolution_dimensions, + in_channels=in_channels, + out_channels=feature_channels, + kernel_size=3, + stride=1, + padding=1, + causal=True, + spatial_padding_mode=decoder_spatial_padding_mode, + ) + + self.up_blocks = nn.ModuleList([]) + + for block_name, block_params in list(reversed(decoder_blocks)): + # Convert int to dict format for uniform handling + block_config = {"num_layers": block_params} if isinstance(block_params, int) else block_params + + block, feature_channels = _make_decoder_block( + block_name=block_name, + block_config=block_config, + in_channels=feature_channels, + convolution_dimensions=convolution_dimensions, + norm_layer=norm_layer, + timestep_conditioning=timestep_conditioning, + norm_num_groups=self._norm_num_groups, + spatial_padding_mode=decoder_spatial_padding_mode, + ) + + self.up_blocks.append(block) + + if norm_layer == NormLayerType.GROUP_NORM: + self.conv_norm_out = nn.GroupNorm(num_channels=feature_channels, num_groups=self._norm_num_groups, eps=1e-6) + elif norm_layer == NormLayerType.PIXEL_NORM: + self.conv_norm_out = PixelNorm() + + self.conv_act = nn.SiLU() + self.conv_out = make_conv_nd( + dims=convolution_dimensions, + in_channels=feature_channels, + out_channels=out_channels, + kernel_size=3, + padding=1, + causal=True, + spatial_padding_mode=decoder_spatial_padding_mode, + ) + + if timestep_conditioning: + self.timestep_scale_multiplier = nn.Parameter(torch.tensor(1000.0)) + self.last_time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(embedding_dim=feature_channels * 2, + size_emb_dim=0) + self.last_scale_shift_table = nn.Parameter(torch.empty(2, feature_channels)) + + def forward( + self, + sample: torch.Tensor, + timestep: torch.Tensor | None = None, + generator: torch.Generator | None = None, + ) -> torch.Tensor: + r""" + Decode latent representation into video frames. + Args: + sample: Latent tensor (B, 128, F', H', W'). + timestep: Timestep for conditioning (if timestep_conditioning=True). Uses default 0.05 if None. + generator: Random generator for deterministic noise injection (if inject_noise=True in blocks). + Returns: + Decoded video (B, 3, F, H, W) where F = 8x(F'-1) + 1, H = 32xH', W = 32xW'. + Example: (B, 128, 5, 16, 16) -> (B, 3, 33, 512, 512). + Note: First frame is removed after temporal upsampling regardless of causal mode. + When causal=False, allows future frame dependencies in convolutions but maintains same output shape. + """ + batch_size = sample.shape[0] + + # Add noise if timestep conditioning is enabled + if self.timestep_conditioning: + noise = (torch.randn( + sample.size(), + generator=generator, + dtype=sample.dtype, + device=sample.device, + ) * self.decode_noise_scale) + + sample = noise + (1.0 - self.decode_noise_scale) * sample + + # Denormalize latents + sample = self.per_channel_statistics.un_normalize(sample) + + # Use default decode_timestep if timestep not provided + if timestep is None and self.timestep_conditioning: + timestep = torch.full((batch_size,), self.decode_timestep, device=sample.device, dtype=sample.dtype) + + sample = self.conv_in(sample, causal=self.causal) + + scaled_timestep = None + if self.timestep_conditioning: + if timestep is None: + raise ValueError("'timestep' parameter must be provided when 'timestep_conditioning' is True") + scaled_timestep = timestep * self.timestep_scale_multiplier.to(sample) + + for up_block in self.up_blocks: + if isinstance(up_block, UNetMidBlock3D): + block_kwargs = { + "causal": self.causal, + "timestep": scaled_timestep if self.timestep_conditioning else None, + "generator": generator, + } + sample = up_block(sample, **block_kwargs) + elif isinstance(up_block, ResnetBlock3D): + sample = up_block(sample, causal=self.causal, generator=generator) + else: + sample = up_block(sample, causal=self.causal) + + sample = self.conv_norm_out(sample) + + if self.timestep_conditioning: + embedded_timestep = self.last_time_embedder( + timestep=scaled_timestep.flatten(), + hidden_dtype=sample.dtype, + ) + embedded_timestep = embedded_timestep.view(batch_size, embedded_timestep.shape[-1], 1, 1, 1) + ada_values = self.last_scale_shift_table[None, ..., None, None, None].to( + device=sample.device, dtype=sample.dtype) + embedded_timestep.reshape( + batch_size, + 2, + -1, + embedded_timestep.shape[-3], + embedded_timestep.shape[-2], + embedded_timestep.shape[-1], + ) + shift, scale = ada_values.unbind(dim=1) + sample = sample * (1 + scale) + shift + + sample = self.conv_act(sample) + sample = self.conv_out(sample, causal=self.causal) + + # Final spatial expansion: reverse the initial patchify from encoder + # Moves pixels from channels back to spatial dimensions + # Example: (B, 48, F, 128, 128) -> (B, 3, F, 512, 512) with patch_size=4 + sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1) + + return sample + + def _prepare_tiles( + self, + latent: torch.Tensor, + tiling_config: TilingConfig | None = None, + ) -> List[Tile]: + splitters = [DEFAULT_SPLIT_OPERATION] * len(latent.shape) + mappers = [DEFAULT_MAPPING_OPERATION] * len(latent.shape) + if tiling_config is not None and tiling_config.spatial_config is not None: + cfg = tiling_config.spatial_config + long_side = max(latent.shape[3], latent.shape[4]) + + def enable_on_axis(axis_idx: int, factor: int) -> None: + size = cfg.tile_size_in_pixels // factor + overlap = cfg.tile_overlap_in_pixels // factor + axis_length = latent.shape[axis_idx] + lower_threshold = max(2, overlap + 1) + tile_size = max(lower_threshold, round(size * axis_length / long_side)) + splitters[axis_idx] = split_in_spatial(tile_size, overlap) + mappers[axis_idx] = to_mapping_operation(map_spatial_slice, factor) + + enable_on_axis(3, self.video_downscale_factors.height) + enable_on_axis(4, self.video_downscale_factors.width) + + if tiling_config is not None and tiling_config.temporal_config is not None: + cfg = tiling_config.temporal_config + tile_size = cfg.tile_size_in_frames // self.video_downscale_factors.time + overlap = cfg.tile_overlap_in_frames // self.video_downscale_factors.time + splitters[2] = split_in_temporal(tile_size, overlap) + mappers[2] = to_mapping_operation(map_temporal_slice, self.video_downscale_factors.time) + + return create_tiles(latent.shape, splitters, mappers) + + def tiled_decode( + self, + latent: torch.Tensor, + tiling_config: TilingConfig | None = None, + timestep: torch.Tensor | None = None, + generator: torch.Generator | None = None, + ) -> Iterator[torch.Tensor]: + """ + Decode a latent tensor into video frames using tiled processing. + Splits the latent tensor into tiles, decodes each tile individually, + and yields video chunks as they become available. + Args: + latent: Input latent tensor (B, C, F', H', W'). + tiling_config: Tiling configuration for the latent tensor. + timestep: Optional timestep for decoder conditioning. + generator: Optional random generator for deterministic decoding. + Yields: + Video chunks (B, C, T, H, W) by temporal slices; + """ + + # Calculate full video shape from latent shape to get spatial dimensions + full_video_shape = VideoLatentShape.from_torch_shape(latent.shape).upscale(self.video_downscale_factors) + tiles = self._prepare_tiles(latent, tiling_config) + + temporal_groups = self._group_tiles_by_temporal_slice(tiles) + + # State for temporal overlap handling + previous_chunk = None + previous_weights = None + previous_temporal_slice = None + + for temporal_group_tiles in temporal_groups: + curr_temporal_slice = temporal_group_tiles[0].out_coords[2] + + # Calculate the shape of the temporal buffer for this group of tiles. + # The temporal length depends on whether this is the first tile (starts at 0) or not. + # - First tile: (frames - 1) * scale + 1 + # - Subsequent tiles: frames * scale + # This logic is handled by TemporalAxisMapping and reflected in out_coords. + temporal_tile_buffer_shape = full_video_shape._replace(frames=curr_temporal_slice.stop - + curr_temporal_slice.start,) + + buffer = torch.zeros( + temporal_tile_buffer_shape.to_torch_shape(), + device=latent.device, + dtype=latent.dtype, + ) + + curr_weights = self._accumulate_temporal_group_into_buffer( + group_tiles=temporal_group_tiles, + buffer=buffer, + latent=latent, + timestep=timestep, + generator=generator, + ) + + # Blend with previous temporal chunk if it exists + if previous_chunk is not None: + # Check if current temporal slice overlaps with previous temporal slice + if previous_temporal_slice.stop > curr_temporal_slice.start: + overlap_len = previous_temporal_slice.stop - curr_temporal_slice.start + temporal_overlap_slice = slice(curr_temporal_slice.start - previous_temporal_slice.start, None) + + # The overlap is already masked before it reaches this step. Each tile is accumulated into buffer + # with its trapezoidal mask, and curr_weights accumulates the same mask. In the overlap blend we add + # the masked values (buffer[...]) and the corresponding weights (curr_weights[...]) into the + # previous buffers, then later normalize by weights. + previous_chunk[:, :, temporal_overlap_slice, :, :] += buffer[:, :, slice(0, overlap_len), :, :] + previous_weights[:, :, temporal_overlap_slice, :, :] += curr_weights[:, :, + slice(0, overlap_len), :, :] + + buffer[:, :, slice(0, overlap_len), :, :] = previous_chunk[:, :, temporal_overlap_slice, :, :] + curr_weights[:, :, slice(0, overlap_len), :, :] = previous_weights[:, :, + temporal_overlap_slice, :, :] + + # Yield the non-overlapping part of the previous chunk + previous_weights = previous_weights.clamp(min=1e-8) + yield_len = curr_temporal_slice.start - previous_temporal_slice.start + yield (previous_chunk / previous_weights)[:, :, :yield_len, :, :] + + # Update state for next iteration + previous_chunk = buffer + previous_weights = curr_weights + previous_temporal_slice = curr_temporal_slice + + # Yield any remaining chunk + if previous_chunk is not None: + previous_weights = previous_weights.clamp(min=1e-8) + yield previous_chunk / previous_weights + + def _group_tiles_by_temporal_slice(self, tiles: List[Tile]) -> List[List[Tile]]: + """Group tiles by their temporal output slice.""" + if not tiles: + return [] + + groups = [] + current_slice = tiles[0].out_coords[2] + current_group = [] + + for tile in tiles: + tile_slice = tile.out_coords[2] + if tile_slice == current_slice: + current_group.append(tile) + else: + groups.append(current_group) + current_slice = tile_slice + current_group = [tile] + + # Add the final group + if current_group: + groups.append(current_group) + + return groups + + def _accumulate_temporal_group_into_buffer( + self, + group_tiles: List[Tile], + buffer: torch.Tensor, + latent: torch.Tensor, + timestep: torch.Tensor | None, + generator: torch.Generator | None, + ) -> torch.Tensor: + """ + Decode and accumulate all tiles of a temporal group into a local buffer. + The buffer is local to the group and always starts at time 0; temporal coordinates + are rebased by subtracting temporal_slice.start. + """ + temporal_slice = group_tiles[0].out_coords[2] + + weights = torch.zeros_like(buffer) + + for tile in group_tiles: + decoded_tile = self.forward(latent[tile.in_coords], timestep, generator) + mask = tile.blend_mask.to(device=buffer.device, dtype=buffer.dtype) + temporal_offset = tile.out_coords[2].start - temporal_slice.start + # Use the tile's output coordinate length, not the decoded tile's length, + # as the decoder may produce a different number of frames than expected + expected_temporal_len = tile.out_coords[2].stop - tile.out_coords[2].start + decoded_temporal_len = decoded_tile.shape[2] + + # Ensure we don't exceed the buffer or decoded tile bounds + actual_temporal_len = min(expected_temporal_len, decoded_temporal_len, buffer.shape[2] - temporal_offset) + + chunk_coords = ( + slice(None), # batch + slice(None), # channels + slice(temporal_offset, temporal_offset + actual_temporal_len), + tile.out_coords[3], # height + tile.out_coords[4], # width + ) + + # Slice decoded_tile and mask to match the actual length we're writing + decoded_slice = decoded_tile[:, :, :actual_temporal_len, :, :] + mask_slice = mask[:, :, :actual_temporal_len, :, :] if mask.shape[2] > 1 else mask + + buffer[chunk_coords] += decoded_slice * mask_slice + weights[chunk_coords] += mask_slice + + return weights + + def decode( + self, + latent: torch.Tensor, + tiled=False, + tile_size_in_pixels: Optional[int] = 512, + tile_overlap_in_pixels: Optional[int] = 128, + tile_size_in_frames: Optional[int] = 128, + tile_overlap_in_frames: Optional[int] = 24, + ) -> torch.Tensor: + if tiled: + tiling_config = TilingConfig( + spatial_config=SpatialTilingConfig( + tile_size_in_pixels=tile_size_in_pixels, + tile_overlap_in_pixels=tile_overlap_in_pixels, + ), + temporal_config=TemporalTilingConfig( + tile_size_in_frames=tile_size_in_frames, + tile_overlap_in_frames=tile_overlap_in_frames, + ), + ) + tiles = self.tiled_decode(latent, tiling_config) + return torch.cat(list(tiles), dim=2) + else: + return self.forward(latent) + +def decode_video( + latent: torch.Tensor, + video_decoder: LTX2VideoDecoder, + tiling_config: TilingConfig | None = None, + generator: torch.Generator | None = None, +) -> Iterator[torch.Tensor]: + """ + Decode a video latent tensor with the given decoder. + Args: + latent: Tensor [c, f, h, w] + video_decoder: Decoder module. + tiling_config: Optional tiling settings. + generator: Optional random generator for deterministic decoding. + Yields: + Decoded chunk [f, h, w, c], uint8 in [0, 255]. + """ + + def convert_to_uint8(frames: torch.Tensor) -> torch.Tensor: + frames = (((frames + 1.0) / 2.0).clamp(0.0, 1.0) * 255.0).to(torch.uint8) + frames = rearrange(frames[0], "c f h w -> f h w c") + return frames + + if tiling_config is not None: + for frames in video_decoder.tiled_decode(latent, tiling_config, generator=generator): + return convert_to_uint8(frames) + else: + decoded_video = video_decoder(latent, generator=generator) + return convert_to_uint8(decoded_video) + + +def get_video_chunks_number(num_frames: int, tiling_config: TilingConfig | None = None) -> int: + """ + Get the number of video chunks for a given number of frames and tiling configuration. + Args: + num_frames: Number of frames in the video. + tiling_config: Tiling configuration. + Returns: + Number of video chunks. + """ + if not tiling_config or not tiling_config.temporal_config: + return 1 + cfg = tiling_config.temporal_config + frame_stride = cfg.tile_size_in_frames - cfg.tile_overlap_in_frames + return (num_frames - 1 + frame_stride - 1) // frame_stride + + +def split_in_spatial(size: int, overlap: int) -> SplitOperation: + + def split(dimension_size: int) -> DimensionIntervals: + if dimension_size <= size: + return DEFAULT_SPLIT_OPERATION(dimension_size) + amount = (dimension_size + size - 2 * overlap - 1) // (size - overlap) + starts = [i * (size - overlap) for i in range(amount)] + ends = [start + size for start in starts] + ends[-1] = dimension_size + left_ramps = [0] + [overlap] * (amount - 1) + right_ramps = [overlap] * (amount - 1) + [0] + return DimensionIntervals(starts=starts, ends=ends, left_ramps=left_ramps, right_ramps=right_ramps) + + return split + + +def split_in_temporal(size: int, overlap: int) -> SplitOperation: + non_causal_split = split_in_spatial(size, overlap) + + def split(dimension_size: int) -> DimensionIntervals: + if dimension_size <= size: + return DEFAULT_SPLIT_OPERATION(dimension_size) + intervals = non_causal_split(dimension_size) + starts = intervals.starts + starts[1:] = [s - 1 for s in starts[1:]] + left_ramps = intervals.left_ramps + left_ramps[1:] = [r + 1 for r in left_ramps[1:]] + return replace(intervals, starts=starts, left_ramps=left_ramps) + + return split + + +def to_mapping_operation( + map_func: Callable[[int, int, int, int, int], Tuple[slice, torch.Tensor]], + scale: int, +) -> MappingOperation: + + def map_op(intervals: DimensionIntervals) -> tuple[list[slice], list[torch.Tensor | None]]: + output_slices: list[slice] = [] + masks_1d: list[torch.Tensor | None] = [] + number_of_slices = len(intervals.starts) + for i in range(number_of_slices): + start = intervals.starts[i] + end = intervals.ends[i] + left_ramp = intervals.left_ramps[i] + right_ramp = intervals.right_ramps[i] + output_slice, mask_1d = map_func(start, end, left_ramp, right_ramp, scale) + output_slices.append(output_slice) + masks_1d.append(mask_1d) + return output_slices, masks_1d + + return map_op + + +def map_temporal_slice(begin: int, end: int, left_ramp: int, right_ramp: int, scale: int) -> Tuple[slice, torch.Tensor]: + start = begin * scale + stop = 1 + (end - 1) * scale + left_ramp = 1 + (left_ramp - 1) * scale + right_ramp = right_ramp * scale + + return slice(start, stop), compute_trapezoidal_mask_1d(stop - start, left_ramp, right_ramp, True) + + +def map_spatial_slice(begin: int, end: int, left_ramp: int, right_ramp: int, scale: int) -> Tuple[slice, torch.Tensor]: + start = begin * scale + stop = end * scale + left_ramp = left_ramp * scale + right_ramp = right_ramp * scale + + return slice(start, stop), compute_trapezoidal_mask_1d(stop - start, left_ramp, right_ramp, False) diff --git a/diffsynth/models/model_loader.py b/diffsynth/models/model_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..7a716e25995f354f41bc3501284c04138e355a57 --- /dev/null +++ b/diffsynth/models/model_loader.py @@ -0,0 +1,113 @@ +from ..core.loader import load_model, hash_model_file +from ..core.vram import AutoWrappedModule +from ..configs import MODEL_CONFIGS, VRAM_MANAGEMENT_MODULE_MAPS, VERSION_CHECKER_MAPS +import importlib, json, torch + + +class ModelPool: + def __init__(self): + self.model = [] + self.model_name = [] + self.model_path = [] + + def import_model_class(self, model_class): + split = model_class.rfind(".") + model_resource, model_class = model_class[:split], model_class[split+1:] + model_class = importlib.import_module(model_resource).__getattribute__(model_class) + return model_class + + def need_to_enable_vram_management(self, vram_config): + return vram_config["offload_dtype"] is not None and vram_config["offload_device"] is not None + + def fetch_module_map(self, model_class, vram_config): + if self.need_to_enable_vram_management(vram_config): + if model_class in VRAM_MANAGEMENT_MODULE_MAPS: + vram_module_map = VRAM_MANAGEMENT_MODULE_MAPS[model_class] if model_class not in VERSION_CHECKER_MAPS else VERSION_CHECKER_MAPS[model_class]() + module_map = {self.import_model_class(source): self.import_model_class(target) for source, target in vram_module_map.items()} + else: + module_map = {self.import_model_class(model_class): AutoWrappedModule} + else: + module_map = None + return module_map + + def load_model_file(self, config, path, vram_config, vram_limit=None, state_dict=None): + model_class = self.import_model_class(config["model_class"]) + model_config = config.get("extra_kwargs", {}) + if "state_dict_converter" in config: + state_dict_converter = self.import_model_class(config["state_dict_converter"]) + else: + state_dict_converter = None + module_map = self.fetch_module_map(config["model_class"], vram_config) + model = load_model( + model_class, path, model_config, + vram_config["computation_dtype"], vram_config["computation_device"], + state_dict_converter, + use_disk_map=True, + vram_config=vram_config, module_map=module_map, vram_limit=vram_limit, + state_dict=state_dict, + ) + return model + + def default_vram_config(self): + vram_config = { + "offload_dtype": None, + "offload_device": None, + "onload_dtype": torch.bfloat16, + "onload_device": "cpu", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cpu", + "computation_dtype": torch.bfloat16, + "computation_device": "cpu", + } + return vram_config + + def auto_load_model(self, path, vram_config=None, vram_limit=None, clear_parameters=False, state_dict=None): + print(f"Loading models from: {json.dumps(path, indent=4)}") + if vram_config is None: + vram_config = self.default_vram_config() + model_hash = hash_model_file(path) + loaded = False + for config in MODEL_CONFIGS: + if config["model_hash"] == model_hash: + model = self.load_model_file(config, path, vram_config, vram_limit=vram_limit, state_dict=state_dict) + if clear_parameters: self.clear_parameters(model) + self.model.append(model) + model_name = config["model_name"] + self.model_name.append(model_name) + self.model_path.append(path) + model_info = {"model_name": model_name, "model_class": config["model_class"], "extra_kwargs": config.get("extra_kwargs")} + print(f"Loaded model: {json.dumps(model_info, indent=4)}") + loaded = True + if not loaded: + raise ValueError(f"Cannot detect the model type. File: {path}. Model hash: {model_hash}") + + def fetch_model(self, model_name, index=None): + fetched_models = [] + fetched_model_paths = [] + for model, model_path, model_name_ in zip(self.model, self.model_path, self.model_name): + if model_name == model_name_: + fetched_models.append(model) + fetched_model_paths.append(model_path) + if len(fetched_models) == 0: + print(f"No {model_name} models available. This is not an error.") + model = None + elif len(fetched_models) == 1: + print(f"Using {model_name} from {json.dumps(fetched_model_paths[0], indent=4)}.") + model = fetched_models[0] + else: + if index is None: + model = fetched_models[0] + print(f"More than one {model_name} models are loaded: {fetched_model_paths}. Using {model_name} from {json.dumps(fetched_model_paths[0], indent=4)}.") + elif isinstance(index, int): + model = fetched_models[:index] + print(f"More than one {model_name} models are loaded: {fetched_model_paths}. Using {model_name} from {json.dumps(fetched_model_paths[:index], indent=4)}.") + else: + model = fetched_models + print(f"More than one {model_name} models are loaded: {fetched_model_paths}. Using {model_name} from {json.dumps(fetched_model_paths, indent=4)}.") + return model + + def clear_parameters(self, model: torch.nn.Module): + for name, module in model.named_children(): + self.clear_parameters(module) + for name, param in model.named_parameters(recurse=False): + setattr(model, name, None) diff --git a/diffsynth/models/mova_audio_dit.py b/diffsynth/models/mova_audio_dit.py new file mode 100644 index 0000000000000000000000000000000000000000..7b2b5d1335ff6503db41455617605cccd3b28b18 --- /dev/null +++ b/diffsynth/models/mova_audio_dit.py @@ -0,0 +1,57 @@ +import torch +import torch.nn as nn +from .wan_video_dit import WanModel, precompute_freqs_cis, sinusoidal_embedding_1d +from einops import rearrange +from ..core import gradient_checkpoint_forward + +def precompute_freqs_cis_1d(dim: int, end: int = 16384, theta: float = 10000.0): + f_freqs_cis = precompute_freqs_cis(dim, end, theta) + return f_freqs_cis.chunk(3, dim=-1) + +class MovaAudioDit(WanModel): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + head_dim = kwargs.get("dim", 1536) // kwargs.get("num_heads", 12) + self.freqs = precompute_freqs_cis_1d(head_dim) + self.patch_embedding = nn.Conv1d( + kwargs.get("in_dim", 128), kwargs.get("dim", 1536), kernel_size=[1], stride=[1] + ) + + def precompute_freqs_cis(self, dim: int, end: int = 16384, theta: float = 10000.0): + self.f_freqs_cis = precompute_freqs_cis_1d(dim, end, theta) + + def forward(self, + x: torch.Tensor, + timestep: torch.Tensor, + context: torch.Tensor, + use_gradient_checkpointing: bool = False, + use_gradient_checkpointing_offload: bool = False, + **kwargs, + ): + t = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, timestep)) + t_mod = self.time_projection(t).unflatten(1, (6, self.dim)) + context = self.text_embedding(context) + x, (f, ) = self.patchify(x) + freqs = torch.cat([ + self.freqs[0][:f].view(f, -1).expand(f, -1), + self.freqs[1][:f].view(f, -1).expand(f, -1), + self.freqs[2][:f].view(f, -1).expand(f, -1), + ], dim=-1).reshape(f, 1, -1).to(x.device) + + for block in self.blocks: + x = gradient_checkpoint_forward( + block, + use_gradient_checkpointing, + use_gradient_checkpointing_offload, + x, context, t_mod, freqs, + ) + x = self.head(x, t) + x = self.unpatchify(x, (f, )) + return x + + def unpatchify(self, x: torch.Tensor, grid_size: torch.Tensor): + return rearrange( + x, 'b f (p c) -> b c (f p)', + f=grid_size[0], + p=self.patch_size[0] + ) diff --git a/diffsynth/models/mova_audio_vae.py b/diffsynth/models/mova_audio_vae.py new file mode 100644 index 0000000000000000000000000000000000000000..570cd43f8f1c5cd3d3f182d8a5fbfe81106b49f2 --- /dev/null +++ b/diffsynth/models/mova_audio_vae.py @@ -0,0 +1,796 @@ +import math +from typing import List, Union +import numpy as np +import torch +from torch import nn +from torch.nn.utils import weight_norm +import torch.nn.functional as F +from einops import rearrange + +def WNConv1d(*args, **kwargs): + return weight_norm(nn.Conv1d(*args, **kwargs)) + + +def WNConvTranspose1d(*args, **kwargs): + return weight_norm(nn.ConvTranspose1d(*args, **kwargs)) + + +# Scripting this brings model speed up 1.4x +@torch.jit.script +def snake(x, alpha): + shape = x.shape + x = x.reshape(shape[0], shape[1], -1) + x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2) + x = x.reshape(shape) + return x + + +class Snake1d(nn.Module): + def __init__(self, channels): + super().__init__() + self.alpha = nn.Parameter(torch.ones(1, channels, 1)) + + def forward(self, x): + return snake(x, self.alpha) + + +class VectorQuantize(nn.Module): + """ + Implementation of VQ similar to Karpathy's repo: + https://github.com/karpathy/deep-vector-quantization + Additionally uses following tricks from Improved VQGAN + (https://arxiv.org/pdf/2110.04627.pdf): + 1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space + for improved codebook usage + 2. l2-normalized codes: Converts euclidean distance to cosine similarity which + improves training stability + """ + + def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int): + super().__init__() + self.codebook_size = codebook_size + self.codebook_dim = codebook_dim + + self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1) + self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1) + self.codebook = nn.Embedding(codebook_size, codebook_dim) + + def forward(self, z): + """Quantized the input tensor using a fixed codebook and returns + the corresponding codebook vectors + + Parameters + ---------- + z : Tensor[B x D x T] + + Returns + ------- + Tensor[B x D x T] + Quantized continuous representation of input + Tensor[1] + Commitment loss to train encoder to predict vectors closer to codebook + entries + Tensor[1] + Codebook loss to update the codebook + Tensor[B x T] + Codebook indices (quantized discrete representation of input) + Tensor[B x D x T] + Projected latents (continuous representation of input before quantization) + """ + + # Factorized codes (ViT-VQGAN) Project input into low-dimensional space + z_e = self.in_proj(z) # z_e : (B x D x T) + z_q, indices = self.decode_latents(z_e) + + commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2]) + codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2]) + + z_q = ( + z_e + (z_q - z_e).detach() + ) # noop in forward pass, straight-through gradient estimator in backward pass + + z_q = self.out_proj(z_q) + + return z_q, commitment_loss, codebook_loss, indices, z_e + + def embed_code(self, embed_id): + return F.embedding(embed_id, self.codebook.weight) + + def decode_code(self, embed_id): + return self.embed_code(embed_id).transpose(1, 2) + + def decode_latents(self, latents): + encodings = rearrange(latents, "b d t -> (b t) d") + codebook = self.codebook.weight # codebook: (N x D) + + # L2 normalize encodings and codebook (ViT-VQGAN) + encodings = F.normalize(encodings) + codebook = F.normalize(codebook) + + # Compute euclidean distance with codebook + dist = ( + encodings.pow(2).sum(1, keepdim=True) + - 2 * encodings @ codebook.t() + + codebook.pow(2).sum(1, keepdim=True).t() + ) + indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0)) + z_q = self.decode_code(indices) + return z_q, indices + + +class ResidualVectorQuantize(nn.Module): + """ + Introduced in SoundStream: An end2end neural audio codec + https://arxiv.org/abs/2107.03312 + """ + + def __init__( + self, + input_dim: int = 512, + n_codebooks: int = 9, + codebook_size: int = 1024, + codebook_dim: Union[int, list] = 8, + quantizer_dropout: float = 0.0, + ): + super().__init__() + if isinstance(codebook_dim, int): + codebook_dim = [codebook_dim for _ in range(n_codebooks)] + + self.n_codebooks = n_codebooks + self.codebook_dim = codebook_dim + self.codebook_size = codebook_size + + self.quantizers = nn.ModuleList( + [ + VectorQuantize(input_dim, codebook_size, codebook_dim[i]) + for i in range(n_codebooks) + ] + ) + self.quantizer_dropout = quantizer_dropout + + def forward(self, z, n_quantizers: int = None): + """Quantized the input tensor using a fixed set of `n` codebooks and returns + the corresponding codebook vectors + Parameters + ---------- + z : Tensor[B x D x T] + n_quantizers : int, optional + No. of quantizers to use + (n_quantizers < self.n_codebooks ex: for quantizer dropout) + Note: if `self.quantizer_dropout` is True, this argument is ignored + when in training mode, and a random number of quantizers is used. + Returns + ------- + dict + A dictionary with the following keys: + + "z" : Tensor[B x D x T] + Quantized continuous representation of input + "codes" : Tensor[B x N x T] + Codebook indices for each codebook + (quantized discrete representation of input) + "latents" : Tensor[B x N*D x T] + Projected latents (continuous representation of input before quantization) + "vq/commitment_loss" : Tensor[1] + Commitment loss to train encoder to predict vectors closer to codebook + entries + "vq/codebook_loss" : Tensor[1] + Codebook loss to update the codebook + """ + z_q = 0 + residual = z + commitment_loss = 0 + codebook_loss = 0 + + codebook_indices = [] + latents = [] + + if n_quantizers is None: + n_quantizers = self.n_codebooks + if self.training: + n_quantizers = torch.ones((z.shape[0],)) * self.n_codebooks + 1 + dropout = torch.randint(1, self.n_codebooks + 1, (z.shape[0],)) + n_dropout = int(z.shape[0] * self.quantizer_dropout) + n_quantizers[:n_dropout] = dropout[:n_dropout] + n_quantizers = n_quantizers.to(z.device) + + for i, quantizer in enumerate(self.quantizers): + if self.training is False and i >= n_quantizers: + break + + z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer( + residual + ) + + # Create mask to apply quantizer dropout + mask = ( + torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers + ) + z_q = z_q + z_q_i * mask[:, None, None] + residual = residual - z_q_i + + # Sum losses + commitment_loss += (commitment_loss_i * mask).mean() + codebook_loss += (codebook_loss_i * mask).mean() + + codebook_indices.append(indices_i) + latents.append(z_e_i) + + codes = torch.stack(codebook_indices, dim=1) + latents = torch.cat(latents, dim=1) + + return z_q, codes, latents, commitment_loss, codebook_loss + + def from_codes(self, codes: torch.Tensor): + """Given the quantized codes, reconstruct the continuous representation + Parameters + ---------- + codes : Tensor[B x N x T] + Quantized discrete representation of input + Returns + ------- + Tensor[B x D x T] + Quantized continuous representation of input + """ + z_q = 0.0 + z_p = [] + n_codebooks = codes.shape[1] + for i in range(n_codebooks): + z_p_i = self.quantizers[i].decode_code(codes[:, i, :]) + z_p.append(z_p_i) + + z_q_i = self.quantizers[i].out_proj(z_p_i) + z_q = z_q + z_q_i + return z_q, torch.cat(z_p, dim=1), codes + + def from_latents(self, latents: torch.Tensor): + """Given the unquantized latents, reconstruct the + continuous representation after quantization. + + Parameters + ---------- + latents : Tensor[B x N x T] + Continuous representation of input after projection + + Returns + ------- + Tensor[B x D x T] + Quantized representation of full-projected space + Tensor[B x D x T] + Quantized representation of latent space + """ + z_q = 0 + z_p = [] + codes = [] + dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers]) + + n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[ + 0 + ] + for i in range(n_codebooks): + j, k = dims[i], dims[i + 1] + z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :]) + z_p.append(z_p_i) + codes.append(codes_i) + + z_q_i = self.quantizers[i].out_proj(z_p_i) + z_q = z_q + z_q_i + + return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1) + + +class AbstractDistribution: + def sample(self): + raise NotImplementedError() + + def mode(self): + raise NotImplementedError() + + +class DiracDistribution(AbstractDistribution): + def __init__(self, value): + self.value = value + + def sample(self): + return self.value + + def mode(self): + return self.value + + +class DiagonalGaussianDistribution(object): + def __init__(self, parameters, deterministic=False): + self.parameters = parameters + self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) + + def sample(self): + x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) + return x + + def kl(self, other=None): + if self.deterministic: + return torch.Tensor([0.0]) + else: + if other is None: + return 0.5 * torch.mean( + torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, + dim=[1, 2], + ) + else: + return 0.5 * torch.mean( + torch.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var + - 1.0 + - self.logvar + + other.logvar, + dim=[1, 2], + ) + + def nll(self, sample, dims=[1, 2]): + if self.deterministic: + return torch.Tensor([0.0]) + logtwopi = np.log(2.0 * np.pi) + return 0.5 * torch.sum( + logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, + dim=dims, + ) + + def mode(self): + return self.mean + + +def normal_kl(mean1, logvar1, mean2, logvar2): + """ + source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 + Compute the KL divergence between two gaussians. + Shapes are automatically broadcasted, so batches can be compared to + scalars, among other use cases. + """ + tensor = None + for obj in (mean1, logvar1, mean2, logvar2): + if isinstance(obj, torch.Tensor): + tensor = obj + break + assert tensor is not None, "at least one argument must be a Tensor" + + # Force variances to be Tensors. Broadcasting helps convert scalars to + # Tensors, but it does not work for torch.exp(). + logvar1, logvar2 = [x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) for x in (logvar1, logvar2)] + + return 0.5 * ( + -1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) + ) + + +def init_weights(m): + if isinstance(m, nn.Conv1d): + nn.init.trunc_normal_(m.weight, std=0.02) + nn.init.constant_(m.bias, 0) + + +class ResidualUnit(nn.Module): + def __init__(self, dim: int = 16, dilation: int = 1): + super().__init__() + pad = ((7 - 1) * dilation) // 2 + self.block = nn.Sequential( + Snake1d(dim), + WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad), + Snake1d(dim), + WNConv1d(dim, dim, kernel_size=1), + ) + + def forward(self, x): + y = self.block(x) + pad = (x.shape[-1] - y.shape[-1]) // 2 + if pad > 0: + x = x[..., pad:-pad] + return x + y + + +class EncoderBlock(nn.Module): + def __init__(self, dim: int = 16, stride: int = 1): + super().__init__() + self.block = nn.Sequential( + ResidualUnit(dim // 2, dilation=1), + ResidualUnit(dim // 2, dilation=3), + ResidualUnit(dim // 2, dilation=9), + Snake1d(dim // 2), + WNConv1d( + dim // 2, + dim, + kernel_size=2 * stride, + stride=stride, + padding=math.ceil(stride / 2), + ), + ) + + def forward(self, x): + return self.block(x) + + +class Encoder(nn.Module): + def __init__( + self, + d_model: int = 64, + strides: list = [2, 4, 8, 8], + d_latent: int = 64, + ): + super().__init__() + # Create first convolution + self.block = [WNConv1d(1, d_model, kernel_size=7, padding=3)] + + # Create EncoderBlocks that double channels as they downsample by `stride` + for stride in strides: + d_model *= 2 + self.block += [EncoderBlock(d_model, stride=stride)] + + # Create last convolution + self.block += [ + Snake1d(d_model), + WNConv1d(d_model, d_latent, kernel_size=3, padding=1), + ] + + # Wrap black into nn.Sequential + self.block = nn.Sequential(*self.block) + self.enc_dim = d_model + + def forward(self, x): + return self.block(x) + + +class DecoderBlock(nn.Module): + def __init__(self, input_dim: int = 16, output_dim: int = 8, stride: int = 1): + super().__init__() + self.block = nn.Sequential( + Snake1d(input_dim), + WNConvTranspose1d( + input_dim, + output_dim, + kernel_size=2 * stride, + stride=stride, + padding=math.ceil(stride / 2), + output_padding=stride % 2, + ), + ResidualUnit(output_dim, dilation=1), + ResidualUnit(output_dim, dilation=3), + ResidualUnit(output_dim, dilation=9), + ) + + def forward(self, x): + return self.block(x) + + +class Decoder(nn.Module): + def __init__( + self, + input_channel, + channels, + rates, + d_out: int = 1, + ): + super().__init__() + + # Add first conv layer + layers = [WNConv1d(input_channel, channels, kernel_size=7, padding=3)] + + # Add upsampling + MRF blocks + for i, stride in enumerate(rates): + input_dim = channels // 2**i + output_dim = channels // 2 ** (i + 1) + layers += [DecoderBlock(input_dim, output_dim, stride)] + + # Add final conv layer + layers += [ + Snake1d(output_dim), + WNConv1d(output_dim, d_out, kernel_size=7, padding=3), + nn.Tanh(), + ] + + self.model = nn.Sequential(*layers) + + def forward(self, x): + return self.model(x) + + +class DacVAE(nn.Module): + + def __init__( + self, + encoder_dim: int = 128, + encoder_rates: List[int] = [2, 3, 4, 5, 8], + latent_dim: int = 128, + decoder_dim: int = 2048, + decoder_rates: List[int] = [8, 5, 4, 3, 2], + n_codebooks: int = 9, + codebook_size: int = 1024, + codebook_dim: Union[int, list] = 8, + quantizer_dropout: bool = False, + sample_rate: int = 48000, + continuous: bool = True, + use_weight_norm: bool = False, + ): + super().__init__() + + self.encoder_dim = encoder_dim + self.encoder_rates = encoder_rates + self.decoder_dim = decoder_dim + self.decoder_rates = decoder_rates + self.sample_rate = sample_rate + self.continuous = continuous + self.use_weight_norm = use_weight_norm + + if latent_dim is None: + latent_dim = encoder_dim * (2 ** len(encoder_rates)) + + self.latent_dim = latent_dim + + self.hop_length = np.prod(encoder_rates) + self.encoder = Encoder(encoder_dim, encoder_rates, latent_dim) + + if not continuous: + self.n_codebooks = n_codebooks + self.codebook_size = codebook_size + self.codebook_dim = codebook_dim + self.quantizer = ResidualVectorQuantize( + input_dim=latent_dim, + n_codebooks=n_codebooks, + codebook_size=codebook_size, + codebook_dim=codebook_dim, + quantizer_dropout=quantizer_dropout, + ) + else: + self.quant_conv = torch.nn.Conv1d(latent_dim, 2 * latent_dim, 1) + self.post_quant_conv = torch.nn.Conv1d(latent_dim, latent_dim, 1) + + self.decoder = Decoder( + latent_dim, + decoder_dim, + decoder_rates, + ) + self.sample_rate = sample_rate + self.apply(init_weights) + + self.delay = self.get_delay() + + if not self.use_weight_norm: + self.remove_weight_norm() + + def get_delay(self): + # Any number works here, delay is invariant to input length + l_out = self.get_output_length(0) + L = l_out + + layers = [] + for layer in self.modules(): + if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)): + layers.append(layer) + + for layer in reversed(layers): + d = layer.dilation[0] + k = layer.kernel_size[0] + s = layer.stride[0] + + if isinstance(layer, nn.ConvTranspose1d): + L = ((L - d * (k - 1) - 1) / s) + 1 + elif isinstance(layer, nn.Conv1d): + L = (L - 1) * s + d * (k - 1) + 1 + + L = math.ceil(L) + + l_in = L + + return (l_in - l_out) // 2 + + def get_output_length(self, input_length): + L = input_length + # Calculate output length + for layer in self.modules(): + if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)): + d = layer.dilation[0] + k = layer.kernel_size[0] + s = layer.stride[0] + + if isinstance(layer, nn.Conv1d): + L = ((L - d * (k - 1) - 1) / s) + 1 + elif isinstance(layer, nn.ConvTranspose1d): + L = (L - 1) * s + d * (k - 1) + 1 + + L = math.floor(L) + return L + + @property + def dtype(self): + """Get the dtype of the model parameters.""" + # Return the dtype of the first parameter found + for param in self.parameters(): + return param.dtype + return torch.float32 # fallback + + @property + def device(self): + """Get the device of the model parameters.""" + # Return the device of the first parameter found + for param in self.parameters(): + return param.device + return torch.device('cpu') # fallback + + def preprocess(self, audio_data, sample_rate): + if sample_rate is None: + sample_rate = self.sample_rate + assert sample_rate == self.sample_rate + + length = audio_data.shape[-1] + right_pad = math.ceil(length / self.hop_length) * self.hop_length - length + audio_data = nn.functional.pad(audio_data, (0, right_pad)) + + return audio_data + + def encode( + self, + audio_data: torch.Tensor, + n_quantizers: int = None, + ): + """Encode given audio data and return quantized latent codes + + Parameters + ---------- + audio_data : Tensor[B x 1 x T] + Audio data to encode + n_quantizers : int, optional + Number of quantizers to use, by default None + If None, all quantizers are used. + + Returns + ------- + dict + A dictionary with the following keys: + "z" : Tensor[B x D x T] + Quantized continuous representation of input + "codes" : Tensor[B x N x T] + Codebook indices for each codebook + (quantized discrete representation of input) + "latents" : Tensor[B x N*D x T] + Projected latents (continuous representation of input before quantization) + "vq/commitment_loss" : Tensor[1] + Commitment loss to train encoder to predict vectors closer to codebook + entries + "vq/codebook_loss" : Tensor[1] + Codebook loss to update the codebook + "length" : int + Number of samples in input audio + """ + z = self.encoder(audio_data) # [B x D x T] + if not self.continuous: + z, codes, latents, commitment_loss, codebook_loss = self.quantizer(z, n_quantizers) + else: + z = self.quant_conv(z) # [B x 2D x T] + z = DiagonalGaussianDistribution(z) + codes, latents, commitment_loss, codebook_loss = None, None, 0, 0 + + return z, codes, latents, commitment_loss, codebook_loss + + def decode(self, z: torch.Tensor): + """Decode given latent codes and return audio data + + Parameters + ---------- + z : Tensor[B x D x T] + Quantized continuous representation of input + length : int, optional + Number of samples in output audio, by default None + + Returns + ------- + dict + A dictionary with the following keys: + "audio" : Tensor[B x 1 x length] + Decoded audio data. + """ + if not self.continuous: + audio = self.decoder(z) + else: + z = self.post_quant_conv(z) + audio = self.decoder(z) + + return audio + + def forward( + self, + audio_data: torch.Tensor, + sample_rate: int = None, + n_quantizers: int = None, + ): + """Model forward pass + + Parameters + ---------- + audio_data : Tensor[B x 1 x T] + Audio data to encode + sample_rate : int, optional + Sample rate of audio data in Hz, by default None + If None, defaults to `self.sample_rate` + n_quantizers : int, optional + Number of quantizers to use, by default None. + If None, all quantizers are used. + + Returns + ------- + dict + A dictionary with the following keys: + "z" : Tensor[B x D x T] + Quantized continuous representation of input + "codes" : Tensor[B x N x T] + Codebook indices for each codebook + (quantized discrete representation of input) + "latents" : Tensor[B x N*D x T] + Projected latents (continuous representation of input before quantization) + "vq/commitment_loss" : Tensor[1] + Commitment loss to train encoder to predict vectors closer to codebook + entries + "vq/codebook_loss" : Tensor[1] + Codebook loss to update the codebook + "length" : int + Number of samples in input audio + "audio" : Tensor[B x 1 x length] + Decoded audio data. + """ + length = audio_data.shape[-1] + audio_data = self.preprocess(audio_data, sample_rate) + if not self.continuous: + z, codes, latents, commitment_loss, codebook_loss = self.encode(audio_data, n_quantizers) + + x = self.decode(z) + return { + "audio": x[..., :length], + "z": z, + "codes": codes, + "latents": latents, + "vq/commitment_loss": commitment_loss, + "vq/codebook_loss": codebook_loss, + } + else: + posterior, _, _, _, _ = self.encode(audio_data, n_quantizers) + z = posterior.sample() + x = self.decode(z) + + kl_loss = posterior.kl() + kl_loss = kl_loss.mean() + + return { + "audio": x[..., :length], + "z": z, + "kl_loss": kl_loss, + } + + def remove_weight_norm(self): + """ + Remove weight_norm from all modules in the model. + This fuses the weight_g and weight_v parameters into a single weight parameter. + Should be called before inference for better performance. + Returns: + self: The model with weight_norm removed + """ + from torch.nn.utils import remove_weight_norm + num_removed = 0 + for name, module in list(self.named_modules()): + if hasattr(module, "_forward_pre_hooks"): + for hook_id, hook in list(module._forward_pre_hooks.items()): + if "WeightNorm" in str(type(hook)): + try: + remove_weight_norm(module) + num_removed += 1 + # print(f"Removed weight_norm from: {name}") + except ValueError as e: + print(f"Failed to remove weight_norm from {name}: {e}") + if num_removed > 0: + # print(f"Successfully removed weight_norm from {num_removed} modules") + self.use_weight_norm = False + else: + print("No weight_norm found in the model") + return self diff --git a/diffsynth/models/mova_dual_tower_bridge.py b/diffsynth/models/mova_dual_tower_bridge.py new file mode 100644 index 0000000000000000000000000000000000000000..ddb342e011ab91fe8104e4fcc0962b85f4a15c0f --- /dev/null +++ b/diffsynth/models/mova_dual_tower_bridge.py @@ -0,0 +1,595 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import Dict, List, Tuple, Optional +from einops import rearrange +from .wan_video_dit import AttentionModule, RMSNorm +from ..core import gradient_checkpoint_forward + +class RotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + + def __init__(self, base: float, dim: int, device=None): + super().__init__() + self.base = base + self.dim = dim + self.attention_scaling = 1.0 + + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + @torch.no_grad() + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +@torch.compile(fullgraph=True) +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class PerFrameAttentionPooling(nn.Module): + """ + Per-frame multi-head attention pooling. + + Given a flattened token sequence [B, L, D] and grid size (T, H, W), perform a + single-query attention pooling over the H*W tokens for each time frame, producing + [B, T, D]. + + Inspired by SigLIP's Multihead Attention Pooling head (without MLP/residual stack). + """ + + def __init__(self, dim: int, num_heads: int, eps: float = 1e-6): + super().__init__() + assert dim % num_heads == 0, "dim must be divisible by num_heads" + self.dim = dim + self.num_heads = num_heads + + self.probe = nn.Parameter(torch.randn(1, 1, dim)) + nn.init.normal_(self.probe, std=0.02) + + self.attention = nn.MultiheadAttention(embed_dim=dim, num_heads=num_heads, batch_first=True) + self.layernorm = nn.LayerNorm(dim, eps=eps) + + def forward(self, x: torch.Tensor, grid_size: Tuple[int, int, int]) -> torch.Tensor: + """ + Args: + x: [B, L, D], where L = T*H*W + grid_size: (T, H, W) + Returns: + pooled: [B, T, D] + """ + B, L, D = x.shape + T, H, W = grid_size + assert D == self.dim, f"Channel dimension mismatch: D={D} vs dim={self.dim}" + assert L == T * H * W, f"Flattened length mismatch: L={L} vs T*H*W={T*H*W}" + + S = H * W + # Re-arrange tokens grouped by frame. + x_bt_s_d = x.view(B, T, S, D).contiguous().view(B * T, S, D) # [B*T, S, D] + + # A learnable probe as the query (one query per frame). + probe = self.probe.expand(B * T, -1, -1) # [B*T, 1, D] + + # Attention pooling: query=probe, key/value=H*W tokens within the frame. + pooled_bt_1_d = self.attention(probe, x_bt_s_d, x_bt_s_d, need_weights=False)[0] # [B*T, 1, D] + pooled_bt_d = pooled_bt_1_d.squeeze(1) # [B*T, D] + + # Restore to [B, T, D]. + pooled = pooled_bt_d.view(B, T, D) + pooled = self.layernorm(pooled) + return pooled + + +class CrossModalInteractionController: + """ + Strategy class that controls interactions between two towers. + Manages the interaction mapping between visual DiT (e.g. 30 layers) and audio DiT (e.g. 30 layers). + """ + + def __init__(self, visual_layers: int = 30, audio_layers: int = 30): + self.visual_layers = visual_layers + self.audio_layers = audio_layers + self.min_layers = min(visual_layers, audio_layers) + + def get_interaction_layers(self, strategy: str = "shallow_focus") -> Dict[str, List[Tuple[int, int]]]: + """ + Get interaction layer mappings. + + Args: + strategy: interaction strategy + - "shallow_focus": emphasize shallow layers to avoid deep-layer asymmetry + - "distributed": distributed interactions across the network + - "progressive": dense shallow interactions, sparse deeper interactions + - "custom": custom interaction layers + + Returns: + A dict containing mappings for 'v2a' (visual -> audio) and 'a2v' (audio -> visual). + """ + + if strategy == "shallow_focus": + # Emphasize the first ~1/3 layers to avoid deep-layer asymmetry. + num_interact = min(10, self.min_layers // 3) + interact_layers = list(range(0, num_interact)) + + elif strategy == "distributed": + # Distribute interactions across the network (every few layers). + step = 3 + interact_layers = list(range(0, self.min_layers, step)) + + elif strategy == "progressive": + # Progressive: dense shallow interactions, sparse deeper interactions. + shallow = list(range(0, min(8, self.min_layers))) # Dense for the first 8 layers. + if self.min_layers > 8: + deep = list(range(8, self.min_layers, 3)) # Every 3 layers afterwards. + interact_layers = shallow + deep + else: + interact_layers = shallow + + elif strategy == "custom": + # Custom strategy: adjust as needed. + interact_layers = [0, 2, 4, 6, 8, 12, 16, 20] # Explicit layer indices. + interact_layers = [i for i in interact_layers if i < self.min_layers] + + elif strategy == "full": + interact_layers = list(range(0, self.min_layers)) + + else: + raise ValueError(f"Unknown interaction strategy: {strategy}") + + # Build bidirectional mapping. + mapping = { + 'v2a': [(i, i) for i in interact_layers], # visual layer i -> audio layer i + 'a2v': [(i, i) for i in interact_layers] # audio layer i -> visual layer i + } + + return mapping + + def should_interact(self, layer_idx: int, direction: str, interaction_mapping: Dict) -> bool: + """ + Check whether a given layer should interact. + + Args: + layer_idx: current layer index + direction: interaction direction ('v2a' or 'a2v') + interaction_mapping: interaction mapping table + + Returns: + bool: whether to interact + """ + if direction not in interaction_mapping: + return False + + return any(src == layer_idx for src, _ in interaction_mapping[direction]) + + +class ConditionalCrossAttention(nn.Module): + def __init__(self, dim: int, kv_dim: int, num_heads: int, eps: float = 1e-6): + super().__init__() + self.q_dim = dim + self.kv_dim = kv_dim + self.num_heads = num_heads + self.head_dim = self.q_dim // num_heads + + self.q = nn.Linear(dim, dim) + self.k = nn.Linear(kv_dim, dim) + self.v = nn.Linear(kv_dim, dim) + self.o = nn.Linear(dim, dim) + self.norm_q = RMSNorm(dim, eps=eps) + self.norm_k = RMSNorm(dim, eps=eps) + + self.attn = AttentionModule(self.num_heads) + + def forward(self, x: torch.Tensor, y: torch.Tensor, x_freqs: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, y_freqs: Optional[Tuple[torch.Tensor, torch.Tensor]] = None): + ctx = y + q = self.norm_q(self.q(x)) + k = self.norm_k(self.k(ctx)) + v = self.v(ctx) + if x_freqs is not None: + x_cos, x_sin = x_freqs + B, L, _ = q.shape + q_view = rearrange(q, 'b l (h d) -> b l h d', d=self.head_dim) + x_cos = x_cos.to(q_view.dtype).to(q_view.device) + x_sin = x_sin.to(q_view.dtype).to(q_view.device) + # Expect x_cos/x_sin shape: [B or 1, L, head_dim] + q_view, _ = apply_rotary_pos_emb(q_view, q_view, x_cos, x_sin, unsqueeze_dim=2) + q = rearrange(q_view, 'b l h d -> b l (h d)') + if y_freqs is not None: + y_cos, y_sin = y_freqs + Bc, Lc, _ = k.shape + k_view = rearrange(k, 'b l (h d) -> b l h d', d=self.head_dim) + y_cos = y_cos.to(k_view.dtype).to(k_view.device) + y_sin = y_sin.to(k_view.dtype).to(k_view.device) + # Expect y_cos/y_sin shape: [B or 1, L, head_dim] + _, k_view = apply_rotary_pos_emb(k_view, k_view, y_cos, y_sin, unsqueeze_dim=2) + k = rearrange(k_view, 'b l h d -> b l (h d)') + x = self.attn(q, k, v) + return self.o(x) + + +# from diffusers.models.attention import AdaLayerNorm +class AdaLayerNorm(nn.Module): + r""" + Norm layer modified to incorporate timestep embeddings. + + Parameters: + embedding_dim (`int`): The size of each embedding vector. + num_embeddings (`int`, *optional*): The size of the embeddings dictionary. + output_dim (`int`, *optional*): + norm_elementwise_affine (`bool`, defaults to `False): + norm_eps (`bool`, defaults to `False`): + chunk_dim (`int`, defaults to `0`): + """ + + def __init__( + self, + embedding_dim: int, + num_embeddings: Optional[int] = None, + output_dim: Optional[int] = None, + norm_elementwise_affine: bool = False, + norm_eps: float = 1e-5, + chunk_dim: int = 0, + ): + super().__init__() + + self.chunk_dim = chunk_dim + output_dim = output_dim or embedding_dim * 2 + + if num_embeddings is not None: + self.emb = nn.Embedding(num_embeddings, embedding_dim) + else: + self.emb = None + + self.silu = nn.SiLU() + self.linear = nn.Linear(embedding_dim, output_dim) + self.norm = nn.LayerNorm(output_dim // 2, norm_eps, norm_elementwise_affine) + + def forward( + self, x: torch.Tensor, timestep: Optional[torch.Tensor] = None, temb: Optional[torch.Tensor] = None + ) -> torch.Tensor: + if self.emb is not None: + temb = self.emb(timestep) + + temb = self.linear(self.silu(temb)) + + if self.chunk_dim == 2: + scale, shift = temb.chunk(2, dim=2) + # print(f"{x.shape = }, {scale.shape = }, {shift.shape = }") + elif self.chunk_dim == 1: + # This is a bit weird why we have the order of "shift, scale" here and "scale, shift" in the + # other if-branch. This branch is specific to CogVideoX and OmniGen for now. + shift, scale = temb.chunk(2, dim=1) + shift = shift[:, None, :] + scale = scale[:, None, :] + else: + scale, shift = temb.chunk(2, dim=0) + + x = self.norm(x) * (1 + scale) + shift + return x + + +class ConditionalCrossAttentionBlock(nn.Module): + """ + A thin wrapper around ConditionalCrossAttention. + Applies LayerNorm to the conditioning input `y` before cross-attention. + """ + def __init__(self, dim: int, kv_dim: int, num_heads: int, eps: float = 1e-6, pooled_adaln: bool = False): + super().__init__() + self.y_norm = nn.LayerNorm(kv_dim, eps=eps) + self.inner = ConditionalCrossAttention(dim=dim, kv_dim=kv_dim, num_heads=num_heads, eps=eps) + self.pooled_adaln = pooled_adaln + if pooled_adaln: + self.per_frame_pooling = PerFrameAttentionPooling(kv_dim, num_heads=num_heads, eps=eps) + self.adaln = AdaLayerNorm(kv_dim, output_dim=dim*2, chunk_dim=2) + + def forward( + self, + x: torch.Tensor, + y: torch.Tensor, + x_freqs: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + y_freqs: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + video_grid_size: Optional[Tuple[int, int, int]] = None, + ) -> torch.Tensor: + if self.pooled_adaln: + assert video_grid_size is not None, "video_grid_size must not be None" + pooled_y = self.per_frame_pooling(y, video_grid_size) + # Interpolate pooled_y along its temporal dimension to match x's sequence length. + if pooled_y.shape[1] != x.shape[1]: + pooled_y = F.interpolate( + pooled_y.permute(0, 2, 1), # [B, C, T] + size=x.shape[1], + mode='linear', + align_corners=False, + ).permute(0, 2, 1) # [B, T, C] + x = self.adaln(x, temb=pooled_y) + y = self.y_norm(y) + return self.inner(x=x, y=y, x_freqs=x_freqs, y_freqs=y_freqs) + + +class DualTowerConditionalBridge(nn.Module): + """ + Dual-tower conditional bridge. + """ + def __init__(self, + visual_layers: int = 40, + audio_layers: int = 30, + visual_hidden_dim: int = 5120, # visual DiT hidden state dimension + audio_hidden_dim: int = 1536, # audio DiT hidden state dimension + audio_fps: float = 50.0, + head_dim: int = 128, # attention head dimension + interaction_strategy: str = "full", + apply_cross_rope: bool = True, # whether to apply RoPE in cross-attention + apply_first_frame_bias_in_rope: bool = False, # whether to account for 1/video_fps bias for the first frame in RoPE alignment + trainable_condition_scale: bool = False, + pooled_adaln: bool = False, + ): + super().__init__() + + self.visual_hidden_dim = visual_hidden_dim + self.audio_hidden_dim = audio_hidden_dim + self.audio_fps = audio_fps + self.head_dim = head_dim + self.apply_cross_rope = apply_cross_rope + self.apply_first_frame_bias_in_rope = apply_first_frame_bias_in_rope + self.trainable_condition_scale = trainable_condition_scale + self.pooled_adaln = pooled_adaln + if self.trainable_condition_scale: + self.condition_scale = nn.Parameter(torch.tensor([1.0], dtype=torch.float32)) + else: + self.condition_scale = 1.0 + + self.controller = CrossModalInteractionController(visual_layers, audio_layers) + self.interaction_mapping = self.controller.get_interaction_layers(interaction_strategy) + + # Conditional cross-attention modules operating at the DiT hidden-state level. + self.audio_to_video_conditioners = nn.ModuleDict() # audio hidden states -> visual DiT conditioning + self.video_to_audio_conditioners = nn.ModuleDict() # visual hidden states -> audio DiT conditioning + + # Build conditioners for layers that should interact. + # audio hidden states condition the visual DiT + self.rotary = RotaryEmbedding(base=10000.0, dim=head_dim) + for v_layer, _ in self.interaction_mapping['a2v']: + self.audio_to_video_conditioners[str(v_layer)] = ConditionalCrossAttentionBlock( + dim=visual_hidden_dim, # 3072 (visual DiT hidden states) + kv_dim=audio_hidden_dim, # 1536 (audio DiT hidden states) + num_heads=visual_hidden_dim // head_dim, # derive number of heads from hidden dim + pooled_adaln=False # a2v typically does not need pooled AdaLN + ) + + # visual hidden states condition the audio DiT + for a_layer, _ in self.interaction_mapping['v2a']: + self.video_to_audio_conditioners[str(a_layer)] = ConditionalCrossAttentionBlock( + dim=audio_hidden_dim, # 1536 (audio DiT hidden states) + kv_dim=visual_hidden_dim, # 3072 (visual DiT hidden states) + num_heads=audio_hidden_dim // head_dim, # safe head count derivation + pooled_adaln=self.pooled_adaln + ) + + @torch.no_grad() + def build_aligned_freqs(self, + video_fps: float, + grid_size: Tuple[int, int, int], + audio_steps: int, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None) -> Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]: + """ + Build aligned RoPE (cos, sin) pairs based on video fps, video grid size (f_v, h, w), + and audio sequence length `audio_steps` (with fixed audio fps = 44100/2048). + + Returns: + visual_freqs: (cos_v, sin_v), shape [1, f_v*h*w, head_dim] + audio_freqs: (cos_a, sin_a), shape [1, audio_steps, head_dim] + """ + f_v, h, w = grid_size + L_v = f_v * h * w + L_a = int(audio_steps) + + device = device or next(self.parameters()).device + dtype = dtype or torch.float32 + + # Audio positions: 0,1,2,...,L_a-1 (audio as reference). + audio_pos = torch.arange(L_a, device=device, dtype=torch.float32).unsqueeze(0) + + # Video positions: align video frames to audio-step units. + # FIXME(dhyu): hard-coded VAE temporal stride = 4 + if self.apply_first_frame_bias_in_rope: + # Account for the "first frame lasts 1/video_fps" bias. + video_effective_fps = float(video_fps) / 4.0 + if f_v > 0: + t_starts = torch.zeros((f_v,), device=device, dtype=torch.float32) + if f_v > 1: + t_starts[1:] = (1.0 / float(video_fps)) + torch.arange(f_v - 1, device=device, dtype=torch.float32) * (1.0 / video_effective_fps) + else: + t_starts = torch.zeros((0,), device=device, dtype=torch.float32) + # Convert to audio-step units. + video_pos_per_frame = t_starts * float(self.audio_fps) + else: + # No first-frame bias: uniform alignment. + scale = float(self.audio_fps) / float(video_fps / 4.0) + video_pos_per_frame = torch.arange(f_v, device=device, dtype=torch.float32) * scale + # Flatten to f*h*w; tokens within the same frame share the same time position. + video_pos = video_pos_per_frame.repeat_interleave(h * w).unsqueeze(0) + + # print(f"video fps: {video_fps}, audio fps: {self.audio_fps}, scale: {scale}") + # print(f"video pos: {video_pos.shape}, audio pos: {audio_pos.shape}") + + # Build dummy x to produce cos/sin, dim=head_dim. + dummy_v = torch.zeros((1, L_v, self.head_dim), device=device, dtype=dtype) + dummy_a = torch.zeros((1, L_a, self.head_dim), device=device, dtype=dtype) + + cos_v, sin_v = self.rotary(dummy_v, position_ids=video_pos) + cos_a, sin_a = self.rotary(dummy_a, position_ids=audio_pos) + + return (cos_v, sin_v), (cos_a, sin_a) + + def should_interact(self, layer_idx: int, direction: str) -> bool: + return self.controller.should_interact(layer_idx, direction, self.interaction_mapping) + + def apply_conditional_control( + self, + layer_idx: int, + direction: str, + primary_hidden_states: torch.Tensor, + condition_hidden_states: torch.Tensor, + x_freqs: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + y_freqs: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + condition_scale: Optional[float] = None, + video_grid_size: Optional[Tuple[int, int, int]] = None, + use_gradient_checkpointing: Optional[bool] = False, + use_gradient_checkpointing_offload: Optional[bool] = False, + ) -> torch.Tensor: + """ + Apply conditional control (at the DiT hidden-state level). + + Args: + layer_idx: current layer index + direction: conditioning direction + - 'a2v': audio hidden states -> visual DiT + - 'v2a': visual hidden states -> audio DiT + primary_hidden_states: primary DiT hidden states [B, L, hidden_dim] + condition_hidden_states: condition DiT hidden states [B, L, hidden_dim] + condition_scale: conditioning strength (similar to CFG scale) + + Returns: + Conditioned primary DiT hidden states [B, L, hidden_dim] + """ + + if not self.controller.should_interact(layer_idx, direction, self.interaction_mapping): + return primary_hidden_states + + if direction == 'a2v': + # audio hidden states condition the visual DiT + conditioner = self.audio_to_video_conditioners[str(layer_idx)] + + elif direction == 'v2a': + # visual hidden states condition the audio DiT + conditioner = self.video_to_audio_conditioners[str(layer_idx)] + else: + raise ValueError(f"Invalid direction: {direction}") + + conditioned_features = gradient_checkpoint_forward( + conditioner, + use_gradient_checkpointing, + use_gradient_checkpointing_offload, + x=primary_hidden_states, + y=condition_hidden_states, + x_freqs=x_freqs, + y_freqs=y_freqs, + video_grid_size=video_grid_size, + ) + + if self.trainable_condition_scale and condition_scale is not None: + print( + "[WARN] This model has a trainable condition_scale, but an external " + f"condition_scale={condition_scale} was provided. The trainable condition_scale " + "will be ignored in favor of the external value." + ) + + scale = condition_scale if condition_scale is not None else self.condition_scale + + primary_hidden_states = primary_hidden_states + conditioned_features * scale + + return primary_hidden_states + + def forward( + self, + layer_idx: int, + visual_hidden_states: torch.Tensor, + audio_hidden_states: torch.Tensor, + *, + x_freqs: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + y_freqs: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + a2v_condition_scale: Optional[float] = None, + v2a_condition_scale: Optional[float] = None, + condition_scale: Optional[float] = None, + video_grid_size: Optional[Tuple[int, int, int]] = None, + use_gradient_checkpointing: Optional[bool] = False, + use_gradient_checkpointing_offload: Optional[bool] = False, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Apply bidirectional conditional control to both visual/audio towers. + + Args: + layer_idx: current layer index + visual_hidden_states: visual DiT hidden states + audio_hidden_states: audio DiT hidden states + x_freqs / y_freqs: cross-modal RoPE (cos, sin) pairs. + If provided, x_freqs is assumed to correspond to the primary tower and y_freqs + to the conditioning tower. + a2v_condition_scale: audio->visual conditioning strength (overrides global condition_scale) + v2a_condition_scale: visual->audio conditioning strength (overrides global condition_scale) + condition_scale: fallback conditioning strength when per-direction scale is None + video_grid_size: (F, H, W), used on the audio side when pooled_adaln is enabled + + Returns: + (visual_hidden_states, audio_hidden_states), both conditioned in their respective directions. + """ + + visual_conditioned = self.apply_conditional_control( + layer_idx=layer_idx, + direction="a2v", + primary_hidden_states=visual_hidden_states, + condition_hidden_states=audio_hidden_states, + x_freqs=x_freqs, + y_freqs=y_freqs, + condition_scale=a2v_condition_scale if a2v_condition_scale is not None else condition_scale, + video_grid_size=video_grid_size, + use_gradient_checkpointing=use_gradient_checkpointing, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + ) + + audio_conditioned = self.apply_conditional_control( + layer_idx=layer_idx, + direction="v2a", + primary_hidden_states=audio_hidden_states, + condition_hidden_states=visual_hidden_states, + x_freqs=y_freqs, + y_freqs=x_freqs, + condition_scale=v2a_condition_scale if v2a_condition_scale is not None else condition_scale, + video_grid_size=video_grid_size, + use_gradient_checkpointing=use_gradient_checkpointing, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + ) + + return visual_conditioned, audio_conditioned diff --git a/diffsynth/models/nexus_gen.py b/diffsynth/models/nexus_gen.py new file mode 100644 index 0000000000000000000000000000000000000000..011039842312b01a0bbb69b999bc868902736e9a --- /dev/null +++ b/diffsynth/models/nexus_gen.py @@ -0,0 +1,161 @@ +import torch +from PIL import Image + + +class NexusGenAutoregressiveModel(torch.nn.Module): + def __init__(self, max_length=1024, max_pixels=262640): + super(NexusGenAutoregressiveModel, self).__init__() + from .nexus_gen_ar_model import Qwen2_5_VLForConditionalGeneration + from transformers import Qwen2_5_VLConfig + self.max_length = max_length + self.max_pixels = max_pixels + model_config = Qwen2_5_VLConfig(**{ + "_name_or_path": "DiffSynth-Studio/Nexus-GenV2", + "architectures": [ + "Qwen2_5_VLForConditionalGeneration" + ], + "attention_dropout": 0.0, + "auto_map": { + "AutoConfig": "configuration_qwen2_5_vl.Qwen2_5_VLConfig", + "AutoModel": "modeling_qwen2_5_vl.Qwen2_5_VLModel", + "AutoModelForCausalLM": "modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration" + }, + "bos_token_id": 151643, + "eos_token_id": 151645, + "hidden_act": "silu", + "hidden_size": 3584, + "image_token_id": 151655, + "initializer_range": 0.02, + "intermediate_size": 18944, + "max_position_embeddings": 128000, + "max_window_layers": 28, + "model_type": "qwen2_5_vl", + "num_attention_heads": 28, + "num_hidden_layers": 28, + "num_key_value_heads": 4, + "pad_token_id": 151643, + "rms_norm_eps": 1e-06, + "rope_scaling": { + "mrope_section": [ + 16, + 24, + 24 + ], + "rope_type": "default", + "type": "default" + }, + "rope_theta": 1000000.0, + "sliding_window": 32768, + "tie_word_embeddings": False, + "torch_dtype": "bfloat16", + "transformers_version": "4.49.0", + "use_cache": False, + "use_sliding_window": False, + "video_token_id": 151656, + "vision_config": { + "hidden_size": 1280, + "in_chans": 3, + "model_type": "qwen2_5_vl", + "spatial_patch_size": 14, + "tokens_per_second": 2, + "torch_dtype": "bfloat16" + }, + "vision_end_token_id": 151653, + "vision_start_token_id": 151652, + "vision_token_id": 151654, + "vocab_size": 152064 + }) + self.model = Qwen2_5_VLForConditionalGeneration(model_config) + self.processor = None + + + def load_processor(self, path): + from .nexus_gen_ar_model import Qwen2_5_VLProcessor + self.processor = Qwen2_5_VLProcessor.from_pretrained(path) + + + @staticmethod + def state_dict_converter(): + return NexusGenAutoregressiveModelStateDictConverter() + + def bound_image(self, image, max_pixels=262640): + from qwen_vl_utils import smart_resize + resized_height, resized_width = smart_resize( + image.height, + image.width, + max_pixels=max_pixels, + ) + return image.resize((resized_width, resized_height)) + + def get_editing_msg(self, instruction): + if '' not in instruction: + instruction = ' ' + instruction + messages = [{"role":"user", "content":instruction}, {"role":"assistant", "content":"Here is the image: "}] + return messages + + def get_generation_msg(self, instruction): + instruction = "Generate an image according to the following description: {}".format(instruction) + messages = [{"role":"user", "content":instruction}, {"role":"assistant", "content":"Here is an image based on the description: "}] + return messages + + def forward(self, instruction, ref_image=None, num_img_tokens=81): + """ + Generate target embeddings for the given instruction and reference image. + """ + if ref_image is not None: + messages = self.get_editing_msg(instruction) + images = [self.bound_image(ref_image)] + [Image.new(mode='RGB', size=(252, 252), color=(255, 255, 255))] + output_image_embeddings = self.get_target_embeddings(images, messages, self.processor, self.model, num_img_tokens) + else: + messages = self.get_generation_msg(instruction) + images = [Image.new(mode='RGB', size=(252, 252), color=(255, 255, 255))] + output_image_embeddings = self.get_target_embeddings(images, messages, self.processor, self.model, num_img_tokens) + + return output_image_embeddings + + def get_target_embeddings(self, images, messages, processor, model, num_img_tokens=81): + text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=False) + text = text.replace('', '<|vision_start|><|image_pad|><|vision_end|>') + inputs = processor( + text=[text], + images=images, + padding=True, + return_tensors="pt", + ) + inputs = inputs.to(model.device) + + input_embeds = model.model.embed_tokens(inputs['input_ids']) + image_embeds = model.visual(inputs['pixel_values'], grid_thw=inputs['image_grid_thw']) + ground_truth_image_embeds = image_embeds[-num_img_tokens:] + input_image_embeds = image_embeds[:-num_img_tokens] + + image_mask = inputs['input_ids'] == model.config.image_token_id + indices = image_mask.cumsum(dim=1) + input_image_mask = torch.logical_and(indices <= (image_embeds.shape[0] - ground_truth_image_embeds.shape[0]), image_mask) + gt_image_mask = torch.logical_and(image_mask, ~input_image_mask) + input_image_mask = input_image_mask.unsqueeze(-1).expand_as(input_embeds) + input_embeds = input_embeds.masked_scatter(input_image_mask, input_image_embeds) + + image_prefill_embeds = model.image_prefill_embeds( + torch.arange(81, device=model.device).long() + ) + input_embeds = input_embeds.masked_scatter(gt_image_mask.unsqueeze(-1).expand_as(input_embeds), image_prefill_embeds) + + position_ids, _ = model.get_rope_index( + inputs['input_ids'], + inputs['image_grid_thw'], + attention_mask=inputs['attention_mask']) + position_ids = position_ids.contiguous() + outputs = model(inputs_embeds=input_embeds, position_ids=position_ids, attention_mask=inputs['attention_mask'], return_dict=True) + output_image_embeddings = outputs.image_embeddings[:, :-1, :] + output_image_embeddings = output_image_embeddings[gt_image_mask[:, 1:]] + return output_image_embeddings, input_image_embeds, inputs['image_grid_thw'] + + +class NexusGenAutoregressiveModelStateDictConverter: + def __init__(self): + pass + + def from_civitai(self, state_dict): + state_dict = {"model." + key: value for key, value in state_dict.items()} + return state_dict diff --git a/diffsynth/models/nexus_gen_ar_model.py b/diffsynth/models/nexus_gen_ar_model.py new file mode 100644 index 0000000000000000000000000000000000000000..b647786aafc5c80245272269d3e0a525e03b2da1 --- /dev/null +++ b/diffsynth/models/nexus_gen_ar_model.py @@ -0,0 +1,1143 @@ +import os +import re +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from torch.nn import CrossEntropyLoss + +from transformers.cache_utils import Cache +from transformers.generation import GenerationMixin, LogitsProcessorList, StoppingCriteriaList, GenerationConfig, GenerateDecoderOnlyOutput, GenerateEncoderDecoderOutput +from transformers.utils import add_start_docstrings_to_model_forward, logging, replace_return_docstrings +from transformers.modeling_outputs import ModelOutput +from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import Qwen2_5_VLConfig +from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( + Qwen2_5_VisionTransformerPretrainedModel, + Qwen2_5_VLModel, + Qwen2_5_VLPreTrainedModel, + QWEN2_5_VL_INPUTS_DOCSTRING, + ) + +from transformers.feature_extraction_utils import BatchFeature +from transformers.image_utils import ImageInput, VideoInput +from transformers.processing_utils import ProcessingKwargs, ProcessorMixin, Unpack, VideosKwargs +from transformers.tokenization_utils_base import PreTokenizedInput, TextInput + +GenerateNonBeamOutput = Union[GenerateDecoderOnlyOutput, GenerateEncoderDecoderOutput] + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "Qwen2_5_VLConfig" + + +@dataclass +class Qwen2_5_VLCausalLMOutputWithPast(ModelOutput): + """ + Base class for Qwen2_5_VL causal language model (or autoregressive) outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): + The rope index difference between sequence length and multimodal rope. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + image_embeddings: torch.FloatTensor = None + past_key_values: Optional[List[torch.FloatTensor]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + rope_deltas: Optional[torch.LongTensor] = None + + +class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + config_class = Qwen2_5_VLConfig + _no_split_modules = ["Qwen2_5_VLDecoderLayer", "Qwen2_5_VLVisionBlock"] + + def __init__(self, config): + super().__init__(config) + self.visual = Qwen2_5_VisionTransformerPretrainedModel._from_config(config.vision_config) + self.model = Qwen2_5_VLModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.vision_head = nn.Linear(config.hidden_size, config.hidden_size, bias=False) + self.rope_deltas = None # cache rope_deltas here + self.image_prefill_embeds = nn.Embedding(81, config.hidden_size) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def get_rope_index( + self, + input_ids: Optional[torch.LongTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + second_per_grid_ts: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Calculate the 3D rope index based on image and video's temporal, height and width in LLM. + + Explanation: + Each embedding sequence contains vision embedding and text embedding or just contains text embedding. + + For pure text embedding sequence, the rotary position embedding has no difference with modern LLMs. + Examples: + input_ids: [T T T T T], here T is for text. + temporal position_ids: [0, 1, 2, 3, 4] + height position_ids: [0, 1, 2, 3, 4] + width position_ids: [0, 1, 2, 3, 4] + + For vision and text embedding sequence, we calculate 3D rotary position embedding for vision part + and 1D rotary position embedding for text part. + Examples: + Temporal (Time): 3 patches, representing different segments of the video in time. + Height: 2 patches, dividing each frame vertically. + Width: 2 patches, dividing each frame horizontally. + We also have some important parameters: + fps (Frames Per Second): The video's frame rate, set to 1. This means one frame is processed each second. + tokens_per_second: This is a crucial parameter. It dictates how many "time-steps" or "temporal tokens" are conceptually packed into a one-second interval of the video. In this case, we have 25 tokens per second. So each second of the video will be represented with 25 separate time points. It essentially defines the temporal granularity. + temporal_patch_size: The number of frames that compose one temporal patch. Here, it's 2 frames. + interval: The step size for the temporal position IDs, calculated as tokens_per_second * temporal_patch_size / fps. In this case, 25 * 2 / 1 = 50. This means that each temporal patch will be have a difference of 50 in the temporal position IDs. + input_ids: [V V V V V V V V V V V V T T T T T], here V is for vision. + vision temporal position_ids: [0, 0, 0, 0, 50, 50, 50, 50, 100, 100, 100, 100] + vision height position_ids: [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1] + vision width position_ids: [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1] + text temporal position_ids: [101, 102, 103, 104, 105] + text height position_ids: [101, 102, 103, 104, 105] + text width position_ids: [101, 102, 103, 104, 105] + Here we calculate the text start position_ids as the max vision position_ids plus 1. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + second_per_grid_ts (`torch.Tensor` of shape `(num_videos)`, *optional*): + The time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + Returns: + position_ids (`torch.LongTensor` of shape `(3, batch_size, sequence_length)`) + mrope_position_deltas (`torch.Tensor` of shape `(batch_size)`) + """ + spatial_merge_size = self.config.vision_config.spatial_merge_size + image_token_id = self.config.image_token_id + video_token_id = self.config.video_token_id + vision_start_token_id = self.config.vision_start_token_id + mrope_position_deltas = [] + if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None): + total_input_ids = input_ids + if attention_mask is None: + attention_mask = torch.ones_like(total_input_ids) + position_ids = torch.ones( + 3, + input_ids.shape[0], + input_ids.shape[1], + dtype=input_ids.dtype, + device=input_ids.device, + ) + image_index, video_index = 0, 0 + attention_mask = attention_mask.to(total_input_ids.device) + for i, input_ids in enumerate(total_input_ids): + input_ids = input_ids[attention_mask[i] == 1] + image_nums, video_nums = 0, 0 + vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1) + vision_tokens = input_ids[vision_start_indices + 1] + image_nums = (vision_tokens == image_token_id).sum() + video_nums = (vision_tokens == video_token_id).sum() + input_tokens = input_ids.tolist() + llm_pos_ids_list: list = [] + st = 0 + remain_images, remain_videos = image_nums, video_nums + for _ in range(image_nums + video_nums): + if image_token_id in input_tokens and remain_images > 0: + ed_image = input_tokens.index(image_token_id, st) + else: + ed_image = len(input_tokens) + 1 + if video_token_id in input_tokens and remain_videos > 0: + ed_video = input_tokens.index(video_token_id, st) + else: + ed_video = len(input_tokens) + 1 + if ed_image < ed_video: + t, h, w = ( + image_grid_thw[image_index][0], + image_grid_thw[image_index][1], + image_grid_thw[image_index][2], + ) + second_per_grid_t = 0 + image_index += 1 + remain_images -= 1 + ed = ed_image + + else: + t, h, w = ( + video_grid_thw[video_index][0], + video_grid_thw[video_index][1], + video_grid_thw[video_index][2], + ) + if second_per_grid_ts is not None: + second_per_grid_t = second_per_grid_ts[video_index] + else: + second_per_grid_t = 1.0 + video_index += 1 + remain_videos -= 1 + ed = ed_video + llm_grid_t, llm_grid_h, llm_grid_w = ( + t.item(), + h.item() // spatial_merge_size, + w.item() // spatial_merge_size, + ) + text_len = ed - st + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + range_tensor = torch.arange(llm_grid_t).view(-1, 1) + expanded_range = range_tensor.expand(-1, llm_grid_h * llm_grid_w) + + time_tensor = expanded_range * second_per_grid_t * self.config.vision_config.tokens_per_second + + time_tensor_long = time_tensor.long() + t_index = time_tensor_long.flatten() + + h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten() + w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten() + llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx) + st = ed + llm_grid_t * llm_grid_h * llm_grid_w + + if st < len(input_tokens): + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + text_len = len(input_tokens) - st + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device) + mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i])) + mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1) + return position_ids, mrope_position_deltas + else: + if attention_mask is not None: + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device) + max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0] + mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1] + else: + position_ids = ( + torch.arange(input_ids.shape[1], device=input_ids.device) + .view(1, 1, -1) + .expand(3, input_ids.shape[0], -1) + ) + mrope_position_deltas = torch.zeros( + [input_ids.shape[0], 1], + device=input_ids.device, + dtype=input_ids.dtype, + ) + + return position_ids, mrope_position_deltas + + @add_start_docstrings_to_model_forward(QWEN2_5_VL_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Qwen2_5_VLCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + pixel_values: Optional[torch.Tensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + rope_deltas: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + second_per_grid_ts: Optional[torch.Tensor] = None, + image_embeddings: Optional[torch.Tensor] = None, + token_loss_weight: Optional[float] = 0.1, + img_loss_weight: Optional[float] = 1.0, + ) -> Union[Tuple, Qwen2_5_VLCausalLMOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration + + >>> model = Qwen2_5_VLForConditionalGeneration.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct") + >>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct") + + >>> messages = [ + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": "What is shown in this image?"}, + ], + }, + ] + >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + >>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos]) + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..." + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if inputs_embeds is None: + # test feature + inputs_embeds = self.model.embed_tokens(input_ids) + # for image encoding and training + if pixel_values is not None: + pixel_values = pixel_values.type(self.visual.dtype) + image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) + n_image_tokens = (input_ids == self.config.image_token_id).sum().item() + n_image_features = image_embeds.shape[0] + if n_image_tokens != n_image_features: + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + + mask = input_ids == self.config.image_token_id + mask_unsqueezed = mask.unsqueeze(-1) + mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) + image_mask = mask_expanded.to(inputs_embeds.device) + + image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) + + if pixel_values_videos is not None: + pixel_values_videos = pixel_values_videos.type(self.visual.dtype) + video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw) + n_video_tokens = (input_ids == self.config.video_token_id).sum().item() + n_video_features = video_embeds.shape[0] + if n_video_tokens != n_video_features: + raise ValueError( + f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" + ) + + mask = input_ids == self.config.video_token_id + mask_unsqueezed = mask.unsqueeze(-1) + mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) + video_mask = mask_expanded.to(inputs_embeds.device) + + video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) + + if attention_mask is not None: + attention_mask = attention_mask.to(inputs_embeds.device) + + # if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme + if position_ids is None and (attention_mask is None or attention_mask.ndim == 2): + # calculate RoPE index once per generation in the pre-fill stage only + if ( + (cache_position is not None and cache_position[0] == 0) + or self.rope_deltas is None + or (past_key_values is None or past_key_values.get_seq_length() == 0) + ): + position_ids, rope_deltas = self.get_rope_index( + input_ids, + image_grid_thw, + video_grid_thw, + second_per_grid_ts, + attention_mask, + ) + self.rope_deltas = rope_deltas + # then use the prev pre-calculated rope-deltas to get the correct position ids + else: + batch_size, seq_length, _ = inputs_embeds.shape + delta = ( + (cache_position[0] + self.rope_deltas).to(inputs_embeds.device) + if cache_position is not None + else 0 + ) + position_ids = torch.arange(seq_length, device=inputs_embeds.device) + position_ids = position_ids.view(1, -1).expand(batch_size, -1) + if cache_position is not None: # otherwise `deltas` is an int `0` + delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) + position_ids = position_ids.add(delta) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) + # position_ids [3, B, L] + + outputs = self.model( + input_ids=None, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + image_embeds = self.vision_head(hidden_states) + + loss = None + if labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + # prepare labels for logits + logits_labels = labels.clone().detach() + image_tokens = (labels == self.config.image_token_id) + logits_labels[image_tokens] = -100 + + logits = logits.float() + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = logits_labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) * token_loss_weight + + shift_image_tokens_2d = (labels[..., 1:].contiguous() == self.config.image_token_id) # (B, L-1) + shifted_image_embeds = image_embeds[:, :-1, :].contiguous() # (B, L-1, D) + masked_image_embeds = shifted_image_embeds[shift_image_tokens_2d] # (num_image_tokens, D) + + mse_loss_fct = nn.MSELoss() + mse_loss_fct = mse_loss_fct.to(shift_logits.device) + if image_embeddings is None: + image_embeddings = torch.zeros_like(masked_image_embeds) + img_loss = mse_loss_fct(masked_image_embeds, image_embeddings) + + cos_sim = torch.cosine_similarity( + masked_image_embeds, + image_embeddings, + dim=-1 + ) + cos_loss = (1 - cos_sim).mean() + img_loss = 0.5 * img_loss + 0.5 * cos_loss + # fix nan for empty image tokens + if image_embeddings.size(0) == 0: + img_loss = img_loss.nan_to_num(0.0) + # combine the loss + loss = loss + img_loss_weight * img_loss + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return Qwen2_5_VLCausalLMOutputWithPast( + loss=loss, + logits=logits, + image_embeddings=image_embeds, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + rope_deltas=self.rope_deltas, + ) + + + + def _sample( + self, + input_ids: torch.LongTensor, + logits_processor: LogitsProcessorList, + stopping_criteria: StoppingCriteriaList, + generation_config: GenerationConfig, + synced_gpus: bool, + streamer: Optional["BaseStreamer"], + **model_kwargs, + ) -> Union[GenerateNonBeamOutput, torch.LongTensor]: + r""" + Generates sequences of token ids for models with a language modeling head using **multinomial sampling** and + can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. + + Parameters: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + The sequence used as a prompt for the generation. + logits_processor (`LogitsProcessorList`): + An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] + used to modify the prediction scores of the language modeling head applied at each generation step. + stopping_criteria (`StoppingCriteriaList`): + An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] + used to tell if the generation loop should stop. + generation_config ([`~generation.GenerationConfig`]): + The generation configuration to be used as parametrization of the decoding method. + synced_gpus (`bool`): + Whether to continue running the while loop until max_length (needed to avoid deadlocking with + `FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3). + streamer (`BaseStreamer`, *optional*): + Streamer object that will be used to stream the generated sequences. Generated tokens are passed + through `streamer.put(token_ids)` and the streamer is responsible for any further processing. + model_kwargs: + Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is + an encoder-decoder model the kwargs should include `encoder_outputs`. + + Return: + [`~generation.GenerateDecoderOnlyOutput`], [`~generation.GenerateEncoderDecoderOutput`] or `torch.LongTensor`: + A `torch.LongTensor` containing the generated tokens (default behaviour) or a + [`~generation.GenerateDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and + `return_dict_in_generate=True` or a [`~generation.GenerateEncoderDecoderOutput`] if + `model.config.is_encoder_decoder=True`. + """ + # init values + pad_token_id = generation_config._pad_token_tensor + output_attentions = generation_config.output_attentions + output_hidden_states = generation_config.output_hidden_states + output_scores = generation_config.output_scores + output_logits = generation_config.output_logits + return_dict_in_generate = generation_config.return_dict_in_generate + max_length = generation_config.max_length + has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria) + do_sample = generation_config.do_sample + + # init attention / hidden states / scores tuples + scores = () if (return_dict_in_generate and output_scores) else None + raw_logits = () if (return_dict_in_generate and output_logits) else None + decoder_attentions = () if (return_dict_in_generate and output_attentions) else None + cross_attentions = () if (return_dict_in_generate and output_attentions) else None + decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None + + # if model is an encoder-decoder, retrieve encoder attention weights and hidden states + if return_dict_in_generate and self.config.is_encoder_decoder: + encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None + encoder_hidden_states = ( + model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None + ) + + # keep track of which sequences are already finished + batch_size, cur_len = input_ids.shape + this_peer_finished = False + unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) + model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs) + + model_forward = self.__call__ + if isinstance(model_kwargs.get("past_key_values"), Cache): + is_compileable = model_kwargs["past_key_values"].is_compileable and self._supports_static_cache + is_compileable = is_compileable and not self.generation_config.disable_compile + if is_compileable and ( + self.device.type in ["cuda", "npu"] or generation_config.compile_config._compile_all_devices + ): + os.environ["TOKENIZERS_PARALLELISM"] = "0" + model_forward = self.get_compiled_call(generation_config.compile_config) + + is_prefill = True + is_sampling_img = input_ids[:, -1] == self.config.vision_start_token_id + generation_image_grid_thw = model_kwargs.pop("generation_image_grid_thw", self.get_default_image_grid_thw()) + num_img_tokens = self.get_num_image_tokens(generation_image_grid_thw) + output_image_embeddings = [] + while self._has_unfinished_sequences( + this_peer_finished, synced_gpus, device=input_ids.device, cur_len=cur_len, max_length=max_length + ): + # prepare model inputs + model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + + # prepare prefilled embeds + model_inputs.update(self.prepare_prefilled_image_embeds(len(output_image_embeddings), num_img_tokens, is_sampling_img, **model_kwargs)) + + # parse position_ids from model_kwargs + model_inputs.update(self.prepare_image_position_ids(input_ids, generation_image_grid_thw, is_sampling_img, **model_kwargs)) + + # prepare variable output controls (note: some models won't accept all output controls) + model_inputs.update({"output_attentions": output_attentions} if output_attentions else {}) + model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {}) + + if is_prefill: + outputs = self(**model_inputs, return_dict=True) + is_prefill = False + else: + outputs = model_forward(**model_inputs, return_dict=True) + + # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping + model_kwargs = self._update_model_kwargs_for_generation( + outputs, + model_kwargs, + is_encoder_decoder=self.config.is_encoder_decoder, + ) + # TODO: support batch image sampling + if bool(is_sampling_img) and len(output_image_embeddings) < num_img_tokens: + output_image_embeddings.append(outputs.image_embeddings[:, -1, :].unsqueeze(1)) + + if synced_gpus and this_peer_finished: + continue + # Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration + # (the clone itself is always small) + next_token_logits = outputs.logits[:, -1, :].clone().float() + next_token_logits = next_token_logits.to(input_ids.device) + + # do not sample token + next_token_logits[:, self.config.vision_end_token_id] = -float('inf') + # pre-process distribution + next_token_scores = logits_processor(input_ids, next_token_logits) + # Store scores, attentions and hidden_states when required + if return_dict_in_generate: + if output_scores: + scores += (next_token_scores,) + if output_logits: + raw_logits += (next_token_logits,) + if output_attentions: + decoder_attentions += ( + (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) + ) + if self.config.is_encoder_decoder: + cross_attentions += (outputs.cross_attentions,) + + if output_hidden_states: + decoder_hidden_states += ( + (outputs.decoder_hidden_states,) + if self.config.is_encoder_decoder + else (outputs.hidden_states,) + ) + + # token selection + if do_sample: + probs = nn.functional.softmax(next_token_scores, dim=-1) + # TODO (joao): this OP throws "skipping cudagraphs due to ['incompatible ops']", find solution + next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) + # while not bool(is_sampling_img) and torch.any(next_tokens == self.config.vision_end_token_id): + # probs[:, self.config.vision_end_token_id] = 0 + # next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) + else: + next_tokens = torch.argmax(next_token_scores, dim=-1) + + # finished sentences should have their next token be a padding token + if has_eos_stopping_criteria: + next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) + + #TODO: support batch image sample + if num_img_tokens is not None: + cur_img_tokens = (input_ids == self.config.vision_start_token_id).flip(dims=[1]).float().argmax(dim=1) + # check whether is sampling images + is_end_img = torch.logical_and(cur_img_tokens == num_img_tokens, is_sampling_img) + is_sampling_img = torch.logical_and(is_sampling_img, cur_img_tokens < num_img_tokens) + next_tokens[is_sampling_img] = self.config.image_token_id + # check whether to end sampling images + next_tokens[is_end_img] = self.config.vision_end_token_id + else: + # check whether to end sampling images + is_sampling_img = torch.logical_and(is_sampling_img, (next_tokens != self.config.vision_end_token_id)) + # replace the next token with the image token if is sampling image + next_tokens[is_sampling_img] = self.config.image_token_id + # check whether to start sampling images + is_sampling_img = torch.logical_or(is_sampling_img, (next_tokens == self.config.vision_start_token_id)) + + # update generated ids, model inputs, and length for next step + input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) + + if streamer is not None: + streamer.put(next_tokens.cpu()) + + unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores) + this_peer_finished = unfinished_sequences.max() == 0 + cur_len += 1 + + # This is needed to properly delete outputs.logits which may be very large for first iteration + # Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration + del outputs + + if streamer is not None: + streamer.end() + + # output the image embeddings + output_image_embeddings = torch.cat(output_image_embeddings, dim=1) if len(output_image_embeddings) > 0 else None + + if return_dict_in_generate: + return GenerateDecoderOnlyAll2AllOutput( + sequences=input_ids, + scores=scores, + logits=raw_logits, + attentions=decoder_attentions, + hidden_states=decoder_hidden_states, + past_key_values=model_kwargs.get("past_key_values"), + output_image_embeddings=output_image_embeddings, + ) + else: + return input_ids + + + def prepare_prefilled_image_embeds(self, cur_image_tokens, num_img_tokens, is_sampling_img, **model_kwargs): + if cur_image_tokens == 0 or cur_image_tokens > num_img_tokens or not bool(is_sampling_img): + return {} + # TODO: support batch image sample + image_idx = torch.tensor([cur_image_tokens-1]).to(self.device).long().unsqueeze(0) + inputs_embeds = self.image_prefill_embeds(image_idx) + return {"inputs_embeds": inputs_embeds} + + + def get_default_image_grid_thw(self,): + return torch.tensor([[1, 18, 18]]).to(self.device) + + + def get_num_image_tokens(self, image_grid_thw): + return int(torch.prod(image_grid_thw, dim=1).sum() // 4) + + + def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]): + num_img_tokens = model_kwargs.pop("generation_image_grid_thw", None) + super()._validate_model_kwargs(model_kwargs) + model_kwargs["generation_image_grid_thw"] = num_img_tokens + + def prepare_image_position_ids(self, input_ids, generation_image_grid_thw, is_sampling_img, **model_kwargs): + # Overwritten -- prepare position_ids for image tokens + cur_img_tokens = int((input_ids == self.config.vision_start_token_id).flip(dims=[1]).float().argmax(dim=1)) + # TODO: support batch image sample + if cur_img_tokens > 0 and bool(is_sampling_img): + image_grid_thw = generation_image_grid_thw + if model_kwargs.get('image_grid_thw') is not None: + image_grid_thw = torch.cat([model_kwargs.get('image_grid_thw'), image_grid_thw]) + remaining_img_tokens = self.get_num_image_tokens(generation_image_grid_thw) - cur_img_tokens + padding_ids = input_ids.new_full((1, remaining_img_tokens), fill_value=self.config.image_token_id) + padded_ids = torch.cat([input_ids, padding_ids], dim=1) + position_ids, _ = self.get_rope_index(padded_ids, image_grid_thw, None, None) + if model_kwargs.get("use_cache", True): + position_ids = position_ids[:, :, input_ids.shape[1] - 1].unsqueeze(-1) + else: + position_ids = position_ids[:, :, :input_ids.shape[1]] + return {"position_ids": position_ids} + return {} + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + pixel_values=None, + pixel_values_videos=None, + image_grid_thw=None, + video_grid_thw=None, + second_per_grid_ts=None, + image_embeddings=None, + **kwargs, + ): + # Overwritten -- in specific circumstances we don't want to forward image inputs to the model + + model_inputs = super().prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + position_ids=position_ids, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + use_cache=use_cache, + **kwargs, + ) + + # Qwen2-5-VL position_ids are prepared with rope_deltas in forward + model_inputs["position_ids"] = None + + if cache_position[0] != 0: + model_inputs["pixel_values"] = None + model_inputs["pixel_values_videos"] = None + return model_inputs + + def _get_image_nums_and_video_nums( + self, + input_ids: Optional[torch.LongTensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Get the number of images and videos for each sample to calculate the separation length of the sample tensor. + These parameters are not passed through the processor to avoid unpredictable impacts from interface modifications. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + + Returns: + image_nums (`torch.LongTensor` of shape `(batch_size, num_images_sample)`) + video_nums (`torch.LongTensor` of shape `(batch_size, num_videos_sample)`) + """ + image_token_id = self.config.image_token_id + video_token_id = self.config.video_token_id + vision_start_token_id = self.config.vision_start_token_id + + vision_start_mask = input_ids == vision_start_token_id + vision_first_mask = torch.roll(vision_start_mask, shifts=1, dims=1) + image_mask = input_ids == image_token_id + video_mask = input_ids == video_token_id + image_nums = torch.sum(vision_first_mask & image_mask, dim=1) + video_nums = torch.sum(vision_first_mask & video_mask, dim=1) + + return image_nums, video_nums + + def _expand_inputs_for_generation( + self, + expand_size: int = 1, + is_encoder_decoder: bool = False, + input_ids: Optional[torch.LongTensor] = None, + **model_kwargs, + ) -> Tuple[torch.LongTensor, Dict[str, Any]]: + # Overwritten -- Support for expanding tensors without a batch size dimension + # e.g., pixel_values, image_grid_thw, pixel_values_videos, video_grid_thw, second_per_grid_t + # pixel_values.shape[0] is sum(seqlen_images for samples) + # image_grid_thw.shape[0] is sum(num_images for samples) + + if expand_size == 1: + return input_ids, model_kwargs + + visual_keys = ["pixel_values", "image_grid_thw", "pixel_values_videos", "video_grid_thw", "second_per_grid_ts"] + + def _expand_dict_for_generation_visual(dict_to_expand): + image_grid_thw = model_kwargs.get("image_grid_thw", None) + video_grid_thw = model_kwargs.get("video_grid_thw", None) + image_nums, video_nums = self._get_image_nums_and_video_nums(input_ids) + + def _repeat_interleave_samples(x, lengths, repeat_times): + samples = torch.split(x, lengths) + repeat_args = [repeat_times] + [1] * (x.dim() - 1) + result = torch.cat([sample.repeat(*repeat_args) for sample in samples], dim=0) + return result + + for key in dict_to_expand: + if key == "pixel_values": + # split images into samples + samples = torch.split(image_grid_thw, list(image_nums)) + # compute the sequence length of images for each sample + lengths = [torch.prod(sample, dim=1).sum() for sample in samples] + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=lengths, repeat_times=expand_size + ) + elif key == "image_grid_thw": + # get the num of images for each sample + lengths = list(image_nums) + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=lengths, repeat_times=expand_size + ) + elif key == "pixel_values_videos": + samples = torch.split(video_grid_thw, list(video_nums)) + lengths = [torch.prod(sample, dim=1).sum() for sample in samples] + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=lengths, repeat_times=expand_size + ) + elif key == "video_grid_thw": + lengths = list(video_nums) + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=lengths, repeat_times=expand_size + ) + elif key == "second_per_grid_ts": + if not isinstance(dict_to_expand[key], list): + raise TypeError( + f"Expected value for key '{key}' to be a list, but got {type(dict_to_expand[key])} instead." + ) + tensor = torch.tensor(dict_to_expand[key]) + lengths = list(video_nums) + tensor = _repeat_interleave_samples(tensor, lengths=lengths, repeat_times=expand_size) + dict_to_expand[key] = tensor.tolist() + return dict_to_expand + + def _expand_dict_for_generation(dict_to_expand): + for key in dict_to_expand: + if ( + key != "cache_position" + and dict_to_expand[key] is not None + and isinstance(dict_to_expand[key], torch.Tensor) + and key not in visual_keys + ): + dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0) + return dict_to_expand + + # input_ids is required for expanding visual inputs + # If input_ids is unavailable, visual inputs will not be used; therefore, there is no need to expand visual inputs. + if input_ids is not None and input_ids.numel() != 0: + model_kwargs = _expand_dict_for_generation_visual(model_kwargs) + + if input_ids is not None: + input_ids = input_ids.repeat_interleave(expand_size, dim=0) + + model_kwargs = _expand_dict_for_generation(model_kwargs) + + if is_encoder_decoder: + if model_kwargs.get("encoder_outputs") is None: + raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.") + model_kwargs["encoder_outputs"] = _expand_dict_for_generation(model_kwargs["encoder_outputs"]) + + return input_ids, model_kwargs + + +__all__ = ["Qwen2_5_VLForConditionalGeneration", "Qwen2_5_VLModel", "Qwen2_5_VLPreTrainedModel"] + + + +class Qwen2_5_VLVideosProcessorKwargs(VideosKwargs, total=False): + fps: Union[List[float], float] + + +class Qwen2_5_VLProcessorKwargs(ProcessingKwargs, total=False): + videos_kwargs: Qwen2_5_VLVideosProcessorKwargs + _defaults = { + "text_kwargs": { + "padding": False, + }, + "videos_kwargs": {"fps": 2.0}, + } + + +class Qwen2_5_VLProcessor(ProcessorMixin): + r""" + Constructs a Qwen2.5-VL processor which wraps a Qwen2.5-VL image processor and a Qwen2 tokenizer into a single processor. + [`Qwen2_5_VLProcessor`] offers all the functionalities of [`Qwen2VLImageProcessor`] and [`Qwen2TokenizerFast`]. See the + [`~Qwen2_5_VLProcessor.__call__`] and [`~Qwen2_5_VLProcessor.decode`] for more information. + Args: + image_processor ([`Qwen2VLImageProcessor`], *optional*): + The image processor is a required input. + tokenizer ([`Qwen2TokenizerFast`], *optional*): + The tokenizer is a required input. + chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages + in a chat into a tokenizable string. + """ + + attributes = ["image_processor", "tokenizer"] + valid_kwargs = ["chat_template"] + + image_processor_class = "AutoImageProcessor" + tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast") + + def __init__(self, image_processor=None, tokenizer=None, chat_template=None, **kwargs): + self.image_token = "<|image_pad|>" if not hasattr(tokenizer, "image_token") else tokenizer.image_token + self.video_token = "<|video_pad|>" if not hasattr(tokenizer, "video_token") else tokenizer.video_token + super().__init__(image_processor, tokenizer, chat_template=chat_template) + + def __call__( + self, + images: ImageInput = None, + text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, + videos: VideoInput = None, + **kwargs: Unpack[Qwen2_5_VLProcessorKwargs], + ) -> BatchFeature: + """ + Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` + and `kwargs` arguments to Qwen2TokenizerFast's [`~Qwen2TokenizerFast.__call__`] if `text` is not `None` to encode + the text. To prepare the vision inputs, this method forwards the `vision_infos` and `kwrags` arguments to + Qwen2VLImageProcessor's [`~Qwen2VLImageProcessor.__call__`] if `vision_infos` is not `None`. + + Args: + images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. Both channels-first and channels-last formats are supported. + text (`str`, `List[str]`, `List[List[str]]`): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + videos (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`): + The image or batch of videos to be prepared. Each video can be a 4D NumPy array or PyTorch + tensor, or a nested list of 3D frames. Both channels-first and channels-last formats are supported. + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors of a particular framework. Acceptable values are: + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + - `'jax'`: Return JAX `jnp.ndarray` objects. + + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + + - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when + `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not + `None`). + - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. + - **pixel_values_videos** -- Pixel values of videos to be fed to a model. Returned when `videos` is not `None`. + - **image_grid_thw** -- List of image 3D grid in LLM. Returned when `images` is not `None`. + - **video_grid_thw** -- List of video 3D grid in LLM. Returned when `videos` is not `None`. + - **second_per_grid_ts** -- List of video seconds per time grid. Returned when `videos` is not `None`. + """ + output_kwargs = self._merge_kwargs( + Qwen2_5_VLProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + if images is not None: + image_inputs = self.image_processor(images=images, videos=None, **output_kwargs["images_kwargs"]) + image_grid_thw = image_inputs["image_grid_thw"] + else: + image_inputs = {} + image_grid_thw = None + + if videos is not None: + videos_inputs = self.image_processor(images=None, videos=videos, **output_kwargs["images_kwargs"]) + video_grid_thw = videos_inputs["video_grid_thw"] + + fps = output_kwargs["videos_kwargs"].pop("fps", 2.0) + if isinstance(fps, (int, float)): + second_per_grid_ts = [self.image_processor.temporal_patch_size / fps] * len(video_grid_thw) + elif hasattr(fps, "__len__") and len(fps) == len(video_grid_thw): + second_per_grid_ts = [self.image_processor.temporal_patch_size / tmp for tmp in fps] + else: + raise ValueError( + f"The length of fps ({len(fps) if hasattr(fps, '__len__') else fps}) must be equal to the length of video_grid_thw ({len(video_grid_thw)}) or fps should be a single number." + ) + videos_inputs.update({"second_per_grid_ts": second_per_grid_ts}) + + else: + videos_inputs = {} + video_grid_thw = None + + if not isinstance(text, list): + text = [text] + + if image_grid_thw is not None: + merge_length = self.image_processor.merge_size**2 + index = 0 + for i in range(len(text)): + while self.image_token in text[i]: + text[i] = text[i].replace( + self.image_token, + "<|placeholder|>" * (image_grid_thw[index].prod() // merge_length), + 1, + ) + index += 1 + text[i] = text[i].replace("<|placeholder|>", self.image_token) + + if video_grid_thw is not None: + merge_length = self.image_processor.merge_size**2 + index = 0 + for i in range(len(text)): + while self.video_token in text[i]: + text[i] = text[i].replace( + self.video_token, + "<|placeholder|>" * (video_grid_thw[index].prod() // merge_length), + 1, + ) + index += 1 + text[i] = text[i].replace("<|placeholder|>", self.video_token) + + text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"]) + + return BatchFeature(data={**text_inputs, **image_inputs, **videos_inputs}) + + def batch_decode(self, *args, **kwargs): + """ + This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please + refer to the docstring of this method for more information. + """ + return self.tokenizer.batch_decode(*args, **kwargs) + + def batch_decode_all2all(self, *args, **kwargs): + """ + This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please + refer to the docstring of this method for more information. + """ + decoded = self.tokenizer.batch_decode(*args, **kwargs) + pattern = r'<\|vision_start\|>.*?<\|vision_end\|>' + decoded_with_image_tag = [re.sub(pattern, '', d, flags=re.DOTALL) for d in decoded] + decoded_with_image_tag = [re.sub(r'<\|im_end\|>', '', d) for d in decoded_with_image_tag] + return decoded_with_image_tag + + def decode(self, *args, **kwargs): + """ + This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to + the docstring of this method for more information. + """ + return self.tokenizer.decode(*args, **kwargs) + + def post_process_image_text_to_text( + self, generated_outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False, **kwargs + ): + """ + Post-process the output of the model to decode the text. + + Args: + generated_outputs (`torch.Tensor` or `np.ndarray`): + The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)` + or `(sequence_length,)`. + skip_special_tokens (`bool`, *optional*, defaults to `True`): + Whether or not to remove special tokens in the output. Argument passed to the tokenizer's `batch_decode` method. + Clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`): + Whether or not to clean up the tokenization spaces. Argument passed to the tokenizer's `batch_decode` method. + **kwargs: + Additional arguments to be passed to the tokenizer's `batch_decode method`. + + Returns: + `List[str]`: The decoded text. + """ + return self.tokenizer.batch_decode( + generated_outputs, + skip_special_tokens=skip_special_tokens, + clean_up_tokenization_spaces=clean_up_tokenization_spaces, + **kwargs, + ) + + @property + def model_input_names(self): + tokenizer_input_names = self.tokenizer.model_input_names + image_processor_input_names = self.image_processor.model_input_names + names_from_processor = list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) + return names_from_processor + ["second_per_grid_ts"] + + +__all__ = ["Qwen2_5_VLProcessor"] diff --git a/diffsynth/models/nexus_gen_projector.py b/diffsynth/models/nexus_gen_projector.py new file mode 100644 index 0000000000000000000000000000000000000000..d69b3e1bfd50fc3b9c098f7775afae4020f9b320 --- /dev/null +++ b/diffsynth/models/nexus_gen_projector.py @@ -0,0 +1,417 @@ +import math +import torch +import torch.nn as nn +from typing import Optional, Tuple + + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1): + mrope_section = mrope_section * 2 + cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze( + unsqueeze_dim + ) + sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze( + unsqueeze_dim + ) + + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class Qwen2_5_VLRotaryEmbedding(nn.Module): + def __init__(self, config, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + from transformers.modeling_rope_utils import _compute_default_rope_parameters + self.rope_init_fn = _compute_default_rope_parameters + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn( + self.config, device, seq_len=seq_len, **self.rope_kwargs + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len + + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len + + + @torch.no_grad() + def forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block. In contrast to other models, Qwen2_5_VL has different position ids for the grids + # So we expand the inv_freq to shape (3, ...) + inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1) + position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions) + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class Qwen2_5_VLAttention(nn.Module): + def __init__(self, config, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.is_causal = True + self.attention_dropout = config.attention_dropout + self.rope_scaling = config.rope_scaling + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_multimodal_rotary_pos_emb( + query_states, key_states, cos, sin, self.rope_scaling["mrope_section"] + ) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + # Fix precision issues in Qwen2-VL float16 inference + # Replace inf values with zeros in attention weights to prevent NaN propagation + if query_states.dtype == torch.float16: + attn_weights = torch.where(torch.isinf(attn_weights), torch.zeros_like(attn_weights), attn_weights) + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, -1) + + attn_output = self.o_proj(attn_output) + + return attn_output + + +class Qwen2MLP(nn.Module): + def __init__(self, config): + super().__init__() + from transformers.activations import ACT2FN + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +class Qwen2RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Qwen2RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class Qwen2_5_VLDecoderLayer(nn.Module): + def __init__(self, config, layer_idx): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = Qwen2_5_VLAttention(config, layer_idx) + + self.mlp = Qwen2MLP(config) + self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states = self.self_attn( + hidden_states=hidden_states, + position_embeddings=position_embeddings, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class NexusGenImageEmbeddingMerger(nn.Module): + def __init__(self, num_layers=1, out_channel=4096, expand_ratio=4, device='cpu'): + super().__init__() + from transformers import Qwen2_5_VLConfig + from transformers.activations import ACT2FN + config = Qwen2_5_VLConfig(**{ + "_name_or_path": "DiffSynth-Studio/Nexus-GenV2", + "architectures": [ + "Qwen2_5_VLForConditionalGeneration" + ], + "attention_dropout": 0.0, + "auto_map": { + "AutoConfig": "configuration_qwen2_5_vl.Qwen2_5_VLConfig", + "AutoModel": "modeling_qwen2_5_vl.Qwen2_5_VLModel", + "AutoModelForCausalLM": "modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration" + }, + "bos_token_id": 151643, + "eos_token_id": 151645, + "hidden_act": "silu", + "hidden_size": 3584, + "image_token_id": 151655, + "initializer_range": 0.02, + "intermediate_size": 18944, + "max_position_embeddings": 128000, + "max_window_layers": 28, + "model_type": "qwen2_5_vl", + "num_attention_heads": 28, + "num_hidden_layers": 28, + "num_key_value_heads": 4, + "pad_token_id": 151643, + "rms_norm_eps": 1e-06, + "rope_scaling": { + "mrope_section": [ + 16, + 24, + 24 + ], + "rope_type": "default", + "type": "default" + }, + "rope_theta": 1000000.0, + "sliding_window": 32768, + "tie_word_embeddings": False, + "torch_dtype": "bfloat16", + "transformers_version": "4.49.0", + "use_cache": False, + "use_sliding_window": False, + "video_token_id": 151656, + "vision_config": { + "hidden_size": 1280, + "in_chans": 3, + "model_type": "qwen2_5_vl", + "spatial_patch_size": 14, + "tokens_per_second": 2, + "torch_dtype": "bfloat16" + }, + "vision_end_token_id": 151653, + "vision_start_token_id": 151652, + "vision_token_id": 151654, + "vocab_size": 152064 + }) + self.config = config + self.num_layers = num_layers + self.layers = nn.ModuleList([Qwen2_5_VLDecoderLayer(config, layer_idx) for layer_idx in range(num_layers)]) + self.projector = nn.Sequential(Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps), + nn.Linear(config.hidden_size, out_channel * expand_ratio), + Qwen2RMSNorm(out_channel * expand_ratio, eps=config.rms_norm_eps), + ACT2FN[config.hidden_act], nn.Linear(out_channel * expand_ratio, out_channel), + Qwen2RMSNorm(out_channel, eps=config.rms_norm_eps)) + self.base_grid = torch.tensor([[1, 72, 72]], device=device) + self.rotary_emb = Qwen2_5_VLRotaryEmbedding(config=config, device=device) + + def get_position_ids(self, image_grid_thw): + """ + Generates position ids for the input embeddings grid. + modified from the qwen2_vl mrope. + """ + batch_size = image_grid_thw.shape[0] + spatial_merge_size = self.config.vision_config.spatial_merge_size + t, h, w = ( + image_grid_thw[0][0], + image_grid_thw[0][1], + image_grid_thw[0][2], + ) + llm_grid_t, llm_grid_h, llm_grid_w = ( + t.item(), + h.item() // spatial_merge_size, + w.item() // spatial_merge_size, + ) + scale_h = self.base_grid[0][1].item() / h.item() + scale_w = self.base_grid[0][2].item() / w.item() + + range_tensor = torch.arange(llm_grid_t).view(-1, 1) + expanded_range = range_tensor.expand(-1, llm_grid_h * llm_grid_w) + time_tensor = expanded_range * self.config.vision_config.tokens_per_second + t_index = time_tensor.long().flatten().to(image_grid_thw.device) + h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten().to(image_grid_thw.device) * scale_h + w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten().to(image_grid_thw.device) * scale_w + # 3, B, L + position_ids = torch.stack([t_index, h_index, w_index]).unsqueeze(0).repeat(batch_size, 1, 1).permute(1, 0, 2) + return position_ids + + def forward(self, embeds, embeds_grid, ref_embeds=None, ref_embeds_grid=None): + position_ids = self.get_position_ids(embeds_grid) + hidden_states = embeds + if ref_embeds is not None: + position_ids_ref_embeds = self.get_position_ids(ref_embeds_grid) + position_ids = torch.cat((position_ids, position_ids_ref_embeds), dim=-1) + hidden_states = torch.cat((embeds, ref_embeds), dim=1) + + position_embeddings = self.rotary_emb(hidden_states, position_ids) + for layer in self.layers: + hidden_states = layer(hidden_states, position_embeddings) + + hidden_states = self.projector(hidden_states) + return hidden_states + + @staticmethod + def state_dict_converter(): + return NexusGenMergerStateDictConverter() + + +class NexusGenMergerStateDictConverter: + def __init__(self): + pass + + def from_diffusers(self, state_dict): + return state_dict + + def from_civitai(self, state_dict): + merger_state_dict = {key.replace("embedding_merger.", ""): value for key, value in state_dict.items() if key.startswith('embedding_merger.')} + return merger_state_dict + + +class NexusGenAdapter(nn.Module): + """ + Adapter for Nexus-Gen generation decoder. + """ + def __init__(self, input_dim=3584, output_dim=4096): + super(NexusGenAdapter, self).__init__() + self.adapter = nn.Sequential(nn.Linear(input_dim, output_dim), + nn.LayerNorm(output_dim), nn.ReLU(), + nn.Linear(output_dim, output_dim), + nn.LayerNorm(output_dim)) + + def forward(self, x): + return self.adapter(x) + + @staticmethod + def state_dict_converter(): + return NexusGenAdapterStateDictConverter() + + +class NexusGenAdapterStateDictConverter: + def __init__(self): + pass + + def from_diffusers(self, state_dict): + return state_dict + + def from_civitai(self, state_dict): + adapter_state_dict = {key: value for key, value in state_dict.items() if key.startswith('adapter.')} + return adapter_state_dict diff --git a/diffsynth/models/qwen_image_controlnet.py b/diffsynth/models/qwen_image_controlnet.py new file mode 100644 index 0000000000000000000000000000000000000000..6ce40809065b3eac020d2b1da29101681a44764a --- /dev/null +++ b/diffsynth/models/qwen_image_controlnet.py @@ -0,0 +1,56 @@ +import torch +import torch.nn as nn +from .general_modules import RMSNorm + + +class BlockWiseControlBlock(torch.nn.Module): + # [linear, gelu, linear] + def __init__(self, dim: int = 3072): + super().__init__() + self.x_rms = RMSNorm(dim, eps=1e-6) + self.y_rms = RMSNorm(dim, eps=1e-6) + self.input_proj = nn.Linear(dim, dim) + self.act = nn.GELU() + self.output_proj = nn.Linear(dim, dim) + + def forward(self, x, y): + x, y = self.x_rms(x), self.y_rms(y) + x = self.input_proj(x + y) + x = self.act(x) + x = self.output_proj(x) + return x + + def init_weights(self): + # zero initialize output_proj + nn.init.zeros_(self.output_proj.weight) + nn.init.zeros_(self.output_proj.bias) + + +class QwenImageBlockWiseControlNet(torch.nn.Module): + def __init__( + self, + num_layers: int = 60, + in_dim: int = 64, + additional_in_dim: int = 0, + dim: int = 3072, + ): + super().__init__() + self.img_in = nn.Linear(in_dim + additional_in_dim, dim) + self.controlnet_blocks = nn.ModuleList( + [ + BlockWiseControlBlock(dim) + for _ in range(num_layers) + ] + ) + + def init_weight(self): + nn.init.zeros_(self.img_in.weight) + nn.init.zeros_(self.img_in.bias) + for block in self.controlnet_blocks: + block.init_weights() + + def process_controlnet_conditioning(self, controlnet_conditioning): + return self.img_in(controlnet_conditioning) + + def blockwise_forward(self, img, controlnet_conditioning, block_id): + return self.controlnet_blocks[block_id](img, controlnet_conditioning) diff --git a/diffsynth/models/qwen_image_dit.py b/diffsynth/models/qwen_image_dit.py new file mode 100644 index 0000000000000000000000000000000000000000..aeb6dd251bb7be4553e4b86af38bad89f9b445c6 --- /dev/null +++ b/diffsynth/models/qwen_image_dit.py @@ -0,0 +1,688 @@ +import torch, math, functools +import torch.nn as nn +from typing import Tuple, Optional, Union, List +from einops import rearrange +from .general_modules import TimestepEmbeddings, RMSNorm, AdaLayerNorm + +try: + import flash_attn_interface + FLASH_ATTN_3_AVAILABLE = True +except ModuleNotFoundError: + FLASH_ATTN_3_AVAILABLE = False + + +def qwen_image_flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads: int, attention_mask = None, enable_fp8_attention: bool = False): + if FLASH_ATTN_3_AVAILABLE and attention_mask is None: + if not enable_fp8_attention: + q = rearrange(q, "b n s d -> b s n d", n=num_heads) + k = rearrange(k, "b n s d -> b s n d", n=num_heads) + v = rearrange(v, "b n s d -> b s n d", n=num_heads) + x = flash_attn_interface.flash_attn_func(q, k, v) + if isinstance(x, tuple): + x = x[0] + x = rearrange(x, "b s n d -> b s (n d)", n=num_heads) + else: + origin_dtype = q.dtype + q_std, k_std, v_std = q.std(), k.std(), v.std() + q, k, v = (q / q_std).to(torch.float8_e4m3fn), (k / k_std).to(torch.float8_e4m3fn), (v / v_std).to(torch.float8_e4m3fn) + q = rearrange(q, "b n s d -> b s n d", n=num_heads) + k = rearrange(k, "b n s d -> b s n d", n=num_heads) + v = rearrange(v, "b n s d -> b s n d", n=num_heads) + x = flash_attn_interface.flash_attn_func(q, k, v, softmax_scale=q_std * k_std / math.sqrt(q.size(-1))) + if isinstance(x, tuple): + x = x[0] + x = x.to(origin_dtype) * v_std + x = rearrange(x, "b s n d -> b s (n d)", n=num_heads) + else: + x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attention_mask) + x = rearrange(x, "b n s d -> b s (n d)", n=num_heads) + return x + + +class ApproximateGELU(nn.Module): + def __init__(self, dim_in: int, dim_out: int, bias: bool = True): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out, bias=bias) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.proj(x) + return x * torch.sigmoid(1.702 * x) + +def apply_rotary_emb_qwen( + x: torch.Tensor, + freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] +): + x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) + x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3) + return x_out.type_as(x) + + +class QwenEmbedRope(nn.Module): + def __init__(self, theta: int, axes_dim: list[int], scale_rope=False): + super().__init__() + self.theta = theta + self.axes_dim = axes_dim + pos_index = torch.arange(4096) + neg_index = torch.arange(4096).flip(0) * -1 - 1 + self.pos_freqs = torch.cat([ + self.rope_params(pos_index, self.axes_dim[0], self.theta), + self.rope_params(pos_index, self.axes_dim[1], self.theta), + self.rope_params(pos_index, self.axes_dim[2], self.theta), + ], dim=1) + self.neg_freqs = torch.cat([ + self.rope_params(neg_index, self.axes_dim[0], self.theta), + self.rope_params(neg_index, self.axes_dim[1], self.theta), + self.rope_params(neg_index, self.axes_dim[2], self.theta), + ], dim=1) + self.rope_cache = {} + self.scale_rope = scale_rope + + def rope_params(self, index, dim, theta=10000): + """ + Args: + index: [0, 1, 2, 3] 1D Tensor representing the position index of the token + """ + assert dim % 2 == 0 + freqs = torch.outer( + index, + 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float32).div(dim)) + ) + freqs = torch.polar(torch.ones_like(freqs), freqs) + return freqs + + + def _expand_pos_freqs_if_needed(self, video_fhw, txt_seq_lens): + if isinstance(video_fhw, list): + video_fhw = tuple(max([i[j] for i in video_fhw]) for j in range(3)) + _, height, width = video_fhw + if self.scale_rope: + max_vid_index = max(height // 2, width // 2) + else: + max_vid_index = max(height, width) + required_len = max_vid_index + max(txt_seq_lens) + cur_max_len = self.pos_freqs.shape[0] + if required_len <= cur_max_len: + return + + new_max_len = math.ceil(required_len / 512) * 512 + pos_index = torch.arange(new_max_len) + neg_index = torch.arange(new_max_len).flip(0) * -1 - 1 + self.pos_freqs = torch.cat([ + self.rope_params(pos_index, self.axes_dim[0], self.theta), + self.rope_params(pos_index, self.axes_dim[1], self.theta), + self.rope_params(pos_index, self.axes_dim[2], self.theta), + ], dim=1) + self.neg_freqs = torch.cat([ + self.rope_params(neg_index, self.axes_dim[0], self.theta), + self.rope_params(neg_index, self.axes_dim[1], self.theta), + self.rope_params(neg_index, self.axes_dim[2], self.theta), + ], dim=1) + return + + + def forward(self, video_fhw, txt_seq_lens, device): + self._expand_pos_freqs_if_needed(video_fhw, txt_seq_lens) + if self.pos_freqs.device != device: + self.pos_freqs = self.pos_freqs.to(device) + self.neg_freqs = self.neg_freqs.to(device) + + vid_freqs = [] + max_vid_index = 0 + for idx, fhw in enumerate(video_fhw): + frame, height, width = fhw + rope_key = f"{idx}_{height}_{width}" + + if rope_key not in self.rope_cache: + seq_lens = frame * height * width + freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1) + freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1) + freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1) + if self.scale_rope: + freqs_height = torch.cat( + [freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0 + ) + freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1) + freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0) + freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1) + + else: + freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1) + freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1) + + freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1) + self.rope_cache[rope_key] = freqs.clone().contiguous() + vid_freqs.append(self.rope_cache[rope_key]) + + if self.scale_rope: + max_vid_index = max(height // 2, width // 2, max_vid_index) + else: + max_vid_index = max(height, width, max_vid_index) + + max_len = max(txt_seq_lens) + txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...] + vid_freqs = torch.cat(vid_freqs, dim=0) + + return vid_freqs, txt_freqs + + + def forward_sampling(self, video_fhw, txt_seq_lens, device): + self._expand_pos_freqs_if_needed(video_fhw, txt_seq_lens) + if self.pos_freqs.device != device: + self.pos_freqs = self.pos_freqs.to(device) + self.neg_freqs = self.neg_freqs.to(device) + + vid_freqs = [] + max_vid_index = 0 + for idx, fhw in enumerate(video_fhw): + frame, height, width = fhw + rope_key = f"{idx}_{height}_{width}" + if idx > 0 and f"{0}_{height}_{width}" not in self.rope_cache: + frame_0, height_0, width_0 = video_fhw[0] + + rope_key_0 = f"0_{height_0}_{width_0}" + spatial_freqs_0 = self.rope_cache[rope_key_0].reshape(frame_0, height_0, width_0, -1) + h_indices = torch.linspace(0, height_0 - 1, height).long() + w_indices = torch.linspace(0, width_0 - 1, width).long() + h_grid, w_grid = torch.meshgrid(h_indices, w_indices, indexing='ij') + sampled_rope = spatial_freqs_0[:, h_grid, w_grid, :] + + freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1) + freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1) + sampled_rope[:, :, :, :freqs_frame.shape[-1]] = freqs_frame + + seq_lens = frame * height * width + self.rope_cache[rope_key] = sampled_rope.reshape(seq_lens, -1).clone() + if rope_key not in self.rope_cache: + seq_lens = frame * height * width + freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1) + freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1) + freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1) + if self.scale_rope: + freqs_height = torch.cat( + [freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0 + ) + freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1) + freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0) + freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1) + + else: + freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1) + freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1) + + freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1) + self.rope_cache[rope_key] = freqs.clone() + vid_freqs.append(self.rope_cache[rope_key].contiguous()) + + if self.scale_rope: + max_vid_index = max(height // 2, width // 2, max_vid_index) + else: + max_vid_index = max(height, width, max_vid_index) + + max_len = max(txt_seq_lens) + txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...] + vid_freqs = torch.cat(vid_freqs, dim=0) + + return vid_freqs, txt_freqs + + +class QwenEmbedLayer3DRope(nn.Module): + def __init__(self, theta: int, axes_dim: List[int], scale_rope=False): + super().__init__() + self.theta = theta + self.axes_dim = axes_dim + pos_index = torch.arange(4096) + neg_index = torch.arange(4096).flip(0) * -1 - 1 + self.pos_freqs = torch.cat( + [ + self.rope_params(pos_index, self.axes_dim[0], self.theta), + self.rope_params(pos_index, self.axes_dim[1], self.theta), + self.rope_params(pos_index, self.axes_dim[2], self.theta), + ], + dim=1, + ) + self.neg_freqs = torch.cat( + [ + self.rope_params(neg_index, self.axes_dim[0], self.theta), + self.rope_params(neg_index, self.axes_dim[1], self.theta), + self.rope_params(neg_index, self.axes_dim[2], self.theta), + ], + dim=1, + ) + + self.scale_rope = scale_rope + + def rope_params(self, index, dim, theta=10000): + """ + Args: + index: [0, 1, 2, 3] 1D Tensor representing the position index of the token + """ + assert dim % 2 == 0 + freqs = torch.outer(index, 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float32).div(dim))) + freqs = torch.polar(torch.ones_like(freqs), freqs) + return freqs + + def forward(self, video_fhw, txt_seq_lens, device): + """ + Args: video_fhw: [frame, height, width] a list of 3 integers representing the shape of the video Args: + txt_length: [bs] a list of 1 integers representing the length of the text + """ + if self.pos_freqs.device != device: + self.pos_freqs = self.pos_freqs.to(device) + self.neg_freqs = self.neg_freqs.to(device) + + video_fhw = [video_fhw] + if isinstance(video_fhw, list): + video_fhw = video_fhw[0] + if not isinstance(video_fhw, list): + video_fhw = [video_fhw] + + vid_freqs = [] + max_vid_index = 0 + layer_num = len(video_fhw) - 1 + for idx, fhw in enumerate(video_fhw): + frame, height, width = fhw + if idx != layer_num: + video_freq = self._compute_video_freqs(frame, height, width, idx) + else: + ### For the condition image, we set the layer index to -1 + video_freq = self._compute_condition_freqs(frame, height, width) + video_freq = video_freq.to(device) + vid_freqs.append(video_freq) + + if self.scale_rope: + max_vid_index = max(height // 2, width // 2, max_vid_index) + else: + max_vid_index = max(height, width, max_vid_index) + + max_vid_index = max(max_vid_index, layer_num) + max_len = max(txt_seq_lens) + txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...] + vid_freqs = torch.cat(vid_freqs, dim=0) + + return vid_freqs, txt_freqs + + @functools.lru_cache(maxsize=None) + def _compute_video_freqs(self, frame, height, width, idx=0): + seq_lens = frame * height * width + freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1) + freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1) + + freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1) + if self.scale_rope: + freqs_height = torch.cat([freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0) + freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1) + freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0) + freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1) + else: + freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1) + freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1) + + freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1) + return freqs.clone().contiguous() + + @functools.lru_cache(maxsize=None) + def _compute_condition_freqs(self, frame, height, width): + seq_lens = frame * height * width + freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1) + freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1) + + freqs_frame = freqs_neg[0][-1:].view(frame, 1, 1, -1).expand(frame, height, width, -1) + if self.scale_rope: + freqs_height = torch.cat([freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0) + freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1) + freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0) + freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1) + else: + freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1) + freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1) + + freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1) + return freqs.clone().contiguous() + + +class QwenFeedForward(nn.Module): + def __init__( + self, + dim: int, + dim_out: Optional[int] = None, + dropout: float = 0.0, + ): + super().__init__() + inner_dim = int(dim * 4) + self.net = nn.ModuleList([]) + self.net.append(ApproximateGELU(dim, inner_dim)) + self.net.append(nn.Dropout(dropout)) + self.net.append(nn.Linear(inner_dim, dim_out)) + + def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor: + for module in self.net: + hidden_states = module(hidden_states) + return hidden_states + +class QwenDoubleStreamAttention(nn.Module): + def __init__( + self, + dim_a, + dim_b, + num_heads, + head_dim, + ): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + + self.to_q = nn.Linear(dim_a, dim_a) + self.to_k = nn.Linear(dim_a, dim_a) + self.to_v = nn.Linear(dim_a, dim_a) + self.norm_q = RMSNorm(head_dim, eps=1e-6) + self.norm_k = RMSNorm(head_dim, eps=1e-6) + + self.add_q_proj = nn.Linear(dim_b, dim_b) + self.add_k_proj = nn.Linear(dim_b, dim_b) + self.add_v_proj = nn.Linear(dim_b, dim_b) + self.norm_added_q = RMSNorm(head_dim, eps=1e-6) + self.norm_added_k = RMSNorm(head_dim, eps=1e-6) + + self.to_out = torch.nn.Sequential(nn.Linear(dim_a, dim_a)) + self.to_add_out = nn.Linear(dim_b, dim_b) + + def forward( + self, + image: torch.FloatTensor, + text: torch.FloatTensor, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + enable_fp8_attention: bool = False, + ) -> Tuple[torch.FloatTensor, torch.FloatTensor]: + img_q, img_k, img_v = self.to_q(image), self.to_k(image), self.to_v(image) + txt_q, txt_k, txt_v = self.add_q_proj(text), self.add_k_proj(text), self.add_v_proj(text) + seq_txt = txt_q.shape[1] + + img_q = rearrange(img_q, 'b s (h d) -> b h s d', h=self.num_heads) + img_k = rearrange(img_k, 'b s (h d) -> b h s d', h=self.num_heads) + img_v = rearrange(img_v, 'b s (h d) -> b h s d', h=self.num_heads) + + txt_q = rearrange(txt_q, 'b s (h d) -> b h s d', h=self.num_heads) + txt_k = rearrange(txt_k, 'b s (h d) -> b h s d', h=self.num_heads) + txt_v = rearrange(txt_v, 'b s (h d) -> b h s d', h=self.num_heads) + + img_q, img_k = self.norm_q(img_q), self.norm_k(img_k) + txt_q, txt_k = self.norm_added_q(txt_q), self.norm_added_k(txt_k) + + if image_rotary_emb is not None: + img_freqs, txt_freqs = image_rotary_emb + img_q = apply_rotary_emb_qwen(img_q, img_freqs) + img_k = apply_rotary_emb_qwen(img_k, img_freqs) + txt_q = apply_rotary_emb_qwen(txt_q, txt_freqs) + txt_k = apply_rotary_emb_qwen(txt_k, txt_freqs) + + joint_q = torch.cat([txt_q, img_q], dim=2) + joint_k = torch.cat([txt_k, img_k], dim=2) + joint_v = torch.cat([txt_v, img_v], dim=2) + + joint_attn_out = qwen_image_flash_attention(joint_q, joint_k, joint_v, num_heads=joint_q.shape[1], attention_mask=attention_mask, enable_fp8_attention=enable_fp8_attention).to(joint_q.dtype) + + txt_attn_output = joint_attn_out[:, :seq_txt, :] + img_attn_output = joint_attn_out[:, seq_txt:, :] + + img_attn_output = self.to_out(img_attn_output) + txt_attn_output = self.to_add_out(txt_attn_output) + + return img_attn_output, txt_attn_output + + +class QwenImageTransformerBlock(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + eps: float = 1e-6, + ): + super().__init__() + + self.dim = dim + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + + self.img_mod = nn.Sequential( + nn.SiLU(), + nn.Linear(dim, 6 * dim), + ) + self.img_norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) + self.attn = QwenDoubleStreamAttention( + dim_a=dim, + dim_b=dim, + num_heads=num_attention_heads, + head_dim=attention_head_dim, + ) + self.img_norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) + self.img_mlp = QwenFeedForward(dim=dim, dim_out=dim) + + self.txt_mod = nn.Sequential( + nn.SiLU(), + nn.Linear(dim, 6 * dim, bias=True), + ) + self.txt_norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) + self.txt_norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) + self.txt_mlp = QwenFeedForward(dim=dim, dim_out=dim) + + def _modulate(self, x, mod_params, index=None): + shift, scale, gate = mod_params.chunk(3, dim=-1) + if index is not None: + # Assuming mod_params batch dim is 2*actual_batch (chunked into 2 parts) + # So shift, scale, gate have shape [2*actual_batch, d] + actual_batch = shift.size(0) // 2 + shift_0, shift_1 = shift[:actual_batch], shift[actual_batch:] # each: [actual_batch, d] + scale_0, scale_1 = scale[:actual_batch], scale[actual_batch:] + gate_0, gate_1 = gate[:actual_batch], gate[actual_batch:] + + # index: [b, l] where b is actual batch size + # Expand to [b, l, 1] to match feature dimension + index_expanded = index.unsqueeze(-1) # [b, l, 1] + + # Expand chunks to [b, 1, d] then broadcast to [b, l, d] + shift_0_exp = shift_0.unsqueeze(1) # [b, 1, d] + shift_1_exp = shift_1.unsqueeze(1) # [b, 1, d] + scale_0_exp = scale_0.unsqueeze(1) + scale_1_exp = scale_1.unsqueeze(1) + gate_0_exp = gate_0.unsqueeze(1) + gate_1_exp = gate_1.unsqueeze(1) + + # Use torch.where to select based on index + shift_result = torch.where(index_expanded == 0, shift_0_exp, shift_1_exp) + scale_result = torch.where(index_expanded == 0, scale_0_exp, scale_1_exp) + gate_result = torch.where(index_expanded == 0, gate_0_exp, gate_1_exp) + else: + shift_result = shift.unsqueeze(1) + scale_result = scale.unsqueeze(1) + gate_result = gate.unsqueeze(1) + + return x * (1 + scale_result) + shift_result, gate_result + + def forward( + self, + image: torch.Tensor, + text: torch.Tensor, + temb: torch.Tensor, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + enable_fp8_attention = False, + modulate_index: Optional[List[int]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + + img_mod_attn, img_mod_mlp = self.img_mod(temb).chunk(2, dim=-1) # [B, 3*dim] each + if modulate_index is not None: + temb = torch.chunk(temb, 2, dim=0)[0] + txt_mod_attn, txt_mod_mlp = self.txt_mod(temb).chunk(2, dim=-1) # [B, 3*dim] each + + img_normed = self.img_norm1(image) + img_modulated, img_gate = self._modulate(img_normed, img_mod_attn, index=modulate_index) + + txt_normed = self.txt_norm1(text) + txt_modulated, txt_gate = self._modulate(txt_normed, txt_mod_attn) + + img_attn_out, txt_attn_out = self.attn( + image=img_modulated, + text=txt_modulated, + image_rotary_emb=image_rotary_emb, + attention_mask=attention_mask, + enable_fp8_attention=enable_fp8_attention, + ) + + image = image + img_gate * img_attn_out + text = text + txt_gate * txt_attn_out + + img_normed_2 = self.img_norm2(image) + img_modulated_2, img_gate_2 = self._modulate(img_normed_2, img_mod_mlp, index=modulate_index) + + txt_normed_2 = self.txt_norm2(text) + txt_modulated_2, txt_gate_2 = self._modulate(txt_normed_2, txt_mod_mlp) + + img_mlp_out = self.img_mlp(img_modulated_2) + txt_mlp_out = self.txt_mlp(txt_modulated_2) + + image = image + img_gate_2 * img_mlp_out + text = text + txt_gate_2 * txt_mlp_out + + return text, image + + +class QwenImageDiT(torch.nn.Module): + + _repeated_blocks = ["QwenImageTransformerBlock"] + + def __init__( + self, + num_layers: int = 60, + use_layer3d_rope: bool = False, + use_additional_t_cond: bool = False, + ): + super().__init__() + + if not use_layer3d_rope: + self.pos_embed = QwenEmbedRope(theta=10000, axes_dim=[16,56,56], scale_rope=True) + else: + self.pos_embed = QwenEmbedLayer3DRope(theta=10000, axes_dim=[16,56,56], scale_rope=True) + + self.time_text_embed = TimestepEmbeddings(256, 3072, diffusers_compatible_format=True, scale=1000, align_dtype_to_timestep=False, use_additional_t_cond=use_additional_t_cond) + self.txt_norm = RMSNorm(3584, eps=1e-6) + + self.img_in = nn.Linear(64, 3072) + self.txt_in = nn.Linear(3584, 3072) + + self.transformer_blocks = nn.ModuleList( + [ + QwenImageTransformerBlock( + dim=3072, + num_attention_heads=24, + attention_head_dim=128, + ) + for _ in range(num_layers) + ] + ) + self.norm_out = AdaLayerNorm(3072, single=True) + self.proj_out = nn.Linear(3072, 64) + + + def process_entity_masks(self, latents, prompt_emb, prompt_emb_mask, entity_prompt_emb, entity_prompt_emb_mask, entity_masks, height, width, image, img_shapes): + # prompt_emb + all_prompt_emb = entity_prompt_emb + [prompt_emb] + all_prompt_emb = [self.txt_in(self.txt_norm(local_prompt_emb)) for local_prompt_emb in all_prompt_emb] + all_prompt_emb = torch.cat(all_prompt_emb, dim=1) + + # image_rotary_emb + txt_seq_lens = prompt_emb_mask.sum(dim=1).tolist() + image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=latents.device) + entity_seq_lens = [emb_mask.sum(dim=1).tolist() for emb_mask in entity_prompt_emb_mask] + entity_rotary_emb = [self.pos_embed(img_shapes, entity_seq_len, device=latents.device)[1] for entity_seq_len in entity_seq_lens] + txt_rotary_emb = torch.cat(entity_rotary_emb + [image_rotary_emb[1]], dim=0) + image_rotary_emb = (image_rotary_emb[0], txt_rotary_emb) + + # attention_mask + repeat_dim = latents.shape[1] + max_masks = entity_masks.shape[1] + entity_masks = entity_masks.repeat(1, 1, repeat_dim, 1, 1) + entity_masks = [entity_masks[:, i, None].squeeze(1) for i in range(max_masks)] + global_mask = torch.ones_like(entity_masks[0]).to(device=latents.device, dtype=latents.dtype) + entity_masks = entity_masks + [global_mask] + + N = len(entity_masks) + batch_size = entity_masks[0].shape[0] + seq_lens = [mask_.sum(dim=1).item() for mask_ in entity_prompt_emb_mask] + [prompt_emb_mask.sum(dim=1).item()] + total_seq_len = sum(seq_lens) + image.shape[1] + patched_masks = [] + for i in range(N): + patched_mask = rearrange(entity_masks[i], "B C (H P) (W Q) -> B (H W) (C P Q)", H=height//16, W=width//16, P=2, Q=2) + patched_masks.append(patched_mask) + attention_mask = torch.ones((batch_size, total_seq_len, total_seq_len), dtype=torch.bool).to(device=entity_masks[0].device) + + # prompt-image attention mask + image_start = sum(seq_lens) + image_end = total_seq_len + cumsum = [0] + single_image_seq = image_end - image_start + for length in seq_lens: + cumsum.append(cumsum[-1] + length) + for i in range(N): + prompt_start = cumsum[i] + prompt_end = cumsum[i+1] + image_mask = torch.sum(patched_masks[i], dim=-1) > 0 + image_mask = image_mask.unsqueeze(1).repeat(1, seq_lens[i], 1) + # repeat image mask to match the single image sequence length + repeat_time = single_image_seq // image_mask.shape[-1] + image_mask = image_mask.repeat(1, 1, repeat_time) + # prompt update with image + attention_mask[:, prompt_start:prompt_end, image_start:image_end] = image_mask + # image update with prompt + attention_mask[:, image_start:image_end, prompt_start:prompt_end] = image_mask.transpose(1, 2) + # prompt-prompt attention mask, let the prompt tokens not attend to each other + for i in range(N): + for j in range(N): + if i == j: + continue + start_i, end_i = cumsum[i], cumsum[i+1] + start_j, end_j = cumsum[j], cumsum[j+1] + attention_mask[:, start_i:end_i, start_j:end_j] = False + + attention_mask = attention_mask.float() + attention_mask[attention_mask == 0] = float('-inf') + attention_mask[attention_mask == 1] = 0 + attention_mask = attention_mask.to(device=latents.device, dtype=latents.dtype).unsqueeze(1) + + return all_prompt_emb, image_rotary_emb, attention_mask + + + def forward( + self, + latents=None, + timestep=None, + prompt_emb=None, + prompt_emb_mask=None, + height=None, + width=None, + ): + img_shapes = [(latents.shape[0], latents.shape[2]//2, latents.shape[3]//2)] + txt_seq_lens = prompt_emb_mask.sum(dim=1).tolist() + + image = rearrange(latents, "B C (H P) (W Q) -> B (H W) (C P Q)", H=height//16, W=width//16, P=2, Q=2) + image = self.img_in(image) + text = self.txt_in(self.txt_norm(prompt_emb)) + + conditioning = self.time_text_embed(timestep, image.dtype) + + image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=latents.device) + + for block in self.transformer_blocks: + text, image = block( + image=image, + text=text, + temb=conditioning, + image_rotary_emb=image_rotary_emb, + ) + + image = self.norm_out(image, conditioning) + image = self.proj_out(image) + + latents = rearrange(image, "B (H W) (C P Q) -> B C (H P) (W Q)", H=height//16, W=width//16, P=2, Q=2) + return image diff --git a/diffsynth/models/qwen_image_image2lora.py b/diffsynth/models/qwen_image_image2lora.py new file mode 100644 index 0000000000000000000000000000000000000000..6aefbf25de6ccdb37de2d2d44e644fb77952b570 --- /dev/null +++ b/diffsynth/models/qwen_image_image2lora.py @@ -0,0 +1,128 @@ +import torch + + +class CompressedMLP(torch.nn.Module): + def __init__(self, in_dim, mid_dim, out_dim, bias=False): + super().__init__() + self.proj_in = torch.nn.Linear(in_dim, mid_dim, bias=bias) + self.proj_out = torch.nn.Linear(mid_dim, out_dim, bias=bias) + + def forward(self, x, residual=None): + x = self.proj_in(x) + if residual is not None: x = x + residual + x = self.proj_out(x) + return x + + +class ImageEmbeddingToLoraMatrix(torch.nn.Module): + def __init__(self, in_dim, compress_dim, lora_a_dim, lora_b_dim, rank): + super().__init__() + self.proj_a = CompressedMLP(in_dim, compress_dim, lora_a_dim * rank) + self.proj_b = CompressedMLP(in_dim, compress_dim, lora_b_dim * rank) + self.lora_a_dim = lora_a_dim + self.lora_b_dim = lora_b_dim + self.rank = rank + + def forward(self, x, residual=None): + lora_a = self.proj_a(x, residual).view(self.rank, self.lora_a_dim) + lora_b = self.proj_b(x, residual).view(self.lora_b_dim, self.rank) + return lora_a, lora_b + + +class SequencialMLP(torch.nn.Module): + def __init__(self, length, in_dim, mid_dim, out_dim, bias=False): + super().__init__() + self.proj_in = torch.nn.Linear(in_dim, mid_dim, bias=bias) + self.proj_out = torch.nn.Linear(length * mid_dim, out_dim, bias=bias) + self.length = length + self.in_dim = in_dim + self.mid_dim = mid_dim + + def forward(self, x): + x = x.view(self.length, self.in_dim) + x = self.proj_in(x) + x = x.view(1, self.length * self.mid_dim) + x = self.proj_out(x) + return x + + +class LoRATrainerBlock(torch.nn.Module): + def __init__(self, lora_patterns, in_dim=1536+4096, compress_dim=128, rank=4, block_id=0, use_residual=True, residual_length=64+7, residual_dim=3584, residual_mid_dim=1024): + super().__init__() + self.lora_patterns = lora_patterns + self.block_id = block_id + self.layers = [] + for name, lora_a_dim, lora_b_dim in self.lora_patterns: + self.layers.append(ImageEmbeddingToLoraMatrix(in_dim, compress_dim, lora_a_dim, lora_b_dim, rank)) + self.layers = torch.nn.ModuleList(self.layers) + if use_residual: + self.proj_residual = SequencialMLP(residual_length, residual_dim, residual_mid_dim, compress_dim) + else: + self.proj_residual = None + + def forward(self, x, residual=None): + lora = {} + if self.proj_residual is not None: residual = self.proj_residual(residual) + for lora_pattern, layer in zip(self.lora_patterns, self.layers): + name = lora_pattern[0] + lora_a, lora_b = layer(x, residual=residual) + lora[f"transformer_blocks.{self.block_id}.{name}.lora_A.default.weight"] = lora_a + lora[f"transformer_blocks.{self.block_id}.{name}.lora_B.default.weight"] = lora_b + return lora + + +class QwenImageImage2LoRAModel(torch.nn.Module): + def __init__(self, num_blocks=60, use_residual=True, compress_dim=128, rank=4, residual_length=64+7, residual_mid_dim=1024): + super().__init__() + self.lora_patterns = [ + [ + ("attn.to_q", 3072, 3072), + ("attn.to_k", 3072, 3072), + ("attn.to_v", 3072, 3072), + ("attn.to_out.0", 3072, 3072), + ], + [ + ("img_mlp.net.2", 3072*4, 3072), + ("img_mod.1", 3072, 3072*6), + ], + [ + ("attn.add_q_proj", 3072, 3072), + ("attn.add_k_proj", 3072, 3072), + ("attn.add_v_proj", 3072, 3072), + ("attn.to_add_out", 3072, 3072), + ], + [ + ("txt_mlp.net.2", 3072*4, 3072), + ("txt_mod.1", 3072, 3072*6), + ], + ] + self.num_blocks = num_blocks + self.blocks = [] + for lora_patterns in self.lora_patterns: + for block_id in range(self.num_blocks): + self.blocks.append(LoRATrainerBlock(lora_patterns, block_id=block_id, use_residual=use_residual, compress_dim=compress_dim, rank=rank, residual_length=residual_length, residual_mid_dim=residual_mid_dim)) + self.blocks = torch.nn.ModuleList(self.blocks) + self.residual_scale = 0.05 + self.use_residual = use_residual + + def forward(self, x, residual=None): + if residual is not None: + if self.use_residual: + residual = residual * self.residual_scale + else: + residual = None + lora = {} + for block in self.blocks: + lora.update(block(x, residual)) + return lora + + def initialize_weights(self): + state_dict = self.state_dict() + for name in state_dict: + if ".proj_a." in name: + state_dict[name] = state_dict[name] * 0.3 + elif ".proj_b.proj_out." in name: + state_dict[name] = state_dict[name] * 0 + elif ".proj_residual.proj_out." in name: + state_dict[name] = state_dict[name] * 0.3 + self.load_state_dict(state_dict) diff --git a/diffsynth/models/qwen_image_text_encoder.py b/diffsynth/models/qwen_image_text_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..f19d2d8ae2a61fd9cc45414a6bac18d28e4edcc9 --- /dev/null +++ b/diffsynth/models/qwen_image_text_encoder.py @@ -0,0 +1,190 @@ +import torch +from typing import Optional, Union + + +class QwenImageTextEncoder(torch.nn.Module): + def __init__(self): + super().__init__() + from transformers import Qwen2_5_VLConfig, Qwen2_5_VLModel + config = Qwen2_5_VLConfig(**{ + "architectures": [ + "Qwen2_5_VLForConditionalGeneration" + ], + "attention_dropout": 0.0, + "bos_token_id": 151643, + "eos_token_id": 151645, + "hidden_act": "silu", + "hidden_size": 3584, + "image_token_id": 151655, + "initializer_range": 0.02, + "intermediate_size": 18944, + "max_position_embeddings": 128000, + "max_window_layers": 28, + "model_type": "qwen2_5_vl", + "num_attention_heads": 28, + "num_hidden_layers": 28, + "num_key_value_heads": 4, + "rms_norm_eps": 1e-06, + "rope_scaling": { + "mrope_section": [ + 16, + 24, + 24 + ], + "rope_type": "default", + "type": "default" + }, + "rope_theta": 1000000.0, + "sliding_window": 32768, + "text_config": { + "architectures": [ + "Qwen2_5_VLForConditionalGeneration" + ], + "attention_dropout": 0.0, + "bos_token_id": 151643, + "eos_token_id": 151645, + "hidden_act": "silu", + "hidden_size": 3584, + "image_token_id": None, + "initializer_range": 0.02, + "intermediate_size": 18944, + "layer_types": [ + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention" + ], + "max_position_embeddings": 128000, + "max_window_layers": 28, + "model_type": "qwen2_5_vl_text", + "num_attention_heads": 28, + "num_hidden_layers": 28, + "num_key_value_heads": 4, + "rms_norm_eps": 1e-06, + "rope_scaling": { + "mrope_section": [ + 16, + 24, + 24 + ], + "rope_type": "default", + "type": "default" + }, + "rope_theta": 1000000.0, + "sliding_window": None, + "torch_dtype": "float32", + "use_cache": True, + "use_sliding_window": False, + "video_token_id": None, + "vision_end_token_id": 151653, + "vision_start_token_id": 151652, + "vision_token_id": 151654, + "vocab_size": 152064 + }, + "tie_word_embeddings": False, + "torch_dtype": "float32", + "transformers_version": "4.54.0", + "use_cache": True, + "use_sliding_window": False, + "video_token_id": 151656, + "vision_config": { + "depth": 32, + "fullatt_block_indexes": [ + 7, + 15, + 23, + 31 + ], + "hidden_act": "silu", + "hidden_size": 1280, + "in_channels": 3, + "in_chans": 3, + "initializer_range": 0.02, + "intermediate_size": 3420, + "model_type": "qwen2_5_vl", + "num_heads": 16, + "out_hidden_size": 3584, + "patch_size": 14, + "spatial_merge_size": 2, + "spatial_patch_size": 14, + "temporal_patch_size": 2, + "tokens_per_second": 2, + "torch_dtype": "float32", + "window_size": 112 + }, + "vision_end_token_id": 151653, + "vision_start_token_id": 151652, + "vision_token_id": 151654, + "vocab_size": 152064 + }) + self.model = Qwen2_5_VLModel(config) + self.lm_head = torch.nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) + self.config = config + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + pixel_values: Optional[torch.Tensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + rope_deltas: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + second_per_grid_ts: Optional[torch.Tensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs, + ): + output_attentions = False + output_hidden_states = True + + outputs = self.model( + input_ids=input_ids, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + cache_position=cache_position, + **kwargs, + ) + return outputs.hidden_states diff --git a/diffsynth/models/qwen_image_vae.py b/diffsynth/models/qwen_image_vae.py new file mode 100644 index 0000000000000000000000000000000000000000..2845354f24ad68214fe3f0ca1264ea9450ab16f3 --- /dev/null +++ b/diffsynth/models/qwen_image_vae.py @@ -0,0 +1,726 @@ +import torch +from typing import List, Optional, Tuple, Union +from torch import nn + + +CACHE_T = 2 + +class QwenImageCausalConv3d(torch.nn.Conv3d): + r""" + A custom 3D causal convolution layer with feature caching support. + + This layer extends the standard Conv3D layer by ensuring causality in the time dimension and handling feature + caching for efficient inference. + + Args: + in_channels (int): Number of channels in the input image + out_channels (int): Number of channels produced by the convolution + kernel_size (int or tuple): Size of the convolving kernel + stride (int or tuple, optional): Stride of the convolution. Default: 1 + padding (int or tuple, optional): Zero-padding added to all three sides of the input. Default: 0 + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int, int]], + stride: Union[int, Tuple[int, int, int]] = 1, + padding: Union[int, Tuple[int, int, int]] = 0, + ) -> None: + super().__init__( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + ) + + # Set up causal padding + self._padding = (self.padding[2], self.padding[2], self.padding[1], self.padding[1], 2 * self.padding[0], 0) + self.padding = (0, 0, 0) + + def forward(self, x, cache_x=None): + padding = list(self._padding) + if cache_x is not None and self._padding[4] > 0: + cache_x = cache_x.to(x.device) + x = torch.cat([cache_x, x], dim=2) + padding[4] -= cache_x.shape[2] + x = torch.nn.functional.pad(x, padding) + return super().forward(x) + + + +class QwenImageRMS_norm(nn.Module): + r""" + A custom RMS normalization layer. + + Args: + dim (int): The number of dimensions to normalize over. + channel_first (bool, optional): Whether the input tensor has channels as the first dimension. + Default is True. + images (bool, optional): Whether the input represents image data. Default is True. + bias (bool, optional): Whether to include a learnable bias term. Default is False. + """ + + def __init__(self, dim: int, channel_first: bool = True, images: bool = True, bias: bool = False) -> None: + super().__init__() + broadcastable_dims = (1, 1, 1) if not images else (1, 1) + shape = (dim, *broadcastable_dims) if channel_first else (dim,) + + self.channel_first = channel_first + self.scale = dim**0.5 + self.gamma = nn.Parameter(torch.ones(shape)) + self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0 + + def forward(self, x): + return torch.nn.functional.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias + + + +class QwenImageResidualBlock(nn.Module): + r""" + A custom residual block module. + + Args: + in_dim (int): Number of input channels. + out_dim (int): Number of output channels. + dropout (float, optional): Dropout rate for the dropout layer. Default is 0.0. + non_linearity (str, optional): Type of non-linearity to use. Default is "silu". + """ + + def __init__( + self, + in_dim: int, + out_dim: int, + dropout: float = 0.0, + non_linearity: str = "silu", + ) -> None: + super().__init__() + self.in_dim = in_dim + self.out_dim = out_dim + self.nonlinearity = torch.nn.SiLU() + + # layers + self.norm1 = QwenImageRMS_norm(in_dim, images=False) + self.conv1 = QwenImageCausalConv3d(in_dim, out_dim, 3, padding=1) + self.norm2 = QwenImageRMS_norm(out_dim, images=False) + self.dropout = nn.Dropout(dropout) + self.conv2 = QwenImageCausalConv3d(out_dim, out_dim, 3, padding=1) + self.conv_shortcut = QwenImageCausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else nn.Identity() + + def forward(self, x, feat_cache=None, feat_idx=[0]): + # Apply shortcut connection + h = self.conv_shortcut(x) + + # First normalization and activation + x = self.norm1(x) + x = self.nonlinearity(x) + + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + + x = self.conv1(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv1(x) + + # Second normalization and activation + x = self.norm2(x) + x = self.nonlinearity(x) + + # Dropout + x = self.dropout(x) + + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + + x = self.conv2(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv2(x) + + # Add residual connection + return x + h + + + +class QwenImageAttentionBlock(nn.Module): + r""" + Causal self-attention with a single head. + + Args: + dim (int): The number of channels in the input tensor. + """ + + def __init__(self, dim): + super().__init__() + self.dim = dim + + # layers + self.norm = QwenImageRMS_norm(dim) + self.to_qkv = nn.Conv2d(dim, dim * 3, 1) + self.proj = nn.Conv2d(dim, dim, 1) + + def forward(self, x): + identity = x + batch_size, channels, time, height, width = x.size() + + x = x.permute(0, 2, 1, 3, 4).reshape(batch_size * time, channels, height, width) + x = self.norm(x) + + # compute query, key, value + qkv = self.to_qkv(x) + qkv = qkv.reshape(batch_size * time, 1, channels * 3, -1) + qkv = qkv.permute(0, 1, 3, 2).contiguous() + q, k, v = qkv.chunk(3, dim=-1) + + # apply attention + x = torch.nn.functional.scaled_dot_product_attention(q, k, v) + + x = x.squeeze(1).permute(0, 2, 1).reshape(batch_size * time, channels, height, width) + + # output projection + x = self.proj(x) + + # Reshape back: [(b*t), c, h, w] -> [b, c, t, h, w] + x = x.view(batch_size, time, channels, height, width) + x = x.permute(0, 2, 1, 3, 4) + + return x + identity + + + +class QwenImageUpsample(nn.Upsample): + r""" + Perform upsampling while ensuring the output tensor has the same data type as the input. + + Args: + x (torch.Tensor): Input tensor to be upsampled. + + Returns: + torch.Tensor: Upsampled tensor with the same data type as the input. + """ + + def forward(self, x): + return super().forward(x.float()).type_as(x) + + + +class QwenImageResample(nn.Module): + r""" + A custom resampling module for 2D and 3D data. + + Args: + dim (int): The number of input/output channels. + mode (str): The resampling mode. Must be one of: + - 'none': No resampling (identity operation). + - 'upsample2d': 2D upsampling with nearest-exact interpolation and convolution. + - 'upsample3d': 3D upsampling with nearest-exact interpolation, convolution, and causal 3D convolution. + - 'downsample2d': 2D downsampling with zero-padding and convolution. + - 'downsample3d': 3D downsampling with zero-padding, convolution, and causal 3D convolution. + """ + + def __init__(self, dim: int, mode: str) -> None: + super().__init__() + self.dim = dim + self.mode = mode + + # layers + if mode == "upsample2d": + self.resample = nn.Sequential( + QwenImageUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1) + ) + elif mode == "upsample3d": + self.resample = nn.Sequential( + QwenImageUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1) + ) + self.time_conv = QwenImageCausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0)) + + elif mode == "downsample2d": + self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))) + elif mode == "downsample3d": + self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))) + self.time_conv = QwenImageCausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)) + + else: + self.resample = nn.Identity() + + def forward(self, x, feat_cache=None, feat_idx=[0]): + b, c, t, h, w = x.size() + if self.mode == "upsample3d": + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = "Rep" + feat_idx[0] += 1 + else: + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep": + # cache last frame of last two chunk + cache_x = torch.cat( + [feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2 + ) + if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] == "Rep": + cache_x = torch.cat([torch.zeros_like(cache_x).to(cache_x.device), cache_x], dim=2) + if feat_cache[idx] == "Rep": + x = self.time_conv(x) + else: + x = self.time_conv(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + + x = x.reshape(b, 2, c, t, h, w) + x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3) + x = x.reshape(b, c, t * 2, h, w) + t = x.shape[2] + x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) + x = self.resample(x) + x = x.view(b, t, x.size(1), x.size(2), x.size(3)).permute(0, 2, 1, 3, 4) + + if self.mode == "downsample3d": + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = x.clone() + feat_idx[0] += 1 + else: + cache_x = x[:, :, -1:, :, :].clone() + x = self.time_conv(torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2)) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + return x + + + +class QwenImageMidBlock(nn.Module): + """ + Middle block for WanVAE encoder and decoder. + + Args: + dim (int): Number of input/output channels. + dropout (float): Dropout rate. + non_linearity (str): Type of non-linearity to use. + """ + + def __init__(self, dim: int, dropout: float = 0.0, non_linearity: str = "silu", num_layers: int = 1): + super().__init__() + self.dim = dim + + # Create the components + resnets = [QwenImageResidualBlock(dim, dim, dropout, non_linearity)] + attentions = [] + for _ in range(num_layers): + attentions.append(QwenImageAttentionBlock(dim)) + resnets.append(QwenImageResidualBlock(dim, dim, dropout, non_linearity)) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + self.gradient_checkpointing = False + + def forward(self, x, feat_cache=None, feat_idx=[0]): + # First residual block + x = self.resnets[0](x, feat_cache, feat_idx) + + # Process through attention and residual blocks + for attn, resnet in zip(self.attentions, self.resnets[1:]): + if attn is not None: + x = attn(x) + + x = resnet(x, feat_cache, feat_idx) + + return x + + + +class QwenImageEncoder3d(nn.Module): + r""" + A 3D encoder module. + + Args: + dim (int): The base number of channels in the first layer. + z_dim (int): The dimensionality of the latent space. + dim_mult (list of int): Multipliers for the number of channels in each block. + num_res_blocks (int): Number of residual blocks in each block. + attn_scales (list of float): Scales at which to apply attention mechanisms. + temperal_downsample (list of bool): Whether to downsample temporally in each block. + dropout (float): Dropout rate for the dropout layers. + non_linearity (str): Type of non-linearity to use. + """ + + def __init__( + self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_downsample=[True, True, False], + dropout=0.0, + non_linearity: str = "silu", + image_channels=3 + ): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_downsample = temperal_downsample + self.nonlinearity = torch.nn.SiLU() + + # dimensions + dims = [dim * u for u in [1] + dim_mult] + scale = 1.0 + + # init block + self.conv_in = QwenImageCausalConv3d(image_channels, dims[0], 3, padding=1) + + # downsample blocks + self.down_blocks = torch.nn.ModuleList([]) + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + # residual (+attention) blocks + for _ in range(num_res_blocks): + self.down_blocks.append(QwenImageResidualBlock(in_dim, out_dim, dropout)) + if scale in attn_scales: + self.down_blocks.append(QwenImageAttentionBlock(out_dim)) + in_dim = out_dim + + # downsample block + if i != len(dim_mult) - 1: + mode = "downsample3d" if temperal_downsample[i] else "downsample2d" + self.down_blocks.append(QwenImageResample(out_dim, mode=mode)) + scale /= 2.0 + + # middle blocks + self.mid_block = QwenImageMidBlock(out_dim, dropout, non_linearity, num_layers=1) + + # output blocks + self.norm_out = QwenImageRMS_norm(out_dim, images=False) + self.conv_out = QwenImageCausalConv3d(out_dim, z_dim, 3, padding=1) + + self.gradient_checkpointing = False + + def forward(self, x, feat_cache=None, feat_idx=[0]): + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + x = self.conv_in(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv_in(x) + + ## downsamples + for layer in self.down_blocks: + if feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## middle + x = self.mid_block(x, feat_cache, feat_idx) + + ## head + x = self.norm_out(x) + x = self.nonlinearity(x) + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + x = self.conv_out(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv_out(x) + return x + + + +class QwenImageUpBlock(nn.Module): + """ + A block that handles upsampling for the WanVAE decoder. + + Args: + in_dim (int): Input dimension + out_dim (int): Output dimension + num_res_blocks (int): Number of residual blocks + dropout (float): Dropout rate + upsample_mode (str, optional): Mode for upsampling ('upsample2d' or 'upsample3d') + non_linearity (str): Type of non-linearity to use + """ + + def __init__( + self, + in_dim: int, + out_dim: int, + num_res_blocks: int, + dropout: float = 0.0, + upsample_mode: Optional[str] = None, + non_linearity: str = "silu", + ): + super().__init__() + self.in_dim = in_dim + self.out_dim = out_dim + + # Create layers list + resnets = [] + # Add residual blocks and attention if needed + current_dim = in_dim + for _ in range(num_res_blocks + 1): + resnets.append(QwenImageResidualBlock(current_dim, out_dim, dropout, non_linearity)) + current_dim = out_dim + + self.resnets = nn.ModuleList(resnets) + + # Add upsampling layer if needed + self.upsamplers = None + if upsample_mode is not None: + self.upsamplers = nn.ModuleList([QwenImageResample(out_dim, mode=upsample_mode)]) + + self.gradient_checkpointing = False + + def forward(self, x, feat_cache=None, feat_idx=[0]): + """ + Forward pass through the upsampling block. + + Args: + x (torch.Tensor): Input tensor + feat_cache (list, optional): Feature cache for causal convolutions + feat_idx (list, optional): Feature index for cache management + + Returns: + torch.Tensor: Output tensor + """ + for resnet in self.resnets: + if feat_cache is not None: + x = resnet(x, feat_cache, feat_idx) + else: + x = resnet(x) + + if self.upsamplers is not None: + if feat_cache is not None: + x = self.upsamplers[0](x, feat_cache, feat_idx) + else: + x = self.upsamplers[0](x) + return x + + + +class QwenImageDecoder3d(nn.Module): + r""" + A 3D decoder module. + + Args: + dim (int): The base number of channels in the first layer. + z_dim (int): The dimensionality of the latent space. + dim_mult (list of int): Multipliers for the number of channels in each block. + num_res_blocks (int): Number of residual blocks in each block. + attn_scales (list of float): Scales at which to apply attention mechanisms. + temperal_upsample (list of bool): Whether to upsample temporally in each block. + dropout (float): Dropout rate for the dropout layers. + non_linearity (str): Type of non-linearity to use. + """ + + def __init__( + self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_upsample=[False, True, True], + dropout=0.0, + non_linearity: str = "silu", + image_channels=3, + ): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_upsample = temperal_upsample + + self.nonlinearity = torch.nn.SiLU() + + # dimensions + dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] + scale = 1.0 / 2 ** (len(dim_mult) - 2) + + # init block + self.conv_in = QwenImageCausalConv3d(z_dim, dims[0], 3, padding=1) + + # middle blocks + self.mid_block = QwenImageMidBlock(dims[0], dropout, non_linearity, num_layers=1) + + # upsample blocks + self.up_blocks = nn.ModuleList([]) + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + # residual (+attention) blocks + if i > 0: + in_dim = in_dim // 2 + + # Determine if we need upsampling + upsample_mode = None + if i != len(dim_mult) - 1: + upsample_mode = "upsample3d" if temperal_upsample[i] else "upsample2d" + + # Create and add the upsampling block + up_block = QwenImageUpBlock( + in_dim=in_dim, + out_dim=out_dim, + num_res_blocks=num_res_blocks, + dropout=dropout, + upsample_mode=upsample_mode, + non_linearity=non_linearity, + ) + self.up_blocks.append(up_block) + + # Update scale for next iteration + if upsample_mode is not None: + scale *= 2.0 + + # output blocks + self.norm_out = QwenImageRMS_norm(out_dim, images=False) + self.conv_out = QwenImageCausalConv3d(out_dim, image_channels, 3, padding=1) + + self.gradient_checkpointing = False + + def forward(self, x, feat_cache=None, feat_idx=[0]): + ## conv1 + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + x = self.conv_in(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv_in(x) + + ## middle + x = self.mid_block(x, feat_cache, feat_idx) + + ## upsamples + for up_block in self.up_blocks: + x = up_block(x, feat_cache, feat_idx) + + ## head + x = self.norm_out(x) + x = self.nonlinearity(x) + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + x = self.conv_out(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv_out(x) + return x + + + +class QwenImageVAE(torch.nn.Module): + def __init__( + self, + base_dim: int = 96, + z_dim: int = 16, + dim_mult: Tuple[int] = [1, 2, 4, 4], + num_res_blocks: int = 2, + attn_scales: List[float] = [], + temperal_downsample: List[bool] = [False, True, True], + dropout: float = 0.0, + image_channels: int = 3, + ) -> None: + super().__init__() + + self.z_dim = z_dim + self.temperal_downsample = temperal_downsample + self.temperal_upsample = temperal_downsample[::-1] + + self.encoder = QwenImageEncoder3d( + base_dim, z_dim * 2, dim_mult, num_res_blocks, attn_scales, self.temperal_downsample, dropout, image_channels=image_channels, + ) + self.quant_conv = QwenImageCausalConv3d(z_dim * 2, z_dim * 2, 1) + self.post_quant_conv = QwenImageCausalConv3d(z_dim, z_dim, 1) + + self.decoder = QwenImageDecoder3d( + base_dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout, image_channels=image_channels, + ) + + mean = [ + -0.7571, + -0.7089, + -0.9113, + 0.1075, + -0.1745, + 0.9653, + -0.1517, + 1.5508, + 0.4134, + -0.0715, + 0.5517, + -0.3632, + -0.1922, + -0.9497, + 0.2503, + -0.2921, + ] + std = [ + 2.8184, + 1.4541, + 2.3275, + 2.6558, + 1.2196, + 1.7708, + 2.6052, + 2.0743, + 3.2687, + 2.1526, + 2.8652, + 1.5579, + 1.6382, + 1.1253, + 2.8251, + 1.9160, + ] + self.mean = torch.tensor(mean).view(1, 16, 1, 1, 1) + self.std = 1 / torch.tensor(std).view(1, 16, 1, 1, 1) + + def encode(self, x, **kwargs): + x = x.unsqueeze(2) + x = self.encoder(x) + x = self.quant_conv(x) + x = x[:, :16] + mean, std = self.mean.to(dtype=x.dtype, device=x.device), self.std.to(dtype=x.dtype, device=x.device) + x = (x - mean) * std + x = x.squeeze(2) + return x + + def decode(self, x, **kwargs): + x = x.unsqueeze(2) + mean, std = self.mean.to(dtype=x.dtype, device=x.device), self.std.to(dtype=x.dtype, device=x.device) + x = x / std + mean + x = self.post_quant_conv(x) + x = self.decoder(x) + x = x.squeeze(2) + return x diff --git a/diffsynth/models/sd_text_encoder.py b/diffsynth/models/sd_text_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..b0a1171c265a43048c1623bd0c2375f4fc3f5e5d --- /dev/null +++ b/diffsynth/models/sd_text_encoder.py @@ -0,0 +1,412 @@ +import torch +from .attention import Attention +from einops import rearrange + + +def low_version_attention(query, key, value, attn_bias=None): + scale = 1 / query.shape[-1] ** 0.5 + query = query * scale + attn = torch.matmul(query, key.transpose(-2, -1)) + if attn_bias is not None: + attn = attn + attn_bias + attn = attn.softmax(-1) + return attn @ value + + +class Attention(torch.nn.Module): + + def __init__(self, q_dim, num_heads, head_dim, kv_dim=None, bias_q=False, bias_kv=False, bias_out=False): + super().__init__() + dim_inner = head_dim * num_heads + kv_dim = kv_dim if kv_dim is not None else q_dim + self.num_heads = num_heads + self.head_dim = head_dim + + self.to_q = torch.nn.Linear(q_dim, dim_inner, bias=bias_q) + self.to_k = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv) + self.to_v = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv) + self.to_out = torch.nn.Linear(dim_inner, q_dim, bias=bias_out) + + def interact_with_ipadapter(self, hidden_states, q, ip_k, ip_v, scale=1.0): + batch_size = q.shape[0] + ip_k = ip_k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + ip_v = ip_v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + ip_hidden_states = torch.nn.functional.scaled_dot_product_attention(q, ip_k, ip_v) + hidden_states = hidden_states + scale * ip_hidden_states + return hidden_states + + def torch_forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None, ipadapter_kwargs=None, qkv_preprocessor=None): + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + + batch_size = encoder_hidden_states.shape[0] + + q = self.to_q(hidden_states) + k = self.to_k(encoder_hidden_states) + v = self.to_v(encoder_hidden_states) + + q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + + if qkv_preprocessor is not None: + q, k, v = qkv_preprocessor(q, k, v) + + hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) + if ipadapter_kwargs is not None: + hidden_states = self.interact_with_ipadapter(hidden_states, q, **ipadapter_kwargs) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim) + hidden_states = hidden_states.to(q.dtype) + + hidden_states = self.to_out(hidden_states) + + return hidden_states + + def xformers_forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None): + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + + q = self.to_q(hidden_states) + k = self.to_k(encoder_hidden_states) + v = self.to_v(encoder_hidden_states) + + q = rearrange(q, "b f (n d) -> (b n) f d", n=self.num_heads) + k = rearrange(k, "b f (n d) -> (b n) f d", n=self.num_heads) + v = rearrange(v, "b f (n d) -> (b n) f d", n=self.num_heads) + + if attn_mask is not None: + hidden_states = low_version_attention(q, k, v, attn_bias=attn_mask) + else: + import xformers.ops as xops + hidden_states = xops.memory_efficient_attention(q, k, v) + hidden_states = rearrange(hidden_states, "(b n) f d -> b f (n d)", n=self.num_heads) + + hidden_states = hidden_states.to(q.dtype) + hidden_states = self.to_out(hidden_states) + + return hidden_states + + def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None, ipadapter_kwargs=None, qkv_preprocessor=None): + return self.torch_forward(hidden_states, encoder_hidden_states=encoder_hidden_states, attn_mask=attn_mask, ipadapter_kwargs=ipadapter_kwargs, qkv_preprocessor=qkv_preprocessor) + + + + + +class CLIPEncoderLayer(torch.nn.Module): + def __init__(self, embed_dim, intermediate_size, num_heads=12, head_dim=64, use_quick_gelu=True): + super().__init__() + self.attn = Attention(q_dim=embed_dim, num_heads=num_heads, head_dim=head_dim, bias_q=True, bias_kv=True, bias_out=True) + self.layer_norm1 = torch.nn.LayerNorm(embed_dim) + self.layer_norm2 = torch.nn.LayerNorm(embed_dim) + self.fc1 = torch.nn.Linear(embed_dim, intermediate_size) + self.fc2 = torch.nn.Linear(intermediate_size, embed_dim) + + self.use_quick_gelu = use_quick_gelu + + def quickGELU(self, x): + return x * torch.sigmoid(1.702 * x) + + def forward(self, hidden_states, attn_mask=None): + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states = self.attn(hidden_states, attn_mask=attn_mask) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.fc1(hidden_states) + if self.use_quick_gelu: + hidden_states = self.quickGELU(hidden_states) + else: + hidden_states = torch.nn.functional.gelu(hidden_states) + hidden_states = self.fc2(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class SDTextEncoder(torch.nn.Module): + def __init__(self, embed_dim=768, vocab_size=49408, max_position_embeddings=77, num_encoder_layers=12, encoder_intermediate_size=3072): + super().__init__() + + # token_embedding + self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim) + + # position_embeds (This is a fixed tensor) + self.position_embeds = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, embed_dim)) + + # encoders + self.encoders = torch.nn.ModuleList([CLIPEncoderLayer(embed_dim, encoder_intermediate_size) for _ in range(num_encoder_layers)]) + + # attn_mask + self.attn_mask = self.attention_mask(max_position_embeddings) + + # final_layer_norm + self.final_layer_norm = torch.nn.LayerNorm(embed_dim) + + def attention_mask(self, length): + mask = torch.empty(length, length) + mask.fill_(float("-inf")) + mask.triu_(1) + return mask + + def forward(self, input_ids, clip_skip=1): + embeds = self.token_embedding(input_ids) + self.position_embeds + attn_mask = self.attn_mask.to(device=embeds.device, dtype=embeds.dtype) + for encoder_id, encoder in enumerate(self.encoders): + embeds = encoder(embeds, attn_mask=attn_mask) + if encoder_id + clip_skip == len(self.encoders): + break + embeds = self.final_layer_norm(embeds) + return embeds + + @staticmethod + def state_dict_converter(): + return SDTextEncoderStateDictConverter() + + +class SDTextEncoderStateDictConverter: + def __init__(self): + pass + + def from_diffusers(self, state_dict): + rename_dict = { + "text_model.embeddings.token_embedding.weight": "token_embedding.weight", + "text_model.embeddings.position_embedding.weight": "position_embeds", + "text_model.final_layer_norm.weight": "final_layer_norm.weight", + "text_model.final_layer_norm.bias": "final_layer_norm.bias" + } + attn_rename_dict = { + "self_attn.q_proj": "attn.to_q", + "self_attn.k_proj": "attn.to_k", + "self_attn.v_proj": "attn.to_v", + "self_attn.out_proj": "attn.to_out", + "layer_norm1": "layer_norm1", + "layer_norm2": "layer_norm2", + "mlp.fc1": "fc1", + "mlp.fc2": "fc2", + } + state_dict_ = {} + for name in state_dict: + if name in rename_dict: + param = state_dict[name] + if name == "text_model.embeddings.position_embedding.weight": + param = param.reshape((1, param.shape[0], param.shape[1])) + state_dict_[rename_dict[name]] = param + elif name.startswith("text_model.encoder.layers."): + param = state_dict[name] + names = name.split(".") + layer_id, layer_type, tail = names[3], ".".join(names[4:-1]), names[-1] + name_ = ".".join(["encoders", layer_id, attn_rename_dict[layer_type], tail]) + state_dict_[name_] = param + return state_dict_ + + def from_civitai(self, state_dict): + rename_dict = { + "cond_stage_model.transformer.text_model.embeddings.token_embedding.weight": "token_embedding.weight", + "cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm1.bias": "encoders.0.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm1.weight": "encoders.0.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm2.bias": "encoders.0.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm2.weight": "encoders.0.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc1.bias": "encoders.0.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc1.weight": "encoders.0.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc2.bias": "encoders.0.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc2.weight": "encoders.0.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.k_proj.bias": "encoders.0.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.k_proj.weight": "encoders.0.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.out_proj.bias": "encoders.0.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.out_proj.weight": "encoders.0.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.q_proj.bias": "encoders.0.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.q_proj.weight": "encoders.0.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.v_proj.bias": "encoders.0.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.v_proj.weight": "encoders.0.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm1.bias": "encoders.1.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm1.weight": "encoders.1.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm2.bias": "encoders.1.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm2.weight": "encoders.1.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc1.bias": "encoders.1.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc1.weight": "encoders.1.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc2.bias": "encoders.1.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc2.weight": "encoders.1.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.k_proj.bias": "encoders.1.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.k_proj.weight": "encoders.1.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.out_proj.bias": "encoders.1.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.out_proj.weight": "encoders.1.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.q_proj.bias": "encoders.1.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.q_proj.weight": "encoders.1.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.v_proj.bias": "encoders.1.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.v_proj.weight": "encoders.1.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm1.bias": "encoders.10.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm1.weight": "encoders.10.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm2.bias": "encoders.10.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm2.weight": "encoders.10.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc1.bias": "encoders.10.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc1.weight": "encoders.10.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc2.bias": "encoders.10.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc2.weight": "encoders.10.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.k_proj.bias": "encoders.10.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.k_proj.weight": "encoders.10.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.out_proj.bias": "encoders.10.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.out_proj.weight": "encoders.10.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.q_proj.bias": "encoders.10.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.q_proj.weight": "encoders.10.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.v_proj.bias": "encoders.10.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.v_proj.weight": "encoders.10.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm1.bias": "encoders.11.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm1.weight": "encoders.11.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm2.bias": "encoders.11.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm2.weight": "encoders.11.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc1.bias": "encoders.11.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc1.weight": "encoders.11.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc2.bias": "encoders.11.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc2.weight": "encoders.11.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.k_proj.bias": "encoders.11.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.k_proj.weight": "encoders.11.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.out_proj.bias": "encoders.11.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.out_proj.weight": "encoders.11.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.q_proj.bias": "encoders.11.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.q_proj.weight": "encoders.11.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.v_proj.bias": "encoders.11.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.v_proj.weight": "encoders.11.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm1.bias": "encoders.2.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm1.weight": "encoders.2.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm2.bias": "encoders.2.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm2.weight": "encoders.2.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc1.bias": "encoders.2.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc1.weight": "encoders.2.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc2.bias": "encoders.2.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc2.weight": "encoders.2.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.k_proj.bias": "encoders.2.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.k_proj.weight": "encoders.2.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.out_proj.bias": "encoders.2.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.out_proj.weight": "encoders.2.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.q_proj.bias": "encoders.2.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.q_proj.weight": "encoders.2.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.v_proj.bias": "encoders.2.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.v_proj.weight": "encoders.2.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm1.bias": "encoders.3.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm1.weight": "encoders.3.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm2.bias": "encoders.3.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm2.weight": "encoders.3.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc1.bias": "encoders.3.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc1.weight": "encoders.3.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc2.bias": "encoders.3.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc2.weight": "encoders.3.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.k_proj.bias": "encoders.3.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.k_proj.weight": "encoders.3.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.out_proj.bias": "encoders.3.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.out_proj.weight": "encoders.3.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.q_proj.bias": "encoders.3.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.q_proj.weight": "encoders.3.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.v_proj.bias": "encoders.3.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.v_proj.weight": "encoders.3.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm1.bias": "encoders.4.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm1.weight": "encoders.4.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm2.bias": "encoders.4.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm2.weight": "encoders.4.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc1.bias": "encoders.4.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc1.weight": "encoders.4.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc2.bias": "encoders.4.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc2.weight": "encoders.4.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.k_proj.bias": "encoders.4.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.k_proj.weight": "encoders.4.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.out_proj.bias": "encoders.4.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.out_proj.weight": "encoders.4.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.q_proj.bias": "encoders.4.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.q_proj.weight": "encoders.4.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.v_proj.bias": "encoders.4.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.v_proj.weight": "encoders.4.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm1.bias": "encoders.5.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm1.weight": "encoders.5.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm2.bias": "encoders.5.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm2.weight": "encoders.5.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc1.bias": "encoders.5.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc1.weight": "encoders.5.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc2.bias": "encoders.5.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc2.weight": "encoders.5.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.k_proj.bias": "encoders.5.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.k_proj.weight": "encoders.5.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.out_proj.bias": "encoders.5.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.out_proj.weight": "encoders.5.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.q_proj.bias": "encoders.5.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.q_proj.weight": "encoders.5.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.v_proj.bias": "encoders.5.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.v_proj.weight": "encoders.5.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm1.bias": "encoders.6.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm1.weight": "encoders.6.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm2.bias": "encoders.6.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm2.weight": "encoders.6.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc1.bias": "encoders.6.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc1.weight": "encoders.6.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc2.bias": "encoders.6.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc2.weight": "encoders.6.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.k_proj.bias": "encoders.6.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.k_proj.weight": "encoders.6.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.out_proj.bias": "encoders.6.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.out_proj.weight": "encoders.6.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.q_proj.bias": "encoders.6.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.q_proj.weight": "encoders.6.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.v_proj.bias": "encoders.6.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.v_proj.weight": "encoders.6.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm1.bias": "encoders.7.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm1.weight": "encoders.7.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm2.bias": "encoders.7.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm2.weight": "encoders.7.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc1.bias": "encoders.7.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc1.weight": "encoders.7.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc2.bias": "encoders.7.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc2.weight": "encoders.7.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.k_proj.bias": "encoders.7.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.k_proj.weight": "encoders.7.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.out_proj.bias": "encoders.7.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.out_proj.weight": "encoders.7.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.q_proj.bias": "encoders.7.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.q_proj.weight": "encoders.7.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.v_proj.bias": "encoders.7.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.v_proj.weight": "encoders.7.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm1.bias": "encoders.8.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm1.weight": "encoders.8.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm2.bias": "encoders.8.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm2.weight": "encoders.8.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc1.bias": "encoders.8.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc1.weight": "encoders.8.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc2.bias": "encoders.8.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc2.weight": "encoders.8.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.k_proj.bias": "encoders.8.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.k_proj.weight": "encoders.8.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.out_proj.bias": "encoders.8.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.out_proj.weight": "encoders.8.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.q_proj.bias": "encoders.8.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.q_proj.weight": "encoders.8.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.v_proj.bias": "encoders.8.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.v_proj.weight": "encoders.8.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm1.bias": "encoders.9.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm1.weight": "encoders.9.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm2.bias": "encoders.9.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm2.weight": "encoders.9.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc1.bias": "encoders.9.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc1.weight": "encoders.9.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc2.bias": "encoders.9.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc2.weight": "encoders.9.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.k_proj.bias": "encoders.9.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.k_proj.weight": "encoders.9.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.out_proj.bias": "encoders.9.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.out_proj.weight": "encoders.9.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.q_proj.bias": "encoders.9.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.q_proj.weight": "encoders.9.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.v_proj.bias": "encoders.9.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.v_proj.weight": "encoders.9.attn.to_v.weight", + "cond_stage_model.transformer.text_model.final_layer_norm.bias": "final_layer_norm.bias", + "cond_stage_model.transformer.text_model.final_layer_norm.weight": "final_layer_norm.weight", + "cond_stage_model.transformer.text_model.embeddings.position_embedding.weight": "position_embeds" + } + state_dict_ = {} + for name in state_dict: + if name in rename_dict: + param = state_dict[name] + if name == "cond_stage_model.transformer.text_model.embeddings.position_embedding.weight": + param = param.reshape((1, param.shape[0], param.shape[1])) + state_dict_[rename_dict[name]] = param + return state_dict_ diff --git a/diffsynth/models/siglip2_image_encoder.py b/diffsynth/models/siglip2_image_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..509eff40946699dffcb31125916fc7acfad0caa0 --- /dev/null +++ b/diffsynth/models/siglip2_image_encoder.py @@ -0,0 +1,134 @@ +from transformers.models.siglip.modeling_siglip import SiglipVisionTransformer, SiglipVisionConfig +from transformers import SiglipImageProcessor, Siglip2VisionModel, Siglip2VisionConfig, Siglip2ImageProcessorFast +import torch + +from diffsynth.core.device.npu_compatible_device import get_device_type + + +class Siglip2ImageEncoder(SiglipVisionTransformer): + def __init__(self): + config = SiglipVisionConfig( + attention_dropout = 0.0, + dtype = "float32", + hidden_act = "gelu_pytorch_tanh", + hidden_size = 1536, + image_size = 384, + intermediate_size = 6144, + layer_norm_eps = 1e-06, + model_type = "siglip_vision_model", + num_attention_heads = 16, + num_channels = 3, + num_hidden_layers = 40, + patch_size = 16, + transformers_version = "4.56.1", + _attn_implementation = "sdpa" + ) + super().__init__(config) + self.processor = SiglipImageProcessor( + do_convert_rgb = None, + do_normalize = True, + do_rescale = True, + do_resize = True, + image_mean = [ + 0.5, + 0.5, + 0.5 + ], + image_processor_type = "SiglipImageProcessor", + image_std = [ + 0.5, + 0.5, + 0.5 + ], + processor_class = "SiglipProcessor", + resample = 2, + rescale_factor = 0.00392156862745098, + size = { + "height": 384, + "width": 384 + } + ) + + def forward(self, image, torch_dtype=torch.bfloat16, device=get_device_type()): + pixel_values = self.processor(images=[image], return_tensors="pt")["pixel_values"] + pixel_values = pixel_values.to(device=device, dtype=torch_dtype) + output_attentions = False + output_hidden_states = False + interpolate_pos_encoding = False + + hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) + + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + last_hidden_state = encoder_outputs.last_hidden_state + last_hidden_state = self.post_layernorm(last_hidden_state) + + pooler_output = self.head(last_hidden_state) if self.use_head else None + + return pooler_output + + +class Siglip2ImageEncoder428M(Siglip2VisionModel): + def __init__(self): + config = Siglip2VisionConfig( + attention_dropout = 0.0, + dtype = "bfloat16", + hidden_act = "gelu_pytorch_tanh", + hidden_size = 1152, + intermediate_size = 4304, + layer_norm_eps = 1e-06, + model_type = "siglip2_vision_model", + num_attention_heads = 16, + num_channels = 3, + num_hidden_layers = 27, + num_patches = 256, + patch_size = 16, + transformers_version = "4.57.1" + ) + super().__init__(config) + self.processor = Siglip2ImageProcessorFast( + **{ + "data_format": "channels_first", + "default_to_square": True, + "device": None, + "disable_grouping": None, + "do_convert_rgb": None, + "do_normalize": True, + "do_pad": None, + "do_rescale": True, + "do_resize": True, + "image_mean": [ + 0.5, + 0.5, + 0.5 + ], + "image_processor_type": "Siglip2ImageProcessorFast", + "image_std": [ + 0.5, + 0.5, + 0.5 + ], + "input_data_format": None, + "max_num_patches": 256, + "pad_size": None, + "patch_size": 16, + "processor_class": "Siglip2Processor", + "resample": 2, + "rescale_factor": 0.00392156862745098, + "return_tensors": None, + } + ) + + def forward(self, image, torch_dtype=torch.bfloat16, device="cuda"): + siglip_inputs = self.processor(images=[image], return_tensors="pt").to(device) + shape = siglip_inputs.spatial_shapes[0] + hidden_state = super().forward(**siglip_inputs).last_hidden_state + B, N, C = hidden_state.shape + hidden_state = hidden_state[:, : shape[0] * shape[1]] + hidden_state = hidden_state.view(shape[0], shape[1], C) + hidden_state = hidden_state.to(torch_dtype) + return hidden_state diff --git a/diffsynth/models/step1x_connector.py b/diffsynth/models/step1x_connector.py new file mode 100644 index 0000000000000000000000000000000000000000..225c8fbcb54f8daebf48656141e9a8998d002fd8 --- /dev/null +++ b/diffsynth/models/step1x_connector.py @@ -0,0 +1,663 @@ +from typing import Optional + +import torch, math +import torch.nn +from einops import rearrange +from torch import nn +from functools import partial +from einops import rearrange + + + +def attention(q, k, v, attn_mask, mode="torch"): + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) + x = rearrange(x, "b n s d -> b s (n d)") + return x + + + +class MLP(nn.Module): + """MLP as used in Vision Transformer, MLP-Mixer and related networks""" + + def __init__( + self, + in_channels, + hidden_channels=None, + out_features=None, + act_layer=nn.GELU, + norm_layer=None, + bias=True, + drop=0.0, + use_conv=False, + device=None, + dtype=None, + ): + super().__init__() + out_features = out_features or in_channels + hidden_channels = hidden_channels or in_channels + bias = (bias, bias) + drop_probs = (drop, drop) + linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear + + self.fc1 = linear_layer( + in_channels, hidden_channels, bias=bias[0], device=device, dtype=dtype + ) + self.act = act_layer() + self.drop1 = nn.Dropout(drop_probs[0]) + self.norm = ( + norm_layer(hidden_channels, device=device, dtype=dtype) + if norm_layer is not None + else nn.Identity() + ) + self.fc2 = linear_layer( + hidden_channels, out_features, bias=bias[1], device=device, dtype=dtype + ) + self.drop2 = nn.Dropout(drop_probs[1]) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop1(x) + x = self.norm(x) + x = self.fc2(x) + x = self.drop2(x) + return x + + +class TextProjection(nn.Module): + """ + Projects text embeddings. Also handles dropout for classifier-free guidance. + + Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py + """ + + def __init__(self, in_channels, hidden_size, act_layer, dtype=None, device=None): + factory_kwargs = {"dtype": dtype, "device": device} + super().__init__() + self.linear_1 = nn.Linear( + in_features=in_channels, + out_features=hidden_size, + bias=True, + **factory_kwargs, + ) + self.act_1 = act_layer() + self.linear_2 = nn.Linear( + in_features=hidden_size, + out_features=hidden_size, + bias=True, + **factory_kwargs, + ) + + def forward(self, caption): + hidden_states = self.linear_1(caption) + hidden_states = self.act_1(hidden_states) + hidden_states = self.linear_2(hidden_states) + return hidden_states + + +class TimestepEmbedder(nn.Module): + """ + Embeds scalar timesteps into vector representations. + """ + + def __init__( + self, + hidden_size, + act_layer, + frequency_embedding_size=256, + max_period=10000, + out_size=None, + dtype=None, + device=None, + ): + factory_kwargs = {"dtype": dtype, "device": device} + super().__init__() + self.frequency_embedding_size = frequency_embedding_size + self.max_period = max_period + if out_size is None: + out_size = hidden_size + + self.mlp = nn.Sequential( + nn.Linear( + frequency_embedding_size, hidden_size, bias=True, **factory_kwargs + ), + act_layer(), + nn.Linear(hidden_size, out_size, bias=True, **factory_kwargs), + ) + nn.init.normal_(self.mlp[0].weight, std=0.02) # type: ignore + nn.init.normal_(self.mlp[2].weight, std=0.02) # type: ignore + + @staticmethod + def timestep_embedding(t, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + + Args: + t (torch.Tensor): a 1-D Tensor of N indices, one per batch element. These may be fractional. + dim (int): the dimension of the output. + max_period (int): controls the minimum frequency of the embeddings. + + Returns: + embedding (torch.Tensor): An (N, D) Tensor of positional embeddings. + + .. ref_link: https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py + """ + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) + * torch.arange(start=0, end=half, dtype=torch.float32) + / half + ).to(device=t.device) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat( + [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 + ) + return embedding + + def forward(self, t): + t_freq = self.timestep_embedding( + t, self.frequency_embedding_size, self.max_period + ).type(t.dtype) # type: ignore + t_emb = self.mlp(t_freq) + return t_emb + + +def apply_gate(x, gate=None, tanh=False): + """AI is creating summary for apply_gate + + Args: + x (torch.Tensor): input tensor. + gate (torch.Tensor, optional): gate tensor. Defaults to None. + tanh (bool, optional): whether to use tanh function. Defaults to False. + + Returns: + torch.Tensor: the output tensor after apply gate. + """ + if gate is None: + return x + if tanh: + return x * gate.unsqueeze(1).tanh() + else: + return x * gate.unsqueeze(1) + + +class RMSNorm(nn.Module): + def __init__( + self, + dim: int, + elementwise_affine=True, + eps: float = 1e-6, + device=None, + dtype=None, + ): + """ + Initialize the RMSNorm normalization layer. + + Args: + dim (int): The dimension of the input tensor. + eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. + + Attributes: + eps (float): A small value added to the denominator for numerical stability. + weight (nn.Parameter): Learnable scaling parameter. + + """ + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.eps = eps + if elementwise_affine: + self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs)) + + def _norm(self, x): + """ + Apply the RMSNorm normalization to the input tensor. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The normalized tensor. + + """ + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + """ + Forward pass through the RMSNorm layer. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The output tensor after applying RMSNorm. + + """ + output = self._norm(x.float()).type_as(x) + if hasattr(self, "weight"): + output = output * self.weight + return output + + +def get_norm_layer(norm_layer): + """ + Get the normalization layer. + + Args: + norm_layer (str): The type of normalization layer. + + Returns: + norm_layer (nn.Module): The normalization layer. + """ + if norm_layer == "layer": + return nn.LayerNorm + elif norm_layer == "rms": + return RMSNorm + else: + raise NotImplementedError(f"Norm layer {norm_layer} is not implemented") + + +def get_activation_layer(act_type): + """get activation layer + + Args: + act_type (str): the activation type + + Returns: + torch.nn.functional: the activation layer + """ + if act_type == "gelu": + return lambda: nn.GELU() + elif act_type == "gelu_tanh": + return lambda: nn.GELU(approximate="tanh") + elif act_type == "relu": + return nn.ReLU + elif act_type == "silu": + return nn.SiLU + else: + raise ValueError(f"Unknown activation type: {act_type}") + +class IndividualTokenRefinerBlock(torch.nn.Module): + def __init__( + self, + hidden_size, + heads_num, + mlp_width_ratio: str = 4.0, + mlp_drop_rate: float = 0.0, + act_type: str = "silu", + qk_norm: bool = False, + qk_norm_type: str = "layer", + qkv_bias: bool = True, + need_CA: bool = False, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + ): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.need_CA = need_CA + self.heads_num = heads_num + head_dim = hidden_size // heads_num + mlp_hidden_dim = int(hidden_size * mlp_width_ratio) + + self.norm1 = nn.LayerNorm( + hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs + ) + self.self_attn_qkv = nn.Linear( + hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs + ) + qk_norm_layer = get_norm_layer(qk_norm_type) + self.self_attn_q_norm = ( + qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) + if qk_norm + else nn.Identity() + ) + self.self_attn_k_norm = ( + qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) + if qk_norm + else nn.Identity() + ) + self.self_attn_proj = nn.Linear( + hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs + ) + + self.norm2 = nn.LayerNorm( + hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs + ) + act_layer = get_activation_layer(act_type) + self.mlp = MLP( + in_channels=hidden_size, + hidden_channels=mlp_hidden_dim, + act_layer=act_layer, + drop=mlp_drop_rate, + **factory_kwargs, + ) + + self.adaLN_modulation = nn.Sequential( + act_layer(), + nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs), + ) + + if self.need_CA: + self.cross_attnblock=CrossAttnBlock(hidden_size=hidden_size, + heads_num=heads_num, + mlp_width_ratio=mlp_width_ratio, + mlp_drop_rate=mlp_drop_rate, + act_type=act_type, + qk_norm=qk_norm, + qk_norm_type=qk_norm_type, + qkv_bias=qkv_bias, + **factory_kwargs,) + # Zero-initialize the modulation + nn.init.zeros_(self.adaLN_modulation[1].weight) + nn.init.zeros_(self.adaLN_modulation[1].bias) + + def forward( + self, + x: torch.Tensor, + c: torch.Tensor, # timestep_aware_representations + context_aware_representations + attn_mask: torch.Tensor = None, + y: torch.Tensor = None, + ): + gate_msa, gate_mlp = self.adaLN_modulation(c).chunk(2, dim=1) + + norm_x = self.norm1(x) + qkv = self.self_attn_qkv(norm_x) + q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num) + # Apply QK-Norm if needed + q = self.self_attn_q_norm(q).to(v) + k = self.self_attn_k_norm(k).to(v) + + # Self-Attention + attn = attention(q, k, v, mode="torch", attn_mask=attn_mask) + + x = x + apply_gate(self.self_attn_proj(attn), gate_msa) + + if self.need_CA: + x = self.cross_attnblock(x, c, attn_mask, y) + + # FFN Layer + x = x + apply_gate(self.mlp(self.norm2(x)), gate_mlp) + + return x + + + + +class CrossAttnBlock(torch.nn.Module): + def __init__( + self, + hidden_size, + heads_num, + mlp_width_ratio: str = 4.0, + mlp_drop_rate: float = 0.0, + act_type: str = "silu", + qk_norm: bool = False, + qk_norm_type: str = "layer", + qkv_bias: bool = True, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + ): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.heads_num = heads_num + head_dim = hidden_size // heads_num + + self.norm1 = nn.LayerNorm( + hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs + ) + self.norm1_2 = nn.LayerNorm( + hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs + ) + self.self_attn_q = nn.Linear( + hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs + ) + self.self_attn_kv = nn.Linear( + hidden_size, hidden_size*2, bias=qkv_bias, **factory_kwargs + ) + qk_norm_layer = get_norm_layer(qk_norm_type) + self.self_attn_q_norm = ( + qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) + if qk_norm + else nn.Identity() + ) + self.self_attn_k_norm = ( + qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) + if qk_norm + else nn.Identity() + ) + self.self_attn_proj = nn.Linear( + hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs + ) + + self.norm2 = nn.LayerNorm( + hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs + ) + act_layer = get_activation_layer(act_type) + + self.adaLN_modulation = nn.Sequential( + act_layer(), + nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs), + ) + # Zero-initialize the modulation + nn.init.zeros_(self.adaLN_modulation[1].weight) + nn.init.zeros_(self.adaLN_modulation[1].bias) + + def forward( + self, + x: torch.Tensor, + c: torch.Tensor, # timestep_aware_representations + context_aware_representations + attn_mask: torch.Tensor = None, + y: torch.Tensor=None, + + ): + gate_msa, gate_mlp = self.adaLN_modulation(c).chunk(2, dim=1) + + norm_x = self.norm1(x) + norm_y = self.norm1_2(y) + q = self.self_attn_q(norm_x) + q = rearrange(q, "B L (H D) -> B L H D", H=self.heads_num) + kv = self.self_attn_kv(norm_y) + k, v = rearrange(kv, "B L (K H D) -> K B L H D", K=2, H=self.heads_num) + # Apply QK-Norm if needed + q = self.self_attn_q_norm(q).to(v) + k = self.self_attn_k_norm(k).to(v) + + # Self-Attention + attn = attention(q, k, v, mode="torch", attn_mask=attn_mask) + + x = x + apply_gate(self.self_attn_proj(attn), gate_msa) + + return x + + + +class IndividualTokenRefiner(torch.nn.Module): + def __init__( + self, + hidden_size, + heads_num, + depth, + mlp_width_ratio: float = 4.0, + mlp_drop_rate: float = 0.0, + act_type: str = "silu", + qk_norm: bool = False, + qk_norm_type: str = "layer", + qkv_bias: bool = True, + need_CA:bool=False, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + ): + + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.need_CA = need_CA + self.blocks = nn.ModuleList( + [ + IndividualTokenRefinerBlock( + hidden_size=hidden_size, + heads_num=heads_num, + mlp_width_ratio=mlp_width_ratio, + mlp_drop_rate=mlp_drop_rate, + act_type=act_type, + qk_norm=qk_norm, + qk_norm_type=qk_norm_type, + qkv_bias=qkv_bias, + need_CA=self.need_CA, + **factory_kwargs, + ) + for _ in range(depth) + ] + ) + + + def forward( + self, + x: torch.Tensor, + c: torch.LongTensor, + mask: Optional[torch.Tensor] = None, + y:torch.Tensor=None, + ): + self_attn_mask = None + if mask is not None: + batch_size = mask.shape[0] + seq_len = mask.shape[1] + mask = mask.to(x.device) + # batch_size x 1 x seq_len x seq_len + self_attn_mask_1 = mask.view(batch_size, 1, 1, seq_len).repeat( + 1, 1, seq_len, 1 + ) + # batch_size x 1 x seq_len x seq_len + self_attn_mask_2 = self_attn_mask_1.transpose(2, 3) + # batch_size x 1 x seq_len x seq_len, 1 for broadcasting of heads_num + self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool() + # avoids self-attention weight being NaN for padding tokens + self_attn_mask[:, :, :, 0] = True + + + for block in self.blocks: + x = block(x, c, self_attn_mask,y) + + return x + + +class SingleTokenRefiner(torch.nn.Module): + """ + A single token refiner block for llm text embedding refine. + """ + def __init__( + self, + in_channels, + hidden_size, + heads_num, + depth, + mlp_width_ratio: float = 4.0, + mlp_drop_rate: float = 0.0, + act_type: str = "silu", + qk_norm: bool = False, + qk_norm_type: str = "layer", + qkv_bias: bool = True, + need_CA:bool=False, + attn_mode: str = "torch", + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + ): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.attn_mode = attn_mode + self.need_CA = need_CA + assert self.attn_mode == "torch", "Only support 'torch' mode for token refiner." + + self.input_embedder = nn.Linear( + in_channels, hidden_size, bias=True, **factory_kwargs + ) + if self.need_CA: + self.input_embedder_CA = nn.Linear( + in_channels, hidden_size, bias=True, **factory_kwargs + ) + + act_layer = get_activation_layer(act_type) + # Build timestep embedding layer + self.t_embedder = TimestepEmbedder(hidden_size, act_layer, **factory_kwargs) + # Build context embedding layer + self.c_embedder = TextProjection( + in_channels, hidden_size, act_layer, **factory_kwargs + ) + + self.individual_token_refiner = IndividualTokenRefiner( + hidden_size=hidden_size, + heads_num=heads_num, + depth=depth, + mlp_width_ratio=mlp_width_ratio, + mlp_drop_rate=mlp_drop_rate, + act_type=act_type, + qk_norm=qk_norm, + qk_norm_type=qk_norm_type, + qkv_bias=qkv_bias, + need_CA=need_CA, + **factory_kwargs, + ) + + def forward( + self, + x: torch.Tensor, + t: torch.LongTensor, + mask: Optional[torch.LongTensor] = None, + y: torch.LongTensor=None, + ): + timestep_aware_representations = self.t_embedder(t) + + if mask is None: + context_aware_representations = x.mean(dim=1) + else: + mask_float = mask.unsqueeze(-1) # [b, s1, 1] + context_aware_representations = (x * mask_float).sum( + dim=1 + ) / mask_float.sum(dim=1) + context_aware_representations = self.c_embedder(context_aware_representations) + c = timestep_aware_representations + context_aware_representations + + x = self.input_embedder(x) + if self.need_CA: + y = self.input_embedder_CA(y) + x = self.individual_token_refiner(x, c, mask, y) + else: + x = self.individual_token_refiner(x, c, mask) + + return x + + +class Qwen2Connector(torch.nn.Module): + def __init__( + self, + # biclip_dim=1024, + in_channels=3584, + hidden_size=4096, + heads_num=32, + depth=2, + need_CA=False, + device=None, + dtype=torch.bfloat16, + ): + super().__init__() + factory_kwargs = {"device": device, "dtype":dtype} + + self.S =SingleTokenRefiner(in_channels=in_channels,hidden_size=hidden_size,heads_num=heads_num,depth=depth,need_CA=need_CA,**factory_kwargs) + self.global_proj_out=nn.Linear(in_channels,768) + + self.scale_factor = nn.Parameter(torch.zeros(1)) + with torch.no_grad(): + self.scale_factor.data += -(1 - 0.09) + + def forward(self, x,t,mask): + mask_float = mask.unsqueeze(-1) # [b, s1, 1] + x_mean = (x * mask_float).sum( + dim=1 + ) / mask_float.sum(dim=1) * (1 + self.scale_factor.to(dtype=x.dtype, device=x.device)) + + global_out=self.global_proj_out(x_mean) + encoder_hidden_states = self.S(x,t,mask) + return encoder_hidden_states,global_out diff --git a/diffsynth/models/step1x_text_encoder.py b/diffsynth/models/step1x_text_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..5d144236a33c923b8650057b420681ea48006efd --- /dev/null +++ b/diffsynth/models/step1x_text_encoder.py @@ -0,0 +1,195 @@ +import torch +from typing import Optional, Union +from .qwen_image_text_encoder import QwenImageTextEncoder +from ..core.device.npu_compatible_device import get_device_type, get_torch_device + + +class Step1xEditEmbedder(torch.nn.Module): + def __init__(self, model: QwenImageTextEncoder, processor, max_length=640, dtype=torch.bfloat16, device=get_device_type()): + super().__init__() + self.max_length = max_length + self.dtype = dtype + self.device = device + + Qwen25VL_7b_PREFIX = '''Given a user prompt, generate an "Enhanced prompt" that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt: +- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes. +- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.\n +Here are examples of how to transform or refine prompts: +- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers. +- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.\n +Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations: +User Prompt:''' + + self.prefix = Qwen25VL_7b_PREFIX + self.model = model + self.processor = processor + + def model_forward( + self, + model: QwenImageTextEncoder, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + pixel_values: Optional[torch.Tensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + rope_deltas: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + second_per_grid_ts: Optional[torch.Tensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs, + ): + output_attentions = output_attentions if output_attentions is not None else model.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else model.config.output_hidden_states + ) + + outputs = model.model( + input_ids=input_ids, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + cache_position=cache_position, + **kwargs, + ) + return outputs.hidden_states + + def forward(self, caption, ref_images): + text_list = caption + embs = torch.zeros( + len(text_list), + self.max_length, + self.model.config.hidden_size, + dtype=torch.bfloat16, + device=get_torch_device().current_device(), + ) + masks = torch.zeros( + len(text_list), + self.max_length, + dtype=torch.long, + device=get_torch_device().current_device(), + ) + + def split_string(s): + s = s.replace("“", '"').replace("”", '"').replace("'", '''"''') # use english quotes + result = [] + in_quotes = False + temp = "" + + for idx,char in enumerate(s): + if char == '"' and idx>155: + temp += char + if not in_quotes: + result.append(temp) + temp = "" + + in_quotes = not in_quotes + continue + if in_quotes: + if char.isspace(): + pass # have space token + + result.append("“" + char + "”") + else: + temp += char + + if temp: + result.append(temp) + + return result + + for idx, (txt, imgs) in enumerate(zip(text_list, ref_images)): + + messages = [{"role": "user", "content": []}] + + messages[0]["content"].append({"type": "text", "text": f"{self.prefix}"}) + + messages[0]["content"].append({"type": "image", "image": imgs}) + + # 再添加 text + messages[0]["content"].append({"type": "text", "text": f"{txt}"}) + + # Preparation for inference + text = self.processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True, add_vision_id=True + ) + + image_inputs = [imgs] + + inputs = self.processor( + text=[text], + images=image_inputs, + padding=True, + return_tensors="pt", + ) + + old_inputs_ids = inputs.input_ids + text_split_list = split_string(text) + + token_list = [] + for text_each in text_split_list: + txt_inputs = self.processor( + text=text_each, + images=None, + videos=None, + padding=True, + return_tensors="pt", + ) + token_each = txt_inputs.input_ids + if token_each[0][0] == 2073 and token_each[0][-1] == 854: + token_each = token_each[:, 1:-1] + token_list.append(token_each) + else: + token_list.append(token_each) + + new_txt_ids = torch.cat(token_list, dim=1).to(get_device_type()) + + new_txt_ids = new_txt_ids.to(old_inputs_ids.device) + + idx1 = (old_inputs_ids == 151653).nonzero(as_tuple=True)[1][0] + idx2 = (new_txt_ids == 151653).nonzero(as_tuple=True)[1][0] + inputs.input_ids = ( + torch.cat([old_inputs_ids[0, :idx1], new_txt_ids[0, idx2:]], dim=0) + .unsqueeze(0) + .to(get_device_type()) + ) + inputs.attention_mask = (inputs.input_ids > 0).long().to(get_device_type()) + outputs = self.model_forward( + self.model, + input_ids=inputs.input_ids, + attention_mask=inputs.attention_mask, + pixel_values=inputs.pixel_values.to(get_device_type()), + image_grid_thw=inputs.image_grid_thw.to(get_device_type()), + output_hidden_states=True, + ) + + emb = outputs[-1] + + embs[idx, : min(self.max_length, emb.shape[1] - 217)] = emb[0, 217:][ + : self.max_length + ] + + masks[idx, : min(self.max_length, emb.shape[1] - 217)] = torch.ones( + (min(self.max_length, emb.shape[1] - 217)), + dtype=torch.long, + device=get_torch_device().current_device(), + ) + + return embs, masks diff --git a/diffsynth/models/wan_video_animate_adapter.py b/diffsynth/models/wan_video_animate_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..3ace70d8b3162e77844f0957cd40207a54e674a9 --- /dev/null +++ b/diffsynth/models/wan_video_animate_adapter.py @@ -0,0 +1,650 @@ +import torch +import torch.nn as nn +from torch.nn import functional as F +import math +from typing import Tuple, Optional, List +from einops import rearrange + + + +MEMORY_LAYOUT = { + "flash": ( + lambda x: x.view(x.shape[0] * x.shape[1], *x.shape[2:]), + lambda x: x, + ), + "torch": ( + lambda x: x.transpose(1, 2), + lambda x: x.transpose(1, 2), + ), + "vanilla": ( + lambda x: x.transpose(1, 2), + lambda x: x.transpose(1, 2), + ), +} + + +def attention( + q, + k, + v, + mode="torch", + drop_rate=0, + attn_mask=None, + causal=False, + max_seqlen_q=None, + batch_size=1, +): + pre_attn_layout, post_attn_layout = MEMORY_LAYOUT[mode] + + if mode == "torch": + if attn_mask is not None and attn_mask.dtype != torch.bool: + attn_mask = attn_mask.to(q.dtype) + x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal) + + x = post_attn_layout(x) + b, s, a, d = x.shape + out = x.reshape(b, s, -1) + return out + + +class CausalConv1d(nn.Module): + + def __init__(self, chan_in, chan_out, kernel_size=3, stride=1, dilation=1, pad_mode="replicate", **kwargs): + super().__init__() + + self.pad_mode = pad_mode + padding = (kernel_size - 1, 0) # T + self.time_causal_padding = padding + + self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs) + + def forward(self, x): + x = F.pad(x, self.time_causal_padding, mode=self.pad_mode) + return self.conv(x) + + + +class FaceEncoder(nn.Module): + def __init__(self, in_dim: int, hidden_dim: int, num_heads=int, dtype=None, device=None): + factory_kwargs = {"dtype": dtype, "device": device} + super().__init__() + + self.num_heads = num_heads + self.conv1_local = CausalConv1d(in_dim, 1024 * num_heads, 3, stride=1) + self.norm1 = nn.LayerNorm(hidden_dim // 8, elementwise_affine=False, eps=1e-6, **factory_kwargs) + self.act = nn.SiLU() + self.conv2 = CausalConv1d(1024, 1024, 3, stride=2) + self.conv3 = CausalConv1d(1024, 1024, 3, stride=2) + + self.out_proj = nn.Linear(1024, hidden_dim) + self.norm1 = nn.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs) + + self.norm2 = nn.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs) + + self.norm3 = nn.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs) + + self.padding_tokens = nn.Parameter(torch.zeros(1, 1, 1, hidden_dim)) + + def forward(self, x): + + x = rearrange(x, "b t c -> b c t") + b, c, t = x.shape + + x = self.conv1_local(x) + x = rearrange(x, "b (n c) t -> (b n) t c", n=self.num_heads) + + x = self.norm1(x) + x = self.act(x) + x = rearrange(x, "b t c -> b c t") + x = self.conv2(x) + x = rearrange(x, "b c t -> b t c") + x = self.norm2(x) + x = self.act(x) + x = rearrange(x, "b t c -> b c t") + x = self.conv3(x) + x = rearrange(x, "b c t -> b t c") + x = self.norm3(x) + x = self.act(x) + x = self.out_proj(x) + x = rearrange(x, "(b n) t c -> b t n c", b=b) + padding = self.padding_tokens.repeat(b, x.shape[1], 1, 1) + x = torch.cat([x, padding], dim=-2) + x_local = x.clone() + + return x_local + + + +class RMSNorm(nn.Module): + def __init__( + self, + dim: int, + elementwise_affine=True, + eps: float = 1e-6, + device=None, + dtype=None, + ): + """ + Initialize the RMSNorm normalization layer. + + Args: + dim (int): The dimension of the input tensor. + eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. + + Attributes: + eps (float): A small value added to the denominator for numerical stability. + weight (nn.Parameter): Learnable scaling parameter. + + """ + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.eps = eps + if elementwise_affine: + self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs)) + + def _norm(self, x): + """ + Apply the RMSNorm normalization to the input tensor. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The normalized tensor. + + """ + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + """ + Forward pass through the RMSNorm layer. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The output tensor after applying RMSNorm. + + """ + output = self._norm(x.float()).type_as(x) + if hasattr(self, "weight"): + output = output * self.weight + return output + + +def get_norm_layer(norm_layer): + """ + Get the normalization layer. + + Args: + norm_layer (str): The type of normalization layer. + + Returns: + norm_layer (nn.Module): The normalization layer. + """ + if norm_layer == "layer": + return nn.LayerNorm + elif norm_layer == "rms": + return RMSNorm + else: + raise NotImplementedError(f"Norm layer {norm_layer} is not implemented") + + +class FaceAdapter(nn.Module): + def __init__( + self, + hidden_dim: int, + heads_num: int, + qk_norm: bool = True, + qk_norm_type: str = "rms", + num_adapter_layers: int = 1, + dtype=None, + device=None, + ): + + factory_kwargs = {"dtype": dtype, "device": device} + super().__init__() + self.hidden_size = hidden_dim + self.heads_num = heads_num + self.fuser_blocks = nn.ModuleList( + [ + FaceBlock( + self.hidden_size, + self.heads_num, + qk_norm=qk_norm, + qk_norm_type=qk_norm_type, + **factory_kwargs, + ) + for _ in range(num_adapter_layers) + ] + ) + + def forward( + self, + x: torch.Tensor, + motion_embed: torch.Tensor, + idx: int, + freqs_cis_q: Tuple[torch.Tensor, torch.Tensor] = None, + freqs_cis_k: Tuple[torch.Tensor, torch.Tensor] = None, + ) -> torch.Tensor: + + return self.fuser_blocks[idx](x, motion_embed, freqs_cis_q, freqs_cis_k) + + + +class FaceBlock(nn.Module): + def __init__( + self, + hidden_size: int, + heads_num: int, + qk_norm: bool = True, + qk_norm_type: str = "rms", + qk_scale: float = None, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + ): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + + self.deterministic = False + self.hidden_size = hidden_size + self.heads_num = heads_num + head_dim = hidden_size // heads_num + self.scale = qk_scale or head_dim**-0.5 + + self.linear1_kv = nn.Linear(hidden_size, hidden_size * 2, **factory_kwargs) + self.linear1_q = nn.Linear(hidden_size, hidden_size, **factory_kwargs) + + self.linear2 = nn.Linear(hidden_size, hidden_size, **factory_kwargs) + + qk_norm_layer = get_norm_layer(qk_norm_type) + self.q_norm = ( + qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity() + ) + self.k_norm = ( + qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity() + ) + + self.pre_norm_feat = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs) + + self.pre_norm_motion = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs) + + def forward( + self, + x: torch.Tensor, + motion_vec: torch.Tensor, + motion_mask: Optional[torch.Tensor] = None, + use_context_parallel=False, + ) -> torch.Tensor: + + B, T, N, C = motion_vec.shape + T_comp = T + + x_motion = self.pre_norm_motion(motion_vec) + x_feat = self.pre_norm_feat(x) + + kv = self.linear1_kv(x_motion) + q = self.linear1_q(x_feat) + + k, v = rearrange(kv, "B L N (K H D) -> K B L N H D", K=2, H=self.heads_num) + q = rearrange(q, "B S (H D) -> B S H D", H=self.heads_num) + + # Apply QK-Norm if needed. + q = self.q_norm(q).to(v) + k = self.k_norm(k).to(v) + + k = rearrange(k, "B L N H D -> (B L) H N D") + v = rearrange(v, "B L N H D -> (B L) H N D") + + q = rearrange(q, "B (L S) H D -> (B L) H S D", L=T_comp) + # Compute attention. + attn = F.scaled_dot_product_attention(q, k, v) + + attn = rearrange(attn, "(B L) H S D -> B (L S) (H D)", L=T_comp) + + output = self.linear2(attn) + + if motion_mask is not None: + output = output * rearrange(motion_mask, "B T H W -> B (T H W)").unsqueeze(-1) + + return output + + + +def custom_qr(input_tensor): + original_dtype = input_tensor.dtype + if original_dtype == torch.bfloat16: + q, r = torch.linalg.qr(input_tensor.to(torch.float32)) + return q.to(original_dtype), r.to(original_dtype) + return torch.linalg.qr(input_tensor) + +def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5): + return F.leaky_relu(input + bias, negative_slope) * scale + + +def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1): + _, minor, in_h, in_w = input.shape + kernel_h, kernel_w = kernel.shape + + out = input.view(-1, minor, in_h, 1, in_w, 1) + out = F.pad(out, [0, up_x - 1, 0, 0, 0, up_y - 1, 0, 0]) + out = out.view(-1, minor, in_h * up_y, in_w * up_x) + + out = F.pad(out, [max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]) + out = out[:, :, max(-pad_y0, 0): out.shape[2] - max(-pad_y1, 0), + max(-pad_x0, 0): out.shape[3] - max(-pad_x1, 0), ] + + out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]) + w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) + out = F.conv2d(out, w) + out = out.reshape(-1, minor, in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, + in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, ) + return out[:, :, ::down_y, ::down_x] + + +def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): + return upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1]) + + +def make_kernel(k): + k = torch.tensor(k, dtype=torch.float32) + if k.ndim == 1: + k = k[None, :] * k[:, None] + k /= k.sum() + return k + + +class FusedLeakyReLU(nn.Module): + def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5): + super().__init__() + self.bias = nn.Parameter(torch.zeros(1, channel, 1, 1)) + self.negative_slope = negative_slope + self.scale = scale + + def forward(self, input): + out = fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) + return out + + +class Blur(nn.Module): + def __init__(self, kernel, pad, upsample_factor=1): + super().__init__() + + kernel = make_kernel(kernel) + + if upsample_factor > 1: + kernel = kernel * (upsample_factor ** 2) + + self.kernel = torch.nn.Parameter(kernel) + + self.pad = pad + + def forward(self, input): + return upfirdn2d(input, self.kernel, pad=self.pad) + + +class ScaledLeakyReLU(nn.Module): + def __init__(self, negative_slope=0.2): + super().__init__() + + self.negative_slope = negative_slope + + def forward(self, input): + return F.leaky_relu(input, negative_slope=self.negative_slope) + + +class EqualConv2d(nn.Module): + def __init__(self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True): + super().__init__() + + self.weight = nn.Parameter(torch.randn(out_channel, in_channel, kernel_size, kernel_size)) + self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2) + + self.stride = stride + self.padding = padding + + if bias: + self.bias = nn.Parameter(torch.zeros(out_channel)) + else: + self.bias = None + + def forward(self, input): + + return F.conv2d(input, self.weight * self.scale, bias=self.bias, stride=self.stride, padding=self.padding) + + def __repr__(self): + return ( + f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},' + f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})' + ) + + +class EqualLinear(nn.Module): + def __init__(self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None): + super().__init__() + + self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul)) + + if bias: + self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init)) + else: + self.bias = None + + self.activation = activation + + self.scale = (1 / math.sqrt(in_dim)) * lr_mul + self.lr_mul = lr_mul + + def forward(self, input): + + if self.activation: + out = F.linear(input, self.weight * self.scale) + out = fused_leaky_relu(out, self.bias * self.lr_mul) + else: + out = F.linear(input, self.weight * self.scale, bias=self.bias * self.lr_mul) + + return out + + def __repr__(self): + return (f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})') + + +class ConvLayer(nn.Sequential): + def __init__( + self, + in_channel, + out_channel, + kernel_size, + downsample=False, + blur_kernel=[1, 3, 3, 1], + bias=True, + activate=True, + ): + layers = [] + + if downsample: + factor = 2 + p = (len(blur_kernel) - factor) + (kernel_size - 1) + pad0 = (p + 1) // 2 + pad1 = p // 2 + + layers.append(Blur(blur_kernel, pad=(pad0, pad1))) + + stride = 2 + self.padding = 0 + + else: + stride = 1 + self.padding = kernel_size // 2 + + layers.append(EqualConv2d(in_channel, out_channel, kernel_size, padding=self.padding, stride=stride, + bias=bias and not activate)) + + if activate: + if bias: + layers.append(FusedLeakyReLU(out_channel)) + else: + layers.append(ScaledLeakyReLU(0.2)) + + super().__init__(*layers) + + +class ResBlock(nn.Module): + def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]): + super().__init__() + + self.conv1 = ConvLayer(in_channel, in_channel, 3) + self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True) + + self.skip = ConvLayer(in_channel, out_channel, 1, downsample=True, activate=False, bias=False) + + def forward(self, input): + out = self.conv1(input) + out = self.conv2(out) + + skip = self.skip(input) + out = (out + skip) / math.sqrt(2) + + return out + + +class EncoderApp(nn.Module): + def __init__(self, size, w_dim=512): + super(EncoderApp, self).__init__() + + channels = { + 4: 512, + 8: 512, + 16: 512, + 32: 512, + 64: 256, + 128: 128, + 256: 64, + 512: 32, + 1024: 16 + } + + self.w_dim = w_dim + log_size = int(math.log(size, 2)) + + self.convs = nn.ModuleList() + self.convs.append(ConvLayer(3, channels[size], 1)) + + in_channel = channels[size] + for i in range(log_size, 2, -1): + out_channel = channels[2 ** (i - 1)] + self.convs.append(ResBlock(in_channel, out_channel)) + in_channel = out_channel + + self.convs.append(EqualConv2d(in_channel, self.w_dim, 4, padding=0, bias=False)) + + def forward(self, x): + + res = [] + h = x + for conv in self.convs: + h = conv(h) + res.append(h) + + return res[-1].squeeze(-1).squeeze(-1), res[::-1][2:] + + +class Encoder(nn.Module): + def __init__(self, size, dim=512, dim_motion=20): + super(Encoder, self).__init__() + + # appearance netmork + self.net_app = EncoderApp(size, dim) + + # motion network + fc = [EqualLinear(dim, dim)] + for i in range(3): + fc.append(EqualLinear(dim, dim)) + + fc.append(EqualLinear(dim, dim_motion)) + self.fc = nn.Sequential(*fc) + + def enc_app(self, x): + h_source = self.net_app(x) + return h_source + + def enc_motion(self, x): + h, _ = self.net_app(x) + h_motion = self.fc(h) + return h_motion + + +class Direction(nn.Module): + def __init__(self, motion_dim): + super(Direction, self).__init__() + self.weight = nn.Parameter(torch.randn(512, motion_dim)) + + def forward(self, input): + + weight = self.weight + 1e-8 + Q, R = custom_qr(weight) + if input is None: + return Q + else: + input_diag = torch.diag_embed(input) # alpha, diagonal matrix + out = torch.matmul(input_diag, Q.T) + out = torch.sum(out, dim=1) + return out + + +class Synthesis(nn.Module): + def __init__(self, motion_dim): + super(Synthesis, self).__init__() + self.direction = Direction(motion_dim) + + +class Generator(nn.Module): + def __init__(self, size, style_dim=512, motion_dim=20): + super().__init__() + + self.enc = Encoder(size, style_dim, motion_dim) + self.dec = Synthesis(motion_dim) + + def get_motion(self, img): + #motion_feat = self.enc.enc_motion(img) + motion_feat = torch.utils.checkpoint.checkpoint((self.enc.enc_motion), img, use_reentrant=True) + motion = self.dec.direction(motion_feat) + return motion + + +class WanAnimateAdapter(torch.nn.Module): + def __init__(self): + super().__init__() + self.pose_patch_embedding = torch.nn.Conv3d(16, 5120, kernel_size=(1, 2, 2), stride=(1, 2, 2)) + self.motion_encoder = Generator(size=512, style_dim=512, motion_dim=20) + self.face_adapter = FaceAdapter(heads_num=40, hidden_dim=5120, num_adapter_layers=40 // 5) + self.face_encoder = FaceEncoder(in_dim=512, hidden_dim=5120, num_heads=4) + + def after_patch_embedding(self, x: List[torch.Tensor], pose_latents, face_pixel_values): + pose_latents = self.pose_patch_embedding(pose_latents) + x[:, :, 1:] += pose_latents + + b,c,T,h,w = face_pixel_values.shape + face_pixel_values = rearrange(face_pixel_values, "b c t h w -> (b t) c h w") + + encode_bs = 8 + face_pixel_values_tmp = [] + for i in range(math.ceil(face_pixel_values.shape[0]/encode_bs)): + face_pixel_values_tmp.append(self.motion_encoder.get_motion(face_pixel_values[i*encode_bs:(i+1)*encode_bs])) + + motion_vec = torch.cat(face_pixel_values_tmp) + + motion_vec = rearrange(motion_vec, "(b t) c -> b t c", t=T) + motion_vec = self.face_encoder(motion_vec) + + B, L, H, C = motion_vec.shape + pad_face = torch.zeros(B, 1, H, C).type_as(motion_vec) + motion_vec = torch.cat([pad_face, motion_vec], dim=1) + return x, motion_vec + + def after_transformer_block(self, block_idx, x, motion_vec, motion_masks=None): + if block_idx % 5 == 0: + adapter_args = [x, motion_vec, motion_masks, False] + residual_out = self.face_adapter.fuser_blocks[block_idx // 5](*adapter_args) + x = residual_out + x + return x diff --git a/diffsynth/models/wan_video_camera_controller.py b/diffsynth/models/wan_video_camera_controller.py new file mode 100644 index 0000000000000000000000000000000000000000..45a44ee6bcd408d7ee9d18653f933151ce351a72 --- /dev/null +++ b/diffsynth/models/wan_video_camera_controller.py @@ -0,0 +1,206 @@ +import torch +import torch.nn as nn +import numpy as np +from einops import rearrange +import os +from typing_extensions import Literal + +class SimpleAdapter(nn.Module): + def __init__(self, in_dim, out_dim, kernel_size, stride, num_residual_blocks=1): + super(SimpleAdapter, self).__init__() + + # Pixel Unshuffle: reduce spatial dimensions by a factor of 8 + self.pixel_unshuffle = nn.PixelUnshuffle(downscale_factor=8) + + # Convolution: reduce spatial dimensions by a factor + # of 2 (without overlap) + self.conv = nn.Conv2d(in_dim * 64, out_dim, kernel_size=kernel_size, stride=stride, padding=0) + + # Residual blocks for feature extraction + self.residual_blocks = nn.Sequential( + *[ResidualBlock(out_dim) for _ in range(num_residual_blocks)] + ) + + def forward(self, x): + # Reshape to merge the frame dimension into batch + bs, c, f, h, w = x.size() + x = x.permute(0, 2, 1, 3, 4).contiguous().view(bs * f, c, h, w) + + # Pixel Unshuffle operation + x_unshuffled = self.pixel_unshuffle(x) + + # Convolution operation + x_conv = self.conv(x_unshuffled) + + # Feature extraction with residual blocks + out = self.residual_blocks(x_conv) + + # Reshape to restore original bf dimension + out = out.view(bs, f, out.size(1), out.size(2), out.size(3)) + + # Permute dimensions to reorder (if needed), e.g., swap channels and feature frames + out = out.permute(0, 2, 1, 3, 4) + + return out + + def process_camera_coordinates( + self, + direction: Literal["Left", "Right", "Up", "Down", "LeftUp", "LeftDown", "RightUp", "RightDown"], + length: int, + height: int, + width: int, + speed: float = 1/54, + origin=(0, 0.532139961, 0.946026558, 0.5, 0.5, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0) + ): + if origin is None: + origin = (0, 0.532139961, 0.946026558, 0.5, 0.5, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0) + coordinates = generate_camera_coordinates(direction, length, speed, origin) + plucker_embedding = process_pose_file(coordinates, width, height) + return plucker_embedding + + + +class ResidualBlock(nn.Module): + def __init__(self, dim): + super(ResidualBlock, self).__init__() + self.conv1 = nn.Conv2d(dim, dim, kernel_size=3, padding=1) + self.relu = nn.ReLU(inplace=True) + self.conv2 = nn.Conv2d(dim, dim, kernel_size=3, padding=1) + + def forward(self, x): + residual = x + out = self.relu(self.conv1(x)) + out = self.conv2(out) + out += residual + return out + +class Camera(object): + """Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py + """ + def __init__(self, entry): + fx, fy, cx, cy = entry[1:5] + self.fx = fx + self.fy = fy + self.cx = cx + self.cy = cy + w2c_mat = np.array(entry[7:]).reshape(3, 4) + w2c_mat_4x4 = np.eye(4) + w2c_mat_4x4[:3, :] = w2c_mat + self.w2c_mat = w2c_mat_4x4 + self.c2w_mat = np.linalg.inv(w2c_mat_4x4) + +def get_relative_pose(cam_params): + """Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py + """ + abs_w2cs = [cam_param.w2c_mat for cam_param in cam_params] + abs_c2ws = [cam_param.c2w_mat for cam_param in cam_params] + cam_to_origin = 0 + target_cam_c2w = np.array([ + [1, 0, 0, 0], + [0, 1, 0, -cam_to_origin], + [0, 0, 1, 0], + [0, 0, 0, 1] + ]) + abs2rel = target_cam_c2w @ abs_w2cs[0] + ret_poses = [target_cam_c2w, ] + [abs2rel @ abs_c2w for abs_c2w in abs_c2ws[1:]] + ret_poses = np.array(ret_poses, dtype=np.float32) + return ret_poses + +def custom_meshgrid(*args): + # torch>=2.0.0 only + return torch.meshgrid(*args, indexing='ij') + + +def ray_condition(K, c2w, H, W, device): + """Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py + """ + # c2w: B, V, 4, 4 + # K: B, V, 4 + + B = K.shape[0] + + j, i = custom_meshgrid( + torch.linspace(0, H - 1, H, device=device, dtype=c2w.dtype), + torch.linspace(0, W - 1, W, device=device, dtype=c2w.dtype), + ) + i = i.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW] + j = j.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW] + + fx, fy, cx, cy = K.chunk(4, dim=-1) # B,V, 1 + + zs = torch.ones_like(i) # [B, HxW] + xs = (i - cx) / fx * zs + ys = (j - cy) / fy * zs + zs = zs.expand_as(ys) + + directions = torch.stack((xs, ys, zs), dim=-1) # B, V, HW, 3 + directions = directions / directions.norm(dim=-1, keepdim=True) # B, V, HW, 3 + + rays_d = directions @ c2w[..., :3, :3].transpose(-1, -2) # B, V, 3, HW + rays_o = c2w[..., :3, 3] # B, V, 3 + rays_o = rays_o[:, :, None].expand_as(rays_d) # B, V, 3, HW + # c2w @ dirctions + rays_dxo = torch.linalg.cross(rays_o, rays_d) + plucker = torch.cat([rays_dxo, rays_d], dim=-1) + plucker = plucker.reshape(B, c2w.shape[1], H, W, 6) # B, V, H, W, 6 + # plucker = plucker.permute(0, 1, 4, 2, 3) + return plucker + + +def process_pose_file(cam_params, width=672, height=384, original_pose_width=1280, original_pose_height=720, device='cpu', return_poses=False): + if return_poses: + return cam_params + else: + cam_params = [Camera(cam_param) for cam_param in cam_params] + + sample_wh_ratio = width / height + pose_wh_ratio = original_pose_width / original_pose_height # Assuming placeholder ratios, change as needed + + if pose_wh_ratio > sample_wh_ratio: + resized_ori_w = height * pose_wh_ratio + for cam_param in cam_params: + cam_param.fx = resized_ori_w * cam_param.fx / width + else: + resized_ori_h = width / pose_wh_ratio + for cam_param in cam_params: + cam_param.fy = resized_ori_h * cam_param.fy / height + + intrinsic = np.asarray([[cam_param.fx * width, + cam_param.fy * height, + cam_param.cx * width, + cam_param.cy * height] + for cam_param in cam_params], dtype=np.float32) + + K = torch.as_tensor(intrinsic)[None] # [1, 1, 4] + c2ws = get_relative_pose(cam_params) # Assuming this function is defined elsewhere + c2ws = torch.as_tensor(c2ws)[None] # [1, n_frame, 4, 4] + plucker_embedding = ray_condition(K, c2ws, height, width, device=device)[0].permute(0, 3, 1, 2).contiguous() # V, 6, H, W + plucker_embedding = plucker_embedding[None] + plucker_embedding = rearrange(plucker_embedding, "b f c h w -> b f h w c")[0] + return plucker_embedding + + + +def generate_camera_coordinates( + direction: Literal["Left", "Right", "Up", "Down", "LeftUp", "LeftDown", "RightUp", "RightDown", "In", "Out"], + length: int, + speed: float = 1/54, + origin=(0, 0.532139961, 0.946026558, 0.5, 0.5, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0) +): + coordinates = [list(origin)] + while len(coordinates) < length: + coor = coordinates[-1].copy() + if "Left" in direction: + coor[9] += speed + if "Right" in direction: + coor[9] -= speed + if "Up" in direction: + coor[13] += speed + if "Down" in direction: + coor[13] -= speed + if "In" in direction: + coor[18] -= speed + if "Out" in direction: + coor[18] += speed + coordinates.append(coor) + return coordinates diff --git a/diffsynth/models/wan_video_dit.py b/diffsynth/models/wan_video_dit.py new file mode 100644 index 0000000000000000000000000000000000000000..52f607e10a66a13180438c92be2eb32e24a5fc85 --- /dev/null +++ b/diffsynth/models/wan_video_dit.py @@ -0,0 +1,551 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import math +from typing import Tuple, Optional +from einops import rearrange +from .wan_video_camera_controller import SimpleAdapter +from ..core.gradient import gradient_checkpoint_forward +from .wantodance import WanToDanceRotaryEmbedding, WanToDanceMusicEncoderLayer + +try: + import flash_attn_interface + FLASH_ATTN_3_AVAILABLE = True +except ModuleNotFoundError: + FLASH_ATTN_3_AVAILABLE = False + +try: + import flash_attn + FLASH_ATTN_2_AVAILABLE = True +except ModuleNotFoundError: + FLASH_ATTN_2_AVAILABLE = False + +try: + from sageattention import sageattn + SAGE_ATTN_AVAILABLE = True +except ModuleNotFoundError: + SAGE_ATTN_AVAILABLE = False + + +def flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads: int, compatibility_mode=False): + if compatibility_mode: + q = rearrange(q, "b s (n d) -> b n s d", n=num_heads) + k = rearrange(k, "b s (n d) -> b n s d", n=num_heads) + v = rearrange(v, "b s (n d) -> b n s d", n=num_heads) + x = F.scaled_dot_product_attention(q, k, v) + x = rearrange(x, "b n s d -> b s (n d)", n=num_heads) + elif FLASH_ATTN_3_AVAILABLE: + q = rearrange(q, "b s (n d) -> b s n d", n=num_heads) + k = rearrange(k, "b s (n d) -> b s n d", n=num_heads) + v = rearrange(v, "b s (n d) -> b s n d", n=num_heads) + x = flash_attn_interface.flash_attn_func(q, k, v) + if isinstance(x,tuple): + x = x[0] + x = rearrange(x, "b s n d -> b s (n d)", n=num_heads) + elif FLASH_ATTN_2_AVAILABLE: + q = rearrange(q, "b s (n d) -> b s n d", n=num_heads) + k = rearrange(k, "b s (n d) -> b s n d", n=num_heads) + v = rearrange(v, "b s (n d) -> b s n d", n=num_heads) + x = flash_attn.flash_attn_func(q, k, v) + x = rearrange(x, "b s n d -> b s (n d)", n=num_heads) + elif SAGE_ATTN_AVAILABLE: + q = rearrange(q, "b s (n d) -> b n s d", n=num_heads) + k = rearrange(k, "b s (n d) -> b n s d", n=num_heads) + v = rearrange(v, "b s (n d) -> b n s d", n=num_heads) + x = sageattn(q, k, v) + x = rearrange(x, "b n s d -> b s (n d)", n=num_heads) + else: + q = rearrange(q, "b s (n d) -> b n s d", n=num_heads) + k = rearrange(k, "b s (n d) -> b n s d", n=num_heads) + v = rearrange(v, "b s (n d) -> b n s d", n=num_heads) + x = F.scaled_dot_product_attention(q, k, v) + x = rearrange(x, "b n s d -> b s (n d)", n=num_heads) + return x + + +def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor): + return (x * (1 + scale) + shift) + + +def sinusoidal_embedding_1d(dim, position): + sinusoid = torch.outer(position.type(torch.float64), torch.pow( + 10000, -torch.arange(dim//2, dtype=torch.float64, device=position.device).div(dim//2))) + x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1) + return x.to(position.dtype) + + +def precompute_freqs_cis_3d(dim: int, end: int = 1024, theta: float = 10000.0): + # 3d rope precompute + f_freqs_cis = precompute_freqs_cis(dim - 2 * (dim // 3), end, theta) + h_freqs_cis = precompute_freqs_cis(dim // 3, end, theta) + w_freqs_cis = precompute_freqs_cis(dim // 3, end, theta) + return f_freqs_cis, h_freqs_cis, w_freqs_cis + + +def precompute_freqs_cis(dim: int, end: int = 1024, theta: float = 10000.0): + # 1d rope precompute + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2) + [: (dim // 2)].double() / dim)) + freqs = torch.outer(torch.arange(end, device=freqs.device), freqs) + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 + return freqs_cis + + +def rope_apply(x, freqs, num_heads): + x = rearrange(x, "b s (n d) -> b s n d", n=num_heads) + x_out = torch.view_as_complex(x.to(torch.float64).reshape( + x.shape[0], x.shape[1], x.shape[2], -1, 2)) + freqs = freqs.to(torch.complex64) if freqs.device.type == "npu" else freqs + x_out = torch.view_as_real(x_out * freqs).flatten(2) + return x_out.to(x.dtype) + + +def set_to_torch_norm(models): + for model in models: + for module in model.modules(): + if isinstance(module, RMSNorm): + module.use_torch_norm = True + + +class RMSNorm(nn.Module): + def __init__(self, dim, eps=1e-5): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + self.use_torch_norm = False + self.normalized_shape = (dim,) + + def norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps) + + def forward(self, x): + dtype = x.dtype + if self.use_torch_norm: + return F.rms_norm(x, self.normalized_shape, self.weight, self.eps) + else: + return self.norm(x.float()).to(dtype) * self.weight + + +class AttentionModule(nn.Module): + def __init__(self, num_heads): + super().__init__() + self.num_heads = num_heads + + def forward(self, q, k, v): + x = flash_attention(q=q, k=k, v=v, num_heads=self.num_heads) + return x + + +class SelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, eps: float = 1e-6): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + + self.q = nn.Linear(dim, dim) + self.k = nn.Linear(dim, dim) + self.v = nn.Linear(dim, dim) + self.o = nn.Linear(dim, dim) + self.norm_q = RMSNorm(dim, eps=eps) + self.norm_k = RMSNorm(dim, eps=eps) + + self.attn = AttentionModule(self.num_heads) + + def forward(self, x, freqs): + q = self.norm_q(self.q(x)) + k = self.norm_k(self.k(x)) + v = self.v(x) + q = rope_apply(q, freqs, self.num_heads) + k = rope_apply(k, freqs, self.num_heads) + x = self.attn(q, k, v) + return self.o(x) + + +class CrossAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, eps: float = 1e-6, has_image_input: bool = False): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + + self.q = nn.Linear(dim, dim) + self.k = nn.Linear(dim, dim) + self.v = nn.Linear(dim, dim) + self.o = nn.Linear(dim, dim) + self.norm_q = RMSNorm(dim, eps=eps) + self.norm_k = RMSNorm(dim, eps=eps) + self.has_image_input = has_image_input + if has_image_input: + self.k_img = nn.Linear(dim, dim) + self.v_img = nn.Linear(dim, dim) + self.norm_k_img = RMSNorm(dim, eps=eps) + + self.attn = AttentionModule(self.num_heads) + + def forward(self, x: torch.Tensor, y: torch.Tensor): + if self.has_image_input: + img = y[:, :257] + ctx = y[:, 257:] + else: + ctx = y + q = self.norm_q(self.q(x)) + k = self.norm_k(self.k(ctx)) + v = self.v(ctx) + x = self.attn(q, k, v) + if self.has_image_input: + k_img = self.norm_k_img(self.k_img(img)) + v_img = self.v_img(img) + y = flash_attention(q, k_img, v_img, num_heads=self.num_heads) + x = x + y + return self.o(x) + + +class GateModule(nn.Module): + def __init__(self,): + super().__init__() + + def forward(self, x, gate, residual): + return x + gate * residual + +class DiTBlock(nn.Module): + def __init__(self, has_image_input: bool, dim: int, num_heads: int, ffn_dim: int, eps: float = 1e-6): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.ffn_dim = ffn_dim + + self.self_attn = SelfAttention(dim, num_heads, eps) + self.cross_attn = CrossAttention( + dim, num_heads, eps, has_image_input=has_image_input) + self.norm1 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False) + self.norm2 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False) + self.norm3 = nn.LayerNorm(dim, eps=eps) + self.ffn = nn.Sequential(nn.Linear(dim, ffn_dim), nn.GELU( + approximate='tanh'), nn.Linear(ffn_dim, dim)) + self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5) + self.gate = GateModule() + + def forward(self, x, context, t_mod, freqs): + has_seq = len(t_mod.shape) == 4 + chunk_dim = 2 if has_seq else 1 + # msa: multi-head self-attention mlp: multi-layer perceptron + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(6, dim=chunk_dim) + if has_seq: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + shift_msa.squeeze(2), scale_msa.squeeze(2), gate_msa.squeeze(2), + shift_mlp.squeeze(2), scale_mlp.squeeze(2), gate_mlp.squeeze(2), + ) + input_x = modulate(self.norm1(x), shift_msa, scale_msa) + x = self.gate(x, gate_msa, self.self_attn(input_x, freqs)) + x = x + self.cross_attn(self.norm3(x), context) + input_x = modulate(self.norm2(x), shift_mlp, scale_mlp) + x = self.gate(x, gate_mlp, self.ffn(input_x)) + return x + + +class MLP(torch.nn.Module): + def __init__(self, in_dim, out_dim, has_pos_emb=False): + super().__init__() + self.proj = torch.nn.Sequential( + nn.LayerNorm(in_dim), + nn.Linear(in_dim, in_dim), + nn.GELU(), + nn.Linear(in_dim, out_dim), + nn.LayerNorm(out_dim) + ) + self.has_pos_emb = has_pos_emb + if has_pos_emb: + self.emb_pos = torch.nn.Parameter(torch.zeros((1, 514, 1280))) + + def forward(self, x): + if self.has_pos_emb: + x = x + self.emb_pos.to(dtype=x.dtype, device=x.device) + return self.proj(x) + + +class Head(nn.Module): + def __init__(self, dim: int, out_dim: int, patch_size: Tuple[int, int, int], eps: float): + super().__init__() + self.dim = dim + self.patch_size = patch_size + self.norm = nn.LayerNorm(dim, eps=eps, elementwise_affine=False) + self.head = nn.Linear(dim, out_dim * math.prod(patch_size)) + self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5) + + def forward(self, x, t_mod): + if len(t_mod.shape) == 3: + shift, scale = (self.modulation.unsqueeze(0).to(dtype=t_mod.dtype, device=t_mod.device) + t_mod.unsqueeze(2)).chunk(2, dim=2) + x = (self.head(self.norm(x) * (1 + scale.squeeze(2)) + shift.squeeze(2))) + else: + shift, scale = (self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(2, dim=1) + x = (self.head(self.norm(x) * (1 + scale) + shift)) + return x + + +def wantodance_torch_dfs(model: nn.Module, parent_name='root'): + module_names, modules = [], [] + current_name = parent_name if parent_name else 'root' + module_names.append(current_name) + modules.append(model) + for name, child in model.named_children(): + if parent_name: + child_name = f'{parent_name}.{name}' + else: + child_name = name + child_modules, child_names = wantodance_torch_dfs(child, child_name) + module_names += child_names + modules += child_modules + return modules, module_names + + +class WanToDanceInjector(nn.Module): + def __init__(self, all_modules, all_modules_names, dim=2048, num_heads=32, inject_layer=[0, 27]): + super().__init__() + self.injected_block_id = {} + injector_id = 0 + for mod_name, mod in zip(all_modules_names, all_modules): + if isinstance(mod, DiTBlock): + for inject_id in inject_layer: + if f'root.transformer_blocks.{inject_id}' == mod_name: + self.injected_block_id[inject_id] = injector_id + injector_id += 1 + + self.injector = nn.ModuleList( + [ + CrossAttention( + dim=dim, + num_heads=num_heads, + ) + for _ in range(injector_id) + ] + ) + self.injector_pre_norm_feat = nn.ModuleList( + [ + nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6,) + for _ in range(injector_id) + ] + ) + self.injector_pre_norm_vec = nn.ModuleList( + [ + nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6,) + for _ in range(injector_id) + ] + ) + + +class WanModel(torch.nn.Module): + + _repeated_blocks = ["DiTBlock"] + + def __init__( + self, + dim: int, + in_dim: int, + ffn_dim: int, + out_dim: int, + text_dim: int, + freq_dim: int, + eps: float, + patch_size: Tuple[int, int, int], + num_heads: int, + num_layers: int, + has_image_input: bool, + has_image_pos_emb: bool = False, + has_ref_conv: bool = False, + add_control_adapter: bool = False, + in_dim_control_adapter: int = 24, + seperated_timestep: bool = False, + require_vae_embedding: bool = True, + require_clip_embedding: bool = True, + fuse_vae_embedding_in_latents: bool = False, + wantodance_enable_music_inject: bool = False, + wantodance_music_inject_layers = [0, 4, 8, 12, 16, 20, 24, 27], + wantodance_enable_refimage: bool = False, + wantodance_enable_refface: bool = False, + wantodance_enable_global: bool = False, + wantodance_enable_dynamicfps: bool = False, + wantodance_enable_unimodel: bool = False, + ): + super().__init__() + self.dim = dim + self.in_dim = in_dim + self.freq_dim = freq_dim + self.has_image_input = has_image_input + self.patch_size = patch_size + self.seperated_timestep = seperated_timestep + self.require_vae_embedding = require_vae_embedding + self.require_clip_embedding = require_clip_embedding + self.fuse_vae_embedding_in_latents = fuse_vae_embedding_in_latents + + self.patch_embedding = nn.Conv3d( + in_dim, dim, kernel_size=patch_size, stride=patch_size) + self.text_embedding = nn.Sequential( + nn.Linear(text_dim, dim), + nn.GELU(approximate='tanh'), + nn.Linear(dim, dim) + ) + self.time_embedding = nn.Sequential( + nn.Linear(freq_dim, dim), + nn.SiLU(), + nn.Linear(dim, dim) + ) + self.time_projection = nn.Sequential( + nn.SiLU(), nn.Linear(dim, dim * 6)) + self.blocks = nn.ModuleList([ + DiTBlock(has_image_input, dim, num_heads, ffn_dim, eps) + for _ in range(num_layers) + ]) + self.head = Head(dim, out_dim, patch_size, eps) + head_dim = dim // num_heads + + if wantodance_enable_dynamicfps or wantodance_enable_unimodel: + end = int(22350 / 8 + 0.5) # 149f * 30fps * 5s = 22350 + self.freqs = precompute_freqs_cis_3d(head_dim, end=end) + else: + self.freqs = precompute_freqs_cis_3d(head_dim) + + if has_image_input: + self.img_emb = MLP(1280, dim, has_pos_emb=has_image_pos_emb) # clip_feature_dim = 1280 + if has_ref_conv: + self.ref_conv = nn.Conv2d(16, dim, kernel_size=(2, 2), stride=(2, 2)) + self.has_image_pos_emb = has_image_pos_emb + self.has_ref_conv = has_ref_conv + if add_control_adapter: + self.control_adapter = SimpleAdapter(in_dim_control_adapter, dim, kernel_size=patch_size[1:], stride=patch_size[1:]) + else: + self.control_adapter = None + + self.prepare_wantodance(in_dim, dim, num_heads, has_image_pos_emb, out_dim, patch_size, eps, + wantodance_enable_music_inject, wantodance_music_inject_layers, wantodance_enable_refimage, wantodance_enable_refface, + wantodance_enable_global, wantodance_enable_dynamicfps, wantodance_enable_unimodel) + + def prepare_wantodance( + self, + in_dim, dim, num_heads, has_image_pos_emb, out_dim, patch_size, eps, + wantodance_enable_music_inject: bool = False, + wantodance_music_inject_layers = [0, 4, 8, 12, 16, 20, 24, 27], + wantodance_enable_refimage: bool = False, + wantodance_enable_refface: bool = False, + wantodance_enable_global: bool = False, + wantodance_enable_dynamicfps: bool = False, + wantodance_enable_unimodel: bool = False, + ): + if wantodance_enable_music_inject: + all_modules, all_modules_names = wantodance_torch_dfs(self.blocks, parent_name="root.transformer_blocks") + self.music_injector = WanToDanceInjector(all_modules, all_modules_names, dim=dim, num_heads=num_heads, inject_layer=wantodance_music_inject_layers) + if wantodance_enable_refimage: + self.img_emb_refimage = MLP(1280, dim, has_pos_emb=has_image_pos_emb) # clip_feature_dim = 1280 + if wantodance_enable_refface: + self.img_emb_refface = MLP(1280, dim, has_pos_emb=has_image_pos_emb) # clip_feature_dim = 1280 + if wantodance_enable_global or wantodance_enable_dynamicfps or wantodance_enable_unimodel: + music_feature_dim = 35 + ff_size = 1024 + dropout = 0.1 + latent_dim = 256 + nhead = 4 + activation = F.gelu + rotary = WanToDanceRotaryEmbedding(dim=latent_dim) + self.music_projection = nn.Linear(music_feature_dim, latent_dim) + self.music_encoder = nn.Sequential() + for _ in range(2): + self.music_encoder.append( + WanToDanceMusicEncoderLayer( + d_model=latent_dim, + nhead=nhead, + dim_feedforward=ff_size, + dropout=dropout, + activation=activation, + batch_first=True, + rotary=rotary, + device='cuda', + ) + ) + if wantodance_enable_unimodel: + self.patch_embedding_global = nn.Conv3d(in_dim, dim, kernel_size=patch_size, stride=patch_size) + if wantodance_enable_unimodel: + self.head_global = Head(dim, out_dim, patch_size, eps) + self.wantodance_enable_music_inject = wantodance_enable_music_inject + self.wantodance_enable_refimage = wantodance_enable_refimage + self.wantodance_enable_refface = wantodance_enable_refface + self.wantodance_enable_global = wantodance_enable_global + self.wantodance_enable_dynamicfps = wantodance_enable_dynamicfps + self.wantodance_enable_unimodel = wantodance_enable_unimodel + + def wantodance_after_transformer_block(self, block_idx, hidden_states): + if self.wantodance_enable_music_inject: + if block_idx in self.music_injector.injected_block_id.keys(): + audio_attn_id = self.music_injector.injected_block_id[block_idx] + audio_emb = self.merged_audio_emb # b f n c + num_frames = audio_emb.shape[1] + input_hidden_states = hidden_states.clone() # b (f h w) c + input_hidden_states = rearrange(input_hidden_states, "b (t n) c -> (b t) n c", t=num_frames) + attn_hidden_states = self.music_injector.injector_pre_norm_feat[audio_attn_id](input_hidden_states) + audio_emb = rearrange(audio_emb, "b t c -> (b t) 1 c", t=num_frames) + attn_audio_emb = audio_emb + residual_out = self.music_injector.injector[audio_attn_id](attn_hidden_states, attn_audio_emb) + residual_out = rearrange(residual_out, "(b t) n c -> b (t n) c", t=num_frames) + hidden_states = hidden_states + residual_out + return hidden_states + + def patchify(self, x: torch.Tensor, control_camera_latents_input: Optional[torch.Tensor] = None, enable_wantodance_global=False): + if enable_wantodance_global: + x = self.patch_embedding_global(x) + else: + x = self.patch_embedding(x) + if self.control_adapter is not None and control_camera_latents_input is not None: + y_camera = self.control_adapter(control_camera_latents_input) + x = [u + v for u, v in zip(x, y_camera)] + x = x[0].unsqueeze(0) + return x + + def unpatchify(self, x: torch.Tensor, grid_size: torch.Tensor): + return rearrange( + x, 'b (f h w) (x y z c) -> b c (f x) (h y) (w z)', + f=grid_size[0], h=grid_size[1], w=grid_size[2], + x=self.patch_size[0], y=self.patch_size[1], z=self.patch_size[2] + ) + + def forward(self, + x: torch.Tensor, + timestep: torch.Tensor, + context: torch.Tensor, + clip_feature: Optional[torch.Tensor] = None, + y: Optional[torch.Tensor] = None, + use_gradient_checkpointing: bool = False, + use_gradient_checkpointing_offload: bool = False, + **kwargs, + ): + t = self.time_embedding( + sinusoidal_embedding_1d(self.freq_dim, timestep).to(x.dtype)) + t_mod = self.time_projection(t).unflatten(1, (6, self.dim)) + context = self.text_embedding(context) + + if self.has_image_input: + x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w) + clip_embdding = self.img_emb(clip_feature) + context = torch.cat([clip_embdding, context], dim=1) + + x, (f, h, w) = self.patchify(x) + + freqs = torch.cat([ + self.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), + self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), + self.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1) + ], dim=-1).reshape(f * h * w, 1, -1).to(x.device) + + for block in self.blocks: + if self.training: + x = gradient_checkpoint_forward( + block, + use_gradient_checkpointing, + use_gradient_checkpointing_offload, + x, context, t_mod, freqs + ) + else: + x = block(x, context, t_mod, freqs) + + x = self.head(x, t) + x = self.unpatchify(x, (f, h, w)) + return x diff --git a/diffsynth/models/wan_video_dit_s2v.py b/diffsynth/models/wan_video_dit_s2v.py new file mode 100644 index 0000000000000000000000000000000000000000..f4d1abe4c293e0615532f01ac5343dd9582f9636 --- /dev/null +++ b/diffsynth/models/wan_video_dit_s2v.py @@ -0,0 +1,568 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import Tuple +from .wan_video_dit import rearrange, precompute_freqs_cis_3d, DiTBlock, Head, CrossAttention, modulate, sinusoidal_embedding_1d +from ..core.gradient import gradient_checkpoint_forward + + +def torch_dfs(model: nn.Module, parent_name='root'): + module_names, modules = [], [] + current_name = parent_name if parent_name else 'root' + module_names.append(current_name) + modules.append(model) + + for name, child in model.named_children(): + if parent_name: + child_name = f'{parent_name}.{name}' + else: + child_name = name + child_modules, child_names = torch_dfs(child, child_name) + module_names += child_names + modules += child_modules + return modules, module_names + + +def rope_precompute(x, grid_sizes, freqs, start=None): + b, s, n, c = x.size(0), x.size(1), x.size(2), x.size(3) // 2 + + # split freqs + if type(freqs) is list: + trainable_freqs = freqs[1] + freqs = freqs[0] + freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1) + + # loop over samples + output = torch.view_as_complex(x.detach().reshape(b, s, n, -1, 2).to(torch.float64)) + seq_bucket = [0] + if not type(grid_sizes) is list: + grid_sizes = [grid_sizes] + for g in grid_sizes: + if not type(g) is list: + g = [torch.zeros_like(g), g] + batch_size = g[0].shape[0] + for i in range(batch_size): + if start is None: + f_o, h_o, w_o = g[0][i] + else: + f_o, h_o, w_o = start[i] + + f, h, w = g[1][i] + t_f, t_h, t_w = g[2][i] + seq_f, seq_h, seq_w = f - f_o, h - h_o, w - w_o + seq_len = int(seq_f * seq_h * seq_w) + if seq_len > 0: + if t_f > 0: + factor_f, factor_h, factor_w = (t_f / seq_f).item(), (t_h / seq_h).item(), (t_w / seq_w).item() + # Generate a list of seq_f integers starting from f_o and ending at math.ceil(factor_f * seq_f.item() + f_o.item()) + if f_o >= 0: + f_sam = np.linspace(f_o.item(), (t_f + f_o).item() - 1, seq_f).astype(int).tolist() + else: + f_sam = np.linspace(-f_o.item(), (-t_f - f_o).item() + 1, seq_f).astype(int).tolist() + h_sam = np.linspace(h_o.item(), (t_h + h_o).item() - 1, seq_h).astype(int).tolist() + w_sam = np.linspace(w_o.item(), (t_w + w_o).item() - 1, seq_w).astype(int).tolist() + + assert f_o * f >= 0 and h_o * h >= 0 and w_o * w >= 0 + freqs_0 = freqs[0][f_sam] if f_o >= 0 else freqs[0][f_sam].conj() + freqs_0 = freqs_0.view(seq_f, 1, 1, -1) + + freqs_i = torch.cat( + [ + freqs_0.expand(seq_f, seq_h, seq_w, -1), + freqs[1][h_sam].view(1, seq_h, 1, -1).expand(seq_f, seq_h, seq_w, -1), + freqs[2][w_sam].view(1, 1, seq_w, -1).expand(seq_f, seq_h, seq_w, -1), + ], + dim=-1 + ).reshape(seq_len, 1, -1) + elif t_f < 0: + freqs_i = trainable_freqs.unsqueeze(1) + # apply rotary embedding + output[i, seq_bucket[-1]:seq_bucket[-1] + seq_len] = freqs_i + seq_bucket.append(seq_bucket[-1] + seq_len) + return output + + +class CausalConv1d(nn.Module): + + def __init__(self, chan_in, chan_out, kernel_size=3, stride=1, dilation=1, pad_mode='replicate', **kwargs): + super().__init__() + + self.pad_mode = pad_mode + padding = (kernel_size - 1, 0) # T + self.time_causal_padding = padding + + self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs) + + def forward(self, x): + x = F.pad(x, self.time_causal_padding, mode=self.pad_mode) + return self.conv(x) + + +class MotionEncoder_tc(nn.Module): + + def __init__(self, in_dim: int, hidden_dim: int, num_heads=int, need_global=True, dtype=None, device=None): + factory_kwargs = {"dtype": dtype, "device": device} + super().__init__() + + self.num_heads = num_heads + self.need_global = need_global + self.conv1_local = CausalConv1d(in_dim, hidden_dim // 4 * num_heads, 3, stride=1) + if need_global: + self.conv1_global = CausalConv1d(in_dim, hidden_dim // 4, 3, stride=1) + self.norm1 = nn.LayerNorm(hidden_dim // 4, elementwise_affine=False, eps=1e-6, **factory_kwargs) + self.act = nn.SiLU() + self.conv2 = CausalConv1d(hidden_dim // 4, hidden_dim // 2, 3, stride=2) + self.conv3 = CausalConv1d(hidden_dim // 2, hidden_dim, 3, stride=2) + + if need_global: + self.final_linear = nn.Linear(hidden_dim, hidden_dim, **factory_kwargs) + + self.norm1 = nn.LayerNorm(hidden_dim // 4, elementwise_affine=False, eps=1e-6, **factory_kwargs) + self.norm2 = nn.LayerNorm(hidden_dim // 2, elementwise_affine=False, eps=1e-6, **factory_kwargs) + self.norm3 = nn.LayerNorm(hidden_dim, elementwise_affine=False, eps=1e-6, **factory_kwargs) + self.padding_tokens = nn.Parameter(torch.zeros(1, 1, 1, hidden_dim)) + + def forward(self, x): + x = rearrange(x, 'b t c -> b c t') + x_ori = x.clone() + b, c, t = x.shape + x = self.conv1_local(x) + x = rearrange(x, 'b (n c) t -> (b n) t c', n=self.num_heads) + x = self.norm1(x) + x = self.act(x) + x = rearrange(x, 'b t c -> b c t') + x = self.conv2(x) + x = rearrange(x, 'b c t -> b t c') + x = self.norm2(x) + x = self.act(x) + x = rearrange(x, 'b t c -> b c t') + x = self.conv3(x) + x = rearrange(x, 'b c t -> b t c') + x = self.norm3(x) + x = self.act(x) + x = rearrange(x, '(b n) t c -> b t n c', b=b) + padding = self.padding_tokens.repeat(b, x.shape[1], 1, 1).to(device=x.device, dtype=x.dtype) + x = torch.cat([x, padding], dim=-2) + x_local = x.clone() + + if not self.need_global: + return x_local + + x = self.conv1_global(x_ori) + x = rearrange(x, 'b c t -> b t c') + x = self.norm1(x) + x = self.act(x) + x = rearrange(x, 'b t c -> b c t') + x = self.conv2(x) + x = rearrange(x, 'b c t -> b t c') + x = self.norm2(x) + x = self.act(x) + x = rearrange(x, 'b t c -> b c t') + x = self.conv3(x) + x = rearrange(x, 'b c t -> b t c') + x = self.norm3(x) + x = self.act(x) + x = self.final_linear(x) + x = rearrange(x, '(b n) t c -> b t n c', b=b) + + return x, x_local + + +class FramePackMotioner(nn.Module): + + def __init__(self, inner_dim=1024, num_heads=16, zip_frame_buckets=[1, 2, 16], drop_mode="drop", *args, **kwargs): + super().__init__(*args, **kwargs) + self.proj = nn.Conv3d(16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2)) + self.proj_2x = nn.Conv3d(16, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4)) + self.proj_4x = nn.Conv3d(16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8)) + self.zip_frame_buckets = torch.tensor(zip_frame_buckets, dtype=torch.long) + + self.inner_dim = inner_dim + self.num_heads = num_heads + self.freqs = torch.cat(precompute_freqs_cis_3d(inner_dim // num_heads), dim=1) + self.drop_mode = drop_mode + + def forward(self, motion_latents, add_last_motion=2): + motion_frames = motion_latents[0].shape[1] + mot = [] + mot_remb = [] + for m in motion_latents: + lat_height, lat_width = m.shape[2], m.shape[3] + padd_lat = torch.zeros(16, self.zip_frame_buckets.sum(), lat_height, lat_width).to(device=m.device, dtype=m.dtype) + overlap_frame = min(padd_lat.shape[1], m.shape[1]) + if overlap_frame > 0: + padd_lat[:, -overlap_frame:] = m[:, -overlap_frame:] + + if add_last_motion < 2 and self.drop_mode != "drop": + zero_end_frame = self.zip_frame_buckets[:self.zip_frame_buckets.__len__() - add_last_motion - 1].sum() + padd_lat[:, -zero_end_frame:] = 0 + + padd_lat = padd_lat.unsqueeze(0) + clean_latents_4x, clean_latents_2x, clean_latents_post = padd_lat[:, :, -self.zip_frame_buckets.sum():, :, :].split( + list(self.zip_frame_buckets)[::-1], dim=2 + ) # 16, 2 ,1 + + # patchfy + clean_latents_post = self.proj(clean_latents_post).flatten(2).transpose(1, 2) + clean_latents_2x = self.proj_2x(clean_latents_2x).flatten(2).transpose(1, 2) + clean_latents_4x = self.proj_4x(clean_latents_4x).flatten(2).transpose(1, 2) + + if add_last_motion < 2 and self.drop_mode == "drop": + clean_latents_post = clean_latents_post[:, :0] if add_last_motion < 2 else clean_latents_post + clean_latents_2x = clean_latents_2x[:, :0] if add_last_motion < 1 else clean_latents_2x + + motion_lat = torch.cat([clean_latents_post, clean_latents_2x, clean_latents_4x], dim=1) + + # rope + start_time_id = -(self.zip_frame_buckets[:1].sum()) + end_time_id = start_time_id + self.zip_frame_buckets[0] + grid_sizes = [] if add_last_motion < 2 and self.drop_mode == "drop" else \ + [ + [torch.tensor([start_time_id, 0, 0]).unsqueeze(0).repeat(1, 1), + torch.tensor([end_time_id, lat_height // 2, lat_width // 2]).unsqueeze(0).repeat(1, 1), + torch.tensor([self.zip_frame_buckets[0], lat_height // 2, lat_width // 2]).unsqueeze(0).repeat(1, 1), ] + ] + + start_time_id = -(self.zip_frame_buckets[:2].sum()) + end_time_id = start_time_id + self.zip_frame_buckets[1] // 2 + grid_sizes_2x = [] if add_last_motion < 1 and self.drop_mode == "drop" else \ + [ + [torch.tensor([start_time_id, 0, 0]).unsqueeze(0).repeat(1, 1), + torch.tensor([end_time_id, lat_height // 4, lat_width // 4]).unsqueeze(0).repeat(1, 1), + torch.tensor([self.zip_frame_buckets[1], lat_height // 2, lat_width // 2]).unsqueeze(0).repeat(1, 1), ] + ] + + start_time_id = -(self.zip_frame_buckets[:3].sum()) + end_time_id = start_time_id + self.zip_frame_buckets[2] // 4 + grid_sizes_4x = [ + [ + torch.tensor([start_time_id, 0, 0]).unsqueeze(0).repeat(1, 1), + torch.tensor([end_time_id, lat_height // 8, lat_width // 8]).unsqueeze(0).repeat(1, 1), + torch.tensor([self.zip_frame_buckets[2], lat_height // 2, lat_width // 2]).unsqueeze(0).repeat(1, 1), + ] + ] + + grid_sizes = grid_sizes + grid_sizes_2x + grid_sizes_4x + + motion_rope_emb = rope_precompute( + motion_lat.detach().view(1, motion_lat.shape[1], self.num_heads, self.inner_dim // self.num_heads), + grid_sizes, + self.freqs, + start=None + ) + + mot.append(motion_lat) + mot_remb.append(motion_rope_emb) + return mot, mot_remb + + +class AdaLayerNorm(nn.Module): + + def __init__( + self, + embedding_dim: int, + output_dim: int, + norm_eps: float = 1e-5, + ): + super().__init__() + self.silu = nn.SiLU() + self.linear = nn.Linear(embedding_dim, output_dim) + self.norm = nn.LayerNorm(output_dim // 2, norm_eps, elementwise_affine=False) + + def forward(self, x, temb): + temb = self.linear(F.silu(temb)) + shift, scale = temb.chunk(2, dim=1) + shift = shift[:, None, :] + scale = scale[:, None, :] + x = self.norm(x) * (1 + scale) + shift + return x + + +class AudioInjector_WAN(nn.Module): + + def __init__( + self, + all_modules, + all_modules_names, + dim=2048, + num_heads=32, + inject_layer=[0, 27], + enable_adain=False, + adain_dim=2048, + ): + super().__init__() + self.injected_block_id = {} + audio_injector_id = 0 + for mod_name, mod in zip(all_modules_names, all_modules): + if isinstance(mod, DiTBlock): + for inject_id in inject_layer: + if f'transformer_blocks.{inject_id}' in mod_name: + self.injected_block_id[inject_id] = audio_injector_id + audio_injector_id += 1 + + self.injector = nn.ModuleList([CrossAttention( + dim=dim, + num_heads=num_heads, + ) for _ in range(audio_injector_id)]) + self.injector_pre_norm_feat = nn.ModuleList([nn.LayerNorm( + dim, + elementwise_affine=False, + eps=1e-6, + ) for _ in range(audio_injector_id)]) + self.injector_pre_norm_vec = nn.ModuleList([nn.LayerNorm( + dim, + elementwise_affine=False, + eps=1e-6, + ) for _ in range(audio_injector_id)]) + if enable_adain: + self.injector_adain_layers = nn.ModuleList([AdaLayerNorm(output_dim=dim * 2, embedding_dim=adain_dim) for _ in range(audio_injector_id)]) + + +class CausalAudioEncoder(nn.Module): + + def __init__(self, dim=5120, num_layers=25, out_dim=2048, num_token=4, need_global=False): + super().__init__() + self.encoder = MotionEncoder_tc(in_dim=dim, hidden_dim=out_dim, num_heads=num_token, need_global=need_global) + weight = torch.ones((1, num_layers, 1, 1)) * 0.01 + + self.weights = torch.nn.Parameter(weight) + self.act = torch.nn.SiLU() + + def forward(self, features): + # features B * num_layers * dim * video_length + weights = self.act(self.weights.to(device=features.device, dtype=features.dtype)) + weights_sum = weights.sum(dim=1, keepdims=True) + weighted_feat = ((features * weights) / weights_sum).sum(dim=1) # b dim f + weighted_feat = weighted_feat.permute(0, 2, 1) # b f dim + res = self.encoder(weighted_feat) # b f n dim + return res # b f n dim + + +class WanS2VDiTBlock(DiTBlock): + + def forward(self, x, context, t_mod, seq_len_x, freqs): + t_mod = (self.modulation.unsqueeze(2).to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(6, dim=1) + # t_mod[:, :, 0] for x, t_mod[:, :, 1] for other like ref, motion, etc. + t_mod = [ + torch.cat([element[:, :, 0].expand(1, seq_len_x, x.shape[-1]), element[:, :, 1].expand(1, x.shape[1] - seq_len_x, x.shape[-1])], dim=1) + for element in t_mod + ] + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = t_mod + input_x = modulate(self.norm1(x), shift_msa, scale_msa) + x = self.gate(x, gate_msa, self.self_attn(input_x, freqs)) + x = x + self.cross_attn(self.norm3(x), context) + input_x = modulate(self.norm2(x), shift_mlp, scale_mlp) + x = self.gate(x, gate_mlp, self.ffn(input_x)) + return x + + +class WanS2VModel(torch.nn.Module): + + def __init__( + self, + dim: int, + in_dim: int, + ffn_dim: int, + out_dim: int, + text_dim: int, + freq_dim: int, + eps: float, + patch_size: Tuple[int, int, int], + num_heads: int, + num_layers: int, + cond_dim: int, + audio_dim: int, + num_audio_token: int, + enable_adain: bool = True, + audio_inject_layers: list = [0, 4, 8, 12, 16, 20, 24, 27, 30, 33, 36, 39], + zero_timestep: bool = True, + add_last_motion: bool = True, + framepack_drop_mode: str = "padd", + fuse_vae_embedding_in_latents: bool = True, + require_vae_embedding: bool = False, + seperated_timestep: bool = False, + require_clip_embedding: bool = False, + ): + super().__init__() + self.dim = dim + self.in_dim = in_dim + self.freq_dim = freq_dim + self.patch_size = patch_size + self.num_heads = num_heads + self.enbale_adain = enable_adain + self.add_last_motion = add_last_motion + self.zero_timestep = zero_timestep + self.fuse_vae_embedding_in_latents = fuse_vae_embedding_in_latents + self.require_vae_embedding = require_vae_embedding + self.seperated_timestep = seperated_timestep + self.require_clip_embedding = require_clip_embedding + + self.patch_embedding = nn.Conv3d(in_dim, dim, kernel_size=patch_size, stride=patch_size) + self.text_embedding = nn.Sequential(nn.Linear(text_dim, dim), nn.GELU(approximate='tanh'), nn.Linear(dim, dim)) + self.time_embedding = nn.Sequential(nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim)) + self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6)) + + self.blocks = nn.ModuleList([WanS2VDiTBlock(False, dim, num_heads, ffn_dim, eps) for _ in range(num_layers)]) + self.head = Head(dim, out_dim, patch_size, eps) + self.freqs = torch.cat(precompute_freqs_cis_3d(dim // num_heads), dim=1) + + self.cond_encoder = nn.Conv3d(cond_dim, dim, kernel_size=patch_size, stride=patch_size) + self.casual_audio_encoder = CausalAudioEncoder(dim=audio_dim, out_dim=dim, num_token=num_audio_token, need_global=enable_adain) + all_modules, all_modules_names = torch_dfs(self.blocks, parent_name="root.transformer_blocks") + self.audio_injector = AudioInjector_WAN( + all_modules, + all_modules_names, + dim=dim, + num_heads=num_heads, + inject_layer=audio_inject_layers, + enable_adain=enable_adain, + adain_dim=dim, + ) + self.trainable_cond_mask = nn.Embedding(3, dim) + self.frame_packer = FramePackMotioner(inner_dim=dim, num_heads=num_heads, zip_frame_buckets=[1, 2, 16], drop_mode=framepack_drop_mode) + + def patchify(self, x: torch.Tensor): + grid_size = x.shape[2:] + x = rearrange(x, 'b c f h w -> b (f h w) c').contiguous() + return x, grid_size # x, grid_size: (f, h, w) + + def unpatchify(self, x: torch.Tensor, grid_size: torch.Tensor): + return rearrange( + x, + 'b (f h w) (x y z c) -> b c (f x) (h y) (w z)', + f=grid_size[0], + h=grid_size[1], + w=grid_size[2], + x=self.patch_size[0], + y=self.patch_size[1], + z=self.patch_size[2] + ) + + def process_motion_frame_pack(self, motion_latents, drop_motion_frames=False, add_last_motion=2): + flattern_mot, mot_remb = self.frame_packer(motion_latents, add_last_motion) + if drop_motion_frames: + return [m[:, :0] for m in flattern_mot], [m[:, :0] for m in mot_remb] + else: + return flattern_mot, mot_remb + + def inject_motion(self, x, rope_embs, mask_input, motion_latents, drop_motion_frames=True, add_last_motion=2): + # inject the motion frames token to the hidden states + mot, mot_remb = self.process_motion_frame_pack(motion_latents, drop_motion_frames=drop_motion_frames, add_last_motion=add_last_motion) + if len(mot) > 0: + x = torch.cat([x, mot[0]], dim=1) + rope_embs = torch.cat([rope_embs, mot_remb[0]], dim=1) + mask_input = torch.cat( + [mask_input, 2 * torch.ones([1, x.shape[1] - mask_input.shape[1]], device=mask_input.device, dtype=mask_input.dtype)], dim=1 + ) + return x, rope_embs, mask_input + + def after_transformer_block(self, block_idx, hidden_states, audio_emb_global, audio_emb, original_seq_len, use_unified_sequence_parallel=False): + if block_idx in self.audio_injector.injected_block_id.keys(): + audio_attn_id = self.audio_injector.injected_block_id[block_idx] + num_frames = audio_emb.shape[1] + if use_unified_sequence_parallel: + from xfuser.core.distributed import get_sp_group + hidden_states = get_sp_group().all_gather(hidden_states, dim=1) + + input_hidden_states = hidden_states[:, :original_seq_len].clone() # b (f h w) c + input_hidden_states = rearrange(input_hidden_states, "b (t n) c -> (b t) n c", t=num_frames) + + audio_emb_global = rearrange(audio_emb_global, "b t n c -> (b t) n c") + adain_hidden_states = self.audio_injector.injector_adain_layers[audio_attn_id](input_hidden_states, temb=audio_emb_global[:, 0]) + attn_hidden_states = adain_hidden_states + + audio_emb = rearrange(audio_emb, "b t n c -> (b t) n c", t=num_frames) + attn_audio_emb = audio_emb + residual_out = self.audio_injector.injector[audio_attn_id](attn_hidden_states, attn_audio_emb) + residual_out = rearrange(residual_out, "(b t) n c -> b (t n) c", t=num_frames) + hidden_states[:, :original_seq_len] = hidden_states[:, :original_seq_len] + residual_out + if use_unified_sequence_parallel: + from xfuser.core.distributed import get_sequence_parallel_world_size, get_sequence_parallel_rank + hidden_states = torch.chunk(hidden_states, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()] + return hidden_states + + def cal_audio_emb(self, audio_input, motion_frames=[73, 19]): + audio_input = torch.cat([audio_input[..., 0:1].repeat(1, 1, 1, motion_frames[0]), audio_input], dim=-1) + audio_emb_global, audio_emb = self.casual_audio_encoder(audio_input) + audio_emb_global = audio_emb_global[:, motion_frames[1]:].clone() + merged_audio_emb = audio_emb[:, motion_frames[1]:, :] + return audio_emb_global, merged_audio_emb + + def get_grid_sizes(self, grid_size_x, grid_size_ref): + f, h, w = grid_size_x + rf, rh, rw = grid_size_ref + grid_sizes_x = torch.tensor([f, h, w], dtype=torch.long).unsqueeze(0) + grid_sizes_x = [[torch.zeros_like(grid_sizes_x), grid_sizes_x, grid_sizes_x]] + grid_sizes_ref = [[ + torch.tensor([30, 0, 0]).unsqueeze(0), + torch.tensor([31, rh, rw]).unsqueeze(0), + torch.tensor([1, rh, rw]).unsqueeze(0), + ]] + return grid_sizes_x + grid_sizes_ref + + def forward( + self, + latents, + timestep, + context, + audio_input, + motion_latents, + pose_cond, + use_gradient_checkpointing_offload=False, + use_gradient_checkpointing=False + ): + origin_ref_latents = latents[:, :, 0:1] + x = latents[:, :, 1:] + + # context embedding + context = self.text_embedding(context) + + # audio encode + audio_emb_global, merged_audio_emb = self.cal_audio_emb(audio_input) + + # x and pose_cond + pose_cond = torch.zeros_like(x) if pose_cond is None else pose_cond + x, (f, h, w) = self.patchify(self.patch_embedding(x) + self.cond_encoder(pose_cond)) # torch.Size([1, 29120, 5120]) + seq_len_x = x.shape[1] + + # reference image + ref_latents, (rf, rh, rw) = self.patchify(self.patch_embedding(origin_ref_latents)) # torch.Size([1, 1456, 5120]) + grid_sizes = self.get_grid_sizes((f, h, w), (rf, rh, rw)) + x = torch.cat([x, ref_latents], dim=1) + # mask + mask = torch.cat([torch.zeros([1, seq_len_x]), torch.ones([1, ref_latents.shape[1]])], dim=1).to(torch.long).to(x.device) + # freqs + pre_compute_freqs = rope_precompute( + x.detach().view(1, x.size(1), self.num_heads, self.dim // self.num_heads), grid_sizes, self.freqs, start=None + ) + # motion + x, pre_compute_freqs, mask = self.inject_motion(x, pre_compute_freqs, mask, motion_latents, add_last_motion=2) + + x = x + self.trainable_cond_mask(mask).to(x.dtype) + + # t_mod + timestep = torch.cat([timestep, torch.zeros([1], dtype=timestep.dtype, device=timestep.device)]) + t = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, timestep)) + t_mod = self.time_projection(t).unflatten(1, (6, self.dim)).unsqueeze(2).transpose(0, 2) + + for block_id, block in enumerate(self.blocks): + x = gradient_checkpoint_forward( + block, + use_gradient_checkpointing, + use_gradient_checkpointing_offload, + x, context, t_mod, seq_len_x, pre_compute_freqs[0] + ) + x = gradient_checkpoint_forward( + lambda x: self.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x), + use_gradient_checkpointing, + use_gradient_checkpointing_offload, + x + ) + + x = x[:, :seq_len_x] + x = self.head(x, t[:-1]) + x = self.unpatchify(x, (f, h, w)) + # make compatible with wan video + x = torch.cat([origin_ref_latents, x], dim=2) + return x diff --git a/diffsynth/models/wan_video_image_encoder.py b/diffsynth/models/wan_video_image_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..37d17d6a183f5d6290c3f4b1e417bf0310bfa353 --- /dev/null +++ b/diffsynth/models/wan_video_image_encoder.py @@ -0,0 +1,878 @@ +""" +Concise re-implementation of +``https://github.com/openai/CLIP'' and +``https://github.com/mlfoundations/open_clip''. +""" +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision.transforms as T +from .wan_video_dit import flash_attention + + +class SelfAttention(nn.Module): + + def __init__(self, dim, num_heads, dropout=0.1, eps=1e-5): + assert dim % num_heads == 0 + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.eps = eps + + # layers + self.q = nn.Linear(dim, dim) + self.k = nn.Linear(dim, dim) + self.v = nn.Linear(dim, dim) + self.o = nn.Linear(dim, dim) + self.dropout = nn.Dropout(dropout) + + def forward(self, x, mask): + """ + x: [B, L, C]. + """ + b, s, c, n, d = *x.size(), self.num_heads, self.head_dim + + # compute query, key, value + q = self.q(x).reshape(b, s, n, d).permute(0, 2, 1, 3) + k = self.k(x).reshape(b, s, n, d).permute(0, 2, 1, 3) + v = self.v(x).reshape(b, s, n, d).permute(0, 2, 1, 3) + + # compute attention + p = self.dropout.p if self.training else 0.0 + x = F.scaled_dot_product_attention(q, k, v, mask, p) + x = x.permute(0, 2, 1, 3).reshape(b, s, c) + + # output + x = self.o(x) + x = self.dropout(x) + return x + + +class AttentionBlock(nn.Module): + + def __init__(self, dim, num_heads, post_norm, dropout=0.1, eps=1e-5): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.post_norm = post_norm + self.eps = eps + + # layers + self.attn = SelfAttention(dim, num_heads, dropout, eps) + self.norm1 = nn.LayerNorm(dim, eps=eps) + self.ffn = nn.Sequential( + nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim), + nn.Dropout(dropout)) + self.norm2 = nn.LayerNorm(dim, eps=eps) + + def forward(self, x, mask): + if self.post_norm: + x = self.norm1(x + self.attn(x, mask)) + x = self.norm2(x + self.ffn(x)) + else: + x = x + self.attn(self.norm1(x), mask) + x = x + self.ffn(self.norm2(x)) + return x + + +class XLMRoberta(nn.Module): + """ + XLMRobertaModel with no pooler and no LM head. + """ + + def __init__(self, + vocab_size=250002, + max_seq_len=514, + type_size=1, + pad_id=1, + dim=1024, + num_heads=16, + num_layers=24, + post_norm=True, + dropout=0.1, + eps=1e-5): + super().__init__() + self.vocab_size = vocab_size + self.max_seq_len = max_seq_len + self.type_size = type_size + self.pad_id = pad_id + self.dim = dim + self.num_heads = num_heads + self.num_layers = num_layers + self.post_norm = post_norm + self.eps = eps + + # embeddings + self.token_embedding = nn.Embedding(vocab_size, dim, padding_idx=pad_id) + self.type_embedding = nn.Embedding(type_size, dim) + self.pos_embedding = nn.Embedding(max_seq_len, dim, padding_idx=pad_id) + self.dropout = nn.Dropout(dropout) + + # blocks + self.blocks = nn.ModuleList([ + AttentionBlock(dim, num_heads, post_norm, dropout, eps) + for _ in range(num_layers) + ]) + + # norm layer + self.norm = nn.LayerNorm(dim, eps=eps) + + def forward(self, ids): + """ + ids: [B, L] of torch.LongTensor. + """ + b, s = ids.shape + mask = ids.ne(self.pad_id).long() + + # embeddings + x = self.token_embedding(ids) + \ + self.type_embedding(torch.zeros_like(ids)) + \ + self.pos_embedding(self.pad_id + torch.cumsum(mask, dim=1) * mask) + if self.post_norm: + x = self.norm(x) + x = self.dropout(x) + + # blocks + mask = torch.where( + mask.view(b, 1, 1, s).gt(0), 0.0, + torch.finfo(x.dtype).min) + for block in self.blocks: + x = block(x, mask) + + # output + if not self.post_norm: + x = self.norm(x) + return x + + +def xlm_roberta_large(pretrained=False, + return_tokenizer=False, + device='cpu', + **kwargs): + """ + XLMRobertaLarge adapted from Huggingface. + """ + # params + cfg = dict( + vocab_size=250002, + max_seq_len=514, + type_size=1, + pad_id=1, + dim=1024, + num_heads=16, + num_layers=24, + post_norm=True, + dropout=0.1, + eps=1e-5) + cfg.update(**kwargs) + + # init model + if pretrained: + from sora import DOWNLOAD_TO_CACHE + + # init a meta model + with torch.device('meta'): + model = XLMRoberta(**cfg) + + # load checkpoint + model.load_state_dict( + torch.load( + DOWNLOAD_TO_CACHE('models/xlm_roberta/xlm_roberta_large.pth'), + map_location=device), + assign=True) + else: + # init a model on device + with torch.device(device): + model = XLMRoberta(**cfg) + + # init tokenizer + if return_tokenizer: + from sora.data import HuggingfaceTokenizer + tokenizer = HuggingfaceTokenizer( + name='xlm-roberta-large', + seq_len=model.text_len, + clean='whitespace') + return model, tokenizer + else: + return model + + + +def pos_interpolate(pos, seq_len): + if pos.size(1) == seq_len: + return pos + else: + src_grid = int(math.sqrt(pos.size(1))) + tar_grid = int(math.sqrt(seq_len)) + n = pos.size(1) - src_grid * src_grid + return torch.cat([ + pos[:, :n], + F.interpolate( + pos[:, n:].float().reshape(1, src_grid, src_grid, -1).permute( + 0, 3, 1, 2), + size=(tar_grid, tar_grid), + mode='bicubic', + align_corners=False).flatten(2).transpose(1, 2) + ], + dim=1) + + +class QuickGELU(nn.Module): + + def forward(self, x): + return x * torch.sigmoid(1.702 * x) + + +class LayerNorm(nn.LayerNorm): + + def forward(self, x): + return super().forward(x).type_as(x) + + +class SelfAttention(nn.Module): + + def __init__(self, + dim, + num_heads, + causal=False, + attn_dropout=0.0, + proj_dropout=0.0): + assert dim % num_heads == 0 + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.causal = causal + self.attn_dropout = attn_dropout + self.proj_dropout = proj_dropout + + # layers + self.to_qkv = nn.Linear(dim, dim * 3) + self.proj = nn.Linear(dim, dim) + + def forward(self, x): + """ + x: [B, L, C]. + """ + # compute query, key, value + q, k, v = self.to_qkv(x).chunk(3, dim=-1) + + # compute attention + x = flash_attention(q, k, v, num_heads=self.num_heads, compatibility_mode=True) + + # output + x = self.proj(x) + x = F.dropout(x, self.proj_dropout, self.training) + return x + + +class SwiGLU(nn.Module): + + def __init__(self, dim, mid_dim): + super().__init__() + self.dim = dim + self.mid_dim = mid_dim + + # layers + self.fc1 = nn.Linear(dim, mid_dim) + self.fc2 = nn.Linear(dim, mid_dim) + self.fc3 = nn.Linear(mid_dim, dim) + + def forward(self, x): + x = F.silu(self.fc1(x)) * self.fc2(x) + x = self.fc3(x) + return x + + +class AttentionBlock(nn.Module): + + def __init__(self, + dim, + mlp_ratio, + num_heads, + post_norm=False, + causal=False, + activation='quick_gelu', + attn_dropout=0.0, + proj_dropout=0.0, + norm_eps=1e-5): + assert activation in ['quick_gelu', 'gelu', 'swi_glu'] + super().__init__() + self.dim = dim + self.mlp_ratio = mlp_ratio + self.num_heads = num_heads + self.post_norm = post_norm + self.causal = causal + self.norm_eps = norm_eps + + # layers + self.norm1 = LayerNorm(dim, eps=norm_eps) + self.attn = SelfAttention(dim, num_heads, causal, attn_dropout, + proj_dropout) + self.norm2 = LayerNorm(dim, eps=norm_eps) + if activation == 'swi_glu': + self.mlp = SwiGLU(dim, int(dim * mlp_ratio)) + else: + self.mlp = nn.Sequential( + nn.Linear(dim, int(dim * mlp_ratio)), + QuickGELU() if activation == 'quick_gelu' else nn.GELU(), + nn.Linear(int(dim * mlp_ratio), dim), nn.Dropout(proj_dropout)) + + def forward(self, x): + if self.post_norm: + x = x + self.norm1(self.attn(x)) + x = x + self.norm2(self.mlp(x)) + else: + x = x + self.attn(self.norm1(x)) + x = x + self.mlp(self.norm2(x)) + return x + + +class AttentionPool(nn.Module): + + def __init__(self, + dim, + mlp_ratio, + num_heads, + activation='gelu', + proj_dropout=0.0, + norm_eps=1e-5): + assert dim % num_heads == 0 + super().__init__() + self.dim = dim + self.mlp_ratio = mlp_ratio + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.proj_dropout = proj_dropout + self.norm_eps = norm_eps + + # layers + gain = 1.0 / math.sqrt(dim) + self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim)) + self.to_q = nn.Linear(dim, dim) + self.to_kv = nn.Linear(dim, dim * 2) + self.proj = nn.Linear(dim, dim) + self.norm = LayerNorm(dim, eps=norm_eps) + self.mlp = nn.Sequential( + nn.Linear(dim, int(dim * mlp_ratio)), + QuickGELU() if activation == 'quick_gelu' else nn.GELU(), + nn.Linear(int(dim * mlp_ratio), dim), nn.Dropout(proj_dropout)) + + def forward(self, x): + """ + x: [B, L, C]. + """ + b, s, c, n, d = *x.size(), self.num_heads, self.head_dim + + # compute query, key, value + q = self.to_q(self.cls_embedding).view(1, 1, n*d).expand(b, -1, -1) + k, v = self.to_kv(x).chunk(2, dim=-1) + + # compute attention + x = flash_attention(q, k, v, num_heads=self.num_heads, compatibility_mode=True) + x = x.reshape(b, 1, c) + + # output + x = self.proj(x) + x = F.dropout(x, self.proj_dropout, self.training) + + # mlp + x = x + self.mlp(self.norm(x)) + return x[:, 0] + + +class VisionTransformer(nn.Module): + + def __init__(self, + image_size=224, + patch_size=16, + dim=768, + mlp_ratio=4, + out_dim=512, + num_heads=12, + num_layers=12, + pool_type='token', + pre_norm=True, + post_norm=False, + activation='quick_gelu', + attn_dropout=0.0, + proj_dropout=0.0, + embedding_dropout=0.0, + norm_eps=1e-5): + if image_size % patch_size != 0: + print( + '[WARNING] image_size is not divisible by patch_size', + flush=True) + assert pool_type in ('token', 'token_fc', 'attn_pool') + out_dim = out_dim or dim + super().__init__() + self.image_size = image_size + self.patch_size = patch_size + self.num_patches = (image_size // patch_size)**2 + self.dim = dim + self.mlp_ratio = mlp_ratio + self.out_dim = out_dim + self.num_heads = num_heads + self.num_layers = num_layers + self.pool_type = pool_type + self.post_norm = post_norm + self.norm_eps = norm_eps + + # embeddings + gain = 1.0 / math.sqrt(dim) + self.patch_embedding = nn.Conv2d( + 3, + dim, + kernel_size=patch_size, + stride=patch_size, + bias=not pre_norm) + if pool_type in ('token', 'token_fc'): + self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim)) + self.pos_embedding = nn.Parameter(gain * torch.randn( + 1, self.num_patches + + (1 if pool_type in ('token', 'token_fc') else 0), dim)) + self.dropout = nn.Dropout(embedding_dropout) + + # transformer + self.pre_norm = LayerNorm(dim, eps=norm_eps) if pre_norm else None + self.transformer = nn.Sequential(*[ + AttentionBlock(dim, mlp_ratio, num_heads, post_norm, False, + activation, attn_dropout, proj_dropout, norm_eps) + for _ in range(num_layers) + ]) + self.post_norm = LayerNorm(dim, eps=norm_eps) + + # head + if pool_type == 'token': + self.head = nn.Parameter(gain * torch.randn(dim, out_dim)) + elif pool_type == 'token_fc': + self.head = nn.Linear(dim, out_dim) + elif pool_type == 'attn_pool': + self.head = AttentionPool(dim, mlp_ratio, num_heads, activation, + proj_dropout, norm_eps) + + def forward(self, x, interpolation=False, use_31_block=False): + b = x.size(0) + + # embeddings + x = self.patch_embedding(x).flatten(2).permute(0, 2, 1) + if self.pool_type in ('token', 'token_fc'): + x = torch.cat([self.cls_embedding.expand(b, -1, -1).to(dtype=x.dtype, device=x.device), x], dim=1) + if interpolation: + e = pos_interpolate(self.pos_embedding, x.size(1)) + else: + e = self.pos_embedding + e = e.to(dtype=x.dtype, device=x.device) + x = self.dropout(x + e) + if self.pre_norm is not None: + x = self.pre_norm(x) + + # transformer + if use_31_block: + x = self.transformer[:-1](x) + return x + else: + x = self.transformer(x) + return x + + +class CLIP(nn.Module): + + def __init__(self, + embed_dim=512, + image_size=224, + patch_size=16, + vision_dim=768, + vision_mlp_ratio=4, + vision_heads=12, + vision_layers=12, + vision_pool='token', + vision_pre_norm=True, + vision_post_norm=False, + vocab_size=49408, + text_len=77, + text_dim=512, + text_mlp_ratio=4, + text_heads=8, + text_layers=12, + text_causal=True, + text_pool='argmax', + text_head_bias=False, + logit_bias=None, + activation='quick_gelu', + attn_dropout=0.0, + proj_dropout=0.0, + embedding_dropout=0.0, + norm_eps=1e-5): + super().__init__() + self.embed_dim = embed_dim + self.image_size = image_size + self.patch_size = patch_size + self.vision_dim = vision_dim + self.vision_mlp_ratio = vision_mlp_ratio + self.vision_heads = vision_heads + self.vision_layers = vision_layers + self.vision_pool = vision_pool + self.vision_pre_norm = vision_pre_norm + self.vision_post_norm = vision_post_norm + self.vocab_size = vocab_size + self.text_len = text_len + self.text_dim = text_dim + self.text_mlp_ratio = text_mlp_ratio + self.text_heads = text_heads + self.text_layers = text_layers + self.text_causal = text_causal + self.text_pool = text_pool + self.text_head_bias = text_head_bias + self.norm_eps = norm_eps + + # models + self.visual = VisionTransformer( + image_size=image_size, + patch_size=patch_size, + dim=vision_dim, + mlp_ratio=vision_mlp_ratio, + out_dim=embed_dim, + num_heads=vision_heads, + num_layers=vision_layers, + pool_type=vision_pool, + pre_norm=vision_pre_norm, + post_norm=vision_post_norm, + activation=activation, + attn_dropout=attn_dropout, + proj_dropout=proj_dropout, + embedding_dropout=embedding_dropout, + norm_eps=norm_eps) + self.textual = TextTransformer( + vocab_size=vocab_size, + text_len=text_len, + dim=text_dim, + mlp_ratio=text_mlp_ratio, + out_dim=embed_dim, + num_heads=text_heads, + num_layers=text_layers, + causal=text_causal, + pool_type=text_pool, + head_bias=text_head_bias, + activation=activation, + attn_dropout=attn_dropout, + proj_dropout=proj_dropout, + embedding_dropout=embedding_dropout, + norm_eps=norm_eps) + self.log_scale = nn.Parameter(math.log(1 / 0.07) * torch.ones([])) + if logit_bias is not None: + self.logit_bias = nn.Parameter(logit_bias * torch.ones([])) + + # initialize weights + self.init_weights() + + def forward(self, imgs, txt_ids): + """ + imgs: [B, 3, H, W] of torch.float32. + - mean: [0.48145466, 0.4578275, 0.40821073] + - std: [0.26862954, 0.26130258, 0.27577711] + txt_ids: [B, L] of torch.long. Encoded by data.CLIPTokenizer. + """ + xi = self.visual(imgs) + xt = self.textual(txt_ids) + return xi, xt + + def init_weights(self): + # embeddings + nn.init.normal_(self.textual.token_embedding.weight, std=0.02) + nn.init.normal_(self.visual.patch_embedding.weight, std=0.1) + + # attentions + for modality in ['visual', 'textual']: + dim = self.vision_dim if modality == 'visual' else self.text_dim + transformer = getattr(self, modality).transformer + proj_gain = (1.0 / math.sqrt(dim)) * ( + 1.0 / math.sqrt(2 * len(transformer))) + attn_gain = 1.0 / math.sqrt(dim) + mlp_gain = 1.0 / math.sqrt(2.0 * dim) + for block in transformer: + nn.init.normal_(block.attn.to_qkv.weight, std=attn_gain) + nn.init.normal_(block.attn.proj.weight, std=proj_gain) + nn.init.normal_(block.mlp[0].weight, std=mlp_gain) + nn.init.normal_(block.mlp[2].weight, std=proj_gain) + + def param_groups(self): + groups = [{ + 'params': [ + p for n, p in self.named_parameters() + if 'norm' in n or n.endswith('bias') + ], + 'weight_decay': 0.0 + }, { + 'params': [ + p for n, p in self.named_parameters() + if not ('norm' in n or n.endswith('bias')) + ] + }] + return groups + + +class XLMRobertaWithHead(XLMRoberta): + + def __init__(self, **kwargs): + self.out_dim = kwargs.pop('out_dim') + super().__init__(**kwargs) + + # head + mid_dim = (self.dim + self.out_dim) // 2 + self.head = nn.Sequential( + nn.Linear(self.dim, mid_dim, bias=False), nn.GELU(), + nn.Linear(mid_dim, self.out_dim, bias=False)) + + def forward(self, ids): + # xlm-roberta + x = super().forward(ids) + + # average pooling + mask = ids.ne(self.pad_id).unsqueeze(-1).to(x) + x = (x * mask).sum(dim=1) / mask.sum(dim=1) + + # head + x = self.head(x) + return x + + +class XLMRobertaCLIP(nn.Module): + + def __init__(self, + embed_dim=1024, + image_size=224, + patch_size=14, + vision_dim=1280, + vision_mlp_ratio=4, + vision_heads=16, + vision_layers=32, + vision_pool='token', + vision_pre_norm=True, + vision_post_norm=False, + activation='gelu', + vocab_size=250002, + max_text_len=514, + type_size=1, + pad_id=1, + text_dim=1024, + text_heads=16, + text_layers=24, + text_post_norm=True, + text_dropout=0.1, + attn_dropout=0.0, + proj_dropout=0.0, + embedding_dropout=0.0, + norm_eps=1e-5): + super().__init__() + self.embed_dim = embed_dim + self.image_size = image_size + self.patch_size = patch_size + self.vision_dim = vision_dim + self.vision_mlp_ratio = vision_mlp_ratio + self.vision_heads = vision_heads + self.vision_layers = vision_layers + self.vision_pre_norm = vision_pre_norm + self.vision_post_norm = vision_post_norm + self.activation = activation + self.vocab_size = vocab_size + self.max_text_len = max_text_len + self.type_size = type_size + self.pad_id = pad_id + self.text_dim = text_dim + self.text_heads = text_heads + self.text_layers = text_layers + self.text_post_norm = text_post_norm + self.norm_eps = norm_eps + + # models + self.visual = VisionTransformer( + image_size=image_size, + patch_size=patch_size, + dim=vision_dim, + mlp_ratio=vision_mlp_ratio, + out_dim=embed_dim, + num_heads=vision_heads, + num_layers=vision_layers, + pool_type=vision_pool, + pre_norm=vision_pre_norm, + post_norm=vision_post_norm, + activation=activation, + attn_dropout=attn_dropout, + proj_dropout=proj_dropout, + embedding_dropout=embedding_dropout, + norm_eps=norm_eps) + self.textual = None + self.log_scale = nn.Parameter(math.log(1 / 0.07) * torch.ones([])) + + def forward(self, imgs, txt_ids): + """ + imgs: [B, 3, H, W] of torch.float32. + - mean: [0.48145466, 0.4578275, 0.40821073] + - std: [0.26862954, 0.26130258, 0.27577711] + txt_ids: [B, L] of torch.long. + Encoded by data.CLIPTokenizer. + """ + xi = self.visual(imgs) + xt = self.textual(txt_ids) + return xi, xt + + def param_groups(self): + groups = [{ + 'params': [ + p for n, p in self.named_parameters() + if 'norm' in n or n.endswith('bias') + ], + 'weight_decay': 0.0 + }, { + 'params': [ + p for n, p in self.named_parameters() + if not ('norm' in n or n.endswith('bias')) + ] + }] + return groups + + +def _clip(pretrained=False, + pretrained_name=None, + model_cls=CLIP, + return_transforms=False, + return_tokenizer=False, + tokenizer_padding='eos', + dtype=torch.float32, + device='cpu', + **kwargs): + # init model + if pretrained and pretrained_name: + from sora import BUCKET, DOWNLOAD_TO_CACHE + + # init a meta model + with torch.device('meta'): + model = model_cls(**kwargs) + + # checkpoint path + checkpoint = f'models/clip/{pretrained_name}' + if dtype in (torch.float16, torch.bfloat16): + suffix = '-' + { + torch.float16: 'fp16', + torch.bfloat16: 'bf16' + }[dtype] + if object_exists(BUCKET, f'{checkpoint}{suffix}.pth'): + checkpoint = f'{checkpoint}{suffix}' + checkpoint += '.pth' + + # load + model.load_state_dict( + torch.load(DOWNLOAD_TO_CACHE(checkpoint), map_location=device), + assign=True, + strict=False) + else: + # init a model on device + with torch.device(device): + model = model_cls(**kwargs) + + # set device + output = (model,) + + # init transforms + if return_transforms: + # mean and std + if 'siglip' in pretrained_name.lower(): + mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5] + else: + mean = [0.48145466, 0.4578275, 0.40821073] + std = [0.26862954, 0.26130258, 0.27577711] + + # transforms + transforms = T.Compose([ + T.Resize((model.image_size, model.image_size), + interpolation=T.InterpolationMode.BICUBIC), + T.ToTensor(), + T.Normalize(mean=mean, std=std) + ]) + output += (transforms,) + + # init tokenizer + if return_tokenizer: + from sora import data + if 'siglip' in pretrained_name.lower(): + tokenizer = data.HuggingfaceTokenizer( + name=f'timm/{pretrained_name}', + seq_len=model.text_len, + clean='canonicalize') + elif 'xlm' in pretrained_name.lower(): + tokenizer = data.HuggingfaceTokenizer( + name='xlm-roberta-large', + seq_len=model.max_text_len - 2, + clean='whitespace') + elif 'mba' in pretrained_name.lower(): + tokenizer = data.HuggingfaceTokenizer( + name='facebook/xlm-roberta-xl', + seq_len=model.max_text_len - 2, + clean='whitespace') + else: + tokenizer = data.CLIPTokenizer( + seq_len=model.text_len, padding=tokenizer_padding) + output += (tokenizer,) + return output[0] if len(output) == 1 else output + + +def clip_xlm_roberta_vit_h_14( + pretrained=False, + pretrained_name='open-clip-xlm-roberta-large-vit-huge-14', + **kwargs): + cfg = dict( + embed_dim=1024, + image_size=224, + patch_size=14, + vision_dim=1280, + vision_mlp_ratio=4, + vision_heads=16, + vision_layers=32, + vision_pool='token', + activation='gelu', + vocab_size=250002, + max_text_len=514, + type_size=1, + pad_id=1, + text_dim=1024, + text_heads=16, + text_layers=24, + text_post_norm=True, + text_dropout=0.1, + attn_dropout=0.0, + proj_dropout=0.0, + embedding_dropout=0.0) + cfg.update(**kwargs) + return _clip(pretrained, pretrained_name, XLMRobertaCLIP, **cfg) + + +class WanImageEncoder(torch.nn.Module): + + def __init__(self): + super().__init__() + # init model + self.model, self.transforms = clip_xlm_roberta_vit_h_14( + pretrained=False, + return_transforms=True, + return_tokenizer=False, + dtype=torch.float32, + device="cpu") + + def encode_image(self, videos): + # preprocess + size = (self.model.image_size,) * 2 + videos = torch.cat([ + F.interpolate( + u, + size=size, + mode='bicubic', + align_corners=False) for u in videos + ]) + videos = self.transforms.transforms[-1](videos.mul_(0.5).add_(0.5)) + + # forward + out = self.model.visual(videos, use_31_block=True) + return out diff --git a/diffsynth/models/wan_video_mot.py b/diffsynth/models/wan_video_mot.py new file mode 100644 index 0000000000000000000000000000000000000000..4091c91777355dce91ccefac56679f8b936e7abb --- /dev/null +++ b/diffsynth/models/wan_video_mot.py @@ -0,0 +1,169 @@ +import torch +from .wan_video_dit import DiTBlock, SelfAttention, rope_apply, flash_attention, modulate, MLP +import einops +import torch.nn as nn + + +class MotSelfAttention(SelfAttention): + def __init__(self, dim: int, num_heads: int, eps: float = 1e-6): + super().__init__(dim, num_heads, eps) + def forward(self, x, freqs, is_before_attn=False): + if is_before_attn: + q = self.norm_q(self.q(x)) + k = self.norm_k(self.k(x)) + v = self.v(x) + q = rope_apply(q, freqs, self.num_heads) + k = rope_apply(k, freqs, self.num_heads) + return q, k, v + else: + return self.o(x) + + +class MotWanAttentionBlock(DiTBlock): + def __init__(self, has_image_input, dim, num_heads, ffn_dim, eps=1e-6, block_id=0): + super().__init__(has_image_input, dim, num_heads, ffn_dim, eps=eps) + self.block_id = block_id + + self.self_attn = MotSelfAttention(dim, num_heads, eps) + + + def forward(self, wan_block, x, context, t_mod, freqs, x_mot, context_mot, t_mod_mot, freqs_mot): + + # 1. prepare scale parameter + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + wan_block.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(6, dim=1) + + scale_params_mot_ref = self.modulation + t_mod_mot.float() + scale_params_mot_ref = einops.rearrange(scale_params_mot_ref, '(b n) t c -> b n t c', n=1) + shift_msa_mot_ref, scale_msa_mot_ref, gate_msa_mot_ref, c_shift_msa_mot_ref, c_scale_msa_mot_ref, c_gate_msa_mot_ref = scale_params_mot_ref.chunk(6, dim=2) + + # 2. Self-attention + input_x = modulate(wan_block.norm1(x), shift_msa, scale_msa) + # original block self-attn + attn1 = wan_block.self_attn + q = attn1.norm_q(attn1.q(input_x)) + k = attn1.norm_k(attn1.k(input_x)) + v = attn1.v(input_x) + q = rope_apply(q, freqs, attn1.num_heads) + k = rope_apply(k, freqs, attn1.num_heads) + + # mot block self-attn + norm_x_mot = einops.rearrange(self.norm1(x_mot.float()), 'b (n t) c -> b n t c', n=1) + norm_x_mot = modulate(norm_x_mot, shift_msa_mot_ref, scale_msa_mot_ref).type_as(x_mot) + norm_x_mot = einops.rearrange(norm_x_mot, 'b n t c -> b (n t) c', n=1) + q_mot,k_mot,v_mot = self.self_attn(norm_x_mot, freqs_mot, is_before_attn=True) + + tmp_hidden_states = flash_attention( + torch.cat([q, q_mot], dim=-2), + torch.cat([k, k_mot], dim=-2), + torch.cat([v, v_mot], dim=-2), + num_heads=attn1.num_heads) + + attn_output, attn_output_mot = torch.split(tmp_hidden_states, [q.shape[-2], q_mot.shape[-2]], dim=-2) + + attn_output = attn1.o(attn_output) + x = wan_block.gate(x, gate_msa, attn_output) + + attn_output_mot = self.self_attn(x=attn_output_mot,freqs=freqs_mot, is_before_attn=False) + # gate + attn_output_mot = einops.rearrange(attn_output_mot, 'b (n t) c -> b n t c', n=1) + attn_output_mot = attn_output_mot * gate_msa_mot_ref + attn_output_mot = einops.rearrange(attn_output_mot, 'b n t c -> b (n t) c', n=1) + x_mot = (x_mot.float() + attn_output_mot).type_as(x_mot) + + # 3. cross-attention and feed-forward + x = x + wan_block.cross_attn(wan_block.norm3(x), context) + input_x = modulate(wan_block.norm2(x), shift_mlp, scale_mlp) + x = wan_block.gate(x, gate_mlp, wan_block.ffn(input_x)) + + x_mot = x_mot + self.cross_attn(self.norm3(x_mot),context_mot) + # modulate + norm_x_mot_ref = einops.rearrange(self.norm2(x_mot.float()), 'b (n t) c -> b n t c', n=1) + norm_x_mot_ref = (norm_x_mot_ref * (1 + c_scale_msa_mot_ref) + c_shift_msa_mot_ref).type_as(x_mot) + norm_x_mot_ref = einops.rearrange(norm_x_mot_ref, 'b n t c -> b (n t) c', n=1) + input_x_mot = self.ffn(norm_x_mot_ref) + # gate + input_x_mot = einops.rearrange(input_x_mot, 'b (n t) c -> b n t c', n=1) + input_x_mot = input_x_mot.float() * c_gate_msa_mot_ref + input_x_mot = einops.rearrange(input_x_mot, 'b n t c -> b (n t) c', n=1) + x_mot = (x_mot.float() + input_x_mot).type_as(x_mot) + + return x, x_mot + + +class MotWanModel(torch.nn.Module): + def __init__( + self, + mot_layers=(0, 4, 8, 12, 16, 20, 24, 28, 32, 36), + patch_size=(1, 2, 2), + has_image_input=True, + has_image_pos_emb=False, + dim=5120, + num_heads=40, + ffn_dim=13824, + freq_dim=256, + text_dim=4096, + in_dim=36, + eps=1e-6, + ): + super().__init__() + self.mot_layers = mot_layers + self.freq_dim = freq_dim + self.dim = dim + + self.mot_layers_mapping = {i: n for n, i in enumerate(self.mot_layers)} + self.head_dim = dim // num_heads + + self.patch_embedding = nn.Conv3d( + in_dim, dim, kernel_size=patch_size, stride=patch_size) + + self.text_embedding = nn.Sequential( + nn.Linear(text_dim, dim), + nn.GELU(approximate='tanh'), + nn.Linear(dim, dim) + ) + self.time_embedding = nn.Sequential( + nn.Linear(freq_dim, dim), + nn.SiLU(), + nn.Linear(dim, dim) + ) + self.time_projection = nn.Sequential( + nn.SiLU(), nn.Linear(dim, dim * 6)) + if has_image_input: + self.img_emb = MLP(1280, dim, has_pos_emb=has_image_pos_emb) + + # mot blocks + self.blocks = torch.nn.ModuleList([ + MotWanAttentionBlock(has_image_input, dim, num_heads, ffn_dim, eps, block_id=i) + for i in self.mot_layers + ]) + + + def patchify(self, x: torch.Tensor): + x = self.patch_embedding(x) + return x + + def compute_freqs_mot(self, f, h, w, end: int = 1024, theta: float = 10000.0): + def precompute_freqs_cis(dim: int, start: int = 0, end: int = 1024, theta: float = 10000.0): + # 1d rope precompute + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2) + [: (dim // 2)].double() / dim)) + freqs = torch.outer(torch.arange(start, end, device=freqs.device), freqs) + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 + return freqs_cis + + f_freqs_cis = precompute_freqs_cis(self.head_dim - 2 * (self.head_dim // 3), -f, end, theta) + h_freqs_cis = precompute_freqs_cis(self.head_dim // 3, 0, end, theta) + w_freqs_cis = precompute_freqs_cis(self.head_dim // 3, 0, end, theta) + + freqs = torch.cat([ + f_freqs_cis[:f].view(f, 1, 1, -1).expand(f, h, w, -1), + h_freqs_cis[:h].view(1, h, 1, -1).expand(f, h, w, -1), + w_freqs_cis[:w].view(1, 1, w, -1).expand(f, h, w, -1) + ], dim=-1).reshape(f * h * w, 1, -1) + return freqs + + def forward(self, wan_block, x, context, t_mod, freqs, x_mot, context_mot, t_mod_mot, freqs_mot, block_id): + block = self.blocks[self.mot_layers_mapping[block_id]] + x, x_mot = block(wan_block, x, context, t_mod, freqs, x_mot, context_mot, t_mod_mot, freqs_mot) + return x, x_mot diff --git a/diffsynth/models/wan_video_motion_controller.py b/diffsynth/models/wan_video_motion_controller.py new file mode 100644 index 0000000000000000000000000000000000000000..34763a8d76e57bc8efff84f23863938cc2309029 --- /dev/null +++ b/diffsynth/models/wan_video_motion_controller.py @@ -0,0 +1,27 @@ +import torch +import torch.nn as nn +from .wan_video_dit import sinusoidal_embedding_1d + + + +class WanMotionControllerModel(torch.nn.Module): + def __init__(self, freq_dim=256, dim=1536): + super().__init__() + self.freq_dim = freq_dim + self.linear = nn.Sequential( + nn.Linear(freq_dim, dim), + nn.SiLU(), + nn.Linear(dim, dim), + nn.SiLU(), + nn.Linear(dim, dim * 6), + ) + + def forward(self, motion_bucket_id): + emb = sinusoidal_embedding_1d(self.freq_dim, motion_bucket_id * 10) + emb = self.linear(emb) + return emb + + def init(self): + state_dict = self.linear[-1].state_dict() + state_dict = {i: state_dict[i] * 0 for i in state_dict} + self.linear[-1].load_state_dict(state_dict) diff --git a/diffsynth/models/wan_video_text_encoder.py b/diffsynth/models/wan_video_text_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..64090db8c65138abfdb60a822b3ba2e74fefeb4c --- /dev/null +++ b/diffsynth/models/wan_video_text_encoder.py @@ -0,0 +1,330 @@ +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers import AutoTokenizer +import ftfy +import html +import string +import regex as re + +def fp16_clamp(x): + if x.dtype == torch.float16 and torch.isinf(x).any(): + clamp = torch.finfo(x.dtype).max - 1000 + x = torch.clamp(x, min=-clamp, max=clamp) + return x + + +class GELU(nn.Module): + + def forward(self, x): + return 0.5 * x * (1.0 + torch.tanh( + math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0)))) + + +class T5LayerNorm(nn.Module): + + def __init__(self, dim, eps=1e-6): + super(T5LayerNorm, self).__init__() + self.dim = dim + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def forward(self, x): + x = x * torch.rsqrt(x.float().pow(2).mean(dim=-1, keepdim=True) + + self.eps) + if self.weight.dtype in [torch.float16, torch.bfloat16]: + x = x.type_as(self.weight) + return self.weight * x + + +class T5Attention(nn.Module): + + def __init__(self, dim, dim_attn, num_heads, dropout=0.1): + assert dim_attn % num_heads == 0 + super(T5Attention, self).__init__() + self.dim = dim + self.dim_attn = dim_attn + self.num_heads = num_heads + self.head_dim = dim_attn // num_heads + + # layers + self.q = nn.Linear(dim, dim_attn, bias=False) + self.k = nn.Linear(dim, dim_attn, bias=False) + self.v = nn.Linear(dim, dim_attn, bias=False) + self.o = nn.Linear(dim_attn, dim, bias=False) + self.dropout = nn.Dropout(dropout) + + def forward(self, x, context=None, mask=None, pos_bias=None): + """ + x: [B, L1, C]. + context: [B, L2, C] or None. + mask: [B, L2] or [B, L1, L2] or None. + """ + # check inputs + context = x if context is None else context + b, n, c = x.size(0), self.num_heads, self.head_dim + + # compute query, key, value + q = self.q(x).view(b, -1, n, c) + k = self.k(context).view(b, -1, n, c) + v = self.v(context).view(b, -1, n, c) + + # attention bias + attn_bias = x.new_zeros(b, n, q.size(1), k.size(1)) + if pos_bias is not None: + attn_bias += pos_bias + if mask is not None: + assert mask.ndim in [2, 3] + mask = mask.view(b, 1, 1, + -1) if mask.ndim == 2 else mask.unsqueeze(1) + attn_bias.masked_fill_(mask == 0, torch.finfo(x.dtype).min) + + # compute attention (T5 does not use scaling) + attn = torch.einsum('binc,bjnc->bnij', q, k) + attn_bias + attn = F.softmax(attn.float(), dim=-1).type_as(attn) + x = torch.einsum('bnij,bjnc->binc', attn, v) + + # output + x = x.reshape(b, -1, n * c) + x = self.o(x) + x = self.dropout(x) + return x + + +class T5FeedForward(nn.Module): + + def __init__(self, dim, dim_ffn, dropout=0.1): + super(T5FeedForward, self).__init__() + self.dim = dim + self.dim_ffn = dim_ffn + + # layers + self.gate = nn.Sequential(nn.Linear(dim, dim_ffn, bias=False), GELU()) + self.fc1 = nn.Linear(dim, dim_ffn, bias=False) + self.fc2 = nn.Linear(dim_ffn, dim, bias=False) + self.dropout = nn.Dropout(dropout) + + def forward(self, x): + x = self.fc1(x) * self.gate(x) + x = self.dropout(x) + x = self.fc2(x) + x = self.dropout(x) + return x + + +class T5SelfAttention(nn.Module): + + def __init__(self, + dim, + dim_attn, + dim_ffn, + num_heads, + num_buckets, + shared_pos=True, + dropout=0.1): + super(T5SelfAttention, self).__init__() + self.dim = dim + self.dim_attn = dim_attn + self.dim_ffn = dim_ffn + self.num_heads = num_heads + self.num_buckets = num_buckets + self.shared_pos = shared_pos + + # layers + self.norm1 = T5LayerNorm(dim) + self.attn = T5Attention(dim, dim_attn, num_heads, dropout) + self.norm2 = T5LayerNorm(dim) + self.ffn = T5FeedForward(dim, dim_ffn, dropout) + self.pos_embedding = None if shared_pos else T5RelativeEmbedding( + num_buckets, num_heads, bidirectional=True) + + def forward(self, x, mask=None, pos_bias=None): + e = pos_bias if self.shared_pos else self.pos_embedding( + x.size(1), x.size(1)) + x = fp16_clamp(x + self.attn(self.norm1(x), mask=mask, pos_bias=e)) + x = fp16_clamp(x + self.ffn(self.norm2(x))) + return x + + +class T5RelativeEmbedding(nn.Module): + + def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128): + super(T5RelativeEmbedding, self).__init__() + self.num_buckets = num_buckets + self.num_heads = num_heads + self.bidirectional = bidirectional + self.max_dist = max_dist + + # layers + self.embedding = nn.Embedding(num_buckets, num_heads) + + def forward(self, lq, lk): + device = self.embedding.weight.device + # rel_pos = torch.arange(lk).unsqueeze(0).to(device) - \ + # torch.arange(lq).unsqueeze(1).to(device) + rel_pos = torch.arange(lk, device=device).unsqueeze(0) - \ + torch.arange(lq, device=device).unsqueeze(1) + rel_pos = self._relative_position_bucket(rel_pos) + rel_pos_embeds = self.embedding(rel_pos) + rel_pos_embeds = rel_pos_embeds.permute(2, 0, 1).unsqueeze( + 0) # [1, N, Lq, Lk] + return rel_pos_embeds.contiguous() + + def _relative_position_bucket(self, rel_pos): + # preprocess + if self.bidirectional: + num_buckets = self.num_buckets // 2 + rel_buckets = (rel_pos > 0).long() * num_buckets + rel_pos = torch.abs(rel_pos) + else: + num_buckets = self.num_buckets + rel_buckets = 0 + rel_pos = -torch.min(rel_pos, torch.zeros_like(rel_pos)) + + # embeddings for small and large positions + max_exact = num_buckets // 2 + rel_pos_large = max_exact + (torch.log(rel_pos.float() / max_exact) / + math.log(self.max_dist / max_exact) * + (num_buckets - max_exact)).long() + rel_pos_large = torch.min( + rel_pos_large, torch.full_like(rel_pos_large, num_buckets - 1)) + rel_buckets += torch.where(rel_pos < max_exact, rel_pos, rel_pos_large) + return rel_buckets + +def init_weights(m): + if isinstance(m, T5LayerNorm): + nn.init.ones_(m.weight) + elif isinstance(m, T5FeedForward): + nn.init.normal_(m.gate[0].weight, std=m.dim**-0.5) + nn.init.normal_(m.fc1.weight, std=m.dim**-0.5) + nn.init.normal_(m.fc2.weight, std=m.dim_ffn**-0.5) + elif isinstance(m, T5Attention): + nn.init.normal_(m.q.weight, std=(m.dim * m.dim_attn)**-0.5) + nn.init.normal_(m.k.weight, std=m.dim**-0.5) + nn.init.normal_(m.v.weight, std=m.dim**-0.5) + nn.init.normal_(m.o.weight, std=(m.num_heads * m.dim_attn)**-0.5) + elif isinstance(m, T5RelativeEmbedding): + nn.init.normal_( + m.embedding.weight, std=(2 * m.num_buckets * m.num_heads)**-0.5) + + +class WanTextEncoder(torch.nn.Module): + + def __init__(self, + vocab=256384, + dim=4096, + dim_attn=4096, + dim_ffn=10240, + num_heads=64, + num_layers=24, + num_buckets=32, + shared_pos=False, + dropout=0.1): + super(WanTextEncoder, self).__init__() + self.dim = dim + self.dim_attn = dim_attn + self.dim_ffn = dim_ffn + self.num_heads = num_heads + self.num_layers = num_layers + self.num_buckets = num_buckets + self.shared_pos = shared_pos + + # layers + self.token_embedding = vocab if isinstance(vocab, nn.Embedding) \ + else nn.Embedding(vocab, dim) + self.pos_embedding = T5RelativeEmbedding( + num_buckets, num_heads, bidirectional=True) if shared_pos else None + self.dropout = nn.Dropout(dropout) + self.blocks = nn.ModuleList([ + T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets, + shared_pos, dropout) for _ in range(num_layers) + ]) + self.norm = T5LayerNorm(dim) + + # initialize weights + self.apply(init_weights) + + def forward(self, ids, mask=None): + x = self.token_embedding(ids) + x = self.dropout(x) + e = self.pos_embedding(x.size(1), + x.size(1)) if self.shared_pos else None + for block in self.blocks: + x = block(x, mask, pos_bias=e) + x = self.norm(x) + x = self.dropout(x) + return x + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r'\s+', ' ', text) + text = text.strip() + return text + + +def canonicalize(text, keep_punctuation_exact_string=None): + text = text.replace('_', ' ') + if keep_punctuation_exact_string: + text = keep_punctuation_exact_string.join( + part.translate(str.maketrans('', '', string.punctuation)) + for part in text.split(keep_punctuation_exact_string)) + else: + text = text.translate(str.maketrans('', '', string.punctuation)) + text = text.lower() + text = re.sub(r'\s+', ' ', text) + return text.strip() + + +class HuggingfaceTokenizer: + + def __init__(self, name, seq_len=None, clean=None, **kwargs): + assert clean in (None, 'whitespace', 'lower', 'canonicalize') + self.name = name + self.seq_len = seq_len + self.clean = clean + + # init tokenizer + self.tokenizer = AutoTokenizer.from_pretrained(name, **kwargs) + self.vocab_size = self.tokenizer.vocab_size + + def __call__(self, sequence, **kwargs): + return_mask = kwargs.pop('return_mask', False) + + # arguments + _kwargs = {'return_tensors': 'pt'} + if self.seq_len is not None: + _kwargs.update({ + 'padding': 'max_length', + 'truncation': True, + 'max_length': self.seq_len + }) + _kwargs.update(**kwargs) + + # tokenization + if isinstance(sequence, str): + sequence = [sequence] + if self.clean: + sequence = [self._clean(u) for u in sequence] + ids = self.tokenizer(sequence, **_kwargs) + + # output + if return_mask: + return ids.input_ids, ids.attention_mask + else: + return ids.input_ids + + def _clean(self, text): + if self.clean == 'whitespace': + text = whitespace_clean(basic_clean(text)) + elif self.clean == 'lower': + text = whitespace_clean(basic_clean(text)).lower() + elif self.clean == 'canonicalize': + text = canonicalize(basic_clean(text)) + return text \ No newline at end of file diff --git a/diffsynth/models/wan_video_vace.py b/diffsynth/models/wan_video_vace.py new file mode 100644 index 0000000000000000000000000000000000000000..ee0f72998d77962f77775629ac5e83611a1acc2b --- /dev/null +++ b/diffsynth/models/wan_video_vace.py @@ -0,0 +1,374 @@ +import torch +import torch.nn as nn +from .wan_video_dit import DiTBlock, CrossAttention, flash_attention +from ..core.gradient import gradient_checkpoint_forward + + +class _OffloadToCPU(torch.autograd.Function): + """Move tensor to CPU in forward, move gradient to GPU in backward.""" + @staticmethod + def forward(ctx, tensor): + ctx.gpu_device = tensor.device + return tensor.detach().cpu().requires_grad_(True) + + @staticmethod + def backward(ctx, grad_output): + return grad_output.to(ctx.gpu_device) + + +class _RestoreToGPU(torch.autograd.Function): + """Move CPU tensor back to GPU in forward, move gradient to CPU in backward.""" + @staticmethod + def forward(ctx, tensor, device): + ctx.save_for_backward(torch.empty(0)) + ctx._device_str = str(device) + return tensor.to(device) + + @staticmethod + def backward(ctx, grad_output): + return grad_output.cpu(), None + + +def tokenize_target_text(text, max_len=64, vocab_size=8192): + """Convert target text string to character-level token IDs. + + Uses modular hashing of Unicode code points to map any character + (ASCII, CJK, Cyrillic, math symbols, etc.) into a fixed vocabulary. + Token 0 is reserved for padding. + """ + ids = [] + for ch in text[:max_len]: + ids.append(ord(ch) % (vocab_size - 1) + 1) + ids += [0] * (max_len - len(ids)) + return ids + + +class ConditionCrossAttention(nn.Module): + """Cross-attention for injecting condition features into VACE blocks. + + Used for both glyph visual features (v2) and character-level text + tokens (v3). Each spatial position in the VACE hidden states queries + the condition tokens to determine what should be rendered at that location. + + Zero-initialized output projection ensures the model starts from + pretrained VACE behavior and gradually learns to use the condition. + """ + + def __init__(self, dim, num_heads, eps=1e-6): + super().__init__() + self.num_heads = num_heads + + self.norm = nn.LayerNorm(dim, eps=eps) + self.q = nn.Linear(dim, dim) + self.k = nn.Linear(dim, dim) + self.v = nn.Linear(dim, dim) + self.o = nn.Linear(dim, dim) + + # Zero-init output projection so condition attention starts as no-op + nn.init.zeros_(self.o.weight) + nn.init.zeros_(self.o.bias) + + def forward(self, x, condition_tokens): + """ + Args: + x: VACE hidden states (B, seq_len, dim) + condition_tokens: condition features (B, num_tokens, dim) + Returns: + x + condition_attn_output (B, seq_len, dim) + """ + residual = x + x = self.norm(x) + q = self.q(x) + k = self.k(condition_tokens) + v = self.v(condition_tokens) + out = flash_attention(q, k, v, self.num_heads) + return residual + self.o(out) + + +# Backward compatibility alias +GlyphCrossAttention = ConditionCrossAttention + + +class VaceWanAttentionBlock(DiTBlock): + def __init__(self, has_image_input, dim, num_heads, ffn_dim, eps=1e-6, block_id=0): + super().__init__(has_image_input, dim, num_heads, ffn_dim, eps=eps) + self.block_id = block_id + if block_id == 0: + self.before_proj = torch.nn.Linear(self.dim, self.dim) + self.after_proj = torch.nn.Linear(self.dim, self.dim) + + def forward(self, c, x, context, t_mod, freqs, condition_tokens=None): + if self.block_id == 0: + c = self.before_proj(c) + x + c = super().forward(c, context, t_mod, freqs) + + # Condition cross-attention (injected after text cross-attn + FFN) + # Works for both glyph tokens (v2) and character tokens (v3) + if condition_tokens is not None and hasattr(self, 'condition_cross_attn'): + c = self.condition_cross_attn(c, condition_tokens) + + c_skip = self.after_proj(c) + return c_skip, c + + +class GlyphEncoder(nn.Module): + """Lightweight encoder that compresses glyph latents into tokens + for cross-attention in VACE blocks. + + glyph_latent (16ch, from VAE) → Conv3D patch embed → spatial pooling + → compact token sequence that each VACE block attends to. + """ + + def __init__(self, in_channels=16, dim=1536, num_tokens=64, patch_size=(1, 2, 2)): + super().__init__() + self.num_tokens = num_tokens + + # Patch embedding (same architecture as vace_patch_embedding) + self.patch_embed = nn.Conv3d(in_channels, dim, kernel_size=patch_size, stride=patch_size) + + # Compress spatial sequence to fixed number of tokens via cross-attention pooling + self.query_tokens = nn.Parameter(torch.randn(1, num_tokens, dim) * 0.02) + self.pool_attn = nn.MultiheadAttention(dim, num_heads=8, batch_first=True) + self.pool_norm = nn.LayerNorm(dim) + + # Output projection (zero-init for safe initialization) + self.out_proj = nn.Linear(dim, dim) + nn.init.zeros_(self.out_proj.weight) + nn.init.zeros_(self.out_proj.bias) + + def forward(self, glyph_latent): + """ + Args: + glyph_latent: (B, 16, T, H, W) from VAE-encoded glyph video + Returns: + glyph_tokens: (B, num_tokens, dim) compressed glyph features + """ + # Patch embed: (B, dim, T, H/2, W/2) + x = self.patch_embed(glyph_latent) + B = x.shape[0] + # Flatten spatial: (B, dim, N) → (B, N, dim) + x = x.flatten(2).transpose(1, 2) + + # Cross-attention pooling: compress N spatial tokens → num_tokens + queries = self.query_tokens.expand(B, -1, -1).to(dtype=x.dtype, device=x.device) + x_normed = self.pool_norm(x) + pooled, _ = self.pool_attn(queries, x_normed, x_normed) + + return self.out_proj(pooled) + + +class TargetTextEncoder(nn.Module): + """Character-level encoder for target text strings. + + Encodes the target text (e.g., "NAD", "CARLIFE", "善") at the character + level using learned embeddings and a small Transformer. This provides + character-precise identity information that T5's subword tokenization + cannot guarantee. + + The output tokens are used as K/V in ConditionCrossAttention within + each VACE block, allowing each spatial position to query which character + it should render. + """ + + def __init__(self, vocab_size=8192, max_len=64, dim=1536, + num_layers=2, num_heads=8): + super().__init__() + self.vocab_size = vocab_size + self.max_len = max_len + + self.char_embed = nn.Embedding(vocab_size, dim, padding_idx=0) + self.pos_embed = nn.Embedding(max_len, dim) + + encoder_layer = nn.TransformerEncoderLayer( + d_model=dim, nhead=num_heads, dim_feedforward=dim * 4, + batch_first=True, norm_first=True, activation='gelu', + ) + self.transformer = nn.TransformerEncoder( + encoder_layer, num_layers=num_layers, + ) + + # Zero-init output projection for safe initialization + self.out_proj = nn.Linear(dim, dim) + nn.init.zeros_(self.out_proj.weight) + nn.init.zeros_(self.out_proj.bias) + + def forward(self, token_ids): + """ + Args: + token_ids: (B, max_len) long tensor of character IDs (0 = PAD) + Returns: + text_tokens: (B, max_len, dim) + """ + B, L = token_ids.shape + positions = torch.arange(L, device=token_ids.device).unsqueeze(0).expand(B, -1) + + x = self.char_embed(token_ids) + self.pos_embed(positions) + + # Padding mask: True means ignore this position + pad_mask = (token_ids == 0) + + x = self.transformer(x, src_key_padding_mask=pad_mask) + return self.out_proj(x) + + +class VaceWanModel(torch.nn.Module): + def __init__( + self, + vace_layers=(0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28), + vace_in_dim=96, + glyph_channels=0, + glyph_num_tokens=64, + use_target_text_encoder=False, + text_vocab_size=8192, + text_max_len=64, + text_num_layers=2, + patch_size=(1, 2, 2), + has_image_input=False, + dim=1536, + num_heads=12, + ffn_dim=8960, + eps=1e-6, + ): + super().__init__() + self.vace_layers = vace_layers + self.vace_in_dim = vace_in_dim + self.glyph_channels = glyph_channels + self.use_target_text_encoder = use_target_text_encoder + self.vace_layers_mapping = {i: n for n, i in enumerate(self.vace_layers)} + + # VACE blocks + self.vace_blocks = torch.nn.ModuleList([ + VaceWanAttentionBlock(has_image_input, dim, num_heads, ffn_dim, eps, block_id=i) + for i in self.vace_layers + ]) + + # VACE patch embedding (original 96 channels, no glyph concat) + self.vace_patch_embedding = torch.nn.Conv3d( + vace_in_dim, dim, kernel_size=patch_size, stride=patch_size + ) + + # Glyph pathway (v2: glyph video → GlyphEncoder → cross-attention) + if glyph_channels > 0: + self.glyph_encoder = GlyphEncoder( + in_channels=glyph_channels, + dim=dim, + num_tokens=glyph_num_tokens, + patch_size=patch_size, + ) + for block in self.vace_blocks: + block.condition_cross_attn = ConditionCrossAttention(dim, num_heads, eps) + + # Target text pathway (v3: text string → TargetTextEncoder → cross-attention) + if use_target_text_encoder: + self.target_text_encoder = TargetTextEncoder( + vocab_size=text_vocab_size, + max_len=text_max_len, + dim=dim, + num_layers=text_num_layers, + num_heads=num_heads, + ) + for block in self.vace_blocks: + block.condition_cross_attn = ConditionCrossAttention(dim, num_heads, eps) + + def _has_new_modules(self): + """Check if this model has modules not present in pretrained checkpoints.""" + return self.glyph_channels > 0 or self.use_target_text_encoder + + def load_state_dict(self, state_dict, strict=True, assign=False): + if self._has_new_modules(): + # New modules won't be in pretrained checkpoints. + # First, materialize any meta-tensor parameters so they can be assigned. + for name, param in self.named_parameters(): + if param.is_meta: + materialized = torch.zeros(param.shape, dtype=param.dtype, device='cpu') + parts = name.split('.') + module = self + for p in parts[:-1]: + module = getattr(module, p) + setattr(module, parts[-1], torch.nn.Parameter(materialized)) + + result = super().load_state_dict(state_dict, strict=False, assign=assign) + + # Re-initialize glyph modules (v2 mode) + if hasattr(self, 'glyph_encoder'): + for name, module in self.glyph_encoder.named_modules(): + if isinstance(module, (torch.nn.Linear, torch.nn.Conv3d)): + if 'out_proj' not in name: + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + torch.nn.init.zeros_(module.bias) + if hasattr(self.glyph_encoder, 'query_tokens'): + torch.nn.init.normal_(self.glyph_encoder.query_tokens, std=0.02) + + # Re-initialize target text encoder modules (v3 mode) + if hasattr(self, 'target_text_encoder'): + for name, module in self.target_text_encoder.named_modules(): + if isinstance(module, (torch.nn.Linear,)): + if 'out_proj' not in name: + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + torch.nn.init.zeros_(module.bias) + elif isinstance(module, torch.nn.Embedding): + torch.nn.init.normal_(module.weight, std=0.02) + if module.padding_idx is not None: + torch.nn.init.zeros_(module.weight[module.padding_idx]) + + # Re-initialize condition cross-attention modules + for block in self.vace_blocks: + if hasattr(block, 'condition_cross_attn'): + ca = block.condition_cross_attn + for name, module in ca.named_modules(): + if isinstance(module, torch.nn.Linear): + if name == 'o': + torch.nn.init.zeros_(module.weight) + torch.nn.init.zeros_(module.bias) + else: + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + torch.nn.init.zeros_(module.bias) + + # Cast re-initialized modules to model dtype (xavier_uniform_ etc. produce float32) + target_dtype = next((p.dtype for p in self.parameters() if p.dtype != torch.float32), torch.bfloat16) + if hasattr(self, 'glyph_encoder'): + self.glyph_encoder.to(dtype=target_dtype) + if hasattr(self, 'target_text_encoder'): + self.target_text_encoder.to(dtype=target_dtype) + for block in self.vace_blocks: + if hasattr(block, 'condition_cross_attn'): + block.condition_cross_attn.to(dtype=target_dtype) + + return result + return super().load_state_dict(state_dict, strict=strict, assign=assign) + + def forward( + self, x, vace_context, context, t_mod, freqs, + glyph_latent=None, + target_text_ids=None, + use_gradient_checkpointing: bool = False, + use_gradient_checkpointing_offload: bool = False, + ): + c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context] + c = [u.flatten(2).transpose(1, 2) for u in c] + c = torch.cat([ + torch.cat([u, u.new_zeros(1, x.shape[1] - u.size(1), u.size(2))], + dim=1) for u in c + ]) + + # Encode condition tokens for cross-attention + condition_tokens = None + if glyph_latent is not None and hasattr(self, 'glyph_encoder'): + condition_tokens = self.glyph_encoder(glyph_latent) + elif target_text_ids is not None and hasattr(self, 'target_text_encoder'): + condition_tokens = self.target_text_encoder(target_text_ids) + + hints = [] + for block in self.vace_blocks: + result = gradient_checkpoint_forward( + block, + use_gradient_checkpointing, + use_gradient_checkpointing_offload, + c, x, context, t_mod, freqs, condition_tokens + ) + c_skip, c = result + hints.append(c_skip) + return hints diff --git a/diffsynth/models/wan_video_vae.py b/diffsynth/models/wan_video_vae.py new file mode 100644 index 0000000000000000000000000000000000000000..f4e4a8a0df45f7c836efe4feb9a0cd6f04fbad3d --- /dev/null +++ b/diffsynth/models/wan_video_vae.py @@ -0,0 +1,1398 @@ +from einops import rearrange, repeat + +import torch +import torch.nn as nn +import torch.nn.functional as F +from tqdm import tqdm + +CACHE_T = 2 + + +def check_is_instance(model, module_class): + if isinstance(model, module_class): + return True + if hasattr(model, "module") and isinstance(model.module, module_class): + return True + return False + + +def block_causal_mask(x, block_size): + # params + b, n, s, _, device = *x.size(), x.device + assert s % block_size == 0 + num_blocks = s // block_size + + # build mask + mask = torch.zeros(b, n, s, s, dtype=torch.bool, device=device) + for i in range(num_blocks): + mask[:, :, + i * block_size:(i + 1) * block_size, :(i + 1) * block_size] = 1 + return mask + + +class CausalConv3d(nn.Conv3d): + """ + Causal 3d convolusion. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._padding = (self.padding[2], self.padding[2], self.padding[1], + self.padding[1], 2 * self.padding[0], 0) + self.padding = (0, 0, 0) + + def forward(self, x, cache_x=None): + padding = list(self._padding) + if cache_x is not None and self._padding[4] > 0: + cache_x = cache_x.to(x.device) + x = torch.cat([cache_x, x], dim=2) + padding[4] -= cache_x.shape[2] + x = F.pad(x, padding) + + return super().forward(x) + + +class RMS_norm(nn.Module): + + def __init__(self, dim, channel_first=True, images=True, bias=False): + super().__init__() + broadcastable_dims = (1, 1, 1) if not images else (1, 1) + shape = (dim, *broadcastable_dims) if channel_first else (dim,) + + self.channel_first = channel_first + self.scale = dim**0.5 + self.gamma = nn.Parameter(torch.ones(shape)) + self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0. + + def forward(self, x): + return F.normalize( + x, dim=(1 if self.channel_first else + -1)) * self.scale * self.gamma + self.bias + + +class Upsample(nn.Upsample): + + def forward(self, x): + """ + Fix bfloat16 support for nearest neighbor interpolation. + """ + return super().forward(x.float()).type_as(x) + + +class Resample(nn.Module): + + def __init__(self, dim, mode): + assert mode in ('none', 'upsample2d', 'upsample3d', 'downsample2d', + 'downsample3d') + super().__init__() + self.dim = dim + self.mode = mode + + # layers + if mode == 'upsample2d': + self.resample = nn.Sequential( + Upsample(scale_factor=(2., 2.), mode='nearest-exact'), + nn.Conv2d(dim, dim // 2, 3, padding=1)) + elif mode == 'upsample3d': + self.resample = nn.Sequential( + Upsample(scale_factor=(2., 2.), mode='nearest-exact'), + nn.Conv2d(dim, dim // 2, 3, padding=1)) + self.time_conv = CausalConv3d(dim, + dim * 2, (3, 1, 1), + padding=(1, 0, 0)) + + elif mode == 'downsample2d': + self.resample = nn.Sequential( + nn.ZeroPad2d((0, 1, 0, 1)), + nn.Conv2d(dim, dim, 3, stride=(2, 2))) + elif mode == 'downsample3d': + self.resample = nn.Sequential( + nn.ZeroPad2d((0, 1, 0, 1)), + nn.Conv2d(dim, dim, 3, stride=(2, 2))) + self.time_conv = CausalConv3d(dim, + dim, (3, 1, 1), + stride=(2, 1, 1), + padding=(0, 0, 0)) + + else: + self.resample = nn.Identity() + + def forward(self, x, feat_cache=None, feat_idx=[0]): + b, c, t, h, w = x.size() + if self.mode == 'upsample3d': + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = 'Rep' + feat_idx[0] += 1 + else: + + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[ + idx] is not None and feat_cache[idx] != 'Rep': + # cache last frame of last two chunk + cache_x = torch.cat([ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), cache_x + ], + dim=2) + if cache_x.shape[2] < 2 and feat_cache[ + idx] is not None and feat_cache[idx] == 'Rep': + cache_x = torch.cat([ + torch.zeros_like(cache_x).to(cache_x.device), + cache_x + ], + dim=2) + if feat_cache[idx] == 'Rep': + x = self.time_conv(x) + else: + x = self.time_conv(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + + x = x.reshape(b, 2, c, t, h, w) + x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), + 3) + x = x.reshape(b, c, t * 2, h, w) + t = x.shape[2] + x = rearrange(x, 'b c t h w -> (b t) c h w') + x = self.resample(x) + x = rearrange(x, '(b t) c h w -> b c t h w', t=t) + + if self.mode == 'downsample3d': + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = x.clone() + feat_idx[0] += 1 + else: + cache_x = x[:, :, -1:, :, :].clone() + x = self.time_conv( + torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2)) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + return x, feat_cache, feat_idx + + def init_weight(self, conv): + conv_weight = conv.weight + nn.init.zeros_(conv_weight) + c1, c2, t, h, w = conv_weight.size() + one_matrix = torch.eye(c1, c2) + init_matrix = one_matrix + nn.init.zeros_(conv_weight) + conv_weight.data[:, :, 1, 0, 0] = init_matrix + conv.weight.data.copy_(conv_weight) + nn.init.zeros_(conv.bias.data) + + def init_weight2(self, conv): + conv_weight = conv.weight.data + nn.init.zeros_(conv_weight) + c1, c2, t, h, w = conv_weight.size() + init_matrix = torch.eye(c1 // 2, c2) + conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix + conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix + conv.weight.data.copy_(conv_weight) + nn.init.zeros_(conv.bias.data) + + + +def patchify(x, patch_size): + if patch_size == 1: + return x + if x.dim() == 4: + x = rearrange(x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size, r=patch_size) + elif x.dim() == 5: + x = rearrange(x, + "b c f (h q) (w r) -> b (c r q) f h w", + q=patch_size, + r=patch_size) + else: + raise ValueError(f"Invalid input shape: {x.shape}") + return x + + +def unpatchify(x, patch_size): + if patch_size == 1: + return x + if x.dim() == 4: + x = rearrange(x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size, r=patch_size) + elif x.dim() == 5: + x = rearrange(x, + "b (c r q) f h w -> b c f (h q) (w r)", + q=patch_size, + r=patch_size) + return x + + +class Resample38(Resample): + + def __init__(self, dim, mode): + assert mode in ( + "none", + "upsample2d", + "upsample3d", + "downsample2d", + "downsample3d", + ) + super(Resample, self).__init__() + self.dim = dim + self.mode = mode + + # layers + if mode == "upsample2d": + self.resample = nn.Sequential( + Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), + nn.Conv2d(dim, dim, 3, padding=1), + ) + elif mode == "upsample3d": + self.resample = nn.Sequential( + Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), + nn.Conv2d(dim, dim, 3, padding=1), + ) + self.time_conv = CausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0)) + elif mode == "downsample2d": + self.resample = nn.Sequential( + nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)) + ) + elif mode == "downsample3d": + self.resample = nn.Sequential( + nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)) + ) + self.time_conv = CausalConv3d( + dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0) + ) + else: + self.resample = nn.Identity() + +class ResidualBlock(nn.Module): + + def __init__(self, in_dim, out_dim, dropout=0.0): + super().__init__() + self.in_dim = in_dim + self.out_dim = out_dim + + # layers + self.residual = nn.Sequential( + RMS_norm(in_dim, images=False), nn.SiLU(), + CausalConv3d(in_dim, out_dim, 3, padding=1), + RMS_norm(out_dim, images=False), nn.SiLU(), nn.Dropout(dropout), + CausalConv3d(out_dim, out_dim, 3, padding=1)) + self.shortcut = CausalConv3d(in_dim, out_dim, 1) \ + if in_dim != out_dim else nn.Identity() + + def forward(self, x, feat_cache=None, feat_idx=[0]): + h = self.shortcut(x) + for layer in self.residual: + if check_is_instance(layer, CausalConv3d) and feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), cache_x + ], + dim=2) + x = layer(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = layer(x) + return x + h, feat_cache, feat_idx + + +class AttentionBlock(nn.Module): + """ + Causal self-attention with a single head. + """ + + def __init__(self, dim): + super().__init__() + self.dim = dim + + # layers + self.norm = RMS_norm(dim) + self.to_qkv = nn.Conv2d(dim, dim * 3, 1) + self.proj = nn.Conv2d(dim, dim, 1) + + # zero out the last layer params + nn.init.zeros_(self.proj.weight) + + def forward(self, x): + identity = x + b, c, t, h, w = x.size() + x = rearrange(x, 'b c t h w -> (b t) c h w') + x = self.norm(x) + # compute query, key, value + q, k, v = self.to_qkv(x).reshape(b * t, 1, c * 3, -1).permute( + 0, 1, 3, 2).contiguous().chunk(3, dim=-1) + + # apply attention + x = F.scaled_dot_product_attention( + q, + k, + v, + #attn_mask=block_causal_mask(q, block_size=h * w) + ) + x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w) + + # output + x = self.proj(x) + x = rearrange(x, '(b t) c h w-> b c t h w', t=t) + return x + identity + + +class AvgDown3D(nn.Module): + def __init__( + self, + in_channels, + out_channels, + factor_t, + factor_s=1, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.factor_t = factor_t + self.factor_s = factor_s + self.factor = self.factor_t * self.factor_s * self.factor_s + + assert in_channels * self.factor % out_channels == 0 + self.group_size = in_channels * self.factor // out_channels + + def forward(self, x: torch.Tensor) -> torch.Tensor: + pad_t = (self.factor_t - x.shape[2] % self.factor_t) % self.factor_t + pad = (0, 0, 0, 0, pad_t, 0) + x = F.pad(x, pad) + B, C, T, H, W = x.shape + x = x.view( + B, + C, + T // self.factor_t, + self.factor_t, + H // self.factor_s, + self.factor_s, + W // self.factor_s, + self.factor_s, + ) + x = x.permute(0, 1, 3, 5, 7, 2, 4, 6).contiguous() + x = x.view( + B, + C * self.factor, + T // self.factor_t, + H // self.factor_s, + W // self.factor_s, + ) + x = x.view( + B, + self.out_channels, + self.group_size, + T // self.factor_t, + H // self.factor_s, + W // self.factor_s, + ) + x = x.mean(dim=2) + return x + + +class DupUp3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + factor_t, + factor_s=1, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + + self.factor_t = factor_t + self.factor_s = factor_s + self.factor = self.factor_t * self.factor_s * self.factor_s + + assert out_channels * self.factor % in_channels == 0 + self.repeats = out_channels * self.factor // in_channels + + def forward(self, x: torch.Tensor, first_chunk=False) -> torch.Tensor: + x = x.repeat_interleave(self.repeats, dim=1) + x = x.view( + x.size(0), + self.out_channels, + self.factor_t, + self.factor_s, + self.factor_s, + x.size(2), + x.size(3), + x.size(4), + ) + x = x.permute(0, 1, 5, 2, 6, 3, 7, 4).contiguous() + x = x.view( + x.size(0), + self.out_channels, + x.size(2) * self.factor_t, + x.size(4) * self.factor_s, + x.size(6) * self.factor_s, + ) + if first_chunk: + x = x[:, :, self.factor_t - 1 :, :, :] + return x + + +class Down_ResidualBlock(nn.Module): + def __init__( + self, in_dim, out_dim, dropout, mult, temperal_downsample=False, down_flag=False + ): + super().__init__() + + # Shortcut path with downsample + self.avg_shortcut = AvgDown3D( + in_dim, + out_dim, + factor_t=2 if temperal_downsample else 1, + factor_s=2 if down_flag else 1, + ) + + # Main path with residual blocks and downsample + downsamples = [] + for _ in range(mult): + downsamples.append(ResidualBlock(in_dim, out_dim, dropout)) + in_dim = out_dim + + # Add the final downsample block + if down_flag: + mode = "downsample3d" if temperal_downsample else "downsample2d" + downsamples.append(Resample38(out_dim, mode=mode)) + + self.downsamples = nn.Sequential(*downsamples) + + def forward(self, x, feat_cache=None, feat_idx=[0]): + x_copy = x.clone() + for module in self.downsamples: + x, feat_cache, feat_idx = module(x, feat_cache, feat_idx) + + return x + self.avg_shortcut(x_copy), feat_cache, feat_idx + + +class Up_ResidualBlock(nn.Module): + def __init__( + self, in_dim, out_dim, dropout, mult, temperal_upsample=False, up_flag=False + ): + super().__init__() + # Shortcut path with upsample + if up_flag: + self.avg_shortcut = DupUp3D( + in_dim, + out_dim, + factor_t=2 if temperal_upsample else 1, + factor_s=2 if up_flag else 1, + ) + else: + self.avg_shortcut = None + + # Main path with residual blocks and upsample + upsamples = [] + for _ in range(mult): + upsamples.append(ResidualBlock(in_dim, out_dim, dropout)) + in_dim = out_dim + + # Add the final upsample block + if up_flag: + mode = "upsample3d" if temperal_upsample else "upsample2d" + upsamples.append(Resample38(out_dim, mode=mode)) + + self.upsamples = nn.Sequential(*upsamples) + + def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False): + x_main = x.clone() + for module in self.upsamples: + x_main, feat_cache, feat_idx = module(x_main, feat_cache, feat_idx) + if self.avg_shortcut is not None: + x_shortcut = self.avg_shortcut(x, first_chunk) + return x_main + x_shortcut, feat_cache, feat_idx + else: + return x_main, feat_cache, feat_idx + + +class Encoder3d(nn.Module): + + def __init__(self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_downsample=[True, True, False], + dropout=0.0): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_downsample = temperal_downsample + + # dimensions + dims = [dim * u for u in [1] + dim_mult] + scale = 1.0 + + # init block + self.conv1 = CausalConv3d(3, dims[0], 3, padding=1) + + # downsample blocks + downsamples = [] + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + # residual (+attention) blocks + for _ in range(num_res_blocks): + downsamples.append(ResidualBlock(in_dim, out_dim, dropout)) + if scale in attn_scales: + downsamples.append(AttentionBlock(out_dim)) + in_dim = out_dim + + # downsample block + if i != len(dim_mult) - 1: + mode = 'downsample3d' if temperal_downsample[ + i] else 'downsample2d' + downsamples.append(Resample(out_dim, mode=mode)) + scale /= 2.0 + self.downsamples = nn.Sequential(*downsamples) + + # middle blocks + self.middle = nn.Sequential(ResidualBlock(out_dim, out_dim, dropout), + AttentionBlock(out_dim), + ResidualBlock(out_dim, out_dim, dropout)) + + # output blocks + self.head = nn.Sequential(RMS_norm(out_dim, images=False), nn.SiLU(), + CausalConv3d(out_dim, z_dim, 3, padding=1)) + + def forward(self, x, feat_cache=None, feat_idx=[0]): + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), cache_x + ], + dim=2) + x = self.conv1(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv1(x) + + ## downsamples + for layer in self.downsamples: + if feat_cache is not None: + x, feat_cache, feat_idx = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## middle + for layer in self.middle: + if check_is_instance(layer, ResidualBlock) and feat_cache is not None: + x, feat_cache, feat_idx = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## head + for layer in self.head: + if check_is_instance(layer, CausalConv3d) and feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), cache_x + ], + dim=2) + x = layer(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = layer(x) + return x, feat_cache, feat_idx + + +class Encoder3d_38(nn.Module): + + def __init__(self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_downsample=[False, True, True], + dropout=0.0): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_downsample = temperal_downsample + + # dimensions + dims = [dim * u for u in [1] + dim_mult] + scale = 1.0 + + # init block + self.conv1 = CausalConv3d(12, dims[0], 3, padding=1) + + # downsample blocks + downsamples = [] + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + t_down_flag = ( + temperal_downsample[i] if i < len(temperal_downsample) else False + ) + downsamples.append( + Down_ResidualBlock( + in_dim=in_dim, + out_dim=out_dim, + dropout=dropout, + mult=num_res_blocks, + temperal_downsample=t_down_flag, + down_flag=i != len(dim_mult) - 1, + ) + ) + scale /= 2.0 + self.downsamples = nn.Sequential(*downsamples) + + # middle blocks + self.middle = nn.Sequential( + ResidualBlock(out_dim, out_dim, dropout), + AttentionBlock(out_dim), + ResidualBlock(out_dim, out_dim, dropout), + ) + + # # output blocks + self.head = nn.Sequential( + RMS_norm(out_dim, images=False), + nn.SiLU(), + CausalConv3d(out_dim, z_dim, 3, padding=1), + ) + + + def forward(self, x, feat_cache=None, feat_idx=[0]): + + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat( + [ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), + cache_x, + ], + dim=2, + ) + x = self.conv1(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv1(x) + + ## downsamples + for layer in self.downsamples: + if feat_cache is not None: + x, feat_cache, feat_idx = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## middle + for layer in self.middle: + if isinstance(layer, ResidualBlock) and feat_cache is not None: + x, feat_cache, feat_idx = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## head + for layer in self.head: + if isinstance(layer, CausalConv3d) and feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat( + [ + feat_cache[idx][:, :, -1, :, :] + .unsqueeze(2) + .to(cache_x.device), + cache_x, + ], + dim=2, + ) + x = layer(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = layer(x) + + return x, feat_cache, feat_idx + + +class Decoder3d(nn.Module): + + def __init__(self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_upsample=[False, True, True], + dropout=0.0): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_upsample = temperal_upsample + + # dimensions + dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] + scale = 1.0 / 2**(len(dim_mult) - 2) + + # init block + self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1) + + # middle blocks + self.middle = nn.Sequential(ResidualBlock(dims[0], dims[0], dropout), + AttentionBlock(dims[0]), + ResidualBlock(dims[0], dims[0], dropout)) + + # upsample blocks + upsamples = [] + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + # residual (+attention) blocks + if i == 1 or i == 2 or i == 3: + in_dim = in_dim // 2 + for _ in range(num_res_blocks + 1): + upsamples.append(ResidualBlock(in_dim, out_dim, dropout)) + if scale in attn_scales: + upsamples.append(AttentionBlock(out_dim)) + in_dim = out_dim + + # upsample block + if i != len(dim_mult) - 1: + mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d' + upsamples.append(Resample(out_dim, mode=mode)) + scale *= 2.0 + self.upsamples = nn.Sequential(*upsamples) + + # output blocks + self.head = nn.Sequential(RMS_norm(out_dim, images=False), nn.SiLU(), + CausalConv3d(out_dim, 3, 3, padding=1)) + + def forward(self, x, feat_cache=None, feat_idx=[0]): + ## conv1 + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), cache_x + ], + dim=2) + x = self.conv1(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv1(x) + + ## middle + for layer in self.middle: + if check_is_instance(layer, ResidualBlock) and feat_cache is not None: + x, feat_cache, feat_idx = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## upsamples + for layer in self.upsamples: + if feat_cache is not None: + x, feat_cache, feat_idx = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## head + for layer in self.head: + if check_is_instance(layer, CausalConv3d) and feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), cache_x + ], + dim=2) + x = layer(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = layer(x) + return x, feat_cache, feat_idx + + + +class Decoder3d_38(nn.Module): + + def __init__(self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_upsample=[False, True, True], + dropout=0.0): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_upsample = temperal_upsample + + # dimensions + dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] + scale = 1.0 / 2 ** (len(dim_mult) - 2) + # init block + self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1) + + # middle blocks + self.middle = nn.Sequential(ResidualBlock(dims[0], dims[0], dropout), + AttentionBlock(dims[0]), + ResidualBlock(dims[0], dims[0], dropout)) + + # upsample blocks + upsamples = [] + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + t_up_flag = temperal_upsample[i] if i < len(temperal_upsample) else False + upsamples.append( + Up_ResidualBlock(in_dim=in_dim, + out_dim=out_dim, + dropout=dropout, + mult=num_res_blocks + 1, + temperal_upsample=t_up_flag, + up_flag=i != len(dim_mult) - 1)) + self.upsamples = nn.Sequential(*upsamples) + + # output blocks + self.head = nn.Sequential(RMS_norm(out_dim, images=False), nn.SiLU(), + CausalConv3d(out_dim, 12, 3, padding=1)) + + + def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False): + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat( + [ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), + cache_x, + ], + dim=2, + ) + x = self.conv1(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv1(x) + + for layer in self.middle: + if check_is_instance(layer, ResidualBlock) and feat_cache is not None: + x, feat_cache, feat_idx = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## upsamples + for layer in self.upsamples: + if feat_cache is not None: + x, feat_cache, feat_idx = layer(x, feat_cache, feat_idx, first_chunk) + else: + x = layer(x) + + ## head + for layer in self.head: + if check_is_instance(layer, CausalConv3d) and feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat( + [ + feat_cache[idx][:, :, -1, :, :] + .unsqueeze(2) + .to(cache_x.device), + cache_x, + ], + dim=2, + ) + x = layer(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = layer(x) + return x, feat_cache, feat_idx + + +def count_conv3d(model): + count = 0 + for m in model.modules(): + if isinstance(m, CausalConv3d): + count += 1 + return count + + +class VideoVAE_(nn.Module): + + def __init__(self, + dim=96, + z_dim=16, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_downsample=[False, True, True], + dropout=0.0): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_downsample = temperal_downsample + self.temperal_upsample = temperal_downsample[::-1] + + # modules + self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks, + attn_scales, self.temperal_downsample, dropout) + self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1) + self.conv2 = CausalConv3d(z_dim, z_dim, 1) + self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks, + attn_scales, self.temperal_upsample, dropout) + + def forward(self, x): + mu, log_var = self.encode(x) + z = self.reparameterize(mu, log_var) + x_recon = self.decode(z) + return x_recon, mu, log_var + + def encode(self, x, scale): + self.clear_cache() + ## cache + t = x.shape[2] + iter_ = 1 + (t - 1) // 4 + + for i in range(iter_): + self._enc_conv_idx = [0] + if i == 0: + out, self._enc_feat_map, self._enc_conv_idx = self.encoder(x[:, :, :1, :, :], + feat_cache=self._enc_feat_map, + feat_idx=self._enc_conv_idx) + else: + out_, self._enc_feat_map, self._enc_conv_idx = self.encoder(x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :], + feat_cache=self._enc_feat_map, + feat_idx=self._enc_conv_idx) + out = torch.cat([out, out_], 2) + mu, log_var = self.conv1(out).chunk(2, dim=1) + if isinstance(scale[0], torch.Tensor): + scale = [s.to(dtype=mu.dtype, device=mu.device) for s in scale] + mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view( + 1, self.z_dim, 1, 1, 1) + else: + scale = scale.to(dtype=mu.dtype, device=mu.device) + mu = (mu - scale[0]) * scale[1] + return mu + + def decode(self, z, scale): + self.clear_cache() + # z: [b,c,t,h,w] + if isinstance(scale[0], torch.Tensor): + scale = [s.to(dtype=z.dtype, device=z.device) for s in scale] + z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view( + 1, self.z_dim, 1, 1, 1) + else: + scale = scale.to(dtype=z.dtype, device=z.device) + z = z / scale[1] + scale[0] + iter_ = z.shape[2] + x = self.conv2(z) + for i in range(iter_): + self._conv_idx = [0] + if i == 0: + out, self._feat_map, self._conv_idx = self.decoder(x[:, :, i:i + 1, :, :], + feat_cache=self._feat_map, + feat_idx=self._conv_idx) + else: + out_, self._feat_map, self._conv_idx = self.decoder(x[:, :, i:i + 1, :, :], + feat_cache=self._feat_map, + feat_idx=self._conv_idx) + out = torch.cat([out, out_], 2) # may add tensor offload + return out + + def reparameterize(self, mu, log_var): + std = torch.exp(0.5 * log_var) + eps = torch.randn_like(std) + return eps * std + mu + + def sample(self, imgs, deterministic=False): + mu, log_var = self.encode(imgs) + if deterministic: + return mu + std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0)) + return mu + std * torch.randn_like(std) + + def clear_cache(self): + self._conv_num = count_conv3d(self.decoder) + self._conv_idx = [0] + self._feat_map = [None] * self._conv_num + # cache encode + self._enc_conv_num = count_conv3d(self.encoder) + self._enc_conv_idx = [0] + self._enc_feat_map = [None] * self._enc_conv_num + + +class WanVideoVAE(nn.Module): + + def __init__(self, z_dim=16): + super().__init__() + + mean = [ + -0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508, + 0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921 + ] + std = [ + 2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743, + 3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160 + ] + self.mean = torch.tensor(mean) + self.std = torch.tensor(std) + self.scale = [self.mean, 1.0 / self.std] + + # init model + self.model = VideoVAE_(z_dim=z_dim).eval().requires_grad_(False) + self.upsampling_factor = 8 + self.z_dim = z_dim + + + def build_1d_mask(self, length, left_bound, right_bound, border_width): + x = torch.ones((length,)) + if not left_bound: + x[:border_width] = (torch.arange(border_width) + 1) / border_width + if not right_bound: + x[-border_width:] = torch.flip((torch.arange(border_width) + 1) / border_width, dims=(0,)) + return x + + + def build_mask(self, data, is_bound, border_width): + _, _, _, H, W = data.shape + h = self.build_1d_mask(H, is_bound[0], is_bound[1], border_width[0]) + w = self.build_1d_mask(W, is_bound[2], is_bound[3], border_width[1]) + + h = repeat(h, "H -> H W", H=H, W=W) + w = repeat(w, "W -> H W", H=H, W=W) + + mask = torch.stack([h, w]).min(dim=0).values + mask = rearrange(mask, "H W -> 1 1 1 H W") + return mask + + + def tiled_decode(self, hidden_states, device, tile_size, tile_stride): + _, _, T, H, W = hidden_states.shape + size_h, size_w = tile_size + stride_h, stride_w = tile_stride + + # Split tasks + tasks = [] + for h in range(0, H, stride_h): + if (h-stride_h >= 0 and h-stride_h+size_h >= H): continue + for w in range(0, W, stride_w): + if (w-stride_w >= 0 and w-stride_w+size_w >= W): continue + h_, w_ = h + size_h, w + size_w + tasks.append((h, h_, w, w_)) + + data_device = "cpu" + computation_device = device + + out_T = T * 4 - 3 + weight = torch.zeros((1, 1, out_T, H * self.upsampling_factor, W * self.upsampling_factor), dtype=hidden_states.dtype, device=data_device) + values = torch.zeros((1, 3, out_T, H * self.upsampling_factor, W * self.upsampling_factor), dtype=hidden_states.dtype, device=data_device) + + for h, h_, w, w_ in tqdm(tasks, desc="VAE decoding"): + hidden_states_batch = hidden_states[:, :, :, h:h_, w:w_].to(computation_device) + hidden_states_batch = self.model.decode(hidden_states_batch, self.scale).to(data_device) + + mask = self.build_mask( + hidden_states_batch, + is_bound=(h==0, h_>=H, w==0, w_>=W), + border_width=((size_h - stride_h) * self.upsampling_factor, (size_w - stride_w) * self.upsampling_factor) + ).to(dtype=hidden_states.dtype, device=data_device) + + target_h = h * self.upsampling_factor + target_w = w * self.upsampling_factor + values[ + :, + :, + :, + target_h:target_h + hidden_states_batch.shape[3], + target_w:target_w + hidden_states_batch.shape[4], + ] += hidden_states_batch * mask + weight[ + :, + :, + :, + target_h: target_h + hidden_states_batch.shape[3], + target_w: target_w + hidden_states_batch.shape[4], + ] += mask + values = values / weight + values = values.clamp_(-1, 1) + return values + + + def tiled_encode(self, video, device, tile_size, tile_stride): + _, _, T, H, W = video.shape + size_h, size_w = tile_size + stride_h, stride_w = tile_stride + + # Split tasks + tasks = [] + for h in range(0, H, stride_h): + if (h-stride_h >= 0 and h-stride_h+size_h >= H): continue + for w in range(0, W, stride_w): + if (w-stride_w >= 0 and w-stride_w+size_w >= W): continue + h_, w_ = h + size_h, w + size_w + tasks.append((h, h_, w, w_)) + + data_device = "cpu" + computation_device = device + + out_T = (T + 3) // 4 + weight = torch.zeros((1, 1, out_T, H // self.upsampling_factor, W // self.upsampling_factor), dtype=video.dtype, device=data_device) + values = torch.zeros((1, self.z_dim, out_T, H // self.upsampling_factor, W // self.upsampling_factor), dtype=video.dtype, device=data_device) + + for h, h_, w, w_ in tqdm(tasks, desc="VAE encoding"): + hidden_states_batch = video[:, :, :, h:h_, w:w_].to(computation_device) + hidden_states_batch = self.model.encode(hidden_states_batch, self.scale).to(data_device) + + mask = self.build_mask( + hidden_states_batch, + is_bound=(h==0, h_>=H, w==0, w_>=W), + border_width=((size_h - stride_h) // self.upsampling_factor, (size_w - stride_w) // self.upsampling_factor) + ).to(dtype=video.dtype, device=data_device) + + target_h = h // self.upsampling_factor + target_w = w // self.upsampling_factor + values[ + :, + :, + :, + target_h:target_h + hidden_states_batch.shape[3], + target_w:target_w + hidden_states_batch.shape[4], + ] += hidden_states_batch * mask + weight[ + :, + :, + :, + target_h: target_h + hidden_states_batch.shape[3], + target_w: target_w + hidden_states_batch.shape[4], + ] += mask + values = values / weight + return values + + + def single_encode(self, video, device): + video = video.to(device) + x = self.model.encode(video, self.scale) + return x + + + def single_decode(self, hidden_state, device): + hidden_state = hidden_state.to(device) + video = self.model.decode(hidden_state, self.scale) + return video.clamp_(-1, 1) + + + def encode(self, videos, device, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)): + videos = [video.to("cpu") for video in videos] + hidden_states = [] + for video in videos: + video = video.unsqueeze(0) + if tiled: + tile_size = (tile_size[0] * self.upsampling_factor, tile_size[1] * self.upsampling_factor) + tile_stride = (tile_stride[0] * self.upsampling_factor, tile_stride[1] * self.upsampling_factor) + hidden_state = self.tiled_encode(video, device, tile_size, tile_stride) + else: + hidden_state = self.single_encode(video, device) + hidden_state = hidden_state.squeeze(0) + hidden_states.append(hidden_state) + hidden_states = torch.stack(hidden_states) + return hidden_states + + + def decode(self, hidden_states, device, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)): + hidden_states = [hidden_state.to("cpu") for hidden_state in hidden_states] + videos = [] + for hidden_state in hidden_states: + hidden_state = hidden_state.unsqueeze(0) + if tiled: + video = self.tiled_decode(hidden_state, device, tile_size, tile_stride) + else: + video = self.single_decode(hidden_state, device) + video = video.squeeze(0) + videos.append(video) + videos = torch.stack(videos) + return videos + + + def encode_framewise(self, videos, device): + hidden_states = [] + for i in range(videos.shape[2]): + hidden_states.append(self.single_encode(videos[:, :, i:i+1], device)) + hidden_states = torch.concat(hidden_states, dim=2) + return hidden_states + + + def decode_framewise(self, hidden_states, device): + video = [] + for i in range(hidden_states.shape[2]): + video.append(self.single_decode(hidden_states[:, :, i:i+1], device)) + video = torch.concat(video, dim=2) + return video + + + @staticmethod + def state_dict_converter(): + return WanVideoVAEStateDictConverter() + + +class WanVideoVAEStateDictConverter: + + def __init__(self): + pass + + def from_civitai(self, state_dict): + state_dict_ = {} + if 'model_state' in state_dict: + state_dict = state_dict['model_state'] + for name in state_dict: + state_dict_['model.' + name] = state_dict[name] + return state_dict_ + + +class VideoVAE38_(VideoVAE_): + + def __init__(self, + dim=160, + z_dim=48, + dec_dim=256, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_downsample=[False, True, True], + dropout=0.0): + super(VideoVAE_, self).__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_downsample = temperal_downsample + self.temperal_upsample = temperal_downsample[::-1] + + # modules + self.encoder = Encoder3d_38(dim, z_dim * 2, dim_mult, num_res_blocks, + attn_scales, self.temperal_downsample, dropout) + self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1) + self.conv2 = CausalConv3d(z_dim, z_dim, 1) + self.decoder = Decoder3d_38(dec_dim, z_dim, dim_mult, num_res_blocks, + attn_scales, self.temperal_upsample, dropout) + + + def encode(self, x, scale): + self.clear_cache() + x = patchify(x, patch_size=2) + t = x.shape[2] + iter_ = 1 + (t - 1) // 4 + for i in range(iter_): + self._enc_conv_idx = [0] + if i == 0: + out, self._enc_feat_map, self._enc_conv_idx = self.encoder(x[:, :, :1, :, :], + feat_cache=self._enc_feat_map, + feat_idx=self._enc_conv_idx) + else: + out_, self._enc_feat_map, self._enc_conv_idx = self.encoder(x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :], + feat_cache=self._enc_feat_map, + feat_idx=self._enc_conv_idx) + out = torch.cat([out, out_], 2) + mu, log_var = self.conv1(out).chunk(2, dim=1) + if isinstance(scale[0], torch.Tensor): + scale = [s.to(dtype=mu.dtype, device=mu.device) for s in scale] + mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view( + 1, self.z_dim, 1, 1, 1) + else: + scale = scale.to(dtype=mu.dtype, device=mu.device) + mu = (mu - scale[0]) * scale[1] + self.clear_cache() + return mu + + + def decode(self, z, scale): + self.clear_cache() + if isinstance(scale[0], torch.Tensor): + scale = [s.to(dtype=z.dtype, device=z.device) for s in scale] + z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view( + 1, self.z_dim, 1, 1, 1) + else: + scale = scale.to(dtype=z.dtype, device=z.device) + z = z / scale[1] + scale[0] + iter_ = z.shape[2] + x = self.conv2(z) + for i in range(iter_): + self._conv_idx = [0] + if i == 0: + out, self._feat_map, self._conv_idx = self.decoder(x[:, :, i:i + 1, :, :], + feat_cache=self._feat_map, + feat_idx=self._conv_idx, + first_chunk=True) + else: + out_, self._feat_map, self._conv_idx = self.decoder(x[:, :, i:i + 1, :, :], + feat_cache=self._feat_map, + feat_idx=self._conv_idx) + out = torch.cat([out, out_], 2) + out = unpatchify(out, patch_size=2) + self.clear_cache() + return out + + +class WanVideoVAE38(WanVideoVAE): + + def __init__(self, z_dim=48, dim=160): + super(WanVideoVAE, self).__init__() + + mean = [ + -0.2289, -0.0052, -0.1323, -0.2339, -0.2799, 0.0174, 0.1838, 0.1557, + -0.1382, 0.0542, 0.2813, 0.0891, 0.1570, -0.0098, 0.0375, -0.1825, + -0.2246, -0.1207, -0.0698, 0.5109, 0.2665, -0.2108, -0.2158, 0.2502, + -0.2055, -0.0322, 0.1109, 0.1567, -0.0729, 0.0899, -0.2799, -0.1230, + -0.0313, -0.1649, 0.0117, 0.0723, -0.2839, -0.2083, -0.0520, 0.3748, + 0.0152, 0.1957, 0.1433, -0.2944, 0.3573, -0.0548, -0.1681, -0.0667 + ] + std = [ + 0.4765, 1.0364, 0.4514, 1.1677, 0.5313, 0.4990, 0.4818, 0.5013, + 0.8158, 1.0344, 0.5894, 1.0901, 0.6885, 0.6165, 0.8454, 0.4978, + 0.5759, 0.3523, 0.7135, 0.6804, 0.5833, 1.4146, 0.8986, 0.5659, + 0.7069, 0.5338, 0.4889, 0.4917, 0.4069, 0.4999, 0.6866, 0.4093, + 0.5709, 0.6065, 0.6415, 0.4944, 0.5726, 1.2042, 0.5458, 1.6887, + 0.3971, 1.0600, 0.3943, 0.5537, 0.5444, 0.4089, 0.7468, 0.7744 + ] + self.mean = torch.tensor(mean) + self.std = torch.tensor(std) + self.scale = [self.mean, 1.0 / self.std] + + # init model + self.model = VideoVAE38_(z_dim=z_dim, dim=dim).eval().requires_grad_(False) + self.upsampling_factor = 16 + self.z_dim = z_dim diff --git a/diffsynth/models/wantodance.py b/diffsynth/models/wantodance.py new file mode 100644 index 0000000000000000000000000000000000000000..bc9ddc936718b8e820587100fae06c52b429d0b6 --- /dev/null +++ b/diffsynth/models/wantodance.py @@ -0,0 +1,209 @@ +from inspect import isfunction +from math import log, pi + +import torch +from einops import rearrange, repeat +from torch import einsum, nn + +from typing import Any, Callable, List, Optional, Union +from torch import Tensor +import torch.nn.functional as F + +# helper functions + + +def exists(val): + return val is not None + + +def broadcat(tensors, dim=-1): + num_tensors = len(tensors) + shape_lens = set(list(map(lambda t: len(t.shape), tensors))) + assert len(shape_lens) == 1, "tensors must all have the same number of dimensions" + shape_len = list(shape_lens)[0] + + dim = (dim + shape_len) if dim < 0 else dim + dims = list(zip(*map(lambda t: list(t.shape), tensors))) + + expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim] + assert all( + [*map(lambda t: len(set(t[1])) <= 2, expandable_dims)] + ), "invalid dimensions for broadcastable concatentation" + max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims)) + expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims)) + expanded_dims.insert(dim, (dim, dims[dim])) + expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims))) + tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes))) + return torch.cat(tensors, dim=dim) + + +# rotary embedding helper functions + + +def rotate_half(x): + x = rearrange(x, "... (d r) -> ... d r", r=2) + x1, x2 = x.unbind(dim=-1) + x = torch.stack((-x2, x1), dim=-1) + return rearrange(x, "... d r -> ... (d r)") + + +def apply_rotary_emb(freqs, t, start_index=0): + freqs = freqs.to(t) + rot_dim = freqs.shape[-1] + end_index = start_index + rot_dim + assert ( + rot_dim <= t.shape[-1] + ), f"feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}" + t_left, t, t_right = ( + t[..., :start_index], + t[..., start_index:end_index], + t[..., end_index:], + ) + t = (t * freqs.cos()) + (rotate_half(t) * freqs.sin()) + return torch.cat((t_left, t, t_right), dim=-1) + + +# learned rotation helpers + + +def apply_learned_rotations(rotations, t, start_index=0, freq_ranges=None): + if exists(freq_ranges): + rotations = einsum("..., f -> ... f", rotations, freq_ranges) + rotations = rearrange(rotations, "... r f -> ... (r f)") + + rotations = repeat(rotations, "... n -> ... (n r)", r=2) + return apply_rotary_emb(rotations, t, start_index=start_index) + + +# classes + + +class WanToDanceRotaryEmbedding(nn.Module): + def __init__( + self, + dim, + custom_freqs=None, + freqs_for="lang", + theta=10000, + max_freq=10, + num_freqs=1, + learned_freq=False, + ): + super().__init__() + if exists(custom_freqs): + freqs = custom_freqs + elif freqs_for == "lang": + freqs = 1.0 / ( + theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim) + ) + elif freqs_for == "pixel": + freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi + elif freqs_for == "constant": + freqs = torch.ones(num_freqs).float() + else: + raise ValueError(f"unknown modality {freqs_for}") + + self.cache = dict() + + if learned_freq: + self.freqs = nn.Parameter(freqs) + else: + self.register_buffer("freqs", freqs, persistent=False) + + def rotate_queries_or_keys(self, t, seq_dim=-2): + device = t.device + seq_len = t.shape[seq_dim] + freqs = self.forward( + lambda: torch.arange(seq_len, device=device), cache_key=seq_len + ) + return apply_rotary_emb(freqs, t) + + def forward(self, t, cache_key=None): + if exists(cache_key) and cache_key in self.cache: + return self.cache[cache_key] + + if isfunction(t): + t = t() + + # freqs = self.freqs + freqs = self.freqs.to(t.device) + + freqs = torch.einsum("..., f -> ... f", t.type(freqs.dtype), freqs) + freqs = repeat(freqs, "... n -> ... (n r)", r=2) + + if exists(cache_key): + self.cache[cache_key] = freqs + + return freqs + + +class WanToDanceMusicEncoderLayer(nn.Module): + def __init__( + self, + d_model: int, + nhead: int, + dim_feedforward: int = 2048, + dropout: float = 0.1, + activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, + layer_norm_eps: float = 1e-5, + batch_first: bool = False, + norm_first: bool = True, + device=None, + dtype=None, + rotary=None, + ) -> None: + super().__init__() + self.self_attn = nn.MultiheadAttention( + d_model, nhead, dropout=dropout, batch_first=batch_first, device=device, dtype=dtype + ) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm_first = norm_first + self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps) + self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + self.activation = activation + + self.rotary = rotary + self.use_rotary = rotary is not None + + # self-attention block + def _sa_block( + self, x: Tensor, attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor] + ) -> Tensor: + qk = self.rotary.rotate_queries_or_keys(x) if self.use_rotary else x + x = self.self_attn( + qk, + qk, + x, + attn_mask=attn_mask, + key_padding_mask=key_padding_mask, + need_weights=False, + )[0] + return self.dropout1(x) + + # feed forward block + def _ff_block(self, x: Tensor) -> Tensor: + x = self.linear2(self.dropout(self.activation(self.linear1(x)))) + return self.dropout2(x) + + def forward( + self, + src: Tensor, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + x = src + if self.norm_first: + self.norm1.to(device=x.device) + self.norm2.to(device=x.device) + x = x + self._sa_block(self.norm1(x), src_mask, src_key_padding_mask) + x = x + self._ff_block(self.norm2(x)) + else: + x = self.norm1(x + self._sa_block(x, src_mask, src_key_padding_mask)) + x = self.norm2(x + self._ff_block(x)) + return x \ No newline at end of file diff --git a/diffsynth/models/wav2vec.py b/diffsynth/models/wav2vec.py new file mode 100644 index 0000000000000000000000000000000000000000..8807302d815a917123b794a597fc5fe84d3394fc --- /dev/null +++ b/diffsynth/models/wav2vec.py @@ -0,0 +1,191 @@ +import math +import numpy as np +import torch +import torch.nn.functional as F + + +def get_sample_indices(original_fps, total_frames, target_fps, num_sample, fixed_start=None): + required_duration = num_sample / target_fps + required_origin_frames = int(np.ceil(required_duration * original_fps)) + if required_duration > total_frames / original_fps: + raise ValueError("required_duration must be less than video length") + + if not fixed_start is None and fixed_start >= 0: + start_frame = fixed_start + else: + max_start = total_frames - required_origin_frames + if max_start < 0: + raise ValueError("video length is too short") + start_frame = np.random.randint(0, max_start + 1) + start_time = start_frame / original_fps + + end_time = start_time + required_duration + time_points = np.linspace(start_time, end_time, num_sample, endpoint=False) + + frame_indices = np.round(np.array(time_points) * original_fps).astype(int) + frame_indices = np.clip(frame_indices, 0, total_frames - 1) + return frame_indices + + +def linear_interpolation(features, input_fps, output_fps, output_len=None): + """ + features: shape=[1, T, 512] + input_fps: fps for audio, f_a + output_fps: fps for video, f_m + output_len: video length + """ + features = features.transpose(1, 2) + seq_len = features.shape[2] / float(input_fps) + if output_len is None: + output_len = int(seq_len * output_fps) + output_features = F.interpolate(features, size=output_len, align_corners=True, mode='linear') # [1, 512, output_len] + return output_features.transpose(1, 2) + + +class WanS2VAudioEncoder(torch.nn.Module): + + def __init__(self): + super().__init__() + from transformers import Wav2Vec2ForCTC, Wav2Vec2Config + config = { + "_name_or_path": "facebook/wav2vec2-large-xlsr-53", + "activation_dropout": 0.05, + "apply_spec_augment": True, + "architectures": ["Wav2Vec2ForCTC"], + "attention_dropout": 0.1, + "bos_token_id": 1, + "conv_bias": True, + "conv_dim": [512, 512, 512, 512, 512, 512, 512], + "conv_kernel": [10, 3, 3, 3, 3, 2, 2], + "conv_stride": [5, 2, 2, 2, 2, 2, 2], + "ctc_loss_reduction": "mean", + "ctc_zero_infinity": True, + "do_stable_layer_norm": True, + "eos_token_id": 2, + "feat_extract_activation": "gelu", + "feat_extract_dropout": 0.0, + "feat_extract_norm": "layer", + "feat_proj_dropout": 0.05, + "final_dropout": 0.0, + "hidden_act": "gelu", + "hidden_dropout": 0.05, + "hidden_size": 1024, + "initializer_range": 0.02, + "intermediate_size": 4096, + "layer_norm_eps": 1e-05, + "layerdrop": 0.05, + "mask_channel_length": 10, + "mask_channel_min_space": 1, + "mask_channel_other": 0.0, + "mask_channel_prob": 0.0, + "mask_channel_selection": "static", + "mask_feature_length": 10, + "mask_feature_prob": 0.0, + "mask_time_length": 10, + "mask_time_min_space": 1, + "mask_time_other": 0.0, + "mask_time_prob": 0.05, + "mask_time_selection": "static", + "model_type": "wav2vec2", + "num_attention_heads": 16, + "num_conv_pos_embedding_groups": 16, + "num_conv_pos_embeddings": 128, + "num_feat_extract_layers": 7, + "num_hidden_layers": 24, + "pad_token_id": 0, + "transformers_version": "4.7.0.dev0", + "vocab_size": 33 + } + self.model = Wav2Vec2ForCTC(Wav2Vec2Config(**config)) + self.video_rate = 30 + + def extract_audio_feat(self, input_audio, sample_rate, processor, return_all_layers=False, dtype=torch.float32, device='cpu'): + input_values = processor(input_audio, sampling_rate=sample_rate, return_tensors="pt").input_values.to(dtype=dtype, device=device) + + # retrieve logits & take argmax + res = self.model(input_values, output_hidden_states=True) + if return_all_layers: + feat = torch.cat(res.hidden_states) + else: + feat = res.hidden_states[-1] + feat = linear_interpolation(feat, input_fps=50, output_fps=self.video_rate) + return feat + + def get_audio_embed_bucket(self, audio_embed, stride=2, batch_frames=12, m=2): + num_layers, audio_frame_num, audio_dim = audio_embed.shape + + if num_layers > 1: + return_all_layers = True + else: + return_all_layers = False + + min_batch_num = int(audio_frame_num / (batch_frames * stride)) + 1 + + bucket_num = min_batch_num * batch_frames + batch_idx = [stride * i for i in range(bucket_num)] + batch_audio_eb = [] + for bi in batch_idx: + if bi < audio_frame_num: + audio_sample_stride = 2 + chosen_idx = list(range(bi - m * audio_sample_stride, bi + (m + 1) * audio_sample_stride, audio_sample_stride)) + chosen_idx = [0 if c < 0 else c for c in chosen_idx] + chosen_idx = [audio_frame_num - 1 if c >= audio_frame_num else c for c in chosen_idx] + + if return_all_layers: + frame_audio_embed = audio_embed[:, chosen_idx].flatten(start_dim=-2, end_dim=-1) + else: + frame_audio_embed = audio_embed[0][chosen_idx].flatten() + else: + frame_audio_embed = \ + torch.zeros([audio_dim * (2 * m + 1)], device=audio_embed.device) if not return_all_layers \ + else torch.zeros([num_layers, audio_dim * (2 * m + 1)], device=audio_embed.device) + batch_audio_eb.append(frame_audio_embed) + batch_audio_eb = torch.cat([c.unsqueeze(0) for c in batch_audio_eb], dim=0) + + return batch_audio_eb, min_batch_num + + def get_audio_embed_bucket_fps(self, audio_embed, fps=16, batch_frames=81, m=0): + num_layers, audio_frame_num, audio_dim = audio_embed.shape + + if num_layers > 1: + return_all_layers = True + else: + return_all_layers = False + + scale = self.video_rate / fps + + min_batch_num = int(audio_frame_num / (batch_frames * scale)) + 1 + + bucket_num = min_batch_num * batch_frames + padd_audio_num = math.ceil(min_batch_num * batch_frames / fps * self.video_rate) - audio_frame_num + batch_idx = get_sample_indices( + original_fps=self.video_rate, total_frames=audio_frame_num + padd_audio_num, target_fps=fps, num_sample=bucket_num, fixed_start=0 + ) + batch_audio_eb = [] + audio_sample_stride = int(self.video_rate / fps) + for bi in batch_idx: + if bi < audio_frame_num: + + chosen_idx = list(range(bi - m * audio_sample_stride, bi + (m + 1) * audio_sample_stride, audio_sample_stride)) + chosen_idx = [0 if c < 0 else c for c in chosen_idx] + chosen_idx = [audio_frame_num - 1 if c >= audio_frame_num else c for c in chosen_idx] + + if return_all_layers: + frame_audio_embed = audio_embed[:, chosen_idx].flatten(start_dim=-2, end_dim=-1) + else: + frame_audio_embed = audio_embed[0][chosen_idx].flatten() + else: + frame_audio_embed = \ + torch.zeros([audio_dim * (2 * m + 1)], device=audio_embed.device) if not return_all_layers \ + else torch.zeros([num_layers, audio_dim * (2 * m + 1)], device=audio_embed.device) + batch_audio_eb.append(frame_audio_embed) + batch_audio_eb = torch.cat([c.unsqueeze(0) for c in batch_audio_eb], dim=0) + + return batch_audio_eb, min_batch_num + + def get_audio_feats_per_inference(self, input_audio, sample_rate, processor, fps=16, batch_frames=80, m=0, dtype=torch.float32, device='cpu'): + audio_feat = self.extract_audio_feat(input_audio, sample_rate, processor, return_all_layers=True, dtype=dtype, device=device) + audio_embed_bucket, min_batch_num = self.get_audio_embed_bucket_fps(audio_feat, fps=fps, batch_frames=batch_frames, m=m) + audio_embed_bucket = audio_embed_bucket.unsqueeze(0).permute(0, 2, 3, 1).to(device, dtype) + audio_embeds = [audio_embed_bucket[..., i * batch_frames:(i + 1) * batch_frames] for i in range(min_batch_num)] + return audio_embeds diff --git a/diffsynth/models/z_image_controlnet.py b/diffsynth/models/z_image_controlnet.py new file mode 100644 index 0000000000000000000000000000000000000000..5105534c736d1908fb5210eda0bfbd07c3b281df --- /dev/null +++ b/diffsynth/models/z_image_controlnet.py @@ -0,0 +1,154 @@ +from .z_image_dit import ZImageTransformerBlock +from ..core.gradient import gradient_checkpoint_forward +from torch.nn.utils.rnn import pad_sequence +import torch +from torch import nn + + +class ZImageControlTransformerBlock(ZImageTransformerBlock): + def __init__( + self, + layer_id: int = 1000, + dim: int = 3840, + n_heads: int = 30, + n_kv_heads: int = 30, + norm_eps: float = 1e-5, + qk_norm: bool = True, + modulation = True, + block_id = 0 + ): + super().__init__(layer_id, dim, n_heads, n_kv_heads, norm_eps, qk_norm, modulation) + self.block_id = block_id + if block_id == 0: + self.before_proj = nn.Linear(self.dim, self.dim) + self.after_proj = nn.Linear(self.dim, self.dim) + + def forward(self, c, x, **kwargs): + if self.block_id == 0: + c = self.before_proj(c) + x + all_c = [] + else: + all_c = list(torch.unbind(c)) + c = all_c.pop(-1) + + c = super().forward(c, **kwargs) + c_skip = self.after_proj(c) + all_c += [c_skip, c] + c = torch.stack(all_c) + return c + + +class ZImageControlNet(torch.nn.Module): + def __init__( + self, + control_layers_places=(0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28), + control_in_dim=33, + dim=3840, + n_refiner_layers=2, + ): + super().__init__() + self.control_layers = nn.ModuleList([ZImageControlTransformerBlock(layer_id=i, block_id=i) for i in control_layers_places]) + self.control_all_x_embedder = nn.ModuleDict({"2-1": nn.Linear(1 * 2 * 2 * control_in_dim, dim, bias=True)}) + self.control_noise_refiner = nn.ModuleList([ZImageControlTransformerBlock(block_id=layer_id) for layer_id in range(n_refiner_layers)]) + self.control_layers_mapping = {0: 0, 2: 1, 4: 2, 6: 3, 8: 4, 10: 5, 12: 6, 14: 7, 16: 8, 18: 9, 20: 10, 22: 11, 24: 12, 26: 13, 28: 14} + + def forward_layers( + self, + x, + cap_feats, + control_context, + control_context_item_seqlens, + kwargs, + use_gradient_checkpointing=False, + use_gradient_checkpointing_offload=False, + ): + bsz = len(control_context) + # unified + cap_item_seqlens = [len(_) for _ in cap_feats] + control_context_unified = [] + for i in range(bsz): + control_context_len = control_context_item_seqlens[i] + cap_len = cap_item_seqlens[i] + control_context_unified.append(torch.cat([control_context[i][:control_context_len], cap_feats[i][:cap_len]])) + c = pad_sequence(control_context_unified, batch_first=True, padding_value=0.0) + + # arguments + new_kwargs = dict(x=x) + new_kwargs.update(kwargs) + + for layer in self.control_layers: + c = gradient_checkpoint_forward( + layer, + use_gradient_checkpointing=use_gradient_checkpointing, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + c=c, **new_kwargs + ) + + hints = torch.unbind(c)[:-1] + return hints + + def forward_refiner( + self, + dit, + x, + cap_feats, + control_context, + kwargs, + t=None, + patch_size=2, + f_patch_size=1, + use_gradient_checkpointing=False, + use_gradient_checkpointing_offload=False, + ): + # embeddings + bsz = len(control_context) + device = control_context[0].device + ( + control_context, + control_context_size, + control_context_pos_ids, + control_context_inner_pad_mask, + ) = dit.patchify_controlnet(control_context, patch_size, f_patch_size, cap_feats[0].size(0)) + + # control_context embed & refine + control_context_item_seqlens = [len(_) for _ in control_context] + assert all(_ % 2 == 0 for _ in control_context_item_seqlens) + control_context_max_item_seqlen = max(control_context_item_seqlens) + + control_context = torch.cat(control_context, dim=0) + control_context = self.control_all_x_embedder[f"{patch_size}-{f_patch_size}"](control_context) + + # Match t_embedder output dtype to control_context for layerwise casting compatibility + adaln_input = t.type_as(control_context) + control_context[torch.cat(control_context_inner_pad_mask)] = dit.x_pad_token.to(dtype=control_context.dtype, device=control_context.device) + control_context = list(control_context.split(control_context_item_seqlens, dim=0)) + control_context_freqs_cis = list(dit.rope_embedder(torch.cat(control_context_pos_ids, dim=0)).split(control_context_item_seqlens, dim=0)) + + control_context = pad_sequence(control_context, batch_first=True, padding_value=0.0) + control_context_freqs_cis = pad_sequence(control_context_freqs_cis, batch_first=True, padding_value=0.0) + control_context_attn_mask = torch.zeros((bsz, control_context_max_item_seqlen), dtype=torch.bool, device=device) + for i, seq_len in enumerate(control_context_item_seqlens): + control_context_attn_mask[i, :seq_len] = 1 + c = control_context + + # arguments + new_kwargs = dict( + x=x, + attn_mask=control_context_attn_mask, + freqs_cis=control_context_freqs_cis, + adaln_input=adaln_input, + ) + new_kwargs.update(kwargs) + + for layer in self.control_noise_refiner: + c = gradient_checkpoint_forward( + layer, + use_gradient_checkpointing=use_gradient_checkpointing, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + c=c, **new_kwargs + ) + + hints = torch.unbind(c)[:-1] + control_context = torch.unbind(c)[-1] + + return hints, control_context, control_context_item_seqlens \ No newline at end of file diff --git a/diffsynth/models/z_image_dit.py b/diffsynth/models/z_image_dit.py new file mode 100644 index 0000000000000000000000000000000000000000..6a0dc33f48ef2f338dae87258b24ca35b0954d68 --- /dev/null +++ b/diffsynth/models/z_image_dit.py @@ -0,0 +1,1153 @@ +import math +from typing import List, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.utils.rnn import pad_sequence + +from .general_modules import RMSNorm +from ..core.attention import attention_forward +from ..core.device.npu_compatible_device import IS_NPU_AVAILABLE, get_device_type +from ..core.gradient import gradient_checkpoint_forward + + +ADALN_EMBED_DIM = 256 +SEQ_MULTI_OF = 32 +X_PAD_DIM = 64 + + +class TimestepEmbedder(nn.Module): + def __init__(self, out_size, mid_size=None, frequency_embedding_size=256): + super().__init__() + if mid_size is None: + mid_size = out_size + self.mlp = nn.Sequential( + nn.Linear( + frequency_embedding_size, + mid_size, + bias=True, + ), + nn.SiLU(), + nn.Linear( + mid_size, + out_size, + bias=True, + ), + ) + + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10000): + with torch.amp.autocast(get_device_type(), enabled=False): + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half + ) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + t_emb = self.mlp(t_freq.to(torch.bfloat16)) + return t_emb + + +class FeedForward(nn.Module): + def __init__(self, dim: int, hidden_dim: int): + super().__init__() + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + + def _forward_silu_gating(self, x1, x3): + return F.silu(x1) * x3 + + def forward(self, x): + return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x))) + + +class Attention(torch.nn.Module): + + def __init__(self, q_dim, num_heads, head_dim, kv_dim=None, bias_q=False, bias_kv=False, bias_out=False): + super().__init__() + dim_inner = head_dim * num_heads + kv_dim = kv_dim if kv_dim is not None else q_dim + self.num_heads = num_heads + self.head_dim = head_dim + + self.to_q = torch.nn.Linear(q_dim, dim_inner, bias=bias_q) + self.to_k = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv) + self.to_v = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv) + self.to_out = torch.nn.ModuleList([torch.nn.Linear(dim_inner, q_dim, bias=bias_out)]) + + self.norm_q = RMSNorm(head_dim, eps=1e-5) + self.norm_k = RMSNorm(head_dim, eps=1e-5) + + # Apply RoPE + def apply_rotary_emb(self, x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: + with torch.amp.autocast(get_device_type(), enabled=False): + x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2)) + freqs_cis = freqs_cis.unsqueeze(2) + x_out = torch.view_as_real(x * freqs_cis).flatten(3) + return x_out.type_as(x_in) # todo + + def forward(self, hidden_states, freqs_cis, attention_mask): + query = self.to_q(hidden_states) + key = self.to_k(hidden_states) + value = self.to_v(hidden_states) + + query = query.unflatten(-1, (self.num_heads, -1)) + key = key.unflatten(-1, (self.num_heads, -1)) + value = value.unflatten(-1, (self.num_heads, -1)) + + # Apply Norms + if self.norm_q is not None: + query = self.norm_q(query) + if self.norm_k is not None: + key = self.norm_k(key) + + if freqs_cis is not None: + query = self.apply_rotary_emb(query, freqs_cis) + key = self.apply_rotary_emb(key, freqs_cis) + + # Cast to correct dtype + dtype = query.dtype + query, key = query.to(dtype), key.to(dtype) + + # Compute joint attention + hidden_states = attention_forward( + query, + key, + value, + q_pattern="b s n d", k_pattern="b s n d", v_pattern="b s n d", out_pattern="b s n d", + attn_mask=attention_mask, + ) + + # Reshape back + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.to(dtype) + + output = self.to_out[0](hidden_states) + if len(self.to_out) > 1: # dropout + output = self.to_out[1](output) + + return output + + +def select_per_token( + value_noisy: torch.Tensor, + value_clean: torch.Tensor, + noise_mask: torch.Tensor, + seq_len: int, +) -> torch.Tensor: + noise_mask_expanded = noise_mask.unsqueeze(-1) # (batch, seq_len, 1) + return torch.where( + noise_mask_expanded == 1, + value_noisy.unsqueeze(1).expand(-1, seq_len, -1), + value_clean.unsqueeze(1).expand(-1, seq_len, -1), + ) + + +class ZImageTransformerBlock(nn.Module): + def __init__( + self, + layer_id: int, + dim: int, + n_heads: int, + n_kv_heads: int, + norm_eps: float, + qk_norm: bool, + modulation=True, + ): + super().__init__() + self.dim = dim + self.head_dim = dim // n_heads + + # Refactored to use diffusers Attention with custom processor + # Original Z-Image params: dim, n_heads, n_kv_heads, qk_norm + self.attention = Attention( + q_dim=dim, + num_heads=n_heads, + head_dim=dim // n_heads, + ) + + self.feed_forward = FeedForward(dim=dim, hidden_dim=int(dim / 3 * 8)) + self.layer_id = layer_id + + self.attention_norm1 = RMSNorm(dim, eps=norm_eps) + self.ffn_norm1 = RMSNorm(dim, eps=norm_eps) + + self.attention_norm2 = RMSNorm(dim, eps=norm_eps) + self.ffn_norm2 = RMSNorm(dim, eps=norm_eps) + + self.modulation = modulation + if modulation: + self.adaLN_modulation = nn.Sequential( + nn.Linear(min(dim, ADALN_EMBED_DIM), 4 * dim, bias=True), + ) + + def forward( + self, + x: torch.Tensor, + attn_mask: torch.Tensor, + freqs_cis: torch.Tensor, + adaln_input: Optional[torch.Tensor] = None, + noise_mask: Optional[torch.Tensor] = None, + adaln_noisy: Optional[torch.Tensor] = None, + adaln_clean: Optional[torch.Tensor] = None, + ): + if self.modulation: + seq_len = x.shape[1] + + if noise_mask is not None: + # Per-token modulation: different modulation for noisy/clean tokens + mod_noisy = self.adaLN_modulation(adaln_noisy) + mod_clean = self.adaLN_modulation(adaln_clean) + + scale_msa_noisy, gate_msa_noisy, scale_mlp_noisy, gate_mlp_noisy = mod_noisy.chunk(4, dim=1) + scale_msa_clean, gate_msa_clean, scale_mlp_clean, gate_mlp_clean = mod_clean.chunk(4, dim=1) + + gate_msa_noisy, gate_mlp_noisy = gate_msa_noisy.tanh(), gate_mlp_noisy.tanh() + gate_msa_clean, gate_mlp_clean = gate_msa_clean.tanh(), gate_mlp_clean.tanh() + + scale_msa_noisy, scale_mlp_noisy = 1.0 + scale_msa_noisy, 1.0 + scale_mlp_noisy + scale_msa_clean, scale_mlp_clean = 1.0 + scale_msa_clean, 1.0 + scale_mlp_clean + + scale_msa = select_per_token(scale_msa_noisy, scale_msa_clean, noise_mask, seq_len) + scale_mlp = select_per_token(scale_mlp_noisy, scale_mlp_clean, noise_mask, seq_len) + gate_msa = select_per_token(gate_msa_noisy, gate_msa_clean, noise_mask, seq_len) + gate_mlp = select_per_token(gate_mlp_noisy, gate_mlp_clean, noise_mask, seq_len) + else: + # Global modulation: same modulation for all tokens (avoid double select) + mod = self.adaLN_modulation(adaln_input) + scale_msa, gate_msa, scale_mlp, gate_mlp = mod.unsqueeze(1).chunk(4, dim=2) + gate_msa, gate_mlp = gate_msa.tanh(), gate_mlp.tanh() + scale_msa, scale_mlp = 1.0 + scale_msa, 1.0 + scale_mlp + + # Attention block + attn_out = self.attention( + self.attention_norm1(x) * scale_msa, attention_mask=attn_mask, freqs_cis=freqs_cis + ) + x = x + gate_msa * self.attention_norm2(attn_out) + + # FFN block + x = x + gate_mlp * self.ffn_norm2(self.feed_forward(self.ffn_norm1(x) * scale_mlp)) + else: + # Attention block + attn_out = self.attention(self.attention_norm1(x), attention_mask=attn_mask, freqs_cis=freqs_cis) + x = x + self.attention_norm2(attn_out) + + # FFN block + x = x + self.ffn_norm2(self.feed_forward(self.ffn_norm1(x))) + + return x + + +class FinalLayer(nn.Module): + def __init__(self, hidden_size, out_channels): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, out_channels, bias=True) + + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(min(hidden_size, ADALN_EMBED_DIM), hidden_size, bias=True), + ) + + def forward(self, x, c=None, noise_mask=None, c_noisy=None, c_clean=None): + seq_len = x.shape[1] + + if noise_mask is not None: + # Per-token modulation + scale_noisy = 1.0 + self.adaLN_modulation(c_noisy) + scale_clean = 1.0 + self.adaLN_modulation(c_clean) + scale = select_per_token(scale_noisy, scale_clean, noise_mask, seq_len) + else: + # Original global modulation + assert c is not None, "Either c or (c_noisy, c_clean) must be provided" + scale = 1.0 + self.adaLN_modulation(c) + scale = scale.unsqueeze(1) + + x = self.norm_final(x) * scale + x = self.linear(x) + return x + + +class RopeEmbedder: + def __init__( + self, + theta: float = 256.0, + axes_dims: List[int] = (16, 56, 56), + axes_lens: List[int] = (64, 128, 128), + ): + self.theta = theta + self.axes_dims = axes_dims + self.axes_lens = axes_lens + assert len(axes_dims) == len(axes_lens), "axes_dims and axes_lens must have the same length" + self.freqs_cis = None + + @staticmethod + def precompute_freqs_cis(dim: List[int], end: List[int], theta: float = 256.0): + with torch.device("cpu"): + freqs_cis = [] + for i, (d, e) in enumerate(zip(dim, end)): + freqs = 1.0 / (theta ** (torch.arange(0, d, 2, dtype=torch.float64, device="cpu") / d)) + timestep = torch.arange(e, device=freqs.device, dtype=torch.float64) + freqs = torch.outer(timestep, freqs).float() + freqs_cis_i = torch.polar(torch.ones_like(freqs), freqs).to(torch.complex64) # complex64 + freqs_cis.append(freqs_cis_i) + + return freqs_cis + + def __call__(self, ids: torch.Tensor): + assert ids.ndim == 2 + assert ids.shape[-1] == len(self.axes_dims) + device = ids.device + + if self.freqs_cis is None: + self.freqs_cis = self.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta) + self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis] + + result = [] + for i in range(len(self.axes_dims)): + index = ids[:, i] + if IS_NPU_AVAILABLE: + result.append(torch.index_select(self.freqs_cis[i], 0, index)) + else: + result.append(self.freqs_cis[i][index]) + return torch.cat(result, dim=-1) + + +class ZImageDiT(nn.Module): + _supports_gradient_checkpointing = True + _no_split_modules = ["ZImageTransformerBlock"] + _repeated_blocks = ["ZImageTransformerBlock"] + + def __init__( + self, + all_patch_size=(2,), + all_f_patch_size=(1,), + in_channels=16, + dim=3840, + n_layers=30, + n_refiner_layers=2, + n_heads=30, + n_kv_heads=30, + norm_eps=1e-5, + qk_norm=True, + cap_feat_dim=2560, + rope_theta=256.0, + t_scale=1000.0, + axes_dims=[32, 48, 48], + axes_lens=[1024, 512, 512], + siglip_feat_dim=None, + ) -> None: + super().__init__() + self.in_channels = in_channels + self.out_channels = in_channels + self.all_patch_size = all_patch_size + self.all_f_patch_size = all_f_patch_size + self.dim = dim + self.n_heads = n_heads + + self.rope_theta = rope_theta + self.t_scale = t_scale + self.gradient_checkpointing = False + + assert len(all_patch_size) == len(all_f_patch_size) + + all_x_embedder = {} + all_final_layer = {} + for patch_idx, (patch_size, f_patch_size) in enumerate(zip(all_patch_size, all_f_patch_size)): + x_embedder = nn.Linear(f_patch_size * patch_size * patch_size * in_channels, dim, bias=True) + all_x_embedder[f"{patch_size}-{f_patch_size}"] = x_embedder + + final_layer = FinalLayer(dim, patch_size * patch_size * f_patch_size * self.out_channels) + all_final_layer[f"{patch_size}-{f_patch_size}"] = final_layer + + self.all_x_embedder = nn.ModuleDict(all_x_embedder) + self.all_final_layer = nn.ModuleDict(all_final_layer) + self.noise_refiner = nn.ModuleList( + [ + ZImageTransformerBlock( + 1000 + layer_id, + dim, + n_heads, + n_kv_heads, + norm_eps, + qk_norm, + modulation=True, + ) + for layer_id in range(n_refiner_layers) + ] + ) + self.context_refiner = nn.ModuleList( + [ + ZImageTransformerBlock( + layer_id, + dim, + n_heads, + n_kv_heads, + norm_eps, + qk_norm, + modulation=False, + ) + for layer_id in range(n_refiner_layers) + ] + ) + self.t_embedder = TimestepEmbedder(min(dim, ADALN_EMBED_DIM), mid_size=1024) + self.cap_embedder = nn.Sequential( + RMSNorm(cap_feat_dim, eps=norm_eps), + nn.Linear(cap_feat_dim, dim, bias=True), + ) + + # Optional SigLIP components (for Omni variant) + self.siglip_feat_dim = siglip_feat_dim + if siglip_feat_dim is not None: + self.siglip_embedder = nn.Sequential( + RMSNorm(siglip_feat_dim, eps=norm_eps), nn.Linear(siglip_feat_dim, dim, bias=True) + ) + self.siglip_refiner = nn.ModuleList( + [ + ZImageTransformerBlock( + 2000 + layer_id, + dim, + n_heads, + n_kv_heads, + norm_eps, + qk_norm, + modulation=False, + ) + for layer_id in range(n_refiner_layers) + ] + ) + self.siglip_pad_token = nn.Parameter(torch.empty((1, dim))) + else: + self.siglip_embedder = None + self.siglip_refiner = None + self.siglip_pad_token = None + + self.x_pad_token = nn.Parameter(torch.empty((1, dim))) + self.cap_pad_token = nn.Parameter(torch.empty((1, dim))) + + self.layers = nn.ModuleList( + [ + ZImageTransformerBlock(layer_id, dim, n_heads, n_kv_heads, norm_eps, qk_norm) + for layer_id in range(n_layers) + ] + ) + head_dim = dim // n_heads + assert head_dim == sum(axes_dims) + self.axes_dims = axes_dims + self.axes_lens = axes_lens + + self.rope_embedder = RopeEmbedder(theta=rope_theta, axes_dims=axes_dims, axes_lens=axes_lens) + + def unpatchify( + self, + x: List[torch.Tensor], + size: List[Tuple], + patch_size = 2, + f_patch_size = 1, + x_pos_offsets: Optional[List[Tuple[int, int]]] = None, + ) -> List[torch.Tensor]: + pH = pW = patch_size + pF = f_patch_size + bsz = len(x) + assert len(size) == bsz + + if x_pos_offsets is not None: + # Omni: extract target image from unified sequence (cond_images + target) + result = [] + for i in range(bsz): + unified_x = x[i][x_pos_offsets[i][0] : x_pos_offsets[i][1]] + cu_len = 0 + x_item = None + for j in range(len(size[i])): + if size[i][j] is None: + ori_len = 0 + pad_len = SEQ_MULTI_OF + cu_len += pad_len + ori_len + else: + F, H, W = size[i][j] + ori_len = (F // pF) * (H // pH) * (W // pW) + pad_len = (-ori_len) % SEQ_MULTI_OF + x_item = ( + unified_x[cu_len : cu_len + ori_len] + .view(F // pF, H // pH, W // pW, pF, pH, pW, self.out_channels) + .permute(6, 0, 3, 1, 4, 2, 5) + .reshape(self.out_channels, F, H, W) + ) + cu_len += ori_len + pad_len + result.append(x_item) # Return only the last (target) image + return result + else: + # Original mode: simple unpatchify + for i in range(bsz): + F, H, W = size[i] + ori_len = (F // pF) * (H // pH) * (W // pW) + # "f h w pf ph pw c -> c (f pf) (h ph) (w pw)" + x[i] = ( + x[i][:ori_len] + .view(F // pF, H // pH, W // pW, pF, pH, pW, self.out_channels) + .permute(6, 0, 3, 1, 4, 2, 5) + .reshape(self.out_channels, F, H, W) + ) + return x + + @staticmethod + def create_coordinate_grid(size, start=None, device=None): + if start is None: + start = (0 for _ in size) + + axes = [torch.arange(x0, x0 + span, dtype=torch.int32, device=device) for x0, span in zip(start, size)] + grids = torch.meshgrid(axes, indexing="ij") + return torch.stack(grids, dim=-1) + + def patchify_and_embed( + self, + all_image: List[torch.Tensor], + all_cap_feats: List[torch.Tensor], + patch_size: int = 2, + f_patch_size: int = 1, + ): + pH = pW = patch_size + pF = f_patch_size + device = all_image[0].device + + all_image_out = [] + all_image_size = [] + all_image_pos_ids = [] + all_image_pad_mask = [] + all_cap_pos_ids = [] + all_cap_pad_mask = [] + all_cap_feats_out = [] + + for i, (image, cap_feat) in enumerate(zip(all_image, all_cap_feats)): + ### Process Caption + cap_ori_len = len(cap_feat) + cap_padding_len = (-cap_ori_len) % SEQ_MULTI_OF + # padded position ids + cap_padded_pos_ids = self.create_coordinate_grid( + size=(cap_ori_len + cap_padding_len, 1, 1), + start=(1, 0, 0), + device=device, + ).flatten(0, 2) + all_cap_pos_ids.append(cap_padded_pos_ids) + # pad mask + all_cap_pad_mask.append( + torch.cat( + [ + torch.zeros((cap_ori_len,), dtype=torch.bool, device=device), + torch.ones((cap_padding_len,), dtype=torch.bool, device=device), + ], + dim=0, + ) + ) + # padded feature + cap_padded_feat = torch.cat( + [cap_feat, cap_feat[-1:].repeat(cap_padding_len, 1)], + dim=0, + ) + all_cap_feats_out.append(cap_padded_feat) + + ### Process Image + C, F, H, W = image.size() + all_image_size.append((F, H, W)) + F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW + + image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW) + # "c f pf h ph w pw -> (f h w) (pf ph pw c)" + image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C) + + image_ori_len = len(image) + image_padding_len = (-image_ori_len) % SEQ_MULTI_OF + + image_ori_pos_ids = self.create_coordinate_grid( + size=(F_tokens, H_tokens, W_tokens), + start=(cap_ori_len + cap_padding_len + 1, 0, 0), + device=device, + ).flatten(0, 2) + image_padding_pos_ids = ( + self.create_coordinate_grid( + size=(1, 1, 1), + start=(0, 0, 0), + device=device, + ) + .flatten(0, 2) + .repeat(image_padding_len, 1) + ) + image_padded_pos_ids = torch.cat([image_ori_pos_ids, image_padding_pos_ids], dim=0) + all_image_pos_ids.append(image_padded_pos_ids) + # pad mask + all_image_pad_mask.append( + torch.cat( + [ + torch.zeros((image_ori_len,), dtype=torch.bool, device=device), + torch.ones((image_padding_len,), dtype=torch.bool, device=device), + ], + dim=0, + ) + ) + # padded feature + image_padded_feat = torch.cat([image, image[-1:].repeat(image_padding_len, 1)], dim=0) + all_image_out.append(image_padded_feat) + + return all_image_out, all_cap_feats_out, { + "x_size": all_image_size, + "x_pos_ids": all_image_pos_ids, + "cap_pos_ids": all_cap_pos_ids, + "x_pad_mask": all_image_pad_mask, + "cap_pad_mask": all_cap_pad_mask + } + # ( + # all_img_out, + # all_cap_out, + # all_img_size, + # all_img_pos_ids, + # all_cap_pos_ids, + # all_img_pad_mask, + # all_cap_pad_mask, + # ) + + def patchify_controlnet( + self, + all_image: List[torch.Tensor], + patch_size: int = 2, + f_patch_size: int = 1, + cap_padding_len: int = None, + ): + pH = pW = patch_size + pF = f_patch_size + device = all_image[0].device + + all_image_out = [] + all_image_size = [] + all_image_pos_ids = [] + all_image_pad_mask = [] + + for i, image in enumerate(all_image): + ### Process Image + C, F, H, W = image.size() + all_image_size.append((F, H, W)) + F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW + + image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW) + # "c f pf h ph w pw -> (f h w) (pf ph pw c)" + image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C) + + image_ori_len = len(image) + image_padding_len = (-image_ori_len) % SEQ_MULTI_OF + + image_ori_pos_ids = self.create_coordinate_grid( + size=(F_tokens, H_tokens, W_tokens), + start=(cap_padding_len + 1, 0, 0), + device=device, + ).flatten(0, 2) + image_padding_pos_ids = ( + self.create_coordinate_grid( + size=(1, 1, 1), + start=(0, 0, 0), + device=device, + ) + .flatten(0, 2) + .repeat(image_padding_len, 1) + ) + image_padded_pos_ids = torch.cat([image_ori_pos_ids, image_padding_pos_ids], dim=0) + all_image_pos_ids.append(image_padded_pos_ids) + # pad mask + all_image_pad_mask.append( + torch.cat( + [ + torch.zeros((image_ori_len,), dtype=torch.bool, device=device), + torch.ones((image_padding_len,), dtype=torch.bool, device=device), + ], + dim=0, + ) + ) + # padded feature + image_padded_feat = torch.cat([image, image[-1:].repeat(image_padding_len, 1)], dim=0) + all_image_out.append(image_padded_feat) + + return ( + all_image_out, + all_image_size, + all_image_pos_ids, + all_image_pad_mask, + ) + + def _prepare_sequence( + self, + feats: List[torch.Tensor], + pos_ids: List[torch.Tensor], + inner_pad_mask: List[torch.Tensor], + pad_token: torch.nn.Parameter, + noise_mask: Optional[List[List[int]]] = None, + device: torch.device = None, + ): + """Prepare sequence: apply pad token, RoPE embed, pad to batch, create attention mask.""" + item_seqlens = [len(f) for f in feats] + max_seqlen = max(item_seqlens) + bsz = len(feats) + + # Pad token + feats_cat = torch.cat(feats, dim=0) + feats_cat[torch.cat(inner_pad_mask)] = pad_token.to(dtype=feats_cat.dtype, device=feats_cat.device) + feats = list(feats_cat.split(item_seqlens, dim=0)) + + # RoPE + freqs_cis = list(self.rope_embedder(torch.cat(pos_ids, dim=0)).split([len(p) for p in pos_ids], dim=0)) + + # Pad to batch + feats = pad_sequence(feats, batch_first=True, padding_value=0.0) + freqs_cis = pad_sequence(freqs_cis, batch_first=True, padding_value=0.0)[:, : feats.shape[1]] + + # Attention mask + attn_mask = torch.zeros((bsz, max_seqlen), dtype=torch.bool, device=device) + for i, seq_len in enumerate(item_seqlens): + attn_mask[i, :seq_len] = 1 + + # Noise mask + noise_mask_tensor = None + if noise_mask is not None: + noise_mask_tensor = pad_sequence( + [torch.tensor(m, dtype=torch.long, device=device) for m in noise_mask], + batch_first=True, + padding_value=0, + )[:, : feats.shape[1]] + + return feats, freqs_cis, attn_mask, item_seqlens, noise_mask_tensor + + def _build_unified_sequence( + self, + x: torch.Tensor, + x_freqs: torch.Tensor, + x_seqlens: List[int], + x_noise_mask: Optional[List[List[int]]], + cap: torch.Tensor, + cap_freqs: torch.Tensor, + cap_seqlens: List[int], + cap_noise_mask: Optional[List[List[int]]], + siglip: Optional[torch.Tensor], + siglip_freqs: Optional[torch.Tensor], + siglip_seqlens: Optional[List[int]], + siglip_noise_mask: Optional[List[List[int]]], + omni_mode: bool, + device: torch.device, + ): + """Build unified sequence: x, cap, and optionally siglip. + Basic mode order: [x, cap]; Omni mode order: [cap, x, siglip] + """ + bsz = len(x_seqlens) + unified = [] + unified_freqs = [] + unified_noise_mask = [] + + for i in range(bsz): + x_len, cap_len = x_seqlens[i], cap_seqlens[i] + + if omni_mode: + # Omni: [cap, x, siglip] + if siglip is not None and siglip_seqlens is not None: + sig_len = siglip_seqlens[i] + unified.append(torch.cat([cap[i][:cap_len], x[i][:x_len], siglip[i][:sig_len]])) + unified_freqs.append( + torch.cat([cap_freqs[i][:cap_len], x_freqs[i][:x_len], siglip_freqs[i][:sig_len]]) + ) + unified_noise_mask.append( + torch.tensor( + cap_noise_mask[i] + x_noise_mask[i] + siglip_noise_mask[i], dtype=torch.long, device=device + ) + ) + else: + unified.append(torch.cat([cap[i][:cap_len], x[i][:x_len]])) + unified_freqs.append(torch.cat([cap_freqs[i][:cap_len], x_freqs[i][:x_len]])) + unified_noise_mask.append( + torch.tensor(cap_noise_mask[i] + x_noise_mask[i], dtype=torch.long, device=device) + ) + else: + # Basic: [x, cap] + unified.append(torch.cat([x[i][:x_len], cap[i][:cap_len]])) + unified_freqs.append(torch.cat([x_freqs[i][:x_len], cap_freqs[i][:cap_len]])) + + # Compute unified seqlens + if omni_mode: + if siglip is not None and siglip_seqlens is not None: + unified_seqlens = [a + b + c for a, b, c in zip(cap_seqlens, x_seqlens, siglip_seqlens)] + else: + unified_seqlens = [a + b for a, b in zip(cap_seqlens, x_seqlens)] + else: + unified_seqlens = [a + b for a, b in zip(x_seqlens, cap_seqlens)] + + max_seqlen = max(unified_seqlens) + + # Pad to batch + unified = pad_sequence(unified, batch_first=True, padding_value=0.0) + unified_freqs = pad_sequence(unified_freqs, batch_first=True, padding_value=0.0) + + # Attention mask + attn_mask = torch.zeros((bsz, max_seqlen), dtype=torch.bool, device=device) + for i, seq_len in enumerate(unified_seqlens): + attn_mask[i, :seq_len] = 1 + + # Noise mask + noise_mask_tensor = None + if omni_mode: + noise_mask_tensor = pad_sequence(unified_noise_mask, batch_first=True, padding_value=0)[ + :, : unified.shape[1] + ] + + return unified, unified_freqs, attn_mask, noise_mask_tensor + + def _pad_with_ids( + self, + feat: torch.Tensor, + pos_grid_size: Tuple, + pos_start: Tuple, + device: torch.device, + noise_mask_val: Optional[int] = None, + ): + """Pad feature to SEQ_MULTI_OF, create position IDs and pad mask.""" + ori_len = len(feat) + pad_len = (-ori_len) % SEQ_MULTI_OF + total_len = ori_len + pad_len + + # Pos IDs + ori_pos_ids = self.create_coordinate_grid(size=pos_grid_size, start=pos_start, device=device).flatten(0, 2) + if pad_len > 0: + pad_pos_ids = ( + self.create_coordinate_grid(size=(1, 1, 1), start=(0, 0, 0), device=device) + .flatten(0, 2) + .repeat(pad_len, 1) + ) + pos_ids = torch.cat([ori_pos_ids, pad_pos_ids], dim=0) + padded_feat = torch.cat([feat, feat[-1:].repeat(pad_len, 1)], dim=0) + pad_mask = torch.cat( + [ + torch.zeros(ori_len, dtype=torch.bool, device=device), + torch.ones(pad_len, dtype=torch.bool, device=device), + ] + ) + else: + pos_ids = ori_pos_ids + padded_feat = feat + pad_mask = torch.zeros(ori_len, dtype=torch.bool, device=device) + + noise_mask = [noise_mask_val] * total_len if noise_mask_val is not None else None # token level + return padded_feat, pos_ids, pad_mask, total_len, noise_mask + + def _patchify_image(self, image: torch.Tensor, patch_size: int, f_patch_size: int): + """Patchify a single image tensor: (C, F, H, W) -> (num_patches, patch_dim).""" + pH, pW, pF = patch_size, patch_size, f_patch_size + C, F, H, W = image.size() + F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW + image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW) + image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C) + return image, (F, H, W), (F_tokens, H_tokens, W_tokens) + + def patchify_and_embed_omni( + self, + all_x: List[List[torch.Tensor]], + all_cap_feats: List[List[torch.Tensor]], + all_siglip_feats: List[List[torch.Tensor]], + patch_size: int = 2, + f_patch_size: int = 1, + images_noise_mask: List[List[int]] = None, + ): + """Patchify for omni mode: multiple images per batch item with noise masks.""" + bsz = len(all_x) + device = all_x[0][-1].device + dtype = all_x[0][-1].dtype + + all_x_out, all_x_size, all_x_pos_ids, all_x_pad_mask, all_x_len, all_x_noise_mask = [], [], [], [], [], [] + all_cap_out, all_cap_pos_ids, all_cap_pad_mask, all_cap_len, all_cap_noise_mask = [], [], [], [], [] + all_sig_out, all_sig_pos_ids, all_sig_pad_mask, all_sig_len, all_sig_noise_mask = [], [], [], [], [] + + for i in range(bsz): + num_images = len(all_x[i]) + cap_feats_list, cap_pos_list, cap_mask_list, cap_lens, cap_noise = [], [], [], [], [] + cap_end_pos = [] + cap_cu_len = 1 + + # Process captions + for j, cap_item in enumerate(all_cap_feats[i]): + noise_val = images_noise_mask[i][j] if j < len(images_noise_mask[i]) else 1 + cap_out, cap_pos, cap_mask, cap_len, cap_nm = self._pad_with_ids( + cap_item, + (len(cap_item) + (-len(cap_item)) % SEQ_MULTI_OF, 1, 1), + (cap_cu_len, 0, 0), + device, + noise_val, + ) + cap_feats_list.append(cap_out) + cap_pos_list.append(cap_pos) + cap_mask_list.append(cap_mask) + cap_lens.append(cap_len) + cap_noise.extend(cap_nm) + cap_cu_len += len(cap_item) + cap_end_pos.append(cap_cu_len) + cap_cu_len += 2 # for image vae and siglip tokens + + all_cap_out.append(torch.cat(cap_feats_list, dim=0)) + all_cap_pos_ids.append(torch.cat(cap_pos_list, dim=0)) + all_cap_pad_mask.append(torch.cat(cap_mask_list, dim=0)) + all_cap_len.append(cap_lens) + all_cap_noise_mask.append(cap_noise) + + # Process images + x_feats_list, x_pos_list, x_mask_list, x_lens, x_size, x_noise = [], [], [], [], [], [] + for j, x_item in enumerate(all_x[i]): + noise_val = images_noise_mask[i][j] + if x_item is not None: + x_patches, size, (F_t, H_t, W_t) = self._patchify_image(x_item, patch_size, f_patch_size) + x_out, x_pos, x_mask, x_len, x_nm = self._pad_with_ids( + x_patches, (F_t, H_t, W_t), (cap_end_pos[j], 0, 0), device, noise_val + ) + x_size.append(size) + else: + x_len = SEQ_MULTI_OF + x_out = torch.zeros((x_len, X_PAD_DIM), dtype=dtype, device=device) + x_pos = self.create_coordinate_grid((1, 1, 1), (0, 0, 0), device).flatten(0, 2).repeat(x_len, 1) + x_mask = torch.ones(x_len, dtype=torch.bool, device=device) + x_nm = [noise_val] * x_len + x_size.append(None) + x_feats_list.append(x_out) + x_pos_list.append(x_pos) + x_mask_list.append(x_mask) + x_lens.append(x_len) + x_noise.extend(x_nm) + + all_x_out.append(torch.cat(x_feats_list, dim=0)) + all_x_pos_ids.append(torch.cat(x_pos_list, dim=0)) + all_x_pad_mask.append(torch.cat(x_mask_list, dim=0)) + all_x_size.append(x_size) + all_x_len.append(x_lens) + all_x_noise_mask.append(x_noise) + + # Process siglip + if all_siglip_feats[i] is None: + all_sig_len.append([0] * num_images) + all_sig_out.append(None) + else: + sig_feats_list, sig_pos_list, sig_mask_list, sig_lens, sig_noise = [], [], [], [], [] + for j, sig_item in enumerate(all_siglip_feats[i]): + noise_val = images_noise_mask[i][j] + if sig_item is not None: + sig_H, sig_W, sig_C = sig_item.size() + sig_flat = sig_item.permute(2, 0, 1).reshape(sig_H * sig_W, sig_C) + sig_out, sig_pos, sig_mask, sig_len, sig_nm = self._pad_with_ids( + sig_flat, (1, sig_H, sig_W), (cap_end_pos[j] + 1, 0, 0), device, noise_val + ) + # Scale position IDs to match x resolution + if x_size[j] is not None: + sig_pos = sig_pos.float() + sig_pos[..., 1] = sig_pos[..., 1] / max(sig_H - 1, 1) * (x_size[j][1] - 1) + sig_pos[..., 2] = sig_pos[..., 2] / max(sig_W - 1, 1) * (x_size[j][2] - 1) + sig_pos = sig_pos.to(torch.int32) + else: + sig_len = SEQ_MULTI_OF + sig_out = torch.zeros((sig_len, self.siglip_feat_dim), dtype=dtype, device=device) + sig_pos = ( + self.create_coordinate_grid((1, 1, 1), (0, 0, 0), device).flatten(0, 2).repeat(sig_len, 1) + ) + sig_mask = torch.ones(sig_len, dtype=torch.bool, device=device) + sig_nm = [noise_val] * sig_len + sig_feats_list.append(sig_out) + sig_pos_list.append(sig_pos) + sig_mask_list.append(sig_mask) + sig_lens.append(sig_len) + sig_noise.extend(sig_nm) + + all_sig_out.append(torch.cat(sig_feats_list, dim=0)) + all_sig_pos_ids.append(torch.cat(sig_pos_list, dim=0)) + all_sig_pad_mask.append(torch.cat(sig_mask_list, dim=0)) + all_sig_len.append(sig_lens) + all_sig_noise_mask.append(sig_noise) + + # Compute x position offsets + all_x_pos_offsets = [(sum(all_cap_len[i]), sum(all_cap_len[i]) + sum(all_x_len[i])) for i in range(bsz)] + + return ( + all_x_out, + all_cap_out, + all_sig_out, + all_x_size, + all_x_pos_ids, + all_cap_pos_ids, + all_sig_pos_ids, + all_x_pad_mask, + all_cap_pad_mask, + all_sig_pad_mask, + all_x_pos_offsets, + all_x_noise_mask, + all_cap_noise_mask, + all_sig_noise_mask, + ) + return all_x_out, all_cap_out, all_sig_out, { + "x_size": x_size, + "x_pos_ids": all_x_pos_ids, + "cap_pos_ids": all_cap_pos_ids, + "sig_pos_ids": all_sig_pos_ids, + "x_pad_mask": all_x_pad_mask, + "cap_pad_mask": all_cap_pad_mask, + "sig_pad_mask": all_sig_pad_mask, + "x_pos_offsets": all_x_pos_offsets, + "x_noise_mask": all_x_noise_mask, + "cap_noise_mask": all_cap_noise_mask, + "sig_noise_mask": all_sig_noise_mask, + } + + def forward( + self, + x: List[torch.Tensor], + t, + cap_feats: List[torch.Tensor], + siglip_feats = None, + image_noise_mask = None, + patch_size=2, + f_patch_size=1, + use_gradient_checkpointing=False, + use_gradient_checkpointing_offload=False, + ): + assert patch_size in self.all_patch_size and f_patch_size in self.all_f_patch_size + omni_mode = isinstance(x[0], list) + device = x[0][-1].device if omni_mode else x[0].device + + if omni_mode: + # Dual embeddings: noisy (t) and clean (t=1) + t_noisy = self.t_embedder(t * self.t_scale).type_as(x[0][-1]) + t_clean = self.t_embedder(torch.ones_like(t) * self.t_scale).type_as(x[0][-1]) + adaln_input = None + else: + # Single embedding for all tokens + adaln_input = self.t_embedder(t * self.t_scale).type_as(x[0]) + t_noisy = t_clean = None + + # Patchify + if omni_mode: + ( + x, + cap_feats, + siglip_feats, + x_size, + x_pos_ids, + cap_pos_ids, + siglip_pos_ids, + x_pad_mask, + cap_pad_mask, + siglip_pad_mask, + x_pos_offsets, + x_noise_mask, + cap_noise_mask, + siglip_noise_mask, + ) = self.patchify_and_embed_omni(x, cap_feats, siglip_feats, patch_size, f_patch_size, image_noise_mask) + else: + ( + x, + cap_feats, + x_size, + x_pos_ids, + cap_pos_ids, + x_pad_mask, + cap_pad_mask, + ) = self.patchify_and_embed(x, cap_feats, patch_size, f_patch_size) + x_pos_offsets = x_noise_mask = cap_noise_mask = siglip_noise_mask = None + + # x embed & refine + x_seqlens = [len(xi) for xi in x] + x = self.all_x_embedder[f"{patch_size}-{f_patch_size}"](torch.cat(x, dim=0)) # embed + x, x_freqs, x_mask, _, x_noise_tensor = self._prepare_sequence( + list(x.split(x_seqlens, dim=0)), x_pos_ids, x_pad_mask, self.x_pad_token, x_noise_mask, device + ) + + for layer in self.noise_refiner: + x = gradient_checkpoint_forward( + layer, + use_gradient_checkpointing=use_gradient_checkpointing, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + x=x, attn_mask=x_mask, freqs_cis=x_freqs, adaln_input=adaln_input, noise_mask=x_noise_tensor, adaln_noisy=t_noisy, adaln_clean=t_clean, + ) + + # Cap embed & refine + cap_seqlens = [len(ci) for ci in cap_feats] + cap_feats = self.cap_embedder(torch.cat(cap_feats, dim=0)) # embed + cap_feats, cap_freqs, cap_mask, _, _ = self._prepare_sequence( + list(cap_feats.split(cap_seqlens, dim=0)), cap_pos_ids, cap_pad_mask, self.cap_pad_token, None, device + ) + + for layer in self.context_refiner: + cap_feats = gradient_checkpoint_forward( + layer, + use_gradient_checkpointing=use_gradient_checkpointing, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + x=cap_feats, + attn_mask=cap_mask, + freqs_cis=cap_freqs, + ) + + # Siglip embed & refine + siglip_seqlens = siglip_freqs = None + if omni_mode and siglip_feats[0] is not None and self.siglip_embedder is not None: + siglip_seqlens = [len(si) for si in siglip_feats] + siglip_feats = self.siglip_embedder(torch.cat(siglip_feats, dim=0)) # embed + siglip_feats, siglip_freqs, siglip_mask, _, _ = self._prepare_sequence( + list(siglip_feats.split(siglip_seqlens, dim=0)), + siglip_pos_ids, + siglip_pad_mask, + self.siglip_pad_token, + None, + device, + ) + + for layer in self.siglip_refiner: + siglip_feats = gradient_checkpoint_forward( + layer, + use_gradient_checkpointing=use_gradient_checkpointing, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + x=siglip_feats, attn_mask=siglip_mask, freqs_cis=siglip_freqs, + ) + + # Unified sequence + unified, unified_freqs, unified_mask, unified_noise_tensor = self._build_unified_sequence( + x, + x_freqs, + x_seqlens, + x_noise_mask, + cap_feats, + cap_freqs, + cap_seqlens, + cap_noise_mask, + siglip_feats, + siglip_freqs, + siglip_seqlens, + siglip_noise_mask, + omni_mode, + device, + ) + + # Main transformer layers + for layer_idx, layer in enumerate(self.layers): + unified = gradient_checkpoint_forward( + layer, + use_gradient_checkpointing=use_gradient_checkpointing, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + x=unified, attn_mask=unified_mask, freqs_cis=unified_freqs, adaln_input=adaln_input, noise_mask=unified_noise_tensor, adaln_noisy=t_noisy, adaln_clean=t_clean + ) + + unified = ( + self.all_final_layer[f"{patch_size}-{f_patch_size}"]( + unified, noise_mask=unified_noise_tensor, c_noisy=t_noisy, c_clean=t_clean + ) + if omni_mode + else self.all_final_layer[f"{patch_size}-{f_patch_size}"](unified, c=adaln_input) + ) + + # Unpatchify + x = self.unpatchify(list(unified.unbind(dim=0)), x_size, patch_size, f_patch_size, x_pos_offsets) + + return x diff --git a/diffsynth/models/z_image_image2lora.py b/diffsynth/models/z_image_image2lora.py new file mode 100644 index 0000000000000000000000000000000000000000..757f3f6778bb24187333451bf30d6773b867ad77 --- /dev/null +++ b/diffsynth/models/z_image_image2lora.py @@ -0,0 +1,189 @@ +import torch +from .qwen_image_image2lora import ImageEmbeddingToLoraMatrix, SequencialMLP + + +class LoRATrainerBlock(torch.nn.Module): + def __init__(self, lora_patterns, in_dim=1536+4096, compress_dim=128, rank=4, block_id=0, use_residual=True, residual_length=64+7, residual_dim=3584, residual_mid_dim=1024, prefix="transformer_blocks"): + super().__init__() + self.prefix = prefix + self.lora_patterns = lora_patterns + self.block_id = block_id + self.layers = [] + for name, lora_a_dim, lora_b_dim in self.lora_patterns: + self.layers.append(ImageEmbeddingToLoraMatrix(in_dim, compress_dim, lora_a_dim, lora_b_dim, rank)) + self.layers = torch.nn.ModuleList(self.layers) + if use_residual: + self.proj_residual = SequencialMLP(residual_length, residual_dim, residual_mid_dim, compress_dim) + else: + self.proj_residual = None + + def forward(self, x, residual=None): + lora = {} + if self.proj_residual is not None: residual = self.proj_residual(residual) + for lora_pattern, layer in zip(self.lora_patterns, self.layers): + name = lora_pattern[0] + lora_a, lora_b = layer(x, residual=residual) + lora[f"{self.prefix}.{self.block_id}.{name}.lora_A.default.weight"] = lora_a + lora[f"{self.prefix}.{self.block_id}.{name}.lora_B.default.weight"] = lora_b + return lora + + +class ZImageImage2LoRAComponent(torch.nn.Module): + def __init__(self, lora_patterns, prefix, num_blocks=60, use_residual=True, compress_dim=128, rank=4, residual_length=64+7, residual_mid_dim=1024): + super().__init__() + self.lora_patterns = lora_patterns + self.num_blocks = num_blocks + self.blocks = [] + for lora_patterns in self.lora_patterns: + for block_id in range(self.num_blocks): + self.blocks.append(LoRATrainerBlock(lora_patterns, block_id=block_id, use_residual=use_residual, compress_dim=compress_dim, rank=rank, residual_length=residual_length, residual_mid_dim=residual_mid_dim, prefix=prefix)) + self.blocks = torch.nn.ModuleList(self.blocks) + self.residual_scale = 0.05 + self.use_residual = use_residual + + def forward(self, x, residual=None): + if residual is not None: + if self.use_residual: + residual = residual * self.residual_scale + else: + residual = None + lora = {} + for block in self.blocks: + lora.update(block(x, residual)) + return lora + + +class ZImageImage2LoRAModel(torch.nn.Module): + def __init__(self, use_residual=False, compress_dim=64, rank=4, residual_length=64+7, residual_mid_dim=1024): + super().__init__() + lora_patterns = [ + [ + ("attention.to_q", 3840, 3840), + ("attention.to_k", 3840, 3840), + ("attention.to_v", 3840, 3840), + ("attention.to_out.0", 3840, 3840), + ], + [ + ("feed_forward.w1", 3840, 10240), + ("feed_forward.w2", 10240, 3840), + ("feed_forward.w3", 3840, 10240), + ], + ] + config = { + "lora_patterns": lora_patterns, + "use_residual": use_residual, + "compress_dim": compress_dim, + "rank": rank, + "residual_length": residual_length, + "residual_mid_dim": residual_mid_dim, + } + self.layers_lora = ZImageImage2LoRAComponent( + prefix="layers", + num_blocks=30, + **config, + ) + self.context_refiner_lora = ZImageImage2LoRAComponent( + prefix="context_refiner", + num_blocks=2, + **config, + ) + self.noise_refiner_lora = ZImageImage2LoRAComponent( + prefix="noise_refiner", + num_blocks=2, + **config, + ) + + def forward(self, x, residual=None): + lora = {} + lora.update(self.layers_lora(x, residual=residual)) + lora.update(self.context_refiner_lora(x, residual=residual)) + lora.update(self.noise_refiner_lora(x, residual=residual)) + return lora + + def initialize_weights(self): + state_dict = self.state_dict() + for name in state_dict: + if ".proj_a." in name: + state_dict[name] = state_dict[name] * 0.3 + elif ".proj_b.proj_out." in name: + state_dict[name] = state_dict[name] * 0 + elif ".proj_residual.proj_out." in name: + state_dict[name] = state_dict[name] * 0.3 + self.load_state_dict(state_dict) + + +class ImageEmb2LoRAWeightCompressed(torch.nn.Module): + def __init__(self, in_dim, out_dim, emb_dim, rank): + super().__init__() + self.lora_a = torch.nn.Parameter(torch.randn((rank, in_dim))) + self.lora_b = torch.nn.Parameter(torch.randn((out_dim, rank))) + self.proj = torch.nn.Linear(emb_dim, rank * rank, bias=True) + self.rank = rank + + def forward(self, x): + x = self.proj(x).view(self.rank, self.rank) + lora_a = x @ self.lora_a + lora_b = self.lora_b + return lora_a, lora_b + + +class ZImageImage2LoRAModelCompressed(torch.nn.Module): + def __init__(self, emb_dim=1536+4096, rank=32): + super().__init__() + target_layers = [ + ("attention.to_q", 3840, 3840), + ("attention.to_k", 3840, 3840), + ("attention.to_v", 3840, 3840), + ("attention.to_out.0", 3840, 3840), + ("feed_forward.w1", 3840, 10240), + ("feed_forward.w2", 10240, 3840), + ("feed_forward.w3", 3840, 10240), + ] + self.lora_patterns = [ + { + "prefix": "layers", + "num_layers": 30, + "target_layers": target_layers, + }, + { + "prefix": "context_refiner", + "num_layers": 2, + "target_layers": target_layers, + }, + { + "prefix": "noise_refiner", + "num_layers": 2, + "target_layers": target_layers, + }, + ] + module_dict = {} + for lora_pattern in self.lora_patterns: + prefix, num_layers, target_layers = lora_pattern["prefix"], lora_pattern["num_layers"], lora_pattern["target_layers"] + for layer_id in range(num_layers): + for layer_name, in_dim, out_dim in target_layers: + name = f"{prefix}.{layer_id}.{layer_name}".replace(".", "___") + model = ImageEmb2LoRAWeightCompressed(in_dim, out_dim, emb_dim, rank) + module_dict[name] = model + self.module_dict = torch.nn.ModuleDict(module_dict) + + def forward(self, x, residual=None): + lora = {} + for name, module in self.module_dict.items(): + name = name.replace("___", ".") + name_a, name_b = f"{name}.lora_A.default.weight", f"{name}.lora_B.default.weight" + lora_a, lora_b = module(x) + lora[name_a] = lora_a + lora[name_b] = lora_b + return lora + + def initialize_weights(self): + state_dict = self.state_dict() + for name in state_dict: + if "lora_b" in name: + state_dict[name] = state_dict[name] * 0 + elif "lora_a" in name: + state_dict[name] = state_dict[name] * 0.2 + elif "proj.weight" in name: + print(name) + state_dict[name] = state_dict[name] * 0.2 + self.load_state_dict(state_dict) diff --git a/diffsynth/models/z_image_text_encoder.py b/diffsynth/models/z_image_text_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..6f3e6c00b723be3dbf9a2e3a69bf8ce90b689c2a --- /dev/null +++ b/diffsynth/models/z_image_text_encoder.py @@ -0,0 +1,104 @@ +from transformers import Qwen3Model, Qwen3Config +import torch + + +class ZImageTextEncoder(torch.nn.Module): + def __init__(self, model_size="4B"): + super().__init__() + config_dict = { + "0.6B": Qwen3Config(**{ + "architectures": [ + "Qwen3ForCausalLM" + ], + "attention_bias": False, + "attention_dropout": 0.0, + "bos_token_id": 151643, + "eos_token_id": 151645, + "head_dim": 128, + "hidden_act": "silu", + "hidden_size": 1024, + "initializer_range": 0.02, + "intermediate_size": 3072, + "max_position_embeddings": 40960, + "max_window_layers": 28, + "model_type": "qwen3", + "num_attention_heads": 16, + "num_hidden_layers": 28, + "num_key_value_heads": 8, + "rms_norm_eps": 1e-06, + "rope_scaling": None, + "rope_theta": 1000000, + "sliding_window": None, + "tie_word_embeddings": True, + "torch_dtype": "bfloat16", + "transformers_version": "4.51.0", + "use_cache": True, + "use_sliding_window": False, + "vocab_size": 151936 + }), + "4B": Qwen3Config(**{ + "architectures": [ + "Qwen3ForCausalLM" + ], + "attention_bias": False, + "attention_dropout": 0.0, + "bos_token_id": 151643, + "eos_token_id": 151645, + "head_dim": 128, + "hidden_act": "silu", + "hidden_size": 2560, + "initializer_range": 0.02, + "intermediate_size": 9728, + "max_position_embeddings": 40960, + "max_window_layers": 36, + "model_type": "qwen3", + "num_attention_heads": 32, + "num_hidden_layers": 36, + "num_key_value_heads": 8, + "rms_norm_eps": 1e-06, + "rope_scaling": None, + "rope_theta": 1000000, + "sliding_window": None, + "tie_word_embeddings": True, + "torch_dtype": "bfloat16", + "transformers_version": "4.51.0", + "use_cache": True, + "use_sliding_window": False, + "vocab_size": 151936 + }), + "8B": Qwen3Config(**{ + "architectures": [ + "Qwen3ForCausalLM" + ], + "attention_bias": False, + "attention_dropout": 0.0, + "bos_token_id": 151643, + "dtype": "bfloat16", + "eos_token_id": 151645, + "head_dim": 128, + "hidden_act": "silu", + "hidden_size": 4096, + "initializer_range": 0.02, + "intermediate_size": 12288, + "max_position_embeddings": 40960, + "max_window_layers": 36, + "model_type": "qwen3", + "num_attention_heads": 32, + "num_hidden_layers": 36, + "num_key_value_heads": 8, + "rms_norm_eps": 1e-06, + "rope_scaling": None, + "rope_theta": 1000000, + "sliding_window": None, + "tie_word_embeddings": False, + "transformers_version": "4.56.1", + "use_cache": True, + "use_sliding_window": False, + "vocab_size": 151936 + }) + } + config = config_dict[model_size] + self.model = Qwen3Model(config) + + def forward(self, *args, **kwargs): + return self.model(*args, **kwargs) diff --git a/diffsynth/pipelines/anima_image.py b/diffsynth/pipelines/anima_image.py new file mode 100644 index 0000000000000000000000000000000000000000..bc4f6cd31eda74cbebd3047ee8845c159999c1aa --- /dev/null +++ b/diffsynth/pipelines/anima_image.py @@ -0,0 +1,264 @@ +import torch, math +from PIL import Image +from typing import Union +from tqdm import tqdm +from einops import rearrange +import numpy as np +from math import prod +from transformers import AutoTokenizer + +from ..core.device.npu_compatible_device import get_device_type +from ..diffusion import FlowMatchScheduler +from ..core import ModelConfig, gradient_checkpoint_forward +from ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput +from ..utils.lora.merge import merge_lora + +from ..models.anima_dit import AnimaDiT +from ..models.z_image_text_encoder import ZImageTextEncoder +from ..models.wan_video_vae import WanVideoVAE + + +class AnimaImagePipeline(BasePipeline): + + def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16): + super().__init__( + device=device, torch_dtype=torch_dtype, + height_division_factor=16, width_division_factor=16, + ) + self.scheduler = FlowMatchScheduler("Z-Image") + self.text_encoder: ZImageTextEncoder = None + self.dit: AnimaDiT = None + self.vae: WanVideoVAE = None + self.tokenizer: AutoTokenizer = None + self.tokenizer_t5xxl: AutoTokenizer = None + self.in_iteration_models = ("dit",) + self.units = [ + AnimaUnit_ShapeChecker(), + AnimaUnit_NoiseInitializer(), + AnimaUnit_InputImageEmbedder(), + AnimaUnit_PromptEmbedder(), + ] + self.model_fn = model_fn_anima + self.compilable_models = ["dit"] + + + @staticmethod + def from_pretrained( + torch_dtype: torch.dtype = torch.bfloat16, + device: Union[str, torch.device] = get_device_type(), + model_configs: list[ModelConfig] = [], + tokenizer_config: ModelConfig = ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="./"), + tokenizer_t5xxl_config: ModelConfig = ModelConfig(model_id="stabilityai/stable-diffusion-3.5-large", origin_file_pattern="tokenizer_3/"), + vram_limit: float = None, + ): + # Initialize pipeline + pipe = AnimaImagePipeline(device=device, torch_dtype=torch_dtype) + model_pool = pipe.download_and_load_models(model_configs, vram_limit) + + # Fetch models + pipe.text_encoder = model_pool.fetch_model("z_image_text_encoder") + pipe.dit = model_pool.fetch_model("anima_dit") + pipe.vae = model_pool.fetch_model("wan_video_vae") + if tokenizer_config is not None: + tokenizer_config.download_if_necessary() + pipe.tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.path) + if tokenizer_t5xxl_config is not None: + tokenizer_t5xxl_config.download_if_necessary() + pipe.tokenizer_t5xxl = AutoTokenizer.from_pretrained(tokenizer_t5xxl_config.path) + # VRAM Management + pipe.vram_management_enabled = pipe.check_vram_management_state() + return pipe + + + @torch.no_grad() + def __call__( + self, + # Prompt + prompt: str, + negative_prompt: str = "", + cfg_scale: float = 4.0, + # Image + input_image: Image.Image = None, + denoising_strength: float = 1.0, + # Shape + height: int = 1024, + width: int = 1024, + # Randomness + seed: int = None, + rand_device: str = "cpu", + # Steps + num_inference_steps: int = 30, + sigma_shift: float = None, + # Progress bar + progress_bar_cmd = tqdm, + ): + # Scheduler + self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift) + + # Parameters + inputs_posi = { + "prompt": prompt, + } + inputs_nega = { + "negative_prompt": negative_prompt, + } + inputs_shared = { + "cfg_scale": cfg_scale, + "input_image": input_image, "denoising_strength": denoising_strength, + "height": height, "width": width, + "seed": seed, "rand_device": rand_device, + "num_inference_steps": num_inference_steps, + } + for unit in self.units: + inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega) + + # Denoise + self.load_models_to_device(self.in_iteration_models) + models = {name: getattr(self, name) for name in self.in_iteration_models} + for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): + timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device) + noise_pred = self.cfg_guided_model_fn( + self.model_fn, cfg_scale, + inputs_shared, inputs_posi, inputs_nega, + **models, timestep=timestep, progress_id=progress_id + ) + inputs_shared["latents"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared) + + # Decode + self.load_models_to_device(['vae']) + image = self.vae.decode(inputs_shared["latents"].unsqueeze(2), device=self.device).squeeze(2) + image = self.vae_output_to_image(image) + self.load_models_to_device([]) + + return image + + +class AnimaUnit_ShapeChecker(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("height", "width"), + output_params=("height", "width"), + ) + + def process(self, pipe: AnimaImagePipeline, height, width): + height, width = pipe.check_resize_height_width(height, width) + return {"height": height, "width": width} + + + +class AnimaUnit_NoiseInitializer(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("height", "width", "seed", "rand_device"), + output_params=("noise",), + ) + + def process(self, pipe: AnimaImagePipeline, height, width, seed, rand_device): + noise = pipe.generate_noise((1, 16, height//8, width//8), seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype) + return {"noise": noise} + + + +class AnimaUnit_InputImageEmbedder(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("input_image", "noise"), + output_params=("latents", "input_latents"), + onload_model_names=("vae",) + ) + + def process(self, pipe: AnimaImagePipeline, input_image, noise): + if input_image is None: + return {"latents": noise, "input_latents": None} + pipe.load_models_to_device(['vae']) + if isinstance(input_image, list): + input_latents = [] + for image in input_image: + image = pipe.preprocess_image(image).to(device=pipe.device, dtype=pipe.torch_dtype) + input_latents.append(pipe.vae.encode(image)) + input_latents = torch.concat(input_latents, dim=0) + else: + image = pipe.preprocess_image(input_image).to(device=pipe.device, dtype=pipe.torch_dtype) + input_latents = pipe.vae.encode(image.unsqueeze(2), device=pipe.device).squeeze(2) + if pipe.scheduler.training: + return {"latents": noise, "input_latents": input_latents} + else: + latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0]) + return {"latents": latents, "input_latents": input_latents} + + +class AnimaUnit_PromptEmbedder(PipelineUnit): + def __init__(self): + super().__init__( + seperate_cfg=True, + input_params_posi={"prompt": "prompt"}, + input_params_nega={"prompt": "negative_prompt"}, + output_params=("prompt_emb",), + onload_model_names=("text_encoder",) + ) + + def encode_prompt( + self, + pipe: AnimaImagePipeline, + prompt, + device = None, + max_sequence_length: int = 512, + ): + if isinstance(prompt, str): + prompt = [prompt] + + text_inputs = pipe.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids.to(device) + prompt_masks = text_inputs.attention_mask.to(device).bool() + + prompt_embeds = pipe.text_encoder( + input_ids=text_input_ids, + attention_mask=prompt_masks, + output_hidden_states=True, + ).hidden_states[-1] + + t5xxl_text_inputs = pipe.tokenizer_t5xxl( + prompt, + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + ) + t5xxl_ids = t5xxl_text_inputs.input_ids.to(device) + + return prompt_embeds.to(pipe.torch_dtype), t5xxl_ids + + def process(self, pipe: AnimaImagePipeline, prompt): + pipe.load_models_to_device(self.onload_model_names) + prompt_embeds, t5xxl_ids = self.encode_prompt(pipe, prompt, pipe.device) + return {"prompt_emb": prompt_embeds, "t5xxl_ids": t5xxl_ids} + + +def model_fn_anima( + dit: AnimaDiT = None, + latents=None, + timestep=None, + prompt_emb=None, + t5xxl_ids=None, + use_gradient_checkpointing=False, + use_gradient_checkpointing_offload=False, + **kwargs +): + latents = latents.unsqueeze(2) + timestep = timestep / 1000 + model_output = dit( + x=latents, + timesteps=timestep, + context=prompt_emb, + t5xxl_ids=t5xxl_ids, + use_gradient_checkpointing=use_gradient_checkpointing, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + ) + model_output = model_output.squeeze(2) + return model_output diff --git a/diffsynth/pipelines/flux2_image.py b/diffsynth/pipelines/flux2_image.py new file mode 100644 index 0000000000000000000000000000000000000000..7b6dcc492beb8d93c1f2ef29fbf5f3f8bc325c24 --- /dev/null +++ b/diffsynth/pipelines/flux2_image.py @@ -0,0 +1,595 @@ +import torch, math, torchvision +from PIL import Image +from typing import Union +from tqdm import tqdm +from einops import rearrange +import numpy as np +from typing import Union, List, Optional, Tuple + +from ..core.device.npu_compatible_device import get_device_type +from ..diffusion import FlowMatchScheduler +from ..core import ModelConfig, gradient_checkpoint_forward +from ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput + +from transformers import AutoProcessor, AutoTokenizer +from ..models.flux2_text_encoder import Flux2TextEncoder +from ..models.flux2_dit import Flux2DiT +from ..models.flux2_vae import Flux2VAE +from ..models.z_image_text_encoder import ZImageTextEncoder + + +class Flux2ImagePipeline(BasePipeline): + + def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16): + super().__init__( + device=device, torch_dtype=torch_dtype, + height_division_factor=16, width_division_factor=16, + ) + self.scheduler = FlowMatchScheduler("FLUX.2") + self.text_encoder: Flux2TextEncoder = None + self.text_encoder_qwen3: ZImageTextEncoder = None + self.dit: Flux2DiT = None + self.vae: Flux2VAE = None + self.tokenizer: AutoProcessor = None + self.in_iteration_models = ("dit",) + self.units = [ + Flux2Unit_ShapeChecker(), + Flux2Unit_PromptEmbedder(), + Flux2Unit_Qwen3PromptEmbedder(), + Flux2Unit_NoiseInitializer(), + Flux2Unit_InputImageEmbedder(), + Flux2Unit_EditImageEmbedder(), + Flux2Unit_ImageIDs(), + ] + self.model_fn = model_fn_flux2 + self.compilable_models = ["dit"] + + + @staticmethod + def from_pretrained( + torch_dtype: torch.dtype = torch.bfloat16, + device: Union[str, torch.device] = get_device_type(), + model_configs: list[ModelConfig] = [], + tokenizer_config: ModelConfig = ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="tokenizer/"), + vram_limit: float = None, + ): + # Initialize pipeline + pipe = Flux2ImagePipeline(device=device, torch_dtype=torch_dtype) + model_pool = pipe.download_and_load_models(model_configs, vram_limit) + + # Fetch models + pipe.text_encoder = model_pool.fetch_model("flux2_text_encoder") + pipe.text_encoder_qwen3 = model_pool.fetch_model("z_image_text_encoder") + pipe.dit = model_pool.fetch_model("flux2_dit") + pipe.vae = model_pool.fetch_model("flux2_vae") + if tokenizer_config is not None: + tokenizer_config.download_if_necessary() + pipe.tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.path) + + # VRAM Management + pipe.vram_management_enabled = pipe.check_vram_management_state() + return pipe + + + @torch.no_grad() + def __call__( + self, + # Prompt + prompt: str, + negative_prompt: str = "", + cfg_scale: float = 1.0, + embedded_guidance: float = 4.0, + # Image + input_image: Image.Image = None, + denoising_strength: float = 1.0, + # Edit + edit_image: Union[Image.Image, List[Image.Image]] = None, + edit_image_auto_resize: bool = True, + # Shape + height: int = 1024, + width: int = 1024, + # Randomness + seed: int = None, + rand_device: str = "cpu", + initial_noise: torch.Tensor = None, + # Steps + num_inference_steps: int = 30, + # Progress bar + progress_bar_cmd = tqdm, + ): + self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, dynamic_shift_len=height//16*width//16) + + # Parameters + inputs_posi = { + "prompt": prompt, + } + inputs_nega = { + "negative_prompt": negative_prompt, + } + inputs_shared = { + "cfg_scale": cfg_scale, "embedded_guidance": embedded_guidance, + "input_image": input_image, "denoising_strength": denoising_strength, + "edit_image": edit_image, "edit_image_auto_resize": edit_image_auto_resize, + "height": height, "width": width, + "seed": seed, "rand_device": rand_device, "initial_noise": initial_noise, + "num_inference_steps": num_inference_steps, + } + for unit in self.units: + inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega) + + # Denoise + self.load_models_to_device(self.in_iteration_models) + models = {name: getattr(self, name) for name in self.in_iteration_models} + for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): + timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device) + noise_pred = self.cfg_guided_model_fn( + self.model_fn, cfg_scale, + inputs_shared, inputs_posi, inputs_nega, + **models, timestep=timestep, progress_id=progress_id + ) + inputs_shared["latents"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared) + + # Decode + self.load_models_to_device(['vae']) + latents = rearrange(inputs_shared["latents"], "B (H W) C -> B C H W", H=inputs_shared["height"]//16, W=inputs_shared["width"]//16) + image = self.vae.decode(latents) + image = self.vae_output_to_image(image) + self.load_models_to_device([]) + + return image + + +class Flux2Unit_ShapeChecker(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("height", "width"), + output_params=("height", "width"), + ) + + def process(self, pipe: Flux2ImagePipeline, height, width): + height, width = pipe.check_resize_height_width(height, width) + return {"height": height, "width": width} + + +class Flux2Unit_PromptEmbedder(PipelineUnit): + def __init__(self): + super().__init__( + seperate_cfg=True, + input_params_posi={"prompt": "prompt"}, + input_params_nega={"prompt": "negative_prompt"}, + output_params=("prompt_emb", "prompt_emb_mask"), + onload_model_names=("text_encoder",) + ) + self.system_message = "You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object attribution and actions without speculation." + + def format_text_input(self, prompts: List[str], system_message: str = None): + # Remove [IMG] tokens from prompts to avoid Pixtral validation issues + # when truncation is enabled. The processor counts [IMG] tokens and fails + # if the count changes after truncation. + cleaned_txt = [prompt.replace("[IMG]", "") for prompt in prompts] + + return [ + [ + { + "role": "system", + "content": [{"type": "text", "text": system_message}], + }, + {"role": "user", "content": [{"type": "text", "text": prompt}]}, + ] + for prompt in cleaned_txt + ] + + def get_mistral_3_small_prompt_embeds( + self, + text_encoder, + tokenizer, + prompt: Union[str, List[str]], + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + max_sequence_length: int = 512, + # fmt: off + system_message: str = "You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object attribution and actions without speculation.", + # fmt: on + hidden_states_layers: List[int] = (10, 20, 30), + ): + dtype = text_encoder.dtype if dtype is None else dtype + device = text_encoder.device if device is None else device + + prompt = [prompt] if isinstance(prompt, str) else prompt + + # Format input messages + messages_batch = self.format_text_input(prompts=prompt, system_message=system_message) + + # Process all messages at once + inputs = tokenizer.apply_chat_template( + messages_batch, + add_generation_prompt=False, + tokenize=True, + return_dict=True, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=max_sequence_length, + ) + + # Move to device + input_ids = inputs["input_ids"].to(device) + attention_mask = inputs["attention_mask"].to(device) + + # Forward pass through the model + output = text_encoder( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=True, + use_cache=False, + ) + + # Only use outputs from intermediate layers and stack them + out = torch.stack([output.hidden_states[k] for k in hidden_states_layers], dim=1) + out = out.to(dtype=dtype, device=device) + + batch_size, num_channels, seq_len, hidden_dim = out.shape + prompt_embeds = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, num_channels * hidden_dim) + + return prompt_embeds + + def prepare_text_ids( + self, + x: torch.Tensor, # (B, L, D) or (L, D) + t_coord: Optional[torch.Tensor] = None, + ): + B, L, _ = x.shape + out_ids = [] + + for i in range(B): + t = torch.arange(1) if t_coord is None else t_coord[i] + h = torch.arange(1) + w = torch.arange(1) + l = torch.arange(L) + + coords = torch.cartesian_prod(t, h, w, l) + out_ids.append(coords) + + return torch.stack(out_ids) + + def encode_prompt( + self, + text_encoder, + tokenizer, + prompt: Union[str, List[str]], + dtype = None, + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + max_sequence_length: int = 512, + text_encoder_out_layers: Tuple[int] = (10, 20, 30), + ): + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt_embeds is None: + prompt_embeds = self.get_mistral_3_small_prompt_embeds( + text_encoder=text_encoder, + tokenizer=tokenizer, + prompt=prompt, + dtype=dtype, + device=device, + max_sequence_length=max_sequence_length, + system_message=self.system_message, + hidden_states_layers=text_encoder_out_layers, + ) + + batch_size, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + text_ids = self.prepare_text_ids(prompt_embeds) + text_ids = text_ids.to(device) + return prompt_embeds, text_ids + + def process(self, pipe: Flux2ImagePipeline, prompt): + # Skip if Qwen3 text encoder is available (handled by Qwen3PromptEmbedder) + if pipe.text_encoder_qwen3 is not None: + return {} + + pipe.load_models_to_device(self.onload_model_names) + prompt_embeds, text_ids = self.encode_prompt( + pipe.text_encoder, pipe.tokenizer, prompt, + dtype=pipe.torch_dtype, device=pipe.device, + ) + return {"prompt_embeds": prompt_embeds, "text_ids": text_ids} + + +class Flux2Unit_Qwen3PromptEmbedder(PipelineUnit): + def __init__(self): + super().__init__( + seperate_cfg=True, + input_params_posi={"prompt": "prompt"}, + input_params_nega={"prompt": "negative_prompt"}, + output_params=("prompt_emb", "prompt_emb_mask"), + onload_model_names=("text_encoder_qwen3",) + ) + self.hidden_states_layers = (9, 18, 27) # Qwen3 layers + + def get_qwen3_prompt_embeds( + self, + text_encoder: ZImageTextEncoder, + tokenizer: AutoTokenizer, + prompt: Union[str, List[str]], + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + max_sequence_length: int = 512, + ): + dtype = text_encoder.dtype if dtype is None else dtype + device = text_encoder.device if device is None else device + + prompt = [prompt] if isinstance(prompt, str) else prompt + + all_input_ids = [] + all_attention_masks = [] + + for single_prompt in prompt: + messages = [{"role": "user", "content": single_prompt}] + text = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + enable_thinking=False, + ) + inputs = tokenizer( + text, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=max_sequence_length, + ) + + all_input_ids.append(inputs["input_ids"]) + all_attention_masks.append(inputs["attention_mask"]) + + input_ids = torch.cat(all_input_ids, dim=0).to(device) + attention_mask = torch.cat(all_attention_masks, dim=0).to(device) + + # Forward pass through the model + output = text_encoder( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=True, + use_cache=False, + ) + + # Only use outputs from intermediate layers and stack them + out = torch.stack([output.hidden_states[k] for k in self.hidden_states_layers], dim=1) + out = out.to(dtype=dtype, device=device) + + batch_size, num_channels, seq_len, hidden_dim = out.shape + prompt_embeds = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, num_channels * hidden_dim) + return prompt_embeds + + def prepare_text_ids( + self, + x: torch.Tensor, # (B, L, D) or (L, D) + t_coord: Optional[torch.Tensor] = None, + ): + B, L, _ = x.shape + out_ids = [] + + for i in range(B): + t = torch.arange(1) if t_coord is None else t_coord[i] + h = torch.arange(1) + w = torch.arange(1) + l = torch.arange(L) + + coords = torch.cartesian_prod(t, h, w, l) + out_ids.append(coords) + + return torch.stack(out_ids) + + def encode_prompt( + self, + text_encoder: ZImageTextEncoder, + tokenizer: AutoTokenizer, + prompt: Union[str, List[str]], + dtype = None, + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + max_sequence_length: int = 512, + ): + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt_embeds is None: + prompt_embeds = self.get_qwen3_prompt_embeds( + text_encoder=text_encoder, + tokenizer=tokenizer, + prompt=prompt, + dtype=dtype, + device=device, + max_sequence_length=max_sequence_length, + ) + + batch_size, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + text_ids = self.prepare_text_ids(prompt_embeds) + text_ids = text_ids.to(device) + return prompt_embeds, text_ids + + def process(self, pipe: Flux2ImagePipeline, prompt): + # Check if Qwen3 text encoder is available + if pipe.text_encoder_qwen3 is None: + return {} + + pipe.load_models_to_device(self.onload_model_names) + prompt_embeds, text_ids = self.encode_prompt( + pipe.text_encoder_qwen3, pipe.tokenizer, prompt, + dtype=pipe.torch_dtype, device=pipe.device, + ) + return {"prompt_embeds": prompt_embeds, "text_ids": text_ids} + + +class Flux2Unit_NoiseInitializer(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("height", "width", "seed", "rand_device", "initial_noise"), + output_params=("noise",), + ) + + def process(self, pipe: Flux2ImagePipeline, height, width, seed, rand_device, initial_noise): + if initial_noise is not None: + noise = initial_noise.clone() + else: + noise = pipe.generate_noise((1, 128, height//16, width//16), seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype) + noise = noise.reshape(1, 128, height//16 * width//16).permute(0, 2, 1) + return {"noise": noise} + + +class Flux2Unit_InputImageEmbedder(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("input_image", "noise"), + output_params=("latents", "input_latents"), + onload_model_names=("vae",) + ) + + def process(self, pipe: Flux2ImagePipeline, input_image, noise): + if input_image is None: + return {"latents": noise, "input_latents": None} + pipe.load_models_to_device(['vae']) + image = pipe.preprocess_image(input_image) + input_latents = pipe.vae.encode(image) + input_latents = rearrange(input_latents, "B C H W -> B (H W) C") + if pipe.scheduler.training: + return {"latents": noise, "input_latents": input_latents} + else: + latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0]) + return {"latents": latents, "input_latents": input_latents} + + +class Flux2Unit_EditImageEmbedder(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("edit_image", "edit_image_auto_resize"), + output_params=("edit_latents", "edit_image_ids"), + onload_model_names=("vae",) + ) + + def calculate_dimensions(self, target_area, ratio): + import math + width = math.sqrt(target_area * ratio) + height = width / ratio + width = round(width / 32) * 32 + height = round(height / 32) * 32 + return width, height + + def crop_and_resize(self, image, target_height, target_width): + width, height = image.size + scale = max(target_width / width, target_height / height) + image = torchvision.transforms.functional.resize( + image, + (round(height*scale), round(width*scale)), + interpolation=torchvision.transforms.InterpolationMode.BILINEAR + ) + image = torchvision.transforms.functional.center_crop(image, (target_height, target_width)) + return image + + def edit_image_auto_resize(self, edit_image): + calculated_width, calculated_height = self.calculate_dimensions(1024 * 1024, edit_image.size[0] / edit_image.size[1]) + return self.crop_and_resize(edit_image, calculated_height, calculated_width) + + def process_image_ids(self, image_latents, scale=10): + t_coords = [scale + scale * t for t in torch.arange(0, len(image_latents))] + t_coords = [t.view(-1) for t in t_coords] + + image_latent_ids = [] + for x, t in zip(image_latents, t_coords): + x = x.squeeze(0) + _, height, width = x.shape + + x_ids = torch.cartesian_prod(t, torch.arange(height), torch.arange(width), torch.arange(1)) + image_latent_ids.append(x_ids) + + image_latent_ids = torch.cat(image_latent_ids, dim=0) + image_latent_ids = image_latent_ids.unsqueeze(0) + + return image_latent_ids + + def process(self, pipe: Flux2ImagePipeline, edit_image, edit_image_auto_resize): + if edit_image is None: + return {} + pipe.load_models_to_device(self.onload_model_names) + if isinstance(edit_image, Image.Image): + edit_image = [edit_image] + resized_edit_image, edit_latents = [], [] + for image in edit_image: + # Preprocess + if edit_image_auto_resize is None or edit_image_auto_resize: + image = self.edit_image_auto_resize(image) + resized_edit_image.append(image) + # Encode + image = pipe.preprocess_image(image) + latents = pipe.vae.encode(image) + edit_latents.append(latents) + edit_image_ids = self.process_image_ids(edit_latents).to(pipe.device) + edit_latents = torch.concat([rearrange(latents, "B C H W -> B (H W) C") for latents in edit_latents], dim=1) + return {"edit_latents": edit_latents, "edit_image_ids": edit_image_ids} + + +class Flux2Unit_ImageIDs(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("height", "width"), + output_params=("image_ids",), + ) + + def prepare_latent_ids(self, height, width): + t = torch.arange(1) # [0] - time dimension + h = torch.arange(height) + w = torch.arange(width) + l = torch.arange(1) # [0] - layer dimension + + # Create position IDs: (H*W, 4) + latent_ids = torch.cartesian_prod(t, h, w, l) + + # Expand to batch: (B, H*W, 4) + latent_ids = latent_ids.unsqueeze(0).expand(1, -1, -1) + + return latent_ids + + def process(self, pipe: Flux2ImagePipeline, height, width): + image_ids = self.prepare_latent_ids(height // 16, width // 16).to(pipe.device) + return {"image_ids": image_ids} + + +def model_fn_flux2( + dit: Flux2DiT, + latents=None, + timestep=None, + embedded_guidance=None, + prompt_embeds=None, + text_ids=None, + image_ids=None, + edit_latents=None, + edit_image_ids=None, + use_gradient_checkpointing=False, + use_gradient_checkpointing_offload=False, + **kwargs, +): + image_seq_len = latents.shape[1] + if edit_latents is not None: + image_seq_len = latents.shape[1] + latents = torch.concat([latents, edit_latents], dim=1) + image_ids = torch.concat([image_ids, edit_image_ids], dim=1) + embedded_guidance = torch.tensor([embedded_guidance], device=latents.device) + model_output = dit( + hidden_states=latents, + timestep=timestep / 1000, + guidance=embedded_guidance, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=image_ids, + use_gradient_checkpointing=use_gradient_checkpointing, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + ) + model_output = model_output[:, :image_seq_len] + return model_output diff --git a/diffsynth/pipelines/flux_image.py b/diffsynth/pipelines/flux_image.py new file mode 100644 index 0000000000000000000000000000000000000000..db2d5224e1635028e2cfd244d949e857e8a1917b --- /dev/null +++ b/diffsynth/pipelines/flux_image.py @@ -0,0 +1,1207 @@ +import torch, math +from PIL import Image +from typing import Union +from tqdm import tqdm +from einops import rearrange, repeat +import numpy as np +from transformers import CLIPTokenizer, T5TokenizerFast + +from ..core.device.npu_compatible_device import get_device_type +from ..diffusion import FlowMatchScheduler +from ..core import ModelConfig, gradient_checkpoint_forward, load_state_dict +from ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput +from ..utils.lora.flux import FluxLoRALoader + +from ..models.flux_dit import FluxDiT +from ..models.flux_text_encoder_clip import FluxTextEncoderClip +from ..models.flux_text_encoder_t5 import FluxTextEncoderT5 +from ..models.flux_vae import FluxVAEEncoder, FluxVAEDecoder +from ..models.flux_value_control import MultiValueEncoder +from ..models.step1x_text_encoder import Step1xEditEmbedder +from ..core.vram.layers import AutoWrappedLinear + +class MultiControlNet(torch.nn.Module): + def __init__(self, models: list[torch.nn.Module]): + super().__init__() + if not isinstance(models, list): + models = [models] + self.models = torch.nn.ModuleList(models) + + def process_single_controlnet(self, controlnet_input: ControlNetInput, conditioning: torch.Tensor, **kwargs): + model = self.models[controlnet_input.controlnet_id] + res_stack, single_res_stack = model( + controlnet_conditioning=conditioning, + processor_id=controlnet_input.processor_id, + **kwargs + ) + res_stack = [res * controlnet_input.scale for res in res_stack] + single_res_stack = [res * controlnet_input.scale for res in single_res_stack] + return res_stack, single_res_stack + + def forward(self, conditionings: list[torch.Tensor], controlnet_inputs: list[ControlNetInput], progress_id, num_inference_steps, **kwargs): + res_stack, single_res_stack = None, None + for controlnet_input, conditioning in zip(controlnet_inputs, conditionings): + progress = (num_inference_steps - 1 - progress_id) / max(num_inference_steps - 1, 1) + if progress > controlnet_input.start or progress < controlnet_input.end: + continue + res_stack_, single_res_stack_ = self.process_single_controlnet(controlnet_input, conditioning, **kwargs) + if res_stack is None: + res_stack = res_stack_ + single_res_stack = single_res_stack_ + else: + res_stack = [i + j for i, j in zip(res_stack, res_stack_)] + single_res_stack = [i + j for i, j in zip(single_res_stack, single_res_stack_)] + return res_stack, single_res_stack + + +class FluxImagePipeline(BasePipeline): + + def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16): + super().__init__( + device=device, torch_dtype=torch_dtype, + height_division_factor=16, width_division_factor=16, + ) + self.scheduler = FlowMatchScheduler("FLUX.1") + self.tokenizer_1: CLIPTokenizer = None + self.tokenizer_2: T5TokenizerFast = None + self.text_encoder_1: FluxTextEncoderClip = None + self.text_encoder_2: FluxTextEncoderT5 = None + self.dit: FluxDiT = None + self.vae_decoder: FluxVAEDecoder = None + self.vae_encoder: FluxVAEEncoder = None + self.controlnet = None + self.ipadapter = None + self.ipadapter_image_encoder = None + self.qwenvl = None + self.step1x_connector = None + self.nexus_gen = None + self.nexus_gen_generation_adapter = None + self.nexus_gen_editing_adapter = None + self.value_controller = None + self.infinityou_processor = None + self.image_proj_model = None + self.lora_patcher = None + self.lora_encoder = None + self.in_iteration_models = ("dit", "step1x_connector", "controlnet", "lora_patcher") + self.units = [ + FluxImageUnit_ShapeChecker(), + FluxImageUnit_NoiseInitializer(), + FluxImageUnit_PromptEmbedder(), + FluxImageUnit_InputImageEmbedder(), + FluxImageUnit_ImageIDs(), + FluxImageUnit_EmbeddedGuidanceEmbedder(), + FluxImageUnit_Kontext(), + FluxImageUnit_InfiniteYou(), + FluxImageUnit_ControlNet(), + FluxImageUnit_IPAdapter(), + FluxImageUnit_EntityControl(), + FluxImageUnit_NexusGen(), + FluxImageUnit_TeaCache(), + FluxImageUnit_Flex(), + FluxImageUnit_Step1x(), + FluxImageUnit_ValueControl(), + FluxImageUnit_LoRAEncode(), + ] + self.model_fn = model_fn_flux_image + self.compilable_models = ["dit"] + self.lora_loader = FluxLoRALoader + + def enable_lora_merger(self): + if not (hasattr(self.dit, "vram_management_enabled") and getattr(self.dit, "vram_management_enabled")): + raise ValueError("DiT VRAM management is not enabled.") + if self.lora_patcher is not None: + for name, module in self.dit.named_modules(): + if isinstance(module, AutoWrappedLinear): + merger_name = name.replace(".", "___") + if merger_name in self.lora_patcher.model_dict: + module.lora_merger = self.lora_patcher.model_dict[merger_name] + + @staticmethod + def from_pretrained( + torch_dtype: torch.dtype = torch.bfloat16, + device: Union[str, torch.device] = get_device_type(), + model_configs: list[ModelConfig] = [], + tokenizer_1_config: ModelConfig = ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="tokenizer/"), + tokenizer_2_config: ModelConfig = ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="tokenizer_2/"), + nexus_gen_processor_config: ModelConfig = ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="processor/"), + step1x_processor_config: ModelConfig = ModelConfig(model_id="Qwen/Qwen2.5-VL-7B-Instruct", origin_file_pattern=""), + vram_limit: float = None, + ): + # Initialize pipeline + pipe = FluxImagePipeline(device=device, torch_dtype=torch_dtype) + model_pool = pipe.download_and_load_models(model_configs, vram_limit) + + # Fetch models + pipe.text_encoder_1 = model_pool.fetch_model("flux_text_encoder_clip") + pipe.text_encoder_2 = model_pool.fetch_model("flux_text_encoder_t5") + pipe.dit = model_pool.fetch_model("flux_dit") + pipe.vae_encoder = model_pool.fetch_model("flux_vae_encoder") + pipe.vae_decoder = model_pool.fetch_model("flux_vae_decoder") + if tokenizer_1_config is not None: + tokenizer_1_config.download_if_necessary() + pipe.tokenizer_1 = CLIPTokenizer.from_pretrained(tokenizer_1_config.path) + if tokenizer_2_config is not None: + tokenizer_2_config.download_if_necessary() + pipe.tokenizer_2 = T5TokenizerFast.from_pretrained(tokenizer_2_config.path) + + value_controllers = model_pool.fetch_model("flux_value_controller") + if value_controllers is not None: + pipe.value_controller = MultiValueEncoder(value_controllers) + if hasattr(pipe.value_controller.encoders[0], "vram_management_enabled"): + pipe.value_controller.vram_management_enabled = pipe.value_controller.encoders[0].vram_management_enabled + controlnets = model_pool.fetch_model("flux_controlnet") + if controlnets is not None: pipe.controlnet = MultiControlNet(controlnets) + pipe.ipadapter = model_pool.fetch_model("flux_ipadapter") + pipe.ipadapter_image_encoder = model_pool.fetch_model("siglip_vision_model") + qwenvl = model_pool.fetch_model("qwen_image_text_encoder") + if qwenvl is not None: + from transformers import AutoProcessor + step1x_processor_config.download_if_necessary() + processor = AutoProcessor.from_pretrained(step1x_processor_config.path, min_pixels=256 * 28 * 28, max_pixels=324 * 28 * 28) + pipe.qwenvl = Step1xEditEmbedder(qwenvl, processor) + pipe.step1x_connector = model_pool.fetch_model("step1x_connector") + pipe.image_proj_model = model_pool.fetch_model("infiniteyou_image_projector") + if pipe.image_proj_model is not None: + pipe.infinityou_processor = InfinitYou(device=device) + pipe.lora_patcher = model_pool.fetch_model("flux_lora_patcher") + pipe.lora_encoder = model_pool.fetch_model("flux_lora_encoder") + pipe.nexus_gen = model_pool.fetch_model("nexus_gen_llm") + pipe.nexus_gen_generation_adapter = model_pool.fetch_model("nexus_gen_generation_adapter") + pipe.nexus_gen_editing_adapter = model_pool.fetch_model("nexus_gen_editing_adapter") + if pipe.nexus_gen is not None: + nexus_gen_processor_config.download_if_necessary() + pipe.nexus_gen.load_processor(nexus_gen_processor_config.path) + + # VRAM Management + pipe.vram_management_enabled = pipe.check_vram_management_state() + return pipe + + + @torch.no_grad() + def __call__( + self, + # Prompt + prompt: str, + negative_prompt: str = "", + cfg_scale: float = 1.0, + embedded_guidance: float = 3.5, + t5_sequence_length: int = 512, + # Image + input_image: Image.Image = None, + denoising_strength: float = 1.0, + # Shape + height: int = 1024, + width: int = 1024, + # Randomness + seed: int = None, + rand_device: str = "cpu", + # Scheduler + sigma_shift: float = None, + # Steps + num_inference_steps: int = 30, + # local prompts + multidiffusion_prompts=(), + multidiffusion_masks=(), + multidiffusion_scales=(), + # Kontext + kontext_images: Union[list[Image.Image], Image.Image] = None, + # ControlNet + controlnet_inputs: list[ControlNetInput] = None, + # IP-Adapter + ipadapter_images: Union[list[Image.Image], Image.Image] = None, + ipadapter_scale: float = 1.0, + # EliGen + eligen_entity_prompts: list[str] = None, + eligen_entity_masks: list[Image.Image] = None, + eligen_enable_on_negative: bool = False, + eligen_enable_inpaint: bool = False, + # InfiniteYou + infinityou_id_image: Image.Image = None, + infinityou_guidance: float = 1.0, + # Flex + flex_inpaint_image: Image.Image = None, + flex_inpaint_mask: Image.Image = None, + flex_control_image: Image.Image = None, + flex_control_strength: float = 0.5, + flex_control_stop: float = 0.5, + # Value Controller + value_controller_inputs: Union[list[float], float] = None, + # Step1x + step1x_reference_image: Image.Image = None, + # NexusGen + nexus_gen_reference_image: Image.Image = None, + # LoRA Encoder + lora_encoder_inputs: Union[list[ModelConfig], ModelConfig, str] = None, + lora_encoder_scale: float = 1.0, + # TeaCache + tea_cache_l1_thresh: float = None, + # Tile + tiled: bool = False, + tile_size: int = 128, + tile_stride: int = 64, + # Progress bar + progress_bar_cmd = tqdm, + ): + # Scheduler + self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift) + + inputs_posi = { + "prompt": prompt, + } + inputs_nega = { + "negative_prompt": negative_prompt, + } + inputs_shared = { + "cfg_scale": cfg_scale, "embedded_guidance": embedded_guidance, "t5_sequence_length": t5_sequence_length, + "input_image": input_image, "denoising_strength": denoising_strength, + "height": height, "width": width, + "seed": seed, "rand_device": rand_device, + "sigma_shift": sigma_shift, "num_inference_steps": num_inference_steps, + "multidiffusion_prompts": multidiffusion_prompts, "multidiffusion_masks": multidiffusion_masks, "multidiffusion_scales": multidiffusion_scales, + "kontext_images": kontext_images, + "controlnet_inputs": controlnet_inputs, + "ipadapter_images": ipadapter_images, "ipadapter_scale": ipadapter_scale, + "eligen_entity_prompts": eligen_entity_prompts, "eligen_entity_masks": eligen_entity_masks, "eligen_enable_on_negative": eligen_enable_on_negative, "eligen_enable_inpaint": eligen_enable_inpaint, + "infinityou_id_image": infinityou_id_image, "infinityou_guidance": infinityou_guidance, + "flex_inpaint_image": flex_inpaint_image, "flex_inpaint_mask": flex_inpaint_mask, "flex_control_image": flex_control_image, "flex_control_strength": flex_control_strength, "flex_control_stop": flex_control_stop, + "value_controller_inputs": value_controller_inputs, + "step1x_reference_image": step1x_reference_image, + "nexus_gen_reference_image": nexus_gen_reference_image, + "lora_encoder_inputs": lora_encoder_inputs, "lora_encoder_scale": lora_encoder_scale, + "tea_cache_l1_thresh": tea_cache_l1_thresh, + "tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride, + "progress_bar_cmd": progress_bar_cmd, + } + for unit in self.units: + inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega) + + # Denoise + self.load_models_to_device(self.in_iteration_models) + models = {name: getattr(self, name) for name in self.in_iteration_models} + for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): + timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device) + noise_pred = self.cfg_guided_model_fn( + self.model_fn, cfg_scale, + inputs_shared, inputs_posi, inputs_nega, + **models, timestep=timestep, progress_id=progress_id + ) + inputs_shared["latents"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared) + + # Decode + self.load_models_to_device(['vae_decoder']) + image = self.vae_decoder(inputs_shared["latents"], device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + image = self.vae_output_to_image(image) + self.load_models_to_device([]) + + return image + + +class FluxImageUnit_ShapeChecker(PipelineUnit): + def __init__(self): + super().__init__(input_params=("height", "width"), output_params=("height", "width")) + + def process(self, pipe: FluxImagePipeline, height, width): + height, width = pipe.check_resize_height_width(height, width) + return {"height": height, "width": width} + + + +class FluxImageUnit_NoiseInitializer(PipelineUnit): + def __init__(self): + super().__init__(input_params=("height", "width", "seed", "rand_device"), output_params=("noise",)) + + def process(self, pipe: FluxImagePipeline, height, width, seed, rand_device): + noise = pipe.generate_noise((1, 16, height//8, width//8), seed=seed, rand_device=rand_device) + return {"noise": noise} + + + +class FluxImageUnit_InputImageEmbedder(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("input_image", "noise", "tiled", "tile_size", "tile_stride"), + output_params=("latents", "input_latents"), + onload_model_names=("vae_encoder",) + ) + + def process(self, pipe: FluxImagePipeline, input_image, noise, tiled, tile_size, tile_stride): + if input_image is None: + return {"latents": noise, "input_latents": None} + pipe.load_models_to_device(['vae_encoder']) + image = pipe.preprocess_image(input_image).to(device=pipe.device, dtype=pipe.torch_dtype) + input_latents = pipe.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + if pipe.scheduler.training: + return {"latents": noise, "input_latents": input_latents} + else: + latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0]) + return {"latents": latents, "input_latents": None} + + + +class FluxImageUnit_PromptEmbedder(PipelineUnit): + def __init__(self): + super().__init__( + seperate_cfg=True, + input_params_posi={"prompt": "prompt", "positive": "positive"}, + input_params_nega={"prompt": "negative_prompt", "positive": "positive"}, + input_params=("t5_sequence_length",), + output_params=("prompt_emb", "pooled_prompt_emb", "text_ids"), + onload_model_names=("text_encoder_1", "text_encoder_2") + ) + + def encode_prompt_using_clip(self, prompt, text_encoder, tokenizer, max_length, device): + input_ids = tokenizer( + prompt, + return_tensors="pt", + padding="max_length", + max_length=max_length, + truncation=True + ).input_ids.to(device) + pooled_prompt_emb, _ = text_encoder(input_ids) + return pooled_prompt_emb + + def encode_prompt_using_t5(self, prompt, text_encoder, tokenizer, max_length, device): + input_ids = tokenizer( + prompt, + return_tensors="pt", + padding="max_length", + max_length=max_length, + truncation=True, + ).input_ids.to(device) + prompt_emb = text_encoder(input_ids) + return prompt_emb + + def encode_prompt( + self, + tokenizer_1, + tokenizer_2, + text_encoder_1, + text_encoder_2, + prompt, + positive=True, + device=get_device_type(), + t5_sequence_length=512, + ): + pooled_prompt_emb = self.encode_prompt_using_clip(prompt, text_encoder_1, tokenizer_1, 77, device) + prompt_emb = self.encode_prompt_using_t5(prompt, text_encoder_2, tokenizer_2, t5_sequence_length, device) + text_ids = torch.zeros(prompt_emb.shape[0], prompt_emb.shape[1], 3).to(device=device, dtype=prompt_emb.dtype) + return prompt_emb, pooled_prompt_emb, text_ids + + def process(self, pipe: FluxImagePipeline, prompt, t5_sequence_length, positive) -> dict: + if pipe.text_encoder_1 is not None and pipe.text_encoder_2 is not None: + prompt_emb, pooled_prompt_emb, text_ids = self.encode_prompt( + tokenizer_1=pipe.tokenizer_1, tokenizer_2=pipe.tokenizer_2, + text_encoder_1=pipe.text_encoder_1, text_encoder_2=pipe.text_encoder_2, + prompt=prompt, device=pipe.device, positive=positive, t5_sequence_length=t5_sequence_length, + ) + return {"prompt_emb": prompt_emb, "pooled_prompt_emb": pooled_prompt_emb, "text_ids": text_ids} + else: + return {} + + +class FluxImageUnit_ImageIDs(PipelineUnit): + def __init__(self): + super().__init__(input_params=("latents",), output_params=("image_ids",)) + + def process(self, pipe: FluxImagePipeline, latents): + latent_image_ids = pipe.dit.prepare_image_ids(latents) + return {"image_ids": latent_image_ids} + + + +class FluxImageUnit_EmbeddedGuidanceEmbedder(PipelineUnit): + def __init__(self): + super().__init__(input_params=("embedded_guidance", "latents"), output_params=("guidance",)) + + def process(self, pipe: FluxImagePipeline, embedded_guidance, latents): + guidance = torch.Tensor([embedded_guidance] * latents.shape[0]).to(device=latents.device, dtype=latents.dtype) + return {"guidance": guidance} + + + +class FluxImageUnit_Kontext(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("kontext_images", "tiled", "tile_size", "tile_stride"), + output_params=("kontext_latents", "kontext_image_ids"), + onload_model_names=("vae_encoder",) + ) + + def process(self, pipe: FluxImagePipeline, kontext_images, tiled, tile_size, tile_stride): + if kontext_images is None: + return {} + if not isinstance(kontext_images, list): + kontext_images = [kontext_images] + + kontext_latents = [] + kontext_image_ids = [] + for kontext_image in kontext_images: + kontext_image = pipe.preprocess_image(kontext_image) + kontext_latent = pipe.vae_encoder(kontext_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + image_ids = pipe.dit.prepare_image_ids(kontext_latent) + image_ids[..., 0] = 1 + kontext_image_ids.append(image_ids) + kontext_latent = pipe.dit.patchify(kontext_latent) + kontext_latents.append(kontext_latent) + kontext_latents = torch.concat(kontext_latents, dim=1) + kontext_image_ids = torch.concat(kontext_image_ids, dim=-2) + return {"kontext_latents": kontext_latents, "kontext_image_ids": kontext_image_ids} + + + +class FluxImageUnit_ControlNet(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("controlnet_inputs", "tiled", "tile_size", "tile_stride"), + output_params=("controlnet_conditionings",), + onload_model_names=("vae_encoder",) + ) + + def apply_controlnet_mask_on_latents(self, pipe, latents, mask): + mask = (pipe.preprocess_image(mask) + 1) / 2 + mask = mask.mean(dim=1, keepdim=True) + mask = 1 - torch.nn.functional.interpolate(mask, size=latents.shape[-2:]) + latents = torch.concat([latents, mask], dim=1) + return latents + + def apply_controlnet_mask_on_image(self, pipe, image, mask): + mask = mask.resize(image.size) + mask = pipe.preprocess_image(mask).mean(dim=[0, 1]).cpu() + image = np.array(image) + image[mask > 0] = 0 + image = Image.fromarray(image) + return image + + def process(self, pipe: FluxImagePipeline, controlnet_inputs: list[ControlNetInput], tiled, tile_size, tile_stride): + if controlnet_inputs is None: + return {} + pipe.load_models_to_device(['vae_encoder']) + conditionings = [] + for controlnet_input in controlnet_inputs: + image = controlnet_input.image + if controlnet_input.inpaint_mask is not None: + image = self.apply_controlnet_mask_on_image(pipe, image, controlnet_input.inpaint_mask) + + image = pipe.preprocess_image(image).to(device=pipe.device, dtype=pipe.torch_dtype) + image = pipe.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + + if controlnet_input.inpaint_mask is not None: + image = self.apply_controlnet_mask_on_latents(pipe, image, controlnet_input.inpaint_mask) + conditionings.append(image) + return {"controlnet_conditionings": conditionings} + + + +class FluxImageUnit_IPAdapter(PipelineUnit): + def __init__(self): + super().__init__( + take_over=True, + input_params=("ipadapter_images", "ipadapter_scale"), + output_params=("ipadapter_kwargs_list",), + onload_model_names=("ipadapter_image_encoder", "ipadapter") + ) + + def process(self, pipe: FluxImagePipeline, inputs_shared, inputs_posi, inputs_nega): + ipadapter_images, ipadapter_scale = inputs_shared.get("ipadapter_images", None), inputs_shared.get("ipadapter_scale", 1.0) + if ipadapter_images is None: + return inputs_shared, inputs_posi, inputs_nega + if not isinstance(ipadapter_images, list): + ipadapter_images = [ipadapter_images] + + pipe.load_models_to_device(self.onload_model_names) + images = [image.convert("RGB").resize((384, 384), resample=3) for image in ipadapter_images] + images = [pipe.preprocess_image(image).to(device=pipe.device, dtype=pipe.torch_dtype) for image in images] + ipadapter_images = torch.cat(images, dim=0) + ipadapter_image_encoding = pipe.ipadapter_image_encoder(ipadapter_images).pooler_output + + inputs_posi.update({"ipadapter_kwargs_list": pipe.ipadapter(ipadapter_image_encoding, scale=ipadapter_scale)}) + if inputs_shared.get("cfg_scale", 1.0) != 1.0: + inputs_nega.update({"ipadapter_kwargs_list": pipe.ipadapter(torch.zeros_like(ipadapter_image_encoding))}) + return inputs_shared, inputs_posi, inputs_nega + + + +class FluxImageUnit_EntityControl(PipelineUnit): + def __init__(self): + super().__init__( + take_over=True, + input_params=("eligen_entity_prompts", "eligen_entity_masks", "eligen_enable_on_negative", "width", "height", "t5_sequence_length", "cfg_scale"), + output_params=("entity_prompt_emb", "entity_masks"), + onload_model_names=("text_encoder_1", "text_encoder_2") + ) + + def encode_prompt_using_clip(self, prompt, text_encoder, tokenizer, max_length, device): + input_ids = tokenizer( + prompt, + return_tensors="pt", + padding="max_length", + max_length=max_length, + truncation=True + ).input_ids.to(device) + pooled_prompt_emb, _ = text_encoder(input_ids) + return pooled_prompt_emb + + def encode_prompt_using_t5(self, prompt, text_encoder, tokenizer, max_length, device): + input_ids = tokenizer( + prompt, + return_tensors="pt", + padding="max_length", + max_length=max_length, + truncation=True, + ).input_ids.to(device) + prompt_emb = text_encoder(input_ids) + return prompt_emb + + def encode_prompt( + self, + tokenizer_1, + tokenizer_2, + text_encoder_1, + text_encoder_2, + prompt, + positive=True, + device=get_device_type(), + t5_sequence_length=512, + ): + pooled_prompt_emb = self.encode_prompt_using_clip(prompt, text_encoder_1, tokenizer_1, 77, device) + prompt_emb = self.encode_prompt_using_t5(prompt, text_encoder_2, tokenizer_2, t5_sequence_length, device) + text_ids = torch.zeros(prompt_emb.shape[0], prompt_emb.shape[1], 3).to(device=device, dtype=prompt_emb.dtype) + return prompt_emb, pooled_prompt_emb, text_ids + + def preprocess_masks(self, pipe, masks, height, width, dim): + out_masks = [] + for mask in masks: + mask = pipe.preprocess_image(mask.resize((width, height), resample=Image.NEAREST)).mean(dim=1, keepdim=True) > 0 + mask = mask.repeat(1, dim, 1, 1).to(device=pipe.device, dtype=pipe.torch_dtype) + out_masks.append(mask) + return out_masks + + def prepare_entity_inputs(self, pipe, entity_prompts, entity_masks, width, height, t5_sequence_length=512): + entity_masks = self.preprocess_masks(pipe, entity_masks, height//8, width//8, 1) + entity_masks = torch.cat(entity_masks, dim=0).unsqueeze(0) # b, n_mask, c, h, w + + prompt_emb, _, _ = self.encode_prompt( + tokenizer_1=pipe.tokenizer_1, tokenizer_2=pipe.tokenizer_2, + text_encoder_1=pipe.text_encoder_1, text_encoder_2=pipe.text_encoder_2, + prompt=entity_prompts, device=pipe.device, t5_sequence_length=t5_sequence_length, + ) + return prompt_emb.unsqueeze(0), entity_masks + + def prepare_eligen(self, pipe, prompt_emb_nega, eligen_entity_prompts, eligen_entity_masks, width, height, t5_sequence_length, enable_eligen_on_negative, cfg_scale): + entity_prompt_emb_posi, entity_masks_posi = self.prepare_entity_inputs(pipe, eligen_entity_prompts, eligen_entity_masks, width, height, t5_sequence_length) + if enable_eligen_on_negative and cfg_scale != 1.0: + entity_prompt_emb_nega = prompt_emb_nega['prompt_emb'].unsqueeze(1).repeat(1, entity_masks_posi.shape[1], 1, 1) + entity_masks_nega = entity_masks_posi + else: + entity_prompt_emb_nega, entity_masks_nega = None, None + eligen_kwargs_posi = {"entity_prompt_emb": entity_prompt_emb_posi, "entity_masks": entity_masks_posi} + eligen_kwargs_nega = {"entity_prompt_emb": entity_prompt_emb_nega, "entity_masks": entity_masks_nega} + return eligen_kwargs_posi, eligen_kwargs_nega + + def process(self, pipe: FluxImagePipeline, inputs_shared, inputs_posi, inputs_nega): + eligen_entity_prompts, eligen_entity_masks = inputs_shared.get("eligen_entity_prompts", None), inputs_shared.get("eligen_entity_masks", None) + if eligen_entity_prompts is None or eligen_entity_masks is None: + return inputs_shared, inputs_posi, inputs_nega + pipe.load_models_to_device(self.onload_model_names) + eligen_enable_on_negative = inputs_shared.get("eligen_enable_on_negative", False) + eligen_kwargs_posi, eligen_kwargs_nega = self.prepare_eligen(pipe, inputs_nega, + eligen_entity_prompts, eligen_entity_masks, inputs_shared["width"], inputs_shared["height"], + inputs_shared["t5_sequence_length"], eligen_enable_on_negative, inputs_shared["cfg_scale"]) + inputs_posi.update(eligen_kwargs_posi) + if inputs_shared.get("cfg_scale", 1.0) != 1.0: + inputs_nega.update(eligen_kwargs_nega) + return inputs_shared, inputs_posi, inputs_nega + + +class FluxImageUnit_NexusGen(PipelineUnit): + def __init__(self): + super().__init__( + take_over=True, + input_params=("nexus_gen_reference_image", "prompt", "latents"), + output_params=("prompt_emb", "text_ids"), + onload_model_names=("nexus_gen", "nexus_gen_generation_adapter", "nexus_gen_editing_adapter"), + ) + + def process(self, pipe: FluxImagePipeline, inputs_shared, inputs_posi, inputs_nega): + if pipe.nexus_gen is None: + return inputs_shared, inputs_posi, inputs_nega + pipe.load_models_to_device(self.onload_model_names) + if inputs_shared.get("nexus_gen_reference_image", None) is None: + assert pipe.nexus_gen_generation_adapter is not None, "NexusGen requires a generation adapter to be set." + embed = pipe.nexus_gen(inputs_posi["prompt"])[0].unsqueeze(0) + inputs_posi["prompt_emb"] = pipe.nexus_gen_generation_adapter(embed) + inputs_posi['text_ids'] = torch.zeros(embed.shape[0], embed.shape[1], 3).to(device=pipe.device, dtype=pipe.torch_dtype) + else: + assert pipe.nexus_gen_editing_adapter is not None, "NexusGen requires an editing adapter to be set." + embed, ref_embed, grids = pipe.nexus_gen(inputs_posi["prompt"], inputs_shared["nexus_gen_reference_image"]) + embeds_grid = grids[1:2].to(device=pipe.device, dtype=torch.long) + ref_embeds_grid = grids[0:1].to(device=pipe.device, dtype=torch.long) + + inputs_posi["prompt_emb"] = pipe.nexus_gen_editing_adapter(embed.unsqueeze(0), embeds_grid, ref_embed.unsqueeze(0), ref_embeds_grid) + inputs_posi["text_ids"] = self.get_editing_text_ids( + inputs_shared["latents"], + embeds_grid[0][1].item(), embeds_grid[0][2].item(), + ref_embeds_grid[0][1].item(), ref_embeds_grid[0][2].item(), + ) + return inputs_shared, inputs_posi, inputs_nega + + + def get_editing_text_ids(self, latents, target_embed_height, target_embed_width, ref_embed_height, ref_embed_width): + # prepare text ids for target and reference embeddings + batch_size, height, width = latents.shape[0], target_embed_height, target_embed_width + embed_ids = torch.zeros(height // 2, width // 2, 3) + scale_factor_height, scale_factor_width = latents.shape[-2] / height, latents.shape[-1] / width + embed_ids[..., 1] = embed_ids[..., 1] + torch.arange(height // 2)[:, None] * scale_factor_height + embed_ids[..., 2] = embed_ids[..., 2] + torch.arange(width // 2)[None, :] * scale_factor_width + embed_ids = embed_ids[None, :].repeat(batch_size, 1, 1, 1).reshape(batch_size, height // 2 * width // 2, 3) + embed_text_ids = embed_ids.to(device=latents.device, dtype=latents.dtype) + + batch_size, height, width = latents.shape[0], ref_embed_height, ref_embed_width + ref_embed_ids = torch.zeros(height // 2, width // 2, 3) + scale_factor_height, scale_factor_width = latents.shape[-2] / height, latents.shape[-1] / width + ref_embed_ids[..., 0] = ref_embed_ids[..., 0] + 1.0 + ref_embed_ids[..., 1] = ref_embed_ids[..., 1] + torch.arange(height // 2)[:, None] * scale_factor_height + ref_embed_ids[..., 2] = ref_embed_ids[..., 2] + torch.arange(width // 2)[None, :] * scale_factor_width + ref_embed_ids = ref_embed_ids[None, :].repeat(batch_size, 1, 1, 1).reshape(batch_size, height // 2 * width // 2, 3) + ref_embed_text_ids = ref_embed_ids.to(device=latents.device, dtype=latents.dtype) + + text_ids = torch.cat([embed_text_ids, ref_embed_text_ids], dim=1) + return text_ids + + +class FluxImageUnit_Step1x(PipelineUnit): + def __init__(self): + super().__init__( + take_over=True, + input_params=("step1x_reference_image", "prompt", "negative_prompt"), + output_params=("step1x_llm_embedding", "step1x_mask", "step1x_reference_latents"), + onload_model_names=("qwenvl","vae_encoder") + ) + + def process(self, pipe: FluxImagePipeline, inputs_shared: dict, inputs_posi: dict, inputs_nega: dict): + image = inputs_shared.get("step1x_reference_image",None) + if image is None: + return inputs_shared, inputs_posi, inputs_nega + else: + pipe.load_models_to_device(self.onload_model_names) + prompt = inputs_posi["prompt"] + nega_prompt = inputs_nega["negative_prompt"] + captions = [prompt, nega_prompt] + ref_images = [image, image] + embs, masks = pipe.qwenvl(captions, ref_images) + image = pipe.preprocess_image(image).to(device=pipe.device, dtype=pipe.torch_dtype) + image = pipe.vae_encoder(image) + inputs_posi.update({"step1x_llm_embedding": embs[0:1], "step1x_mask": masks[0:1], "step1x_reference_latents": image}) + if inputs_shared.get("cfg_scale", 1) != 1: + inputs_nega.update({"step1x_llm_embedding": embs[1:2], "step1x_mask": masks[1:2], "step1x_reference_latents": image}) + return inputs_shared, inputs_posi, inputs_nega + + +class FluxImageUnit_TeaCache(PipelineUnit): + def __init__(self): + super().__init__(input_params=("num_inference_steps","tea_cache_l1_thresh"), output_params=("tea_cache",)) + + def process(self, pipe: FluxImagePipeline, num_inference_steps, tea_cache_l1_thresh): + if tea_cache_l1_thresh is None: + return {} + else: + return {"tea_cache": TeaCache(num_inference_steps=num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh)} + +class FluxImageUnit_Flex(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("latents", "flex_inpaint_image", "flex_inpaint_mask", "flex_control_image", "flex_control_strength", "flex_control_stop", "tiled", "tile_size", "tile_stride"), + output_params=("flex_condition", "flex_uncondition", "flex_control_stop_timestep"), + onload_model_names=("vae_encoder",) + ) + + def process(self, pipe: FluxImagePipeline, latents, flex_inpaint_image, flex_inpaint_mask, flex_control_image, flex_control_strength, flex_control_stop, tiled, tile_size, tile_stride): + if pipe.dit.input_dim == 196: + if flex_control_stop is None: + flex_control_stop = 1 + pipe.load_models_to_device(self.onload_model_names) + if flex_inpaint_image is None: + flex_inpaint_image = torch.zeros_like(latents) + else: + flex_inpaint_image = pipe.preprocess_image(flex_inpaint_image).to(device=pipe.device, dtype=pipe.torch_dtype) + flex_inpaint_image = pipe.vae_encoder(flex_inpaint_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + if flex_inpaint_mask is None: + flex_inpaint_mask = torch.ones_like(latents)[:, 0:1, :, :] + else: + flex_inpaint_mask = flex_inpaint_mask.resize((latents.shape[3], latents.shape[2])) + flex_inpaint_mask = pipe.preprocess_image(flex_inpaint_mask).to(device=pipe.device, dtype=pipe.torch_dtype) + flex_inpaint_mask = (flex_inpaint_mask[:, 0:1, :, :] + 1) / 2 + flex_inpaint_image = flex_inpaint_image * (1 - flex_inpaint_mask) + if flex_control_image is None: + flex_control_image = torch.zeros_like(latents) + else: + flex_control_image = pipe.preprocess_image(flex_control_image).to(device=pipe.device, dtype=pipe.torch_dtype) + flex_control_image = pipe.vae_encoder(flex_control_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) * flex_control_strength + flex_condition = torch.concat([flex_inpaint_image, flex_inpaint_mask, flex_control_image], dim=1) + flex_uncondition = torch.concat([flex_inpaint_image, flex_inpaint_mask, torch.zeros_like(flex_control_image)], dim=1) + flex_control_stop_timestep = pipe.scheduler.timesteps[int(flex_control_stop * (len(pipe.scheduler.timesteps) - 1))] + return {"flex_condition": flex_condition, "flex_uncondition": flex_uncondition, "flex_control_stop_timestep": flex_control_stop_timestep} + else: + return {} + + + +class FluxImageUnit_InfiniteYou(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("infinityou_id_image", "infinityou_guidance"), + output_params=("id_emb", "infinityou_guidance"), + onload_model_names=("infinityou_processor",) + ) + + def process(self, pipe: FluxImagePipeline, infinityou_id_image, infinityou_guidance): + pipe.load_models_to_device("infinityou_processor") + if infinityou_id_image is not None: + return pipe.infinityou_processor.prepare_infinite_you(pipe.image_proj_model, infinityou_id_image, infinityou_guidance, pipe.device) + else: + return {} + + + +class FluxImageUnit_ValueControl(PipelineUnit): + def __init__(self): + super().__init__( + seperate_cfg=True, + input_params_posi={"prompt_emb": "prompt_emb", "text_ids": "text_ids"}, + input_params_nega={"prompt_emb": "prompt_emb", "text_ids": "text_ids"}, + input_params=("value_controller_inputs",), + output_params=("prompt_emb", "text_ids"), + onload_model_names=("value_controller",) + ) + + def add_to_text_embedding(self, prompt_emb, text_ids, value_emb): + prompt_emb = torch.concat([prompt_emb, value_emb], dim=1) + extra_text_ids = torch.zeros((value_emb.shape[0], value_emb.shape[1], 3), device=value_emb.device, dtype=value_emb.dtype) + text_ids = torch.concat([text_ids, extra_text_ids], dim=1) + return prompt_emb, text_ids + + def process(self, pipe: FluxImagePipeline, prompt_emb, text_ids, value_controller_inputs): + if value_controller_inputs is None: + return {} + if not isinstance(value_controller_inputs, list): + value_controller_inputs = [value_controller_inputs] + value_controller_inputs = torch.tensor(value_controller_inputs).to(dtype=pipe.torch_dtype, device=pipe.device) + pipe.load_models_to_device(["value_controller"]) + value_emb = pipe.value_controller(value_controller_inputs, pipe.torch_dtype) + value_emb = value_emb.unsqueeze(0) + prompt_emb, text_ids = self.add_to_text_embedding(prompt_emb, text_ids, value_emb) + return {"prompt_emb": prompt_emb, "text_ids": text_ids} + + + +class InfinitYou(torch.nn.Module): + def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16): + super().__init__() + from facexlib.recognition import init_recognition_model + from insightface.app import FaceAnalysis + self.device = device + self.torch_dtype = torch_dtype + insightface_root_path = 'models/ByteDance/InfiniteYou/supports/insightface' + self.app_640 = FaceAnalysis(name='antelopev2', root=insightface_root_path, providers=['CUDAExecutionProvider', 'CPUExecutionProvider']) + self.app_640.prepare(ctx_id=0, det_size=(640, 640)) + self.app_320 = FaceAnalysis(name='antelopev2', root=insightface_root_path, providers=['CUDAExecutionProvider', 'CPUExecutionProvider']) + self.app_320.prepare(ctx_id=0, det_size=(320, 320)) + self.app_160 = FaceAnalysis(name='antelopev2', root=insightface_root_path, providers=['CUDAExecutionProvider', 'CPUExecutionProvider']) + self.app_160.prepare(ctx_id=0, det_size=(160, 160)) + self.arcface_model = init_recognition_model('arcface', device=self.device).to(torch_dtype) + + def _detect_face(self, id_image_cv2): + face_info = self.app_640.get(id_image_cv2) + if len(face_info) > 0: + return face_info + face_info = self.app_320.get(id_image_cv2) + if len(face_info) > 0: + return face_info + face_info = self.app_160.get(id_image_cv2) + return face_info + + def extract_arcface_bgr_embedding(self, in_image, landmark, device): + from insightface.utils import face_align + arc_face_image = face_align.norm_crop(in_image, landmark=np.array(landmark), image_size=112) + arc_face_image = torch.from_numpy(arc_face_image).unsqueeze(0).permute(0, 3, 1, 2) / 255. + arc_face_image = 2 * arc_face_image - 1 + arc_face_image = arc_face_image.contiguous().to(device=device, dtype=self.torch_dtype) + face_emb = self.arcface_model(arc_face_image)[0] # [512], normalized + return face_emb + + def prepare_infinite_you(self, model, id_image, infinityou_guidance, device): + import cv2 + if id_image is None: + return {'id_emb': None} + id_image_cv2 = cv2.cvtColor(np.array(id_image), cv2.COLOR_RGB2BGR) + face_info = self._detect_face(id_image_cv2) + if len(face_info) == 0: + raise ValueError('No face detected in the input ID image') + landmark = sorted(face_info, key=lambda x:(x['bbox'][2]-x['bbox'][0])*(x['bbox'][3]-x['bbox'][1]))[-1]['kps'] # only use the maximum face + id_emb = self.extract_arcface_bgr_embedding(id_image_cv2, landmark, device) + id_emb = model(id_emb.unsqueeze(0).reshape([1, -1, 512]).to(dtype=self.torch_dtype)) + infinityou_guidance = torch.Tensor([infinityou_guidance]).to(device=device, dtype=self.torch_dtype) + return {'id_emb': id_emb, 'infinityou_guidance': infinityou_guidance} + + + +class FluxImageUnit_LoRAEncode(PipelineUnit): + def __init__(self): + super().__init__( + take_over=True, + input_params=("lora_encoder_inputs", "lora_encoder_scale"), + output_params=("prompt_emb", "text_ids"), + onload_model_names=("lora_encoder",) + ) + + def parse_lora_encoder_inputs(self, lora_encoder_inputs): + if not isinstance(lora_encoder_inputs, list): + lora_encoder_inputs = [lora_encoder_inputs] + lora_configs = [] + for lora_encoder_input in lora_encoder_inputs: + if isinstance(lora_encoder_input, str): + lora_encoder_input = ModelConfig(path=lora_encoder_input) + lora_encoder_input.download_if_necessary() + lora_configs.append(lora_encoder_input) + return lora_configs + + def load_lora(self, lora_config, dtype, device): + loader = FluxLoRALoader(torch_dtype=dtype, device=device) + lora = load_state_dict(lora_config.path, torch_dtype=dtype, device=device) + lora = loader.convert_state_dict(lora) + return lora + + def lora_embedding(self, pipe, lora_encoder_inputs): + lora_emb = [] + for lora_config in self.parse_lora_encoder_inputs(lora_encoder_inputs): + lora = self.load_lora(lora_config, pipe.torch_dtype, pipe.device) + lora_emb.append(pipe.lora_encoder(lora)) + lora_emb = torch.concat(lora_emb, dim=1) + return lora_emb + + def add_to_text_embedding(self, prompt_emb, text_ids, lora_emb): + prompt_emb = torch.concat([prompt_emb, lora_emb], dim=1) + extra_text_ids = torch.zeros((lora_emb.shape[0], lora_emb.shape[1], 3), device=lora_emb.device, dtype=lora_emb.dtype) + text_ids = torch.concat([text_ids, extra_text_ids], dim=1) + return prompt_emb, text_ids + + def process(self, pipe: FluxImagePipeline, inputs_shared, inputs_posi, inputs_nega): + if inputs_shared.get("lora_encoder_inputs", None) is None: + return inputs_shared, inputs_posi, inputs_nega + + # Encode + pipe.load_models_to_device(["lora_encoder"]) + lora_encoder_inputs = inputs_shared["lora_encoder_inputs"] + lora_emb = self.lora_embedding(pipe, lora_encoder_inputs) + + # Scale + lora_encoder_scale = inputs_shared.get("lora_encoder_scale", None) + if lora_encoder_scale is not None: + lora_emb = lora_emb * lora_encoder_scale + + # Add to prompt embedding + inputs_posi["prompt_emb"], inputs_posi["text_ids"] = self.add_to_text_embedding( + inputs_posi["prompt_emb"], inputs_posi["text_ids"], lora_emb) + return inputs_shared, inputs_posi, inputs_nega + + + +class TeaCache: + def __init__(self, num_inference_steps, rel_l1_thresh): + self.num_inference_steps = num_inference_steps + self.step = 0 + self.accumulated_rel_l1_distance = 0 + self.previous_modulated_input = None + self.rel_l1_thresh = rel_l1_thresh + self.previous_residual = None + self.previous_hidden_states = None + + def check(self, dit: FluxDiT, hidden_states, conditioning): + inp = hidden_states.clone() + temb_ = conditioning.clone() + modulated_inp, _, _, _, _ = dit.blocks[0].norm1_a(inp, emb=temb_) + if self.step == 0 or self.step == self.num_inference_steps - 1: + should_calc = True + self.accumulated_rel_l1_distance = 0 + else: + coefficients = [4.98651651e+02, -2.83781631e+02, 5.58554382e+01, -3.82021401e+00, 2.64230861e-01] + rescale_func = np.poly1d(coefficients) + self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item()) + if self.accumulated_rel_l1_distance < self.rel_l1_thresh: + should_calc = False + else: + should_calc = True + self.accumulated_rel_l1_distance = 0 + self.previous_modulated_input = modulated_inp + self.step += 1 + if self.step == self.num_inference_steps: + self.step = 0 + if should_calc: + self.previous_hidden_states = hidden_states.clone() + return not should_calc + + def store(self, hidden_states): + self.previous_residual = hidden_states - self.previous_hidden_states + self.previous_hidden_states = None + + def update(self, hidden_states): + hidden_states = hidden_states + self.previous_residual + return hidden_states + + +class FastTileWorker: + def __init__(self): + pass + + + def build_mask(self, data, is_bound): + _, _, H, W = data.shape + h = repeat(torch.arange(H), "H -> H W", H=H, W=W) + w = repeat(torch.arange(W), "W -> H W", H=H, W=W) + border_width = (H + W) // 4 + pad = torch.ones_like(h) * border_width + mask = torch.stack([ + pad if is_bound[0] else h + 1, + pad if is_bound[1] else H - h, + pad if is_bound[2] else w + 1, + pad if is_bound[3] else W - w + ]).min(dim=0).values + mask = mask.clip(1, border_width) + mask = (mask / border_width).to(dtype=data.dtype, device=data.device) + mask = rearrange(mask, "H W -> 1 H W") + return mask + + + def tiled_forward(self, forward_fn, model_input, tile_size, tile_stride, tile_device="cpu", tile_dtype=torch.float32, border_width=None): + # Prepare + B, C, H, W = model_input.shape + border_width = int(tile_stride*0.5) if border_width is None else border_width + weight = torch.zeros((1, 1, H, W), dtype=tile_dtype, device=tile_device) + values = torch.zeros((B, C, H, W), dtype=tile_dtype, device=tile_device) + + # Split tasks + tasks = [] + for h in range(0, H, tile_stride): + for w in range(0, W, tile_stride): + if (h-tile_stride >= 0 and h-tile_stride+tile_size >= H) or (w-tile_stride >= 0 and w-tile_stride+tile_size >= W): + continue + h_, w_ = h + tile_size, w + tile_size + if h_ > H: h, h_ = H - tile_size, H + if w_ > W: w, w_ = W - tile_size, W + tasks.append((h, h_, w, w_)) + + # Run + for hl, hr, wl, wr in tasks: + # Forward + hidden_states_batch = forward_fn(hl, hr, wl, wr).to(dtype=tile_dtype, device=tile_device) + + mask = self.build_mask(hidden_states_batch, is_bound=(hl==0, hr>=H, wl==0, wr>=W)) + values[:, :, hl:hr, wl:wr] += hidden_states_batch * mask + weight[:, :, hl:hr, wl:wr] += mask + values /= weight + return values + + +def model_fn_flux_image( + dit: FluxDiT, + controlnet=None, + step1x_connector=None, + latents=None, + timestep=None, + prompt_emb=None, + pooled_prompt_emb=None, + guidance=None, + text_ids=None, + image_ids=None, + kontext_latents=None, + kontext_image_ids=None, + controlnet_inputs=None, + controlnet_conditionings=None, + tiled=False, + tile_size=128, + tile_stride=64, + entity_prompt_emb=None, + entity_masks=None, + ipadapter_kwargs_list={}, + id_emb=None, + infinityou_guidance=None, + flex_condition=None, + flex_uncondition=None, + flex_control_stop_timestep=None, + step1x_llm_embedding=None, + step1x_mask=None, + step1x_reference_latents=None, + tea_cache: TeaCache = None, + progress_id=0, + num_inference_steps=1, + use_gradient_checkpointing=False, + use_gradient_checkpointing_offload=False, + **kwargs +): + if tiled: + def flux_forward_fn(hl, hr, wl, wr): + tiled_controlnet_conditionings = [f[:, :, hl: hr, wl: wr] for f in controlnet_conditionings] if controlnet_conditionings is not None else None + return model_fn_flux_image( + dit=dit, + controlnet=controlnet, + latents=latents[:, :, hl: hr, wl: wr], + timestep=timestep, + prompt_emb=prompt_emb, + pooled_prompt_emb=pooled_prompt_emb, + guidance=guidance, + text_ids=text_ids, + image_ids=None, + controlnet_inputs=controlnet_inputs, + controlnet_conditionings=tiled_controlnet_conditionings, + tiled=False, + **kwargs + ) + return FastTileWorker().tiled_forward( + flux_forward_fn, + latents, + tile_size=tile_size, + tile_stride=tile_stride, + tile_device=latents.device, + tile_dtype=latents.dtype + ) + + hidden_states = latents + + # ControlNet + if controlnet is not None and controlnet_conditionings is not None: + controlnet_extra_kwargs = { + "hidden_states": hidden_states, + "timestep": timestep, + "prompt_emb": prompt_emb, + "pooled_prompt_emb": pooled_prompt_emb, + "guidance": guidance, + "text_ids": text_ids, + "image_ids": image_ids, + "controlnet_inputs": controlnet_inputs, + "tiled": tiled, + "tile_size": tile_size, + "tile_stride": tile_stride, + "progress_id": progress_id, + "num_inference_steps": num_inference_steps, + } + if id_emb is not None: + controlnet_text_ids = torch.zeros(id_emb.shape[0], id_emb.shape[1], 3).to(device=hidden_states.device, dtype=hidden_states.dtype) + controlnet_extra_kwargs.update({"prompt_emb": id_emb, 'text_ids': controlnet_text_ids, 'guidance': infinityou_guidance}) + controlnet_res_stack, controlnet_single_res_stack = controlnet( + controlnet_conditionings, **controlnet_extra_kwargs + ) + + # Flex + if flex_condition is not None: + if timestep.tolist()[0] >= flex_control_stop_timestep: + hidden_states = torch.concat([hidden_states, flex_condition], dim=1) + else: + hidden_states = torch.concat([hidden_states, flex_uncondition], dim=1) + + # Step1x + if step1x_llm_embedding is not None: + prompt_emb, pooled_prompt_emb = step1x_connector(step1x_llm_embedding, timestep / 1000, step1x_mask) + text_ids = torch.zeros((1, prompt_emb.shape[1], 3), dtype=prompt_emb.dtype, device=prompt_emb.device) + + if image_ids is None: + image_ids = dit.prepare_image_ids(hidden_states) + + conditioning = dit.time_embedder(timestep, hidden_states.dtype) + dit.pooled_text_embedder(pooled_prompt_emb) + if dit.guidance_embedder is not None: + guidance = guidance * 1000 + conditioning = conditioning + dit.guidance_embedder(guidance, hidden_states.dtype) + + height, width = hidden_states.shape[-2:] + hidden_states = dit.patchify(hidden_states) + + # Kontext + if kontext_latents is not None: + image_ids = torch.concat([image_ids, kontext_image_ids], dim=-2) + hidden_states = torch.concat([hidden_states, kontext_latents], dim=1) + + # Step1x + if step1x_reference_latents is not None: + step1x_reference_image_ids = dit.prepare_image_ids(step1x_reference_latents) + step1x_reference_latents = dit.patchify(step1x_reference_latents) + image_ids = torch.concat([image_ids, step1x_reference_image_ids], dim=-2) + hidden_states = torch.concat([hidden_states, step1x_reference_latents], dim=1) + + hidden_states = dit.x_embedder(hidden_states) + + # EliGen + if entity_prompt_emb is not None and entity_masks is not None: + prompt_emb, image_rotary_emb, attention_mask = dit.process_entity_masks(hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids, latents.shape[1]) + else: + prompt_emb = dit.context_embedder(prompt_emb) + image_rotary_emb = dit.pos_embedder(torch.cat((text_ids, image_ids), dim=1)) + attention_mask = None + + # TeaCache + if tea_cache is not None: + tea_cache_update = tea_cache.check(dit, hidden_states, conditioning) + else: + tea_cache_update = False + + if tea_cache_update: + hidden_states = tea_cache.update(hidden_states) + else: + # Joint Blocks + for block_id, block in enumerate(dit.blocks): + hidden_states, prompt_emb = gradient_checkpoint_forward( + block, + use_gradient_checkpointing, + use_gradient_checkpointing_offload, + hidden_states, + prompt_emb, + conditioning, + image_rotary_emb, + attention_mask, + ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, None), + ) + # ControlNet + if controlnet is not None and controlnet_conditionings is not None and controlnet_res_stack is not None: + if kontext_latents is None: + hidden_states = hidden_states + controlnet_res_stack[block_id] + else: + hidden_states[:, :-kontext_latents.shape[1]] = hidden_states[:, :-kontext_latents.shape[1]] + controlnet_res_stack[block_id] + + # Single Blocks + hidden_states = torch.cat([prompt_emb, hidden_states], dim=1) + num_joint_blocks = len(dit.blocks) + for block_id, block in enumerate(dit.single_blocks): + hidden_states, prompt_emb = gradient_checkpoint_forward( + block, + use_gradient_checkpointing, + use_gradient_checkpointing_offload, + hidden_states, + prompt_emb, + conditioning, + image_rotary_emb, + attention_mask, + ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id + num_joint_blocks, None), + ) + # ControlNet + if controlnet is not None and controlnet_conditionings is not None and controlnet_single_res_stack is not None: + if kontext_latents is None: + hidden_states[:, prompt_emb.shape[1]:] = hidden_states[:, prompt_emb.shape[1]:] + controlnet_single_res_stack[block_id] + else: + hidden_states[:, prompt_emb.shape[1]:-kontext_latents.shape[1]] = hidden_states[:, prompt_emb.shape[1]:-kontext_latents.shape[1]] + controlnet_single_res_stack[block_id] + hidden_states = hidden_states[:, prompt_emb.shape[1]:] + + if tea_cache is not None: + tea_cache.store(hidden_states) + + hidden_states = dit.final_norm_out(hidden_states, conditioning) + hidden_states = dit.final_proj_out(hidden_states) + + # Step1x + if step1x_reference_latents is not None: + hidden_states = hidden_states[:, :hidden_states.shape[1] // 2] + + # Kontext + if kontext_latents is not None: + hidden_states = hidden_states[:, :-kontext_latents.shape[1]] + + hidden_states = dit.unpatchify(hidden_states, height, width) + + return hidden_states diff --git a/diffsynth/pipelines/ltx2_audio_video.py b/diffsynth/pipelines/ltx2_audio_video.py new file mode 100644 index 0000000000000000000000000000000000000000..1263b43d7ad43cabf6944d773c72284c446a7596 --- /dev/null +++ b/diffsynth/pipelines/ltx2_audio_video.py @@ -0,0 +1,731 @@ +import torch, types +import numpy as np +from PIL import Image +from einops import repeat +from typing import Optional, Union +from einops import rearrange +import numpy as np +from PIL import Image +from tqdm import tqdm +from typing import Optional +from transformers import AutoImageProcessor, Gemma3Processor + +from ..core.device.npu_compatible_device import get_device_type +from ..diffusion import FlowMatchScheduler +from ..core import ModelConfig +from ..diffusion.base_pipeline import BasePipeline, PipelineUnit + +from ..models.ltx2_text_encoder import LTX2TextEncoder, LTX2TextEncoderPostModules, LTXVGemmaTokenizer +from ..models.ltx2_dit import LTXModel +from ..models.ltx2_video_vae import LTX2VideoEncoder, LTX2VideoDecoder, VideoLatentPatchifier +from ..models.ltx2_audio_vae import LTX2AudioEncoder, LTX2AudioDecoder, LTX2Vocoder, AudioPatchifier, AudioProcessor +from ..models.ltx2_upsampler import LTX2LatentUpsampler +from ..models.ltx2_common import VideoLatentShape, AudioLatentShape, VideoPixelShape, get_pixel_coords, VIDEO_SCALE_FACTORS +from ..utils.data.media_io_ltx2 import ltx2_preprocess +from ..utils.data.audio import convert_to_stereo + + +class LTX2AudioVideoPipeline(BasePipeline): + + def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16): + super().__init__( + device=device, + torch_dtype=torch_dtype, + height_division_factor=32, + width_division_factor=32, + time_division_factor=8, + time_division_remainder=1, + ) + self.scheduler = FlowMatchScheduler("LTX-2") + self.text_encoder: LTX2TextEncoder = None + self.tokenizer: LTXVGemmaTokenizer = None + self.processor: Gemma3Processor = None + self.text_encoder_post_modules: LTX2TextEncoderPostModules = None + self.dit: LTXModel = None + self.video_vae_encoder: LTX2VideoEncoder = None + self.video_vae_decoder: LTX2VideoDecoder = None + self.audio_vae_encoder: LTX2AudioEncoder = None + self.audio_vae_decoder: LTX2AudioDecoder = None + self.audio_vocoder: LTX2Vocoder = None + self.upsampler: LTX2LatentUpsampler = None + + self.video_patchifier: VideoLatentPatchifier = VideoLatentPatchifier(patch_size=1) + self.audio_patchifier: AudioPatchifier = AudioPatchifier(patch_size=1) + self.audio_processor: AudioProcessor = AudioProcessor() + + self.in_iteration_models = ("dit",) + self.units = [ + LTX2AudioVideoUnit_PipelineChecker(), + LTX2AudioVideoUnit_ShapeChecker(), + LTX2AudioVideoUnit_PromptEmbedder(), + LTX2AudioVideoUnit_NoiseInitializer(), + LTX2AudioVideoUnit_VideoRetakeEmbedder(), + LTX2AudioVideoUnit_AudioRetakeEmbedder(), + LTX2AudioVideoUnit_InputAudioEmbedder(), + LTX2AudioVideoUnit_InputVideoEmbedder(), + LTX2AudioVideoUnit_InputImagesEmbedder(), + LTX2AudioVideoUnit_InContextVideoEmbedder(), + ] + self.stage2_units = [ + LTX2AudioVideoUnit_SwitchStage2(), + LTX2AudioVideoUnit_NoiseInitializer(), + LTX2AudioVideoUnit_LatentsUpsampler(), + LTX2AudioVideoUnit_VideoRetakeEmbedder(), + LTX2AudioVideoUnit_AudioRetakeEmbedder(), + LTX2AudioVideoUnit_InputImagesEmbedder(), + LTX2AudioVideoUnit_SetScheduleStage2(), + ] + self.model_fn = model_fn_ltx2 + self.compilable_models = ["dit"] + + self.default_negative_prompt = { + "LTX-2": ( + "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, " + "grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, " + "deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, " + "wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of " + "field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent " + "lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny " + "valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, " + "mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, " + "off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward " + "pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, " + "inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts." + ), + "LTX-2.3": ( + "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, " + "grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, " + "deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, " + "wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of " + "field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent " + "lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny " + "valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, " + "mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, " + "off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward " + "pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, " + "inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts." + ), + } + + @staticmethod + def from_pretrained( + torch_dtype: torch.dtype = torch.bfloat16, + device: Union[str, torch.device] = get_device_type(), + model_configs: list[ModelConfig] = [], + tokenizer_config: ModelConfig = ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"), + stage2_lora_config: Optional[ModelConfig] = None, + stage2_lora_strength: float = 0.8, + vram_limit: float = None, + ): + # Initialize pipeline + pipe = LTX2AudioVideoPipeline(device=device, torch_dtype=torch_dtype) + model_pool = pipe.download_and_load_models(model_configs, vram_limit) + + # Fetch models + pipe.text_encoder = model_pool.fetch_model("ltx2_text_encoder") + tokenizer_config.download_if_necessary() + pipe.tokenizer = LTXVGemmaTokenizer(tokenizer_path=tokenizer_config.path) + image_processor = AutoImageProcessor.from_pretrained(tokenizer_config.path, local_files_only=True) + pipe.processor = Gemma3Processor(image_processor=image_processor, tokenizer=pipe.tokenizer.tokenizer) + + pipe.text_encoder_post_modules = model_pool.fetch_model("ltx2_text_encoder_post_modules") + pipe.dit = model_pool.fetch_model("ltx2_dit") + pipe.video_vae_encoder = model_pool.fetch_model("ltx2_video_vae_encoder") + pipe.video_vae_decoder = model_pool.fetch_model("ltx2_video_vae_decoder") + pipe.audio_vae_decoder = model_pool.fetch_model("ltx2_audio_vae_decoder") + pipe.audio_vocoder = model_pool.fetch_model("ltx2_audio_vocoder") + pipe.upsampler = model_pool.fetch_model("ltx2_latent_upsampler") + pipe.audio_vae_encoder = model_pool.fetch_model("ltx2_audio_vae_encoder") + + # Stage 2 + if stage2_lora_config is not None: + pipe.stage2_lora_config = stage2_lora_config + pipe.stage2_lora_strength = stage2_lora_strength + + # VRAM Management + pipe.vram_management_enabled = pipe.check_vram_management_state() + return pipe + + def denoise_stage(self, inputs_shared, inputs_posi, inputs_nega, units, cfg_scale=1.0, progress_bar_cmd=tqdm, skip_stage=False): + if skip_stage: + return inputs_shared, inputs_posi, inputs_nega + for unit in units: + inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega) + self.load_models_to_device(self.in_iteration_models) + models = {name: getattr(self, name) for name in self.in_iteration_models} + for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): + timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device) + noise_pred_video, noise_pred_audio = self.cfg_guided_model_fn( + self.model_fn, cfg_scale, inputs_shared, inputs_posi, inputs_nega, + **models, timestep=timestep, progress_id=progress_id + ) + inputs_shared["video_latents"] = self.step(self.scheduler, inputs_shared["video_latents"], progress_id=progress_id, noise_pred=noise_pred_video, + inpaint_mask=inputs_shared.get("denoise_mask_video", None), input_latents=inputs_shared.get("input_latents_video", None), **inputs_shared) + inputs_shared["audio_latents"] = self.step(self.scheduler, inputs_shared["audio_latents"], progress_id=progress_id, noise_pred=noise_pred_audio, + inpaint_mask=inputs_shared.get("denoise_mask_audio", None), input_latents=inputs_shared.get("input_latents_audio", None), **inputs_shared) + return inputs_shared, inputs_posi, inputs_nega + + @torch.no_grad() + def __call__( + self, + # Prompt + prompt: str, + negative_prompt: Optional[str] = "", + denoising_strength: float = 1.0, + # Image-to-video + input_images: Optional[list[Image.Image]] = None, + input_images_indexes: Optional[list[int]] = [0], + input_images_strength: Optional[float] = 1.0, + # In-Context Video Control + in_context_videos: Optional[list[list[Image.Image]]] = None, + in_context_downsample_factor: Optional[int] = 2, + # Video-to-video + retake_video: Optional[list[Image.Image]] = None, + retake_video_regions: Optional[list[tuple[float, float]]] = None, + # Audio-to-video + retake_audio: Optional[torch.Tensor] = None, + audio_sample_rate: Optional[int] = 48000, + retake_audio_regions: Optional[list[tuple[float, float]]] = None, + # Randomness + seed: Optional[int] = None, + rand_device: Optional[str] = "cpu", + # Shape + height: Optional[int] = 512, + width: Optional[int] = 768, + num_frames: Optional[int] = 121, + frame_rate: Optional[int] = 24, + # Classifier-free guidance + cfg_scale: Optional[float] = 3.0, + # Scheduler + num_inference_steps: Optional[int] = 30, + # VAE tiling + tiled: Optional[bool] = True, + tile_size_in_pixels: Optional[int] = 512, + tile_overlap_in_pixels: Optional[int] = 128, + tile_size_in_frames: Optional[int] = 128, + tile_overlap_in_frames: Optional[int] = 24, + # Special Pipelines + use_two_stage_pipeline: Optional[bool] = False, + stage2_spatial_upsample_factor: Optional[int] = 2, + clear_lora_before_state_two: Optional[bool] = False, + use_distilled_pipeline: Optional[bool] = False, + # progress_bar + progress_bar_cmd=tqdm, + ): + # Scheduler + self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, special_case="ditilled_stage1" if use_distilled_pipeline else None) + # Inputs + inputs_posi = { + "prompt": prompt, + } + inputs_nega = { + "negative_prompt": negative_prompt, + } + inputs_shared = { + "input_images": input_images, "input_images_indexes": input_images_indexes, "input_images_strength": input_images_strength, + "retake_video": retake_video, "retake_video_regions": retake_video_regions, + "retake_audio": (retake_audio, audio_sample_rate) if retake_audio is not None else None, "retake_audio_regions": retake_audio_regions, + "in_context_videos": in_context_videos, "in_context_downsample_factor": in_context_downsample_factor, + "seed": seed, "rand_device": rand_device, + "height": height, "width": width, "num_frames": num_frames, "frame_rate": frame_rate, + "cfg_scale": cfg_scale, + "tiled": tiled, "tile_size_in_pixels": tile_size_in_pixels, "tile_overlap_in_pixels": tile_overlap_in_pixels, + "tile_size_in_frames": tile_size_in_frames, "tile_overlap_in_frames": tile_overlap_in_frames, + "use_two_stage_pipeline": use_two_stage_pipeline, "use_distilled_pipeline": use_distilled_pipeline, "clear_lora_before_state_two": clear_lora_before_state_two, "stage2_spatial_upsample_factor": stage2_spatial_upsample_factor, + "video_patchifier": self.video_patchifier, "audio_patchifier": self.audio_patchifier, + } + # Stage 1 + inputs_shared, inputs_posi, inputs_nega = self.denoise_stage(inputs_shared, inputs_posi, inputs_nega, self.units, cfg_scale, progress_bar_cmd) + # Stage 2 + inputs_shared, inputs_posi, inputs_nega = self.denoise_stage(inputs_shared, inputs_posi, inputs_nega, self.stage2_units, 1.0, progress_bar_cmd, not inputs_shared["use_two_stage_pipeline"]) + # Decode + self.load_models_to_device(['video_vae_decoder']) + video = self.video_vae_decoder.decode(inputs_shared["video_latents"], tiled, tile_size_in_pixels, tile_overlap_in_pixels, tile_size_in_frames, tile_overlap_in_frames) + video = self.vae_output_to_video(video) + self.load_models_to_device(['audio_vae_decoder', 'audio_vocoder']) + decoded_audio = self.audio_vae_decoder(inputs_shared["audio_latents"]) + decoded_audio = self.audio_vocoder(decoded_audio) + decoded_audio = self.output_audio_format_check(decoded_audio) + return video, decoded_audio + + +class LTX2AudioVideoUnit_PipelineChecker(PipelineUnit): + def __init__(self): + super().__init__( + take_over=True, + input_params=("use_distilled_pipeline", "use_two_stage_pipeline"), + output_params=("use_two_stage_pipeline", "cfg_scale") + ) + + def process(self, pipe: LTX2AudioVideoPipeline, inputs_shared, inputs_posi, inputs_nega): + if inputs_shared.get("use_distilled_pipeline", False): + inputs_shared["use_two_stage_pipeline"] = True + inputs_shared["cfg_scale"] = 1.0 + print(f"Distilled pipeline requested, setting use_two_stage_pipeline to True, disable CFG by setting cfg_scale to 1.0.") + if inputs_shared.get("use_two_stage_pipeline", False): + # distill pipeline also uses two-stage, but it does not needs lora + if not inputs_shared.get("use_distilled_pipeline", False): + if not (hasattr(pipe, "stage2_lora_config") and pipe.stage2_lora_config is not None): + raise ValueError("Two-stage pipeline requested, but stage2_lora_config is not set in the pipeline.") + if not (hasattr(pipe, "upsampler") and pipe.upsampler is not None): + raise ValueError("Two-stage pipeline requested, but upsampler model is not loaded in the pipeline.") + return inputs_shared, inputs_posi, inputs_nega + + +class LTX2AudioVideoUnit_ShapeChecker(PipelineUnit): + """ + For two-stage pipelines, the resolution must be divisible by 64. + For one-stage pipelines, the resolution must be divisible by 32. + This unit set height and width to stage 1 resolution, and stage_2_width and stage_2_height. + """ + def __init__(self): + super().__init__( + input_params=("height", "width", "num_frames", "use_two_stage_pipeline", "stage2_spatial_upsample_factor"), + output_params=("height", "width", "num_frames", "stage_2_height", "stage_2_width"), + ) + + def process(self, pipe: LTX2AudioVideoPipeline, height, width, num_frames, use_two_stage_pipeline=False, stage2_spatial_upsample_factor=2): + if use_two_stage_pipeline: + height, width = height // stage2_spatial_upsample_factor, width // stage2_spatial_upsample_factor + height, width, num_frames = pipe.check_resize_height_width(height, width, num_frames) + stage_2_height, stage_2_width = int(height * stage2_spatial_upsample_factor), int(width * stage2_spatial_upsample_factor) + else: + stage_2_height, stage_2_width = None, None + height, width, num_frames = pipe.check_resize_height_width(height, width, num_frames) + return {"height": height, "width": width, "num_frames": num_frames, "stage_2_height": stage_2_height, "stage_2_width": stage_2_width} + + +class LTX2AudioVideoUnit_PromptEmbedder(PipelineUnit): + + def __init__(self): + super().__init__( + seperate_cfg=True, + input_params_posi={"prompt": "prompt"}, + input_params_nega={"prompt": "negative_prompt"}, + output_params=("video_context", "audio_context"), + onload_model_names=("text_encoder", "text_encoder_post_modules"), + ) + def _preprocess_text( + self, + pipe, + text: str, + ) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: + token_pairs = pipe.tokenizer.tokenize_with_weights(text)["gemma"] + input_ids = torch.tensor([[t[0] for t in token_pairs]], device=pipe.device) + attention_mask = torch.tensor([[w[1] for w in token_pairs]], device=pipe.device) + outputs = pipe.text_encoder(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True) + return outputs.hidden_states, attention_mask + def encode_prompt(self, pipe, text, padding_side="left"): + hidden_states, attention_mask = self._preprocess_text(pipe, text) + video_encoding, audio_encoding, attention_mask = pipe.text_encoder_post_modules.process_hidden_states( + hidden_states, attention_mask, padding_side) + return video_encoding, audio_encoding, attention_mask + + def process(self, pipe: LTX2AudioVideoPipeline, prompt: str): + pipe.load_models_to_device(self.onload_model_names) + video_context, audio_context, _ = self.encode_prompt(pipe, prompt) + return {"video_context": video_context, "audio_context": audio_context} + + +class LTX2AudioVideoUnit_NoiseInitializer(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("height", "width", "num_frames", "seed", "rand_device", "frame_rate"), + output_params=("video_noise", "audio_noise", "video_positions", "audio_positions", "video_latent_shape", "audio_latent_shape") + ) + + def process_stage(self, pipe: LTX2AudioVideoPipeline, height, width, num_frames, seed, rand_device, frame_rate=24.0): + video_pixel_shape = VideoPixelShape(batch=1, frames=num_frames, width=width, height=height, fps=frame_rate) + video_latent_shape = VideoLatentShape.from_pixel_shape(shape=video_pixel_shape, latent_channels=128) + video_noise = pipe.generate_noise(video_latent_shape.to_torch_shape(), seed=seed, rand_device=rand_device) + + latent_coords = pipe.video_patchifier.get_patch_grid_bounds(output_shape=video_latent_shape, device=pipe.device) + video_positions = get_pixel_coords(latent_coords, VIDEO_SCALE_FACTORS, True).float() + video_positions[:, 0, ...] = video_positions[:, 0, ...] / frame_rate + video_positions = video_positions.to(pipe.torch_dtype) + + audio_latent_shape = AudioLatentShape.from_video_pixel_shape(video_pixel_shape) + audio_noise = pipe.generate_noise(audio_latent_shape.to_torch_shape(), seed=seed, rand_device=rand_device) + audio_positions = pipe.audio_patchifier.get_patch_grid_bounds(audio_latent_shape, device=pipe.device) + return { + "video_noise": video_noise, + "audio_noise": audio_noise, + "video_positions": video_positions, + "audio_positions": audio_positions, + "video_latent_shape": video_latent_shape, + "audio_latent_shape": audio_latent_shape + } + + def process(self, pipe: LTX2AudioVideoPipeline, height, width, num_frames, seed, rand_device, frame_rate=24.0): + return self.process_stage(pipe, height, width, num_frames, seed, rand_device, frame_rate) + + +class LTX2AudioVideoUnit_InputVideoEmbedder(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("input_video", "video_noise", "tiled", "tile_size_in_pixels", "tile_overlap_in_pixels"), + output_params=("video_latents", "input_latents"), + onload_model_names=("video_vae_encoder") + ) + + def process(self, pipe: LTX2AudioVideoPipeline, input_video, video_noise, tiled, tile_size_in_pixels, tile_overlap_in_pixels): + if input_video is None or not pipe.scheduler.training: + return {"video_latents": video_noise} + else: + pipe.load_models_to_device(self.onload_model_names) + input_video = pipe.preprocess_video(input_video) + input_latents = pipe.video_vae_encoder.encode(input_video, tiled, tile_size_in_pixels, tile_overlap_in_pixels).to(dtype=pipe.torch_dtype, device=pipe.device) + return {"video_latents": input_latents, "input_latents": input_latents} + +class LTX2AudioVideoUnit_InputAudioEmbedder(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("input_audio", "audio_noise"), + output_params=("audio_latents", "audio_input_latents", "audio_positions", "audio_latent_shape"), + onload_model_names=("audio_vae_encoder",) + ) + + def process(self, pipe: LTX2AudioVideoPipeline, input_audio, audio_noise): + if input_audio is None or not pipe.scheduler.training: + return {"audio_latents": audio_noise} + else: + input_audio, sample_rate = input_audio + input_audio = convert_to_stereo(input_audio) + pipe.load_models_to_device(self.onload_model_names) + input_audio = pipe.audio_processor.waveform_to_mel(input_audio.unsqueeze(0), waveform_sample_rate=sample_rate).to(dtype=pipe.torch_dtype) + audio_input_latents = pipe.audio_vae_encoder(input_audio) + audio_latent_shape = AudioLatentShape.from_torch_shape(audio_input_latents.shape) + audio_positions = pipe.audio_patchifier.get_patch_grid_bounds(audio_latent_shape, device=pipe.device) + return {"audio_latents": audio_input_latents, "audio_input_latents": audio_input_latents, "audio_positions": audio_positions, "audio_latent_shape": audio_latent_shape} + + +class LTX2AudioVideoUnit_VideoRetakeEmbedder(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("retake_video", "height", "width", "tiled", "tile_size_in_pixels", "tile_overlap_in_pixels", "video_positions", "retake_video_regions"), + output_params=("input_latents_video", "denoise_mask_video"), + onload_model_names=("video_vae_encoder") + ) + + def process(self, pipe: LTX2AudioVideoPipeline, retake_video, height, width, tiled, tile_size_in_pixels, tile_overlap_in_pixels, video_positions, retake_video_regions=None): + if retake_video is None: + return {} + pipe.load_models_to_device(self.onload_model_names) + resized_video = [frame.resize((width, height)) for frame in retake_video] + input_video = pipe.preprocess_video(resized_video) + input_latents_video = pipe.video_vae_encoder.encode(input_video, tiled, tile_size_in_pixels, tile_overlap_in_pixels).to(dtype=pipe.torch_dtype, device=pipe.device) + + b, c, f, h, w = input_latents_video.shape + denoise_mask_video = torch.zeros((b, 1, f, h, w), device=input_latents_video.device, dtype=input_latents_video.dtype) + if retake_video_regions is not None and len(retake_video_regions) > 0: + for start_time, end_time in retake_video_regions: + t_start, t_end = video_positions[0, 0].unbind(dim=-1) + in_region = (t_end >= start_time) & (t_start <= end_time) + in_region = pipe.video_patchifier.unpatchify_video(in_region.unsqueeze(0).unsqueeze(-1), f, h, w) + denoise_mask_video = torch.where(in_region, torch.ones_like(denoise_mask_video), denoise_mask_video) + + return {"input_latents_video": input_latents_video, "denoise_mask_video": denoise_mask_video} + + +class LTX2AudioVideoUnit_AudioRetakeEmbedder(PipelineUnit): + """ + Functionality of audio2video, audio retaking. + """ + def __init__(self): + super().__init__( + input_params=("retake_audio", "seed", "rand_device", "retake_audio_regions"), + output_params=("input_latents_audio", "audio_noise", "audio_positions", "audio_latent_shape", "denoise_mask_audio"), + onload_model_names=("audio_vae_encoder",) + ) + + def process(self, pipe: LTX2AudioVideoPipeline, retake_audio, seed, rand_device, retake_audio_regions=None): + if retake_audio is None: + return {} + else: + input_audio, sample_rate = retake_audio + input_audio = convert_to_stereo(input_audio) + pipe.load_models_to_device(self.onload_model_names) + input_audio = pipe.audio_processor.waveform_to_mel(input_audio.unsqueeze(0), waveform_sample_rate=sample_rate).to(dtype=pipe.torch_dtype, device=pipe.device) + input_latents_audio = pipe.audio_vae_encoder(input_audio) + audio_latent_shape = AudioLatentShape.from_torch_shape(input_latents_audio.shape) + audio_positions = pipe.audio_patchifier.get_patch_grid_bounds(audio_latent_shape, device=pipe.device) + # Regenerate noise for the new shape if retake_audio is provided, to avoid shape mismatch. + audio_noise = pipe.generate_noise(input_latents_audio.shape, seed=seed, rand_device=rand_device) + + b, c, t, f = input_latents_audio.shape + denoise_mask_audio = torch.zeros((b, 1, t, 1), device=input_latents_audio.device, dtype=input_latents_audio.dtype) + if retake_audio_regions is not None and len(retake_audio_regions) > 0: + for start_time, end_time in retake_audio_regions: + t_start, t_end = audio_positions[:, 0, :, 0], audio_positions[:, 0, :, 1] + in_region = (t_end >= start_time) & (t_start <= end_time) + in_region = pipe.audio_patchifier.unpatchify_audio(in_region.unsqueeze(-1), 1, 1) + denoise_mask_audio = torch.where(in_region, torch.ones_like(denoise_mask_audio), denoise_mask_audio) + + return { + "input_latents_audio": input_latents_audio, + "denoise_mask_audio": denoise_mask_audio, + "audio_noise": audio_noise, + "audio_positions": audio_positions, + "audio_latent_shape": audio_latent_shape, + } + + +class LTX2AudioVideoUnit_InputImagesEmbedder(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("input_images", "input_images_indexes", "input_images_strength", "video_latents", "height", "width", "frame_rate", "tiled", "tile_size_in_pixels", "tile_overlap_in_pixels", "input_latents_video", "denoise_mask_video"), + output_params=("denoise_mask_video", "input_latents_video", "ref_frames_latents", "ref_frames_positions"), + onload_model_names=("video_vae_encoder") + ) + + def get_image_latent(self, pipe, input_image, height, width, tiled, tile_size_in_pixels, tile_overlap_in_pixels): + image = ltx2_preprocess(np.array(input_image.resize((width, height)))) + image = torch.Tensor(np.array(image, dtype=np.float32)).to(dtype=pipe.torch_dtype, device=pipe.device) + image = image / 127.5 - 1.0 + image = repeat(image, f"H W C -> B C F H W", B=1, F=1) + latents = pipe.video_vae_encoder.encode(image, tiled, tile_size_in_pixels, tile_overlap_in_pixels).to(pipe.device) + return latents + + def apply_input_images_to_latents(self, latents, input_latents, input_indexes, input_strength=1.0, input_latents_video=None, denoise_mask_video=None): + b, _, f, h, w = latents.shape + denoise_mask = torch.ones((b, 1, f, h, w), dtype=latents.dtype, device=latents.device) if denoise_mask_video is None else denoise_mask_video + input_latents_video = torch.zeros_like(latents) if input_latents_video is None else input_latents_video + for idx, input_latent in zip(input_indexes, input_latents): + idx = min(max(1 + (idx-1) // 8, 0), f - 1) + input_latent = input_latent.to(dtype=latents.dtype, device=latents.device) + input_latents_video[:, :, idx:idx + input_latent.shape[2], :, :] = input_latent + denoise_mask[:, :, idx:idx + input_latent.shape[2], :, :] = 1.0 - input_strength + return input_latents_video, denoise_mask + + def process( + self, + pipe: LTX2AudioVideoPipeline, + video_latents, + input_images, + height, + width, + frame_rate, + tiled, + tile_size_in_pixels, + tile_overlap_in_pixels, + input_images_indexes=[0], + input_images_strength=1.0, + input_latents_video=None, + denoise_mask_video=None, + ): + if input_images is None or len(input_images) == 0: + return {} + else: + if len(input_images_indexes) != len(set(input_images_indexes)): + raise ValueError("Input images must have unique indexes.") + pipe.load_models_to_device(self.onload_model_names) + frame_conditions = {"input_latents_video": None, "denoise_mask_video": None, "ref_frames_latents": [], "ref_frames_positions": []} + for img, index in zip(input_images, input_images_indexes): + latents = self.get_image_latent(pipe, img, height, width, tiled, tile_size_in_pixels, tile_overlap_in_pixels) + # first_frame by replacing latents + if index == 0: + input_latents_video, denoise_mask_video = self.apply_input_images_to_latents( + video_latents, [latents], [0], input_images_strength, input_latents_video, denoise_mask_video) + frame_conditions.update({"input_latents_video": input_latents_video, "denoise_mask_video": denoise_mask_video}) + # other frames by adding reference latents + else: + latent_coords = pipe.video_patchifier.get_patch_grid_bounds(output_shape=VideoLatentShape.from_torch_shape(latents.shape), device=pipe.device) + video_positions = get_pixel_coords(latent_coords, VIDEO_SCALE_FACTORS, False).float() + video_positions[:, 0, ...] = (video_positions[:, 0, ...] + index) / frame_rate + video_positions = video_positions.to(pipe.torch_dtype) + frame_conditions["ref_frames_latents"].append(latents) + frame_conditions["ref_frames_positions"].append(video_positions) + if len(frame_conditions["ref_frames_latents"]) == 0: + frame_conditions.update({"ref_frames_latents": None, "ref_frames_positions": None}) + return frame_conditions + + +class LTX2AudioVideoUnit_InContextVideoEmbedder(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("in_context_videos", "height", "width", "num_frames", "frame_rate", "in_context_downsample_factor", "tiled", "tile_size_in_pixels", "tile_overlap_in_pixels"), + output_params=("in_context_video_latents", "in_context_video_positions"), + onload_model_names=("video_vae_encoder") + ) + + def check_in_context_video(self, pipe, in_context_video, height, width, num_frames, in_context_downsample_factor): + if in_context_video is None or len(in_context_video) == 0: + raise ValueError("In-context video is None or empty.") + in_context_video = in_context_video[:num_frames] + expected_height = height // in_context_downsample_factor + expected_width = width // in_context_downsample_factor + current_h, current_w, current_f = in_context_video[0].size[1], in_context_video[0].size[0], len(in_context_video) + h, w, f = pipe.check_resize_height_width(expected_height, expected_width, current_f, verbose=0) + if current_h != h or current_w != w: + in_context_video = [img.resize((w, h)) for img in in_context_video] + if current_f != f: + # pad black frames at the end + in_context_video = in_context_video + [Image.new("RGB", (w, h), (0, 0, 0))] * (f - current_f) + return in_context_video + + def process(self, pipe: LTX2AudioVideoPipeline, in_context_videos, height, width, num_frames, frame_rate, in_context_downsample_factor, tiled, tile_size_in_pixels, tile_overlap_in_pixels): + if in_context_videos is None or len(in_context_videos) == 0: + return {} + else: + pipe.load_models_to_device(self.onload_model_names) + latents, positions = [], [] + for in_context_video in in_context_videos: + in_context_video = self.check_in_context_video(pipe, in_context_video, height, width, num_frames, in_context_downsample_factor) + in_context_video = pipe.preprocess_video(in_context_video) + in_context_latents = pipe.video_vae_encoder.encode(in_context_video, tiled, tile_size_in_pixels, tile_overlap_in_pixels).to(dtype=pipe.torch_dtype, device=pipe.device) + + latent_coords = pipe.video_patchifier.get_patch_grid_bounds(output_shape=VideoLatentShape.from_torch_shape(in_context_latents.shape), device=pipe.device) + video_positions = get_pixel_coords(latent_coords, VIDEO_SCALE_FACTORS, True).float() + video_positions[:, 0, ...] = video_positions[:, 0, ...] / frame_rate + video_positions[:, 1, ...] *= in_context_downsample_factor # height axis + video_positions[:, 2, ...] *= in_context_downsample_factor # width axis + video_positions = video_positions.to(pipe.torch_dtype) + + latents.append(in_context_latents) + positions.append(video_positions) + latents = torch.cat(latents, dim=1) + positions = torch.cat(positions, dim=1) + return {"in_context_video_latents": latents, "in_context_video_positions": positions} + + +class LTX2AudioVideoUnit_SwitchStage2(PipelineUnit): + """ + 1. switch height and width to stage 2 resolution + 2. clear in_context_video_latents and in_context_video_positions + 3. switch stage 2 lora model + """ + def __init__(self): + super().__init__( + input_params=("stage_2_height", "stage_2_width", "clear_lora_before_state_two", "use_distilled_pipeline"), + output_params=("height", "width", "in_context_video_latents", "in_context_video_positions"), + ) + + def process(self, pipe: LTX2AudioVideoPipeline, stage_2_height, stage_2_width, clear_lora_before_state_two, use_distilled_pipeline): + stage2_params = {} + stage2_params.update({"height": stage_2_height, "width": stage_2_width}) + stage2_params.update({"in_context_video_latents": None, "in_context_video_positions": None}) + stage2_params.update({"input_latents_video": None, "denoise_mask_video": None}) + if clear_lora_before_state_two: + pipe.clear_lora() + if not use_distilled_pipeline: + pipe.load_lora(pipe.dit, pipe.stage2_lora_config, alpha=pipe.stage2_lora_strength, state_dict=pipe.stage2_lora_config.state_dict) + return stage2_params + + +class LTX2AudioVideoUnit_SetScheduleStage2(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("video_latents", "video_noise", "audio_latents", "audio_noise"), + output_params=("video_latents", "audio_latents"), + ) + + def process(self, pipe: LTX2AudioVideoPipeline, video_latents, video_noise, audio_latents, audio_noise): + pipe.scheduler.set_timesteps(special_case="stage2") + video_latents = pipe.scheduler.add_noise(video_latents, video_noise, pipe.scheduler.timesteps[0]) + audio_latents = pipe.scheduler.add_noise(audio_latents, audio_noise, pipe.scheduler.timesteps[0]) + return {"video_latents": video_latents, "audio_latents": audio_latents} + + +class LTX2AudioVideoUnit_LatentsUpsampler(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("video_latents",), + output_params=("video_latents",), + onload_model_names=("upsampler",), + ) + + def process(self, pipe: LTX2AudioVideoPipeline, video_latents): + if video_latents is None or pipe.upsampler is None: + raise ValueError("No upsampler or no video latents before stage 2.") + else: + pipe.load_models_to_device(self.onload_model_names) + video_latents = pipe.video_vae_encoder.per_channel_statistics.un_normalize(video_latents) + video_latents = pipe.upsampler(video_latents) + video_latents = pipe.video_vae_encoder.per_channel_statistics.normalize(video_latents) + return {"video_latents": video_latents} + + +def model_fn_ltx2( + dit: LTXModel, + video_latents=None, + video_context=None, + video_positions=None, + video_patchifier=None, + audio_latents=None, + audio_context=None, + audio_positions=None, + audio_patchifier=None, + timestep=None, + # First Frame Conditioning + input_latents_video=None, + denoise_mask_video=None, + # Other Frames Conditioning + ref_frames_latents=None, + ref_frames_positions=None, + # In-Context Conditioning + in_context_video_latents=None, + in_context_video_positions=None, + # Audio Inputs + input_latents_audio=None, + denoise_mask_audio=None, + # Gradient Checkpointing + use_gradient_checkpointing=False, + use_gradient_checkpointing_offload=False, + **kwargs, +): + timestep = timestep.float() / 1000. + + # patchify + b, c_v, f, h, w = video_latents.shape + video_latents = video_patchifier.patchify(video_latents) + seq_len_video = video_latents.shape[1] + video_timesteps = timestep.repeat(1, video_latents.shape[1], 1) + # Frist frame conditioning by replacing the video latents + if input_latents_video is not None: + denoise_mask_video = video_patchifier.patchify(denoise_mask_video) + video_latents = video_latents * denoise_mask_video + video_patchifier.patchify(input_latents_video) * (1.0 - denoise_mask_video) + video_timesteps = denoise_mask_video * video_timesteps + + # Reference conditioning by appending the reference video or frame latents + total_ref_latents = ref_frames_latents if ref_frames_latents is not None else [] + total_ref_positions = ref_frames_positions if ref_frames_positions is not None else [] + total_ref_latents += [in_context_video_latents] if in_context_video_latents is not None else [] + total_ref_positions += [in_context_video_positions] if in_context_video_positions is not None else [] + if len(total_ref_latents) > 0: + for ref_frames_latent, ref_frames_position in zip(total_ref_latents, total_ref_positions): + ref_frames_latent = video_patchifier.patchify(ref_frames_latent) + ref_frames_timestep = timestep.repeat(1, ref_frames_latent.shape[1], 1) * 0. + video_latents = torch.cat([video_latents, ref_frames_latent], dim=1) + video_positions = torch.cat([video_positions, ref_frames_position], dim=2) + video_timesteps = torch.cat([video_timesteps, ref_frames_timestep], dim=1) + + if audio_latents is not None: + _, c_a, _, mel_bins = audio_latents.shape + audio_latents = audio_patchifier.patchify(audio_latents) + audio_timesteps = timestep.repeat(1, audio_latents.shape[1], 1) + else: + audio_timesteps = None + if input_latents_audio is not None: + denoise_mask_audio = audio_patchifier.patchify(denoise_mask_audio) + audio_latents = audio_latents * denoise_mask_audio + audio_patchifier.patchify(input_latents_audio) * (1.0 - denoise_mask_audio) + audio_timesteps = denoise_mask_audio * audio_timesteps + + vx, ax = dit( + video_latents=video_latents, + video_positions=video_positions, + video_context=video_context, + video_timesteps=video_timesteps, + audio_latents=audio_latents, + audio_positions=audio_positions, + audio_context=audio_context, + audio_timesteps=audio_timesteps, + sigma=timestep, + use_gradient_checkpointing=use_gradient_checkpointing, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + ) + + vx = vx[:, :seq_len_video, ...] + # unpatchify + vx = video_patchifier.unpatchify_video(vx, f, h, w) + ax = audio_patchifier.unpatchify_audio(ax, c_a, mel_bins) if ax is not None else None + return vx, ax diff --git a/diffsynth/pipelines/mova_audio_video.py b/diffsynth/pipelines/mova_audio_video.py new file mode 100644 index 0000000000000000000000000000000000000000..d89d3ff8dd701bc62708b0f558154728705549e4 --- /dev/null +++ b/diffsynth/pipelines/mova_audio_video.py @@ -0,0 +1,461 @@ +import sys +import torch, types +from PIL import Image +from typing import Optional, Union +from einops import rearrange +import numpy as np +from PIL import Image +from tqdm import tqdm +from typing import Optional + +from ..core.device.npu_compatible_device import get_device_type +from ..diffusion import FlowMatchScheduler +from ..core import ModelConfig, gradient_checkpoint_forward +from ..diffusion.base_pipeline import BasePipeline, PipelineUnit + +from ..models.wan_video_dit import WanModel, sinusoidal_embedding_1d, set_to_torch_norm +from ..models.wan_video_text_encoder import WanTextEncoder, HuggingfaceTokenizer +from ..models.wan_video_vae import WanVideoVAE +from ..models.mova_audio_dit import MovaAudioDit +from ..models.mova_audio_vae import DacVAE +from ..models.mova_dual_tower_bridge import DualTowerConditionalBridge +from ..utils.data.audio import convert_to_mono, resample_waveform + + +class MovaAudioVideoPipeline(BasePipeline): + + def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16): + super().__init__( + device=device, torch_dtype=torch_dtype, + height_division_factor=16, width_division_factor=16, time_division_factor=4, time_division_remainder=1 + ) + self.scheduler = FlowMatchScheduler("Wan") + self.tokenizer: HuggingfaceTokenizer = None + self.text_encoder: WanTextEncoder = None + self.video_dit: WanModel = None # high noise model + self.video_dit2: WanModel = None # low noise model + self.audio_dit: MovaAudioDit = None + self.dual_tower_bridge: DualTowerConditionalBridge = None + self.video_vae: WanVideoVAE = None + self.audio_vae: DacVAE = None + + self.in_iteration_models = ("video_dit", "audio_dit", "dual_tower_bridge") + self.in_iteration_models_2 = ("video_dit2", "audio_dit", "dual_tower_bridge") + + self.units = [ + MovaAudioVideoUnit_ShapeChecker(), + MovaAudioVideoUnit_NoiseInitializer(), + MovaAudioVideoUnit_InputVideoEmbedder(), + MovaAudioVideoUnit_InputAudioEmbedder(), + MovaAudioVideoUnit_PromptEmbedder(), + MovaAudioVideoUnit_ImageEmbedderVAE(), + MovaAudioVideoUnit_UnifiedSequenceParallel(), + ] + self.model_fn = model_fn_mova_audio_video + self.compilable_models = ["video_dit", "video_dit2", "audio_dit"] + + def enable_usp(self): + from ..utils.xfuser import get_sequence_parallel_world_size, usp_attn_forward + for block in self.video_dit.blocks + self.audio_dit.blocks + self.video_dit2.blocks: + block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn) + self.sp_size = get_sequence_parallel_world_size() + self.use_unified_sequence_parallel = True + + @staticmethod + def from_pretrained( + torch_dtype: torch.dtype = torch.bfloat16, + device: Union[str, torch.device] = get_device_type(), + model_configs: list[ModelConfig] = [], + tokenizer_config: ModelConfig = ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="tokenizer/"), + use_usp: bool = False, + vram_limit: float = None, + ): + if use_usp: + from ..utils.xfuser import initialize_usp + initialize_usp(device) + import torch.distributed as dist + from ..core.device.npu_compatible_device import get_device_name + if dist.is_available() and dist.is_initialized(): + device = get_device_name() + # Initialize pipeline + pipe = MovaAudioVideoPipeline(device=device, torch_dtype=torch_dtype) + model_pool = pipe.download_and_load_models(model_configs, vram_limit) + + # Fetch models + pipe.text_encoder = model_pool.fetch_model("wan_video_text_encoder") + dit = model_pool.fetch_model("wan_video_dit", index=2) + if isinstance(dit, list): + pipe.video_dit, pipe.video_dit2 = dit + else: + pipe.video_dit = dit + pipe.audio_dit = model_pool.fetch_model("mova_audio_dit") + pipe.dual_tower_bridge = model_pool.fetch_model("mova_dual_tower_bridge") + pipe.video_vae = model_pool.fetch_model("wan_video_vae") + pipe.audio_vae = model_pool.fetch_model("mova_audio_vae") + set_to_torch_norm([pipe.video_dit, pipe.audio_dit, pipe.dual_tower_bridge] + ([pipe.video_dit2] if pipe.video_dit2 is not None else [])) + + # Size division factor + if pipe.video_vae is not None: + pipe.height_division_factor = pipe.video_vae.upsampling_factor * 2 + pipe.width_division_factor = pipe.video_vae.upsampling_factor * 2 + + # Initialize tokenizer and processor + if tokenizer_config is not None: + tokenizer_config.download_if_necessary() + pipe.tokenizer = HuggingfaceTokenizer(name=tokenizer_config.path, seq_len=512, clean='whitespace') + + # Unified Sequence Parallel + if use_usp: pipe.enable_usp() + + # VRAM Management + pipe.vram_management_enabled = pipe.check_vram_management_state() + return pipe + + @torch.no_grad() + def __call__( + self, + # Prompt + prompt: str, + negative_prompt: Optional[str] = "", + # Image-to-video + input_image: Optional[Image.Image] = None, + # First-last-frame-to-video + end_image: Optional[Image.Image] = None, + # Video-to-video + denoising_strength: Optional[float] = 1.0, + # Randomness + seed: Optional[int] = None, + rand_device: Optional[str] = "cpu", + # Shape + height: Optional[int] = 352, + width: Optional[int] = 640, + num_frames: Optional[int] = 81, + frame_rate: Optional[int] = 24, + # Classifier-free guidance + cfg_scale: Optional[float] = 5.0, + # Boundary + switch_DiT_boundary: Optional[float] = 0.9, + # Scheduler + num_inference_steps: Optional[int] = 50, + sigma_shift: Optional[float] = 5.0, + # VAE tiling + tiled: Optional[bool] = True, + tile_size: Optional[tuple[int, int]] = (30, 52), + tile_stride: Optional[tuple[int, int]] = (15, 26), + # progress_bar + progress_bar_cmd=tqdm, + ): + # Scheduler + self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift) + + # Inputs + inputs_posi = { + "prompt": prompt, + } + inputs_nega = { + "negative_prompt": negative_prompt, + } + inputs_shared = { + "input_image": input_image, + "end_image": end_image, + "denoising_strength": denoising_strength, + "seed": seed, "rand_device": rand_device, + "height": height, "width": width, "num_frames": num_frames, "frame_rate": frame_rate, + "cfg_scale": cfg_scale, + "sigma_shift": sigma_shift, + "tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride, + } + for unit in self.units: + inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega) + + # Denoise + self.load_models_to_device(self.in_iteration_models) + models = {name: getattr(self, name) for name in self.in_iteration_models} + for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): + # Switch DiT if necessary + if timestep.item() < switch_DiT_boundary * 1000 and self.video_dit2 is not None and not models["video_dit"] is self.video_dit2: + self.load_models_to_device(self.in_iteration_models_2) + models["video_dit"] = self.video_dit2 + # Timestep + timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device) + noise_pred_video, noise_pred_audio = self.cfg_guided_model_fn( + self.model_fn, cfg_scale, inputs_shared, inputs_posi, inputs_nega, + **models, timestep=timestep, progress_id=progress_id + ) + # Scheduler + inputs_shared["video_latents"] = self.step(self.scheduler, inputs_shared["video_latents"], progress_id=progress_id, noise_pred=noise_pred_video, **inputs_shared) + inputs_shared["audio_latents"] = self.step(self.scheduler, inputs_shared["audio_latents"], progress_id=progress_id, noise_pred=noise_pred_audio, **inputs_shared) + + # Decode + self.load_models_to_device(['video_vae']) + video = self.video_vae.decode(inputs_shared["video_latents"], device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + video = self.vae_output_to_video(video) + self.load_models_to_device(["audio_vae"]) + audio = self.audio_vae.decode(inputs_shared["audio_latents"]) + audio = self.output_audio_format_check(audio) + self.load_models_to_device([]) + return video, audio + + +class MovaAudioVideoUnit_ShapeChecker(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("height", "width", "num_frames"), + output_params=("height", "width", "num_frames"), + ) + + def process(self, pipe: MovaAudioVideoPipeline, height, width, num_frames): + height, width, num_frames = pipe.check_resize_height_width(height, width, num_frames) + return {"height": height, "width": width, "num_frames": num_frames} + + +class MovaAudioVideoUnit_NoiseInitializer(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("height", "width", "num_frames", "seed", "rand_device", "frame_rate"), + output_params=("video_noise", "audio_noise") + ) + + def process(self, pipe: MovaAudioVideoPipeline, height, width, num_frames, seed, rand_device, frame_rate): + length = (num_frames - 1) // 4 + 1 + video_shape = (1, pipe.video_vae.model.z_dim, length, height // pipe.video_vae.upsampling_factor, width // pipe.video_vae.upsampling_factor) + video_noise = pipe.generate_noise(video_shape, seed=seed, rand_device=rand_device) + + audio_num_samples = (int(pipe.audio_vae.sample_rate * num_frames / frame_rate) - 1) // int(pipe.audio_vae.hop_length) + 1 + audio_shape = (1, pipe.audio_vae.latent_dim, audio_num_samples) + audio_noise = pipe.generate_noise(audio_shape, seed=seed, rand_device=rand_device) + return {"video_noise": video_noise, "audio_noise": audio_noise} + + +class MovaAudioVideoUnit_InputVideoEmbedder(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("input_video", "video_noise", "tiled", "tile_size", "tile_stride"), + output_params=("video_latents", "input_latents"), + onload_model_names=("video_vae",) + ) + + def process(self, pipe: MovaAudioVideoPipeline, input_video, video_noise, tiled, tile_size, tile_stride): + if input_video is None or not pipe.scheduler.training: + return {"video_latents": video_noise} + else: + pipe.load_models_to_device(self.onload_model_names) + input_video = pipe.preprocess_video(input_video) + input_latents = pipe.video_vae.encode(input_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) + return {"input_latents": input_latents} + + +class MovaAudioVideoUnit_InputAudioEmbedder(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("input_audio", "audio_noise"), + output_params=("audio_latents", "audio_input_latents"), + onload_model_names=("audio_vae",) + ) + + def process(self, pipe: MovaAudioVideoPipeline, input_audio, audio_noise): + if input_audio is None or not pipe.scheduler.training: + return {"audio_latents": audio_noise} + else: + pipe.load_models_to_device(self.onload_model_names) + input_audio, sample_rate = input_audio + input_audio = convert_to_mono(input_audio) + input_audio = resample_waveform(input_audio, sample_rate, pipe.audio_vae.sample_rate) + input_audio = pipe.audio_vae.preprocess(input_audio.unsqueeze(0), pipe.audio_vae.sample_rate) + z, _, _, _, _ = pipe.audio_vae.encode(input_audio) + return {"audio_input_latents": z.mode()} + + +class MovaAudioVideoUnit_PromptEmbedder(PipelineUnit): + def __init__(self): + super().__init__( + seperate_cfg=True, + input_params_posi={"prompt": "prompt"}, + input_params_nega={"prompt": "negative_prompt"}, + output_params=("context",), + onload_model_names=("text_encoder",) + ) + + def encode_prompt(self, pipe: MovaAudioVideoPipeline, prompt): + ids, mask = pipe.tokenizer( + prompt, + padding="max_length", + max_length=512, + truncation=True, + add_special_tokens=True, + return_mask=True, + return_tensors="pt", + ) + ids = ids.to(pipe.device) + mask = mask.to(pipe.device) + seq_lens = mask.gt(0).sum(dim=1).long() + prompt_emb = pipe.text_encoder(ids, mask) + for i, v in enumerate(seq_lens): + prompt_emb[:, v:] = 0 + return prompt_emb + + def process(self, pipe: MovaAudioVideoPipeline, prompt) -> dict: + pipe.load_models_to_device(self.onload_model_names) + prompt_emb = self.encode_prompt(pipe, prompt) + return {"context": prompt_emb} + + +class MovaAudioVideoUnit_ImageEmbedderVAE(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("input_image", "end_image", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride"), + output_params=("y",), + onload_model_names=("video_vae",) + ) + + def process(self, pipe: MovaAudioVideoPipeline, input_image, end_image, num_frames, height, width, tiled, tile_size, tile_stride): + if input_image is None or not pipe.video_dit.require_vae_embedding: + return {} + pipe.load_models_to_device(self.onload_model_names) + + image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device) + msk = torch.ones(1, num_frames, height//8, width//8, device=pipe.device) + msk[:, 1:] = 0 + if end_image is not None: + end_image = pipe.preprocess_image(end_image.resize((width, height))).to(pipe.device) + vae_input = torch.concat([image.transpose(0,1), torch.zeros(3, num_frames-2, height, width).to(image.device), end_image.transpose(0,1)],dim=1) + msk[:, -1:] = 1 + else: + vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1) + + msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) + msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8) + msk = msk.transpose(1, 2)[0] + + y = pipe.video_vae.encode([vae_input.to(dtype=pipe.torch_dtype, device=pipe.device)], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0] + y = y.to(dtype=pipe.torch_dtype, device=pipe.device) + y = torch.concat([msk, y]) + y = y.unsqueeze(0) + y = y.to(dtype=pipe.torch_dtype, device=pipe.device) + return {"y": y} + + +class MovaAudioVideoUnit_UnifiedSequenceParallel(PipelineUnit): + def __init__(self): + super().__init__(input_params=(), output_params=("use_unified_sequence_parallel",)) + + def process(self, pipe: MovaAudioVideoPipeline): + if hasattr(pipe, "use_unified_sequence_parallel") and pipe.use_unified_sequence_parallel: + return {"use_unified_sequence_parallel": True} + return {"use_unified_sequence_parallel": False} + + +def model_fn_mova_audio_video( + video_dit: WanModel, + audio_dit: MovaAudioDit, + dual_tower_bridge: DualTowerConditionalBridge, + video_latents: torch.Tensor = None, + audio_latents: torch.Tensor = None, + timestep: torch.Tensor = None, + context: torch.Tensor = None, + y: Optional[torch.Tensor] = None, + frame_rate: Optional[int] = 24, + use_unified_sequence_parallel: bool = False, + use_gradient_checkpointing: bool = False, + use_gradient_checkpointing_offload: bool = False, + **kwargs, +): + video_x, audio_x = video_latents, audio_latents + # First-Last Frame + if y is not None: + video_x = torch.cat([video_x, y], dim=1) + + # Timestep + video_t = video_dit.time_embedding(sinusoidal_embedding_1d(video_dit.freq_dim, timestep)) + video_t_mod = video_dit.time_projection(video_t).unflatten(1, (6, video_dit.dim)) + audio_t = audio_dit.time_embedding(sinusoidal_embedding_1d(audio_dit.freq_dim, timestep)) + audio_t_mod = audio_dit.time_projection(audio_t).unflatten(1, (6, audio_dit.dim)) + + # Context + video_context = video_dit.text_embedding(context) + audio_context = audio_dit.text_embedding(context) + + # Patchify + video_x = video_dit.patch_embedding(video_x) + f_v, h, w = video_x.shape[2:] + video_x = rearrange(video_x, 'b c f h w -> b (f h w) c').contiguous() + seq_len_video = video_x.shape[1] + + audio_x = audio_dit.patch_embedding(audio_x) + f_a = audio_x.shape[2] + audio_x = rearrange(audio_x, 'b c f -> b f c').contiguous() + seq_len_audio = audio_x.shape[1] + + # Freqs + video_freqs = torch.cat([ + video_dit.freqs[0][:f_v].view(f_v, 1, 1, -1).expand(f_v, h, w, -1), + video_dit.freqs[1][:h].view(1, h, 1, -1).expand(f_v, h, w, -1), + video_dit.freqs[2][:w].view(1, 1, w, -1).expand(f_v, h, w, -1) + ], dim=-1).reshape(f_v * h * w, 1, -1).to(video_x.device) + audio_freqs = torch.cat([ + audio_dit.freqs[0][:f_a].view(f_a, -1).expand(f_a, -1), + audio_dit.freqs[1][:f_a].view(f_a, -1).expand(f_a, -1), + audio_dit.freqs[2][:f_a].view(f_a, -1).expand(f_a, -1), + ], dim=-1).reshape(f_a, 1, -1).to(audio_x.device) + + video_rope, audio_rope = dual_tower_bridge.build_aligned_freqs( + video_fps=frame_rate, + grid_size=(f_v, h, w), + audio_steps=audio_x.shape[1], + device=video_x.device, + dtype=video_x.dtype, + ) + # usp func + if use_unified_sequence_parallel: + from ..utils.xfuser import get_current_chunk, gather_all_chunks + else: + get_current_chunk = lambda x, dim=1: x + gather_all_chunks = lambda x, seq_len, dim=1: x + # Forward blocks + for block_id in range(len(audio_dit.blocks)): + if dual_tower_bridge.should_interact(block_id, "a2v"): + video_x, audio_x = dual_tower_bridge( + block_id, + video_x, + audio_x, + x_freqs=video_rope, + y_freqs=audio_rope, + condition_scale=1.0, + video_grid_size=(f_v, h, w), + use_gradient_checkpointing=use_gradient_checkpointing, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + ) + video_x = get_current_chunk(video_x, dim=1) + video_x = gradient_checkpoint_forward( + video_dit.blocks[block_id], + use_gradient_checkpointing, + use_gradient_checkpointing_offload, + video_x, video_context, video_t_mod, video_freqs + ) + video_x = gather_all_chunks(video_x, seq_len=seq_len_video, dim=1) + audio_x = get_current_chunk(audio_x, dim=1) + audio_x = gradient_checkpoint_forward( + audio_dit.blocks[block_id], + use_gradient_checkpointing, + use_gradient_checkpointing_offload, + audio_x, audio_context, audio_t_mod, audio_freqs + ) + audio_x = gather_all_chunks(audio_x, seq_len=seq_len_audio, dim=1) + + video_x = get_current_chunk(video_x, dim=1) + for block_id in range(len(audio_dit.blocks), len(video_dit.blocks)): + video_x = gradient_checkpoint_forward( + video_dit.blocks[block_id], + use_gradient_checkpointing, + use_gradient_checkpointing_offload, + video_x, video_context, video_t_mod, video_freqs + ) + video_x = gather_all_chunks(video_x, seq_len=seq_len_video, dim=1) + + # Head + video_x = video_dit.head(video_x, video_t) + video_x = video_dit.unpatchify(video_x, (f_v, h, w)) + + audio_x = audio_dit.head(audio_x, audio_t) + audio_x = audio_dit.unpatchify(audio_x, (f_a,)) + return video_x, audio_x diff --git a/diffsynth/pipelines/qwen_image.py b/diffsynth/pipelines/qwen_image.py new file mode 100644 index 0000000000000000000000000000000000000000..f3256a19af9f6123bf5aefe3ec684db787e92b7c --- /dev/null +++ b/diffsynth/pipelines/qwen_image.py @@ -0,0 +1,818 @@ +import torch, math +from PIL import Image +from typing import Union +from tqdm import tqdm +from einops import rearrange +import numpy as np +from math import prod + +from ..core.device.npu_compatible_device import get_device_type +from ..diffusion import FlowMatchScheduler +from ..core import ModelConfig, gradient_checkpoint_forward +from ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput +from ..utils.lora.merge import merge_lora + +from ..models.qwen_image_dit import QwenImageDiT +from ..models.qwen_image_text_encoder import QwenImageTextEncoder +from ..models.qwen_image_vae import QwenImageVAE +from ..models.qwen_image_controlnet import QwenImageBlockWiseControlNet +from ..models.siglip2_image_encoder import Siglip2ImageEncoder +from ..models.dinov3_image_encoder import DINOv3ImageEncoder +from ..models.qwen_image_image2lora import QwenImageImage2LoRAModel + + +class QwenImagePipeline(BasePipeline): + + def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16): + super().__init__( + device=device, torch_dtype=torch_dtype, + height_division_factor=16, width_division_factor=16, + ) + from transformers import Qwen2Tokenizer, Qwen2VLProcessor + + self.scheduler = FlowMatchScheduler("Qwen-Image") + self.text_encoder: QwenImageTextEncoder = None + self.dit: QwenImageDiT = None + self.vae: QwenImageVAE = None + self.blockwise_controlnet: QwenImageBlockwiseMultiControlNet = None + self.tokenizer: Qwen2Tokenizer = None + self.siglip2_image_encoder: Siglip2ImageEncoder = None + self.dinov3_image_encoder: DINOv3ImageEncoder = None + self.image2lora_style: QwenImageImage2LoRAModel = None + self.image2lora_coarse: QwenImageImage2LoRAModel = None + self.image2lora_fine: QwenImageImage2LoRAModel = None + self.processor: Qwen2VLProcessor = None + self.in_iteration_models = ("dit", "blockwise_controlnet") + self.units = [ + QwenImageUnit_ShapeChecker(), + QwenImageUnit_NoiseInitializer(), + QwenImageUnit_InputImageEmbedder(), + QwenImageUnit_Inpaint(), + QwenImageUnit_EditImageEmbedder(), + QwenImageUnit_LayerInputImageEmbedder(), + QwenImageUnit_ContextImageEmbedder(), + QwenImageUnit_PromptEmbedder(), + QwenImageUnit_EntityControl(), + QwenImageUnit_BlockwiseControlNet(), + ] + self.model_fn = model_fn_qwen_image + self.compilable_models = ["dit"] + + + @staticmethod + def from_pretrained( + torch_dtype: torch.dtype = torch.bfloat16, + device: Union[str, torch.device] = get_device_type(), + model_configs: list[ModelConfig] = [], + tokenizer_config: ModelConfig = ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"), + processor_config: ModelConfig = None, + vram_limit: float = None, + ): + # Initialize pipeline + pipe = QwenImagePipeline(device=device, torch_dtype=torch_dtype) + model_pool = pipe.download_and_load_models(model_configs, vram_limit) + + # Fetch models + pipe.text_encoder = model_pool.fetch_model("qwen_image_text_encoder") + pipe.dit = model_pool.fetch_model("qwen_image_dit") + pipe.vae = model_pool.fetch_model("qwen_image_vae") + pipe.blockwise_controlnet = QwenImageBlockwiseMultiControlNet(model_pool.fetch_model("qwen_image_blockwise_controlnet", index="all")) + if tokenizer_config is not None: + tokenizer_config.download_if_necessary() + from transformers import Qwen2Tokenizer + pipe.tokenizer = Qwen2Tokenizer.from_pretrained(tokenizer_config.path) + if processor_config is not None: + processor_config.download_if_necessary() + from transformers import Qwen2VLProcessor + pipe.processor = Qwen2VLProcessor.from_pretrained(processor_config.path) + pipe.siglip2_image_encoder = model_pool.fetch_model("siglip2_image_encoder") + pipe.dinov3_image_encoder = model_pool.fetch_model("dinov3_image_encoder") + pipe.image2lora_style = model_pool.fetch_model("qwen_image_image2lora_style") + pipe.image2lora_coarse = model_pool.fetch_model("qwen_image_image2lora_coarse") + pipe.image2lora_fine = model_pool.fetch_model("qwen_image_image2lora_fine") + + # VRAM Management + pipe.vram_management_enabled = pipe.check_vram_management_state() + return pipe + + + @torch.no_grad() + def __call__( + self, + # Prompt + prompt: str, + negative_prompt: str = "", + cfg_scale: float = 4.0, + # Image + input_image: Image.Image = None, + denoising_strength: float = 1.0, + # Inpaint + inpaint_mask: Image.Image = None, + inpaint_blur_size: int = None, + inpaint_blur_sigma: float = None, + # Shape + height: int = 1328, + width: int = 1328, + # Randomness + seed: int = None, + rand_device: str = "cpu", + # Steps + num_inference_steps: int = 30, + exponential_shift_mu: float = None, + # Blockwise ControlNet + blockwise_controlnet_inputs: list[ControlNetInput] = None, + # EliGen + eligen_entity_prompts: list[str] = None, + eligen_entity_masks: list[Image.Image] = None, + eligen_enable_on_negative: bool = False, + # Qwen-Image-Edit + edit_image: Image.Image = None, + edit_image_auto_resize: bool = True, + edit_rope_interpolation: bool = False, + # Qwen-Image-Edit-2511 + zero_cond_t: bool = False, + # Qwen-Image-Layered + layer_input_image: Image.Image = None, + layer_num: int = None, + # In-context control + context_image: Image.Image = None, + # Tile + tiled: bool = False, + tile_size: int = 128, + tile_stride: int = 64, + # Progress bar + progress_bar_cmd = tqdm, + ): + # Scheduler + self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, dynamic_shift_len=(height // 16) * (width // 16), exponential_shift_mu=exponential_shift_mu) + + # Parameters + inputs_posi = { + "prompt": prompt, + } + inputs_nega = { + "negative_prompt": negative_prompt, + } + inputs_shared = { + "cfg_scale": cfg_scale, + "input_image": input_image, "denoising_strength": denoising_strength, + "inpaint_mask": inpaint_mask, "inpaint_blur_size": inpaint_blur_size, "inpaint_blur_sigma": inpaint_blur_sigma, + "height": height, "width": width, + "seed": seed, "rand_device": rand_device, + "num_inference_steps": num_inference_steps, + "blockwise_controlnet_inputs": blockwise_controlnet_inputs, + "tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride, + "eligen_entity_prompts": eligen_entity_prompts, "eligen_entity_masks": eligen_entity_masks, "eligen_enable_on_negative": eligen_enable_on_negative, + "edit_image": edit_image, "edit_image_auto_resize": edit_image_auto_resize, "edit_rope_interpolation": edit_rope_interpolation, + "context_image": context_image, + "zero_cond_t": zero_cond_t, + "layer_input_image": layer_input_image, + "layer_num": layer_num, + } + for unit in self.units: + inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega) + + # Denoise + self.load_models_to_device(self.in_iteration_models) + models = {name: getattr(self, name) for name in self.in_iteration_models} + for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): + timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device) + noise_pred = self.cfg_guided_model_fn( + self.model_fn, cfg_scale, + inputs_shared, inputs_posi, inputs_nega, + **models, timestep=timestep, progress_id=progress_id + ) + inputs_shared["latents"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared) + + # Decode + self.load_models_to_device(['vae']) + image = self.vae.decode(inputs_shared["latents"], device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + if layer_num is None: + image = self.vae_output_to_image(image) + else: + image = [self.vae_output_to_image(i, pattern="C H W") for i in image] + self.load_models_to_device([]) + + return image + + +class QwenImageBlockwiseMultiControlNet(torch.nn.Module): + def __init__(self, models: list[QwenImageBlockWiseControlNet]): + super().__init__() + if not isinstance(models, list): + models = [models] + self.models = torch.nn.ModuleList(models) + for model in models: + if hasattr(model, "vram_management_enabled") and getattr(model, "vram_management_enabled"): + self.vram_management_enabled = True + + def preprocess(self, controlnet_inputs: list[ControlNetInput], conditionings: list[torch.Tensor], **kwargs): + processed_conditionings = [] + for controlnet_input, conditioning in zip(controlnet_inputs, conditionings): + conditioning = rearrange(conditioning, "B C (H P) (W Q) -> B (H W) (C P Q)", P=2, Q=2) + model_output = self.models[controlnet_input.controlnet_id].process_controlnet_conditioning(conditioning) + processed_conditionings.append(model_output) + return processed_conditionings + + def blockwise_forward(self, image, conditionings: list[torch.Tensor], controlnet_inputs: list[ControlNetInput], progress_id, num_inference_steps, block_id, **kwargs): + res = 0 + for controlnet_input, conditioning in zip(controlnet_inputs, conditionings): + progress = (num_inference_steps - 1 - progress_id) / max(num_inference_steps - 1, 1) + if progress > controlnet_input.start + (1e-4) or progress < controlnet_input.end - (1e-4): + continue + model_output = self.models[controlnet_input.controlnet_id].blockwise_forward(image, conditioning, block_id) + res = res + model_output * controlnet_input.scale + return res + + +class QwenImageUnit_ShapeChecker(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("height", "width"), + output_params=("height", "width"), + ) + + def process(self, pipe: QwenImagePipeline, height, width): + height, width = pipe.check_resize_height_width(height, width) + return {"height": height, "width": width} + + + +class QwenImageUnit_NoiseInitializer(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("height", "width", "seed", "rand_device", "layer_num"), + output_params=("noise",), + ) + + def process(self, pipe: QwenImagePipeline, height, width, seed, rand_device, layer_num): + if layer_num is None: + noise = pipe.generate_noise((1, 16, height//8, width//8), seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype) + else: + noise = pipe.generate_noise((layer_num + 1, 16, height//8, width//8), seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype) + return {"noise": noise} + + + +class QwenImageUnit_InputImageEmbedder(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("input_image", "noise", "tiled", "tile_size", "tile_stride"), + output_params=("latents", "input_latents"), + onload_model_names=("vae",) + ) + + def process(self, pipe: QwenImagePipeline, input_image, noise, tiled, tile_size, tile_stride): + if input_image is None: + return {"latents": noise, "input_latents": None} + pipe.load_models_to_device(['vae']) + if isinstance(input_image, list): + input_latents = [] + for image in input_image: + image = pipe.preprocess_image(image).to(device=pipe.device, dtype=pipe.torch_dtype) + input_latents.append(pipe.vae.encode(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)) + input_latents = torch.concat(input_latents, dim=0) + else: + image = pipe.preprocess_image(input_image).to(device=pipe.device, dtype=pipe.torch_dtype) + input_latents = pipe.vae.encode(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + if pipe.scheduler.training: + return {"latents": noise, "input_latents": input_latents} + else: + latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0]) + return {"latents": latents, "input_latents": input_latents} + + +class QwenImageUnit_LayerInputImageEmbedder(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("layer_input_image", "tiled", "tile_size", "tile_stride"), + output_params=("layer_input_latents",), + onload_model_names=("vae",) + ) + + def process(self, pipe: QwenImagePipeline, layer_input_image, tiled, tile_size, tile_stride): + if layer_input_image is None: + return {} + pipe.load_models_to_device(['vae']) + image = pipe.preprocess_image(layer_input_image).to(device=pipe.device, dtype=pipe.torch_dtype) + latents = pipe.vae.encode(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + return {"layer_input_latents": latents} + + +class QwenImageUnit_Inpaint(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("inpaint_mask", "height", "width", "inpaint_blur_size", "inpaint_blur_sigma"), + output_params=("inpaint_mask",), + ) + + def process(self, pipe: QwenImagePipeline, inpaint_mask, height, width, inpaint_blur_size, inpaint_blur_sigma): + if inpaint_mask is None: + return {} + inpaint_mask = pipe.preprocess_image(inpaint_mask.convert("RGB").resize((width // 8, height // 8)), min_value=0, max_value=1) + inpaint_mask = inpaint_mask.mean(dim=1, keepdim=True) + if inpaint_blur_size is not None and inpaint_blur_sigma is not None: + from torchvision.transforms import GaussianBlur + blur = GaussianBlur(kernel_size=inpaint_blur_size * 2 + 1, sigma=inpaint_blur_sigma) + inpaint_mask = blur(inpaint_mask) + return {"inpaint_mask": inpaint_mask} + + +class QwenImageUnit_PromptEmbedder(PipelineUnit): + def __init__(self): + super().__init__( + seperate_cfg=True, + input_params_posi={"prompt": "prompt"}, + input_params_nega={"prompt": "negative_prompt"}, + input_params=("edit_image",), + output_params=("prompt_emb", "prompt_emb_mask"), + onload_model_names=("text_encoder",) + ) + + def extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor): + bool_mask = mask.bool() + valid_lengths = bool_mask.sum(dim=1) + selected = hidden_states[bool_mask] + split_result = torch.split(selected, valid_lengths.tolist(), dim=0) + return split_result + + def calculate_dimensions(self, target_area, ratio): + width = math.sqrt(target_area * ratio) + height = width / ratio + width = round(width / 32) * 32 + height = round(height / 32) * 32 + return width, height + + def resize_image(self, image, target_area=384*384): + width, height = self.calculate_dimensions(target_area, image.size[0] / image.size[1]) + return image.resize((width, height)) + + def encode_prompt(self, pipe: QwenImagePipeline, prompt): + template = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" + drop_idx = 34 + txt = [template.format(e) for e in prompt] + model_inputs = pipe.tokenizer(txt, max_length=4096+drop_idx, padding=True, truncation=True, return_tensors="pt").to(pipe.device) + if model_inputs.input_ids.shape[1] >= 1024: + print(f"Warning!!! QwenImage model was trained on prompts up to 512 tokens. Current prompt requires {model_inputs['input_ids'].shape[1] - drop_idx} tokens, which may lead to unpredictable behavior.") + hidden_states = pipe.text_encoder(input_ids=model_inputs.input_ids, attention_mask=model_inputs.attention_mask, output_hidden_states=True,)[-1] + split_hidden_states = self.extract_masked_hidden(hidden_states, model_inputs.attention_mask) + split_hidden_states = [e[drop_idx:] for e in split_hidden_states] + return split_hidden_states + + def encode_prompt_edit(self, pipe: QwenImagePipeline, prompt, edit_image): + template = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n" + drop_idx = 64 + txt = [template.format(e) for e in prompt] + model_inputs = pipe.processor(text=txt, images=edit_image, padding=True, return_tensors="pt").to(pipe.device) + hidden_states = pipe.text_encoder(input_ids=model_inputs.input_ids, attention_mask=model_inputs.attention_mask, pixel_values=model_inputs.pixel_values, image_grid_thw=model_inputs.image_grid_thw, output_hidden_states=True,)[-1] + split_hidden_states = self.extract_masked_hidden(hidden_states, model_inputs.attention_mask) + split_hidden_states = [e[drop_idx:] for e in split_hidden_states] + return split_hidden_states + + def encode_prompt_edit_multi(self, pipe: QwenImagePipeline, prompt, edit_image): + template = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" + drop_idx = 64 + img_prompt_template = "Picture {}: <|vision_start|><|image_pad|><|vision_end|>" + base_img_prompt = "".join([img_prompt_template.format(i + 1) for i in range(len(edit_image))]) + txt = [template.format(base_img_prompt + e) for e in prompt] + edit_image = [self.resize_image(image) for image in edit_image] + model_inputs = pipe.processor(text=txt, images=edit_image, padding=True, return_tensors="pt").to(pipe.device) + hidden_states = pipe.text_encoder(input_ids=model_inputs.input_ids, attention_mask=model_inputs.attention_mask, pixel_values=model_inputs.pixel_values, image_grid_thw=model_inputs.image_grid_thw, output_hidden_states=True,)[-1] + split_hidden_states = self.extract_masked_hidden(hidden_states, model_inputs.attention_mask) + split_hidden_states = [e[drop_idx:] for e in split_hidden_states] + return split_hidden_states + + def process(self, pipe: QwenImagePipeline, prompt, edit_image=None) -> dict: + pipe.load_models_to_device(self.onload_model_names) + if pipe.text_encoder is not None: + prompt = [prompt] + if edit_image is None: + split_hidden_states = self.encode_prompt(pipe, prompt) + elif isinstance(edit_image, Image.Image): + split_hidden_states = self.encode_prompt_edit(pipe, prompt, edit_image) + else: + split_hidden_states = self.encode_prompt_edit_multi(pipe, prompt, edit_image) + attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] + max_seq_len = max([e.size(0) for e in split_hidden_states]) + prompt_embeds = torch.stack([torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]) + encoder_attention_mask = torch.stack([torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list]) + prompt_embeds = prompt_embeds.to(dtype=pipe.torch_dtype, device=pipe.device) + return {"prompt_emb": prompt_embeds, "prompt_emb_mask": encoder_attention_mask} + else: + return {} + + +class QwenImageUnit_EntityControl(PipelineUnit): + def __init__(self): + super().__init__( + take_over=True, + input_params=("eligen_entity_prompts", "width", "height", "eligen_enable_on_negative", "cfg_scale"), + output_params=("entity_prompt_emb", "entity_masks", "entity_prompt_emb_mask"), + onload_model_names=("text_encoder",) + ) + + def extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor): + bool_mask = mask.bool() + valid_lengths = bool_mask.sum(dim=1) + selected = hidden_states[bool_mask] + split_result = torch.split(selected, valid_lengths.tolist(), dim=0) + return split_result + + def get_prompt_emb(self, pipe: QwenImagePipeline, prompt) -> dict: + if pipe.text_encoder is not None: + prompt = [prompt] + template = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" + drop_idx = 34 + txt = [template.format(e) for e in prompt] + txt_tokens = pipe.tokenizer(txt, max_length=1024+drop_idx, padding=True, truncation=True, return_tensors="pt").to(pipe.device) + hidden_states = pipe.text_encoder(input_ids=txt_tokens.input_ids, attention_mask=txt_tokens.attention_mask, output_hidden_states=True,)[-1] + + split_hidden_states = self.extract_masked_hidden(hidden_states, txt_tokens.attention_mask) + split_hidden_states = [e[drop_idx:] for e in split_hidden_states] + attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] + max_seq_len = max([e.size(0) for e in split_hidden_states]) + prompt_embeds = torch.stack([torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]) + encoder_attention_mask = torch.stack([torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list]) + prompt_embeds = prompt_embeds.to(dtype=pipe.torch_dtype, device=pipe.device) + return {"prompt_emb": prompt_embeds, "prompt_emb_mask": encoder_attention_mask} + else: + return {} + + def preprocess_masks(self, pipe, masks, height, width, dim): + out_masks = [] + for mask in masks: + mask = pipe.preprocess_image(mask.resize((width, height), resample=Image.NEAREST)).mean(dim=1, keepdim=True) > 0 + mask = mask.repeat(1, dim, 1, 1).to(device=pipe.device, dtype=pipe.torch_dtype) + out_masks.append(mask) + return out_masks + + def prepare_entity_inputs(self, pipe, entity_prompts, entity_masks, width, height): + entity_masks = self.preprocess_masks(pipe, entity_masks, height//8, width//8, 1) + entity_masks = torch.cat(entity_masks, dim=0).unsqueeze(0) # b, n_mask, c, h, w + prompt_embs, prompt_emb_masks = [], [] + for entity_prompt in entity_prompts: + prompt_emb_dict = self.get_prompt_emb(pipe, entity_prompt) + prompt_embs.append(prompt_emb_dict['prompt_emb']) + prompt_emb_masks.append(prompt_emb_dict['prompt_emb_mask']) + return prompt_embs, prompt_emb_masks, entity_masks + + def prepare_eligen(self, pipe, prompt_emb_nega, eligen_entity_prompts, eligen_entity_masks, width, height, enable_eligen_on_negative, cfg_scale): + entity_prompt_emb_posi, entity_prompt_emb_posi_mask, entity_masks_posi = self.prepare_entity_inputs(pipe, eligen_entity_prompts, eligen_entity_masks, width, height) + if enable_eligen_on_negative and cfg_scale != 1.0: + entity_prompt_emb_nega = [prompt_emb_nega['prompt_emb']] * len(entity_prompt_emb_posi) + entity_prompt_emb_nega_mask = [prompt_emb_nega['prompt_emb_mask']] * len(entity_prompt_emb_posi) + entity_masks_nega = entity_masks_posi + else: + entity_prompt_emb_nega, entity_prompt_emb_nega_mask, entity_masks_nega = None, None, None + eligen_kwargs_posi = {"entity_prompt_emb": entity_prompt_emb_posi, "entity_masks": entity_masks_posi, "entity_prompt_emb_mask": entity_prompt_emb_posi_mask} + eligen_kwargs_nega = {"entity_prompt_emb": entity_prompt_emb_nega, "entity_masks": entity_masks_nega, "entity_prompt_emb_mask": entity_prompt_emb_nega_mask} + return eligen_kwargs_posi, eligen_kwargs_nega + + def process(self, pipe: QwenImagePipeline, inputs_shared, inputs_posi, inputs_nega): + eligen_entity_prompts, eligen_entity_masks = inputs_shared.get("eligen_entity_prompts", None), inputs_shared.get("eligen_entity_masks", None) + if eligen_entity_prompts is None or eligen_entity_masks is None or len(eligen_entity_prompts) == 0 or len(eligen_entity_masks) == 0: + return inputs_shared, inputs_posi, inputs_nega + pipe.load_models_to_device(self.onload_model_names) + eligen_enable_on_negative = inputs_shared.get("eligen_enable_on_negative", False) + eligen_kwargs_posi, eligen_kwargs_nega = self.prepare_eligen(pipe, inputs_nega, + eligen_entity_prompts, eligen_entity_masks, inputs_shared["width"], inputs_shared["height"], + eligen_enable_on_negative, inputs_shared["cfg_scale"]) + inputs_posi.update(eligen_kwargs_posi) + if inputs_shared.get("cfg_scale", 1.0) != 1.0: + inputs_nega.update(eligen_kwargs_nega) + return inputs_shared, inputs_posi, inputs_nega + + + +class QwenImageUnit_BlockwiseControlNet(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("blockwise_controlnet_inputs", "tiled", "tile_size", "tile_stride"), + output_params=("blockwise_controlnet_conditioning",), + onload_model_names=("vae",) + ) + + def apply_controlnet_mask_on_latents(self, pipe, latents, mask): + mask = (pipe.preprocess_image(mask) + 1) / 2 + mask = mask.mean(dim=1, keepdim=True) + mask = 1 - torch.nn.functional.interpolate(mask, size=latents.shape[-2:]) + latents = torch.concat([latents, mask], dim=1) + return latents + + def apply_controlnet_mask_on_image(self, pipe, image, mask): + mask = mask.resize(image.size) + mask = pipe.preprocess_image(mask).mean(dim=[0, 1]).cpu() + image = np.array(image) + image[mask > 0] = 0 + image = Image.fromarray(image) + return image + + def process(self, pipe: QwenImagePipeline, blockwise_controlnet_inputs: list[ControlNetInput], tiled, tile_size, tile_stride): + if blockwise_controlnet_inputs is None: + return {} + pipe.load_models_to_device(self.onload_model_names) + conditionings = [] + for controlnet_input in blockwise_controlnet_inputs: + image = controlnet_input.image + if controlnet_input.inpaint_mask is not None: + image = self.apply_controlnet_mask_on_image(pipe, image, controlnet_input.inpaint_mask) + + image = pipe.preprocess_image(image).to(device=pipe.device, dtype=pipe.torch_dtype) + image = pipe.vae.encode(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + + if controlnet_input.inpaint_mask is not None: + image = self.apply_controlnet_mask_on_latents(pipe, image, controlnet_input.inpaint_mask) + conditionings.append(image) + + return {"blockwise_controlnet_conditioning": conditionings} + + +class QwenImageUnit_EditImageEmbedder(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("edit_image", "tiled", "tile_size", "tile_stride", "edit_image_auto_resize"), + output_params=("edit_latents", "edit_image"), + onload_model_names=("vae",) + ) + + + def calculate_dimensions(self, target_area, ratio): + import math + width = math.sqrt(target_area * ratio) + height = width / ratio + width = round(width / 32) * 32 + height = round(height / 32) * 32 + return width, height + + + def edit_image_auto_resize(self, edit_image): + calculated_width, calculated_height = self.calculate_dimensions(1024 * 1024, edit_image.size[0] / edit_image.size[1]) + return edit_image.resize((calculated_width, calculated_height)) + + + def process(self, pipe: QwenImagePipeline, edit_image, tiled, tile_size, tile_stride, edit_image_auto_resize=False): + if edit_image is None: + return {} + pipe.load_models_to_device(self.onload_model_names) + if isinstance(edit_image, Image.Image): + resized_edit_image = self.edit_image_auto_resize(edit_image) if edit_image_auto_resize else edit_image + edit_image = pipe.preprocess_image(resized_edit_image).to(device=pipe.device, dtype=pipe.torch_dtype) + edit_latents = pipe.vae.encode(edit_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + else: + resized_edit_image, edit_latents = [], [] + for image in edit_image: + if edit_image_auto_resize: + image = self.edit_image_auto_resize(image) + resized_edit_image.append(image) + image = pipe.preprocess_image(image).to(device=pipe.device, dtype=pipe.torch_dtype) + latents = pipe.vae.encode(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + edit_latents.append(latents) + return {"edit_latents": edit_latents, "edit_image": resized_edit_image} + + +class QwenImageUnit_Image2LoRAEncode(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("image2lora_images",), + output_params=("image2lora_x", "image2lora_residual", "image2lora_residual_highres"), + onload_model_names=("siglip2_image_encoder", "dinov3_image_encoder", "text_encoder"), + ) + from ..core.data.operators import ImageCropAndResize + self.processor_lowres = ImageCropAndResize(height=28*8, width=28*8) + self.processor_highres = ImageCropAndResize(height=1024, width=1024) + + def extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor): + bool_mask = mask.bool() + valid_lengths = bool_mask.sum(dim=1) + selected = hidden_states[bool_mask] + split_result = torch.split(selected, valid_lengths.tolist(), dim=0) + return split_result + + def encode_prompt_edit(self, pipe: QwenImagePipeline, prompt, edit_image): + prompt = [prompt] + template = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n" + drop_idx = 64 + txt = [template.format(e) for e in prompt] + model_inputs = pipe.processor(text=txt, images=edit_image, padding=True, return_tensors="pt").to(pipe.device) + hidden_states = pipe.text_encoder(input_ids=model_inputs.input_ids, attention_mask=model_inputs.attention_mask, pixel_values=model_inputs.pixel_values, image_grid_thw=model_inputs.image_grid_thw, output_hidden_states=True,)[-1] + split_hidden_states = self.extract_masked_hidden(hidden_states, model_inputs.attention_mask) + split_hidden_states = [e[drop_idx:] for e in split_hidden_states] + max_seq_len = max([e.size(0) for e in split_hidden_states]) + prompt_embeds = torch.stack([torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]) + prompt_embeds = prompt_embeds.to(dtype=pipe.torch_dtype, device=pipe.device) + return prompt_embeds.view(1, -1) + + def encode_images_using_siglip2(self, pipe: QwenImagePipeline, images: list[Image.Image]): + pipe.load_models_to_device(["siglip2_image_encoder"]) + embs = [] + for image in images: + image = self.processor_highres(image) + embs.append(pipe.siglip2_image_encoder(image).to(pipe.torch_dtype)) + embs = torch.stack(embs) + return embs + + def encode_images_using_dinov3(self, pipe: QwenImagePipeline, images: list[Image.Image]): + pipe.load_models_to_device(["dinov3_image_encoder"]) + embs = [] + for image in images: + image = self.processor_highres(image) + embs.append(pipe.dinov3_image_encoder(image).to(pipe.torch_dtype)) + embs = torch.stack(embs) + return embs + + def encode_images_using_qwenvl(self, pipe: QwenImagePipeline, images: list[Image.Image], highres=False): + pipe.load_models_to_device(["text_encoder"]) + embs = [] + for image in images: + image = self.processor_highres(image) if highres else self.processor_lowres(image) + embs.append(self.encode_prompt_edit(pipe, prompt="", edit_image=image)) + embs = torch.stack(embs) + return embs + + def encode_images(self, pipe: QwenImagePipeline, images: list[Image.Image]): + if images is None: + return {} + if not isinstance(images, list): + images = [images] + embs_siglip2 = self.encode_images_using_siglip2(pipe, images) + embs_dinov3 = self.encode_images_using_dinov3(pipe, images) + x = torch.concat([embs_siglip2, embs_dinov3], dim=-1) + residual = None + residual_highres = None + if pipe.image2lora_coarse is not None: + residual = self.encode_images_using_qwenvl(pipe, images, highres=False) + if pipe.image2lora_fine is not None: + residual_highres = self.encode_images_using_qwenvl(pipe, images, highres=True) + return x, residual, residual_highres + + def process(self, pipe: QwenImagePipeline, image2lora_images): + if image2lora_images is None: + return {} + x, residual, residual_highres = self.encode_images(pipe, image2lora_images) + return {"image2lora_x": x, "image2lora_residual": residual, "image2lora_residual_highres": residual_highres} + + +class QwenImageUnit_Image2LoRADecode(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("image2lora_x", "image2lora_residual", "image2lora_residual_highres"), + output_params=("lora",), + onload_model_names=("image2lora_coarse", "image2lora_fine", "image2lora_style"), + ) + + def process(self, pipe: QwenImagePipeline, image2lora_x, image2lora_residual, image2lora_residual_highres): + if image2lora_x is None: + return {} + loras = [] + if pipe.image2lora_style is not None: + pipe.load_models_to_device(["image2lora_style"]) + for x in image2lora_x: + loras.append(pipe.image2lora_style(x=x, residual=None)) + if pipe.image2lora_coarse is not None: + pipe.load_models_to_device(["image2lora_coarse"]) + for x, residual in zip(image2lora_x, image2lora_residual): + loras.append(pipe.image2lora_coarse(x=x, residual=residual)) + if pipe.image2lora_fine is not None: + pipe.load_models_to_device(["image2lora_fine"]) + for x, residual in zip(image2lora_x, image2lora_residual_highres): + loras.append(pipe.image2lora_fine(x=x, residual=residual)) + lora = merge_lora(loras, alpha=1 / len(image2lora_x)) + return {"lora": lora} + + +class QwenImageUnit_ContextImageEmbedder(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("context_image", "height", "width", "tiled", "tile_size", "tile_stride", "layer_input_image"), + output_params=("context_latents",), + onload_model_names=("vae",) + ) + + def process(self, pipe: QwenImagePipeline, context_image, height, width, tiled, tile_size, tile_stride, layer_input_image=None): + if context_image is None: + return {} + if layer_input_image is not None: + context_image = context_image.convert("RGBA") + pipe.load_models_to_device(self.onload_model_names) + context_image = pipe.preprocess_image(context_image.resize((width, height))).to(device=pipe.device, dtype=pipe.torch_dtype) + context_latents = pipe.vae.encode(context_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + return {"context_latents": context_latents} + + +def model_fn_qwen_image( + dit: QwenImageDiT = None, + blockwise_controlnet: QwenImageBlockwiseMultiControlNet = None, + latents=None, + timestep=None, + prompt_emb=None, + prompt_emb_mask=None, + height=None, + width=None, + blockwise_controlnet_conditioning=None, + blockwise_controlnet_inputs=None, + progress_id=0, + num_inference_steps=1, + entity_prompt_emb=None, + entity_prompt_emb_mask=None, + entity_masks=None, + edit_latents=None, + layer_input_latents=None, + layer_num=None, + context_latents=None, + enable_fp8_attention=False, + use_gradient_checkpointing=False, + use_gradient_checkpointing_offload=False, + edit_rope_interpolation=False, + zero_cond_t=False, + **kwargs +): + if layer_num is None: + layer_num = 1 + img_shapes = [(1, latents.shape[2]//2, latents.shape[3]//2)] + else: + layer_num = layer_num + 1 + img_shapes = [(1, latents.shape[2]//2, latents.shape[3]//2)] * layer_num + txt_seq_lens = prompt_emb_mask.sum(dim=1).tolist() + timestep = timestep / 1000 + + image = rearrange(latents, "(B N) C (H P) (W Q) -> B (N H W) (C P Q)", H=height//16, W=width//16, P=2, Q=2, N=layer_num) + image_seq_len = image.shape[1] + + if context_latents is not None: + img_shapes += [(context_latents.shape[0], context_latents.shape[2]//2, context_latents.shape[3]//2)] + context_image = rearrange(context_latents, "B C (H P) (W Q) -> B (H W) (C P Q)", H=context_latents.shape[2]//2, W=context_latents.shape[3]//2, P=2, Q=2) + image = torch.cat([image, context_image], dim=1) + if edit_latents is not None: + edit_latents_list = edit_latents if isinstance(edit_latents, list) else [edit_latents] + img_shapes += [(e.shape[0], e.shape[2]//2, e.shape[3]//2) for e in edit_latents_list] + edit_image = [rearrange(e, "B C (H P) (W Q) -> B (H W) (C P Q)", H=e.shape[2]//2, W=e.shape[3]//2, P=2, Q=2) for e in edit_latents_list] + image = torch.cat([image] + edit_image, dim=1) + if layer_input_latents is not None: + layer_num = layer_num + 1 + img_shapes += [(layer_input_latents.shape[0], layer_input_latents.shape[2]//2, layer_input_latents.shape[3]//2)] + layer_input_latents = rearrange(layer_input_latents, "B C (H P) (W Q) -> B (H W) (C P Q)", P=2, Q=2) + image = torch.cat([image, layer_input_latents], dim=1) + + image = dit.img_in(image) + if zero_cond_t: + timestep = torch.cat([timestep, timestep * 0], dim=0) + modulate_index = torch.tensor( + [[0] * prod(sample[0]) + [1] * sum([prod(s) for s in sample[1:]]) for sample in [img_shapes]], + device=timestep.device, + dtype=torch.int, + ) + else: + modulate_index = None + conditioning = dit.time_text_embed( + timestep, + image.dtype, + addition_t_cond=None if not dit.time_text_embed.use_additional_t_cond else torch.tensor([0]).to(device=image.device, dtype=torch.long) + ) + + if entity_prompt_emb is not None: + text, image_rotary_emb, attention_mask = dit.process_entity_masks( + latents, prompt_emb, prompt_emb_mask, entity_prompt_emb, entity_prompt_emb_mask, + entity_masks, height, width, image, img_shapes, + ) + else: + text = dit.txt_in(dit.txt_norm(prompt_emb)) + if edit_rope_interpolation: + image_rotary_emb = dit.pos_embed.forward_sampling(img_shapes, txt_seq_lens, device=latents.device) + else: + image_rotary_emb = dit.pos_embed(img_shapes, txt_seq_lens, device=latents.device) + attention_mask = None + + if blockwise_controlnet_conditioning is not None: + blockwise_controlnet_conditioning = blockwise_controlnet.preprocess( + blockwise_controlnet_inputs, blockwise_controlnet_conditioning) + + for block_id, block in enumerate(dit.transformer_blocks): + text, image = gradient_checkpoint_forward( + block, + use_gradient_checkpointing, + use_gradient_checkpointing_offload, + image=image, + text=text, + temb=conditioning, + image_rotary_emb=image_rotary_emb, + attention_mask=attention_mask, + enable_fp8_attention=enable_fp8_attention, + modulate_index=modulate_index, + ) + if blockwise_controlnet_conditioning is not None: + image_slice = image[:, :image_seq_len].clone() + controlnet_output = blockwise_controlnet.blockwise_forward( + image=image_slice, conditionings=blockwise_controlnet_conditioning, + controlnet_inputs=blockwise_controlnet_inputs, block_id=block_id, + progress_id=progress_id, num_inference_steps=num_inference_steps, + ) + image[:, :image_seq_len] = image_slice + controlnet_output + + if zero_cond_t: + conditioning = conditioning.chunk(2, dim=0)[0] + image = dit.norm_out(image, conditioning) + image = dit.proj_out(image) + image = image[:, :image_seq_len] + + latents = rearrange(image, "B (N H W) (C P Q) -> (B N) C (H P) (W Q)", H=height//16, W=width//16, P=2, Q=2, B=1) + return latents diff --git a/diffsynth/pipelines/wan_video.py b/diffsynth/pipelines/wan_video.py new file mode 100644 index 0000000000000000000000000000000000000000..f175929ede8c51b9b3ecce60d101c8c4898ac072 --- /dev/null +++ b/diffsynth/pipelines/wan_video.py @@ -0,0 +1,1799 @@ +import torch, types +import numpy as np +from PIL import Image +from einops import repeat +from typing import Optional, Union +from einops import rearrange +import numpy as np +from PIL import Image +from tqdm import tqdm +from typing import Optional +from typing_extensions import Literal +from transformers import Wav2Vec2Processor + +from ..core.device.npu_compatible_device import get_device_type +from ..diffusion import FlowMatchScheduler +from ..core import ModelConfig, gradient_checkpoint_forward +from ..diffusion.base_pipeline import BasePipeline, PipelineUnit + +from ..models.wan_video_dit import WanModel, sinusoidal_embedding_1d +from ..models.wan_video_dit_s2v import rope_precompute +from ..models.wan_video_text_encoder import WanTextEncoder, HuggingfaceTokenizer +from ..models.wan_video_vae import WanVideoVAE +from ..models.wan_video_image_encoder import WanImageEncoder +from ..models.wan_video_vace import VaceWanModel, tokenize_target_text +from ..models.wan_video_motion_controller import WanMotionControllerModel +from ..models.wan_video_animate_adapter import WanAnimateAdapter +from ..models.wan_video_mot import MotWanModel +from ..models.wav2vec import WanS2VAudioEncoder +from ..models.longcat_video_dit import LongCatVideoTransformer3DModel + + +class WanVideoPipeline(BasePipeline): + + def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16): + super().__init__( + device=device, torch_dtype=torch_dtype, + height_division_factor=16, width_division_factor=16, time_division_factor=4, time_division_remainder=1 + ) + self.scheduler = FlowMatchScheduler("Wan") + self.tokenizer: HuggingfaceTokenizer = None + self.audio_processor: Wav2Vec2Processor = None + self.text_encoder: WanTextEncoder = None + self.image_encoder: WanImageEncoder = None + self.dit: WanModel = None + self.dit2: WanModel = None + self.vae: WanVideoVAE = None + self.motion_controller: WanMotionControllerModel = None + self.vace: VaceWanModel = None + self.vace2: VaceWanModel = None + self.vap: MotWanModel = None + self.animate_adapter: WanAnimateAdapter = None + self.audio_encoder: WanS2VAudioEncoder = None + self.in_iteration_models = ("dit", "motion_controller", "vace", "animate_adapter", "vap") + self.in_iteration_models_2 = ("dit2", "motion_controller", "vace2", "animate_adapter", "vap") + self.units = [ + WanVideoUnit_ShapeChecker(), + WanVideoUnit_NoiseInitializer(), + WanVideoUnit_PromptEmbedder(), + WanVideoUnit_S2V(), + WanVideoUnit_InputVideoEmbedder(), + WanVideoUnit_ImageEmbedderVAE(), + WanVideoUnit_ImageEmbedderCLIP(), + WanVideoUnit_ImageEmbedderFused(), + WanVideoUnit_FunControl(), + WanVideoUnit_FunReference(), + WanVideoUnit_FunCameraControl(), + WanVideoUnit_SpeedControl(), + WanVideoUnit_VACE(), + WanVideoUnit_AnimateVideoSplit(), + WanVideoUnit_AnimatePoseLatents(), + WanVideoUnit_AnimateFacePixelValues(), + WanVideoUnit_AnimateInpaint(), + WanVideoUnit_VAP(), + WanVideoUnit_UnifiedSequenceParallel(), + WanVideoUnit_TeaCache(), + WanVideoUnit_CfgMerger(), + WanVideoUnit_LongCatVideo(), + WanVideoUnit_WanToDance_ProcessInputs(), + WanVideoUnit_WanToDance_RefImageEmbedder(), + WanVideoUnit_WanToDance_ImageKeyframesEmbedder(), + ] + self.post_units = [ + WanVideoPostUnit_S2V(), + ] + self.model_fn = model_fn_wan_video + self.compilable_models = ["dit", "dit2"] + + + def enable_usp(self): + from ..utils.xfuser import get_sequence_parallel_world_size, usp_attn_forward, usp_dit_forward, usp_vace_forward + + for block in self.dit.blocks: + block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn) + self.dit.forward = types.MethodType(usp_dit_forward, self.dit) + if self.dit2 is not None: + for block in self.dit2.blocks: + block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn) + self.dit2.forward = types.MethodType(usp_dit_forward, self.dit2) + if self.vace is not None: + for block in self.vace.vace_blocks: + block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn) + self.vace.forward = types.MethodType(usp_vace_forward, self.vace) + if self.vace2 is not None: + for block in self.vace2.vace_blocks: + block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn) + self.vace2.forward = types.MethodType(usp_vace_forward, self.vace2) + self.sp_size = get_sequence_parallel_world_size() + self.use_unified_sequence_parallel = True + + + @staticmethod + def from_pretrained( + torch_dtype: torch.dtype = torch.bfloat16, + device: Union[str, torch.device] = get_device_type(), + model_configs: list[ModelConfig] = [], + tokenizer_config: ModelConfig = ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"), + audio_processor_config: ModelConfig = None, + redirect_common_files: bool = True, + use_usp: bool = False, + vram_limit: float = None, + ): + # Redirect model path + if redirect_common_files: + redirect_dict = { + "models_t5_umt5-xxl-enc-bf16.pth": ("DiffSynth-Studio/Wan-Series-Converted-Safetensors", "models_t5_umt5-xxl-enc-bf16.safetensors"), + "models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth": ("DiffSynth-Studio/Wan-Series-Converted-Safetensors", "models_clip_open-clip-xlm-roberta-large-vit-huge-14.safetensors"), + "Wan2.1_VAE.pth": ("DiffSynth-Studio/Wan-Series-Converted-Safetensors", "Wan2.1_VAE.safetensors"), + "Wan2.2_VAE.pth": ("DiffSynth-Studio/Wan-Series-Converted-Safetensors", "Wan2.2_VAE.safetensors"), + } + for model_config in model_configs: + if model_config.origin_file_pattern is None or model_config.model_id is None: + continue + if model_config.origin_file_pattern in redirect_dict and model_config.model_id != redirect_dict[model_config.origin_file_pattern][0]: + print(f"To avoid repeatedly downloading model files, ({model_config.model_id}, {model_config.origin_file_pattern}) is redirected to {redirect_dict[model_config.origin_file_pattern]}. You can use `redirect_common_files=False` to disable file redirection.") + model_config.model_id = redirect_dict[model_config.origin_file_pattern][0] + model_config.origin_file_pattern = redirect_dict[model_config.origin_file_pattern][1] + + if use_usp: + from ..utils.xfuser import initialize_usp + initialize_usp(device) + import torch.distributed as dist + from ..core.device.npu_compatible_device import get_device_name + if dist.is_available() and dist.is_initialized(): + device = get_device_name() + # Initialize pipeline + pipe = WanVideoPipeline(device=device, torch_dtype=torch_dtype) + model_pool = pipe.download_and_load_models(model_configs, vram_limit) + + # Fetch models + pipe.text_encoder = model_pool.fetch_model("wan_video_text_encoder") + dit = model_pool.fetch_model("wan_video_dit", index=2) + if isinstance(dit, list): + pipe.dit, pipe.dit2 = dit + else: + pipe.dit = dit + pipe.vae = model_pool.fetch_model("wan_video_vae") + pipe.image_encoder = model_pool.fetch_model("wan_video_image_encoder") + pipe.motion_controller = model_pool.fetch_model("wan_video_motion_controller") + vace = model_pool.fetch_model("wan_video_vace", index=2) + if isinstance(vace, list): + pipe.vace, pipe.vace2 = vace + else: + pipe.vace = vace + pipe.vap = model_pool.fetch_model("wan_video_vap") + pipe.audio_encoder = model_pool.fetch_model("wans2v_audio_encoder") + pipe.animate_adapter = model_pool.fetch_model("wan_video_animate_adapter") + + # Size division factor + if pipe.vae is not None: + pipe.height_division_factor = pipe.vae.upsampling_factor * 2 + pipe.width_division_factor = pipe.vae.upsampling_factor * 2 + + # Initialize tokenizer and processor + if tokenizer_config is not None: + tokenizer_config.download_if_necessary() + pipe.tokenizer = HuggingfaceTokenizer(name=tokenizer_config.path, seq_len=512, clean='whitespace') + if audio_processor_config is not None: + audio_processor_config.download_if_necessary() + pipe.audio_processor = Wav2Vec2Processor.from_pretrained(audio_processor_config.path) + + # Unified Sequence Parallel + if use_usp: pipe.enable_usp() + + # VRAM Management + pipe.vram_management_enabled = pipe.check_vram_management_state() + return pipe + + + @torch.no_grad() + def __call__( + self, + # Prompt + prompt: str, + negative_prompt: Optional[str] = "", + # Image-to-video + input_image: Optional[Image.Image] = None, + # First-last-frame-to-video + end_image: Optional[Image.Image] = None, + # Video-to-video + input_video: Optional[list[Image.Image]] = None, + denoising_strength: Optional[float] = 1.0, + # Speech-to-video + input_audio: Optional[np.array] = None, + audio_embeds: Optional[torch.Tensor] = None, + audio_sample_rate: Optional[int] = 16000, + s2v_pose_video: Optional[list[Image.Image]] = None, + s2v_pose_latents: Optional[torch.Tensor] = None, + motion_video: Optional[list[Image.Image]] = None, + # ControlNet + control_video: Optional[list[Image.Image]] = None, + reference_image: Optional[Image.Image] = None, + # Camera control + camera_control_direction: Optional[Literal["Left", "Right", "Up", "Down", "LeftUp", "LeftDown", "RightUp", "RightDown"]] = None, + camera_control_speed: Optional[float] = 1/54, + camera_control_origin: Optional[tuple] = (0, 0.532139961, 0.946026558, 0.5, 0.5, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0), + # VACE + vace_video: Optional[list[Image.Image]] = None, + vace_video_mask: Optional[Image.Image] = None, + vace_reference_image: Optional[Image.Image] = None, + glyph_video: Optional[list[Image.Image]] = None, + target_text: Optional[str] = None, + vace_scale: Optional[float] = 1.0, + # Animate + animate_pose_video: Optional[list[Image.Image]] = None, + animate_face_video: Optional[list[Image.Image]] = None, + animate_inpaint_video: Optional[list[Image.Image]] = None, + animate_mask_video: Optional[list[Image.Image]] = None, + # VAP + vap_video: Optional[list[Image.Image]] = None, + vap_prompt: Optional[str] = " ", + negative_vap_prompt: Optional[str] = " ", + # Randomness + seed: Optional[int] = None, + rand_device: Optional[str] = "cpu", + # Shape + height: Optional[int] = 480, + width: Optional[int] = 832, + num_frames=81, + # Classifier-free guidance + cfg_scale: Optional[float] = 5.0, + cfg_merge: Optional[bool] = False, + # Boundary + switch_DiT_boundary: Optional[float] = 0.875, + # Scheduler + num_inference_steps: Optional[int] = 50, + sigma_shift: Optional[float] = 5.0, + # Speed control + motion_bucket_id: Optional[int] = None, + # LongCat-Video + longcat_video: Optional[list[Image.Image]] = None, + # VAE tiling + tiled: Optional[bool] = True, + tile_size: Optional[tuple[int, int]] = (30, 52), + tile_stride: Optional[tuple[int, int]] = (15, 26), + # Sliding window + sliding_window_size: Optional[int] = None, + sliding_window_stride: Optional[int] = None, + # Teacache + tea_cache_l1_thresh: Optional[float] = None, + tea_cache_model_id: Optional[str] = "", + # WanToDance + wantodance_music_path: Optional[str] = None, + wantodance_reference_image: Optional[Image.Image] = None, + wantodance_fps: Optional[float] = 30, + wantodance_keyframes: Optional[list[Image.Image]] = None, + wantodance_keyframes_mask: Optional[list[int]] = None, + framewise_decoding: bool = False, + # progress_bar + progress_bar_cmd=tqdm, + output_type: Optional[Literal["quantized", "floatpoint"]] = "quantized", + ): + # Scheduler + self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift) + + # Inputs + inputs_posi = { + "prompt": prompt, + "vap_prompt": vap_prompt, + "tea_cache_l1_thresh": tea_cache_l1_thresh, "tea_cache_model_id": tea_cache_model_id, "num_inference_steps": num_inference_steps, + } + inputs_nega = { + "negative_prompt": negative_prompt, + "negative_vap_prompt": negative_vap_prompt, + "tea_cache_l1_thresh": tea_cache_l1_thresh, "tea_cache_model_id": tea_cache_model_id, "num_inference_steps": num_inference_steps, + } + inputs_shared = { + "input_image": input_image, + "end_image": end_image, + "input_video": input_video, "denoising_strength": denoising_strength, + "control_video": control_video, "reference_image": reference_image, + "camera_control_direction": camera_control_direction, "camera_control_speed": camera_control_speed, "camera_control_origin": camera_control_origin, + "vace_video": vace_video, "vace_video_mask": vace_video_mask, "vace_reference_image": vace_reference_image, "glyph_video": glyph_video, "target_text": target_text, "vace_scale": vace_scale, + "seed": seed, "rand_device": rand_device, + "height": height, "width": width, "num_frames": num_frames, + "cfg_scale": cfg_scale, "cfg_merge": cfg_merge, + "sigma_shift": sigma_shift, + "motion_bucket_id": motion_bucket_id, + "longcat_video": longcat_video, + "tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride, + "sliding_window_size": sliding_window_size, "sliding_window_stride": sliding_window_stride, + "input_audio": input_audio, "audio_sample_rate": audio_sample_rate, "s2v_pose_video": s2v_pose_video, "audio_embeds": audio_embeds, "s2v_pose_latents": s2v_pose_latents, "motion_video": motion_video, + "animate_pose_video": animate_pose_video, "animate_face_video": animate_face_video, "animate_inpaint_video": animate_inpaint_video, "animate_mask_video": animate_mask_video, + "vap_video": vap_video, + "wantodance_music_path": wantodance_music_path, "wantodance_reference_image": wantodance_reference_image, "wantodance_fps": wantodance_fps, + "wantodance_keyframes": wantodance_keyframes, "wantodance_keyframes_mask": wantodance_keyframes_mask, + "framewise_decoding": framewise_decoding, + } + for unit in self.units: + inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega) + + # Denoise + self.load_models_to_device(self.in_iteration_models) + models = {name: getattr(self, name) for name in self.in_iteration_models} + for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): + # Switch DiT if necessary + if timestep.item() < switch_DiT_boundary * 1000 and self.dit2 is not None and not models["dit"] is self.dit2: + self.load_models_to_device(self.in_iteration_models_2) + models["dit"] = self.dit2 + models["vace"] = self.vace2 + + # Timestep + timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device) + + # Inference + noise_pred_posi = self.model_fn(**models, **inputs_shared, **inputs_posi, timestep=timestep) + if cfg_scale != 1.0: + if cfg_merge: + noise_pred_posi, noise_pred_nega = noise_pred_posi.chunk(2, dim=0) + else: + noise_pred_nega = self.model_fn(**models, **inputs_shared, **inputs_nega, timestep=timestep) + noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega) + else: + noise_pred = noise_pred_posi + + # Scheduler + inputs_shared["latents"] = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], inputs_shared["latents"]) + if "first_frame_latents" in inputs_shared: + inputs_shared["latents"][:, :, 0:1] = inputs_shared["first_frame_latents"] + + # Pixel-Anchored Denoising: anchor non-mask regions to original video latents + if hasattr(self, "_anchor_latents") and self._anchor_latents is not None: + anchor = self._anchor_latents + mask_lat = self._anchor_mask_latent + # Compute what the original latents look like at the next timestep + if progress_id + 1 < len(self.scheduler.timesteps): + next_t = self.scheduler.timesteps[progress_id + 1] + next_sigma = self.scheduler.sigmas[torch.argmin((self.scheduler.timesteps - next_t.cpu()).abs())] + anchor_noisy = (1 - next_sigma) * anchor + next_sigma * self._anchor_noise + else: + anchor_noisy = anchor # final step: use clean original + # Blend: mask=1 means text region (keep generated), mask=0 means background (use original) + inputs_shared["latents"] = mask_lat * inputs_shared["latents"] + (1 - mask_lat) * anchor_noisy + + # VACE (TODO: remove it) + if vace_reference_image is not None or (animate_pose_video is not None and animate_face_video is not None): + if vace_reference_image is not None and isinstance(vace_reference_image, list): + f = len(vace_reference_image) + else: + f = 1 + inputs_shared["latents"] = inputs_shared["latents"][:, :, f:] + # post-denoising, pre-decoding processing logic + for unit in self.post_units: + inputs_shared, _, _ = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega) + # Decode + self.load_models_to_device(['vae']) + if framewise_decoding: + video = self.vae.decode_framewise(inputs_shared["latents"], device=self.device) + else: + video = self.vae.decode(inputs_shared["latents"], device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + if output_type == "quantized": + video = self.vae_output_to_video(video) + elif output_type == "floatpoint": + pass + self.load_models_to_device([]) + return video + + + +class WanVideoUnit_ShapeChecker(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("height", "width", "num_frames"), + output_params=("height", "width", "num_frames"), + ) + + def process(self, pipe: WanVideoPipeline, height, width, num_frames): + height, width, num_frames = pipe.check_resize_height_width(height, width, num_frames) + return {"height": height, "width": width, "num_frames": num_frames} + + + +class WanVideoUnit_NoiseInitializer(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("height", "width", "num_frames", "seed", "rand_device", "vace_reference_image"), + output_params=("noise",) + ) + + def process(self, pipe: WanVideoPipeline, height, width, num_frames, seed, rand_device, vace_reference_image): + length = (num_frames - 1) // 4 + 1 + if vace_reference_image is not None: + f = len(vace_reference_image) if isinstance(vace_reference_image, list) else 1 + length += f + shape = (1, pipe.vae.model.z_dim, length, height // pipe.vae.upsampling_factor, width // pipe.vae.upsampling_factor) + noise = pipe.generate_noise(shape, seed=seed, rand_device=rand_device) + if vace_reference_image is not None: + noise = torch.concat((noise[:, :, -f:], noise[:, :, :-f]), dim=2) + return {"noise": noise} + + + +class WanVideoUnit_InputVideoEmbedder(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("input_video", "noise", "tiled", "tile_size", "tile_stride", "vace_reference_image", "framewise_decoding"), + output_params=("latents", "input_latents"), + onload_model_names=("vae",) + ) + + def process(self, pipe: WanVideoPipeline, input_video, noise, tiled, tile_size, tile_stride, vace_reference_image, framewise_decoding): + if input_video is None: + return {"latents": noise} + pipe.load_models_to_device(self.onload_model_names) + input_video = pipe.preprocess_video(input_video) + if framewise_decoding: + input_latents = pipe.vae.encode_framewise(input_video, device=pipe.device) + else: + input_latents = pipe.vae.encode(input_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) + if vace_reference_image is not None: + if not isinstance(vace_reference_image, list): + vace_reference_image = [vace_reference_image] + vace_reference_image = pipe.preprocess_video(vace_reference_image) + vace_reference_latents = pipe.vae.encode(vace_reference_image, device=pipe.device).to(dtype=pipe.torch_dtype, device=pipe.device) + input_latents = torch.concat([vace_reference_latents, input_latents], dim=2) + if pipe.scheduler.training: + return {"latents": noise, "input_latents": input_latents} + else: + latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0]) + return {"latents": latents} + + + +class WanVideoUnit_PromptEmbedder(PipelineUnit): + def __init__(self): + super().__init__( + seperate_cfg=True, + input_params_posi={"prompt": "prompt", "positive": "positive"}, + input_params_nega={"prompt": "negative_prompt", "positive": "positive"}, + output_params=("context",), + onload_model_names=("text_encoder",) + ) + + def encode_prompt(self, pipe: WanVideoPipeline, prompt): + ids, mask = pipe.tokenizer(prompt, return_mask=True, add_special_tokens=True) + ids = ids.to(pipe.device) + mask = mask.to(pipe.device) + seq_lens = mask.gt(0).sum(dim=1).long() + prompt_emb = pipe.text_encoder(ids, mask) + for i, v in enumerate(seq_lens): + prompt_emb[:, v:] = 0 + return prompt_emb + + def process(self, pipe: WanVideoPipeline, prompt, positive) -> dict: + pipe.load_models_to_device(self.onload_model_names) + prompt_emb = self.encode_prompt(pipe, prompt) + return {"context": prompt_emb} + + + +class WanVideoUnit_ImageEmbedderCLIP(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("input_image", "end_image", "height", "width"), + output_params=("clip_feature",), + onload_model_names=("image_encoder",) + ) + + def process(self, pipe: WanVideoPipeline, input_image, end_image, height, width): + if input_image is None or pipe.image_encoder is None or not pipe.dit.require_clip_embedding: + return {} + pipe.load_models_to_device(self.onload_model_names) + image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device) + clip_context = pipe.image_encoder.encode_image([image]) + if end_image is not None: + end_image = pipe.preprocess_image(end_image.resize((width, height))).to(pipe.device) + if pipe.dit.has_image_pos_emb: + clip_context = torch.concat([clip_context, pipe.image_encoder.encode_image([end_image])], dim=1) + clip_context = clip_context.to(dtype=pipe.torch_dtype, device=pipe.device) + return {"clip_feature": clip_context} + + + +class WanVideoUnit_ImageEmbedderVAE(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("input_image", "end_image", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride"), + output_params=("y",), + onload_model_names=("vae",) + ) + + def process(self, pipe: WanVideoPipeline, input_image, end_image, num_frames, height, width, tiled, tile_size, tile_stride): + if input_image is None or not pipe.dit.require_vae_embedding: + return {} + pipe.load_models_to_device(self.onload_model_names) + image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device) + msk = torch.ones(1, num_frames, height//8, width//8, device=pipe.device) + msk[:, 1:] = 0 + if end_image is not None: + end_image = pipe.preprocess_image(end_image.resize((width, height))).to(pipe.device) + vae_input = torch.concat([image.transpose(0,1), torch.zeros(3, num_frames-2, height, width).to(image.device), end_image.transpose(0,1)],dim=1) + msk[:, -1:] = 1 + else: + vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1) + + msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) + msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8) + msk = msk.transpose(1, 2)[0] + + y = pipe.vae.encode([vae_input.to(dtype=pipe.torch_dtype, device=pipe.device)], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0] + y = y.to(dtype=pipe.torch_dtype, device=pipe.device) + y = torch.concat([msk, y]) + y = y.unsqueeze(0) + y = y.to(dtype=pipe.torch_dtype, device=pipe.device) + return {"y": y} + + + +class WanVideoUnit_ImageEmbedderFused(PipelineUnit): + """ + Encode input image to latents using VAE. This unit is for Wan-AI/Wan2.2-TI2V-5B. + """ + def __init__(self): + super().__init__( + input_params=("input_image", "latents", "height", "width", "tiled", "tile_size", "tile_stride"), + output_params=("latents", "fuse_vae_embedding_in_latents", "first_frame_latents"), + onload_model_names=("vae",) + ) + + def process(self, pipe: WanVideoPipeline, input_image, latents, height, width, tiled, tile_size, tile_stride): + if input_image is None or not pipe.dit.fuse_vae_embedding_in_latents: + return {} + pipe.load_models_to_device(self.onload_model_names) + image = pipe.preprocess_image(input_image.resize((width, height))).transpose(0, 1) + z = pipe.vae.encode([image], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + latents[:, :, 0: 1] = z + return {"latents": latents, "fuse_vae_embedding_in_latents": True, "first_frame_latents": z} + + + +class WanVideoUnit_FunControl(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("control_video", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride", "clip_feature", "y", "latents"), + output_params=("clip_feature", "y"), + onload_model_names=("vae",) + ) + + def process(self, pipe: WanVideoPipeline, control_video, num_frames, height, width, tiled, tile_size, tile_stride, clip_feature, y, latents): + if control_video is None: + return {} + pipe.load_models_to_device(self.onload_model_names) + control_video = pipe.preprocess_video(control_video) + control_latents = pipe.vae.encode(control_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) + control_latents = control_latents.to(dtype=pipe.torch_dtype, device=pipe.device) + y_dim = pipe.dit.in_dim-control_latents.shape[1]-latents.shape[1] + if clip_feature is None or y is None: + clip_feature = torch.zeros((1, 257, 1280), dtype=pipe.torch_dtype, device=pipe.device) + y = torch.zeros((1, y_dim, (num_frames - 1) // 4 + 1, height//8, width//8), dtype=pipe.torch_dtype, device=pipe.device) + else: + y = y[:, -y_dim:] + y = torch.concat([control_latents, y], dim=1) + return {"clip_feature": clip_feature, "y": y} + + + +class WanVideoUnit_FunReference(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("reference_image", "height", "width", "reference_image"), + output_params=("reference_latents", "clip_feature"), + onload_model_names=("vae", "image_encoder") + ) + + def process(self, pipe: WanVideoPipeline, reference_image, height, width): + if reference_image is None: + return {} + pipe.load_models_to_device(self.onload_model_names) + reference_image = reference_image.resize((width, height)) + reference_latents = pipe.preprocess_video([reference_image]) + reference_latents = pipe.vae.encode(reference_latents, device=pipe.device) + if pipe.image_encoder is None: + return {"reference_latents": reference_latents} + clip_feature = pipe.preprocess_image(reference_image) + clip_feature = pipe.image_encoder.encode_image([clip_feature]) + return {"reference_latents": reference_latents, "clip_feature": clip_feature} + + + +class WanVideoUnit_FunCameraControl(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("height", "width", "num_frames", "camera_control_direction", "camera_control_speed", "camera_control_origin", "latents", "input_image", "tiled", "tile_size", "tile_stride"), + output_params=("control_camera_latents_input", "y"), + onload_model_names=("vae",) + ) + + def process(self, pipe: WanVideoPipeline, height, width, num_frames, camera_control_direction, camera_control_speed, camera_control_origin, latents, input_image, tiled, tile_size, tile_stride): + if camera_control_direction is None: + return {} + pipe.load_models_to_device(self.onload_model_names) + camera_control_plucker_embedding = pipe.dit.control_adapter.process_camera_coordinates( + camera_control_direction, num_frames, height, width, camera_control_speed, camera_control_origin) + + control_camera_video = camera_control_plucker_embedding[:num_frames].permute([3, 0, 1, 2]).unsqueeze(0) + control_camera_latents = torch.concat( + [ + torch.repeat_interleave(control_camera_video[:, :, 0:1], repeats=4, dim=2), + control_camera_video[:, :, 1:] + ], dim=2 + ).transpose(1, 2) + b, f, c, h, w = control_camera_latents.shape + control_camera_latents = control_camera_latents.contiguous().view(b, f // 4, 4, c, h, w).transpose(2, 3) + control_camera_latents = control_camera_latents.contiguous().view(b, f // 4, c * 4, h, w).transpose(1, 2) + control_camera_latents_input = control_camera_latents.to(device=pipe.device, dtype=pipe.torch_dtype) + + input_image = input_image.resize((width, height)) + input_latents = pipe.preprocess_video([input_image]) + input_latents = pipe.vae.encode(input_latents, device=pipe.device) + y = torch.zeros_like(latents).to(pipe.device) + y[:, :, :1] = input_latents + y = y.to(dtype=pipe.torch_dtype, device=pipe.device) + + if y.shape[1] != pipe.dit.in_dim - latents.shape[1]: + image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device) + vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1) + y = pipe.vae.encode([vae_input.to(dtype=pipe.torch_dtype, device=pipe.device)], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0] + y = y.to(dtype=pipe.torch_dtype, device=pipe.device) + msk = torch.ones(1, num_frames, height//8, width//8, device=pipe.device) + msk[:, 1:] = 0 + msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) + msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8) + msk = msk.transpose(1, 2)[0] + y = torch.cat([msk,y]) + y = y.unsqueeze(0) + y = y.to(dtype=pipe.torch_dtype, device=pipe.device) + return {"control_camera_latents_input": control_camera_latents_input, "y": y} + + + +class WanVideoUnit_SpeedControl(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("motion_bucket_id",), + output_params=("motion_bucket_id",) + ) + + def process(self, pipe: WanVideoPipeline, motion_bucket_id): + if motion_bucket_id is None: + return {} + motion_bucket_id = torch.Tensor((motion_bucket_id,)).to(dtype=pipe.torch_dtype, device=pipe.device) + return {"motion_bucket_id": motion_bucket_id} + + + +class WanVideoUnit_VACE(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("vace_video", "vace_video_mask", "vace_reference_image", "glyph_video", "target_text", "vace_scale", "height", "width", "num_frames", "tiled", "tile_size", "tile_stride"), + output_params=("vace_context", "glyph_latent", "target_text_ids", "vace_scale"), + onload_model_names=("vae",) + ) + + def process( + self, + pipe: WanVideoPipeline, + vace_video, vace_video_mask, vace_reference_image, glyph_video, target_text, vace_scale, + height, width, num_frames, + tiled, tile_size, tile_stride + ): + if vace_video is not None or vace_video_mask is not None or vace_reference_image is not None or glyph_video is not None or target_text is not None: + pipe.load_models_to_device(["vae"]) + if vace_video is None: + vace_video = torch.zeros((1, 3, num_frames, height, width), dtype=pipe.torch_dtype, device=pipe.device) + else: + vace_video = pipe.preprocess_video(vace_video) + + if vace_video_mask is None: + vace_video_mask = torch.ones_like(vace_video) + else: + vace_video_mask = pipe.preprocess_video(vace_video_mask, min_value=0, max_value=1) + + inactive = vace_video * (1 - vace_video_mask) + 0 * vace_video_mask + reactive = vace_video * vace_video_mask + 0 * (1 - vace_video_mask) + inactive = pipe.vae.encode(inactive, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) + reactive = pipe.vae.encode(reactive, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) + vace_video_latents = torch.concat((inactive, reactive), dim=1) + + # TextVACE: encode glyph video separately (NOT concat, used via cross-attention) + glyph_latent = None + if glyph_video is not None: + glyph_video_tensor = pipe.preprocess_video(glyph_video) + glyph_latent = pipe.vae.encode( + glyph_video_tensor, device=pipe.device, + tiled=tiled, tile_size=tile_size, tile_stride=tile_stride + ).to(dtype=pipe.torch_dtype, device=pipe.device) + + vace_mask_latents = rearrange(vace_video_mask[0,0], "T (H P) (W Q) -> 1 (P Q) T H W", P=8, Q=8) + vace_mask_latents = torch.nn.functional.interpolate(vace_mask_latents, size=((vace_mask_latents.shape[2] + 3) // 4, vace_mask_latents.shape[3], vace_mask_latents.shape[4]), mode='nearest-exact') + + if vace_reference_image is None: + pass + else: + if not isinstance(vace_reference_image,list): + vace_reference_image = [vace_reference_image] + + vace_reference_image = pipe.preprocess_video(vace_reference_image) + + bs, c, f, h, w = vace_reference_image.shape + new_vace_ref_images = [] + for j in range(f): + new_vace_ref_images.append(vace_reference_image[0, :, j:j+1]) + vace_reference_image = new_vace_ref_images + + vace_reference_latents = pipe.vae.encode(vace_reference_image, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) + vace_reference_latents = torch.concat((vace_reference_latents, torch.zeros_like(vace_reference_latents)), dim=1) + vace_reference_latents = [u.unsqueeze(0) for u in vace_reference_latents] + + vace_video_latents = torch.concat((*vace_reference_latents, vace_video_latents), dim=2) + vace_mask_latents = torch.concat((torch.zeros_like(vace_mask_latents[:, :, :f]), vace_mask_latents), dim=2) + + vace_context = torch.concat((vace_video_latents, vace_mask_latents), dim=1) + + # Tokenize target text for TargetTextEncoder (v3 mode) + target_text_ids = None + if target_text is not None and isinstance(target_text, str): + vace_model = getattr(pipe, 'vace', None) + vocab_size = 8192 + max_len = 64 + if vace_model is not None and hasattr(vace_model, 'target_text_encoder'): + vocab_size = vace_model.target_text_encoder.vocab_size + max_len = vace_model.target_text_encoder.max_len + ids = tokenize_target_text(target_text, max_len=max_len, vocab_size=vocab_size) + target_text_ids = torch.tensor([ids], dtype=torch.long, device=pipe.device) + + return {"vace_context": vace_context, "glyph_latent": glyph_latent, "target_text_ids": target_text_ids, "vace_scale": vace_scale} + else: + return {"vace_context": None, "glyph_latent": None, "target_text_ids": None, "vace_scale": vace_scale} + + +class WanVideoUnit_VAP(PipelineUnit): + def __init__(self): + super().__init__( + take_over=True, + onload_model_names=("text_encoder", "vae", "image_encoder"), + input_params=("vap_video", "vap_prompt", "negative_vap_prompt", "end_image", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride"), + output_params=("vap_clip_feature", "vap_hidden_state", "context_vap") + ) + + def encode_prompt(self, pipe: WanVideoPipeline, prompt): + ids, mask = pipe.tokenizer(prompt, return_mask=True, add_special_tokens=True) + ids = ids.to(pipe.device) + mask = mask.to(pipe.device) + seq_lens = mask.gt(0).sum(dim=1).long() + prompt_emb = pipe.text_encoder(ids, mask) + for i, v in enumerate(seq_lens): + prompt_emb[:, v:] = 0 + return prompt_emb + + def process(self, pipe: WanVideoPipeline, inputs_shared, inputs_posi, inputs_nega): + if inputs_shared.get("vap_video") is None: + return inputs_shared, inputs_posi, inputs_nega + else: + # 1. encode vap prompt + pipe.load_models_to_device(["text_encoder"]) + vap_prompt, negative_vap_prompt = inputs_posi.get("vap_prompt", ""), inputs_nega.get("negative_vap_prompt", "") + vap_prompt_emb = self.encode_prompt(pipe, vap_prompt) + negative_vap_prompt_emb = self.encode_prompt(pipe, negative_vap_prompt) + inputs_posi.update({"context_vap":vap_prompt_emb}) + inputs_nega.update({"context_vap":negative_vap_prompt_emb}) + # 2. prepare vap image clip embedding + pipe.load_models_to_device(["vae", "image_encoder"]) + vap_video, end_image = inputs_shared.get("vap_video"), inputs_shared.get("end_image") + + num_frames, height, width = inputs_shared.get("num_frames"),inputs_shared.get("height"), inputs_shared.get("width") + + image_vap = pipe.preprocess_image(vap_video[0].resize((width, height))).to(pipe.device) + + vap_clip_context = pipe.image_encoder.encode_image([image_vap]) + if end_image is not None: + vap_end_image = pipe.preprocess_image(vap_video[-1].resize((width, height))).to(pipe.device) + if pipe.dit.has_image_pos_emb: + vap_clip_context = torch.concat([vap_clip_context, pipe.image_encoder.encode_image([vap_end_image])], dim=1) + vap_clip_context = vap_clip_context.to(dtype=pipe.torch_dtype, device=pipe.device) + inputs_shared.update({"vap_clip_feature":vap_clip_context}) + + # 3. prepare vap latents + msk = torch.ones(1, num_frames, height//8, width//8, device=pipe.device) + msk[:, 1:] = 0 + if end_image is not None: + msk[:, -1:] = 1 + last_image_vap = pipe.preprocess_image(vap_video[-1].resize((width, height))).to(pipe.device) + vae_input = torch.concat([image_vap.transpose(0,1), torch.zeros(3, num_frames-2, height, width).to(image_vap.device), last_image_vap.transpose(0,1)],dim=1) + else: + vae_input = torch.concat([image_vap.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image_vap.device)], dim=1) + + msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) + msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8) + msk = msk.transpose(1, 2)[0] + + tiled,tile_size,tile_stride = inputs_shared.get("tiled"), inputs_shared.get("tile_size"), inputs_shared.get("tile_stride") + + y = pipe.vae.encode([vae_input.to(dtype=pipe.torch_dtype, device=pipe.device)], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0] + y = y.to(dtype=pipe.torch_dtype, device=pipe.device) + y = torch.concat([msk, y]) + y = y.unsqueeze(0) + y = y.to(dtype=pipe.torch_dtype, device=pipe.device) + + vap_video = pipe.preprocess_video(vap_video) + vap_latent = pipe.vae.encode(vap_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) + + vap_latent = torch.concat([vap_latent,y], dim=1).to(dtype=pipe.torch_dtype, device=pipe.device) + inputs_shared.update({"vap_hidden_state":vap_latent}) + + return inputs_shared, inputs_posi, inputs_nega + + + +class WanVideoUnit_UnifiedSequenceParallel(PipelineUnit): + def __init__(self): + super().__init__(input_params=(), output_params=("use_unified_sequence_parallel",)) + + def process(self, pipe: WanVideoPipeline): + if hasattr(pipe, "use_unified_sequence_parallel"): + if pipe.use_unified_sequence_parallel: + return {"use_unified_sequence_parallel": True} + return {} + + + +class WanVideoUnit_TeaCache(PipelineUnit): + def __init__(self): + super().__init__( + seperate_cfg=True, + input_params_posi={"num_inference_steps": "num_inference_steps", "tea_cache_l1_thresh": "tea_cache_l1_thresh", "tea_cache_model_id": "tea_cache_model_id"}, + input_params_nega={"num_inference_steps": "num_inference_steps", "tea_cache_l1_thresh": "tea_cache_l1_thresh", "tea_cache_model_id": "tea_cache_model_id"}, + output_params=("tea_cache",) + ) + + def process(self, pipe: WanVideoPipeline, num_inference_steps, tea_cache_l1_thresh, tea_cache_model_id): + if tea_cache_l1_thresh is None: + return {} + return {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id)} + + + +class WanVideoUnit_CfgMerger(PipelineUnit): + def __init__(self): + super().__init__(take_over=True) + self.concat_tensor_names = ["context", "clip_feature", "y", "reference_latents"] + + def process(self, pipe: WanVideoPipeline, inputs_shared, inputs_posi, inputs_nega): + if not inputs_shared["cfg_merge"]: + return inputs_shared, inputs_posi, inputs_nega + for name in self.concat_tensor_names: + tensor_posi = inputs_posi.get(name) + tensor_nega = inputs_nega.get(name) + tensor_shared = inputs_shared.get(name) + if tensor_posi is not None and tensor_nega is not None: + inputs_shared[name] = torch.concat((tensor_posi, tensor_nega), dim=0) + elif tensor_shared is not None: + inputs_shared[name] = torch.concat((tensor_shared, tensor_shared), dim=0) + inputs_posi.clear() + inputs_nega.clear() + return inputs_shared, inputs_posi, inputs_nega + + +class WanVideoUnit_S2V(PipelineUnit): + def __init__(self): + super().__init__( + take_over=True, + onload_model_names=("audio_encoder", "vae",), + input_params=("input_audio", "audio_embeds", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride", "audio_sample_rate", "s2v_pose_video", "s2v_pose_latents", "motion_video"), + output_params=("audio_embeds", "motion_latents", "drop_motion_frames", "s2v_pose_latents"), + ) + + def process_audio(self, pipe: WanVideoPipeline, input_audio, audio_sample_rate, num_frames, fps=16, audio_embeds=None, return_all=False): + if audio_embeds is not None: + return {"audio_embeds": audio_embeds} + pipe.load_models_to_device(["audio_encoder"]) + audio_embeds = pipe.audio_encoder.get_audio_feats_per_inference(input_audio, audio_sample_rate, pipe.audio_processor, fps=fps, batch_frames=num_frames-1, dtype=pipe.torch_dtype, device=pipe.device) + if return_all: + return audio_embeds + else: + return {"audio_embeds": audio_embeds[0]} + + def process_motion_latents(self, pipe: WanVideoPipeline, height, width, tiled, tile_size, tile_stride, motion_video=None): + pipe.load_models_to_device(["vae"]) + motion_frames = 73 + kwargs = {} + if motion_video is not None: + assert motion_video.shape[2] == motion_frames, f"motion video must have {motion_frames} frames, but got {motion_video.shape[2]}" + motion_latents = motion_video + kwargs["drop_motion_frames"] = False + else: + motion_latents = torch.zeros([1, 3, motion_frames, height, width], dtype=pipe.torch_dtype, device=pipe.device) + kwargs["drop_motion_frames"] = True + motion_latents = pipe.vae.encode(motion_latents, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) + kwargs.update({"motion_latents": motion_latents}) + return kwargs + + def process_pose_cond(self, pipe: WanVideoPipeline, s2v_pose_video, num_frames, height, width, tiled, tile_size, tile_stride, s2v_pose_latents=None, num_repeats=1, return_all=False): + if s2v_pose_latents is not None: + return {"s2v_pose_latents": s2v_pose_latents} + if s2v_pose_video is None: + return {"s2v_pose_latents": None} + pipe.load_models_to_device(["vae"]) + infer_frames = num_frames - 1 + input_video = pipe.preprocess_video(s2v_pose_video)[:, :, :infer_frames * num_repeats] + # pad if not enough frames + padding_frames = infer_frames * num_repeats - input_video.shape[2] + input_video = torch.cat([input_video, -torch.ones(1, 3, padding_frames, height, width, device=input_video.device, dtype=input_video.dtype)], dim=2) + input_videos = input_video.chunk(num_repeats, dim=2) + pose_conds = [] + for r in range(num_repeats): + cond = input_videos[r] + cond = torch.cat([cond[:, :, 0:1].repeat(1, 1, 1, 1, 1), cond], dim=2) + cond_latents = pipe.vae.encode(cond, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) + pose_conds.append(cond_latents[:,:,1:]) + if return_all: + return pose_conds + else: + return {"s2v_pose_latents": pose_conds[0]} + + def process(self, pipe: WanVideoPipeline, inputs_shared, inputs_posi, inputs_nega): + if (inputs_shared.get("input_audio") is None and inputs_shared.get("audio_embeds") is None) or pipe.audio_encoder is None or pipe.audio_processor is None: + return inputs_shared, inputs_posi, inputs_nega + num_frames, height, width, tiled, tile_size, tile_stride = inputs_shared.get("num_frames"), inputs_shared.get("height"), inputs_shared.get("width"), inputs_shared.get("tiled"), inputs_shared.get("tile_size"), inputs_shared.get("tile_stride") + input_audio, audio_embeds, audio_sample_rate = inputs_shared.pop("input_audio", None), inputs_shared.pop("audio_embeds", None), inputs_shared.get("audio_sample_rate", 16000) + s2v_pose_video, s2v_pose_latents, motion_video = inputs_shared.pop("s2v_pose_video", None), inputs_shared.pop("s2v_pose_latents", None), inputs_shared.pop("motion_video", None) + + audio_input_positive = self.process_audio(pipe, input_audio, audio_sample_rate, num_frames, audio_embeds=audio_embeds) + inputs_posi.update(audio_input_positive) + inputs_nega.update({"audio_embeds": 0.0 * audio_input_positive["audio_embeds"]}) + + inputs_shared.update(self.process_motion_latents(pipe, height, width, tiled, tile_size, tile_stride, motion_video)) + inputs_shared.update(self.process_pose_cond(pipe, s2v_pose_video, num_frames, height, width, tiled, tile_size, tile_stride, s2v_pose_latents=s2v_pose_latents)) + return inputs_shared, inputs_posi, inputs_nega + + @staticmethod + def pre_calculate_audio_pose(pipe: WanVideoPipeline, input_audio=None, audio_sample_rate=16000, s2v_pose_video=None, num_frames=81, height=448, width=832, fps=16, tiled=True, tile_size=(30, 52), tile_stride=(15, 26)): + assert pipe.audio_encoder is not None and pipe.audio_processor is not None, "Please load audio encoder and audio processor first." + shapes = WanVideoUnit_ShapeChecker().process(pipe, height, width, num_frames) + height, width, num_frames = shapes["height"], shapes["width"], shapes["num_frames"] + unit = WanVideoUnit_S2V() + audio_embeds = unit.process_audio(pipe, input_audio, audio_sample_rate, num_frames, fps, return_all=True) + pose_latents = unit.process_pose_cond(pipe, s2v_pose_video, num_frames, height, width, num_repeats=len(audio_embeds), return_all=True, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + pose_latents = None if s2v_pose_video is None else pose_latents + return audio_embeds, pose_latents, len(audio_embeds) + + +class WanVideoPostUnit_S2V(PipelineUnit): + def __init__(self): + super().__init__(input_params=("latents", "motion_latents", "drop_motion_frames")) + + def process(self, pipe: WanVideoPipeline, latents, motion_latents, drop_motion_frames): + if pipe.audio_encoder is None or motion_latents is None or drop_motion_frames: + return {} + latents = torch.cat([motion_latents, latents[:,:,1:]], dim=2) + return {"latents": latents} + + +class WanVideoUnit_AnimateVideoSplit(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("input_video", "animate_pose_video", "animate_face_video", "animate_inpaint_video", "animate_mask_video"), + output_params=("animate_pose_video", "animate_face_video", "animate_inpaint_video", "animate_mask_video") + ) + + def process(self, pipe: WanVideoPipeline, input_video, animate_pose_video, animate_face_video, animate_inpaint_video, animate_mask_video): + if input_video is None: + return {} + if animate_pose_video is not None: + animate_pose_video = animate_pose_video[:len(input_video) - 4] + if animate_face_video is not None: + animate_face_video = animate_face_video[:len(input_video) - 4] + if animate_inpaint_video is not None: + animate_inpaint_video = animate_inpaint_video[:len(input_video) - 4] + if animate_mask_video is not None: + animate_mask_video = animate_mask_video[:len(input_video) - 4] + return {"animate_pose_video": animate_pose_video, "animate_face_video": animate_face_video, "animate_inpaint_video": animate_inpaint_video, "animate_mask_video": animate_mask_video} + + +class WanVideoUnit_AnimatePoseLatents(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("animate_pose_video", "tiled", "tile_size", "tile_stride"), + output_params=("pose_latents",), + onload_model_names=("vae",) + ) + + def process(self, pipe: WanVideoPipeline, animate_pose_video, tiled, tile_size, tile_stride): + if animate_pose_video is None: + return {} + pipe.load_models_to_device(self.onload_model_names) + animate_pose_video = pipe.preprocess_video(animate_pose_video) + pose_latents = pipe.vae.encode(animate_pose_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) + return {"pose_latents": pose_latents} + + +class WanVideoUnit_AnimateFacePixelValues(PipelineUnit): + def __init__(self): + super().__init__( + take_over=True, + input_params=("animate_face_video",), + output_params=("face_pixel_values"), + ) + + def process(self, pipe: WanVideoPipeline, inputs_shared, inputs_posi, inputs_nega): + if inputs_shared.get("animate_face_video", None) is None: + return inputs_shared, inputs_posi, inputs_nega + inputs_posi["face_pixel_values"] = pipe.preprocess_video(inputs_shared["animate_face_video"]) + inputs_nega["face_pixel_values"] = torch.zeros_like(inputs_posi["face_pixel_values"]) - 1 + return inputs_shared, inputs_posi, inputs_nega + + +class WanVideoUnit_AnimateInpaint(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("animate_inpaint_video", "animate_mask_video", "input_image", "tiled", "tile_size", "tile_stride"), + output_params=("y",), + onload_model_names=("vae",) + ) + + def get_i2v_mask(self, lat_t, lat_h, lat_w, mask_len=1, mask_pixel_values=None, device=get_device_type()): + if mask_pixel_values is None: + msk = torch.zeros(1, (lat_t-1) * 4 + 1, lat_h, lat_w, device=device) + else: + msk = mask_pixel_values.clone() + msk[:, :mask_len] = 1 + msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) + msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w) + msk = msk.transpose(1, 2)[0] + return msk + + def process(self, pipe: WanVideoPipeline, animate_inpaint_video, animate_mask_video, input_image, tiled, tile_size, tile_stride): + if animate_inpaint_video is None or animate_mask_video is None: + return {} + pipe.load_models_to_device(self.onload_model_names) + + bg_pixel_values = pipe.preprocess_video(animate_inpaint_video) + y_reft = pipe.vae.encode(bg_pixel_values, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0].to(dtype=pipe.torch_dtype, device=pipe.device) + _, lat_t, lat_h, lat_w = y_reft.shape + + ref_pixel_values = pipe.preprocess_video([input_image]) + ref_latents = pipe.vae.encode(ref_pixel_values, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) + mask_ref = self.get_i2v_mask(1, lat_h, lat_w, 1, device=pipe.device) + y_ref = torch.concat([mask_ref, ref_latents[0]]).to(dtype=torch.bfloat16, device=pipe.device) + + mask_pixel_values = 1 - pipe.preprocess_video(animate_mask_video, max_value=1, min_value=0) + mask_pixel_values = rearrange(mask_pixel_values, "b c t h w -> (b t) c h w") + mask_pixel_values = torch.nn.functional.interpolate(mask_pixel_values, size=(lat_h, lat_w), mode='nearest') + mask_pixel_values = rearrange(mask_pixel_values, "(b t) c h w -> b t c h w", b=1)[:,:,0] + msk_reft = self.get_i2v_mask(lat_t, lat_h, lat_w, 0, mask_pixel_values=mask_pixel_values, device=pipe.device) + + y_reft = torch.concat([msk_reft, y_reft]).to(dtype=torch.bfloat16, device=pipe.device) + y = torch.concat([y_ref, y_reft], dim=1).unsqueeze(0) + return {"y": y} + + +class WanVideoUnit_LongCatVideo(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("longcat_video",), + output_params=("longcat_latents",), + onload_model_names=("vae",) + ) + + def process(self, pipe: WanVideoPipeline, longcat_video): + if longcat_video is None: + return {} + pipe.load_models_to_device(self.onload_model_names) + longcat_video = pipe.preprocess_video(longcat_video) + longcat_latents = pipe.vae.encode(longcat_video, device=pipe.device).to(dtype=pipe.torch_dtype, device=pipe.device) + return {"longcat_latents": longcat_latents} + + +class WanVideoUnit_WanToDance_ProcessInputs(PipelineUnit): + def __init__(self): + super().__init__( + take_over=True, + ) + + def get_music_base_feature(self, music_path, fps=30): + import librosa + hop_length = 512 + sr = fps * hop_length + data, sr = librosa.load(music_path, sr=sr) + sr = 22050 + envelope = librosa.onset.onset_strength(y=data, sr=sr) + mfcc = librosa.feature.mfcc(y=data, sr=sr, n_mfcc=20).T + chroma = librosa.feature.chroma_cens( + y=data, sr=sr, hop_length=hop_length, n_chroma=12 + ).T + peak_idxs = librosa.onset.onset_detect( + onset_envelope=envelope.flatten(), sr=sr, hop_length=hop_length + ) + peak_onehot = np.zeros_like(envelope, dtype=np.float32) + peak_onehot[peak_idxs] = 1.0 + start_bpm = librosa.beat.tempo(y=librosa.load(music_path)[0])[0] + _, beat_idxs = librosa.beat.beat_track( + onset_envelope=envelope, + sr=sr, + hop_length=hop_length, + start_bpm=start_bpm, + tightness=100, + ) + beat_onehot = np.zeros_like(envelope, dtype=np.float32) + beat_onehot[beat_idxs] = 1.0 + audio_feature = np.concatenate( + [envelope[:, None], mfcc, chroma, peak_onehot[:, None], beat_onehot[:, None]], + axis=-1, + ) + return torch.from_numpy(audio_feature) + + def process(self, pipe: WanVideoPipeline, inputs_shared, inputs_posi, inputs_nega): + if pipe.dit.wantodance_enable_global: + inputs_nega["skip_9th_layer"] = True + if inputs_shared.get("wantodance_music_path", None) is not None: + inputs_shared["music_feature"] = self.get_music_base_feature(inputs_shared["wantodance_music_path"]).to(dtype=pipe.torch_dtype, device=pipe.device) + return inputs_shared, inputs_posi, inputs_nega + + +class WanVideoUnit_WanToDance_RefImageEmbedder(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("wantodance_reference_image", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride"), + output_params=("wantodance_refimage_feature",), + onload_model_names=("image_encoder", "vae") + ) + + def process(self, pipe: WanVideoPipeline, wantodance_reference_image, num_frames, height, width, tiled, tile_size, tile_stride): + if wantodance_reference_image is None: + return {} + pipe.load_models_to_device(self.onload_model_names) + if isinstance(wantodance_reference_image, list): + wantodance_reference_image = wantodance_reference_image[0] + image = pipe.preprocess_image(wantodance_reference_image.resize((width, height))).to(pipe.device) # B,C,H,W;B=1 + refimage_feature = pipe.image_encoder.encode_image([image]) + refimage_feature = refimage_feature.to(dtype=pipe.torch_dtype, device=pipe.device) + return {"wantodance_refimage_feature": refimage_feature} + + +class WanVideoUnit_WanToDance_ImageKeyframesEmbedder(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("wantodance_keyframes", "wantodance_keyframes_mask", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride"), + output_params=("clip_feature", "y"), + onload_model_names=("image_encoder", "vae") + ) + + def process(self, pipe: WanVideoPipeline, wantodance_keyframes, wantodance_keyframes_mask, num_frames, height, width, tiled, tile_size, tile_stride): + if wantodance_keyframes is None: + return {} + wantodance_keyframes_mask = torch.tensor(wantodance_keyframes_mask) + pipe.load_models_to_device(self.onload_model_names) + images = [] + for input_image in wantodance_keyframes: + input_image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device) + images.append(input_image) + + clip_context = pipe.image_encoder.encode_image(images[:1]) # 取第一帧作为clip输入 + msk = torch.zeros(1, num_frames, height//8, width//8, device=pipe.device) + msk[:, wantodance_keyframes_mask==1, :, :] = torch.ones(1, height//8, width//8, device=pipe.device) # set keyframes mask to 1 + + images = [image.transpose(0, 1) for image in images] # 3, num_frames, h, w + images = torch.concat(images, dim=1) + vae_input = images + + msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) # expand first frame mask, N to N + 3 + msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8) + msk = msk.transpose(1, 2)[0] + + y = pipe.vae.encode([vae_input.to(dtype=pipe.torch_dtype, device=pipe.device)], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0] + y = y.to(dtype=pipe.torch_dtype, device=pipe.device) + y = torch.concat([msk, y]) + y = y.unsqueeze(0) + clip_context = clip_context.to(dtype=pipe.torch_dtype, device=pipe.device) + y = y.to(dtype=pipe.torch_dtype, device=pipe.device) + return {"clip_feature": clip_context, "y": y} + + +class TeaCache: + def __init__(self, num_inference_steps, rel_l1_thresh, model_id): + self.num_inference_steps = num_inference_steps + self.step = 0 + self.accumulated_rel_l1_distance = 0 + self.previous_modulated_input = None + self.rel_l1_thresh = rel_l1_thresh + self.previous_residual = None + self.previous_hidden_states = None + + self.coefficients_dict = { + "Wan2.1-T2V-1.3B": [-5.21862437e+04, 9.23041404e+03, -5.28275948e+02, 1.36987616e+01, -4.99875664e-02], + "Wan2.1-T2V-14B": [-3.03318725e+05, 4.90537029e+04, -2.65530556e+03, 5.87365115e+01, -3.15583525e-01], + "Wan2.1-I2V-14B-480P": [2.57151496e+05, -3.54229917e+04, 1.40286849e+03, -1.35890334e+01, 1.32517977e-01], + "Wan2.1-I2V-14B-720P": [ 8.10705460e+03, 2.13393892e+03, -3.72934672e+02, 1.66203073e+01, -4.17769401e-02], + } + if model_id not in self.coefficients_dict: + supported_model_ids = ", ".join([i for i in self.coefficients_dict]) + raise ValueError(f"{model_id} is not a supported TeaCache model id. Please choose a valid model id in ({supported_model_ids}).") + self.coefficients = self.coefficients_dict[model_id] + + def check(self, dit: WanModel, x, t_mod): + modulated_inp = t_mod.clone() + if self.step == 0 or self.step == self.num_inference_steps - 1: + should_calc = True + self.accumulated_rel_l1_distance = 0 + else: + coefficients = self.coefficients + rescale_func = np.poly1d(coefficients) + self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item()) + if self.accumulated_rel_l1_distance < self.rel_l1_thresh: + should_calc = False + else: + should_calc = True + self.accumulated_rel_l1_distance = 0 + self.previous_modulated_input = modulated_inp + self.step += 1 + if self.step == self.num_inference_steps: + self.step = 0 + if should_calc: + self.previous_hidden_states = x.clone() + return not should_calc + + def store(self, hidden_states): + self.previous_residual = hidden_states - self.previous_hidden_states + self.previous_hidden_states = None + + def update(self, hidden_states): + hidden_states = hidden_states + self.previous_residual + return hidden_states + + + +class TemporalTiler_BCTHW: + def __init__(self): + pass + + def build_1d_mask(self, length, left_bound, right_bound, border_width): + x = torch.ones((length,)) + if border_width == 0: + return x + + shift = 0.5 + if not left_bound: + x[:border_width] = (torch.arange(border_width) + shift) / border_width + if not right_bound: + x[-border_width:] = torch.flip((torch.arange(border_width) + shift) / border_width, dims=(0,)) + return x + + def build_mask(self, data, is_bound, border_width): + _, _, T, _, _ = data.shape + t = self.build_1d_mask(T, is_bound[0], is_bound[1], border_width[0]) + mask = repeat(t, "T -> 1 1 T 1 1") + return mask + + def run(self, model_fn, sliding_window_size, sliding_window_stride, computation_device, computation_dtype, model_kwargs, tensor_names, batch_size=None): + tensor_names = [tensor_name for tensor_name in tensor_names if model_kwargs.get(tensor_name) is not None] + tensor_dict = {tensor_name: model_kwargs[tensor_name] for tensor_name in tensor_names} + B, C, T, H, W = tensor_dict[tensor_names[0]].shape + if batch_size is not None: + B *= batch_size + data_device, data_dtype = tensor_dict[tensor_names[0]].device, tensor_dict[tensor_names[0]].dtype + value = torch.zeros((B, C, T, H, W), device=data_device, dtype=data_dtype) + weight = torch.zeros((1, 1, T, 1, 1), device=data_device, dtype=data_dtype) + for t in range(0, T, sliding_window_stride): + if t - sliding_window_stride >= 0 and t - sliding_window_stride + sliding_window_size >= T: + continue + t_ = min(t + sliding_window_size, T) + model_kwargs.update({ + tensor_name: tensor_dict[tensor_name][:, :, t: t_:, :].to(device=computation_device, dtype=computation_dtype) \ + for tensor_name in tensor_names + }) + model_output = model_fn(**model_kwargs).to(device=data_device, dtype=data_dtype) + mask = self.build_mask( + model_output, + is_bound=(t == 0, t_ == T), + border_width=(sliding_window_size - sliding_window_stride,) + ).to(device=data_device, dtype=data_dtype) + value[:, :, t: t_, :, :] += model_output * mask + weight[:, :, t: t_, :, :] += mask + value /= weight + model_kwargs.update(tensor_dict) + return value + + +def wantodance_get_single_freqs(freqs, frame_num, fps): + total_frame = int(30.0 / (fps + 1e-6) * frame_num + 0.5) + interval_frame = 30.0 / (fps + 1e-6) + freqs_0 = freqs[:total_frame] + freqs_new = torch.zeros((frame_num, freqs_0.shape[1]), device=freqs_0.device, dtype=freqs_0.dtype) + freqs_new[0] = freqs_0[0] + freqs_new[-1] = freqs_0[total_frame - 1] + for i in range(1, frame_num-1): + pos = i * interval_frame + low_idx = int(pos) + high_idx = min(low_idx + 1, total_frame - 1) + weight_high = pos - low_idx + weight_low = 1.0 - weight_high + freqs_new[i] = freqs_0[low_idx] * weight_low + freqs_0[high_idx] * weight_high + return freqs_new + + +def model_fn_wan_video( + dit: WanModel, + motion_controller: WanMotionControllerModel = None, + vace: VaceWanModel = None, + vap: MotWanModel = None, + animate_adapter: WanAnimateAdapter = None, + latents: torch.Tensor = None, + timestep: torch.Tensor = None, + context: torch.Tensor = None, + clip_feature: Optional[torch.Tensor] = None, + y: Optional[torch.Tensor] = None, + reference_latents = None, + vace_context = None, + glyph_latent = None, + target_text_ids = None, + vace_scale = 1.0, + audio_embeds: Optional[torch.Tensor] = None, + motion_latents: Optional[torch.Tensor] = None, + s2v_pose_latents: Optional[torch.Tensor] = None, + vap_hidden_state = None, + vap_clip_feature = None, + context_vap = None, + drop_motion_frames: bool = True, + tea_cache: TeaCache = None, + use_unified_sequence_parallel: bool = False, + motion_bucket_id: Optional[torch.Tensor] = None, + pose_latents=None, + face_pixel_values=None, + longcat_latents=None, + sliding_window_size: Optional[int] = None, + sliding_window_stride: Optional[int] = None, + cfg_merge: bool = False, + use_gradient_checkpointing: bool = False, + use_gradient_checkpointing_offload: bool = False, + control_camera_latents_input = None, + fuse_vae_embedding_in_latents: bool = False, + wantodance_refimage_feature = None, + wantodance_fps: float = 30.0, + music_feature = None, + skip_9th_layer: bool = False, + **kwargs, +): + if sliding_window_size is not None and sliding_window_stride is not None: + model_kwargs = dict( + dit=dit, + motion_controller=motion_controller, + vace=vace, + latents=latents, + timestep=timestep, + context=context, + clip_feature=clip_feature, + y=y, + reference_latents=reference_latents, + vace_context=vace_context, + vace_scale=vace_scale, + tea_cache=tea_cache, + use_unified_sequence_parallel=use_unified_sequence_parallel, + motion_bucket_id=motion_bucket_id, + ) + return TemporalTiler_BCTHW().run( + model_fn_wan_video, + sliding_window_size, sliding_window_stride, + latents.device, latents.dtype, + model_kwargs=model_kwargs, + tensor_names=["latents", "y"], + batch_size=2 if cfg_merge else 1 + ) + # LongCat-Video + if isinstance(dit, LongCatVideoTransformer3DModel): + return model_fn_longcat_video( + dit=dit, + latents=latents, + timestep=timestep, + context=context, + longcat_latents=longcat_latents, + use_gradient_checkpointing=use_gradient_checkpointing, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + ) + + # wan2.2 s2v + if audio_embeds is not None: + return model_fn_wans2v( + dit=dit, + latents=latents, + timestep=timestep, + context=context, + audio_embeds=audio_embeds, + motion_latents=motion_latents, + s2v_pose_latents=s2v_pose_latents, + drop_motion_frames=drop_motion_frames, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + use_gradient_checkpointing=use_gradient_checkpointing, + use_unified_sequence_parallel=use_unified_sequence_parallel, + ) + + if use_unified_sequence_parallel: + import torch.distributed as dist + from xfuser.core.distributed import (get_sequence_parallel_rank, + get_sequence_parallel_world_size, + get_sp_group) + + # Timestep + if dit.seperated_timestep and fuse_vae_embedding_in_latents: + timestep = torch.concat([ + torch.zeros((1, latents.shape[3] * latents.shape[4] // 4), dtype=latents.dtype, device=latents.device), + torch.ones((latents.shape[2] - 1, latents.shape[3] * latents.shape[4] // 4), dtype=latents.dtype, device=latents.device) * timestep + ]).flatten() + t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep).unsqueeze(0)) + if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1: + t_chunks = torch.chunk(t, get_sequence_parallel_world_size(), dim=1) + t_chunks = [torch.nn.functional.pad(chunk, (0, 0, 0, t_chunks[0].shape[1]-chunk.shape[1]), value=0) for chunk in t_chunks] + t = t_chunks[get_sequence_parallel_rank()] + t_mod = dit.time_projection(t).unflatten(2, (6, dit.dim)) + else: + t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep)) + t_mod = dit.time_projection(t).unflatten(1, (6, dit.dim)) + + # Motion Controller + if motion_bucket_id is not None and motion_controller is not None: + t_mod = t_mod + motion_controller(motion_bucket_id).unflatten(1, (6, dit.dim)) + context = dit.text_embedding(context) + + x = latents + # Merged cfg + if x.shape[0] != context.shape[0]: + x = torch.concat([x] * context.shape[0], dim=0) + if timestep.shape[0] != context.shape[0]: + timestep = torch.concat([timestep] * context.shape[0], dim=0) + + # Image Embedding + if y is not None and dit.require_vae_embedding: + x = torch.cat([x, y], dim=1) + if clip_feature is not None and dit.require_clip_embedding: + clip_embdding = dit.img_emb(clip_feature) + context = torch.cat([clip_embdding, context], dim=1) + + # Camera control + if hasattr(dit, "wantodance_enable_global") and dit.wantodance_enable_global and int(wantodance_fps + 0.5) != 30: + x = dit.patchify(x, control_camera_latents_input, enable_wantodance_global=True) + else: + x = dit.patchify(x, control_camera_latents_input) + + # Animate + if pose_latents is not None and face_pixel_values is not None: + x, motion_vec = animate_adapter.after_patch_embedding(x, pose_latents, face_pixel_values) + + # Patchify + f, h, w = x.shape[2:] + x = rearrange(x, 'b c f h w -> b (f h w) c').contiguous() + + # Reference image + if reference_latents is not None: + if len(reference_latents.shape) == 5: + reference_latents = reference_latents[:, :, 0] + reference_latents = dit.ref_conv(reference_latents).flatten(2).transpose(1, 2) + x = torch.concat([reference_latents, x], dim=1) + f += 1 + + freqs = torch.cat([ + dit.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), + dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), + dit.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1) + ], dim=-1).reshape(f * h * w, 1, -1).to(x.device) + + # VAP + if vap is not None: + # hidden state + x_vap = vap_hidden_state + x_vap = vap.patchify(x_vap) + x_vap = rearrange(x_vap, 'b c f h w -> b (f h w) c').contiguous() + # Timestep + clean_timestep = torch.ones(timestep.shape, device=timestep.device).to(timestep.dtype) + t = vap.time_embedding(sinusoidal_embedding_1d(vap.freq_dim, clean_timestep)) + t_mod_vap = vap.time_projection(t).unflatten(1, (6, vap.dim)) + + # rope + freqs_vap = vap.compute_freqs_mot(f,h,w).to(x.device) + + # context + vap_clip_embedding = vap.img_emb(vap_clip_feature) + context_vap = vap.text_embedding(context_vap) + context_vap = torch.cat([vap_clip_embedding, context_vap], dim=1) + + # TeaCache + if tea_cache is not None: + tea_cache_update = tea_cache.check(dit, x, t_mod) + else: + tea_cache_update = False + + # WanToDance + if hasattr(dit, "wantodance_enable_global") and dit.wantodance_enable_global: + if wantodance_refimage_feature is not None: + refimage_feature_embedding = dit.img_emb_refimage(wantodance_refimage_feature) + context = torch.cat([refimage_feature_embedding, context], dim=1) + if (dit.wantodance_enable_dynamicfps or dit.wantodance_enable_unimodel) and int(wantodance_fps + 0.5) != 30: + freqs_0 = wantodance_get_single_freqs(dit.freqs[0], f, wantodance_fps) + freqs = torch.cat([ + freqs_0.view(f, 1, 1, -1).expand(f, h, w, -1), + dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), + dit.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1) + ], dim=-1).reshape(f * h * w, 1, -1).to(x.device) + if dit.wantodance_enable_global or dit.wantodance_enable_dynamicfps or dit.wantodance_enable_unimodel: + if use_unified_sequence_parallel: + length = int(float(music_feature.shape[0]) / get_sequence_parallel_world_size()) * get_sequence_parallel_world_size() + music_feature = music_feature[:length] + music_feature = torch.chunk(music_feature, get_sequence_parallel_world_size(), dim=0)[get_sequence_parallel_rank()] + if not dit.training: + dit.music_encoder.to(x.device, dtype=x.dtype) # only evaluation + music_feature = music_feature.to(x.device, dtype=x.dtype) + music_feature = dit.music_projection(music_feature) + music_feature = dit.music_encoder(music_feature) + if music_feature.dim() == 2: + music_feature = music_feature.unsqueeze(0) + if use_unified_sequence_parallel: + if dist.is_initialized() and dist.get_world_size() > 1: + music_feature = get_sp_group().all_gather(music_feature, dim=1) + music_feature = music_feature.unsqueeze(1) # [1, 1, 149, 4800] + N = 149 + M = 4800 + music_feature = torch.nn.functional.interpolate(music_feature, size=(N, M), mode='bilinear') + music_feature = music_feature.squeeze(1) # shape: [1, 149, 4800] + if music_feature is not None: + if music_feature.dim() == 2: + music_feature = music_feature.unsqueeze(0) + music_feature = music_feature.to(x.device, dtype=x.dtype) + interp_mode = 'bilinear' + if interp_mode == 'bilinear': + frame_num = latents.shape[2] if len(latents.shape) == 5 else latents.shape[1] # 21 + context_shape_end = context.shape[2] ## 14B 5120 + music_feature = music_feature.unsqueeze(1) # shape: [1, 1, 149, 4800] + if use_unified_sequence_parallel: + N = int(float(frame_num * 8) / get_sequence_parallel_world_size()) * get_sequence_parallel_world_size() + else: + N = frame_num * 8 + music_feature = torch.nn.functional.interpolate(music_feature, size=(N, context_shape_end), mode='bilinear') + music_feature = music_feature.squeeze(1) # shape: [1, N, context_shape_end] + if use_unified_sequence_parallel: + dit.merged_audio_emb = torch.chunk(music_feature, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()] + else: + dit.merged_audio_emb = music_feature + else: + dit.merged_audio_emb = music_feature + + # blocks + if use_unified_sequence_parallel: + if dist.is_initialized() and dist.get_world_size() > 1: + chunks = torch.chunk(x, get_sequence_parallel_world_size(), dim=1) + pad_shape = chunks[0].shape[1] - chunks[-1].shape[1] + chunks = [torch.nn.functional.pad(chunk, (0, 0, 0, chunks[0].shape[1]-chunk.shape[1]), value=0) for chunk in chunks] + x = chunks[get_sequence_parallel_rank()] + + if vace_context is not None: + vace_hints = vace( + x, vace_context, context, t_mod, freqs, + glyph_latent=glyph_latent, + target_text_ids=target_text_ids, + use_gradient_checkpointing=use_gradient_checkpointing, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload + ) + # Offload hints to CPU to save GPU memory during DiT blocks (same as PISCO line 1530) + if use_gradient_checkpointing_offload: + vace_hints = [h.cpu() for h in vace_hints] + torch.cuda.empty_cache() + if tea_cache_update: + x = tea_cache.update(x) + else: + from diffsynth.models.wan_video_vace import _OffloadToCPU, _RestoreToGPU + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + return custom_forward + + def create_custom_forward_offload(module, gpu_device): + """Checkpoint wrapper that restores CPU inputs to GPU before running the block.""" + def custom_forward(x_cpu, ctx_cpu, tmod_arg, freqs_cpu): + x_gpu = _RestoreToGPU.apply(x_cpu, gpu_device) + ctx_gpu = _RestoreToGPU.apply(ctx_cpu, gpu_device) + freqs_gpu = _RestoreToGPU.apply(freqs_cpu, gpu_device) + return module(x_gpu, ctx_gpu, tmod_arg, freqs_gpu) + return custom_forward + + def create_custom_forward_vap(block, vap): + def custom_forward(*inputs): + return vap(block, *inputs) + return custom_forward + + # Pre-offload shared tensors to CPU (same as PISCO lines 1567-1570) + if use_gradient_checkpointing_offload: + _gpu_dev = x.device + context_cpu = _OffloadToCPU.apply(context).requires_grad_(True) + freqs_cpu = _OffloadToCPU.apply(freqs).requires_grad_(True) + + # Block + for block_id, block in enumerate(dit.blocks): + if skip_9th_layer: + # This is only used in WanToDance + if block_id == 9: + continue + if vap is not None and block_id in vap.mot_layers_mapping: + if use_gradient_checkpointing_offload: + with torch.autograd.graph.save_on_cpu(): + x, x_vap = torch.utils.checkpoint.checkpoint( + create_custom_forward_vap(block, vap), + x, context, t_mod, freqs, x_vap, context_vap, t_mod_vap, freqs_vap, block_id, + use_reentrant=True, + ) + elif use_gradient_checkpointing: + x, x_vap = torch.utils.checkpoint.checkpoint( + create_custom_forward_vap(block, vap), + x, context, t_mod, freqs, x_vap, context_vap, t_mod_vap, freqs_vap, block_id, + use_reentrant=True, + ) + else: + x, x_vap = vap(block, x, context, t_mod, freqs, x_vap, context_vap, t_mod_vap, freqs_vap, block_id) + else: + if use_gradient_checkpointing_offload: + x_cpu = _OffloadToCPU.apply(x).requires_grad_(True) + with torch.autograd.graph.save_on_cpu(): + x = torch.utils.checkpoint.checkpoint( + create_custom_forward_offload(block, _gpu_dev), + x_cpu, context_cpu, t_mod, freqs_cpu, + use_reentrant=True, + ) + del x_cpu + elif use_gradient_checkpointing: + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + x, context, t_mod, freqs, + use_reentrant=True, + ) + else: + x = block(x, context, t_mod, freqs) + + # VACE + if vace_context is not None and block_id in vace.vace_layers_mapping: + current_vace_hint = vace_hints[vace.vace_layers_mapping[block_id]] + if not isinstance(current_vace_hint, torch.Tensor): + pass + elif current_vace_hint.device != x.device: + current_vace_hint = current_vace_hint.to(x.device) + x = x + current_vace_hint * vace_scale + + # Animate + if pose_latents is not None and face_pixel_values is not None: + x = animate_adapter.after_transformer_block(block_id, x, motion_vec) + + # WanToDance + if hasattr(dit, "wantodance_enable_music_inject") and dit.wantodance_enable_music_inject: + x = dit.wantodance_after_transformer_block(block_id, x) + if tea_cache is not None: + tea_cache.store(x) + + if hasattr(dit, "wantodance_enable_unimodel") and dit.wantodance_enable_unimodel and int(wantodance_fps + 0.5) != 30: + x = dit.head_global(x, t) + else: + x = dit.head(x, t) + + if use_unified_sequence_parallel: + if dist.is_initialized() and dist.get_world_size() > 1: + x = get_sp_group().all_gather(x, dim=1) + x = x[:, :-pad_shape] if pad_shape > 0 else x + # Remove reference latents + if reference_latents is not None: + x = x[:, reference_latents.shape[1]:] + f -= 1 + x = dit.unpatchify(x, (f, h, w)) + return x + + +def model_fn_longcat_video( + dit: LongCatVideoTransformer3DModel, + latents: torch.Tensor = None, + timestep: torch.Tensor = None, + context: torch.Tensor = None, + longcat_latents: torch.Tensor = None, + use_gradient_checkpointing=False, + use_gradient_checkpointing_offload=False, +): + if longcat_latents is not None: + latents[:, :, :longcat_latents.shape[2]] = longcat_latents + num_cond_latents = longcat_latents.shape[2] + else: + num_cond_latents = 0 + context = context.unsqueeze(0) + encoder_attention_mask = torch.any(context != 0, dim=-1)[:, 0].to(torch.int64) + output = dit( + latents, + timestep, + context, + encoder_attention_mask, + num_cond_latents=num_cond_latents, + use_gradient_checkpointing=use_gradient_checkpointing, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + ) + output = -output + output = output.to(latents.dtype) + return output + + +def model_fn_wans2v( + dit, + latents, + timestep, + context, + audio_embeds, + motion_latents, + s2v_pose_latents, + drop_motion_frames=True, + use_gradient_checkpointing_offload=False, + use_gradient_checkpointing=False, + use_unified_sequence_parallel=False, +): + if use_unified_sequence_parallel: + import torch.distributed as dist + from xfuser.core.distributed import (get_sequence_parallel_rank, + get_sequence_parallel_world_size, + get_sp_group) + origin_ref_latents = latents[:, :, 0:1] + x = latents[:, :, 1:] + + # context embedding + context = dit.text_embedding(context) + + # audio encode + audio_emb_global, merged_audio_emb = dit.cal_audio_emb(audio_embeds) + + # x and s2v_pose_latents + s2v_pose_latents = torch.zeros_like(x) if s2v_pose_latents is None else s2v_pose_latents + x, (f, h, w) = dit.patchify(dit.patch_embedding(x) + dit.cond_encoder(s2v_pose_latents)) + seq_len_x = seq_len_x_global = x.shape[1] # global used for unified sequence parallel + + # reference image + ref_latents, (rf, rh, rw) = dit.patchify(dit.patch_embedding(origin_ref_latents)) + grid_sizes = dit.get_grid_sizes((f, h, w), (rf, rh, rw)) + x = torch.cat([x, ref_latents], dim=1) + # mask + mask = torch.cat([torch.zeros([1, seq_len_x]), torch.ones([1, ref_latents.shape[1]])], dim=1).to(torch.long).to(x.device) + # freqs + pre_compute_freqs = rope_precompute(x.detach().view(1, x.size(1), dit.num_heads, dit.dim // dit.num_heads), grid_sizes, dit.freqs, start=None) + # motion + x, pre_compute_freqs, mask = dit.inject_motion(x, pre_compute_freqs, mask, motion_latents, drop_motion_frames=drop_motion_frames, add_last_motion=2) + + x = x + dit.trainable_cond_mask(mask).to(x.dtype) + + # tmod + timestep = torch.cat([timestep, torch.zeros([1], dtype=timestep.dtype, device=timestep.device)]) + t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep)) + t_mod = dit.time_projection(t).unflatten(1, (6, dit.dim)).unsqueeze(2).transpose(0, 2) + + if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1: + world_size, sp_rank = get_sequence_parallel_world_size(), get_sequence_parallel_rank() + assert x.shape[1] % world_size == 0, f"the dimension after chunk must be divisible by world size, but got {x.shape[1]} and {get_sequence_parallel_world_size()}" + x = torch.chunk(x, world_size, dim=1)[sp_rank] + seg_idxs = [0] + list(torch.cumsum(torch.tensor([x.shape[1]] * world_size), dim=0).cpu().numpy()) + seq_len_x_list = [min(max(0, seq_len_x - seg_idxs[i]), x.shape[1]) for i in range(len(seg_idxs)-1)] + seq_len_x = seq_len_x_list[sp_rank] + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + return custom_forward + + for block_id, block in enumerate(dit.blocks): + x = gradient_checkpoint_forward( + block, + use_gradient_checkpointing, + use_gradient_checkpointing_offload, + x, context, t_mod, seq_len_x, pre_compute_freqs[0] + ) + x = gradient_checkpoint_forward( + lambda x: dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x), + use_gradient_checkpointing, + use_gradient_checkpointing_offload, + x + ) + + if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1: + x = get_sp_group().all_gather(x, dim=1) + + x = x[:, :seq_len_x_global] + x = dit.head(x, t[:-1]) + x = dit.unpatchify(x, (f, h, w)) + # make compatible with wan video + x = torch.cat([origin_ref_latents, x], dim=2) + return x diff --git a/diffsynth/pipelines/z_image.py b/diffsynth/pipelines/z_image.py new file mode 100644 index 0000000000000000000000000000000000000000..59e44b375727b365131e4a405a7bcee2a940f6dd --- /dev/null +++ b/diffsynth/pipelines/z_image.py @@ -0,0 +1,689 @@ +import torch, math, warnings +from PIL import Image +from typing import Union +from tqdm import tqdm +from einops import rearrange +import numpy as np +from typing import Union, List, Optional, Tuple, Iterable, Dict + +from ..core.device.npu_compatible_device import get_device_type, IS_NPU_AVAILABLE +from ..diffusion import FlowMatchScheduler +from ..core import ModelConfig, gradient_checkpoint_forward +from ..core.data.operators import ImageCropAndResize +from ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput +from ..utils.lora import merge_lora + +from transformers import AutoTokenizer +from ..models.z_image_text_encoder import ZImageTextEncoder +from ..models.z_image_dit import ZImageDiT +from ..models.flux_vae import FluxVAEEncoder, FluxVAEDecoder +from ..models.siglip2_image_encoder import Siglip2ImageEncoder428M +from ..models.z_image_controlnet import ZImageControlNet +from ..models.siglip2_image_encoder import Siglip2ImageEncoder +from ..models.dinov3_image_encoder import DINOv3ImageEncoder +from ..models.z_image_image2lora import ZImageImage2LoRAModel + + +class ZImagePipeline(BasePipeline): + + def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16): + super().__init__( + device=device, torch_dtype=torch_dtype, + height_division_factor=16, width_division_factor=16, + ) + self.scheduler = FlowMatchScheduler("Z-Image") + self.text_encoder: ZImageTextEncoder = None + self.dit: ZImageDiT = None + self.vae_encoder: FluxVAEEncoder = None + self.vae_decoder: FluxVAEDecoder = None + self.image_encoder: Siglip2ImageEncoder428M = None + self.controlnet: ZImageControlNet = None + self.siglip2_image_encoder: Siglip2ImageEncoder = None + self.dinov3_image_encoder: DINOv3ImageEncoder = None + self.image2lora_style: ZImageImage2LoRAModel = None + self.tokenizer: AutoTokenizer = None + self.in_iteration_models = ("dit", "controlnet") + self.units = [ + ZImageUnit_ShapeChecker(), + ZImageUnit_PromptEmbedder(), + ZImageUnit_NoiseInitializer(), + ZImageUnit_InputImageEmbedder(), + ZImageUnit_EditImageAutoResize(), + ZImageUnit_EditImageEmbedderVAE(), + ZImageUnit_EditImageEmbedderSiglip(), + ZImageUnit_PAIControlNet(), + ] + self.model_fn = model_fn_z_image + self.compilable_models = ["dit"] + + + @staticmethod + def from_pretrained( + torch_dtype: torch.dtype = torch.bfloat16, + device: Union[str, torch.device] = get_device_type(), + model_configs: list[ModelConfig] = [], + tokenizer_config: ModelConfig = ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"), + vram_limit: float = None, + enable_npu_patch: bool = True, + ): + # Initialize pipeline + pipe = ZImagePipeline(device=device, torch_dtype=torch_dtype) + model_pool = pipe.download_and_load_models(model_configs, vram_limit) + + # Fetch models + pipe.text_encoder = model_pool.fetch_model("z_image_text_encoder") + pipe.dit = model_pool.fetch_model("z_image_dit") + pipe.vae_encoder = model_pool.fetch_model("flux_vae_encoder") + pipe.vae_decoder = model_pool.fetch_model("flux_vae_decoder") + pipe.image_encoder = model_pool.fetch_model("siglip_vision_model_428m") + pipe.controlnet = model_pool.fetch_model("z_image_controlnet") + pipe.siglip2_image_encoder = model_pool.fetch_model("siglip2_image_encoder") + pipe.dinov3_image_encoder = model_pool.fetch_model("dinov3_image_encoder") + pipe.image2lora_style = model_pool.fetch_model("z_image_image2lora_style") + if tokenizer_config is not None: + tokenizer_config.download_if_necessary() + pipe.tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.path) + + # VRAM Management + pipe.vram_management_enabled = pipe.check_vram_management_state() + # NPU patch + apply_npu_patch(enable_npu_patch) + return pipe + + + @torch.no_grad() + def __call__( + self, + # Prompt + prompt: str, + negative_prompt: str = "", + cfg_scale: float = 1.0, + # Image + input_image: Image.Image = None, + denoising_strength: float = 1.0, + # Edit + edit_image: Image.Image = None, + edit_image_auto_resize: bool = True, + # Shape + height: int = 1024, + width: int = 1024, + # Randomness + seed: int = None, + rand_device: str = "cpu", + # Steps + num_inference_steps: int = 8, + sigma_shift: float = None, + # ControlNet + controlnet_inputs: List[ControlNetInput] = None, + # Image to LoRA + image2lora_images: List[Image.Image] = None, + positive_only_lora: Dict[str, torch.Tensor] = None, + # Progress bar + progress_bar_cmd = tqdm, + ): + # Scheduler + self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift) + + # Parameters + inputs_posi = { + "prompt": prompt, + } + inputs_nega = { + "negative_prompt": negative_prompt, + } + inputs_shared = { + "cfg_scale": cfg_scale, + "input_image": input_image, "denoising_strength": denoising_strength, + "height": height, "width": width, + "seed": seed, "rand_device": rand_device, + "num_inference_steps": num_inference_steps, + "edit_image": edit_image, "edit_image_auto_resize": edit_image_auto_resize, + "controlnet_inputs": controlnet_inputs, + "image2lora_images": image2lora_images, "positive_only_lora": positive_only_lora, + } + for unit in self.units: + inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega) + + # Denoise + self.load_models_to_device(self.in_iteration_models) + models = {name: getattr(self, name) for name in self.in_iteration_models} + for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): + timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device) + noise_pred = self.cfg_guided_model_fn( + self.model_fn, cfg_scale, + inputs_shared, inputs_posi, inputs_nega, + **models, timestep=timestep, progress_id=progress_id + ) + inputs_shared["latents"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared) + + # Decode + self.load_models_to_device(['vae_decoder']) + image = self.vae_decoder(inputs_shared["latents"]) + image = self.vae_output_to_image(image) + self.load_models_to_device([]) + + return image + + +class ZImageUnit_ShapeChecker(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("height", "width"), + output_params=("height", "width"), + ) + + def process(self, pipe: ZImagePipeline, height, width): + height, width = pipe.check_resize_height_width(height, width) + return {"height": height, "width": width} + + +class ZImageUnit_PromptEmbedder(PipelineUnit): + def __init__(self): + super().__init__( + seperate_cfg=True, + input_params=("edit_image",), + input_params_posi={"prompt": "prompt"}, + input_params_nega={"prompt": "negative_prompt"}, + output_params=("prompt_embeds",), + onload_model_names=("text_encoder",) + ) + + def encode_prompt( + self, + pipe, + prompt: Union[str, List[str]], + device: Optional[torch.device] = None, + max_sequence_length: int = 512, + ) -> List[torch.FloatTensor]: + if isinstance(prompt, str): + prompt = [prompt] + + for i, prompt_item in enumerate(prompt): + messages = [ + {"role": "user", "content": prompt_item}, + ] + prompt_item = pipe.tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + enable_thinking=True, + ) + prompt[i] = prompt_item + + text_inputs = pipe.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids.to(device) + prompt_masks = text_inputs.attention_mask.to(device).bool() + + prompt_embeds = pipe.text_encoder( + input_ids=text_input_ids, + attention_mask=prompt_masks, + output_hidden_states=True, + ).hidden_states[-2] + + embeddings_list = [] + + for i in range(len(prompt_embeds)): + embeddings_list.append(prompt_embeds[i][prompt_masks[i]]) + + return embeddings_list + + def encode_prompt_omni( + self, + pipe, + prompt: Union[str, List[str]], + edit_image=None, + device: Optional[torch.device] = None, + max_sequence_length: int = 512, + ) -> List[torch.FloatTensor]: + if isinstance(prompt, str): + prompt = [prompt] + + if edit_image is None: + num_condition_images = 0 + elif isinstance(edit_image, list): + num_condition_images = len(edit_image) + else: + num_condition_images = 1 + + for i, prompt_item in enumerate(prompt): + if num_condition_images == 0: + prompt[i] = ["<|im_start|>user\n" + prompt_item + "<|im_end|>\n<|im_start|>assistant\n"] + elif num_condition_images > 0: + prompt_list = ["<|im_start|>user\n<|vision_start|>"] + prompt_list += ["<|vision_end|><|vision_start|>"] * (num_condition_images - 1) + prompt_list += ["<|vision_end|>" + prompt_item + "<|im_end|>\n<|im_start|>assistant\n<|vision_start|>"] + prompt_list += ["<|vision_end|><|im_end|>"] + prompt[i] = prompt_list + + flattened_prompt = [] + prompt_list_lengths = [] + + for i in range(len(prompt)): + prompt_list_lengths.append(len(prompt[i])) + flattened_prompt.extend(prompt[i]) + + text_inputs = pipe.tokenizer( + flattened_prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids.to(device) + prompt_masks = text_inputs.attention_mask.to(device).bool() + + prompt_embeds = pipe.text_encoder( + input_ids=text_input_ids, + attention_mask=prompt_masks, + output_hidden_states=True, + ).hidden_states[-2] + + embeddings_list = [] + start_idx = 0 + for i in range(len(prompt_list_lengths)): + batch_embeddings = [] + end_idx = start_idx + prompt_list_lengths[i] + for j in range(start_idx, end_idx): + batch_embeddings.append(prompt_embeds[j][prompt_masks[j]]) + embeddings_list.append(batch_embeddings) + start_idx = end_idx + + return embeddings_list + + def process(self, pipe: ZImagePipeline, prompt, edit_image): + pipe.load_models_to_device(self.onload_model_names) + if hasattr(pipe, "dit") and pipe.dit is not None and pipe.dit.siglip_embedder is not None: + # Z-Image-Turbo and Z-Image-Omni-Base use different prompt encoding methods. + # We determine which encoding method to use based on the model architecture. + # If you are using two-stage split training, + # please use `--offload_models` instead of skipping the DiT model loading. + prompt_embeds = self.encode_prompt_omni(pipe, prompt, edit_image, pipe.device) + else: + prompt_embeds = self.encode_prompt(pipe, prompt, pipe.device) + return {"prompt_embeds": prompt_embeds} + + +class ZImageUnit_NoiseInitializer(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("height", "width", "seed", "rand_device"), + output_params=("noise",), + ) + + def process(self, pipe: ZImagePipeline, height, width, seed, rand_device): + noise = pipe.generate_noise((1, 16, height//8, width//8), seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype) + return {"noise": noise} + + +class ZImageUnit_InputImageEmbedder(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("input_image", "noise"), + output_params=("latents", "input_latents"), + onload_model_names=("vae_encoder",) + ) + + def process(self, pipe: ZImagePipeline, input_image, noise): + if input_image is None: + return {"latents": noise, "input_latents": None} + pipe.load_models_to_device(['vae']) + image = pipe.preprocess_image(input_image) + input_latents = pipe.vae_encoder(image) + if pipe.scheduler.training: + return {"latents": noise, "input_latents": input_latents} + else: + latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0]) + return {"latents": latents, "input_latents": input_latents} + + +class ZImageUnit_EditImageAutoResize(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("edit_image", "edit_image_auto_resize"), + output_params=("edit_image",), + ) + + def process(self, pipe: ZImagePipeline, edit_image, edit_image_auto_resize): + if edit_image is None: + return {} + if edit_image_auto_resize is None or not edit_image_auto_resize: + return {} + operator = ImageCropAndResize(max_pixels=1024*1024, height_division_factor=16, width_division_factor=16) + if not isinstance(edit_image, list): + edit_image = [edit_image] + edit_image = [operator(i) for i in edit_image] + return {"edit_image": edit_image} + + +class ZImageUnit_EditImageEmbedderSiglip(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("edit_image",), + output_params=("image_embeds",), + onload_model_names=("image_encoder",) + ) + + def process(self, pipe: ZImagePipeline, edit_image): + if edit_image is None: + return {} + pipe.load_models_to_device(self.onload_model_names) + if not isinstance(edit_image, list): + edit_image = [edit_image] + image_emb = [] + for image_ in edit_image: + image_emb.append(pipe.image_encoder(image_, device=pipe.device)) + return {"image_embeds": image_emb} + + +class ZImageUnit_EditImageEmbedderVAE(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("edit_image",), + output_params=("image_latents",), + onload_model_names=("vae_encoder",) + ) + + def process(self, pipe: ZImagePipeline, edit_image): + if edit_image is None: + return {} + pipe.load_models_to_device(self.onload_model_names) + if not isinstance(edit_image, list): + edit_image = [edit_image] + image_latents = [] + for image_ in edit_image: + image_ = pipe.preprocess_image(image_) + image_latents.append(pipe.vae_encoder(image_)) + return {"image_latents": image_latents} + + +class ZImageUnit_PAIControlNet(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("controlnet_inputs", "height", "width"), + output_params=("control_context", "control_scale"), + onload_model_names=("vae_encoder",) + ) + + def process(self, pipe: ZImagePipeline, controlnet_inputs: List[ControlNetInput], height, width): + if controlnet_inputs is None: + return {} + if len(controlnet_inputs) != 1: + print("Z-Image ControlNet doesn't support multi-ControlNet. Only one image will be used.") + controlnet_input = controlnet_inputs[0] + pipe.load_models_to_device(self.onload_model_names) + + control_image = controlnet_input.image + if control_image is not None: + control_image = pipe.preprocess_image(control_image) + control_latents = pipe.vae_encoder(control_image) + else: + control_latents = torch.ones((1, 16, height // 8, width // 8), dtype=pipe.torch_dtype, device=pipe.device) * -1 + + inpaint_mask = controlnet_input.inpaint_mask + if inpaint_mask is not None: + inpaint_mask = pipe.preprocess_image(inpaint_mask, min_value=0, max_value=1) + inpaint_image = controlnet_input.inpaint_image + inpaint_image = pipe.preprocess_image(inpaint_image) + inpaint_image = inpaint_image * (inpaint_mask < 0.5) + inpaint_mask = torch.nn.functional.interpolate(1 - inpaint_mask, (height // 8, width // 8), mode='nearest')[:, :1] + else: + inpaint_mask = torch.zeros((1, 1, height // 8, width // 8), dtype=pipe.torch_dtype, device=pipe.device) + inpaint_image = torch.zeros((1, 3, height, width), dtype=pipe.torch_dtype, device=pipe.device) + inpaint_latent = pipe.vae_encoder(inpaint_image) + + control_context = torch.concat([control_latents, inpaint_mask, inpaint_latent], dim=1) + control_context = rearrange(control_context, "B C H W -> B C 1 H W") + return {"control_context": control_context, "control_scale": controlnet_input.scale} + + +def model_fn_z_image( + dit: ZImageDiT, + controlnet: ZImageControlNet = None, + latents=None, + timestep=None, + prompt_embeds=None, + image_embeds=None, + image_latents=None, + use_gradient_checkpointing=False, + use_gradient_checkpointing_offload=False, + **kwargs, +): + # Due to the complex and verbose codebase of Z-Image, + # we are temporarily using this inelegant structure. + # We will refactor this part in the future (if time permits). + if dit.siglip_embedder is None: + return model_fn_z_image_turbo( + dit, + controlnet=controlnet, + latents=latents, + timestep=timestep, + prompt_embeds=prompt_embeds, + image_embeds=image_embeds, + image_latents=image_latents, + use_gradient_checkpointing=use_gradient_checkpointing, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + **kwargs, + ) + latents = [rearrange(latents, "B C H W -> C B H W")] + if dit.siglip_embedder is not None: + if image_latents is not None: + image_latents = [rearrange(image_latent, "B C H W -> C B H W") for image_latent in image_latents] + latents = [image_latents + latents] + image_noise_mask = [[0] * len(image_latents) + [1]] + else: + latents = [latents] + image_noise_mask = [[1]] + image_embeds = [image_embeds] + else: + image_noise_mask = None + timestep = (1000 - timestep) / 1000 + model_output = dit( + latents, + timestep, + prompt_embeds, + siglip_feats=image_embeds, + image_noise_mask=image_noise_mask, + use_gradient_checkpointing=use_gradient_checkpointing, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + )[0] + model_output = -model_output + model_output = rearrange(model_output, "C B H W -> B C H W") + return model_output + + +class ZImageUnit_Image2LoRAEncode(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("image2lora_images",), + output_params=("image2lora_x",), + onload_model_names=("siglip2_image_encoder", "dinov3_image_encoder",), + ) + from ..core.data.operators import ImageCropAndResize + self.processor_highres = ImageCropAndResize(height=1024, width=1024) + + def encode_images_using_siglip2(self, pipe: ZImagePipeline, images: list[Image.Image]): + pipe.load_models_to_device(["siglip2_image_encoder"]) + embs = [] + for image in images: + image = self.processor_highres(image) + embs.append(pipe.siglip2_image_encoder(image).to(pipe.torch_dtype)) + embs = torch.stack(embs) + return embs + + def encode_images_using_dinov3(self, pipe: ZImagePipeline, images: list[Image.Image]): + pipe.load_models_to_device(["dinov3_image_encoder"]) + embs = [] + for image in images: + image = self.processor_highres(image) + embs.append(pipe.dinov3_image_encoder(image).to(pipe.torch_dtype)) + embs = torch.stack(embs) + return embs + + def encode_images(self, pipe: ZImagePipeline, images: list[Image.Image]): + if images is None: + return {} + if not isinstance(images, list): + images = [images] + embs_siglip2 = self.encode_images_using_siglip2(pipe, images) + embs_dinov3 = self.encode_images_using_dinov3(pipe, images) + x = torch.concat([embs_siglip2, embs_dinov3], dim=-1) + return x + + def process(self, pipe: ZImagePipeline, image2lora_images): + if image2lora_images is None: + return {} + x = self.encode_images(pipe, image2lora_images) + return {"image2lora_x": x} + + +class ZImageUnit_Image2LoRADecode(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("image2lora_x",), + output_params=("lora",), + onload_model_names=("image2lora_style",), + ) + + def process(self, pipe: ZImagePipeline, image2lora_x): + if image2lora_x is None: + return {} + loras = [] + if pipe.image2lora_style is not None: + pipe.load_models_to_device(["image2lora_style"]) + for x in image2lora_x: + loras.append(pipe.image2lora_style(x=x, residual=None)) + lora = merge_lora(loras, alpha=1 / len(image2lora_x)) + return {"lora": lora} + + +def model_fn_z_image_turbo( + dit: ZImageDiT, + controlnet: ZImageControlNet = None, + latents=None, + timestep=None, + prompt_embeds=None, + image_embeds=None, + image_latents=None, + control_context=None, + control_scale=None, + use_gradient_checkpointing=False, + use_gradient_checkpointing_offload=False, + **kwargs, +): + while isinstance(prompt_embeds, list): + prompt_embeds = prompt_embeds[0] + while isinstance(latents, list): + latents = latents[0] + while isinstance(image_embeds, list): + image_embeds = image_embeds[0] + + # Timestep + timestep = 1000 - timestep + t_noisy = dit.t_embedder(timestep) + t_clean = dit.t_embedder(torch.ones_like(timestep) * 1000) + + # Patchify + latents = rearrange(latents, "B C H W -> C B H W") + x, cap_feats, patch_metadata = dit.patchify_and_embed([latents], [prompt_embeds]) + x = x[0] + cap_feats = cap_feats[0] + + # Noise refine + x = dit.all_x_embedder["2-1"](x) + x[torch.cat(patch_metadata.get("x_pad_mask"))] = dit.x_pad_token.to(dtype=x.dtype, device=x.device) + x_freqs_cis = dit.rope_embedder(torch.cat(patch_metadata.get("x_pos_ids"), dim=0)) + x = rearrange(x, "L C -> 1 L C") + x_freqs_cis = rearrange(x_freqs_cis, "L C -> 1 L C") + + if control_context is not None: + kwargs = dict(attn_mask=None, freqs_cis=x_freqs_cis, adaln_input=t_noisy) + refiner_hints, control_context, control_context_item_seqlens = controlnet.forward_refiner( + dit, x, [cap_feats], control_context, kwargs, t=t_noisy, patch_size=2, f_patch_size=1, + use_gradient_checkpointing=use_gradient_checkpointing, use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + ) + + for layer_id, layer in enumerate(dit.noise_refiner): + x = gradient_checkpoint_forward( + layer, + use_gradient_checkpointing=use_gradient_checkpointing, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + x=x, + attn_mask=None, + freqs_cis=x_freqs_cis, + adaln_input=t_noisy, + ) + if control_context is not None: + x = x + refiner_hints[layer_id] * control_scale + + # Prompt refine + cap_feats = dit.cap_embedder(cap_feats) + cap_feats[torch.cat(patch_metadata.get("cap_pad_mask"))] = dit.cap_pad_token.to(dtype=x.dtype, device=x.device) + cap_freqs_cis = dit.rope_embedder(torch.cat(patch_metadata.get("cap_pos_ids"), dim=0)) + cap_feats = rearrange(cap_feats, "L C -> 1 L C") + cap_freqs_cis = rearrange(cap_freqs_cis, "L C -> 1 L C") + + for layer in dit.context_refiner: + cap_feats = gradient_checkpoint_forward( + layer, + use_gradient_checkpointing=use_gradient_checkpointing, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + x=cap_feats, + attn_mask=None, + freqs_cis=cap_freqs_cis, + ) + + # Unified + unified = torch.cat([x, cap_feats], dim=1) + unified_freqs_cis = torch.cat([x_freqs_cis, cap_freqs_cis], dim=1) + + if control_context is not None: + kwargs = dict(attn_mask=None, freqs_cis=unified_freqs_cis, adaln_input=t_noisy) + hints = controlnet.forward_layers( + unified, cap_feats, control_context, control_context_item_seqlens, kwargs, + use_gradient_checkpointing=use_gradient_checkpointing, use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + ) + + for layer_id, layer in enumerate(dit.layers): + unified = gradient_checkpoint_forward( + layer, + use_gradient_checkpointing=use_gradient_checkpointing, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + x=unified, + attn_mask=None, + freqs_cis=unified_freqs_cis, + adaln_input=t_noisy, + ) + if control_context is not None: + if layer_id in controlnet.control_layers_mapping: + unified = unified + hints[controlnet.control_layers_mapping[layer_id]] * control_scale + + # Output + unified = dit.all_final_layer["2-1"](unified, t_noisy) + x = dit.unpatchify([unified[0]], patch_metadata.get("x_size"))[0] + x = rearrange(x, "C B H W -> B C H W") + x = -x + return x + + +def apply_npu_patch(enable_npu_patch: bool=True): + if IS_NPU_AVAILABLE and enable_npu_patch: + from ..models.general_modules import RMSNorm + from transformers.models.qwen3.modeling_qwen3 import Qwen3RMSNorm + from ..models.z_image_dit import Attention + from ..core.npu_patch.npu_fused_operator import ( + rms_norm_forward_npu, + rms_norm_forward_transformers_npu, + rotary_emb_Zimage_npu + ) + warnings.warn("Replacing RMSNorm and Rope with NPU fusion operators to improve the performance of the model on NPU.Set enable_npu_patch=False to disable this feature.") + RMSNorm.forward = rms_norm_forward_npu + Qwen3RMSNorm.forward = rms_norm_forward_transformers_npu + Attention.apply_rotary_emb = rotary_emb_Zimage_npu diff --git a/diffsynth/utils/controlnet/__init__.py b/diffsynth/utils/controlnet/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..df23b6c61b99319f1f41d6448b0c52ffd03b9f25 --- /dev/null +++ b/diffsynth/utils/controlnet/__init__.py @@ -0,0 +1,2 @@ +from .controlnet_input import ControlNetInput +from .annotator import Annotator diff --git a/diffsynth/utils/controlnet/annotator.py b/diffsynth/utils/controlnet/annotator.py new file mode 100644 index 0000000000000000000000000000000000000000..cb737385f75bf1edd681c24c3118b0ac0d79e185 --- /dev/null +++ b/diffsynth/utils/controlnet/annotator.py @@ -0,0 +1,63 @@ +from typing_extensions import Literal, TypeAlias + +from diffsynth.core.device.npu_compatible_device import get_device_type + +Processor_id: TypeAlias = Literal[ + "canny", "depth", "softedge", "lineart", "lineart_anime", "openpose", "normal", "tile", "none", "inpaint" +] + +class Annotator: + def __init__(self, processor_id: Processor_id, model_path="models/Annotators", detect_resolution=None, device=get_device_type(), skip_processor=False): + if not skip_processor: + if processor_id == "canny": + from controlnet_aux.processor import CannyDetector + self.processor = CannyDetector() + elif processor_id == "depth": + from controlnet_aux.processor import MidasDetector + self.processor = MidasDetector.from_pretrained(model_path).to(device) + elif processor_id == "softedge": + from controlnet_aux.processor import HEDdetector + self.processor = HEDdetector.from_pretrained(model_path).to(device) + elif processor_id == "lineart": + from controlnet_aux.processor import LineartDetector + self.processor = LineartDetector.from_pretrained(model_path).to(device) + elif processor_id == "lineart_anime": + from controlnet_aux.processor import LineartAnimeDetector + self.processor = LineartAnimeDetector.from_pretrained(model_path).to(device) + elif processor_id == "openpose": + from controlnet_aux.processor import OpenposeDetector + self.processor = OpenposeDetector.from_pretrained(model_path).to(device) + elif processor_id == "normal": + from controlnet_aux.processor import NormalBaeDetector + self.processor = NormalBaeDetector.from_pretrained(model_path).to(device) + elif processor_id == "tile" or processor_id == "none" or processor_id == "inpaint": + self.processor = None + else: + raise ValueError(f"Unsupported processor_id: {processor_id}") + else: + self.processor = None + + self.processor_id = processor_id + self.detect_resolution = detect_resolution + + def to(self,device): + if hasattr(self.processor,"model") and hasattr(self.processor.model,"to"): + + self.processor.model.to(device) + + def __call__(self, image, mask=None): + width, height = image.size + if self.processor_id == "openpose": + kwargs = { + "include_body": True, + "include_hand": True, + "include_face": True + } + else: + kwargs = {} + if self.processor is not None: + detect_resolution = self.detect_resolution if self.detect_resolution is not None else min(width, height) + image = self.processor(image, detect_resolution=detect_resolution, image_resolution=min(width, height), **kwargs) + image = image.resize((width, height)) + return image + diff --git a/diffsynth/utils/controlnet/controlnet_input.py b/diffsynth/utils/controlnet/controlnet_input.py new file mode 100644 index 0000000000000000000000000000000000000000..a79064bb51fa3625ca692a183544ec9720ca33b9 --- /dev/null +++ b/diffsynth/utils/controlnet/controlnet_input.py @@ -0,0 +1,14 @@ +from dataclasses import dataclass +from PIL import Image + + +@dataclass +class ControlNetInput: + controlnet_id: int = 0 + scale: float = 1.0 + start: float = 1.0 + end: float = 0.0 + image: Image.Image = None + inpaint_image: Image.Image = None + inpaint_mask: Image.Image = None + processor_id: str = None diff --git a/diffsynth/utils/data/__init__.py b/diffsynth/utils/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..edc3d413bf7d78d23cc232f5a60ce5eabc966456 --- /dev/null +++ b/diffsynth/utils/data/__init__.py @@ -0,0 +1,217 @@ +import imageio, os +import numpy as np +from PIL import Image +from tqdm import tqdm +import subprocess +import shutil + + +class LowMemoryVideo: + def __init__(self, file_name): + self.reader = imageio.get_reader(file_name) + + def __len__(self): + return self.reader.count_frames() + + def __getitem__(self, item): + return Image.fromarray(np.array(self.reader.get_data(item))).convert("RGB") + + def __del__(self): + self.reader.close() + + +def split_file_name(file_name): + result = [] + number = -1 + for i in file_name: + if ord(i)>=ord("0") and ord(i)<=ord("9"): + if number == -1: + number = 0 + number = number*10 + ord(i) - ord("0") + else: + if number != -1: + result.append(number) + number = -1 + result.append(i) + if number != -1: + result.append(number) + result = tuple(result) + return result + + +def search_for_images(folder): + file_list = [i for i in os.listdir(folder) if i.endswith(".jpg") or i.endswith(".png")] + file_list = [(split_file_name(file_name), file_name) for file_name in file_list] + file_list = [i[1] for i in sorted(file_list)] + file_list = [os.path.join(folder, i) for i in file_list] + return file_list + + +class LowMemoryImageFolder: + def __init__(self, folder, file_list=None): + if file_list is None: + self.file_list = search_for_images(folder) + else: + self.file_list = [os.path.join(folder, file_name) for file_name in file_list] + + def __len__(self): + return len(self.file_list) + + def __getitem__(self, item): + return Image.open(self.file_list[item]).convert("RGB") + + def __del__(self): + pass + + +def crop_and_resize(image, height, width): + image = np.array(image) + image_height, image_width, _ = image.shape + if image_height / image_width < height / width: + croped_width = int(image_height / height * width) + left = (image_width - croped_width) // 2 + image = image[:, left: left+croped_width] + image = Image.fromarray(image).resize((width, height)) + else: + croped_height = int(image_width / width * height) + left = (image_height - croped_height) // 2 + image = image[left: left+croped_height, :] + image = Image.fromarray(image).resize((width, height)) + return image + + +class VideoData: + def __init__(self, video_file=None, image_folder=None, height=None, width=None, **kwargs): + if video_file is not None: + self.data_type = "video" + self.data = LowMemoryVideo(video_file, **kwargs) + elif image_folder is not None: + self.data_type = "images" + self.data = LowMemoryImageFolder(image_folder, **kwargs) + else: + raise ValueError("Cannot open video or image folder") + self.length = None + self.set_shape(height, width) + + def raw_data(self): + frames = [] + for i in range(self.__len__()): + frames.append(self.__getitem__(i)) + return frames + + def set_length(self, length): + self.length = length + + def set_shape(self, height, width): + self.height = height + self.width = width + + def __len__(self): + if self.length is None: + return len(self.data) + else: + return self.length + + def shape(self): + if self.height is not None and self.width is not None: + return self.height, self.width + else: + width, height = self.__getitem__(0).size + return height, width + + def __getitem__(self, item): + frame = self.data.__getitem__(item) + width, height = frame.size + if self.height is not None and self.width is not None: + if self.height != height or self.width != width: + frame = crop_and_resize(frame, self.height, self.width) + return frame + + def __del__(self): + pass + + def save_images(self, folder): + os.makedirs(folder, exist_ok=True) + for i in tqdm(range(self.__len__()), desc="Saving images"): + frame = self.__getitem__(i) + frame.save(os.path.join(folder, f"{i}.png")) + + +def save_video(frames, save_path, fps, quality=9, ffmpeg_params=None): + writer = imageio.get_writer(save_path, fps=fps, quality=quality, ffmpeg_params=ffmpeg_params) + for frame in tqdm(frames, desc="Saving video"): + frame = np.array(frame) + writer.append_data(frame) + writer.close() + +def save_frames(frames, save_path): + os.makedirs(save_path, exist_ok=True) + for i, frame in enumerate(tqdm(frames, desc="Saving images")): + frame.save(os.path.join(save_path, f"{i}.png")) + + +def merge_video_audio(video_path: str, audio_path: str): + # TODO: may need a in-python implementation to avoid subprocess dependency + """ + Merge the video and audio into a new video, with the duration set to the shorter of the two, + and overwrite the original video file. + + Parameters: + video_path (str): Path to the original video file + audio_path (str): Path to the audio file + """ + + # check + if not os.path.exists(video_path): + raise FileNotFoundError(f"video file {video_path} does not exist") + if not os.path.exists(audio_path): + raise FileNotFoundError(f"audio file {audio_path} does not exist") + + base, ext = os.path.splitext(video_path) + temp_output = f"{base}_temp{ext}" + + try: + # create ffmpeg command + command = [ + 'ffmpeg', + '-y', # overwrite + '-i', + video_path, + '-i', + audio_path, + '-c:v', + 'copy', # copy video stream + '-c:a', + 'aac', # use AAC audio encoder + '-b:a', + '192k', # set audio bitrate (optional) + '-map', + '0:v:0', # select the first video stream + '-map', + '1:a:0', # select the first audio stream + '-shortest', # choose the shortest duration + temp_output + ] + + # execute the command + result = subprocess.run( + command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) + + # check result + if result.returncode != 0: + error_msg = f"FFmpeg execute failed: {result.stderr}" + print(error_msg) + raise RuntimeError(error_msg) + + shutil.move(temp_output, video_path) + print(f"Merge completed, saved to {video_path}") + + except Exception as e: + if os.path.exists(temp_output): + os.remove(temp_output) + print(f"merge_video_audio failed with error: {e}") + + +def save_video_with_audio(frames, save_path, audio_path, fps=16, quality=9, ffmpeg_params=None): + save_video(frames, save_path, fps, quality, ffmpeg_params) + merge_video_audio(save_path, audio_path) diff --git a/diffsynth/utils/data/audio.py b/diffsynth/utils/data/audio.py new file mode 100644 index 0000000000000000000000000000000000000000..414fcb2e4d5dab8563db6d538365c666b6944ea4 --- /dev/null +++ b/diffsynth/utils/data/audio.py @@ -0,0 +1,108 @@ +import torch +import torchaudio + + +def convert_to_mono(audio_tensor: torch.Tensor) -> torch.Tensor: + """ + Convert audio to mono by averaging channels. + Supports [C, T] or [B, C, T]. Output shape: [1, T] or [B, 1, T]. + """ + return audio_tensor.mean(dim=-2, keepdim=True) + + +def convert_to_stereo(audio_tensor: torch.Tensor) -> torch.Tensor: + """ + Convert audio to stereo. + Supports [C, T] or [B, C, T]. Duplicate mono, keep stereo. + """ + if audio_tensor.size(-2) == 1: + return audio_tensor.repeat(1, 2, 1) if audio_tensor.dim() == 3 else audio_tensor.repeat(2, 1) + return audio_tensor + + +def resample_waveform(waveform: torch.Tensor, source_rate: int, target_rate: int) -> torch.Tensor: + """Resample waveform to target sample rate if needed.""" + if source_rate == target_rate: + return waveform + resampled = torchaudio.functional.resample(waveform, source_rate, target_rate) + return resampled.to(dtype=waveform.dtype) + + +def read_audio_with_torchcodec( + path: str, + start_time: float = 0, + duration: float | None = None, +) -> tuple[torch.Tensor, int]: + """ + Read audio from file natively using torchcodec, with optional start time and duration. + + Args: + path (str): The file path to the audio file. + start_time (float, optional): The start time in seconds to read from. Defaults to 0. + duration (float | None, optional): The duration in seconds to read. If None, reads until the end. Defaults to None. + + Returns: + tuple[torch.Tensor, int]: A tuple containing the audio tensor and the sample rate. + The audio tensor shape is [C, T] where C is the number of channels and T is the number of audio frames. + """ + from torchcodec.decoders import AudioDecoder + decoder = AudioDecoder(path) + stop_seconds = None if duration is None else start_time + duration + waveform = decoder.get_samples_played_in_range(start_seconds=start_time, stop_seconds=stop_seconds).data + return waveform, decoder.metadata.sample_rate + + +def read_audio( + path: str, + start_time: float = 0, + duration: float | None = None, + resample: bool = False, + resample_rate: int = 48000, + backend: str = "torchcodec", +) -> tuple[torch.Tensor, int]: + """ + Read audio from file, with optional start time, duration, and resampling. + + Args: + path (str): The file path to the audio file. + start_time (float, optional): The start time in seconds to read from. Defaults to 0. + duration (float | None, optional): The duration in seconds to read. If None, reads until the end. Defaults to None. + resample (bool, optional): Whether to resample the audio to a different sample rate. Defaults to False. + resample_rate (int, optional): The target sample rate for resampling if resample is True. Defaults to 48000. + backend (str, optional): The audio backend to use for reading. Defaults to "torchcodec". + + Returns: + tuple[torch.Tensor, int]: A tuple containing the audio tensor and the sample rate. + The audio tensor shape is [C, T] where C is the number of channels and T is the number of audio frames. + """ + if backend == "torchcodec": + waveform, sample_rate = read_audio_with_torchcodec(path, start_time, duration) + else: + raise ValueError(f"Unsupported audio backend: {backend}") + + if resample: + waveform = resample_waveform(waveform, sample_rate, resample_rate) + sample_rate = resample_rate + + return waveform, sample_rate + + +def save_audio(waveform: torch.Tensor, sample_rate: int, save_path: str, backend: str = "torchcodec"): + """ + Save audio tensor to file. + + Args: + waveform (torch.Tensor): The audio tensor to save. Shape can be [C, T] or [B, C, T]. + sample_rate (int): The sample rate of the audio. + save_path (str): The file path to save the audio to. + backend (str, optional): The audio backend to use for saving. Defaults to "torchcodec". + """ + if waveform.dim() == 3: + waveform = waveform[0] + + if backend == "torchcodec": + from torchcodec.encoders import AudioEncoder + encoder = AudioEncoder(waveform, sample_rate=sample_rate) + encoder.to_file(dest=save_path) + else: + raise ValueError(f"Unsupported audio backend: {backend}") diff --git a/diffsynth/utils/data/audio_video.py b/diffsynth/utils/data/audio_video.py new file mode 100644 index 0000000000000000000000000000000000000000..6914b2d1246d4299e4ebd87deff9dd1dd68dd08c --- /dev/null +++ b/diffsynth/utils/data/audio_video.py @@ -0,0 +1,134 @@ +import av +from fractions import Fraction +import torch +from PIL import Image +from tqdm import tqdm +from .audio import convert_to_stereo + + +def _resample_audio( + container: av.container.Container, audio_stream: av.audio.AudioStream, frame_in: av.AudioFrame +) -> None: + cc = audio_stream.codec_context + + # Use the encoder's format/layout/rate as the *target* + target_format = cc.format or "fltp" # AAC → usually fltp + target_layout = cc.layout or "stereo" + target_rate = cc.sample_rate or frame_in.sample_rate + + audio_resampler = av.audio.resampler.AudioResampler( + format=target_format, + layout=target_layout, + rate=target_rate, + ) + + audio_next_pts = 0 + for rframe in audio_resampler.resample(frame_in): + if rframe.pts is None: + rframe.pts = audio_next_pts + audio_next_pts += rframe.samples + rframe.sample_rate = frame_in.sample_rate + container.mux(audio_stream.encode(rframe)) + + # flush audio encoder + for packet in audio_stream.encode(): + container.mux(packet) + + +def _write_audio( + container: av.container.Container, audio_stream: av.audio.AudioStream, samples: torch.Tensor, audio_sample_rate: int +) -> None: + if samples.ndim == 1: + samples = samples.unsqueeze(0) + samples = convert_to_stereo(samples) + assert samples.ndim == 2 and samples.shape[0] == 2, "audio samples must be [C, S] or [S], C must be 1 or 2" + samples = samples.T + # Convert to int16 packed for ingestion; resampler converts to encoder fmt. + if samples.dtype != torch.int16: + samples = torch.clip(samples, -1.0, 1.0) + samples = (samples * 32767.0).to(torch.int16) + + frame_in = av.AudioFrame.from_ndarray( + samples.contiguous().reshape(1, -1).cpu().numpy(), + format="s16", + layout="stereo", + ) + frame_in.sample_rate = audio_sample_rate + + _resample_audio(container, audio_stream, frame_in) + + +def _prepare_audio_stream(container: av.container.Container, audio_sample_rate: int) -> av.audio.AudioStream: + """ + Prepare the audio stream for writing. + """ + audio_stream = container.add_stream("aac") + supported_sample_rates = audio_stream.codec_context.codec.audio_rates + if supported_sample_rates: + best_rate = min(supported_sample_rates, key=lambda x: abs(x - audio_sample_rate)) + if best_rate != audio_sample_rate: + print(f"Using closest supported audio sample rate: {best_rate}") + else: + best_rate = audio_sample_rate + audio_stream.codec_context.sample_rate = best_rate + audio_stream.codec_context.layout = "stereo" + audio_stream.codec_context.time_base = Fraction(1, best_rate) + return audio_stream + + +def write_video_audio( + video: list[Image.Image], + audio: torch.Tensor | None, + output_path: str, + fps: int = 24, + audio_sample_rate: int | None = None, +) -> None: + """ + Writes a sequence of images and an audio tensor to a video file. + + This function utilizes PyAV (or a similar multimedia library) to encode a list of PIL images into a video stream + and multiplex a PyTorch tensor as the audio stream into the output container. + + Args: + video (list[Image.Image]): A list of PIL Image objects representing the video frames. + The length of this list determines the total duration of the video based on the FPS. + audio (torch.Tensor | None): The audio data as a PyTorch tensor. + The shape is typically (channels, samples). If no audio is required, pass None. + channels can be 1 or 2. 1 for mono, 2 for stereo. + output_path (str): The file path (including extension) where the output video will be saved. + fps (int, optional): The frame rate (frames per second) for the video. Defaults to 24. + audio_sample_rate (int | None, optional): The sample rate (e.g., 44100, 48000) for the audio. + If the audio tensor is provided and this is None, the function attempts to infer the rate + based on the audio tensor's length and the video duration. + Raises: + ValueError: If an audio tensor is provided but the sample rate cannot be determined. + """ + duration = len(video) / fps + if audio_sample_rate is None: + audio_sample_rate = int(audio.shape[-1] / duration) + + width, height = video[0].size + container = av.open(output_path, mode="w") + stream = container.add_stream("libx264", rate=int(fps)) + stream.width = width + stream.height = height + stream.pix_fmt = "yuv420p" + + if audio is not None: + if audio_sample_rate is None: + raise ValueError("audio_sample_rate is required when audio is provided") + audio_stream = _prepare_audio_stream(container, audio_sample_rate) + + for frame in tqdm(video, total=len(video)): + frame = av.VideoFrame.from_image(frame) + for packet in stream.encode(frame): + container.mux(packet) + + # Flush encoder + for packet in stream.encode(): + container.mux(packet) + + if audio is not None: + _write_audio(container, audio_stream, audio, audio_sample_rate) + + container.close() diff --git a/diffsynth/utils/data/media_io_ltx2.py b/diffsynth/utils/data/media_io_ltx2.py new file mode 100644 index 0000000000000000000000000000000000000000..425278651e34e622c7917796037fafb96ca40771 --- /dev/null +++ b/diffsynth/utils/data/media_io_ltx2.py @@ -0,0 +1,43 @@ +import av +import numpy as np +from io import BytesIO +from .audio_video import write_video_audio as write_video_audio_ltx2 + + +def encode_single_frame(output_file: str, image_array: np.ndarray, crf: float) -> None: + container = av.open(output_file, "w", format="mp4") + try: + stream = container.add_stream("libx264", rate=1, options={"crf": str(crf), "preset": "veryfast"}) + # Round to nearest multiple of 2 for compatibility with video codecs + height = image_array.shape[0] // 2 * 2 + width = image_array.shape[1] // 2 * 2 + image_array = image_array[:height, :width] + stream.height = height + stream.width = width + av_frame = av.VideoFrame.from_ndarray(image_array, format="rgb24").reformat(format="yuv420p") + container.mux(stream.encode(av_frame)) + container.mux(stream.encode()) + finally: + container.close() + + +def decode_single_frame(video_file: str) -> np.array: + container = av.open(video_file) + try: + stream = next(s for s in container.streams if s.type == "video") + frame = next(container.decode(stream)) + finally: + container.close() + return frame.to_ndarray(format="rgb24") + + +def ltx2_preprocess(image: np.array, crf: float = 33) -> np.array: + if crf == 0: + return image + + with BytesIO() as output_file: + encode_single_frame(output_file, image, crf) + video_bytes = output_file.getvalue() + with BytesIO(video_bytes) as video_file: + image_array = decode_single_frame(video_file) + return image_array diff --git a/diffsynth/utils/lora/__init__.py b/diffsynth/utils/lora/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8eb5901acba99ed8490079b8ebaeb6991ae3f59d --- /dev/null +++ b/diffsynth/utils/lora/__init__.py @@ -0,0 +1,3 @@ +from .general import GeneralLoRALoader +from .merge import merge_lora +from .reset_rank import reset_lora_rank \ No newline at end of file diff --git a/diffsynth/utils/lora/flux.py b/diffsynth/utils/lora/flux.py new file mode 100644 index 0000000000000000000000000000000000000000..97599b652a8cb97a004f8d3264d1ac4716612260 --- /dev/null +++ b/diffsynth/utils/lora/flux.py @@ -0,0 +1,302 @@ +from .general import GeneralLoRALoader +import torch, math + + +class FluxLoRALoader(GeneralLoRALoader): + def __init__(self, device="cpu", torch_dtype=torch.float32): + super().__init__(device=device, torch_dtype=torch_dtype) + + self.diffusers_rename_dict = { + "transformer.single_transformer_blocks.blockid.attn.to_k.lora_A.weight":"single_blocks.blockid.a_to_k.lora_A.weight", + "transformer.single_transformer_blocks.blockid.attn.to_k.lora_B.weight":"single_blocks.blockid.a_to_k.lora_B.weight", + "transformer.single_transformer_blocks.blockid.attn.to_q.lora_A.weight":"single_blocks.blockid.a_to_q.lora_A.weight", + "transformer.single_transformer_blocks.blockid.attn.to_q.lora_B.weight":"single_blocks.blockid.a_to_q.lora_B.weight", + "transformer.single_transformer_blocks.blockid.attn.to_v.lora_A.weight":"single_blocks.blockid.a_to_v.lora_A.weight", + "transformer.single_transformer_blocks.blockid.attn.to_v.lora_B.weight":"single_blocks.blockid.a_to_v.lora_B.weight", + "transformer.single_transformer_blocks.blockid.norm.linear.lora_A.weight":"single_blocks.blockid.norm.linear.lora_A.weight", + "transformer.single_transformer_blocks.blockid.norm.linear.lora_B.weight":"single_blocks.blockid.norm.linear.lora_B.weight", + "transformer.single_transformer_blocks.blockid.proj_mlp.lora_A.weight":"single_blocks.blockid.proj_in_besides_attn.lora_A.weight", + "transformer.single_transformer_blocks.blockid.proj_mlp.lora_B.weight":"single_blocks.blockid.proj_in_besides_attn.lora_B.weight", + "transformer.single_transformer_blocks.blockid.proj_out.lora_A.weight":"single_blocks.blockid.proj_out.lora_A.weight", + "transformer.single_transformer_blocks.blockid.proj_out.lora_B.weight":"single_blocks.blockid.proj_out.lora_B.weight", + "transformer.transformer_blocks.blockid.attn.add_k_proj.lora_A.weight":"blocks.blockid.attn.b_to_k.lora_A.weight", + "transformer.transformer_blocks.blockid.attn.add_k_proj.lora_B.weight":"blocks.blockid.attn.b_to_k.lora_B.weight", + "transformer.transformer_blocks.blockid.attn.add_q_proj.lora_A.weight":"blocks.blockid.attn.b_to_q.lora_A.weight", + "transformer.transformer_blocks.blockid.attn.add_q_proj.lora_B.weight":"blocks.blockid.attn.b_to_q.lora_B.weight", + "transformer.transformer_blocks.blockid.attn.add_v_proj.lora_A.weight":"blocks.blockid.attn.b_to_v.lora_A.weight", + "transformer.transformer_blocks.blockid.attn.add_v_proj.lora_B.weight":"blocks.blockid.attn.b_to_v.lora_B.weight", + "transformer.transformer_blocks.blockid.attn.to_add_out.lora_A.weight":"blocks.blockid.attn.b_to_out.lora_A.weight", + "transformer.transformer_blocks.blockid.attn.to_add_out.lora_B.weight":"blocks.blockid.attn.b_to_out.lora_B.weight", + "transformer.transformer_blocks.blockid.attn.to_k.lora_A.weight":"blocks.blockid.attn.a_to_k.lora_A.weight", + "transformer.transformer_blocks.blockid.attn.to_k.lora_B.weight":"blocks.blockid.attn.a_to_k.lora_B.weight", + "transformer.transformer_blocks.blockid.attn.to_out.0.lora_A.weight":"blocks.blockid.attn.a_to_out.lora_A.weight", + "transformer.transformer_blocks.blockid.attn.to_out.0.lora_B.weight":"blocks.blockid.attn.a_to_out.lora_B.weight", + "transformer.transformer_blocks.blockid.attn.to_q.lora_A.weight":"blocks.blockid.attn.a_to_q.lora_A.weight", + "transformer.transformer_blocks.blockid.attn.to_q.lora_B.weight":"blocks.blockid.attn.a_to_q.lora_B.weight", + "transformer.transformer_blocks.blockid.attn.to_v.lora_A.weight":"blocks.blockid.attn.a_to_v.lora_A.weight", + "transformer.transformer_blocks.blockid.attn.to_v.lora_B.weight":"blocks.blockid.attn.a_to_v.lora_B.weight", + "transformer.transformer_blocks.blockid.ff.net.0.proj.lora_A.weight":"blocks.blockid.ff_a.0.lora_A.weight", + "transformer.transformer_blocks.blockid.ff.net.0.proj.lora_B.weight":"blocks.blockid.ff_a.0.lora_B.weight", + "transformer.transformer_blocks.blockid.ff.net.2.lora_A.weight":"blocks.blockid.ff_a.2.lora_A.weight", + "transformer.transformer_blocks.blockid.ff.net.2.lora_B.weight":"blocks.blockid.ff_a.2.lora_B.weight", + "transformer.transformer_blocks.blockid.ff_context.net.0.proj.lora_A.weight":"blocks.blockid.ff_b.0.lora_A.weight", + "transformer.transformer_blocks.blockid.ff_context.net.0.proj.lora_B.weight":"blocks.blockid.ff_b.0.lora_B.weight", + "transformer.transformer_blocks.blockid.ff_context.net.2.lora_A.weight":"blocks.blockid.ff_b.2.lora_A.weight", + "transformer.transformer_blocks.blockid.ff_context.net.2.lora_B.weight":"blocks.blockid.ff_b.2.lora_B.weight", + "transformer.transformer_blocks.blockid.norm1.linear.lora_A.weight":"blocks.blockid.norm1_a.linear.lora_A.weight", + "transformer.transformer_blocks.blockid.norm1.linear.lora_B.weight":"blocks.blockid.norm1_a.linear.lora_B.weight", + "transformer.transformer_blocks.blockid.norm1_context.linear.lora_A.weight":"blocks.blockid.norm1_b.linear.lora_A.weight", + "transformer.transformer_blocks.blockid.norm1_context.linear.lora_B.weight":"blocks.blockid.norm1_b.linear.lora_B.weight", + } + + self.civitai_rename_dict = { + "lora_unet_double_blocks_blockid_img_mod_lin.lora_down.weight": "blocks.blockid.norm1_a.linear.lora_A.weight", + "lora_unet_double_blocks_blockid_img_mod_lin.lora_up.weight": "blocks.blockid.norm1_a.linear.lora_B.weight", + "lora_unet_double_blocks_blockid_txt_mod_lin.lora_down.weight": "blocks.blockid.norm1_b.linear.lora_A.weight", + "lora_unet_double_blocks_blockid_txt_mod_lin.lora_up.weight": "blocks.blockid.norm1_b.linear.lora_B.weight", + "lora_unet_double_blocks_blockid_img_attn_qkv.lora_down.weight": "blocks.blockid.attn.a_to_qkv.lora_A.weight", + "lora_unet_double_blocks_blockid_img_attn_qkv.lora_up.weight": "blocks.blockid.attn.a_to_qkv.lora_B.weight", + "lora_unet_double_blocks_blockid_txt_attn_qkv.lora_down.weight": "blocks.blockid.attn.b_to_qkv.lora_A.weight", + "lora_unet_double_blocks_blockid_txt_attn_qkv.lora_up.weight": "blocks.blockid.attn.b_to_qkv.lora_B.weight", + "lora_unet_double_blocks_blockid_img_attn_proj.lora_down.weight": "blocks.blockid.attn.a_to_out.lora_A.weight", + "lora_unet_double_blocks_blockid_img_attn_proj.lora_up.weight": "blocks.blockid.attn.a_to_out.lora_B.weight", + "lora_unet_double_blocks_blockid_txt_attn_proj.lora_down.weight": "blocks.blockid.attn.b_to_out.lora_A.weight", + "lora_unet_double_blocks_blockid_txt_attn_proj.lora_up.weight": "blocks.blockid.attn.b_to_out.lora_B.weight", + "lora_unet_double_blocks_blockid_img_mlp_0.lora_down.weight": "blocks.blockid.ff_a.0.lora_A.weight", + "lora_unet_double_blocks_blockid_img_mlp_0.lora_up.weight": "blocks.blockid.ff_a.0.lora_B.weight", + "lora_unet_double_blocks_blockid_img_mlp_2.lora_down.weight": "blocks.blockid.ff_a.2.lora_A.weight", + "lora_unet_double_blocks_blockid_img_mlp_2.lora_up.weight": "blocks.blockid.ff_a.2.lora_B.weight", + "lora_unet_double_blocks_blockid_txt_mlp_0.lora_down.weight": "blocks.blockid.ff_b.0.lora_A.weight", + "lora_unet_double_blocks_blockid_txt_mlp_0.lora_up.weight": "blocks.blockid.ff_b.0.lora_B.weight", + "lora_unet_double_blocks_blockid_txt_mlp_2.lora_down.weight": "blocks.blockid.ff_b.2.lora_A.weight", + "lora_unet_double_blocks_blockid_txt_mlp_2.lora_up.weight": "blocks.blockid.ff_b.2.lora_B.weight", + "lora_unet_single_blocks_blockid_modulation_lin.lora_down.weight": "single_blocks.blockid.norm.linear.lora_A.weight", + "lora_unet_single_blocks_blockid_modulation_lin.lora_up.weight": "single_blocks.blockid.norm.linear.lora_B.weight", + "lora_unet_single_blocks_blockid_linear1.lora_down.weight": "single_blocks.blockid.to_qkv_mlp.lora_A.weight", + "lora_unet_single_blocks_blockid_linear1.lora_up.weight": "single_blocks.blockid.to_qkv_mlp.lora_B.weight", + "lora_unet_single_blocks_blockid_linear2.lora_down.weight": "single_blocks.blockid.proj_out.lora_A.weight", + "lora_unet_single_blocks_blockid_linear2.lora_up.weight": "single_blocks.blockid.proj_out.lora_B.weight", + } + + def fuse_lora_to_base_model(self, model: torch.nn.Module, state_dict_lora, alpha=1.0): + super().fuse_lora_to_base_model(model, state_dict_lora, alpha) + + def convert_state_dict(self, state_dict): + + def guess_block_id(name,model_resource): + if model_resource == 'civitai': + names = name.split("_") + for i in names: + if i.isdigit(): + return i, name.replace(f"_{i}_", "_blockid_") + if model_resource == 'diffusers': + names = name.split(".") + for i in names: + if i.isdigit(): + return i, name.replace(f"transformer_blocks.{i}.", "transformer_blocks.blockid.") + return None, None + + def guess_resource(state_dict): + for k in state_dict: + if "lora_unet_" in k: + return 'civitai' + elif k.startswith("transformer."): + return 'diffusers' + else: + None + + model_resource = guess_resource(state_dict) + if model_resource is None: + return state_dict + + rename_dict = self.diffusers_rename_dict if model_resource == 'diffusers' else self.civitai_rename_dict + def guess_alpha(state_dict): + for name, param in state_dict.items(): + if ".alpha" in name: + for suffix in [".lora_down.weight", ".lora_A.weight"]: + name_ = name.replace(".alpha", suffix) + if name_ in state_dict: + lora_alpha = param.item() / state_dict[name_].shape[0] + lora_alpha = math.sqrt(lora_alpha) + return lora_alpha + + return 1 + + alpha = guess_alpha(state_dict) + + state_dict_ = {} + for name, param in state_dict.items(): + block_id, source_name = guess_block_id(name,model_resource) + if alpha != 1: + param *= alpha + if source_name in rename_dict: + target_name = rename_dict[source_name] + target_name = target_name.replace(".blockid.", f".{block_id}.") + state_dict_[target_name] = param + else: + state_dict_[name] = param + + if model_resource == 'diffusers': + for name in list(state_dict_.keys()): + if "single_blocks." in name and ".a_to_q." in name: + mlp = state_dict_.get(name.replace(".a_to_q.", ".proj_in_besides_attn."), None) + if mlp is None: + dim = 4 + if 'lora_A' in name: + dim = 1 + mlp = torch.zeros(dim * state_dict_[name].shape[0], + *state_dict_[name].shape[1:], + dtype=state_dict_[name].dtype) + else: + state_dict_.pop(name.replace(".a_to_q.", ".proj_in_besides_attn.")) + + mlp = mlp.to(device=state_dict_[name].device) + if 'lora_A' in name: + param = torch.concat([ + state_dict_.pop(name), + state_dict_.pop(name.replace(".a_to_q.", ".a_to_k.")), + state_dict_.pop(name.replace(".a_to_q.", ".a_to_v.")), + mlp, + ], dim=0) + elif 'lora_B' in name: + d, r = state_dict_[name].shape + param = torch.zeros((3*d+mlp.shape[0], 3*r+mlp.shape[1]), dtype=state_dict_[name].dtype, device=state_dict_[name].device) + param[:d, :r] = state_dict_.pop(name) + param[d:2*d, r:2*r] = state_dict_.pop(name.replace(".a_to_q.", ".a_to_k.")) + param[2*d:3*d, 2*r:3*r] = state_dict_.pop(name.replace(".a_to_q.", ".a_to_v.")) + param[3*d:, 3*r:] = mlp + else: + param = torch.concat([ + state_dict_.pop(name), + state_dict_.pop(name.replace(".a_to_q.", ".a_to_k.")), + state_dict_.pop(name.replace(".a_to_q.", ".a_to_v.")), + mlp, + ], dim=0) + name_ = name.replace(".a_to_q.", ".to_qkv_mlp.") + state_dict_[name_] = param + for name in list(state_dict_.keys()): + for component in ["a", "b"]: + if f".{component}_to_q." in name: + name_ = name.replace(f".{component}_to_q.", f".{component}_to_qkv.") + concat_dim = 0 + if 'lora_A' in name: + param = torch.concat([ + state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")], + state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")], + state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")], + ], dim=0) + elif 'lora_B' in name: + origin = state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")] + d, r = origin.shape + # print(d, r) + param = torch.zeros((3*d, 3*r), dtype=origin.dtype, device=origin.device) + param[:d, :r] = state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")] + param[d:2*d, r:2*r] = state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")] + param[2*d:3*d, 2*r:3*r] = state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")] + else: + param = torch.concat([ + state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")], + state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")], + state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")], + ], dim=0) + state_dict_[name_] = param + state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_q.")) + state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_k.")) + state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_v.")) + return state_dict_ + + +class FluxLoRAConverter: + def __init__(self): + pass + + @staticmethod + def align_to_opensource_format(state_dict, alpha=None): + prefix_rename_dict = { + "single_blocks": "lora_unet_single_blocks", + "blocks": "lora_unet_double_blocks", + } + middle_rename_dict = { + "norm.linear": "modulation_lin", + "to_qkv_mlp": "linear1", + "proj_out": "linear2", + + "norm1_a.linear": "img_mod_lin", + "norm1_b.linear": "txt_mod_lin", + "attn.a_to_qkv": "img_attn_qkv", + "attn.b_to_qkv": "txt_attn_qkv", + "attn.a_to_out": "img_attn_proj", + "attn.b_to_out": "txt_attn_proj", + "ff_a.0": "img_mlp_0", + "ff_a.2": "img_mlp_2", + "ff_b.0": "txt_mlp_0", + "ff_b.2": "txt_mlp_2", + } + suffix_rename_dict = { + "lora_B.weight": "lora_up.weight", + "lora_A.weight": "lora_down.weight", + } + state_dict_ = {} + for name, param in state_dict.items(): + names = name.split(".") + if names[-2] != "lora_A" and names[-2] != "lora_B": + names.pop(-2) + prefix = names[0] + middle = ".".join(names[2:-2]) + suffix = ".".join(names[-2:]) + block_id = names[1] + if middle not in middle_rename_dict: + continue + rename = prefix_rename_dict[prefix] + "_" + block_id + "_" + middle_rename_dict[middle] + "." + suffix_rename_dict[suffix] + state_dict_[rename] = param + if rename.endswith("lora_up.weight"): + lora_alpha = alpha if alpha is not None else param.shape[-1] + state_dict_[rename.replace("lora_up.weight", "alpha")] = torch.tensor((lora_alpha,))[0] + return state_dict_ + + @staticmethod + def align_to_diffsynth_format(state_dict): + rename_dict = { + "lora_unet_double_blocks_blockid_img_mod_lin.lora_down.weight": "blocks.blockid.norm1_a.linear.lora_A.default.weight", + "lora_unet_double_blocks_blockid_img_mod_lin.lora_up.weight": "blocks.blockid.norm1_a.linear.lora_B.default.weight", + "lora_unet_double_blocks_blockid_txt_mod_lin.lora_down.weight": "blocks.blockid.norm1_b.linear.lora_A.default.weight", + "lora_unet_double_blocks_blockid_txt_mod_lin.lora_up.weight": "blocks.blockid.norm1_b.linear.lora_B.default.weight", + "lora_unet_double_blocks_blockid_img_attn_qkv.lora_down.weight": "blocks.blockid.attn.a_to_qkv.lora_A.default.weight", + "lora_unet_double_blocks_blockid_img_attn_qkv.lora_up.weight": "blocks.blockid.attn.a_to_qkv.lora_B.default.weight", + "lora_unet_double_blocks_blockid_txt_attn_qkv.lora_down.weight": "blocks.blockid.attn.b_to_qkv.lora_A.default.weight", + "lora_unet_double_blocks_blockid_txt_attn_qkv.lora_up.weight": "blocks.blockid.attn.b_to_qkv.lora_B.default.weight", + "lora_unet_double_blocks_blockid_img_attn_proj.lora_down.weight": "blocks.blockid.attn.a_to_out.lora_A.default.weight", + "lora_unet_double_blocks_blockid_img_attn_proj.lora_up.weight": "blocks.blockid.attn.a_to_out.lora_B.default.weight", + "lora_unet_double_blocks_blockid_txt_attn_proj.lora_down.weight": "blocks.blockid.attn.b_to_out.lora_A.default.weight", + "lora_unet_double_blocks_blockid_txt_attn_proj.lora_up.weight": "blocks.blockid.attn.b_to_out.lora_B.default.weight", + "lora_unet_double_blocks_blockid_img_mlp_0.lora_down.weight": "blocks.blockid.ff_a.0.lora_A.default.weight", + "lora_unet_double_blocks_blockid_img_mlp_0.lora_up.weight": "blocks.blockid.ff_a.0.lora_B.default.weight", + "lora_unet_double_blocks_blockid_img_mlp_2.lora_down.weight": "blocks.blockid.ff_a.2.lora_A.default.weight", + "lora_unet_double_blocks_blockid_img_mlp_2.lora_up.weight": "blocks.blockid.ff_a.2.lora_B.default.weight", + "lora_unet_double_blocks_blockid_txt_mlp_0.lora_down.weight": "blocks.blockid.ff_b.0.lora_A.default.weight", + "lora_unet_double_blocks_blockid_txt_mlp_0.lora_up.weight": "blocks.blockid.ff_b.0.lora_B.default.weight", + "lora_unet_double_blocks_blockid_txt_mlp_2.lora_down.weight": "blocks.blockid.ff_b.2.lora_A.default.weight", + "lora_unet_double_blocks_blockid_txt_mlp_2.lora_up.weight": "blocks.blockid.ff_b.2.lora_B.default.weight", + "lora_unet_single_blocks_blockid_modulation_lin.lora_down.weight": "single_blocks.blockid.norm.linear.lora_A.default.weight", + "lora_unet_single_blocks_blockid_modulation_lin.lora_up.weight": "single_blocks.blockid.norm.linear.lora_B.default.weight", + "lora_unet_single_blocks_blockid_linear1.lora_down.weight": "single_blocks.blockid.to_qkv_mlp.lora_A.default.weight", + "lora_unet_single_blocks_blockid_linear1.lora_up.weight": "single_blocks.blockid.to_qkv_mlp.lora_B.default.weight", + "lora_unet_single_blocks_blockid_linear2.lora_down.weight": "single_blocks.blockid.proj_out.lora_A.default.weight", + "lora_unet_single_blocks_blockid_linear2.lora_up.weight": "single_blocks.blockid.proj_out.lora_B.default.weight", + } + def guess_block_id(name): + names = name.split("_") + for i in names: + if i.isdigit(): + return i, name.replace(f"_{i}_", "_blockid_") + return None, None + state_dict_ = {} + for name, param in state_dict.items(): + block_id, source_name = guess_block_id(name) + if source_name in rename_dict: + target_name = rename_dict[source_name] + target_name = target_name.replace(".blockid.", f".{block_id}.") + state_dict_[target_name] = param + else: + state_dict_[name] = param + return state_dict_ diff --git a/diffsynth/utils/lora/general.py b/diffsynth/utils/lora/general.py new file mode 100644 index 0000000000000000000000000000000000000000..85ada77598b563c597bdd9d638076eeadc896669 --- /dev/null +++ b/diffsynth/utils/lora/general.py @@ -0,0 +1,70 @@ +import torch, warnings + + +class GeneralLoRALoader: + def __init__(self, device="cpu", torch_dtype=torch.float32): + self.device = device + self.torch_dtype = torch_dtype + + + def get_name_dict(self, lora_state_dict): + lora_name_dict = {} + for key in lora_state_dict: + if ".lora_up." in key: + lora_A_key = "lora_down" + lora_B_key = "lora_up" + else: + lora_A_key = "lora_A" + lora_B_key = "lora_B" + if lora_B_key not in key: + continue + keys = key.split(".") + if len(keys) > keys.index(lora_B_key) + 2: + keys.pop(keys.index(lora_B_key) + 1) + keys.pop(keys.index(lora_B_key)) + if keys[0] == "diffusion_model": + keys.pop(0) + keys.pop(-1) + target_name = ".".join(keys) + # Alpha: Deprecated but retained for compatibility. + key_alpha = key.replace(lora_B_key + ".weight", "alpha").replace(lora_B_key + ".default.weight", "alpha") + if key_alpha == key or key_alpha not in lora_state_dict: + key_alpha = None + lora_name_dict[target_name] = (key, key.replace(lora_B_key, lora_A_key), key_alpha) + return lora_name_dict + + + def convert_state_dict(self, state_dict, suffix=".weight"): + name_dict = self.get_name_dict(state_dict) + state_dict_ = {} + for name in name_dict: + weight_up = state_dict[name_dict[name][0]] + weight_down = state_dict[name_dict[name][1]] + if name_dict[name][2] is not None: + warnings.warn("Alpha detected in the LoRA file. This may be a LoRA model not trained by DiffSynth-Studio. To ensure compatibility, the LoRA weights will be converted to weight * alpha / rank.") + alpha = state_dict[name_dict[name][2]] / weight_down.shape[0] + weight_down = weight_down * alpha + state_dict_[name + f".lora_B{suffix}"] = weight_up + state_dict_[name + f".lora_A{suffix}"] = weight_down + return state_dict_ + + + def fuse_lora_to_base_model(self, model: torch.nn.Module, state_dict, alpha=1.0): + updated_num = 0 + state_dict = self.convert_state_dict(state_dict) + lora_layer_names = set([i.replace(".lora_B.weight", "") for i in state_dict if i.endswith(".lora_B.weight")]) + for name, module in model.named_modules(): + if name in lora_layer_names: + weight_up = state_dict[name + ".lora_B.weight"].to(device=self.device, dtype=self.torch_dtype) + weight_down = state_dict[name + ".lora_A.weight"].to(device=self.device, dtype=self.torch_dtype) + if len(weight_up.shape) == 4: + weight_up = weight_up.squeeze(3).squeeze(2) + weight_down = weight_down.squeeze(3).squeeze(2) + weight_lora = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3) + else: + weight_lora = alpha * torch.mm(weight_up, weight_down) + state_dict_base = module.state_dict() + state_dict_base["weight"] = state_dict_base["weight"].to(device=self.device, dtype=self.torch_dtype) + weight_lora + module.load_state_dict(state_dict_base) + updated_num += 1 + print(f"{updated_num} tensors are fused by LoRA. Fused LoRA layers cannot be cleared by `pipe.clear_lora()`.") diff --git a/diffsynth/utils/lora/merge.py b/diffsynth/utils/lora/merge.py new file mode 100644 index 0000000000000000000000000000000000000000..61904ff4bcebc6c344c23f26073aec292355217c --- /dev/null +++ b/diffsynth/utils/lora/merge.py @@ -0,0 +1,20 @@ +import torch +from typing import Dict, List + + +def merge_lora_weight(tensors_A, tensors_B): + lora_A = torch.concat(tensors_A, dim=0) + lora_B = torch.concat(tensors_B, dim=1) + return lora_A, lora_B + + +def merge_lora(loras: List[Dict[str, torch.Tensor]], alpha=1): + lora_merged = {} + keys = [i for i in loras[0].keys() if ".lora_A." in i] + for key in keys: + tensors_A = [lora[key] for lora in loras] + tensors_B = [lora[key.replace(".lora_A.", ".lora_B.")] for lora in loras] + lora_A, lora_B = merge_lora_weight(tensors_A, tensors_B) + lora_merged[key] = lora_A * alpha + lora_merged[key.replace(".lora_A.", ".lora_B.")] = lora_B + return lora_merged diff --git a/diffsynth/utils/lora/reset_rank.py b/diffsynth/utils/lora/reset_rank.py new file mode 100644 index 0000000000000000000000000000000000000000..9522b043ff962bc050fa79596197f00abf3877b0 --- /dev/null +++ b/diffsynth/utils/lora/reset_rank.py @@ -0,0 +1,20 @@ +import torch + +def decomposite(tensor_A, tensor_B, rank): + dtype, device = tensor_A.dtype, tensor_A.device + weight = tensor_B @ tensor_A + U, S, V = torch.pca_lowrank(weight.float(), q=rank) + tensor_A = (V.T).to(dtype=dtype, device=device).contiguous() + tensor_B = (U @ torch.diag(S)).to(dtype=dtype, device=device).contiguous() + return tensor_A, tensor_B + +def reset_lora_rank(lora, rank): + lora_merged = {} + keys = [i for i in lora.keys() if ".lora_A." in i] + for key in keys: + tensor_A = lora[key] + tensor_B = lora[key.replace(".lora_A.", ".lora_B.")] + tensor_A, tensor_B = decomposite(tensor_A, tensor_B, rank) + lora_merged[key] = tensor_A + lora_merged[key.replace(".lora_A.", ".lora_B.")] = tensor_B + return lora_merged \ No newline at end of file diff --git a/diffsynth/utils/ses/README.md b/diffsynth/utils/ses/README.md new file mode 100644 index 0000000000000000000000000000000000000000..67ffa00840b3d729bb09c6924f10b9ad4153ed30 --- /dev/null +++ b/diffsynth/utils/ses/README.md @@ -0,0 +1 @@ +Please see `docs/en/Research_Tutorial/inference_time_scaling.md` or `docs/zh/Research_Tutorial/inference_time_scaling.md` for more details. diff --git a/diffsynth/utils/ses/__init__.py b/diffsynth/utils/ses/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8a9906105a78c67f85993f19977cc77dac9c4ba4 --- /dev/null +++ b/diffsynth/utils/ses/__init__.py @@ -0,0 +1 @@ +from .ses import ses_search \ No newline at end of file diff --git a/diffsynth/utils/ses/ses.py b/diffsynth/utils/ses/ses.py new file mode 100644 index 0000000000000000000000000000000000000000..d1c0b7a520f3bd64f1fe9c45bc06f134699c4856 --- /dev/null +++ b/diffsynth/utils/ses/ses.py @@ -0,0 +1,117 @@ +import torch +import pywt +import numpy as np +from tqdm import tqdm + + +def split_dwt(z_tensor_cpu, wavelet_name, dwt_level): + all_clow_np = [] + all_chigh_list = [] + z_tensor_cpu = z_tensor_cpu.cpu().float() + + for i in range(z_tensor_cpu.shape[0]): + z_numpy_ch = z_tensor_cpu[i].numpy() + + coeffs_ch = pywt.wavedec2(z_numpy_ch, wavelet_name, level=dwt_level, mode='symmetric', axes=(-2, -1)) + + clow_np = coeffs_ch[0] + chigh_list = coeffs_ch[1:] + + all_clow_np.append(clow_np) + all_chigh_list.append(chigh_list) + + all_clow_tensor = torch.from_numpy(np.stack(all_clow_np, axis=0)) + return all_clow_tensor, all_chigh_list + + +def reconstruct_dwt(c_low_tensor_cpu, c_high_coeffs, wavelet_name, original_shape): + H_high, W_high = original_shape + c_low_tensor_cpu = c_low_tensor_cpu.cpu().float() + + clow_np = c_low_tensor_cpu.numpy() + + if clow_np.ndim == 4 and clow_np.shape[0] == 1: + clow_np = clow_np[0] + + coeffs_combined = [clow_np] + c_high_coeffs + z_recon_np = pywt.waverec2(coeffs_combined, wavelet_name, mode='symmetric', axes=(-2, -1)) + if z_recon_np.shape[-2] != H_high or z_recon_np.shape[-1] != W_high: + z_recon_np = z_recon_np[..., :H_high, :W_high] + z_recon_tensor = torch.from_numpy(z_recon_np) + if z_recon_tensor.ndim == 3: + z_recon_tensor = z_recon_tensor.unsqueeze(0) + return z_recon_tensor + + +def ses_search( + base_latents, + objective_reward_fn, + total_eval_budget=30, + popsize=10, + k_elites=5, + wavelet_name="db1", + dwt_level=4, +): + latent_h, latent_w = base_latents.shape[-2], base_latents.shape[-1] + c_low_init, c_high_fixed_batch = split_dwt(base_latents, wavelet_name, dwt_level) + c_high_fixed = c_high_fixed_batch[0] + c_low_shape = c_low_init.shape[1:] + mu = torch.zeros_like(c_low_init.view(-1).cpu()) + sigma_sq = torch.ones_like(mu) * 1.0 + + best_overall = {"fitness": -float('inf'), "score": -float('inf'), "c_low": c_low_init[0]} + eval_count = 0 + + elite_db = [] + n_generations = (total_eval_budget // popsize) + 5 + pbar = tqdm(total=total_eval_budget, desc="[SES] Searching", unit="img") + + for gen in range(n_generations): + if eval_count >= total_eval_budget: break + + std = torch.sqrt(torch.clamp(sigma_sq, min=1e-9)) + z_noise = torch.randn(popsize, mu.shape[0]) + samples_flat = mu + z_noise * std + samples_reshaped = samples_flat.view(popsize, *c_low_shape) + + batch_results = [] + + for i in range(popsize): + if eval_count >= total_eval_budget: break + + c_low_sample = samples_reshaped[i].unsqueeze(0) + z_recon = reconstruct_dwt(c_low_sample, c_high_fixed, wavelet_name, (latent_h, latent_w)) + z_recon = z_recon.to(base_latents.device, dtype=base_latents.dtype) + # img = pipeline_callback(z_recon) + + # score = scorer.get_score(img, prompt) + score = objective_reward_fn(z_recon) + res = { + "score": score, + "c_low": c_low_sample.cpu() + } + batch_results.append(res) + if score > best_overall['score']: + best_overall = res + + eval_count += 1 + pbar.update(1) + + if not batch_results: break + elite_db.extend(batch_results) + elite_db.sort(key=lambda x: x['score'], reverse=True) + elite_db = elite_db[:k_elites] + elites_flat = torch.stack([x['c_low'].view(-1) for x in elite_db]) + mu_new = torch.mean(elites_flat, dim=0) + + if len(elite_db) > 1: + sigma_sq_new = torch.var(elites_flat, dim=0, unbiased=True) + 1e-7 + else: + sigma_sq_new = sigma_sq + mu = mu_new + sigma_sq = sigma_sq_new + pbar.close() + best_c_low = best_overall['c_low'] + final_latents = reconstruct_dwt(best_c_low, c_high_fixed, wavelet_name, (latent_h, latent_w)) + + return final_latents.to(base_latents.device, dtype=base_latents.dtype) diff --git a/diffsynth/utils/state_dict_converters/__init__.py b/diffsynth/utils/state_dict_converters/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/diffsynth/utils/state_dict_converters/anima_dit.py b/diffsynth/utils/state_dict_converters/anima_dit.py new file mode 100644 index 0000000000000000000000000000000000000000..16afc768370eb10cca06e899d9b509af5ed5ad8d --- /dev/null +++ b/diffsynth/utils/state_dict_converters/anima_dit.py @@ -0,0 +1,6 @@ +def AnimaDiTStateDictConverter(state_dict): + new_state_dict = {} + for key in state_dict: + value = state_dict[key] + new_state_dict[key.replace("net.", "")] = value + return new_state_dict diff --git a/diffsynth/utils/state_dict_converters/flux2_text_encoder.py b/diffsynth/utils/state_dict_converters/flux2_text_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..0975e62a35021c697192ad054f0e3aff42289292 --- /dev/null +++ b/diffsynth/utils/state_dict_converters/flux2_text_encoder.py @@ -0,0 +1,17 @@ +def Flux2TextEncoderStateDictConverter(state_dict): + rename_dict = { + "multi_modal_projector.linear_1.weight": "model.multi_modal_projector.linear_1.weight", + "multi_modal_projector.linear_2.weight": "model.multi_modal_projector.linear_2.weight", + "multi_modal_projector.norm.weight": "model.multi_modal_projector.norm.weight", + "multi_modal_projector.patch_merger.merging_layer.weight": "model.multi_modal_projector.patch_merger.merging_layer.weight", + "language_model.lm_head.weight": "lm_head.weight", + } + state_dict_ = {} + for k in state_dict: + k_ = k + k_ = k_.replace("language_model.model", "model.language_model") + k_ = k_.replace("vision_tower", "model.vision_tower") + if k_ in rename_dict: + k_ = rename_dict[k_] + state_dict_[k_] = state_dict[k] + return state_dict_ diff --git a/diffsynth/utils/state_dict_converters/flux_controlnet.py b/diffsynth/utils/state_dict_converters/flux_controlnet.py new file mode 100644 index 0000000000000000000000000000000000000000..15f9447d22bc0ebc2dbb3d2eac8dbf0bd78e4151 --- /dev/null +++ b/diffsynth/utils/state_dict_converters/flux_controlnet.py @@ -0,0 +1,103 @@ +import torch + + +def FluxControlNetStateDictConverter(state_dict): + global_rename_dict = { + "context_embedder": "context_embedder", + "x_embedder": "x_embedder", + "time_text_embed.timestep_embedder.linear_1": "time_embedder.timestep_embedder.0", + "time_text_embed.timestep_embedder.linear_2": "time_embedder.timestep_embedder.2", + "time_text_embed.guidance_embedder.linear_1": "guidance_embedder.timestep_embedder.0", + "time_text_embed.guidance_embedder.linear_2": "guidance_embedder.timestep_embedder.2", + "time_text_embed.text_embedder.linear_1": "pooled_text_embedder.0", + "time_text_embed.text_embedder.linear_2": "pooled_text_embedder.2", + "norm_out.linear": "final_norm_out.linear", + "proj_out": "final_proj_out", + } + rename_dict = { + "proj_out": "proj_out", + "norm1.linear": "norm1_a.linear", + "norm1_context.linear": "norm1_b.linear", + "attn.to_q": "attn.a_to_q", + "attn.to_k": "attn.a_to_k", + "attn.to_v": "attn.a_to_v", + "attn.to_out.0": "attn.a_to_out", + "attn.add_q_proj": "attn.b_to_q", + "attn.add_k_proj": "attn.b_to_k", + "attn.add_v_proj": "attn.b_to_v", + "attn.to_add_out": "attn.b_to_out", + "ff.net.0.proj": "ff_a.0", + "ff.net.2": "ff_a.2", + "ff_context.net.0.proj": "ff_b.0", + "ff_context.net.2": "ff_b.2", + "attn.norm_q": "attn.norm_q_a", + "attn.norm_k": "attn.norm_k_a", + "attn.norm_added_q": "attn.norm_q_b", + "attn.norm_added_k": "attn.norm_k_b", + } + rename_dict_single = { + "attn.to_q": "a_to_q", + "attn.to_k": "a_to_k", + "attn.to_v": "a_to_v", + "attn.norm_q": "norm_q_a", + "attn.norm_k": "norm_k_a", + "norm.linear": "norm.linear", + "proj_mlp": "proj_in_besides_attn", + "proj_out": "proj_out", + } + state_dict_ = {} + + for name in state_dict: + param = state_dict[name] + if name.endswith(".weight") or name.endswith(".bias"): + suffix = ".weight" if name.endswith(".weight") else ".bias" + prefix = name[:-len(suffix)] + if prefix in global_rename_dict: + state_dict_[global_rename_dict[prefix] + suffix] = param + elif prefix.startswith("transformer_blocks."): + names = prefix.split(".") + names[0] = "blocks" + middle = ".".join(names[2:]) + if middle in rename_dict: + name_ = ".".join(names[:2] + [rename_dict[middle]] + [suffix[1:]]) + state_dict_[name_] = param + elif prefix.startswith("single_transformer_blocks."): + names = prefix.split(".") + names[0] = "single_blocks" + middle = ".".join(names[2:]) + if middle in rename_dict_single: + name_ = ".".join(names[:2] + [rename_dict_single[middle]] + [suffix[1:]]) + state_dict_[name_] = param + else: + state_dict_[name] = param + else: + state_dict_[name] = param + for name in list(state_dict_.keys()): + if ".proj_in_besides_attn." in name: + name_ = name.replace(".proj_in_besides_attn.", ".to_qkv_mlp.") + param = torch.concat([ + state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_q.")], + state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_k.")], + state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_v.")], + state_dict_[name], + ], dim=0) + state_dict_[name_] = param + state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_q.")) + state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_k.")) + state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_v.")) + state_dict_.pop(name) + for name in list(state_dict_.keys()): + for component in ["a", "b"]: + if f".{component}_to_q." in name: + name_ = name.replace(f".{component}_to_q.", f".{component}_to_qkv.") + param = torch.concat([ + state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")], + state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")], + state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")], + ], dim=0) + state_dict_[name_] = param + state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_q.")) + state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_k.")) + state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_v.")) + + return state_dict_ \ No newline at end of file diff --git a/diffsynth/utils/state_dict_converters/flux_dit.py b/diffsynth/utils/state_dict_converters/flux_dit.py new file mode 100644 index 0000000000000000000000000000000000000000..f808b60defd800ff97ec78fec1dac6f472038cb7 --- /dev/null +++ b/diffsynth/utils/state_dict_converters/flux_dit.py @@ -0,0 +1,197 @@ +import torch + + +def FluxDiTStateDictConverter(state_dict): + is_nexus_gen = sum([key.startswith("pipe.dit.") for key in state_dict]) > 0 + if is_nexus_gen: + dit_state_dict = {} + for key in state_dict: + if key.startswith('pipe.dit.'): + param = state_dict[key] + new_key = key.replace("pipe.dit.", "") + if new_key.startswith("final_norm_out.linear."): + param = torch.concat([param[3072:], param[:3072]], dim=0) + dit_state_dict[new_key] = param + return dit_state_dict + + rename_dict = { + "time_in.in_layer.bias": "time_embedder.timestep_embedder.0.bias", + "time_in.in_layer.weight": "time_embedder.timestep_embedder.0.weight", + "time_in.out_layer.bias": "time_embedder.timestep_embedder.2.bias", + "time_in.out_layer.weight": "time_embedder.timestep_embedder.2.weight", + "txt_in.bias": "context_embedder.bias", + "txt_in.weight": "context_embedder.weight", + "vector_in.in_layer.bias": "pooled_text_embedder.0.bias", + "vector_in.in_layer.weight": "pooled_text_embedder.0.weight", + "vector_in.out_layer.bias": "pooled_text_embedder.2.bias", + "vector_in.out_layer.weight": "pooled_text_embedder.2.weight", + "final_layer.linear.bias": "final_proj_out.bias", + "final_layer.linear.weight": "final_proj_out.weight", + "guidance_in.in_layer.bias": "guidance_embedder.timestep_embedder.0.bias", + "guidance_in.in_layer.weight": "guidance_embedder.timestep_embedder.0.weight", + "guidance_in.out_layer.bias": "guidance_embedder.timestep_embedder.2.bias", + "guidance_in.out_layer.weight": "guidance_embedder.timestep_embedder.2.weight", + "img_in.bias": "x_embedder.bias", + "img_in.weight": "x_embedder.weight", + "final_layer.adaLN_modulation.1.weight": "final_norm_out.linear.weight", + "final_layer.adaLN_modulation.1.bias": "final_norm_out.linear.bias", + } + suffix_rename_dict = { + "img_attn.norm.key_norm.scale": "attn.norm_k_a.weight", + "img_attn.norm.query_norm.scale": "attn.norm_q_a.weight", + "img_attn.proj.bias": "attn.a_to_out.bias", + "img_attn.proj.weight": "attn.a_to_out.weight", + "img_attn.qkv.bias": "attn.a_to_qkv.bias", + "img_attn.qkv.weight": "attn.a_to_qkv.weight", + "img_mlp.0.bias": "ff_a.0.bias", + "img_mlp.0.weight": "ff_a.0.weight", + "img_mlp.2.bias": "ff_a.2.bias", + "img_mlp.2.weight": "ff_a.2.weight", + "img_mod.lin.bias": "norm1_a.linear.bias", + "img_mod.lin.weight": "norm1_a.linear.weight", + "txt_attn.norm.key_norm.scale": "attn.norm_k_b.weight", + "txt_attn.norm.query_norm.scale": "attn.norm_q_b.weight", + "txt_attn.proj.bias": "attn.b_to_out.bias", + "txt_attn.proj.weight": "attn.b_to_out.weight", + "txt_attn.qkv.bias": "attn.b_to_qkv.bias", + "txt_attn.qkv.weight": "attn.b_to_qkv.weight", + "txt_mlp.0.bias": "ff_b.0.bias", + "txt_mlp.0.weight": "ff_b.0.weight", + "txt_mlp.2.bias": "ff_b.2.bias", + "txt_mlp.2.weight": "ff_b.2.weight", + "txt_mod.lin.bias": "norm1_b.linear.bias", + "txt_mod.lin.weight": "norm1_b.linear.weight", + + "linear1.bias": "to_qkv_mlp.bias", + "linear1.weight": "to_qkv_mlp.weight", + "linear2.bias": "proj_out.bias", + "linear2.weight": "proj_out.weight", + "modulation.lin.bias": "norm.linear.bias", + "modulation.lin.weight": "norm.linear.weight", + "norm.key_norm.scale": "norm_k_a.weight", + "norm.query_norm.scale": "norm_q_a.weight", + } + state_dict_ = {} + for name in state_dict: + original_name = name + if name.startswith("model.diffusion_model."): + name = name[len("model.diffusion_model."):] + names = name.split(".") + if name in rename_dict: + rename = rename_dict[name] + state_dict_[rename] = state_dict[original_name] + elif names[0] == "double_blocks": + rename = f"blocks.{names[1]}." + suffix_rename_dict[".".join(names[2:])] + state_dict_[rename] = state_dict[original_name] + elif names[0] == "single_blocks": + if ".".join(names[2:]) in suffix_rename_dict: + rename = f"single_blocks.{names[1]}." + suffix_rename_dict[".".join(names[2:])] + state_dict_[rename] = state_dict[original_name] + else: + pass + return state_dict_ + + +def FluxDiTStateDictConverterFromDiffusers(state_dict): + global_rename_dict = { + "context_embedder": "context_embedder", + "x_embedder": "x_embedder", + "time_text_embed.timestep_embedder.linear_1": "time_embedder.timestep_embedder.0", + "time_text_embed.timestep_embedder.linear_2": "time_embedder.timestep_embedder.2", + "time_text_embed.guidance_embedder.linear_1": "guidance_embedder.timestep_embedder.0", + "time_text_embed.guidance_embedder.linear_2": "guidance_embedder.timestep_embedder.2", + "time_text_embed.text_embedder.linear_1": "pooled_text_embedder.0", + "time_text_embed.text_embedder.linear_2": "pooled_text_embedder.2", + "norm_out.linear": "final_norm_out.linear", + "proj_out": "final_proj_out", + } + rename_dict = { + "proj_out": "proj_out", + "norm1.linear": "norm1_a.linear", + "norm1_context.linear": "norm1_b.linear", + "attn.to_q": "attn.a_to_q", + "attn.to_k": "attn.a_to_k", + "attn.to_v": "attn.a_to_v", + "attn.to_out.0": "attn.a_to_out", + "attn.add_q_proj": "attn.b_to_q", + "attn.add_k_proj": "attn.b_to_k", + "attn.add_v_proj": "attn.b_to_v", + "attn.to_add_out": "attn.b_to_out", + "ff.net.0.proj": "ff_a.0", + "ff.net.2": "ff_a.2", + "ff_context.net.0.proj": "ff_b.0", + "ff_context.net.2": "ff_b.2", + "attn.norm_q": "attn.norm_q_a", + "attn.norm_k": "attn.norm_k_a", + "attn.norm_added_q": "attn.norm_q_b", + "attn.norm_added_k": "attn.norm_k_b", + } + rename_dict_single = { + "attn.to_q": "a_to_q", + "attn.to_k": "a_to_k", + "attn.to_v": "a_to_v", + "attn.norm_q": "norm_q_a", + "attn.norm_k": "norm_k_a", + "norm.linear": "norm.linear", + "proj_mlp": "proj_in_besides_attn", + "proj_out": "proj_out", + } + state_dict_ = {} + for name in state_dict: + param = state_dict[name] + if name.endswith(".weight") or name.endswith(".bias"): + suffix = ".weight" if name.endswith(".weight") else ".bias" + prefix = name[:-len(suffix)] + if prefix in global_rename_dict: + if global_rename_dict[prefix] == "final_norm_out.linear": + param = torch.concat([param[3072:], param[:3072]], dim=0) + state_dict_[global_rename_dict[prefix] + suffix] = param + elif prefix.startswith("transformer_blocks."): + names = prefix.split(".") + names[0] = "blocks" + middle = ".".join(names[2:]) + if middle in rename_dict: + name_ = ".".join(names[:2] + [rename_dict[middle]] + [suffix[1:]]) + state_dict_[name_] = param + elif prefix.startswith("single_transformer_blocks."): + names = prefix.split(".") + names[0] = "single_blocks" + middle = ".".join(names[2:]) + if middle in rename_dict_single: + name_ = ".".join(names[:2] + [rename_dict_single[middle]] + [suffix[1:]]) + state_dict_[name_] = param + else: + pass + else: + pass + for name in list(state_dict_.keys()): + if "single_blocks." in name and ".a_to_q." in name: + mlp = state_dict_.get(name.replace(".a_to_q.", ".proj_in_besides_attn."), None) + if mlp is None: + mlp = torch.zeros(4 * state_dict_[name].shape[0], + *state_dict_[name].shape[1:], + dtype=state_dict_[name].dtype) + else: + state_dict_.pop(name.replace(".a_to_q.", ".proj_in_besides_attn.")) + param = torch.concat([ + state_dict_.pop(name), + state_dict_.pop(name.replace(".a_to_q.", ".a_to_k.")), + state_dict_.pop(name.replace(".a_to_q.", ".a_to_v.")), + mlp, + ], dim=0) + name_ = name.replace(".a_to_q.", ".to_qkv_mlp.") + state_dict_[name_] = param + for name in list(state_dict_.keys()): + for component in ["a", "b"]: + if f".{component}_to_q." in name: + name_ = name.replace(f".{component}_to_q.", f".{component}_to_qkv.") + param = torch.concat([ + state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")], + state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")], + state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")], + ], dim=0) + state_dict_[name_] = param + state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_q.")) + state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_k.")) + state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_v.")) + return state_dict_ \ No newline at end of file diff --git a/diffsynth/utils/state_dict_converters/flux_infiniteyou.py b/diffsynth/utils/state_dict_converters/flux_infiniteyou.py new file mode 100644 index 0000000000000000000000000000000000000000..7025b392d54c5b4844ed3b3387bd010217897f4a --- /dev/null +++ b/diffsynth/utils/state_dict_converters/flux_infiniteyou.py @@ -0,0 +1,2 @@ +def FluxInfiniteYouImageProjectorStateDictConverter(state_dict): + return state_dict['image_proj'] \ No newline at end of file diff --git a/diffsynth/utils/state_dict_converters/flux_ipadapter.py b/diffsynth/utils/state_dict_converters/flux_ipadapter.py new file mode 100644 index 0000000000000000000000000000000000000000..86dfb133655fbe9c33c84b419706a103cec96b1b --- /dev/null +++ b/diffsynth/utils/state_dict_converters/flux_ipadapter.py @@ -0,0 +1,32 @@ +def FluxIpAdapterStateDictConverter(state_dict): + state_dict_ = {} + + if "ip_adapter" in state_dict and isinstance(state_dict["ip_adapter"], dict): + for name, param in state_dict["ip_adapter"].items(): + name_ = 'ipadapter_modules.' + name + state_dict_[name_] = param + + if "image_proj" in state_dict: + for name, param in state_dict["image_proj"].items(): + name_ = "image_proj." + name + state_dict_[name_] = param + return state_dict_ + + for key, value in state_dict.items(): + if key.startswith("image_proj."): + state_dict_[key] = value + elif key.startswith("ip_adapter."): + new_key = key.replace("ip_adapter.", "ipadapter_modules.") + state_dict_[new_key] = value + else: + pass + + return state_dict_ + + +def SiglipStateDictConverter(state_dict): + new_state_dict = {} + for key in state_dict: + if key.startswith("vision_model."): + new_state_dict[key] = state_dict[key] + return new_state_dict \ No newline at end of file diff --git a/diffsynth/utils/state_dict_converters/flux_text_encoder_clip.py b/diffsynth/utils/state_dict_converters/flux_text_encoder_clip.py new file mode 100644 index 0000000000000000000000000000000000000000..aa018aa5c570cc67f4856002e8f1f83f18998e07 --- /dev/null +++ b/diffsynth/utils/state_dict_converters/flux_text_encoder_clip.py @@ -0,0 +1,31 @@ +def FluxTextEncoderClipStateDictConverter(state_dict): + rename_dict = { + "text_model.embeddings.token_embedding.weight": "token_embedding.weight", + "text_model.embeddings.position_embedding.weight": "position_embeds", + "text_model.final_layer_norm.weight": "final_layer_norm.weight", + "text_model.final_layer_norm.bias": "final_layer_norm.bias", + } + attn_rename_dict = { + "self_attn.q_proj": "attn.to_q", + "self_attn.k_proj": "attn.to_k", + "self_attn.v_proj": "attn.to_v", + "self_attn.out_proj": "attn.to_out", + "layer_norm1": "layer_norm1", + "layer_norm2": "layer_norm2", + "mlp.fc1": "fc1", + "mlp.fc2": "fc2", + } + state_dict_ = {} + for name in state_dict: + if name in rename_dict: + param = state_dict[name] + if name == "text_model.embeddings.position_embedding.weight": + param = param.reshape((1, param.shape[0], param.shape[1])) + state_dict_[rename_dict[name]] = param + elif name.startswith("text_model.encoder.layers."): + param = state_dict[name] + names = name.split(".") + layer_id, layer_type, tail = names[3], ".".join(names[4:-1]), names[-1] + name_ = ".".join(["encoders", layer_id, attn_rename_dict[layer_type], tail]) + state_dict_[name_] = param + return state_dict_ diff --git a/diffsynth/utils/state_dict_converters/flux_text_encoder_t5.py b/diffsynth/utils/state_dict_converters/flux_text_encoder_t5.py new file mode 100644 index 0000000000000000000000000000000000000000..d35eb831d2a7b1d48eee747d251d6cfb6ad508ef --- /dev/null +++ b/diffsynth/utils/state_dict_converters/flux_text_encoder_t5.py @@ -0,0 +1,4 @@ +def FluxTextEncoderT5StateDictConverter(state_dict): + state_dict_ = {i: state_dict[i] for i in state_dict} + state_dict_["encoder.embed_tokens.weight"] = state_dict["shared.weight"] + return state_dict_ diff --git a/diffsynth/utils/state_dict_converters/flux_vae.py b/diffsynth/utils/state_dict_converters/flux_vae.py new file mode 100644 index 0000000000000000000000000000000000000000..6547f18f1e1cfe69d0cf4ef43860702812d25fab --- /dev/null +++ b/diffsynth/utils/state_dict_converters/flux_vae.py @@ -0,0 +1,382 @@ +def FluxVAEEncoderStateDictConverter(state_dict): + rename_dict = { + "encoder.conv_in.bias": "conv_in.bias", + "encoder.conv_in.weight": "conv_in.weight", + "encoder.conv_out.bias": "conv_out.bias", + "encoder.conv_out.weight": "conv_out.weight", + "encoder.down.0.block.0.conv1.bias": "blocks.0.conv1.bias", + "encoder.down.0.block.0.conv1.weight": "blocks.0.conv1.weight", + "encoder.down.0.block.0.conv2.bias": "blocks.0.conv2.bias", + "encoder.down.0.block.0.conv2.weight": "blocks.0.conv2.weight", + "encoder.down.0.block.0.norm1.bias": "blocks.0.norm1.bias", + "encoder.down.0.block.0.norm1.weight": "blocks.0.norm1.weight", + "encoder.down.0.block.0.norm2.bias": "blocks.0.norm2.bias", + "encoder.down.0.block.0.norm2.weight": "blocks.0.norm2.weight", + "encoder.down.0.block.1.conv1.bias": "blocks.1.conv1.bias", + "encoder.down.0.block.1.conv1.weight": "blocks.1.conv1.weight", + "encoder.down.0.block.1.conv2.bias": "blocks.1.conv2.bias", + "encoder.down.0.block.1.conv2.weight": "blocks.1.conv2.weight", + "encoder.down.0.block.1.norm1.bias": "blocks.1.norm1.bias", + "encoder.down.0.block.1.norm1.weight": "blocks.1.norm1.weight", + "encoder.down.0.block.1.norm2.bias": "blocks.1.norm2.bias", + "encoder.down.0.block.1.norm2.weight": "blocks.1.norm2.weight", + "encoder.down.0.downsample.conv.bias": "blocks.2.conv.bias", + "encoder.down.0.downsample.conv.weight": "blocks.2.conv.weight", + "encoder.down.1.block.0.conv1.bias": "blocks.3.conv1.bias", + "encoder.down.1.block.0.conv1.weight": "blocks.3.conv1.weight", + "encoder.down.1.block.0.conv2.bias": "blocks.3.conv2.bias", + "encoder.down.1.block.0.conv2.weight": "blocks.3.conv2.weight", + "encoder.down.1.block.0.nin_shortcut.bias": "blocks.3.conv_shortcut.bias", + "encoder.down.1.block.0.nin_shortcut.weight": "blocks.3.conv_shortcut.weight", + "encoder.down.1.block.0.norm1.bias": "blocks.3.norm1.bias", + "encoder.down.1.block.0.norm1.weight": "blocks.3.norm1.weight", + "encoder.down.1.block.0.norm2.bias": "blocks.3.norm2.bias", + "encoder.down.1.block.0.norm2.weight": "blocks.3.norm2.weight", + "encoder.down.1.block.1.conv1.bias": "blocks.4.conv1.bias", + "encoder.down.1.block.1.conv1.weight": "blocks.4.conv1.weight", + "encoder.down.1.block.1.conv2.bias": "blocks.4.conv2.bias", + "encoder.down.1.block.1.conv2.weight": "blocks.4.conv2.weight", + "encoder.down.1.block.1.norm1.bias": "blocks.4.norm1.bias", + "encoder.down.1.block.1.norm1.weight": "blocks.4.norm1.weight", + "encoder.down.1.block.1.norm2.bias": "blocks.4.norm2.bias", + "encoder.down.1.block.1.norm2.weight": "blocks.4.norm2.weight", + "encoder.down.1.downsample.conv.bias": "blocks.5.conv.bias", + "encoder.down.1.downsample.conv.weight": "blocks.5.conv.weight", + "encoder.down.2.block.0.conv1.bias": "blocks.6.conv1.bias", + "encoder.down.2.block.0.conv1.weight": "blocks.6.conv1.weight", + "encoder.down.2.block.0.conv2.bias": "blocks.6.conv2.bias", + "encoder.down.2.block.0.conv2.weight": "blocks.6.conv2.weight", + "encoder.down.2.block.0.nin_shortcut.bias": "blocks.6.conv_shortcut.bias", + "encoder.down.2.block.0.nin_shortcut.weight": "blocks.6.conv_shortcut.weight", + "encoder.down.2.block.0.norm1.bias": "blocks.6.norm1.bias", + "encoder.down.2.block.0.norm1.weight": "blocks.6.norm1.weight", + "encoder.down.2.block.0.norm2.bias": "blocks.6.norm2.bias", + "encoder.down.2.block.0.norm2.weight": "blocks.6.norm2.weight", + "encoder.down.2.block.1.conv1.bias": "blocks.7.conv1.bias", + "encoder.down.2.block.1.conv1.weight": "blocks.7.conv1.weight", + "encoder.down.2.block.1.conv2.bias": "blocks.7.conv2.bias", + "encoder.down.2.block.1.conv2.weight": "blocks.7.conv2.weight", + "encoder.down.2.block.1.norm1.bias": "blocks.7.norm1.bias", + "encoder.down.2.block.1.norm1.weight": "blocks.7.norm1.weight", + "encoder.down.2.block.1.norm2.bias": "blocks.7.norm2.bias", + "encoder.down.2.block.1.norm2.weight": "blocks.7.norm2.weight", + "encoder.down.2.downsample.conv.bias": "blocks.8.conv.bias", + "encoder.down.2.downsample.conv.weight": "blocks.8.conv.weight", + "encoder.down.3.block.0.conv1.bias": "blocks.9.conv1.bias", + "encoder.down.3.block.0.conv1.weight": "blocks.9.conv1.weight", + "encoder.down.3.block.0.conv2.bias": "blocks.9.conv2.bias", + "encoder.down.3.block.0.conv2.weight": "blocks.9.conv2.weight", + "encoder.down.3.block.0.norm1.bias": "blocks.9.norm1.bias", + "encoder.down.3.block.0.norm1.weight": "blocks.9.norm1.weight", + "encoder.down.3.block.0.norm2.bias": "blocks.9.norm2.bias", + "encoder.down.3.block.0.norm2.weight": "blocks.9.norm2.weight", + "encoder.down.3.block.1.conv1.bias": "blocks.10.conv1.bias", + "encoder.down.3.block.1.conv1.weight": "blocks.10.conv1.weight", + "encoder.down.3.block.1.conv2.bias": "blocks.10.conv2.bias", + "encoder.down.3.block.1.conv2.weight": "blocks.10.conv2.weight", + "encoder.down.3.block.1.norm1.bias": "blocks.10.norm1.bias", + "encoder.down.3.block.1.norm1.weight": "blocks.10.norm1.weight", + "encoder.down.3.block.1.norm2.bias": "blocks.10.norm2.bias", + "encoder.down.3.block.1.norm2.weight": "blocks.10.norm2.weight", + "encoder.mid.attn_1.k.bias": "blocks.12.transformer_blocks.0.to_k.bias", + "encoder.mid.attn_1.k.weight": "blocks.12.transformer_blocks.0.to_k.weight", + "encoder.mid.attn_1.norm.bias": "blocks.12.norm.bias", + "encoder.mid.attn_1.norm.weight": "blocks.12.norm.weight", + "encoder.mid.attn_1.proj_out.bias": "blocks.12.transformer_blocks.0.to_out.bias", + "encoder.mid.attn_1.proj_out.weight": "blocks.12.transformer_blocks.0.to_out.weight", + "encoder.mid.attn_1.q.bias": "blocks.12.transformer_blocks.0.to_q.bias", + "encoder.mid.attn_1.q.weight": "blocks.12.transformer_blocks.0.to_q.weight", + "encoder.mid.attn_1.v.bias": "blocks.12.transformer_blocks.0.to_v.bias", + "encoder.mid.attn_1.v.weight": "blocks.12.transformer_blocks.0.to_v.weight", + "encoder.mid.block_1.conv1.bias": "blocks.11.conv1.bias", + "encoder.mid.block_1.conv1.weight": "blocks.11.conv1.weight", + "encoder.mid.block_1.conv2.bias": "blocks.11.conv2.bias", + "encoder.mid.block_1.conv2.weight": "blocks.11.conv2.weight", + "encoder.mid.block_1.norm1.bias": "blocks.11.norm1.bias", + "encoder.mid.block_1.norm1.weight": "blocks.11.norm1.weight", + "encoder.mid.block_1.norm2.bias": "blocks.11.norm2.bias", + "encoder.mid.block_1.norm2.weight": "blocks.11.norm2.weight", + "encoder.mid.block_2.conv1.bias": "blocks.13.conv1.bias", + "encoder.mid.block_2.conv1.weight": "blocks.13.conv1.weight", + "encoder.mid.block_2.conv2.bias": "blocks.13.conv2.bias", + "encoder.mid.block_2.conv2.weight": "blocks.13.conv2.weight", + "encoder.mid.block_2.norm1.bias": "blocks.13.norm1.bias", + "encoder.mid.block_2.norm1.weight": "blocks.13.norm1.weight", + "encoder.mid.block_2.norm2.bias": "blocks.13.norm2.bias", + "encoder.mid.block_2.norm2.weight": "blocks.13.norm2.weight", + "encoder.norm_out.bias": "conv_norm_out.bias", + "encoder.norm_out.weight": "conv_norm_out.weight", + } + state_dict_ = {} + for name in state_dict: + if name in rename_dict: + param = state_dict[name] + state_dict_[rename_dict[name]] = param + return state_dict_ + + +def FluxVAEDecoderStateDictConverter(state_dict): + rename_dict = { + "decoder.conv_in.bias": "conv_in.bias", + "decoder.conv_in.weight": "conv_in.weight", + "decoder.conv_out.bias": "conv_out.bias", + "decoder.conv_out.weight": "conv_out.weight", + "decoder.mid.attn_1.k.bias": "blocks.1.transformer_blocks.0.to_k.bias", + "decoder.mid.attn_1.k.weight": "blocks.1.transformer_blocks.0.to_k.weight", + "decoder.mid.attn_1.norm.bias": "blocks.1.norm.bias", + "decoder.mid.attn_1.norm.weight": "blocks.1.norm.weight", + "decoder.mid.attn_1.proj_out.bias": "blocks.1.transformer_blocks.0.to_out.bias", + "decoder.mid.attn_1.proj_out.weight": "blocks.1.transformer_blocks.0.to_out.weight", + "decoder.mid.attn_1.q.bias": "blocks.1.transformer_blocks.0.to_q.bias", + "decoder.mid.attn_1.q.weight": "blocks.1.transformer_blocks.0.to_q.weight", + "decoder.mid.attn_1.v.bias": "blocks.1.transformer_blocks.0.to_v.bias", + "decoder.mid.attn_1.v.weight": "blocks.1.transformer_blocks.0.to_v.weight", + "decoder.mid.block_1.conv1.bias": "blocks.0.conv1.bias", + "decoder.mid.block_1.conv1.weight": "blocks.0.conv1.weight", + "decoder.mid.block_1.conv2.bias": "blocks.0.conv2.bias", + "decoder.mid.block_1.conv2.weight": "blocks.0.conv2.weight", + "decoder.mid.block_1.norm1.bias": "blocks.0.norm1.bias", + "decoder.mid.block_1.norm1.weight": "blocks.0.norm1.weight", + "decoder.mid.block_1.norm2.bias": "blocks.0.norm2.bias", + "decoder.mid.block_1.norm2.weight": "blocks.0.norm2.weight", + "decoder.mid.block_2.conv1.bias": "blocks.2.conv1.bias", + "decoder.mid.block_2.conv1.weight": "blocks.2.conv1.weight", + "decoder.mid.block_2.conv2.bias": "blocks.2.conv2.bias", + "decoder.mid.block_2.conv2.weight": "blocks.2.conv2.weight", + "decoder.mid.block_2.norm1.bias": "blocks.2.norm1.bias", + "decoder.mid.block_2.norm1.weight": "blocks.2.norm1.weight", + "decoder.mid.block_2.norm2.bias": "blocks.2.norm2.bias", + "decoder.mid.block_2.norm2.weight": "blocks.2.norm2.weight", + "decoder.norm_out.bias": "conv_norm_out.bias", + "decoder.norm_out.weight": "conv_norm_out.weight", + "decoder.up.0.block.0.conv1.bias": "blocks.15.conv1.bias", + "decoder.up.0.block.0.conv1.weight": "blocks.15.conv1.weight", + "decoder.up.0.block.0.conv2.bias": "blocks.15.conv2.bias", + "decoder.up.0.block.0.conv2.weight": "blocks.15.conv2.weight", + "decoder.up.0.block.0.nin_shortcut.bias": "blocks.15.conv_shortcut.bias", + "decoder.up.0.block.0.nin_shortcut.weight": "blocks.15.conv_shortcut.weight", + "decoder.up.0.block.0.norm1.bias": "blocks.15.norm1.bias", + "decoder.up.0.block.0.norm1.weight": "blocks.15.norm1.weight", + "decoder.up.0.block.0.norm2.bias": "blocks.15.norm2.bias", + "decoder.up.0.block.0.norm2.weight": "blocks.15.norm2.weight", + "decoder.up.0.block.1.conv1.bias": "blocks.16.conv1.bias", + "decoder.up.0.block.1.conv1.weight": "blocks.16.conv1.weight", + "decoder.up.0.block.1.conv2.bias": "blocks.16.conv2.bias", + "decoder.up.0.block.1.conv2.weight": "blocks.16.conv2.weight", + "decoder.up.0.block.1.norm1.bias": "blocks.16.norm1.bias", + "decoder.up.0.block.1.norm1.weight": "blocks.16.norm1.weight", + "decoder.up.0.block.1.norm2.bias": "blocks.16.norm2.bias", + "decoder.up.0.block.1.norm2.weight": "blocks.16.norm2.weight", + "decoder.up.0.block.2.conv1.bias": "blocks.17.conv1.bias", + "decoder.up.0.block.2.conv1.weight": "blocks.17.conv1.weight", + "decoder.up.0.block.2.conv2.bias": "blocks.17.conv2.bias", + "decoder.up.0.block.2.conv2.weight": "blocks.17.conv2.weight", + "decoder.up.0.block.2.norm1.bias": "blocks.17.norm1.bias", + "decoder.up.0.block.2.norm1.weight": "blocks.17.norm1.weight", + "decoder.up.0.block.2.norm2.bias": "blocks.17.norm2.bias", + "decoder.up.0.block.2.norm2.weight": "blocks.17.norm2.weight", + "decoder.up.1.block.0.conv1.bias": "blocks.11.conv1.bias", + "decoder.up.1.block.0.conv1.weight": "blocks.11.conv1.weight", + "decoder.up.1.block.0.conv2.bias": "blocks.11.conv2.bias", + "decoder.up.1.block.0.conv2.weight": "blocks.11.conv2.weight", + "decoder.up.1.block.0.nin_shortcut.bias": "blocks.11.conv_shortcut.bias", + "decoder.up.1.block.0.nin_shortcut.weight": "blocks.11.conv_shortcut.weight", + "decoder.up.1.block.0.norm1.bias": "blocks.11.norm1.bias", + "decoder.up.1.block.0.norm1.weight": "blocks.11.norm1.weight", + "decoder.up.1.block.0.norm2.bias": "blocks.11.norm2.bias", + "decoder.up.1.block.0.norm2.weight": "blocks.11.norm2.weight", + "decoder.up.1.block.1.conv1.bias": "blocks.12.conv1.bias", + "decoder.up.1.block.1.conv1.weight": "blocks.12.conv1.weight", + "decoder.up.1.block.1.conv2.bias": "blocks.12.conv2.bias", + "decoder.up.1.block.1.conv2.weight": "blocks.12.conv2.weight", + "decoder.up.1.block.1.norm1.bias": "blocks.12.norm1.bias", + "decoder.up.1.block.1.norm1.weight": "blocks.12.norm1.weight", + "decoder.up.1.block.1.norm2.bias": "blocks.12.norm2.bias", + "decoder.up.1.block.1.norm2.weight": "blocks.12.norm2.weight", + "decoder.up.1.block.2.conv1.bias": "blocks.13.conv1.bias", + "decoder.up.1.block.2.conv1.weight": "blocks.13.conv1.weight", + "decoder.up.1.block.2.conv2.bias": "blocks.13.conv2.bias", + "decoder.up.1.block.2.conv2.weight": "blocks.13.conv2.weight", + "decoder.up.1.block.2.norm1.bias": "blocks.13.norm1.bias", + "decoder.up.1.block.2.norm1.weight": "blocks.13.norm1.weight", + "decoder.up.1.block.2.norm2.bias": "blocks.13.norm2.bias", + "decoder.up.1.block.2.norm2.weight": "blocks.13.norm2.weight", + "decoder.up.1.upsample.conv.bias": "blocks.14.conv.bias", + "decoder.up.1.upsample.conv.weight": "blocks.14.conv.weight", + "decoder.up.2.block.0.conv1.bias": "blocks.7.conv1.bias", + "decoder.up.2.block.0.conv1.weight": "blocks.7.conv1.weight", + "decoder.up.2.block.0.conv2.bias": "blocks.7.conv2.bias", + "decoder.up.2.block.0.conv2.weight": "blocks.7.conv2.weight", + "decoder.up.2.block.0.norm1.bias": "blocks.7.norm1.bias", + "decoder.up.2.block.0.norm1.weight": "blocks.7.norm1.weight", + "decoder.up.2.block.0.norm2.bias": "blocks.7.norm2.bias", + "decoder.up.2.block.0.norm2.weight": "blocks.7.norm2.weight", + "decoder.up.2.block.1.conv1.bias": "blocks.8.conv1.bias", + "decoder.up.2.block.1.conv1.weight": "blocks.8.conv1.weight", + "decoder.up.2.block.1.conv2.bias": "blocks.8.conv2.bias", + "decoder.up.2.block.1.conv2.weight": "blocks.8.conv2.weight", + "decoder.up.2.block.1.norm1.bias": "blocks.8.norm1.bias", + "decoder.up.2.block.1.norm1.weight": "blocks.8.norm1.weight", + "decoder.up.2.block.1.norm2.bias": "blocks.8.norm2.bias", + "decoder.up.2.block.1.norm2.weight": "blocks.8.norm2.weight", + "decoder.up.2.block.2.conv1.bias": "blocks.9.conv1.bias", + "decoder.up.2.block.2.conv1.weight": "blocks.9.conv1.weight", + "decoder.up.2.block.2.conv2.bias": "blocks.9.conv2.bias", + "decoder.up.2.block.2.conv2.weight": "blocks.9.conv2.weight", + "decoder.up.2.block.2.norm1.bias": "blocks.9.norm1.bias", + "decoder.up.2.block.2.norm1.weight": "blocks.9.norm1.weight", + "decoder.up.2.block.2.norm2.bias": "blocks.9.norm2.bias", + "decoder.up.2.block.2.norm2.weight": "blocks.9.norm2.weight", + "decoder.up.2.upsample.conv.bias": "blocks.10.conv.bias", + "decoder.up.2.upsample.conv.weight": "blocks.10.conv.weight", + "decoder.up.3.block.0.conv1.bias": "blocks.3.conv1.bias", + "decoder.up.3.block.0.conv1.weight": "blocks.3.conv1.weight", + "decoder.up.3.block.0.conv2.bias": "blocks.3.conv2.bias", + "decoder.up.3.block.0.conv2.weight": "blocks.3.conv2.weight", + "decoder.up.3.block.0.norm1.bias": "blocks.3.norm1.bias", + "decoder.up.3.block.0.norm1.weight": "blocks.3.norm1.weight", + "decoder.up.3.block.0.norm2.bias": "blocks.3.norm2.bias", + "decoder.up.3.block.0.norm2.weight": "blocks.3.norm2.weight", + "decoder.up.3.block.1.conv1.bias": "blocks.4.conv1.bias", + "decoder.up.3.block.1.conv1.weight": "blocks.4.conv1.weight", + "decoder.up.3.block.1.conv2.bias": "blocks.4.conv2.bias", + "decoder.up.3.block.1.conv2.weight": "blocks.4.conv2.weight", + "decoder.up.3.block.1.norm1.bias": "blocks.4.norm1.bias", + "decoder.up.3.block.1.norm1.weight": "blocks.4.norm1.weight", + "decoder.up.3.block.1.norm2.bias": "blocks.4.norm2.bias", + "decoder.up.3.block.1.norm2.weight": "blocks.4.norm2.weight", + "decoder.up.3.block.2.conv1.bias": "blocks.5.conv1.bias", + "decoder.up.3.block.2.conv1.weight": "blocks.5.conv1.weight", + "decoder.up.3.block.2.conv2.bias": "blocks.5.conv2.bias", + "decoder.up.3.block.2.conv2.weight": "blocks.5.conv2.weight", + "decoder.up.3.block.2.norm1.bias": "blocks.5.norm1.bias", + "decoder.up.3.block.2.norm1.weight": "blocks.5.norm1.weight", + "decoder.up.3.block.2.norm2.bias": "blocks.5.norm2.bias", + "decoder.up.3.block.2.norm2.weight": "blocks.5.norm2.weight", + "decoder.up.3.upsample.conv.bias": "blocks.6.conv.bias", + "decoder.up.3.upsample.conv.weight": "blocks.6.conv.weight", + } + state_dict_ = {} + for name in state_dict: + if name in rename_dict: + param = state_dict[name] + state_dict_[rename_dict[name]] = param + return state_dict_ + + +def FluxVAEEncoderStateDictConverterDiffusers(state_dict): + # architecture + block_types = [ + 'ResnetBlock', 'ResnetBlock', 'DownSampler', + 'ResnetBlock', 'ResnetBlock', 'DownSampler', + 'ResnetBlock', 'ResnetBlock', 'DownSampler', + 'ResnetBlock', 'ResnetBlock', + 'ResnetBlock', 'VAEAttentionBlock', 'ResnetBlock' + ] + + # Rename each parameter + local_rename_dict = { + "quant_conv": "quant_conv", + "encoder.conv_in": "conv_in", + "encoder.mid_block.attentions.0.group_norm": "blocks.12.norm", + "encoder.mid_block.attentions.0.to_q": "blocks.12.transformer_blocks.0.to_q", + "encoder.mid_block.attentions.0.to_k": "blocks.12.transformer_blocks.0.to_k", + "encoder.mid_block.attentions.0.to_v": "blocks.12.transformer_blocks.0.to_v", + "encoder.mid_block.attentions.0.to_out.0": "blocks.12.transformer_blocks.0.to_out", + "encoder.mid_block.resnets.0.norm1": "blocks.11.norm1", + "encoder.mid_block.resnets.0.conv1": "blocks.11.conv1", + "encoder.mid_block.resnets.0.norm2": "blocks.11.norm2", + "encoder.mid_block.resnets.0.conv2": "blocks.11.conv2", + "encoder.mid_block.resnets.1.norm1": "blocks.13.norm1", + "encoder.mid_block.resnets.1.conv1": "blocks.13.conv1", + "encoder.mid_block.resnets.1.norm2": "blocks.13.norm2", + "encoder.mid_block.resnets.1.conv2": "blocks.13.conv2", + "encoder.conv_norm_out": "conv_norm_out", + "encoder.conv_out": "conv_out", + } + name_list = sorted([name for name in state_dict]) + rename_dict = {} + block_id = {"ResnetBlock": -1, "DownSampler": -1, "UpSampler": -1} + last_block_type_with_id = {"ResnetBlock": "", "DownSampler": "", "UpSampler": ""} + for name in name_list: + names = name.split(".") + name_prefix = ".".join(names[:-1]) + if name_prefix in local_rename_dict: + rename_dict[name] = local_rename_dict[name_prefix] + "." + names[-1] + elif name.startswith("encoder.down_blocks"): + block_type = {"resnets": "ResnetBlock", "downsamplers": "DownSampler", "upsamplers": "UpSampler"}[names[3]] + block_type_with_id = ".".join(names[:5]) + if block_type_with_id != last_block_type_with_id[block_type]: + block_id[block_type] += 1 + last_block_type_with_id[block_type] = block_type_with_id + while block_id[block_type] < len(block_types) and block_types[block_id[block_type]] != block_type: + block_id[block_type] += 1 + block_type_with_id = ".".join(names[:5]) + names = ["blocks", str(block_id[block_type])] + names[5:] + rename_dict[name] = ".".join(names) + + # Convert state_dict + state_dict_ = {} + for name in state_dict: + if name in rename_dict: + state_dict_[rename_dict[name]] = state_dict[name] + return state_dict_ + + +def FluxVAEDecoderStateDictConverterDiffusers(state_dict): + # architecture + block_types = [ + 'ResnetBlock', 'VAEAttentionBlock', 'ResnetBlock', + 'ResnetBlock', 'ResnetBlock', 'ResnetBlock', 'UpSampler', + 'ResnetBlock', 'ResnetBlock', 'ResnetBlock', 'UpSampler', + 'ResnetBlock', 'ResnetBlock', 'ResnetBlock', 'UpSampler', + 'ResnetBlock', 'ResnetBlock', 'ResnetBlock' + ] + + # Rename each parameter + local_rename_dict = { + "post_quant_conv": "post_quant_conv", + "decoder.conv_in": "conv_in", + "decoder.mid_block.attentions.0.group_norm": "blocks.1.norm", + "decoder.mid_block.attentions.0.to_q": "blocks.1.transformer_blocks.0.to_q", + "decoder.mid_block.attentions.0.to_k": "blocks.1.transformer_blocks.0.to_k", + "decoder.mid_block.attentions.0.to_v": "blocks.1.transformer_blocks.0.to_v", + "decoder.mid_block.attentions.0.to_out.0": "blocks.1.transformer_blocks.0.to_out", + "decoder.mid_block.resnets.0.norm1": "blocks.0.norm1", + "decoder.mid_block.resnets.0.conv1": "blocks.0.conv1", + "decoder.mid_block.resnets.0.norm2": "blocks.0.norm2", + "decoder.mid_block.resnets.0.conv2": "blocks.0.conv2", + "decoder.mid_block.resnets.1.norm1": "blocks.2.norm1", + "decoder.mid_block.resnets.1.conv1": "blocks.2.conv1", + "decoder.mid_block.resnets.1.norm2": "blocks.2.norm2", + "decoder.mid_block.resnets.1.conv2": "blocks.2.conv2", + "decoder.conv_norm_out": "conv_norm_out", + "decoder.conv_out": "conv_out", + } + name_list = sorted([name for name in state_dict]) + rename_dict = {} + block_id = {"ResnetBlock": 2, "DownSampler": 2, "UpSampler": 2} + last_block_type_with_id = {"ResnetBlock": "", "DownSampler": "", "UpSampler": ""} + for name in name_list: + names = name.split(".") + name_prefix = ".".join(names[:-1]) + if name_prefix in local_rename_dict: + rename_dict[name] = local_rename_dict[name_prefix] + "." + names[-1] + elif name.startswith("decoder.up_blocks"): + block_type = {"resnets": "ResnetBlock", "downsamplers": "DownSampler", "upsamplers": "UpSampler"}[names[3]] + block_type_with_id = ".".join(names[:5]) + if block_type_with_id != last_block_type_with_id[block_type]: + block_id[block_type] += 1 + last_block_type_with_id[block_type] = block_type_with_id + while block_id[block_type] < len(block_types) and block_types[block_id[block_type]] != block_type: + block_id[block_type] += 1 + block_type_with_id = ".".join(names[:5]) + names = ["blocks", str(block_id[block_type])] + names[5:] + rename_dict[name] = ".".join(names) + + # Convert state_dict + state_dict_ = {} + for name in state_dict: + if name in rename_dict: + state_dict_[rename_dict[name]] = state_dict[name] + return state_dict_ \ No newline at end of file diff --git a/diffsynth/utils/state_dict_converters/ltx2_audio_vae.py b/diffsynth/utils/state_dict_converters/ltx2_audio_vae.py new file mode 100644 index 0000000000000000000000000000000000000000..02185309b263840a36ea21b7d9b064740ccc4ff3 --- /dev/null +++ b/diffsynth/utils/state_dict_converters/ltx2_audio_vae.py @@ -0,0 +1,32 @@ +def LTX2AudioEncoderStateDictConverter(state_dict): + # Not used + state_dict_ = {} + for name in state_dict: + if name.startswith("audio_vae.encoder."): + new_name = name.replace("audio_vae.encoder.", "") + state_dict_[new_name] = state_dict[name] + elif name.startswith("audio_vae.per_channel_statistics."): + new_name = name.replace("audio_vae.per_channel_statistics.", "per_channel_statistics.") + state_dict_[new_name] = state_dict[name] + return state_dict_ + + +def LTX2AudioDecoderStateDictConverter(state_dict): + state_dict_ = {} + for name in state_dict: + if name.startswith("audio_vae.decoder."): + new_name = name.replace("audio_vae.decoder.", "") + state_dict_[new_name] = state_dict[name] + elif name.startswith("audio_vae.per_channel_statistics."): + new_name = name.replace("audio_vae.per_channel_statistics.", "per_channel_statistics.") + state_dict_[new_name] = state_dict[name] + return state_dict_ + + +def LTX2VocoderStateDictConverter(state_dict): + state_dict_ = {} + for name in state_dict: + if name.startswith("vocoder."): + new_name = name[len("vocoder."):] + state_dict_[new_name] = state_dict[name] + return state_dict_ diff --git a/diffsynth/utils/state_dict_converters/ltx2_dit.py b/diffsynth/utils/state_dict_converters/ltx2_dit.py new file mode 100644 index 0000000000000000000000000000000000000000..baffb9a6d571e38f831e1643ef86f165bbe9f8a1 --- /dev/null +++ b/diffsynth/utils/state_dict_converters/ltx2_dit.py @@ -0,0 +1,9 @@ +def LTXModelStateDictConverter(state_dict): + state_dict_ = {} + for name in state_dict: + if name.startswith("model.diffusion_model."): + new_name = name.replace("model.diffusion_model.", "") + if new_name.startswith("audio_embeddings_connector.") or new_name.startswith("video_embeddings_connector."): + continue + state_dict_[new_name] = state_dict[name] + return state_dict_ diff --git a/diffsynth/utils/state_dict_converters/ltx2_text_encoder.py b/diffsynth/utils/state_dict_converters/ltx2_text_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..b7e528f7ca1aa576b12f9e07411a82a183e7d18c --- /dev/null +++ b/diffsynth/utils/state_dict_converters/ltx2_text_encoder.py @@ -0,0 +1,31 @@ +def LTX2TextEncoderStateDictConverter(state_dict): + state_dict_ = {} + for key in state_dict: + if key.startswith("language_model.model."): + new_key = key.replace("language_model.model.", "model.language_model.") + elif key.startswith("vision_tower."): + new_key = key.replace("vision_tower.", "model.vision_tower.") + elif key.startswith("multi_modal_projector."): + new_key = key.replace("multi_modal_projector.", "model.multi_modal_projector.") + elif key.startswith("language_model.lm_head."): + new_key = key.replace("language_model.lm_head.", "lm_head.") + else: + continue + state_dict_[new_key] = state_dict[key] + state_dict_["lm_head.weight"] = state_dict_.get("model.language_model.embed_tokens.weight") + return state_dict_ + + +def LTX2TextEncoderPostModulesStateDictConverter(state_dict): + state_dict_ = {} + for key in state_dict: + if key.startswith("text_embedding_projection."): + new_key = key.replace("text_embedding_projection.", "feature_extractor_linear.") + elif key.startswith("model.diffusion_model.video_embeddings_connector."): + new_key = key.replace("model.diffusion_model.video_embeddings_connector.", "embeddings_connector.") + elif key.startswith("model.diffusion_model.audio_embeddings_connector."): + new_key = key.replace("model.diffusion_model.audio_embeddings_connector.", "audio_embeddings_connector.") + else: + continue + state_dict_[new_key] = state_dict[key] + return state_dict_ diff --git a/diffsynth/utils/state_dict_converters/ltx2_video_vae.py b/diffsynth/utils/state_dict_converters/ltx2_video_vae.py new file mode 100644 index 0000000000000000000000000000000000000000..53df15e544fcb51e0bf91ff4f6e8bb93b6e1841a --- /dev/null +++ b/diffsynth/utils/state_dict_converters/ltx2_video_vae.py @@ -0,0 +1,24 @@ +def LTX2VideoEncoderStateDictConverter(state_dict): + state_dict_ = {} + for name in state_dict: + if name.startswith("vae.encoder."): + new_name = name.replace("vae.encoder.", "") + state_dict_[new_name] = state_dict[name] + elif name.startswith("vae.per_channel_statistics."): + new_name = name.replace("vae.per_channel_statistics.", "per_channel_statistics.") + if new_name not in ["per_channel_statistics.channel", "per_channel_statistics.mean-of-stds", "per_channel_statistics.mean-of-stds_over_std-of-means"]: + state_dict_[new_name] = state_dict[name] + return state_dict_ + + +def LTX2VideoDecoderStateDictConverter(state_dict): + state_dict_ = {} + for name in state_dict: + if name.startswith("vae.decoder."): + new_name = name.replace("vae.decoder.", "") + state_dict_[new_name] = state_dict[name] + elif name.startswith("vae.per_channel_statistics."): + new_name = name.replace("vae.per_channel_statistics.", "per_channel_statistics.") + if new_name not in ["per_channel_statistics.channel", "per_channel_statistics.mean-of-stds", "per_channel_statistics.mean-of-stds_over_std-of-means"]: + state_dict_[new_name] = state_dict[name] + return state_dict_ diff --git a/diffsynth/utils/state_dict_converters/nexus_gen.py b/diffsynth/utils/state_dict_converters/nexus_gen.py new file mode 100644 index 0000000000000000000000000000000000000000..aff853d0e76dd1f130ce44241462f61af84370db --- /dev/null +++ b/diffsynth/utils/state_dict_converters/nexus_gen.py @@ -0,0 +1,6 @@ +def NexusGenAutoregressiveModelStateDictConverter(state_dict): + new_state_dict = {} + for key in state_dict: + value = state_dict[key] + new_state_dict["model." + key] = value + return new_state_dict \ No newline at end of file diff --git a/diffsynth/utils/state_dict_converters/nexus_gen_projector.py b/diffsynth/utils/state_dict_converters/nexus_gen_projector.py new file mode 100644 index 0000000000000000000000000000000000000000..c9a44665551ba4a97d063de94f6025c8c48989fd --- /dev/null +++ b/diffsynth/utils/state_dict_converters/nexus_gen_projector.py @@ -0,0 +1,15 @@ +def NexusGenMergerStateDictConverter(state_dict): + merger_state_dict = {} + for key in state_dict: + if key.startswith('embedding_merger.'): + value = state_dict[key] + new_key = key.replace("embedding_merger.", "") + merger_state_dict[new_key] = value + return merger_state_dict + +def NexusGenAdapterStateDictConverter(state_dict): + adapter_state_dict = {} + for key in state_dict: + if key.startswith('adapter.'): + adapter_state_dict[key] = state_dict[key] + return adapter_state_dict \ No newline at end of file diff --git a/diffsynth/utils/state_dict_converters/qwen_image_text_encoder.py b/diffsynth/utils/state_dict_converters/qwen_image_text_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..e8192a1f2a959685cf1fa5af40824bd896454141 --- /dev/null +++ b/diffsynth/utils/state_dict_converters/qwen_image_text_encoder.py @@ -0,0 +1,10 @@ +def QwenImageTextEncoderStateDictConverter(state_dict): + state_dict_ = {} + for k in state_dict: + v = state_dict[k] + if k.startswith("visual."): + k = "model." + k + elif k.startswith("model."): + k = k.replace("model.", "model.language_model.") + state_dict_[k] = v + return state_dict_ diff --git a/diffsynth/utils/state_dict_converters/step1x_connector.py b/diffsynth/utils/state_dict_converters/step1x_connector.py new file mode 100644 index 0000000000000000000000000000000000000000..35a2a4167b16ea5cc16aaa1b0f20575bc2918bbf --- /dev/null +++ b/diffsynth/utils/state_dict_converters/step1x_connector.py @@ -0,0 +1,7 @@ +def Qwen2ConnectorStateDictConverter(state_dict): + state_dict_ = {} + for name in state_dict: + if name.startswith("connector."): + name_ = name[len("connector."):] + state_dict_[name_] = state_dict[name] + return state_dict_ \ No newline at end of file diff --git a/diffsynth/utils/state_dict_converters/wan_video_animate_adapter.py b/diffsynth/utils/state_dict_converters/wan_video_animate_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..8ea69f4e6696bbef6de197abaa031ea8cc5b398e --- /dev/null +++ b/diffsynth/utils/state_dict_converters/wan_video_animate_adapter.py @@ -0,0 +1,6 @@ +def WanAnimateAdapterStateDictConverter(state_dict): + state_dict_ = {} + for name in state_dict: + if name.startswith("pose_patch_embedding.") or name.startswith("face_adapter") or name.startswith("face_encoder") or name.startswith("motion_encoder"): + state_dict_[name] = state_dict[name] + return state_dict_ \ No newline at end of file diff --git a/diffsynth/utils/state_dict_converters/wan_video_dit.py b/diffsynth/utils/state_dict_converters/wan_video_dit.py new file mode 100644 index 0000000000000000000000000000000000000000..c7716dad52e42ebf76f98dd85511ac0a04b3d3b3 --- /dev/null +++ b/diffsynth/utils/state_dict_converters/wan_video_dit.py @@ -0,0 +1,83 @@ +def WanVideoDiTFromDiffusers(state_dict): + rename_dict = { + "blocks.0.attn1.norm_k.weight": "blocks.0.self_attn.norm_k.weight", + "blocks.0.attn1.norm_q.weight": "blocks.0.self_attn.norm_q.weight", + "blocks.0.attn1.to_k.bias": "blocks.0.self_attn.k.bias", + "blocks.0.attn1.to_k.weight": "blocks.0.self_attn.k.weight", + "blocks.0.attn1.to_out.0.bias": "blocks.0.self_attn.o.bias", + "blocks.0.attn1.to_out.0.weight": "blocks.0.self_attn.o.weight", + "blocks.0.attn1.to_q.bias": "blocks.0.self_attn.q.bias", + "blocks.0.attn1.to_q.weight": "blocks.0.self_attn.q.weight", + "blocks.0.attn1.to_v.bias": "blocks.0.self_attn.v.bias", + "blocks.0.attn1.to_v.weight": "blocks.0.self_attn.v.weight", + "blocks.0.attn2.norm_k.weight": "blocks.0.cross_attn.norm_k.weight", + "blocks.0.attn2.norm_q.weight": "blocks.0.cross_attn.norm_q.weight", + "blocks.0.attn2.to_k.bias": "blocks.0.cross_attn.k.bias", + "blocks.0.attn2.to_k.weight": "blocks.0.cross_attn.k.weight", + "blocks.0.attn2.to_out.0.bias": "blocks.0.cross_attn.o.bias", + "blocks.0.attn2.to_out.0.weight": "blocks.0.cross_attn.o.weight", + "blocks.0.attn2.to_q.bias": "blocks.0.cross_attn.q.bias", + "blocks.0.attn2.to_q.weight": "blocks.0.cross_attn.q.weight", + "blocks.0.attn2.to_v.bias": "blocks.0.cross_attn.v.bias", + "blocks.0.attn2.to_v.weight": "blocks.0.cross_attn.v.weight", + "blocks.0.attn2.add_k_proj.bias":"blocks.0.cross_attn.k_img.bias", + "blocks.0.attn2.add_k_proj.weight":"blocks.0.cross_attn.k_img.weight", + "blocks.0.attn2.add_v_proj.bias":"blocks.0.cross_attn.v_img.bias", + "blocks.0.attn2.add_v_proj.weight":"blocks.0.cross_attn.v_img.weight", + "blocks.0.attn2.norm_added_k.weight":"blocks.0.cross_attn.norm_k_img.weight", + "blocks.0.ffn.net.0.proj.bias": "blocks.0.ffn.0.bias", + "blocks.0.ffn.net.0.proj.weight": "blocks.0.ffn.0.weight", + "blocks.0.ffn.net.2.bias": "blocks.0.ffn.2.bias", + "blocks.0.ffn.net.2.weight": "blocks.0.ffn.2.weight", + "blocks.0.norm2.bias": "blocks.0.norm3.bias", + "blocks.0.norm2.weight": "blocks.0.norm3.weight", + "blocks.0.scale_shift_table": "blocks.0.modulation", + "condition_embedder.text_embedder.linear_1.bias": "text_embedding.0.bias", + "condition_embedder.text_embedder.linear_1.weight": "text_embedding.0.weight", + "condition_embedder.text_embedder.linear_2.bias": "text_embedding.2.bias", + "condition_embedder.text_embedder.linear_2.weight": "text_embedding.2.weight", + "condition_embedder.time_embedder.linear_1.bias": "time_embedding.0.bias", + "condition_embedder.time_embedder.linear_1.weight": "time_embedding.0.weight", + "condition_embedder.time_embedder.linear_2.bias": "time_embedding.2.bias", + "condition_embedder.time_embedder.linear_2.weight": "time_embedding.2.weight", + "condition_embedder.time_proj.bias": "time_projection.1.bias", + "condition_embedder.time_proj.weight": "time_projection.1.weight", + "condition_embedder.image_embedder.ff.net.0.proj.bias":"img_emb.proj.1.bias", + "condition_embedder.image_embedder.ff.net.0.proj.weight":"img_emb.proj.1.weight", + "condition_embedder.image_embedder.ff.net.2.bias":"img_emb.proj.3.bias", + "condition_embedder.image_embedder.ff.net.2.weight":"img_emb.proj.3.weight", + "condition_embedder.image_embedder.norm1.bias":"img_emb.proj.0.bias", + "condition_embedder.image_embedder.norm1.weight":"img_emb.proj.0.weight", + "condition_embedder.image_embedder.norm2.bias":"img_emb.proj.4.bias", + "condition_embedder.image_embedder.norm2.weight":"img_emb.proj.4.weight", + "patch_embedding.bias": "patch_embedding.bias", + "patch_embedding.weight": "patch_embedding.weight", + "scale_shift_table": "head.modulation", + "proj_out.bias": "head.head.bias", + "proj_out.weight": "head.head.weight", + } + state_dict_ = {} + for name in state_dict: + if name in rename_dict: + state_dict_[rename_dict[name]] = state_dict[name] + else: + name_ = ".".join(name.split(".")[:1] + ["0"] + name.split(".")[2:]) + if name_ in rename_dict: + name_ = rename_dict[name_] + name_ = ".".join(name_.split(".")[:1] + [name.split(".")[1]] + name_.split(".")[2:]) + state_dict_[name_] = state_dict[name] + return state_dict_ + + +def WanVideoDiTStateDictConverter(state_dict): + state_dict_ = {} + for name in state_dict: + if name.startswith("vace"): + continue + if name.split(".")[0] in ["pose_patch_embedding", "face_adapter", "face_encoder", "motion_encoder"]: + continue + name_ = name + if name_.startswith("model."): + name_ = name_[len("model."):] + state_dict_[name_] = state_dict[name] + return state_dict_ diff --git a/diffsynth/utils/state_dict_converters/wan_video_image_encoder.py b/diffsynth/utils/state_dict_converters/wan_video_image_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..ecb7e9bfce50e88601f8876341ac56645a8e5913 --- /dev/null +++ b/diffsynth/utils/state_dict_converters/wan_video_image_encoder.py @@ -0,0 +1,8 @@ +def WanImageEncoderStateDictConverter(state_dict): + state_dict_ = {} + for name in state_dict: + if name.startswith("textual."): + continue + name_ = "model." + name + state_dict_[name_] = state_dict[name] + return state_dict_ \ No newline at end of file diff --git a/diffsynth/utils/state_dict_converters/wan_video_mot.py b/diffsynth/utils/state_dict_converters/wan_video_mot.py new file mode 100644 index 0000000000000000000000000000000000000000..12b42d7db752fca1cb24c0f16217deab925916f5 --- /dev/null +++ b/diffsynth/utils/state_dict_converters/wan_video_mot.py @@ -0,0 +1,78 @@ +def WanVideoMotStateDictConverter(state_dict): + rename_dict = { + "blocks.0.attn1.norm_k.weight": "blocks.0.self_attn.norm_k.weight", + "blocks.0.attn1.norm_q.weight": "blocks.0.self_attn.norm_q.weight", + "blocks.0.attn1.to_k.bias": "blocks.0.self_attn.k.bias", + "blocks.0.attn1.to_k.weight": "blocks.0.self_attn.k.weight", + "blocks.0.attn1.to_out.0.bias": "blocks.0.self_attn.o.bias", + "blocks.0.attn1.to_out.0.weight": "blocks.0.self_attn.o.weight", + "blocks.0.attn1.to_q.bias": "blocks.0.self_attn.q.bias", + "blocks.0.attn1.to_q.weight": "blocks.0.self_attn.q.weight", + "blocks.0.attn1.to_v.bias": "blocks.0.self_attn.v.bias", + "blocks.0.attn1.to_v.weight": "blocks.0.self_attn.v.weight", + "blocks.0.attn2.norm_k.weight": "blocks.0.cross_attn.norm_k.weight", + "blocks.0.attn2.norm_q.weight": "blocks.0.cross_attn.norm_q.weight", + "blocks.0.attn2.to_k.bias": "blocks.0.cross_attn.k.bias", + "blocks.0.attn2.to_k.weight": "blocks.0.cross_attn.k.weight", + "blocks.0.attn2.to_out.0.bias": "blocks.0.cross_attn.o.bias", + "blocks.0.attn2.to_out.0.weight": "blocks.0.cross_attn.o.weight", + "blocks.0.attn2.to_q.bias": "blocks.0.cross_attn.q.bias", + "blocks.0.attn2.to_q.weight": "blocks.0.cross_attn.q.weight", + "blocks.0.attn2.to_v.bias": "blocks.0.cross_attn.v.bias", + "blocks.0.attn2.to_v.weight": "blocks.0.cross_attn.v.weight", + "blocks.0.attn2.add_k_proj.bias":"blocks.0.cross_attn.k_img.bias", + "blocks.0.attn2.add_k_proj.weight":"blocks.0.cross_attn.k_img.weight", + "blocks.0.attn2.add_v_proj.bias":"blocks.0.cross_attn.v_img.bias", + "blocks.0.attn2.add_v_proj.weight":"blocks.0.cross_attn.v_img.weight", + "blocks.0.attn2.norm_added_k.weight":"blocks.0.cross_attn.norm_k_img.weight", + "blocks.0.ffn.net.0.proj.bias": "blocks.0.ffn.0.bias", + "blocks.0.ffn.net.0.proj.weight": "blocks.0.ffn.0.weight", + "blocks.0.ffn.net.2.bias": "blocks.0.ffn.2.bias", + "blocks.0.ffn.net.2.weight": "blocks.0.ffn.2.weight", + "blocks.0.norm2.bias": "blocks.0.norm3.bias", + "blocks.0.norm2.weight": "blocks.0.norm3.weight", + "blocks.0.scale_shift_table": "blocks.0.modulation", + "condition_embedder.text_embedder.linear_1.bias": "text_embedding.0.bias", + "condition_embedder.text_embedder.linear_1.weight": "text_embedding.0.weight", + "condition_embedder.text_embedder.linear_2.bias": "text_embedding.2.bias", + "condition_embedder.text_embedder.linear_2.weight": "text_embedding.2.weight", + "condition_embedder.time_embedder.linear_1.bias": "time_embedding.0.bias", + "condition_embedder.time_embedder.linear_1.weight": "time_embedding.0.weight", + "condition_embedder.time_embedder.linear_2.bias": "time_embedding.2.bias", + "condition_embedder.time_embedder.linear_2.weight": "time_embedding.2.weight", + "condition_embedder.time_proj.bias": "time_projection.1.bias", + "condition_embedder.time_proj.weight": "time_projection.1.weight", + "condition_embedder.image_embedder.ff.net.0.proj.bias":"img_emb.proj.1.bias", + "condition_embedder.image_embedder.ff.net.0.proj.weight":"img_emb.proj.1.weight", + "condition_embedder.image_embedder.ff.net.2.bias":"img_emb.proj.3.bias", + "condition_embedder.image_embedder.ff.net.2.weight":"img_emb.proj.3.weight", + "condition_embedder.image_embedder.norm1.bias":"img_emb.proj.0.bias", + "condition_embedder.image_embedder.norm1.weight":"img_emb.proj.0.weight", + "condition_embedder.image_embedder.norm2.bias":"img_emb.proj.4.bias", + "condition_embedder.image_embedder.norm2.weight":"img_emb.proj.4.weight", + "patch_embedding.bias": "patch_embedding.bias", + "patch_embedding.weight": "patch_embedding.weight", + "scale_shift_table": "head.modulation", + "proj_out.bias": "head.head.bias", + "proj_out.weight": "head.head.weight", + } + mot_layers = (0, 4, 8, 12, 16, 20, 24, 28, 32, 36) + mot_layers_mapping = {i:n for n, i in enumerate(mot_layers)} + state_dict_ = {} + for name in state_dict: + if "_mot_ref" not in name: + continue + param = state_dict[name] + name = name.replace("_mot_ref", "") + if name in rename_dict: + state_dict_[rename_dict[name]] = param + else: + if name.split(".")[1].isdigit(): + block_id = int(name.split(".")[1]) + name = name.replace(str(block_id), str(mot_layers_mapping[block_id])) + name_ = ".".join(name.split(".")[:1] + ["0"] + name.split(".")[2:]) + if name_ in rename_dict: + name_ = rename_dict[name_] + name_ = ".".join(name_.split(".")[:1] + [name.split(".")[1]] + name_.split(".")[2:]) + state_dict_[name_] = param + return state_dict_ diff --git a/diffsynth/utils/state_dict_converters/wan_video_vace.py b/diffsynth/utils/state_dict_converters/wan_video_vace.py new file mode 100644 index 0000000000000000000000000000000000000000..8e1aec13e7f2be66d5faf8389547a96aa14020fd --- /dev/null +++ b/diffsynth/utils/state_dict_converters/wan_video_vace.py @@ -0,0 +1,40 @@ +import torch + + +def VaceWanModelDictConverter(state_dict): + state_dict_ = {name: state_dict[name] for name in state_dict if name.startswith("vace")} + return state_dict_ + + +def expand_patch_embedding_channels(model, state_dict, glyph_channels): + """Expand vace_patch_embedding Conv3D input channels to accommodate glyph channels. + + Pretrained weights cover the original 96 input channels (inactive + reactive + mask). + New glyph channels (16) are zero-initialized so the model starts from pretrained + behavior and gradually learns to use glyph information during fine-tuning. + """ + if glyph_channels <= 0: + return + + key_w = "vace_patch_embedding.weight" + key_b = "vace_patch_embedding.bias" + + if key_w not in state_dict: + return + + pretrained_w = state_dict[key_w] # (out_ch, 96, 1, 2, 2) + out_ch = pretrained_w.shape[0] + orig_in = pretrained_w.shape[1] + kernel = pretrained_w.shape[2:] + + expected_in = orig_in + glyph_channels + if model.vace_patch_embedding.weight.shape[1] != expected_in: + return + + new_w = torch.zeros(out_ch, expected_in, *kernel, dtype=pretrained_w.dtype) + new_w[:, :orig_in] = pretrained_w + state_dict[key_w] = new_w + + # Bias is per output channel, no change needed + if key_b in state_dict: + pass # bias shape is (out_ch,), unchanged diff --git a/diffsynth/utils/state_dict_converters/wan_video_vae.py b/diffsynth/utils/state_dict_converters/wan_video_vae.py new file mode 100644 index 0000000000000000000000000000000000000000..76a430e1bd4575e0ae06234de23b620d4877566f --- /dev/null +++ b/diffsynth/utils/state_dict_converters/wan_video_vae.py @@ -0,0 +1,7 @@ +def WanVideoVAEStateDictConverter(state_dict): + state_dict_ = {} + if 'model_state' in state_dict: + state_dict = state_dict['model_state'] + for name in state_dict: + state_dict_['model.' + name] = state_dict[name] + return state_dict_ \ No newline at end of file diff --git a/diffsynth/utils/state_dict_converters/wans2v_audio_encoder.py b/diffsynth/utils/state_dict_converters/wans2v_audio_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..eaa12c0d4ff7fc166eea1f804cb645be9aa28776 --- /dev/null +++ b/diffsynth/utils/state_dict_converters/wans2v_audio_encoder.py @@ -0,0 +1,12 @@ +def WanS2VAudioEncoderStateDictConverter(state_dict): + rename_dict = { + "model.wav2vec2.encoder.pos_conv_embed.conv.weight_g": "model.wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original0", + "model.wav2vec2.encoder.pos_conv_embed.conv.weight_v": "model.wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original1", + } + state_dict_ = {} + for name in state_dict: + name_ = "model." + name + if name_ in rename_dict: + name_ = rename_dict[name_] + state_dict_[name_] = state_dict[name] + return state_dict_ diff --git a/diffsynth/utils/state_dict_converters/z_image_dit.py b/diffsynth/utils/state_dict_converters/z_image_dit.py new file mode 100644 index 0000000000000000000000000000000000000000..0f44d8bbc067236b2cb74abe10496ff2ebf468b0 --- /dev/null +++ b/diffsynth/utils/state_dict_converters/z_image_dit.py @@ -0,0 +1,3 @@ +def ZImageDiTStateDictConverter(state_dict): + state_dict_ = {name.replace("model.diffusion_model.", ""): state_dict[name] for name in state_dict} + return state_dict_ diff --git a/diffsynth/utils/state_dict_converters/z_image_text_encoder.py b/diffsynth/utils/state_dict_converters/z_image_text_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..b11461345e808edd2c4ca793419ae137b70bfbc9 --- /dev/null +++ b/diffsynth/utils/state_dict_converters/z_image_text_encoder.py @@ -0,0 +1,6 @@ +def ZImageTextEncoderStateDictConverter(state_dict): + state_dict_ = {} + for name in state_dict: + if name != "lm_head.weight": + state_dict_[name] = state_dict[name] + return state_dict_ \ No newline at end of file diff --git a/diffsynth/utils/xfuser/__init__.py b/diffsynth/utils/xfuser/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8e5faca2fa5de1ee0baccfc89dbdd33315ef6905 --- /dev/null +++ b/diffsynth/utils/xfuser/__init__.py @@ -0,0 +1 @@ +from .xdit_context_parallel import usp_attn_forward, usp_dit_forward, usp_vace_forward, get_sequence_parallel_world_size, initialize_usp, get_current_chunk, gather_all_chunks diff --git a/diffsynth/utils/xfuser/xdit_context_parallel.py b/diffsynth/utils/xfuser/xdit_context_parallel.py new file mode 100644 index 0000000000000000000000000000000000000000..abf0f3fef630356c80e456256c7e346aea82216f --- /dev/null +++ b/diffsynth/utils/xfuser/xdit_context_parallel.py @@ -0,0 +1,206 @@ +import torch +from typing import Optional +from einops import rearrange +from yunchang.kernels import AttnType +from xfuser.core.distributed import (get_sequence_parallel_rank, + get_sequence_parallel_world_size, + get_sp_group) +from xfuser.core.long_ctx_attention import xFuserLongContextAttention + +from ... import IS_NPU_AVAILABLE +from ...core.device import parse_nccl_backend, parse_device_type +from ...core.gradient import gradient_checkpoint_forward + + +def initialize_usp(device_type): + import torch.distributed as dist + from xfuser.core.distributed import initialize_model_parallel, init_distributed_environment + dist.init_process_group(backend=parse_nccl_backend(device_type), init_method="env://") + init_distributed_environment(rank=dist.get_rank(), world_size=dist.get_world_size()) + initialize_model_parallel( + sequence_parallel_degree=dist.get_world_size(), + ring_degree=1, + ulysses_degree=dist.get_world_size(), + ) + getattr(torch, device_type).set_device(dist.get_rank()) + + +def sinusoidal_embedding_1d(dim, position): + sinusoid = torch.outer(position.type(torch.float64), torch.pow( + 10000, -torch.arange(dim//2, dtype=torch.float64, device=position.device).div(dim//2))) + x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1) + return x.to(position.dtype) + +def pad_freqs(original_tensor, target_len): + seq_len, s1, s2 = original_tensor.shape + pad_size = target_len - seq_len + original_tensor_device = original_tensor.device + if original_tensor.device == "npu": + original_tensor = original_tensor.cpu() + padding_tensor = torch.ones( + pad_size, + s1, + s2, + dtype=original_tensor.dtype, + device=original_tensor.device) + padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0).to(device=original_tensor_device) + return padded_tensor + +def rope_apply(x, freqs, num_heads): + x = rearrange(x, "b s (n d) -> b s n d", n=num_heads) + s_per_rank = x.shape[1] + + x_out = torch.view_as_complex(x.to(torch.float64).reshape( + x.shape[0], x.shape[1], x.shape[2], -1, 2)) + + sp_size = get_sequence_parallel_world_size() + sp_rank = get_sequence_parallel_rank() + freqs = pad_freqs(freqs, s_per_rank * sp_size) + freqs_rank = freqs[(sp_rank * s_per_rank):((sp_rank + 1) * s_per_rank), :, :] + freqs_rank = freqs_rank.to(torch.complex64) if freqs_rank.device.type == "npu" else freqs_rank + x_out = torch.view_as_real(x_out * freqs_rank).flatten(2) + return x_out.to(x.dtype) + +def usp_dit_forward(self, + x: torch.Tensor, + timestep: torch.Tensor, + context: torch.Tensor, + clip_feature: Optional[torch.Tensor] = None, + y: Optional[torch.Tensor] = None, + use_gradient_checkpointing: bool = False, + use_gradient_checkpointing_offload: bool = False, + **kwargs, + ): + t = self.time_embedding( + sinusoidal_embedding_1d(self.freq_dim, timestep)) + t_mod = self.time_projection(t).unflatten(1, (6, self.dim)) + context = self.text_embedding(context) + + if self.has_image_input: + x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w) + clip_embdding = self.img_emb(clip_feature) + context = torch.cat([clip_embdding, context], dim=1) + + x, (f, h, w) = self.patchify(x) + + freqs = torch.cat([ + self.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), + self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), + self.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1) + ], dim=-1).reshape(f * h * w, 1, -1).to(x.device) + + # Context Parallel + chunks = torch.chunk(x, get_sequence_parallel_world_size(), dim=1) + pad_shape = chunks[0].shape[1] - chunks[-1].shape[1] + chunks = [torch.nn.functional.pad(chunk, (0, 0, 0, chunks[0].shape[1]-chunk.shape[1]), value=0) for chunk in chunks] + x = chunks[get_sequence_parallel_rank()] + + for block in self.blocks: + if self.training: + x = gradient_checkpoint_forward( + block, + use_gradient_checkpointing, + use_gradient_checkpointing_offload, + x, context, t_mod, freqs + ) + else: + x = block(x, context, t_mod, freqs) + + x = self.head(x, t) + + # Context Parallel + x = get_sp_group().all_gather(x, dim=1) + x = x[:, :-pad_shape] if pad_shape > 0 else x + + # unpatchify + x = self.unpatchify(x, (f, h, w)) + return x + + +def usp_vace_forward( + self, x, vace_context, context, t_mod, freqs, + use_gradient_checkpointing: bool = False, + use_gradient_checkpointing_offload: bool = False, +): + # Compute full sequence length from the sharded x + full_seq_len = x.shape[1] * get_sequence_parallel_world_size() + + # Embed vace_context via patch embedding + c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context] + c = [u.flatten(2).transpose(1, 2) for u in c] + c = torch.cat([ + torch.cat([u, u.new_zeros(1, full_seq_len - u.size(1), u.size(2))], + dim=1) for u in c + ]) + + # Chunk VACE context along sequence dim BEFORE processing through blocks + c = torch.chunk(c, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()] + + # Process through vace_blocks (self_attn already monkey-patched to usp_attn_forward) + for block in self.vace_blocks: + c = gradient_checkpoint_forward( + block, + use_gradient_checkpointing, + use_gradient_checkpointing_offload, + c, x, context, t_mod, freqs + ) + + # Hints are already sharded per-rank + hints = torch.unbind(c)[:-1] + return hints + + +def usp_attn_forward(self, x, freqs): + q = self.norm_q(self.q(x)) + k = self.norm_k(self.k(x)) + v = self.v(x) + + q = rope_apply(q, freqs, self.num_heads) + k = rope_apply(k, freqs, self.num_heads) + q = rearrange(q, "b s (n d) -> b s n d", n=self.num_heads) + k = rearrange(k, "b s (n d) -> b s n d", n=self.num_heads) + v = rearrange(v, "b s (n d) -> b s n d", n=self.num_heads) + + attn_type = AttnType.FA + ring_impl_type = "basic" + if IS_NPU_AVAILABLE: + attn_type = AttnType.NPU + ring_impl_type = "basic_npu" + x = xFuserLongContextAttention(attn_type=attn_type, ring_impl_type=ring_impl_type)( + None, + query=q, + key=k, + value=v, + ) + x = x.flatten(2) + + del q, k, v + getattr(torch, parse_device_type(x.device)).empty_cache() + return self.o(x) + + +def get_current_chunk(x, dim=1): + chunks = torch.chunk(x, get_sequence_parallel_world_size(), dim=dim) + ndims = len(chunks[0].shape) + pad_list = [0] * (2 * ndims) + pad_end_index = 2 * (ndims - 1 - dim) + 1 + max_size = chunks[0].size(dim) + chunks = [ + torch.nn.functional.pad( + chunk, + tuple(pad_list[:pad_end_index] + [max_size - chunk.size(dim)] + pad_list[pad_end_index+1:]), + value=0 + ) + for chunk in chunks + ] + x = chunks[get_sequence_parallel_rank()] + return x + + +def gather_all_chunks(x, seq_len=None, dim=1): + x = get_sp_group().all_gather(x, dim=dim) + if seq_len is not None: + slices = [slice(None)] * x.ndim + slices[dim] = slice(0, seq_len) + x = x[tuple(slices)] + return x diff --git a/diffsynth/version.py b/diffsynth/version.py new file mode 100644 index 0000000000000000000000000000000000000000..6fcae7a677a5f5e73fbd7f44840f3cf7533ea495 --- /dev/null +++ b/diffsynth/version.py @@ -0,0 +1,5 @@ +# Make sure to modify __release_datetime__ to release time when making official release. +__version__ = '2.0.0' +# default release datetime for branches under active development is set +# to be a time far-far-away-into-the-future +__release_datetime__ = '2099-10-13 08:56:12' \ No newline at end of file