Bundle diffsynth library (no external repo dependency)
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- diffsynth/__init__.py +1 -0
- diffsynth/configs/__init__.py +2 -0
- diffsynth/configs/model_configs.py +888 -0
- diffsynth/configs/vram_management_module_maps.py +284 -0
- diffsynth/core/__init__.py +6 -0
- diffsynth/core/attention/__init__.py +1 -0
- diffsynth/core/attention/attention.py +121 -0
- diffsynth/core/data/__init__.py +1 -0
- diffsynth/core/data/operators.py +280 -0
- diffsynth/core/data/unified_dataset.py +118 -0
- diffsynth/core/device/__init__.py +2 -0
- diffsynth/core/device/npu_compatible_device.py +107 -0
- diffsynth/core/gradient/__init__.py +1 -0
- diffsynth/core/gradient/gradient_checkpoint.py +37 -0
- diffsynth/core/loader/__init__.py +3 -0
- diffsynth/core/loader/config.py +119 -0
- diffsynth/core/loader/file.py +130 -0
- diffsynth/core/loader/model.py +115 -0
- diffsynth/core/npu_patch/npu_fused_operator.py +30 -0
- diffsynth/core/vram/__init__.py +2 -0
- diffsynth/core/vram/disk_map.py +93 -0
- diffsynth/core/vram/initialization.py +21 -0
- diffsynth/core/vram/layers.py +479 -0
- diffsynth/diffusion/__init__.py +6 -0
- diffsynth/diffusion/base_pipeline.py +500 -0
- diffsynth/diffusion/flow_match.py +236 -0
- diffsynth/diffusion/logger.py +43 -0
- diffsynth/diffusion/loss.py +158 -0
- diffsynth/diffusion/parsers.py +71 -0
- diffsynth/diffusion/runner.py +135 -0
- diffsynth/diffusion/training_module.py +302 -0
- diffsynth/models/anima_dit.py +1307 -0
- diffsynth/models/dinov3_image_encoder.py +96 -0
- diffsynth/models/flux2_dit.py +1053 -0
- diffsynth/models/flux2_text_encoder.py +58 -0
- diffsynth/models/flux2_vae.py +0 -0
- diffsynth/models/flux_controlnet.py +384 -0
- diffsynth/models/flux_dit.py +398 -0
- diffsynth/models/flux_infiniteyou.py +129 -0
- diffsynth/models/flux_ipadapter.py +110 -0
- diffsynth/models/flux_lora_encoder.py +521 -0
- diffsynth/models/flux_lora_patcher.py +306 -0
- diffsynth/models/flux_text_encoder_clip.py +112 -0
- diffsynth/models/flux_text_encoder_t5.py +43 -0
- diffsynth/models/flux_vae.py +451 -0
- diffsynth/models/flux_value_control.py +56 -0
- diffsynth/models/general_modules.py +146 -0
- diffsynth/models/longcat_video_dit.py +902 -0
- diffsynth/models/ltx2_audio_vae.py +1872 -0
- diffsynth/models/ltx2_common.py +388 -0
diffsynth/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .core import *
|
diffsynth/configs/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .model_configs import MODEL_CONFIGS
|
| 2 |
+
from .vram_management_module_maps import VRAM_MANAGEMENT_MODULE_MAPS, VERSION_CHECKER_MAPS
|
diffsynth/configs/model_configs.py
ADDED
|
@@ -0,0 +1,888 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
qwen_image_series = [
|
| 2 |
+
{
|
| 3 |
+
# Example: ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors")
|
| 4 |
+
"model_hash": "0319a1cb19835fb510907dd3367c95ff",
|
| 5 |
+
"model_name": "qwen_image_dit",
|
| 6 |
+
"model_class": "diffsynth.models.qwen_image_dit.QwenImageDiT",
|
| 7 |
+
},
|
| 8 |
+
{
|
| 9 |
+
# Example: ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors")
|
| 10 |
+
"model_hash": "8004730443f55db63092006dd9f7110e",
|
| 11 |
+
"model_name": "qwen_image_text_encoder",
|
| 12 |
+
"model_class": "diffsynth.models.qwen_image_text_encoder.QwenImageTextEncoder",
|
| 13 |
+
"state_dict_converter": "diffsynth.utils.state_dict_converters.qwen_image_text_encoder.QwenImageTextEncoderStateDictConverter",
|
| 14 |
+
},
|
| 15 |
+
{
|
| 16 |
+
# Example: ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors")
|
| 17 |
+
"model_hash": "ed4ea5824d55ec3107b09815e318123a",
|
| 18 |
+
"model_name": "qwen_image_vae",
|
| 19 |
+
"model_class": "diffsynth.models.qwen_image_vae.QwenImageVAE",
|
| 20 |
+
},
|
| 21 |
+
{
|
| 22 |
+
# Example: ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth", origin_file_pattern="model.safetensors")
|
| 23 |
+
"model_hash": "073bce9cf969e317e5662cd570c3e79c",
|
| 24 |
+
"model_name": "qwen_image_blockwise_controlnet",
|
| 25 |
+
"model_class": "diffsynth.models.qwen_image_controlnet.QwenImageBlockWiseControlNet",
|
| 26 |
+
},
|
| 27 |
+
{
|
| 28 |
+
# Example: ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint", origin_file_pattern="model.safetensors")
|
| 29 |
+
"model_hash": "a9e54e480a628f0b956a688a81c33bab",
|
| 30 |
+
"model_name": "qwen_image_blockwise_controlnet",
|
| 31 |
+
"model_class": "diffsynth.models.qwen_image_controlnet.QwenImageBlockWiseControlNet",
|
| 32 |
+
"extra_kwargs": {"additional_in_dim": 4},
|
| 33 |
+
},
|
| 34 |
+
{
|
| 35 |
+
# Example: ModelConfig(model_id="DiffSynth-Studio/General-Image-Encoders", origin_file_pattern="SigLIP2-G384/model.safetensors")
|
| 36 |
+
"model_hash": "469c78b61e3e31bc9eec0d0af3d3f2f8",
|
| 37 |
+
"model_name": "siglip2_image_encoder",
|
| 38 |
+
"model_class": "diffsynth.models.siglip2_image_encoder.Siglip2ImageEncoder",
|
| 39 |
+
},
|
| 40 |
+
{
|
| 41 |
+
# Example: ModelConfig(model_id="DiffSynth-Studio/General-Image-Encoders", origin_file_pattern="DINOv3-7B/model.safetensors")
|
| 42 |
+
"model_hash": "5722b5c873720009de96422993b15682",
|
| 43 |
+
"model_name": "dinov3_image_encoder",
|
| 44 |
+
"model_class": "diffsynth.models.dinov3_image_encoder.DINOv3ImageEncoder",
|
| 45 |
+
},
|
| 46 |
+
{
|
| 47 |
+
# Example:
|
| 48 |
+
"model_hash": "a166c33455cdbd89c0888a3645ca5c0f",
|
| 49 |
+
"model_name": "qwen_image_image2lora_coarse",
|
| 50 |
+
"model_class": "diffsynth.models.qwen_image_image2lora.QwenImageImage2LoRAModel",
|
| 51 |
+
},
|
| 52 |
+
{
|
| 53 |
+
# Example:
|
| 54 |
+
"model_hash": "a5476e691767a4da6d3a6634a10f7408",
|
| 55 |
+
"model_name": "qwen_image_image2lora_fine",
|
| 56 |
+
"model_class": "diffsynth.models.qwen_image_image2lora.QwenImageImage2LoRAModel",
|
| 57 |
+
"extra_kwargs": {"residual_length": 37*37+7, "residual_mid_dim": 64}
|
| 58 |
+
},
|
| 59 |
+
{
|
| 60 |
+
# Example:
|
| 61 |
+
"model_hash": "0aad514690602ecaff932c701cb4b0bb",
|
| 62 |
+
"model_name": "qwen_image_image2lora_style",
|
| 63 |
+
"model_class": "diffsynth.models.qwen_image_image2lora.QwenImageImage2LoRAModel",
|
| 64 |
+
"extra_kwargs": {"compress_dim": 64, "use_residual": False}
|
| 65 |
+
},
|
| 66 |
+
{
|
| 67 |
+
# Example: ModelConfig(model_id="Qwen/Qwen-Image-Layered", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors")
|
| 68 |
+
"model_hash": "8dc8cda05de16c73afa755e2c1ce2839",
|
| 69 |
+
"model_name": "qwen_image_dit",
|
| 70 |
+
"model_class": "diffsynth.models.qwen_image_dit.QwenImageDiT",
|
| 71 |
+
"extra_kwargs": {"use_layer3d_rope": True, "use_additional_t_cond": True}
|
| 72 |
+
},
|
| 73 |
+
{
|
| 74 |
+
# Example: ModelConfig(model_id="Qwen/Qwen-Image-Layered", origin_file_pattern="vae/diffusion_pytorch_model.safetensors")
|
| 75 |
+
"model_hash": "44b39ddc499e027cfb24f7878d7416b9",
|
| 76 |
+
"model_name": "qwen_image_vae",
|
| 77 |
+
"model_class": "diffsynth.models.qwen_image_vae.QwenImageVAE",
|
| 78 |
+
"extra_kwargs": {"image_channels": 4}
|
| 79 |
+
},
|
| 80 |
+
]
|
| 81 |
+
|
| 82 |
+
wan_series = [
|
| 83 |
+
{
|
| 84 |
+
# Example: ModelConfig(model_id="krea/krea-realtime-video", origin_file_pattern="krea-realtime-video-14b.safetensors")
|
| 85 |
+
"model_hash": "5ec04e02b42d2580483ad69f4e76346a",
|
| 86 |
+
"model_name": "wan_video_dit",
|
| 87 |
+
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
| 88 |
+
"extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 16, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06},
|
| 89 |
+
"state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_dit.WanVideoDiTStateDictConverter",
|
| 90 |
+
},
|
| 91 |
+
{
|
| 92 |
+
# Example: ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth")
|
| 93 |
+
"model_hash": "9c8818c2cbea55eca56c7b447df170da",
|
| 94 |
+
"model_name": "wan_video_text_encoder",
|
| 95 |
+
"model_class": "diffsynth.models.wan_video_text_encoder.WanTextEncoder",
|
| 96 |
+
},
|
| 97 |
+
{
|
| 98 |
+
# Example: ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="Wan2.1_VAE.pth")
|
| 99 |
+
"model_hash": "ccc42284ea13e1ad04693284c7a09be6",
|
| 100 |
+
"model_name": "wan_video_vae",
|
| 101 |
+
"model_class": "diffsynth.models.wan_video_vae.WanVideoVAE",
|
| 102 |
+
"state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_vae.WanVideoVAEStateDictConverter",
|
| 103 |
+
},
|
| 104 |
+
{
|
| 105 |
+
# Example: ModelConfig(model_id="meituan-longcat/LongCat-Video", origin_file_pattern="dit/diffusion_pytorch_model*.safetensors")
|
| 106 |
+
"model_hash": "8b27900f680d7251ce44e2dc8ae1ffef",
|
| 107 |
+
"model_name": "wan_video_dit",
|
| 108 |
+
"model_class": "diffsynth.models.longcat_video_dit.LongCatVideoTransformer3DModel",
|
| 109 |
+
},
|
| 110 |
+
{
|
| 111 |
+
# Example: ModelConfig(model_id="ByteDance/Video-As-Prompt-Wan2.1-14B", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors")
|
| 112 |
+
"model_hash": "5f90e66a0672219f12d9a626c8c21f61",
|
| 113 |
+
"model_name": "wan_video_dit",
|
| 114 |
+
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
| 115 |
+
"extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06},
|
| 116 |
+
"state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_dit.WanVideoDiTFromDiffusers"
|
| 117 |
+
},
|
| 118 |
+
{
|
| 119 |
+
# Example: ModelConfig(model_id="ByteDance/Video-As-Prompt-Wan2.1-14B", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors")
|
| 120 |
+
"model_hash": "5f90e66a0672219f12d9a626c8c21f61",
|
| 121 |
+
"model_name": "wan_video_vap",
|
| 122 |
+
"model_class": "diffsynth.models.wan_video_mot.MotWanModel",
|
| 123 |
+
"state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_mot.WanVideoMotStateDictConverter"
|
| 124 |
+
},
|
| 125 |
+
{
|
| 126 |
+
# Example: ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-480P", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth")
|
| 127 |
+
"model_hash": "5941c53e207d62f20f9025686193c40b",
|
| 128 |
+
"model_name": "wan_video_image_encoder",
|
| 129 |
+
"model_class": "diffsynth.models.wan_video_image_encoder.WanImageEncoder",
|
| 130 |
+
"state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_image_encoder.WanImageEncoderStateDictConverter"
|
| 131 |
+
},
|
| 132 |
+
{
|
| 133 |
+
# Example: ModelConfig(model_id="DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1", origin_file_pattern="model.safetensors")
|
| 134 |
+
"model_hash": "dbd5ec76bbf977983f972c151d545389",
|
| 135 |
+
"model_name": "wan_video_motion_controller",
|
| 136 |
+
"model_class": "diffsynth.models.wan_video_motion_controller.WanMotionControllerModel",
|
| 137 |
+
},
|
| 138 |
+
{
|
| 139 |
+
# Example: ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
| 140 |
+
"model_hash": "9269f8db9040a9d860eaca435be61814",
|
| 141 |
+
"model_name": "wan_video_dit",
|
| 142 |
+
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
| 143 |
+
"extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 16, 'dim': 1536, 'ffn_dim': 8960, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 12, 'num_layers': 30, 'eps': 1e-06}
|
| 144 |
+
},
|
| 145 |
+
{
|
| 146 |
+
# Example: ModelConfig(model_id="Wan-AI/Wan2.1-FLF2V-14B-720P", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
| 147 |
+
"model_hash": "3ef3b1f8e1dab83d5b71fd7b617f859f",
|
| 148 |
+
"model_name": "wan_video_dit",
|
| 149 |
+
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
| 150 |
+
"extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06, 'has_image_pos_emb': True}
|
| 151 |
+
},
|
| 152 |
+
{
|
| 153 |
+
# Example: ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
| 154 |
+
"model_hash": "349723183fc063b2bfc10bb2835cf677",
|
| 155 |
+
"model_name": "wan_video_dit",
|
| 156 |
+
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
| 157 |
+
"extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 48, 'dim': 1536, 'ffn_dim': 8960, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 12, 'num_layers': 30, 'eps': 1e-06}
|
| 158 |
+
},
|
| 159 |
+
{
|
| 160 |
+
# Example: ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-InP", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
| 161 |
+
"model_hash": "6d6ccde6845b95ad9114ab993d917893",
|
| 162 |
+
"model_name": "wan_video_dit",
|
| 163 |
+
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
| 164 |
+
"extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 1536, 'ffn_dim': 8960, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 12, 'num_layers': 30, 'eps': 1e-06}
|
| 165 |
+
},
|
| 166 |
+
{
|
| 167 |
+
# Example: ModelConfig(model_id="PAI/Wan2.1-Fun-14B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
| 168 |
+
"model_hash": "efa44cddf936c70abd0ea28b6cbe946c",
|
| 169 |
+
"model_name": "wan_video_dit",
|
| 170 |
+
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
| 171 |
+
"extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 48, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06}
|
| 172 |
+
},
|
| 173 |
+
{
|
| 174 |
+
# Example: ModelConfig(model_id="PAI/Wan2.1-Fun-14B-InP", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
| 175 |
+
"model_hash": "6bfcfb3b342cb286ce886889d519a77e",
|
| 176 |
+
"model_name": "wan_video_dit",
|
| 177 |
+
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
| 178 |
+
"extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06}
|
| 179 |
+
},
|
| 180 |
+
{
|
| 181 |
+
# Example: ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
| 182 |
+
"model_hash": "ac6a5aa74f4a0aab6f64eb9a72f19901",
|
| 183 |
+
"model_name": "wan_video_dit",
|
| 184 |
+
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
| 185 |
+
"extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 32, 'dim': 1536, 'ffn_dim': 8960, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 12, 'num_layers': 30, 'eps': 1e-06, 'has_ref_conv': False, 'add_control_adapter': True, 'in_dim_control_adapter': 24}
|
| 186 |
+
},
|
| 187 |
+
{
|
| 188 |
+
# Example: ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
| 189 |
+
"model_hash": "70ddad9d3a133785da5ea371aae09504",
|
| 190 |
+
"model_name": "wan_video_dit",
|
| 191 |
+
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
| 192 |
+
"extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 48, 'dim': 1536, 'ffn_dim': 8960, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 12, 'num_layers': 30, 'eps': 1e-06, 'has_ref_conv': True}
|
| 193 |
+
},
|
| 194 |
+
{
|
| 195 |
+
# Example: ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control-Camera", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
| 196 |
+
"model_hash": "b61c605c2adbd23124d152ed28e049ae",
|
| 197 |
+
"model_name": "wan_video_dit",
|
| 198 |
+
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
| 199 |
+
"extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 32, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06, 'has_ref_conv': False, 'add_control_adapter': True, 'in_dim_control_adapter': 24}
|
| 200 |
+
},
|
| 201 |
+
{
|
| 202 |
+
# Example: ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
| 203 |
+
"model_hash": "26bde73488a92e64cc20b0a7485b9e5b",
|
| 204 |
+
"model_name": "wan_video_dit",
|
| 205 |
+
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
| 206 |
+
"extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 48, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06, 'has_ref_conv': True}
|
| 207 |
+
},
|
| 208 |
+
{
|
| 209 |
+
# Example: ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
| 210 |
+
"model_hash": "aafcfd9672c3a2456dc46e1cb6e52c70",
|
| 211 |
+
"model_name": "wan_video_dit",
|
| 212 |
+
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
| 213 |
+
"extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 16, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06}
|
| 214 |
+
},
|
| 215 |
+
{
|
| 216 |
+
# Example: ModelConfig(model_id="iic/VACE-Wan2.1-1.3B-Preview", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
| 217 |
+
"model_hash": "a61453409b67cd3246cf0c3bebad47ba",
|
| 218 |
+
"model_name": "wan_video_dit",
|
| 219 |
+
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
| 220 |
+
"extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 16, 'dim': 1536, 'ffn_dim': 8960, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 12, 'num_layers': 30, 'eps': 1e-06},
|
| 221 |
+
"state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_dit.WanVideoDiTStateDictConverter",
|
| 222 |
+
},
|
| 223 |
+
{
|
| 224 |
+
# Example: ModelConfig(model_id="iic/VACE-Wan2.1-1.3B-Preview", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
| 225 |
+
"model_hash": "a61453409b67cd3246cf0c3bebad47ba",
|
| 226 |
+
"model_name": "wan_video_vace",
|
| 227 |
+
"model_class": "diffsynth.models.wan_video_vace.VaceWanModel",
|
| 228 |
+
"extra_kwargs": {"use_target_text_encoder": True},
|
| 229 |
+
"state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_vace.VaceWanModelDictConverter"
|
| 230 |
+
},
|
| 231 |
+
{
|
| 232 |
+
# Example: ModelConfig(model_id="Wan-AI/Wan2.1-VACE-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
| 233 |
+
"model_hash": "7a513e1f257a861512b1afd387a8ecd9",
|
| 234 |
+
"model_name": "wan_video_dit",
|
| 235 |
+
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
| 236 |
+
"extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 16, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06},
|
| 237 |
+
"state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_dit.WanVideoDiTStateDictConverter",
|
| 238 |
+
},
|
| 239 |
+
{
|
| 240 |
+
# Example: ModelConfig(model_id="Wan-AI/Wan2.1-VACE-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
| 241 |
+
"model_hash": "7a513e1f257a861512b1afd387a8ecd9",
|
| 242 |
+
"model_name": "wan_video_vace",
|
| 243 |
+
"model_class": "diffsynth.models.wan_video_vace.VaceWanModel",
|
| 244 |
+
"extra_kwargs": {'vace_layers': (0, 5, 10, 15, 20, 25, 30, 35), 'vace_in_dim': 96, 'glyph_channels': 16, 'patch_size': (1, 2, 2), 'has_image_input': False, 'dim': 5120, 'num_heads': 40, 'ffn_dim': 13824, 'eps': 1e-06},
|
| 245 |
+
"state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_vace.VaceWanModelDictConverter"
|
| 246 |
+
},
|
| 247 |
+
{
|
| 248 |
+
# Example: ModelConfig(model_id="Wan-AI/Wan2.2-Animate-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
| 249 |
+
"model_hash": "31fa352acb8a1b1d33cd8764273d80a2",
|
| 250 |
+
"model_name": "wan_video_dit",
|
| 251 |
+
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
| 252 |
+
"extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06},
|
| 253 |
+
"state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_dit.WanVideoDiTStateDictConverter"
|
| 254 |
+
},
|
| 255 |
+
{
|
| 256 |
+
# Example: ModelConfig(model_id="Wan-AI/Wan2.2-Animate-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
| 257 |
+
"model_hash": "31fa352acb8a1b1d33cd8764273d80a2",
|
| 258 |
+
"model_name": "wan_video_animate_adapter",
|
| 259 |
+
"model_class": "diffsynth.models.wan_video_animate_adapter.WanAnimateAdapter",
|
| 260 |
+
"state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_animate_adapter.WanAnimateAdapterStateDictConverter"
|
| 261 |
+
},
|
| 262 |
+
{
|
| 263 |
+
# Example: ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-Control-Camera", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors")
|
| 264 |
+
"model_hash": "47dbeab5e560db3180adf51dc0232fb1",
|
| 265 |
+
"model_name": "wan_video_dit",
|
| 266 |
+
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
| 267 |
+
"extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06, 'has_ref_conv': False, 'add_control_adapter': True, 'in_dim_control_adapter': 24, 'require_clip_embedding': False}
|
| 268 |
+
},
|
| 269 |
+
{
|
| 270 |
+
# Example: ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-Control", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors")
|
| 271 |
+
"model_hash": "2267d489f0ceb9f21836532952852ee5",
|
| 272 |
+
"model_name": "wan_video_dit",
|
| 273 |
+
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
| 274 |
+
"extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 52, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06, 'has_ref_conv': True, 'require_clip_embedding': False},
|
| 275 |
+
},
|
| 276 |
+
{
|
| 277 |
+
# Example: ModelConfig(model_id="Wan-AI/Wan2.2-I2V-A14B", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors")
|
| 278 |
+
"model_hash": "5b013604280dd715f8457c6ed6d6a626",
|
| 279 |
+
"model_name": "wan_video_dit",
|
| 280 |
+
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
| 281 |
+
"extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06, 'require_clip_embedding': False}
|
| 282 |
+
},
|
| 283 |
+
{
|
| 284 |
+
# Example: ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
| 285 |
+
"model_hash": "966cffdcc52f9c46c391768b27637614",
|
| 286 |
+
"model_name": "wan_video_dit",
|
| 287 |
+
"model_class": "diffsynth.models.wan_video_dit_s2v.WanS2VModel",
|
| 288 |
+
"extra_kwargs": {'dim': 5120, 'in_dim': 16, 'ffn_dim': 13824, 'out_dim': 16, 'text_dim': 4096, 'freq_dim': 256, 'eps': 1e-06, 'patch_size': (1, 2, 2), 'num_heads': 40, 'num_layers': 40, 'cond_dim': 16, 'audio_dim': 1024, 'num_audio_token': 4}
|
| 289 |
+
},
|
| 290 |
+
{
|
| 291 |
+
# Example: ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
| 292 |
+
"model_hash": "1f5ab7703c6fc803fdded85ff040c316",
|
| 293 |
+
"model_name": "wan_video_dit",
|
| 294 |
+
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
| 295 |
+
"extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 48, 'dim': 3072, 'ffn_dim': 14336, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 48, 'num_heads': 24, 'num_layers': 30, 'eps': 1e-06, 'seperated_timestep': True, 'require_clip_embedding': False, 'require_vae_embedding': False, 'fuse_vae_embedding_in_latents': True}
|
| 296 |
+
},
|
| 297 |
+
{
|
| 298 |
+
# Example: ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="Wan2.2_VAE.pth")
|
| 299 |
+
"model_hash": "e1de6c02cdac79f8b739f4d3698cd216",
|
| 300 |
+
"model_name": "wan_video_vae",
|
| 301 |
+
"model_class": "diffsynth.models.wan_video_vae.WanVideoVAE38",
|
| 302 |
+
"state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_vae.WanVideoVAEStateDictConverter",
|
| 303 |
+
},
|
| 304 |
+
{
|
| 305 |
+
# Example: ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="wav2vec2-large-xlsr-53-english/model.safetensors")
|
| 306 |
+
"model_hash": "06be60f3a4526586d8431cd038a71486",
|
| 307 |
+
"model_name": "wans2v_audio_encoder",
|
| 308 |
+
"model_class": "diffsynth.models.wav2vec.WanS2VAudioEncoder",
|
| 309 |
+
"state_dict_converter": "diffsynth.utils.state_dict_converters.wans2v_audio_encoder.WanS2VAudioEncoderStateDictConverter",
|
| 310 |
+
},
|
| 311 |
+
{
|
| 312 |
+
# Example: ModelConfig(model_id="Wan-AI/WanToDance-14B", origin_file_pattern="global_model.safetensors")
|
| 313 |
+
"model_hash": "eb18873fc0ba77b541eb7b62dbcd2059",
|
| 314 |
+
"model_name": "wan_video_dit",
|
| 315 |
+
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
| 316 |
+
"extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06, 'wantodance_enable_music_inject': True, 'wantodance_music_inject_layers': [0, 4, 8, 12, 16, 20, 24, 27], 'wantodance_enable_refimage': True, 'has_ref_conv': True, 'wantodance_enable_refface': False, 'wantodance_enable_global': True, 'wantodance_enable_dynamicfps': True, 'wantodance_enable_unimodel': True}
|
| 317 |
+
},
|
| 318 |
+
]
|
| 319 |
+
|
| 320 |
+
flux_series = [
|
| 321 |
+
{
|
| 322 |
+
# Example: ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors")
|
| 323 |
+
"model_hash": "a29710fea6dddb0314663ee823598e50",
|
| 324 |
+
"model_name": "flux_dit",
|
| 325 |
+
"model_class": "diffsynth.models.flux_dit.FluxDiT",
|
| 326 |
+
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverter",
|
| 327 |
+
},
|
| 328 |
+
{
|
| 329 |
+
# Supported due to historical reasons.
|
| 330 |
+
"model_hash": "605c56eab23e9e2af863ad8f0813a25d",
|
| 331 |
+
"model_name": "flux_dit",
|
| 332 |
+
"model_class": "diffsynth.models.flux_dit.FluxDiT",
|
| 333 |
+
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverterFromDiffusers",
|
| 334 |
+
},
|
| 335 |
+
{
|
| 336 |
+
# Example: ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors")
|
| 337 |
+
"model_hash": "94eefa3dac9cec93cb1ebaf1747d7b78",
|
| 338 |
+
"model_name": "flux_text_encoder_clip",
|
| 339 |
+
"model_class": "diffsynth.models.flux_text_encoder_clip.FluxTextEncoderClip",
|
| 340 |
+
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_text_encoder_clip.FluxTextEncoderClipStateDictConverter",
|
| 341 |
+
},
|
| 342 |
+
{
|
| 343 |
+
# Example: ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors")
|
| 344 |
+
"model_hash": "22540b49eaedbc2f2784b2091a234c7c",
|
| 345 |
+
"model_name": "flux_text_encoder_t5",
|
| 346 |
+
"model_class": "diffsynth.models.flux_text_encoder_t5.FluxTextEncoderT5",
|
| 347 |
+
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_text_encoder_t5.FluxTextEncoderT5StateDictConverter",
|
| 348 |
+
},
|
| 349 |
+
{
|
| 350 |
+
# Example: ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors")
|
| 351 |
+
"model_hash": "21ea55f476dfc4fd135587abb59dfe5d",
|
| 352 |
+
"model_name": "flux_vae_encoder",
|
| 353 |
+
"model_class": "diffsynth.models.flux_vae.FluxVAEEncoder",
|
| 354 |
+
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_vae.FluxVAEEncoderStateDictConverter",
|
| 355 |
+
},
|
| 356 |
+
{
|
| 357 |
+
# Example: ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors")
|
| 358 |
+
"model_hash": "21ea55f476dfc4fd135587abb59dfe5d",
|
| 359 |
+
"model_name": "flux_vae_decoder",
|
| 360 |
+
"model_class": "diffsynth.models.flux_vae.FluxVAEDecoder",
|
| 361 |
+
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_vae.FluxVAEDecoderStateDictConverter",
|
| 362 |
+
},
|
| 363 |
+
{
|
| 364 |
+
# Example: ModelConfig(model_id="ostris/Flex.2-preview", origin_file_pattern="Flex.2-preview.safetensors")
|
| 365 |
+
"model_hash": "d02f41c13549fa5093d3521f62a5570a",
|
| 366 |
+
"model_name": "flux_dit",
|
| 367 |
+
"model_class": "diffsynth.models.flux_dit.FluxDiT",
|
| 368 |
+
"extra_kwargs": {'input_dim': 196, 'num_blocks': 8},
|
| 369 |
+
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverter",
|
| 370 |
+
},
|
| 371 |
+
{
|
| 372 |
+
# Example: ModelConfig(model_id="DiffSynth-Studio/AttriCtrl-FLUX.1-Dev", origin_file_pattern="models/brightness.safetensors")
|
| 373 |
+
"model_hash": "0629116fce1472503a66992f96f3eb1a",
|
| 374 |
+
"model_name": "flux_value_controller",
|
| 375 |
+
"model_class": "diffsynth.models.flux_value_control.SingleValueEncoder",
|
| 376 |
+
},
|
| 377 |
+
{
|
| 378 |
+
# Example: ModelConfig(model_id="alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta", origin_file_pattern="diffusion_pytorch_model.safetensors")
|
| 379 |
+
"model_hash": "52357cb26250681367488a8954c271e8",
|
| 380 |
+
"model_name": "flux_controlnet",
|
| 381 |
+
"model_class": "diffsynth.models.flux_controlnet.FluxControlNet",
|
| 382 |
+
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_controlnet.FluxControlNetStateDictConverter",
|
| 383 |
+
"extra_kwargs": {"num_joint_blocks": 6, "num_single_blocks": 0, "additional_input_dim": 4},
|
| 384 |
+
},
|
| 385 |
+
{
|
| 386 |
+
# Example: ModelConfig(model_id="InstantX/FLUX.1-dev-Controlnet-Union-alpha", origin_file_pattern="diffusion_pytorch_model.safetensors")
|
| 387 |
+
"model_hash": "78d18b9101345ff695f312e7e62538c0",
|
| 388 |
+
"model_name": "flux_controlnet",
|
| 389 |
+
"model_class": "diffsynth.models.flux_controlnet.FluxControlNet",
|
| 390 |
+
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_controlnet.FluxControlNetStateDictConverter",
|
| 391 |
+
"extra_kwargs": {"num_mode": 10, "mode_dict": {"canny": 0, "tile": 1, "depth": 2, "blur": 3, "pose": 4, "gray": 5, "lq": 6}},
|
| 392 |
+
},
|
| 393 |
+
{
|
| 394 |
+
# Example: ModelConfig(model_id="jasperai/Flux.1-dev-Controlnet-Upscaler", origin_file_pattern="diffusion_pytorch_model.safetensors")
|
| 395 |
+
"model_hash": "b001c89139b5f053c715fe772362dd2a",
|
| 396 |
+
"model_name": "flux_controlnet",
|
| 397 |
+
"model_class": "diffsynth.models.flux_controlnet.FluxControlNet",
|
| 398 |
+
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_controlnet.FluxControlNetStateDictConverter",
|
| 399 |
+
"extra_kwargs": {"num_single_blocks": 0},
|
| 400 |
+
},
|
| 401 |
+
{
|
| 402 |
+
# Example: ModelConfig(model_id="ByteDance/InfiniteYou", origin_file_pattern="infu_flux_v1.0/aes_stage2/image_proj_model.bin")
|
| 403 |
+
"model_hash": "c07c0f04f5ff55e86b4e937c7a40d481",
|
| 404 |
+
"model_name": "infiniteyou_image_projector",
|
| 405 |
+
"model_class": "diffsynth.models.flux_infiniteyou.InfiniteYouImageProjector",
|
| 406 |
+
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_infiniteyou.FluxInfiniteYouImageProjectorStateDictConverter",
|
| 407 |
+
},
|
| 408 |
+
{
|
| 409 |
+
# Example: ModelConfig(model_id="ByteDance/InfiniteYou", origin_file_pattern="infu_flux_v1.0/aes_stage2/InfuseNetModel/*.safetensors")
|
| 410 |
+
"model_hash": "7f9583eb8ba86642abb9a21a4b2c9e16",
|
| 411 |
+
"model_name": "flux_controlnet",
|
| 412 |
+
"model_class": "diffsynth.models.flux_controlnet.FluxControlNet",
|
| 413 |
+
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_controlnet.FluxControlNetStateDictConverter",
|
| 414 |
+
"extra_kwargs": {"num_joint_blocks": 4, "num_single_blocks": 10},
|
| 415 |
+
},
|
| 416 |
+
{
|
| 417 |
+
# Example: ModelConfig(model_id="DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev", origin_file_pattern="model.safetensors")
|
| 418 |
+
"model_hash": "77c2e4dd2440269eb33bfaa0d004f6ab",
|
| 419 |
+
"model_name": "flux_lora_encoder",
|
| 420 |
+
"model_class": "diffsynth.models.flux_lora_encoder.FluxLoRAEncoder",
|
| 421 |
+
},
|
| 422 |
+
{
|
| 423 |
+
# Example: ModelConfig(model_id="DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev", origin_file_pattern="model.safetensors")
|
| 424 |
+
"model_hash": "30143afb2dea73d1ac580e0787628f8c",
|
| 425 |
+
"model_name": "flux_lora_patcher",
|
| 426 |
+
"model_class": "diffsynth.models.flux_lora_patcher.FluxLoraPatcher",
|
| 427 |
+
},
|
| 428 |
+
{
|
| 429 |
+
# Example: ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="model*.safetensors")
|
| 430 |
+
"model_hash": "2bd19e845116e4f875a0a048e27fc219",
|
| 431 |
+
"model_name": "nexus_gen_llm",
|
| 432 |
+
"model_class": "diffsynth.models.nexus_gen.NexusGenAutoregressiveModel",
|
| 433 |
+
"state_dict_converter": "diffsynth.utils.state_dict_converters.nexus_gen.NexusGenAutoregressiveModelStateDictConverter",
|
| 434 |
+
},
|
| 435 |
+
{
|
| 436 |
+
# Example: ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="edit_decoder.bin")
|
| 437 |
+
"model_hash": "63c969fd37cce769a90aa781fbff5f81",
|
| 438 |
+
"model_name": "nexus_gen_editing_adapter",
|
| 439 |
+
"model_class": "diffsynth.models.nexus_gen_projector.NexusGenImageEmbeddingMerger",
|
| 440 |
+
"state_dict_converter": "diffsynth.utils.state_dict_converters.nexus_gen_projector.NexusGenMergerStateDictConverter",
|
| 441 |
+
},
|
| 442 |
+
{
|
| 443 |
+
# Example: ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="edit_decoder.bin")
|
| 444 |
+
"model_hash": "63c969fd37cce769a90aa781fbff5f81",
|
| 445 |
+
"model_name": "flux_dit",
|
| 446 |
+
"model_class": "diffsynth.models.flux_dit.FluxDiT",
|
| 447 |
+
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverter",
|
| 448 |
+
},
|
| 449 |
+
{
|
| 450 |
+
# Example: ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="generation_decoder.bin")
|
| 451 |
+
"model_hash": "3e6c61b0f9471135fc9c6d6a98e98b6d",
|
| 452 |
+
"model_name": "nexus_gen_generation_adapter",
|
| 453 |
+
"model_class": "diffsynth.models.nexus_gen_projector.NexusGenAdapter",
|
| 454 |
+
"state_dict_converter": "diffsynth.utils.state_dict_converters.nexus_gen_projector.NexusGenAdapterStateDictConverter",
|
| 455 |
+
},
|
| 456 |
+
{
|
| 457 |
+
# Example: ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="generation_decoder.bin")
|
| 458 |
+
"model_hash": "3e6c61b0f9471135fc9c6d6a98e98b6d",
|
| 459 |
+
"model_name": "flux_dit",
|
| 460 |
+
"model_class": "diffsynth.models.flux_dit.FluxDiT",
|
| 461 |
+
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverter",
|
| 462 |
+
},
|
| 463 |
+
{
|
| 464 |
+
# Example: ModelConfig(model_id="InstantX/FLUX.1-dev-IP-Adapter", origin_file_pattern="ip-adapter.bin")
|
| 465 |
+
"model_hash": "4daaa66cc656a8fe369908693dad0a35",
|
| 466 |
+
"model_name": "flux_ipadapter",
|
| 467 |
+
"model_class": "diffsynth.models.flux_ipadapter.FluxIpAdapter",
|
| 468 |
+
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_ipadapter.FluxIpAdapterStateDictConverter",
|
| 469 |
+
},
|
| 470 |
+
{
|
| 471 |
+
# Example: ModelConfig(model_id="google/siglip-so400m-patch14-384", origin_file_pattern="model.safetensors")
|
| 472 |
+
"model_hash": "04d8c1e20a1f1b25f7434f111992a33f",
|
| 473 |
+
"model_name": "siglip_vision_model",
|
| 474 |
+
"model_class": "diffsynth.models.flux_ipadapter.SiglipVisionModelSO400M",
|
| 475 |
+
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_ipadapter.SiglipStateDictConverter",
|
| 476 |
+
},
|
| 477 |
+
{
|
| 478 |
+
# Example: ModelConfig(model_id="stepfun-ai/Step1X-Edit", origin_file_pattern="step1x-edit-i1258.safetensors"),
|
| 479 |
+
"model_hash": "d30fb9e02b1dbf4e509142f05cf7dd50",
|
| 480 |
+
"model_name": "step1x_connector",
|
| 481 |
+
"model_class": "diffsynth.models.step1x_connector.Qwen2Connector",
|
| 482 |
+
"state_dict_converter": "diffsynth.utils.state_dict_converters.step1x_connector.Qwen2ConnectorStateDictConverter",
|
| 483 |
+
},
|
| 484 |
+
{
|
| 485 |
+
# Example: ModelConfig(model_id="stepfun-ai/Step1X-Edit", origin_file_pattern="step1x-edit-i1258.safetensors"),
|
| 486 |
+
"model_hash": "d30fb9e02b1dbf4e509142f05cf7dd50",
|
| 487 |
+
"model_name": "flux_dit",
|
| 488 |
+
"model_class": "diffsynth.models.flux_dit.FluxDiT",
|
| 489 |
+
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverter",
|
| 490 |
+
"extra_kwargs": {"disable_guidance_embedder": True},
|
| 491 |
+
},
|
| 492 |
+
{
|
| 493 |
+
# Example: ModelConfig(model_id="MAILAND/majicflus_v1", origin_file_pattern="majicflus_v134.safetensors")
|
| 494 |
+
"model_hash": "3394f306c4cbf04334b712bf5aaed95f",
|
| 495 |
+
"model_name": "flux_dit",
|
| 496 |
+
"model_class": "diffsynth.models.flux_dit.FluxDiT",
|
| 497 |
+
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverter",
|
| 498 |
+
},
|
| 499 |
+
]
|
| 500 |
+
|
| 501 |
+
flux2_series = [
|
| 502 |
+
{
|
| 503 |
+
# Example: ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="text_encoder/*.safetensors")
|
| 504 |
+
"model_hash": "28fca3d8e5bf2a2d1271748a773f6757",
|
| 505 |
+
"model_name": "flux2_text_encoder",
|
| 506 |
+
"model_class": "diffsynth.models.flux2_text_encoder.Flux2TextEncoder",
|
| 507 |
+
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux2_text_encoder.Flux2TextEncoderStateDictConverter",
|
| 508 |
+
},
|
| 509 |
+
{
|
| 510 |
+
# Example: ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="transformer/*.safetensors")
|
| 511 |
+
"model_hash": "d38e1d5c5aec3b0a11e79327ac6e3b0f",
|
| 512 |
+
"model_name": "flux2_dit",
|
| 513 |
+
"model_class": "diffsynth.models.flux2_dit.Flux2DiT",
|
| 514 |
+
},
|
| 515 |
+
{
|
| 516 |
+
# Example: ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="vae/diffusion_pytorch_model.safetensors")
|
| 517 |
+
"model_hash": "c54288e3ee12ca215898840682337b95",
|
| 518 |
+
"model_name": "flux2_vae",
|
| 519 |
+
"model_class": "diffsynth.models.flux2_vae.Flux2VAE",
|
| 520 |
+
},
|
| 521 |
+
{
|
| 522 |
+
# Example: ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="transformer/*.safetensors")
|
| 523 |
+
"model_hash": "3bde7b817fec8143028b6825a63180df",
|
| 524 |
+
"model_name": "flux2_dit",
|
| 525 |
+
"model_class": "diffsynth.models.flux2_dit.Flux2DiT",
|
| 526 |
+
"extra_kwargs": {"guidance_embeds": False, "joint_attention_dim": 7680, "num_attention_heads": 24, "num_layers": 5, "num_single_layers": 20}
|
| 527 |
+
},
|
| 528 |
+
{
|
| 529 |
+
# Example: ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="text_encoder/*.safetensors")
|
| 530 |
+
"model_hash": "9195f3ea256fcd0ae6d929c203470754",
|
| 531 |
+
"model_name": "z_image_text_encoder",
|
| 532 |
+
"model_class": "diffsynth.models.z_image_text_encoder.ZImageTextEncoder",
|
| 533 |
+
"extra_kwargs": {"model_size": "8B"},
|
| 534 |
+
"state_dict_converter": "diffsynth.utils.state_dict_converters.z_image_text_encoder.ZImageTextEncoderStateDictConverter",
|
| 535 |
+
},
|
| 536 |
+
{
|
| 537 |
+
# Example: ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="transformer/*.safetensors")
|
| 538 |
+
"model_hash": "39c6fc48f07bebecedbbaa971ff466c8",
|
| 539 |
+
"model_name": "flux2_dit",
|
| 540 |
+
"model_class": "diffsynth.models.flux2_dit.Flux2DiT",
|
| 541 |
+
"extra_kwargs": {"guidance_embeds": False, "joint_attention_dim": 12288, "num_attention_heads": 32, "num_layers": 8, "num_single_layers": 24}
|
| 542 |
+
},
|
| 543 |
+
]
|
| 544 |
+
|
| 545 |
+
z_image_series = [
|
| 546 |
+
{
|
| 547 |
+
# Example: ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="transformer/*.safetensors")
|
| 548 |
+
"model_hash": "fc3a8a1247fe185ce116ccbe0e426c28",
|
| 549 |
+
"model_name": "z_image_dit",
|
| 550 |
+
"model_class": "diffsynth.models.z_image_dit.ZImageDiT",
|
| 551 |
+
},
|
| 552 |
+
{
|
| 553 |
+
# Example: ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors")
|
| 554 |
+
"model_hash": "0f050f62a88876fea6eae0a18dac5a2e",
|
| 555 |
+
"model_name": "z_image_text_encoder",
|
| 556 |
+
"model_class": "diffsynth.models.z_image_text_encoder.ZImageTextEncoder",
|
| 557 |
+
},
|
| 558 |
+
{
|
| 559 |
+
# Example: ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/vae/diffusion_pytorch_model.safetensors")
|
| 560 |
+
"model_hash": "1aafa3cc91716fb6b300cc1cd51b85a3",
|
| 561 |
+
"model_name": "flux_vae_encoder",
|
| 562 |
+
"model_class": "diffsynth.models.flux_vae.FluxVAEEncoder",
|
| 563 |
+
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_vae.FluxVAEEncoderStateDictConverterDiffusers",
|
| 564 |
+
"extra_kwargs": {"use_conv_attention": False},
|
| 565 |
+
},
|
| 566 |
+
{
|
| 567 |
+
# Example: ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/vae/diffusion_pytorch_model.safetensors")
|
| 568 |
+
"model_hash": "1aafa3cc91716fb6b300cc1cd51b85a3",
|
| 569 |
+
"model_name": "flux_vae_decoder",
|
| 570 |
+
"model_class": "diffsynth.models.flux_vae.FluxVAEDecoder",
|
| 571 |
+
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_vae.FluxVAEDecoderStateDictConverterDiffusers",
|
| 572 |
+
"extra_kwargs": {"use_conv_attention": False},
|
| 573 |
+
},
|
| 574 |
+
{
|
| 575 |
+
# Example: ModelConfig(model_id="Tongyi-MAI/Z-Image-Omni-Base", origin_file_pattern="transformer/*.safetensors")
|
| 576 |
+
"model_hash": "aa3563718e5c3ecde3dfbb020ca61180",
|
| 577 |
+
"model_name": "z_image_dit",
|
| 578 |
+
"model_class": "diffsynth.models.z_image_dit.ZImageDiT",
|
| 579 |
+
"extra_kwargs": {"siglip_feat_dim": 1152},
|
| 580 |
+
},
|
| 581 |
+
{
|
| 582 |
+
# Example: ModelConfig(model_id="Tongyi-MAI/Z-Image-Omni-Base", origin_file_pattern="siglip/model.safetensors")
|
| 583 |
+
"model_hash": "89d48e420f45cff95115a9f3e698d44a",
|
| 584 |
+
"model_name": "siglip_vision_model_428m",
|
| 585 |
+
"model_class": "diffsynth.models.siglip2_image_encoder.Siglip2ImageEncoder428M",
|
| 586 |
+
},
|
| 587 |
+
{
|
| 588 |
+
# Example: ModelConfig(model_id="PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1", origin_file_pattern="Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.safetensors")
|
| 589 |
+
"model_hash": "1677708d40029ab380a95f6c731a57d7",
|
| 590 |
+
"model_name": "z_image_controlnet",
|
| 591 |
+
"model_class": "diffsynth.models.z_image_controlnet.ZImageControlNet",
|
| 592 |
+
},
|
| 593 |
+
{
|
| 594 |
+
# Example: ???
|
| 595 |
+
"model_hash": "9510cb8cd1dd34ee0e4f111c24905510",
|
| 596 |
+
"model_name": "z_image_image2lora_style",
|
| 597 |
+
"model_class": "diffsynth.models.z_image_image2lora.ZImageImage2LoRAModel",
|
| 598 |
+
"extra_kwargs": {"compress_dim": 128},
|
| 599 |
+
},
|
| 600 |
+
{
|
| 601 |
+
# Example: ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="model.safetensors")
|
| 602 |
+
"model_hash": "1392adecee344136041e70553f875f31",
|
| 603 |
+
"model_name": "z_image_text_encoder",
|
| 604 |
+
"model_class": "diffsynth.models.z_image_text_encoder.ZImageTextEncoder",
|
| 605 |
+
"extra_kwargs": {"model_size": "0.6B"},
|
| 606 |
+
"state_dict_converter": "diffsynth.utils.state_dict_converters.z_image_text_encoder.ZImageTextEncoderStateDictConverter",
|
| 607 |
+
},
|
| 608 |
+
{
|
| 609 |
+
# To ensure compatibility with the `model.diffusion_model` prefix introduced by other frameworks.
|
| 610 |
+
"model_hash": "8cf241a0d32f93d5de368502a086852f",
|
| 611 |
+
"model_name": "z_image_dit",
|
| 612 |
+
"model_class": "diffsynth.models.z_image_dit.ZImageDiT",
|
| 613 |
+
"state_dict_converter": "diffsynth.utils.state_dict_converters.z_image_dit.ZImageDiTStateDictConverter",
|
| 614 |
+
},
|
| 615 |
+
]
|
| 616 |
+
"""
|
| 617 |
+
Offical model repo: https://www.modelscope.cn/models/Lightricks/LTX-2
|
| 618 |
+
Repackaged model repo: https://www.modelscope.cn/models/DiffSynth-Studio/LTX-2-Repackage
|
| 619 |
+
For base models of LTX-2, offical checkpoint (with model config ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors"))
|
| 620 |
+
and repackaged checkpoints (with model config ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="*.safetensors")) are both supported.
|
| 621 |
+
We have repackeged the official checkpoints in DiffSynth-Studio/LTX-2-Repackage repo to support separate loading of different submodules,
|
| 622 |
+
and avoid redundant memory usage when users only want to use part of the model.
|
| 623 |
+
"""
|
| 624 |
+
ltx2_series = [
|
| 625 |
+
{
|
| 626 |
+
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
|
| 627 |
+
"model_hash": "aca7b0bbf8415e9c98360750268915fc",
|
| 628 |
+
"model_name": "ltx2_dit",
|
| 629 |
+
"model_class": "diffsynth.models.ltx2_dit.LTXModel",
|
| 630 |
+
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_dit.LTXModelStateDictConverter",
|
| 631 |
+
},
|
| 632 |
+
{
|
| 633 |
+
# Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="transformer.safetensors")
|
| 634 |
+
"model_hash": "c567aaa37d5ed7454c73aa6024458661",
|
| 635 |
+
"model_name": "ltx2_dit",
|
| 636 |
+
"model_class": "diffsynth.models.ltx2_dit.LTXModel",
|
| 637 |
+
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_dit.LTXModelStateDictConverter",
|
| 638 |
+
},
|
| 639 |
+
{
|
| 640 |
+
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
|
| 641 |
+
"model_hash": "aca7b0bbf8415e9c98360750268915fc",
|
| 642 |
+
"model_name": "ltx2_video_vae_encoder",
|
| 643 |
+
"model_class": "diffsynth.models.ltx2_video_vae.LTX2VideoEncoder",
|
| 644 |
+
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_video_vae.LTX2VideoEncoderStateDictConverter",
|
| 645 |
+
},
|
| 646 |
+
{
|
| 647 |
+
# Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_encoder.safetensors")
|
| 648 |
+
"model_hash": "7f7e904a53260ec0351b05f32153754b",
|
| 649 |
+
"model_name": "ltx2_video_vae_encoder",
|
| 650 |
+
"model_class": "diffsynth.models.ltx2_video_vae.LTX2VideoEncoder",
|
| 651 |
+
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_video_vae.LTX2VideoEncoderStateDictConverter",
|
| 652 |
+
},
|
| 653 |
+
{
|
| 654 |
+
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
|
| 655 |
+
"model_hash": "aca7b0bbf8415e9c98360750268915fc",
|
| 656 |
+
"model_name": "ltx2_video_vae_decoder",
|
| 657 |
+
"model_class": "diffsynth.models.ltx2_video_vae.LTX2VideoDecoder",
|
| 658 |
+
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_video_vae.LTX2VideoDecoderStateDictConverter",
|
| 659 |
+
},
|
| 660 |
+
{
|
| 661 |
+
# Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_decoder.safetensors")
|
| 662 |
+
"model_hash": "dc6029ca2825147872b45e35a2dc3a97",
|
| 663 |
+
"model_name": "ltx2_video_vae_decoder",
|
| 664 |
+
"model_class": "diffsynth.models.ltx2_video_vae.LTX2VideoDecoder",
|
| 665 |
+
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_video_vae.LTX2VideoDecoderStateDictConverter",
|
| 666 |
+
},
|
| 667 |
+
{
|
| 668 |
+
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
|
| 669 |
+
"model_hash": "aca7b0bbf8415e9c98360750268915fc",
|
| 670 |
+
"model_name": "ltx2_audio_vae_decoder",
|
| 671 |
+
"model_class": "diffsynth.models.ltx2_audio_vae.LTX2AudioDecoder",
|
| 672 |
+
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2AudioDecoderStateDictConverter",
|
| 673 |
+
},
|
| 674 |
+
{
|
| 675 |
+
# Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_decoder.safetensors")
|
| 676 |
+
"model_hash": "7d7823dde8f1ea0b50fb07ac329dd4cb",
|
| 677 |
+
"model_name": "ltx2_audio_vae_decoder",
|
| 678 |
+
"model_class": "diffsynth.models.ltx2_audio_vae.LTX2AudioDecoder",
|
| 679 |
+
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2AudioDecoderStateDictConverter",
|
| 680 |
+
},
|
| 681 |
+
{
|
| 682 |
+
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
|
| 683 |
+
"model_hash": "aca7b0bbf8415e9c98360750268915fc",
|
| 684 |
+
"model_name": "ltx2_audio_vocoder",
|
| 685 |
+
"model_class": "diffsynth.models.ltx2_audio_vae.LTX2Vocoder",
|
| 686 |
+
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2VocoderStateDictConverter",
|
| 687 |
+
},
|
| 688 |
+
{
|
| 689 |
+
# Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vocoder.safetensors")
|
| 690 |
+
"model_hash": "f471360f6b24bef702ab73133d9f8bb9",
|
| 691 |
+
"model_name": "ltx2_audio_vocoder",
|
| 692 |
+
"model_class": "diffsynth.models.ltx2_audio_vae.LTX2Vocoder",
|
| 693 |
+
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2VocoderStateDictConverter",
|
| 694 |
+
},
|
| 695 |
+
{
|
| 696 |
+
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
|
| 697 |
+
"model_hash": "aca7b0bbf8415e9c98360750268915fc",
|
| 698 |
+
"model_name": "ltx2_audio_vae_encoder",
|
| 699 |
+
"model_class": "diffsynth.models.ltx2_audio_vae.LTX2AudioEncoder",
|
| 700 |
+
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2AudioEncoderStateDictConverter",
|
| 701 |
+
},
|
| 702 |
+
{
|
| 703 |
+
# Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_encoder.safetensors")
|
| 704 |
+
"model_hash": "29338f3b95e7e312a3460a482e4f4554",
|
| 705 |
+
"model_name": "ltx2_audio_vae_encoder",
|
| 706 |
+
"model_class": "diffsynth.models.ltx2_audio_vae.LTX2AudioEncoder",
|
| 707 |
+
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2AudioEncoderStateDictConverter",
|
| 708 |
+
},
|
| 709 |
+
{
|
| 710 |
+
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
|
| 711 |
+
"model_hash": "aca7b0bbf8415e9c98360750268915fc",
|
| 712 |
+
"model_name": "ltx2_text_encoder_post_modules",
|
| 713 |
+
"model_class": "diffsynth.models.ltx2_text_encoder.LTX2TextEncoderPostModules",
|
| 714 |
+
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_text_encoder.LTX2TextEncoderPostModulesStateDictConverter",
|
| 715 |
+
},
|
| 716 |
+
{
|
| 717 |
+
# Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="text_encoder_post_modules.safetensors")
|
| 718 |
+
"model_hash": "981629689c8be92a712ab3c5eb4fc3f6",
|
| 719 |
+
"model_name": "ltx2_text_encoder_post_modules",
|
| 720 |
+
"model_class": "diffsynth.models.ltx2_text_encoder.LTX2TextEncoderPostModules",
|
| 721 |
+
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_text_encoder.LTX2TextEncoderPostModulesStateDictConverter",
|
| 722 |
+
},
|
| 723 |
+
{
|
| 724 |
+
# Example: ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors")
|
| 725 |
+
"model_hash": "33917f31c4a79196171154cca39f165e",
|
| 726 |
+
"model_name": "ltx2_text_encoder",
|
| 727 |
+
"model_class": "diffsynth.models.ltx2_text_encoder.LTX2TextEncoder",
|
| 728 |
+
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_text_encoder.LTX2TextEncoderStateDictConverter",
|
| 729 |
+
},
|
| 730 |
+
{
|
| 731 |
+
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
|
| 732 |
+
"model_hash": "c79c458c6e99e0e14d47e676761732d2",
|
| 733 |
+
"model_name": "ltx2_latent_upsampler",
|
| 734 |
+
"model_class": "diffsynth.models.ltx2_upsampler.LTX2LatentUpsampler",
|
| 735 |
+
},
|
| 736 |
+
{
|
| 737 |
+
# Example: ModelConfig(model_id="Lightricks/LTX-2.3", origin_file_pattern="ltx-2.3-22b-dev.safetensors")
|
| 738 |
+
"model_hash": "f3a83ecf3995dcc4fae2d27e08ad5767",
|
| 739 |
+
"model_name": "ltx2_dit",
|
| 740 |
+
"model_class": "diffsynth.models.ltx2_dit.LTXModel",
|
| 741 |
+
"extra_kwargs": {"apply_gated_attention": True, "cross_attention_adaln": True, "caption_channels": None},
|
| 742 |
+
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_dit.LTXModelStateDictConverter",
|
| 743 |
+
},
|
| 744 |
+
{
|
| 745 |
+
# Example: ModelConfig(model_id="Lightricks/LTX-2.3", origin_file_pattern="ltx-2.3-22b-dev.safetensors")
|
| 746 |
+
"model_hash": "f3a83ecf3995dcc4fae2d27e08ad5767",
|
| 747 |
+
"model_name": "ltx2_video_vae_encoder",
|
| 748 |
+
"model_class": "diffsynth.models.ltx2_video_vae.LTX2VideoEncoder",
|
| 749 |
+
"extra_kwargs": {"encoder_version": "ltx-2.3"},
|
| 750 |
+
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_video_vae.LTX2VideoEncoderStateDictConverter",
|
| 751 |
+
},
|
| 752 |
+
{
|
| 753 |
+
# Example: ModelConfig(model_id="Lightricks/LTX-2.3", origin_file_pattern="ltx-2.3-22b-dev.safetensors")
|
| 754 |
+
"model_hash": "f3a83ecf3995dcc4fae2d27e08ad5767",
|
| 755 |
+
"model_name": "ltx2_video_vae_decoder",
|
| 756 |
+
"model_class": "diffsynth.models.ltx2_video_vae.LTX2VideoDecoder",
|
| 757 |
+
"extra_kwargs": {"decoder_version": "ltx-2.3"},
|
| 758 |
+
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_video_vae.LTX2VideoDecoderStateDictConverter",
|
| 759 |
+
},
|
| 760 |
+
{
|
| 761 |
+
# Example: ModelConfig(model_id="Lightricks/LTX-2.3", origin_file_pattern="ltx-2.3-22b-dev.safetensors")
|
| 762 |
+
"model_hash": "f3a83ecf3995dcc4fae2d27e08ad5767",
|
| 763 |
+
"model_name": "ltx2_audio_vae_decoder",
|
| 764 |
+
"model_class": "diffsynth.models.ltx2_audio_vae.LTX2AudioDecoder",
|
| 765 |
+
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2AudioDecoderStateDictConverter",
|
| 766 |
+
},
|
| 767 |
+
{
|
| 768 |
+
# Example: ModelConfig(model_id="Lightricks/LTX-2.3", origin_file_pattern="ltx-2.3-22b-dev.safetensors")
|
| 769 |
+
"model_hash": "f3a83ecf3995dcc4fae2d27e08ad5767",
|
| 770 |
+
"model_name": "ltx2_audio_vocoder",
|
| 771 |
+
"model_class": "diffsynth.models.ltx2_audio_vae.LTX2VocoderWithBWE",
|
| 772 |
+
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2VocoderStateDictConverter",
|
| 773 |
+
},
|
| 774 |
+
{
|
| 775 |
+
# Example: ModelConfig(model_id="Lightricks/LTX-2.3", origin_file_pattern="ltx-2.3-22b-dev.safetensors")
|
| 776 |
+
"model_hash": "f3a83ecf3995dcc4fae2d27e08ad5767",
|
| 777 |
+
"model_name": "ltx2_audio_vae_encoder",
|
| 778 |
+
"model_class": "diffsynth.models.ltx2_audio_vae.LTX2AudioEncoder",
|
| 779 |
+
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2AudioEncoderStateDictConverter",
|
| 780 |
+
},
|
| 781 |
+
{
|
| 782 |
+
# Example: ModelConfig(model_id="Lightricks/LTX-2.3", origin_file_pattern="ltx-2.3-22b-dev.safetensors")
|
| 783 |
+
"model_hash": "f3a83ecf3995dcc4fae2d27e08ad5767",
|
| 784 |
+
"model_name": "ltx2_text_encoder_post_modules",
|
| 785 |
+
"model_class": "diffsynth.models.ltx2_text_encoder.LTX2TextEncoderPostModules",
|
| 786 |
+
"extra_kwargs": {"separated_audio_video": True, "embedding_dim_gemma": 3840, "num_layers_gemma": 49, "video_attention_heads": 32, "video_attention_head_dim": 128, "audio_attention_heads": 32, "audio_attention_head_dim": 64, "num_connector_layers": 8, "apply_gated_attention": True},
|
| 787 |
+
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_text_encoder.LTX2TextEncoderPostModulesStateDictConverter",
|
| 788 |
+
},
|
| 789 |
+
{
|
| 790 |
+
# Example: ModelConfig(model_id="Lightricks/LTX-2.3", origin_file_pattern="ltx-2.3-spatial-upscaler-x2-1.0.safetensors")
|
| 791 |
+
"model_hash": "aed408774d694a2452f69936c32febb5",
|
| 792 |
+
"model_name": "ltx2_latent_upsampler",
|
| 793 |
+
"model_class": "diffsynth.models.ltx2_upsampler.LTX2LatentUpsampler",
|
| 794 |
+
"extra_kwargs": {"rational_resampler": False},
|
| 795 |
+
},
|
| 796 |
+
{
|
| 797 |
+
# Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2.3-Repackage", origin_file_pattern="transformer.safetensors")
|
| 798 |
+
"model_hash": "1c55afad76ed33c112a2978550b524d1",
|
| 799 |
+
"model_name": "ltx2_dit",
|
| 800 |
+
"model_class": "diffsynth.models.ltx2_dit.LTXModel",
|
| 801 |
+
"extra_kwargs": {"apply_gated_attention": True, "cross_attention_adaln": True, "caption_channels": None},
|
| 802 |
+
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_dit.LTXModelStateDictConverter",
|
| 803 |
+
},
|
| 804 |
+
{
|
| 805 |
+
# Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2.3-Repackage", origin_file_pattern="video_vae_encoder.safetensors")
|
| 806 |
+
"model_hash": "eecdc07c2ec30863b8a2b8b2134036cf",
|
| 807 |
+
"model_name": "ltx2_video_vae_encoder",
|
| 808 |
+
"model_class": "diffsynth.models.ltx2_video_vae.LTX2VideoEncoder",
|
| 809 |
+
"extra_kwargs": {"encoder_version": "ltx-2.3"},
|
| 810 |
+
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_video_vae.LTX2VideoEncoderStateDictConverter",
|
| 811 |
+
},
|
| 812 |
+
{
|
| 813 |
+
# Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2.3-Repackage", origin_file_pattern="video_vae_decoder.safetensors")
|
| 814 |
+
"model_hash": "deda2f542e17ee25bc8c38fd605316ea",
|
| 815 |
+
"model_name": "ltx2_video_vae_decoder",
|
| 816 |
+
"model_class": "diffsynth.models.ltx2_video_vae.LTX2VideoDecoder",
|
| 817 |
+
"extra_kwargs": {"decoder_version": "ltx-2.3"},
|
| 818 |
+
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_video_vae.LTX2VideoDecoderStateDictConverter",
|
| 819 |
+
},
|
| 820 |
+
{
|
| 821 |
+
# Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2.3-Repackage", origin_file_pattern="audio_vocoder.safetensors")
|
| 822 |
+
"model_hash": "7d7823dde8f1ea0b50fb07ac329dd4cb",
|
| 823 |
+
"model_name": "ltx2_audio_vae_decoder",
|
| 824 |
+
"model_class": "diffsynth.models.ltx2_audio_vae.LTX2AudioDecoder",
|
| 825 |
+
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2AudioDecoderStateDictConverter",
|
| 826 |
+
},
|
| 827 |
+
{
|
| 828 |
+
# Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2.3-Repackage", origin_file_pattern="audio_vae_encoder.safetensors")
|
| 829 |
+
"model_hash": "29338f3b95e7e312a3460a482e4f4554",
|
| 830 |
+
"model_name": "ltx2_audio_vae_encoder",
|
| 831 |
+
"model_class": "diffsynth.models.ltx2_audio_vae.LTX2AudioEncoder",
|
| 832 |
+
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2AudioEncoderStateDictConverter",
|
| 833 |
+
},
|
| 834 |
+
{
|
| 835 |
+
# Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2.3-Repackage", origin_file_pattern="audio_vocoder.safetensors")
|
| 836 |
+
"model_hash": "cd436c99e69ec5c80f050f0944f02a15",
|
| 837 |
+
"model_name": "ltx2_audio_vocoder",
|
| 838 |
+
"model_class": "diffsynth.models.ltx2_audio_vae.LTX2VocoderWithBWE",
|
| 839 |
+
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2VocoderStateDictConverter",
|
| 840 |
+
},
|
| 841 |
+
{
|
| 842 |
+
# Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2.3-Repackage", origin_file_pattern="text_encoder_post_modules.safetensors")
|
| 843 |
+
"model_hash": "05da2aab1c4b061f72c426311c165a43",
|
| 844 |
+
"model_name": "ltx2_text_encoder_post_modules",
|
| 845 |
+
"model_class": "diffsynth.models.ltx2_text_encoder.LTX2TextEncoderPostModules",
|
| 846 |
+
"extra_kwargs": {"separated_audio_video": True, "embedding_dim_gemma": 3840, "num_layers_gemma": 49, "video_attention_heads": 32, "video_attention_head_dim": 128, "audio_attention_heads": 32, "audio_attention_head_dim": 64, "num_connector_layers": 8, "apply_gated_attention": True},
|
| 847 |
+
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_text_encoder.LTX2TextEncoderPostModulesStateDictConverter",
|
| 848 |
+
},
|
| 849 |
+
]
|
| 850 |
+
anima_series = [
|
| 851 |
+
{
|
| 852 |
+
# Example: ModelConfig(model_id="circlestone-labs/Anima", origin_file_pattern="split_files/vae/qwen_image_vae.safetensors")
|
| 853 |
+
"model_hash": "a9995952c2d8e63cf82e115005eb61b9",
|
| 854 |
+
"model_name": "z_image_text_encoder",
|
| 855 |
+
"model_class": "diffsynth.models.z_image_text_encoder.ZImageTextEncoder",
|
| 856 |
+
"extra_kwargs": {"model_size": "0.6B"},
|
| 857 |
+
},
|
| 858 |
+
{
|
| 859 |
+
# Example: ModelConfig(model_id="circlestone-labs/Anima", origin_file_pattern="split_files/diffusion_models/anima-preview.safetensors")
|
| 860 |
+
"model_hash": "417673936471e79e31ed4d186d7a3f4a",
|
| 861 |
+
"model_name": "anima_dit",
|
| 862 |
+
"model_class": "diffsynth.models.anima_dit.AnimaDiT",
|
| 863 |
+
"state_dict_converter": "diffsynth.utils.state_dict_converters.anima_dit.AnimaDiTStateDictConverter",
|
| 864 |
+
}
|
| 865 |
+
]
|
| 866 |
+
|
| 867 |
+
mova_series = [
|
| 868 |
+
# Example: ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="audio_dit/diffusion_pytorch_model.safetensors")
|
| 869 |
+
{
|
| 870 |
+
"model_hash": "8c57e12790e2c45a64817e0ce28cde2f",
|
| 871 |
+
"model_name": "mova_audio_dit",
|
| 872 |
+
"model_class": "diffsynth.models.mova_audio_dit.MovaAudioDit",
|
| 873 |
+
"extra_kwargs": {'has_image_input': False, 'patch_size': [1], 'in_dim': 128, 'dim': 1536, 'ffn_dim': 8960, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 128, 'num_heads': 12, 'num_layers': 30, 'eps': 1e-06}
|
| 874 |
+
},
|
| 875 |
+
# Example: ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="audio_vae/diffusion_pytorch_model.safetensors")
|
| 876 |
+
{
|
| 877 |
+
"model_hash": "418517fb2b4e919d2cac8f314fcf82ac",
|
| 878 |
+
"model_name": "mova_audio_vae",
|
| 879 |
+
"model_class": "diffsynth.models.mova_audio_vae.DacVAE",
|
| 880 |
+
},
|
| 881 |
+
# Example: ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="dual_tower_bridge/diffusion_pytorch_model.safetensors")
|
| 882 |
+
{
|
| 883 |
+
"model_hash": "d1139dbbc8b4ab53cf4b4243d57bbceb",
|
| 884 |
+
"model_name": "mova_dual_tower_bridge",
|
| 885 |
+
"model_class": "diffsynth.models.mova_dual_tower_bridge.DualTowerConditionalBridge",
|
| 886 |
+
},
|
| 887 |
+
]
|
| 888 |
+
MODEL_CONFIGS = qwen_image_series + wan_series + flux_series + flux2_series + z_image_series + ltx2_series + anima_series + mova_series
|
diffsynth/configs/vram_management_module_maps.py
ADDED
|
@@ -0,0 +1,284 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
flux_general_vram_config = {
|
| 2 |
+
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
| 3 |
+
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 4 |
+
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 5 |
+
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 6 |
+
"torch.nn.GroupNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 7 |
+
"diffsynth.models.general_modules.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 8 |
+
"diffsynth.models.flux_lora_encoder.LoRALayerBlock": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 9 |
+
"diffsynth.models.flux_lora_patcher.LoraMerger": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 10 |
+
}
|
| 11 |
+
|
| 12 |
+
VRAM_MANAGEMENT_MODULE_MAPS = {
|
| 13 |
+
"diffsynth.models.qwen_image_dit.QwenImageDiT": {
|
| 14 |
+
"diffsynth.models.qwen_image_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 15 |
+
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
| 16 |
+
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 17 |
+
},
|
| 18 |
+
"diffsynth.models.qwen_image_text_encoder.QwenImageTextEncoder": {
|
| 19 |
+
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
| 20 |
+
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 21 |
+
"transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VLRotaryEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 22 |
+
"transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 23 |
+
"transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VisionPatchEmbed": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 24 |
+
"transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VisionRotaryEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 25 |
+
},
|
| 26 |
+
"diffsynth.models.qwen_image_vae.QwenImageVAE": {
|
| 27 |
+
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
| 28 |
+
"torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 29 |
+
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 30 |
+
"diffsynth.models.qwen_image_vae.QwenImageRMS_norm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 31 |
+
},
|
| 32 |
+
"diffsynth.models.qwen_image_controlnet.BlockWiseControlBlock": {
|
| 33 |
+
"diffsynth.models.qwen_image_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 34 |
+
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
| 35 |
+
},
|
| 36 |
+
"diffsynth.models.siglip2_image_encoder.Siglip2ImageEncoder": {
|
| 37 |
+
"transformers.models.siglip.modeling_siglip.SiglipVisionEmbeddings": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 38 |
+
"transformers.models.siglip.modeling_siglip.SiglipMultiheadAttentionPoolingHead": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 39 |
+
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 40 |
+
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 41 |
+
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 42 |
+
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
| 43 |
+
},
|
| 44 |
+
"diffsynth.models.dinov3_image_encoder.DINOv3ImageEncoder": {
|
| 45 |
+
"transformers.models.dinov3_vit.modeling_dinov3_vit.DINOv3ViTLayerScale": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 46 |
+
"transformers.models.dinov3_vit.modeling_dinov3_vit.DINOv3ViTRopePositionEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 47 |
+
"transformers.models.dinov3_vit.modeling_dinov3_vit.DINOv3ViTEmbeddings": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 48 |
+
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 49 |
+
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 50 |
+
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
| 51 |
+
},
|
| 52 |
+
"diffsynth.models.qwen_image_image2lora.QwenImageImage2LoRAModel": {
|
| 53 |
+
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
| 54 |
+
},
|
| 55 |
+
"diffsynth.models.wan_video_animate_adapter.WanAnimateAdapter": {
|
| 56 |
+
"diffsynth.models.wan_video_animate_adapter.FaceEncoder": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 57 |
+
"diffsynth.models.wan_video_animate_adapter.EqualLinear": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 58 |
+
"diffsynth.models.wan_video_animate_adapter.ConvLayer": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 59 |
+
"diffsynth.models.wan_video_animate_adapter.FusedLeakyReLU": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 60 |
+
"diffsynth.models.wan_video_animate_adapter.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 61 |
+
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
| 62 |
+
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 63 |
+
"torch.nn.Conv1d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 64 |
+
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 65 |
+
"torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 66 |
+
},
|
| 67 |
+
"diffsynth.models.wan_video_dit_s2v.WanS2VModel": {
|
| 68 |
+
"diffsynth.models.wan_video_dit.Head": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 69 |
+
"diffsynth.models.wan_video_dit_s2v.WanS2VDiTBlock": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 70 |
+
"diffsynth.models.wan_video_dit_s2v.CausalAudioEncoder": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 71 |
+
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 72 |
+
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
| 73 |
+
"torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 74 |
+
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 75 |
+
"diffsynth.models.wan_video_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 76 |
+
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 77 |
+
},
|
| 78 |
+
"diffsynth.models.wan_video_dit.WanModel": {
|
| 79 |
+
"diffsynth.models.wan_video_dit.MLP": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 80 |
+
"diffsynth.models.wan_video_dit.DiTBlock": "diffsynth.core.vram.layers.AutoWrappedNonRecurseModule",
|
| 81 |
+
"diffsynth.models.wan_video_dit.Head": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 82 |
+
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
| 83 |
+
"torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 84 |
+
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 85 |
+
"diffsynth.models.wan_video_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 86 |
+
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 87 |
+
},
|
| 88 |
+
"diffsynth.models.wan_video_image_encoder.WanImageEncoder": {
|
| 89 |
+
"diffsynth.models.wan_video_image_encoder.VisionTransformer": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 90 |
+
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
| 91 |
+
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 92 |
+
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 93 |
+
},
|
| 94 |
+
"diffsynth.models.wan_video_mot.MotWanModel": {
|
| 95 |
+
"diffsynth.models.wan_video_mot.MotWanAttentionBlock": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 96 |
+
"torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 97 |
+
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
| 98 |
+
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 99 |
+
},
|
| 100 |
+
"diffsynth.models.wan_video_motion_controller.WanMotionControllerModel": {
|
| 101 |
+
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
| 102 |
+
},
|
| 103 |
+
"diffsynth.models.wan_video_text_encoder.WanTextEncoder": {
|
| 104 |
+
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
| 105 |
+
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 106 |
+
"diffsynth.models.wan_video_text_encoder.T5RelativeEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 107 |
+
"diffsynth.models.wan_video_text_encoder.T5LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 108 |
+
},
|
| 109 |
+
"diffsynth.models.wan_video_vace.VaceWanModel": {
|
| 110 |
+
"diffsynth.models.wan_video_dit.DiTBlock": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 111 |
+
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
| 112 |
+
"torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 113 |
+
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 114 |
+
"diffsynth.models.wan_video_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 115 |
+
},
|
| 116 |
+
"diffsynth.models.wan_video_vae.WanVideoVAE": {
|
| 117 |
+
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
| 118 |
+
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 119 |
+
"diffsynth.models.wan_video_vae.RMS_norm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 120 |
+
"diffsynth.models.wan_video_vae.CausalConv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 121 |
+
"diffsynth.models.wan_video_vae.Upsample": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 122 |
+
"torch.nn.SiLU": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 123 |
+
"torch.nn.Dropout": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 124 |
+
},
|
| 125 |
+
"diffsynth.models.wan_video_vae.WanVideoVAE38": {
|
| 126 |
+
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
| 127 |
+
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 128 |
+
"diffsynth.models.wan_video_vae.RMS_norm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 129 |
+
"diffsynth.models.wan_video_vae.CausalConv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 130 |
+
"diffsynth.models.wan_video_vae.Upsample": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 131 |
+
"torch.nn.SiLU": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 132 |
+
"torch.nn.Dropout": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 133 |
+
},
|
| 134 |
+
"diffsynth.models.wav2vec.WanS2VAudioEncoder": {
|
| 135 |
+
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
| 136 |
+
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 137 |
+
"torch.nn.Conv1d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 138 |
+
},
|
| 139 |
+
"diffsynth.models.longcat_video_dit.LongCatVideoTransformer3DModel": {
|
| 140 |
+
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
| 141 |
+
"torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 142 |
+
"diffsynth.models.longcat_video_dit.RMSNorm_FP32": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 143 |
+
"diffsynth.models.longcat_video_dit.LayerNorm_FP32": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 144 |
+
},
|
| 145 |
+
"diffsynth.models.flux_dit.FluxDiT": {
|
| 146 |
+
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
| 147 |
+
"diffsynth.models.flux_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 148 |
+
},
|
| 149 |
+
"diffsynth.models.flux_text_encoder_clip.FluxTextEncoderClip": flux_general_vram_config,
|
| 150 |
+
"diffsynth.models.flux_vae.FluxVAEEncoder": flux_general_vram_config,
|
| 151 |
+
"diffsynth.models.flux_vae.FluxVAEDecoder": flux_general_vram_config,
|
| 152 |
+
"diffsynth.models.flux_controlnet.FluxControlNet": flux_general_vram_config,
|
| 153 |
+
"diffsynth.models.flux_infiniteyou.InfiniteYouImageProjector": flux_general_vram_config,
|
| 154 |
+
"diffsynth.models.flux_ipadapter.FluxIpAdapter": flux_general_vram_config,
|
| 155 |
+
"diffsynth.models.flux_lora_patcher.FluxLoraPatcher": flux_general_vram_config,
|
| 156 |
+
"diffsynth.models.step1x_connector.Qwen2Connector": flux_general_vram_config,
|
| 157 |
+
"diffsynth.models.flux_lora_encoder.FluxLoRAEncoder": flux_general_vram_config,
|
| 158 |
+
"diffsynth.models.flux_text_encoder_t5.FluxTextEncoderT5": {
|
| 159 |
+
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
| 160 |
+
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 161 |
+
"transformers.models.t5.modeling_t5.T5LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 162 |
+
"transformers.models.t5.modeling_t5.T5DenseActDense": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 163 |
+
"transformers.models.t5.modeling_t5.T5DenseGatedActDense": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 164 |
+
},
|
| 165 |
+
"diffsynth.models.flux_ipadapter.SiglipVisionModelSO400M": {
|
| 166 |
+
"transformers.models.siglip.modeling_siglip.SiglipVisionEmbeddings": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 167 |
+
"transformers.models.siglip.modeling_siglip.SiglipEncoder": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 168 |
+
"transformers.models.siglip.modeling_siglip.SiglipMultiheadAttentionPoolingHead": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 169 |
+
"torch.nn.MultiheadAttention": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 170 |
+
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
| 171 |
+
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 172 |
+
},
|
| 173 |
+
"diffsynth.models.flux2_dit.Flux2DiT": {
|
| 174 |
+
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
| 175 |
+
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 176 |
+
"torch.nn.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 177 |
+
},
|
| 178 |
+
"diffsynth.models.flux2_text_encoder.Flux2TextEncoder": {
|
| 179 |
+
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
| 180 |
+
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 181 |
+
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 182 |
+
"transformers.models.mistral.modeling_mistral.MistralRMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 183 |
+
},
|
| 184 |
+
"diffsynth.models.flux2_vae.Flux2VAE": {
|
| 185 |
+
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
| 186 |
+
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 187 |
+
"torch.nn.GroupNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 188 |
+
},
|
| 189 |
+
"diffsynth.models.z_image_text_encoder.ZImageTextEncoder": {
|
| 190 |
+
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
| 191 |
+
"transformers.models.qwen3.modeling_qwen3.Qwen3RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 192 |
+
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 193 |
+
},
|
| 194 |
+
"diffsynth.models.z_image_dit.ZImageDiT": {
|
| 195 |
+
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
| 196 |
+
"diffsynth.models.z_image_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 197 |
+
},
|
| 198 |
+
"diffsynth.models.z_image_controlnet.ZImageControlNet": {
|
| 199 |
+
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
| 200 |
+
"diffsynth.models.z_image_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 201 |
+
},
|
| 202 |
+
"diffsynth.models.z_image_image2lora.ZImageImage2LoRAModel": {
|
| 203 |
+
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
| 204 |
+
},
|
| 205 |
+
"diffsynth.models.siglip2_image_encoder.Siglip2ImageEncoder428M": {
|
| 206 |
+
"transformers.models.siglip2.modeling_siglip2.Siglip2VisionEmbeddings": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 207 |
+
"transformers.models.siglip2.modeling_siglip2.Siglip2MultiheadAttentionPoolingHead": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 208 |
+
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 209 |
+
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 210 |
+
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 211 |
+
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
| 212 |
+
},
|
| 213 |
+
"diffsynth.models.ltx2_dit.LTXModel": {
|
| 214 |
+
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
| 215 |
+
"torch.nn.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 216 |
+
},
|
| 217 |
+
"diffsynth.models.ltx2_upsampler.LTX2LatentUpsampler": {
|
| 218 |
+
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 219 |
+
"torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 220 |
+
"torch.nn.GroupNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 221 |
+
},
|
| 222 |
+
"diffsynth.models.ltx2_video_vae.LTX2VideoEncoder": {
|
| 223 |
+
"torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 224 |
+
},
|
| 225 |
+
"diffsynth.models.ltx2_video_vae.LTX2VideoDecoder": {
|
| 226 |
+
"torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 227 |
+
},
|
| 228 |
+
"diffsynth.models.ltx2_audio_vae.LTX2AudioDecoder": {
|
| 229 |
+
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 230 |
+
},
|
| 231 |
+
"diffsynth.models.ltx2_audio_vae.LTX2Vocoder": {
|
| 232 |
+
"torch.nn.Conv1d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 233 |
+
"torch.nn.ConvTranspose1d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 234 |
+
},
|
| 235 |
+
"diffsynth.models.ltx2_text_encoder.LTX2TextEncoderPostModules": {
|
| 236 |
+
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
| 237 |
+
"torch.nn.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 238 |
+
"diffsynth.models.ltx2_text_encoder.Embeddings1DConnector": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 239 |
+
},
|
| 240 |
+
"diffsynth.models.ltx2_text_encoder.LTX2TextEncoder": {
|
| 241 |
+
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
| 242 |
+
"transformers.models.gemma3.modeling_gemma3.Gemma3MultiModalProjector": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 243 |
+
"transformers.models.gemma3.modeling_gemma3.Gemma3RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 244 |
+
"transformers.models.gemma3.modeling_gemma3.Gemma3TextScaledWordEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 245 |
+
},
|
| 246 |
+
"diffsynth.models.anima_dit.AnimaDiT": {
|
| 247 |
+
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
| 248 |
+
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 249 |
+
"torch.nn.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 250 |
+
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 251 |
+
},
|
| 252 |
+
"diffsynth.models.mova_audio_dit.MovaAudioDit": {
|
| 253 |
+
"diffsynth.models.wan_video_dit.DiTBlock": "diffsynth.core.vram.layers.AutoWrappedNonRecurseModule",
|
| 254 |
+
"diffsynth.models.wan_video_dit.Head": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 255 |
+
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
| 256 |
+
"torch.nn.Conv1d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 257 |
+
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 258 |
+
"diffsynth.models.wan_video_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 259 |
+
},
|
| 260 |
+
"diffsynth.models.mova_dual_tower_bridge.DualTowerConditionalBridge": {
|
| 261 |
+
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
| 262 |
+
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 263 |
+
"diffsynth.models.wan_video_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 264 |
+
},
|
| 265 |
+
"diffsynth.models.mova_audio_vae.DacVAE": {
|
| 266 |
+
"diffsynth.models.mova_audio_vae.Snake1d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 267 |
+
"torch.nn.Conv1d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 268 |
+
"torch.nn.ConvTranspose1d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 269 |
+
},
|
| 270 |
+
}
|
| 271 |
+
|
| 272 |
+
def QwenImageTextEncoder_Module_Map_Updater():
|
| 273 |
+
current = VRAM_MANAGEMENT_MODULE_MAPS["diffsynth.models.qwen_image_text_encoder.QwenImageTextEncoder"]
|
| 274 |
+
from packaging import version
|
| 275 |
+
import transformers
|
| 276 |
+
if version.parse(transformers.__version__) >= version.parse("5.2.0"):
|
| 277 |
+
# The Qwen2RMSNorm in transformers 5.2.0+ has been renamed to Qwen2_5_VLRMSNorm, so we need to update the module map accordingly
|
| 278 |
+
current.pop("transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2RMSNorm", None)
|
| 279 |
+
current["transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VLRMSNorm"] = "diffsynth.core.vram.layers.AutoWrappedModule"
|
| 280 |
+
return current
|
| 281 |
+
|
| 282 |
+
VERSION_CHECKER_MAPS = {
|
| 283 |
+
"diffsynth.models.qwen_image_text_encoder.QwenImageTextEncoder": QwenImageTextEncoder_Module_Map_Updater,
|
| 284 |
+
}
|
diffsynth/core/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .attention import *
|
| 2 |
+
from .data import *
|
| 3 |
+
from .gradient import *
|
| 4 |
+
from .loader import *
|
| 5 |
+
from .vram import *
|
| 6 |
+
from .device import *
|
diffsynth/core/attention/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .attention import attention_forward
|
diffsynth/core/attention/attention.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch, os
|
| 2 |
+
from einops import rearrange
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
try:
|
| 6 |
+
import flash_attn_interface
|
| 7 |
+
FLASH_ATTN_3_AVAILABLE = True
|
| 8 |
+
except ModuleNotFoundError:
|
| 9 |
+
FLASH_ATTN_3_AVAILABLE = False
|
| 10 |
+
|
| 11 |
+
try:
|
| 12 |
+
import flash_attn
|
| 13 |
+
FLASH_ATTN_2_AVAILABLE = True
|
| 14 |
+
except ModuleNotFoundError:
|
| 15 |
+
FLASH_ATTN_2_AVAILABLE = False
|
| 16 |
+
|
| 17 |
+
try:
|
| 18 |
+
from sageattention import sageattn
|
| 19 |
+
SAGE_ATTN_AVAILABLE = True
|
| 20 |
+
except ModuleNotFoundError:
|
| 21 |
+
SAGE_ATTN_AVAILABLE = False
|
| 22 |
+
|
| 23 |
+
try:
|
| 24 |
+
import xformers.ops as xops
|
| 25 |
+
XFORMERS_AVAILABLE = True
|
| 26 |
+
except ModuleNotFoundError:
|
| 27 |
+
XFORMERS_AVAILABLE = False
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def initialize_attention_priority():
|
| 31 |
+
if os.environ.get('DIFFSYNTH_ATTENTION_IMPLEMENTATION') is not None:
|
| 32 |
+
return os.environ.get('DIFFSYNTH_ATTENTION_IMPLEMENTATION').lower()
|
| 33 |
+
elif FLASH_ATTN_3_AVAILABLE:
|
| 34 |
+
return "flash_attention_3"
|
| 35 |
+
elif FLASH_ATTN_2_AVAILABLE:
|
| 36 |
+
return "flash_attention_2"
|
| 37 |
+
elif SAGE_ATTN_AVAILABLE:
|
| 38 |
+
return "sage_attention"
|
| 39 |
+
elif XFORMERS_AVAILABLE:
|
| 40 |
+
return "xformers"
|
| 41 |
+
else:
|
| 42 |
+
return "torch"
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
ATTENTION_IMPLEMENTATION = initialize_attention_priority()
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def rearrange_qkv(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", required_in_pattern="b n s d", dims=None):
|
| 49 |
+
dims = {} if dims is None else dims
|
| 50 |
+
if q_pattern != required_in_pattern:
|
| 51 |
+
q = rearrange(q, f"{q_pattern} -> {required_in_pattern}", **dims)
|
| 52 |
+
if k_pattern != required_in_pattern:
|
| 53 |
+
k = rearrange(k, f"{k_pattern} -> {required_in_pattern}", **dims)
|
| 54 |
+
if v_pattern != required_in_pattern:
|
| 55 |
+
v = rearrange(v, f"{v_pattern} -> {required_in_pattern}", **dims)
|
| 56 |
+
return q, k, v
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def rearrange_out(out: torch.Tensor, out_pattern="b n s d", required_out_pattern="b n s d", dims=None):
|
| 60 |
+
dims = {} if dims is None else dims
|
| 61 |
+
if out_pattern != required_out_pattern:
|
| 62 |
+
out = rearrange(out, f"{required_out_pattern} -> {out_pattern}", **dims)
|
| 63 |
+
return out
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def torch_sdpa(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, attn_mask=None, scale=None):
|
| 67 |
+
required_in_pattern, required_out_pattern= "b n s d", "b n s d"
|
| 68 |
+
q, k, v = rearrange_qkv(q, k, v, q_pattern, k_pattern, v_pattern, required_in_pattern, dims)
|
| 69 |
+
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask, scale=scale)
|
| 70 |
+
out = rearrange_out(out, out_pattern, required_out_pattern, dims)
|
| 71 |
+
return out
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def flash_attention_3(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, scale=None):
|
| 75 |
+
required_in_pattern, required_out_pattern= "b s n d", "b s n d"
|
| 76 |
+
q, k, v = rearrange_qkv(q, k, v, q_pattern, k_pattern, v_pattern, required_in_pattern, dims)
|
| 77 |
+
out = flash_attn_interface.flash_attn_func(q, k, v, softmax_scale=scale)
|
| 78 |
+
if isinstance(out, tuple):
|
| 79 |
+
out = out[0]
|
| 80 |
+
out = rearrange_out(out, out_pattern, required_out_pattern, dims)
|
| 81 |
+
return out
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def flash_attention_2(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, scale=None):
|
| 85 |
+
required_in_pattern, required_out_pattern= "b s n d", "b s n d"
|
| 86 |
+
q, k, v = rearrange_qkv(q, k, v, q_pattern, k_pattern, v_pattern, required_in_pattern, dims)
|
| 87 |
+
out = flash_attn.flash_attn_func(q, k, v, softmax_scale=scale)
|
| 88 |
+
out = rearrange_out(out, out_pattern, required_out_pattern, dims)
|
| 89 |
+
return out
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def sage_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, scale=None):
|
| 93 |
+
required_in_pattern, required_out_pattern= "b n s d", "b n s d"
|
| 94 |
+
q, k, v = rearrange_qkv(q, k, v, q_pattern, k_pattern, v_pattern, required_in_pattern, dims)
|
| 95 |
+
out = sageattn(q, k, v, sm_scale=scale)
|
| 96 |
+
out = rearrange_out(out, out_pattern, required_out_pattern, dims)
|
| 97 |
+
return out
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def xformers_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, scale=None):
|
| 101 |
+
required_in_pattern, required_out_pattern= "b s n d", "b s n d"
|
| 102 |
+
q, k, v = rearrange_qkv(q, k, v, q_pattern, k_pattern, v_pattern, required_in_pattern, dims)
|
| 103 |
+
out = xops.memory_efficient_attention(q, k, v, scale=scale)
|
| 104 |
+
out = rearrange_out(out, out_pattern, required_out_pattern, dims)
|
| 105 |
+
return out
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def attention_forward(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, attn_mask=None, scale=None, compatibility_mode=False):
|
| 109 |
+
if compatibility_mode or (attn_mask is not None):
|
| 110 |
+
return torch_sdpa(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, attn_mask=attn_mask, scale=scale)
|
| 111 |
+
else:
|
| 112 |
+
if ATTENTION_IMPLEMENTATION == "flash_attention_3":
|
| 113 |
+
return flash_attention_3(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale)
|
| 114 |
+
elif ATTENTION_IMPLEMENTATION == "flash_attention_2":
|
| 115 |
+
return flash_attention_2(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale)
|
| 116 |
+
elif ATTENTION_IMPLEMENTATION == "sage_attention":
|
| 117 |
+
return sage_attention(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale)
|
| 118 |
+
elif ATTENTION_IMPLEMENTATION == "xformers":
|
| 119 |
+
return xformers_attention(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale)
|
| 120 |
+
else:
|
| 121 |
+
return torch_sdpa(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale)
|
diffsynth/core/data/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .unified_dataset import UnifiedDataset
|
diffsynth/core/data/operators.py
ADDED
|
@@ -0,0 +1,280 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import torch, torchvision, imageio, os
|
| 3 |
+
import imageio.v3 as iio
|
| 4 |
+
from PIL import Image
|
| 5 |
+
import torchaudio
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class DataProcessingPipeline:
|
| 9 |
+
def __init__(self, operators=None):
|
| 10 |
+
self.operators: list[DataProcessingOperator] = [] if operators is None else operators
|
| 11 |
+
|
| 12 |
+
def __call__(self, data):
|
| 13 |
+
for operator in self.operators:
|
| 14 |
+
data = operator(data)
|
| 15 |
+
return data
|
| 16 |
+
|
| 17 |
+
def __rshift__(self, pipe):
|
| 18 |
+
if isinstance(pipe, DataProcessingOperator):
|
| 19 |
+
pipe = DataProcessingPipeline([pipe])
|
| 20 |
+
return DataProcessingPipeline(self.operators + pipe.operators)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class DataProcessingOperator:
|
| 24 |
+
def __call__(self, data):
|
| 25 |
+
raise NotImplementedError("DataProcessingOperator cannot be called directly.")
|
| 26 |
+
|
| 27 |
+
def __rshift__(self, pipe):
|
| 28 |
+
if isinstance(pipe, DataProcessingOperator):
|
| 29 |
+
pipe = DataProcessingPipeline([pipe])
|
| 30 |
+
return DataProcessingPipeline([self]).__rshift__(pipe)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class DataProcessingOperatorRaw(DataProcessingOperator):
|
| 34 |
+
def __call__(self, data):
|
| 35 |
+
return data
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class ToInt(DataProcessingOperator):
|
| 39 |
+
def __call__(self, data):
|
| 40 |
+
return int(data)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class ToFloat(DataProcessingOperator):
|
| 44 |
+
def __call__(self, data):
|
| 45 |
+
return float(data)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class ToStr(DataProcessingOperator):
|
| 49 |
+
def __init__(self, none_value=""):
|
| 50 |
+
self.none_value = none_value
|
| 51 |
+
|
| 52 |
+
def __call__(self, data):
|
| 53 |
+
if data is None: data = self.none_value
|
| 54 |
+
return str(data)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class LoadImage(DataProcessingOperator):
|
| 58 |
+
def __init__(self, convert_RGB=True, convert_RGBA=False):
|
| 59 |
+
self.convert_RGB = convert_RGB
|
| 60 |
+
self.convert_RGBA = convert_RGBA
|
| 61 |
+
|
| 62 |
+
def __call__(self, data: str):
|
| 63 |
+
image = Image.open(data)
|
| 64 |
+
if self.convert_RGB: image = image.convert("RGB")
|
| 65 |
+
if self.convert_RGBA: image = image.convert("RGBA")
|
| 66 |
+
return image
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class ImageCropAndResize(DataProcessingOperator):
|
| 70 |
+
def __init__(self, height=None, width=None, max_pixels=None, height_division_factor=1, width_division_factor=1):
|
| 71 |
+
self.height = height
|
| 72 |
+
self.width = width
|
| 73 |
+
self.max_pixels = max_pixels
|
| 74 |
+
self.height_division_factor = height_division_factor
|
| 75 |
+
self.width_division_factor = width_division_factor
|
| 76 |
+
|
| 77 |
+
def crop_and_resize(self, image, target_height, target_width):
|
| 78 |
+
width, height = image.size
|
| 79 |
+
scale = max(target_width / width, target_height / height)
|
| 80 |
+
image = torchvision.transforms.functional.resize(
|
| 81 |
+
image,
|
| 82 |
+
(round(height*scale), round(width*scale)),
|
| 83 |
+
interpolation=torchvision.transforms.InterpolationMode.BILINEAR
|
| 84 |
+
)
|
| 85 |
+
image = torchvision.transforms.functional.center_crop(image, (target_height, target_width))
|
| 86 |
+
return image
|
| 87 |
+
|
| 88 |
+
def get_height_width(self, image):
|
| 89 |
+
if self.height is None or self.width is None:
|
| 90 |
+
width, height = image.size
|
| 91 |
+
if width * height > self.max_pixels:
|
| 92 |
+
scale = (width * height / self.max_pixels) ** 0.5
|
| 93 |
+
height, width = int(height / scale), int(width / scale)
|
| 94 |
+
height = height // self.height_division_factor * self.height_division_factor
|
| 95 |
+
width = width // self.width_division_factor * self.width_division_factor
|
| 96 |
+
else:
|
| 97 |
+
height, width = self.height, self.width
|
| 98 |
+
return height, width
|
| 99 |
+
|
| 100 |
+
def __call__(self, data: Image.Image):
|
| 101 |
+
image = self.crop_and_resize(data, *self.get_height_width(data))
|
| 102 |
+
return image
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
class ToList(DataProcessingOperator):
|
| 106 |
+
def __call__(self, data):
|
| 107 |
+
return [data]
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
class FrameSamplerByRateMixin:
|
| 111 |
+
def __init__(self, num_frames=81, time_division_factor=4, time_division_remainder=1, frame_rate=24, fix_frame_rate=False):
|
| 112 |
+
self.num_frames = num_frames
|
| 113 |
+
self.time_division_factor = time_division_factor
|
| 114 |
+
self.time_division_remainder = time_division_remainder
|
| 115 |
+
self.frame_rate = frame_rate
|
| 116 |
+
self.fix_frame_rate = fix_frame_rate
|
| 117 |
+
|
| 118 |
+
def get_reader(self, data: str):
|
| 119 |
+
return imageio.get_reader(data)
|
| 120 |
+
|
| 121 |
+
def get_available_num_frames(self, reader):
|
| 122 |
+
if not self.fix_frame_rate:
|
| 123 |
+
return reader.count_frames()
|
| 124 |
+
meta_data = reader.get_meta_data()
|
| 125 |
+
total_original_frames = int(reader.count_frames())
|
| 126 |
+
duration = meta_data["duration"] if "duration" in meta_data else total_original_frames / meta_data['fps']
|
| 127 |
+
total_available_frames = math.floor(duration * self.frame_rate)
|
| 128 |
+
return int(total_available_frames)
|
| 129 |
+
|
| 130 |
+
def get_num_frames(self, reader):
|
| 131 |
+
num_frames = self.num_frames
|
| 132 |
+
total_frames = self.get_available_num_frames(reader)
|
| 133 |
+
if int(total_frames) < num_frames:
|
| 134 |
+
num_frames = total_frames
|
| 135 |
+
while num_frames > 1 and num_frames % self.time_division_factor != self.time_division_remainder:
|
| 136 |
+
num_frames -= 1
|
| 137 |
+
return num_frames
|
| 138 |
+
|
| 139 |
+
def map_single_frame_id(self, new_sequence_id: int, raw_frame_rate: float, total_raw_frames: int) -> int:
|
| 140 |
+
if not self.fix_frame_rate:
|
| 141 |
+
return new_sequence_id
|
| 142 |
+
target_time_in_seconds = new_sequence_id / self.frame_rate
|
| 143 |
+
raw_frame_index_float = target_time_in_seconds * raw_frame_rate
|
| 144 |
+
frame_id = int(round(raw_frame_index_float))
|
| 145 |
+
frame_id = min(frame_id, total_raw_frames - 1)
|
| 146 |
+
return frame_id
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
class LoadVideo(DataProcessingOperator, FrameSamplerByRateMixin):
|
| 150 |
+
def __init__(self, num_frames=81, time_division_factor=4, time_division_remainder=1, frame_processor=lambda x: x, frame_rate=24, fix_frame_rate=False):
|
| 151 |
+
FrameSamplerByRateMixin.__init__(self, num_frames, time_division_factor, time_division_remainder, frame_rate, fix_frame_rate)
|
| 152 |
+
# frame_processor is build in the video loader for high efficiency.
|
| 153 |
+
self.frame_processor = frame_processor
|
| 154 |
+
|
| 155 |
+
def __call__(self, data: str):
|
| 156 |
+
reader = self.get_reader(data)
|
| 157 |
+
raw_frame_rate = reader.get_meta_data()['fps']
|
| 158 |
+
total_raw_frames = reader.count_frames()
|
| 159 |
+
total_available = self.get_available_num_frames(reader)
|
| 160 |
+
# Pad short videos with the last frame instead of reducing num_frames
|
| 161 |
+
num_frames = self.num_frames
|
| 162 |
+
frames = []
|
| 163 |
+
for frame_id in range(num_frames):
|
| 164 |
+
if frame_id < total_available:
|
| 165 |
+
raw_id = self.map_single_frame_id(frame_id, raw_frame_rate, total_raw_frames)
|
| 166 |
+
frame = reader.get_data(raw_id)
|
| 167 |
+
frame = Image.fromarray(frame)
|
| 168 |
+
frame = self.frame_processor(frame)
|
| 169 |
+
frames.append(frame)
|
| 170 |
+
else:
|
| 171 |
+
# Pad with the last frame
|
| 172 |
+
frames.append(frames[-1])
|
| 173 |
+
reader.close()
|
| 174 |
+
return frames
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
class SequencialProcess(DataProcessingOperator):
|
| 178 |
+
def __init__(self, operator=lambda x: x):
|
| 179 |
+
self.operator = operator
|
| 180 |
+
|
| 181 |
+
def __call__(self, data):
|
| 182 |
+
return [self.operator(i) for i in data]
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
class LoadGIF(DataProcessingOperator):
|
| 186 |
+
def __init__(self, num_frames=81, time_division_factor=4, time_division_remainder=1, frame_processor=lambda x: x):
|
| 187 |
+
self.num_frames = num_frames
|
| 188 |
+
self.time_division_factor = time_division_factor
|
| 189 |
+
self.time_division_remainder = time_division_remainder
|
| 190 |
+
# frame_processor is build in the video loader for high efficiency.
|
| 191 |
+
self.frame_processor = frame_processor
|
| 192 |
+
|
| 193 |
+
def get_num_frames(self, path):
|
| 194 |
+
num_frames = self.num_frames
|
| 195 |
+
images = iio.imread(path, mode="RGB")
|
| 196 |
+
if len(images) < num_frames:
|
| 197 |
+
num_frames = len(images)
|
| 198 |
+
while num_frames > 1 and num_frames % self.time_division_factor != self.time_division_remainder:
|
| 199 |
+
num_frames -= 1
|
| 200 |
+
return num_frames
|
| 201 |
+
|
| 202 |
+
def __call__(self, data: str):
|
| 203 |
+
num_frames = self.get_num_frames(data)
|
| 204 |
+
frames = []
|
| 205 |
+
images = iio.imread(data, mode="RGB")
|
| 206 |
+
for img in images:
|
| 207 |
+
frame = Image.fromarray(img)
|
| 208 |
+
frame = self.frame_processor(frame)
|
| 209 |
+
frames.append(frame)
|
| 210 |
+
if len(frames) >= num_frames:
|
| 211 |
+
break
|
| 212 |
+
return frames
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
class RouteByExtensionName(DataProcessingOperator):
|
| 216 |
+
def __init__(self, operator_map):
|
| 217 |
+
self.operator_map = operator_map
|
| 218 |
+
|
| 219 |
+
def __call__(self, data: str):
|
| 220 |
+
file_ext_name = data.split(".")[-1].lower()
|
| 221 |
+
for ext_names, operator in self.operator_map:
|
| 222 |
+
if ext_names is None or file_ext_name in ext_names:
|
| 223 |
+
return operator(data)
|
| 224 |
+
raise ValueError(f"Unsupported file: {data}")
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
class RouteByType(DataProcessingOperator):
|
| 228 |
+
def __init__(self, operator_map):
|
| 229 |
+
self.operator_map = operator_map
|
| 230 |
+
|
| 231 |
+
def __call__(self, data):
|
| 232 |
+
for dtype, operator in self.operator_map:
|
| 233 |
+
if dtype is None or isinstance(data, dtype):
|
| 234 |
+
return operator(data)
|
| 235 |
+
raise ValueError(f"Unsupported data: {data}")
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
class LoadTorchPickle(DataProcessingOperator):
|
| 239 |
+
def __init__(self, map_location="cpu"):
|
| 240 |
+
self.map_location = map_location
|
| 241 |
+
|
| 242 |
+
def __call__(self, data):
|
| 243 |
+
return torch.load(data, map_location=self.map_location, weights_only=False)
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
class ToAbsolutePath(DataProcessingOperator):
|
| 247 |
+
def __init__(self, base_path=""):
|
| 248 |
+
self.base_path = base_path
|
| 249 |
+
|
| 250 |
+
def __call__(self, data):
|
| 251 |
+
return os.path.join(self.base_path, data)
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
class LoadAudio(DataProcessingOperator):
|
| 255 |
+
def __init__(self, sr=16000):
|
| 256 |
+
self.sr = sr
|
| 257 |
+
def __call__(self, data: str):
|
| 258 |
+
import librosa
|
| 259 |
+
input_audio, sample_rate = librosa.load(data, sr=self.sr)
|
| 260 |
+
return input_audio
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
class LoadAudioWithTorchaudio(DataProcessingOperator, FrameSamplerByRateMixin):
|
| 264 |
+
|
| 265 |
+
def __init__(self, num_frames=121, time_division_factor=8, time_division_remainder=1, frame_rate=24, fix_frame_rate=True):
|
| 266 |
+
FrameSamplerByRateMixin.__init__(self, num_frames, time_division_factor, time_division_remainder, frame_rate, fix_frame_rate)
|
| 267 |
+
|
| 268 |
+
def __call__(self, data: str):
|
| 269 |
+
reader = self.get_reader(data)
|
| 270 |
+
num_frames = self.get_num_frames(reader)
|
| 271 |
+
duration = num_frames / self.frame_rate
|
| 272 |
+
waveform, sample_rate = torchaudio.load(data)
|
| 273 |
+
target_samples = int(duration * sample_rate)
|
| 274 |
+
current_samples = waveform.shape[-1]
|
| 275 |
+
if current_samples > target_samples:
|
| 276 |
+
waveform = waveform[..., :target_samples]
|
| 277 |
+
elif current_samples < target_samples:
|
| 278 |
+
padding = target_samples - current_samples
|
| 279 |
+
waveform = torch.nn.functional.pad(waveform, (0, padding))
|
| 280 |
+
return waveform, sample_rate
|
diffsynth/core/data/unified_dataset.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .operators import *
|
| 2 |
+
import torch, json, pandas
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class UnifiedDataset(torch.utils.data.Dataset):
|
| 6 |
+
def __init__(
|
| 7 |
+
self,
|
| 8 |
+
base_path=None, metadata_path=None,
|
| 9 |
+
repeat=1,
|
| 10 |
+
data_file_keys=tuple(),
|
| 11 |
+
main_data_operator=lambda x: x,
|
| 12 |
+
special_operator_map=None,
|
| 13 |
+
max_data_items=None,
|
| 14 |
+
):
|
| 15 |
+
self.base_path = base_path
|
| 16 |
+
self.metadata_path = metadata_path
|
| 17 |
+
self.repeat = repeat
|
| 18 |
+
self.data_file_keys = data_file_keys
|
| 19 |
+
self.main_data_operator = main_data_operator
|
| 20 |
+
self.cached_data_operator = LoadTorchPickle()
|
| 21 |
+
self.special_operator_map = {} if special_operator_map is None else special_operator_map
|
| 22 |
+
self.max_data_items = max_data_items
|
| 23 |
+
self.data = []
|
| 24 |
+
self.cached_data = []
|
| 25 |
+
self.load_from_cache = metadata_path is None
|
| 26 |
+
self.load_metadata(metadata_path)
|
| 27 |
+
|
| 28 |
+
@staticmethod
|
| 29 |
+
def default_image_operator(
|
| 30 |
+
base_path="",
|
| 31 |
+
max_pixels=1920*1080, height=None, width=None,
|
| 32 |
+
height_division_factor=16, width_division_factor=16,
|
| 33 |
+
):
|
| 34 |
+
return RouteByType(operator_map=[
|
| 35 |
+
(str, ToAbsolutePath(base_path) >> LoadImage() >> ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor)),
|
| 36 |
+
(list, SequencialProcess(ToAbsolutePath(base_path) >> LoadImage() >> ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor))),
|
| 37 |
+
])
|
| 38 |
+
|
| 39 |
+
@staticmethod
|
| 40 |
+
def default_video_operator(
|
| 41 |
+
base_path="",
|
| 42 |
+
max_pixels=1920*1080, height=None, width=None,
|
| 43 |
+
height_division_factor=16, width_division_factor=16,
|
| 44 |
+
num_frames=81, time_division_factor=4, time_division_remainder=1,
|
| 45 |
+
frame_rate=24, fix_frame_rate=False,
|
| 46 |
+
):
|
| 47 |
+
return RouteByType(operator_map=[
|
| 48 |
+
(str, ToAbsolutePath(base_path) >> RouteByExtensionName(operator_map=[
|
| 49 |
+
(("jpg", "jpeg", "png", "webp"), LoadImage() >> ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor) >> ToList()),
|
| 50 |
+
(("gif",), LoadGIF(
|
| 51 |
+
num_frames, time_division_factor, time_division_remainder,
|
| 52 |
+
frame_processor=ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor),
|
| 53 |
+
)),
|
| 54 |
+
(("mp4", "avi", "mov", "wmv", "mkv", "flv", "webm"), LoadVideo(
|
| 55 |
+
num_frames, time_division_factor, time_division_remainder,
|
| 56 |
+
frame_processor=ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor),
|
| 57 |
+
frame_rate=frame_rate, fix_frame_rate=fix_frame_rate,
|
| 58 |
+
)),
|
| 59 |
+
])),
|
| 60 |
+
])
|
| 61 |
+
|
| 62 |
+
def search_for_cached_data_files(self, path):
|
| 63 |
+
for file_name in os.listdir(path):
|
| 64 |
+
subpath = os.path.join(path, file_name)
|
| 65 |
+
if os.path.isdir(subpath):
|
| 66 |
+
self.search_for_cached_data_files(subpath)
|
| 67 |
+
elif subpath.endswith(".pth"):
|
| 68 |
+
self.cached_data.append(subpath)
|
| 69 |
+
|
| 70 |
+
def load_metadata(self, metadata_path):
|
| 71 |
+
if metadata_path is None:
|
| 72 |
+
print("No metadata_path. Searching for cached data files.")
|
| 73 |
+
self.search_for_cached_data_files(self.base_path)
|
| 74 |
+
print(f"{len(self.cached_data)} cached data files found.")
|
| 75 |
+
elif metadata_path.endswith(".json"):
|
| 76 |
+
with open(metadata_path, "r") as f:
|
| 77 |
+
metadata = json.load(f)
|
| 78 |
+
self.data = metadata
|
| 79 |
+
elif metadata_path.endswith(".jsonl"):
|
| 80 |
+
metadata = []
|
| 81 |
+
with open(metadata_path, 'r') as f:
|
| 82 |
+
for line in f:
|
| 83 |
+
metadata.append(json.loads(line.strip()))
|
| 84 |
+
self.data = metadata
|
| 85 |
+
else:
|
| 86 |
+
metadata = pandas.read_csv(metadata_path)
|
| 87 |
+
self.data = [metadata.iloc[i].to_dict() for i in range(len(metadata))]
|
| 88 |
+
|
| 89 |
+
def __getitem__(self, data_id):
|
| 90 |
+
if self.load_from_cache:
|
| 91 |
+
data = self.cached_data[data_id % len(self.cached_data)]
|
| 92 |
+
data = self.cached_data_operator(data)
|
| 93 |
+
else:
|
| 94 |
+
data = self.data[data_id % len(self.data)].copy()
|
| 95 |
+
for key in self.data_file_keys:
|
| 96 |
+
if key in data:
|
| 97 |
+
if key in self.special_operator_map:
|
| 98 |
+
data[key] = self.special_operator_map[key](data[key])
|
| 99 |
+
elif key in self.data_file_keys:
|
| 100 |
+
data[key] = self.main_data_operator(data[key])
|
| 101 |
+
return data
|
| 102 |
+
|
| 103 |
+
def __len__(self):
|
| 104 |
+
if self.max_data_items is not None:
|
| 105 |
+
return self.max_data_items
|
| 106 |
+
elif self.load_from_cache:
|
| 107 |
+
return len(self.cached_data) * self.repeat
|
| 108 |
+
else:
|
| 109 |
+
return len(self.data) * self.repeat
|
| 110 |
+
|
| 111 |
+
def check_data_equal(self, data1, data2):
|
| 112 |
+
# Debug only
|
| 113 |
+
if len(data1) != len(data2):
|
| 114 |
+
return False
|
| 115 |
+
for k in data1:
|
| 116 |
+
if data1[k] != data2[k]:
|
| 117 |
+
return False
|
| 118 |
+
return True
|
diffsynth/core/device/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .npu_compatible_device import parse_device_type, parse_nccl_backend, get_available_device_type, get_device_name
|
| 2 |
+
from .npu_compatible_device import IS_NPU_AVAILABLE, IS_CUDA_AVAILABLE
|
diffsynth/core/device/npu_compatible_device.py
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import importlib
|
| 2 |
+
import torch
|
| 3 |
+
from typing import Any
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def is_torch_npu_available():
|
| 7 |
+
return importlib.util.find_spec("torch_npu") is not None
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
IS_CUDA_AVAILABLE = torch.cuda.is_available()
|
| 11 |
+
IS_NPU_AVAILABLE = is_torch_npu_available() and torch.npu.is_available()
|
| 12 |
+
|
| 13 |
+
if IS_NPU_AVAILABLE:
|
| 14 |
+
import torch_npu
|
| 15 |
+
|
| 16 |
+
torch.npu.config.allow_internal_format = False
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def get_device_type() -> str:
|
| 20 |
+
"""Get device type based on current machine, currently only support CPU, CUDA, NPU."""
|
| 21 |
+
if IS_CUDA_AVAILABLE:
|
| 22 |
+
device = "cuda"
|
| 23 |
+
elif IS_NPU_AVAILABLE:
|
| 24 |
+
device = "npu"
|
| 25 |
+
else:
|
| 26 |
+
device = "cpu"
|
| 27 |
+
|
| 28 |
+
return device
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def get_torch_device() -> Any:
|
| 32 |
+
"""Get torch attribute based on device type, e.g. torch.cuda or torch.npu"""
|
| 33 |
+
device_name = get_device_type()
|
| 34 |
+
|
| 35 |
+
try:
|
| 36 |
+
return getattr(torch, device_name)
|
| 37 |
+
except AttributeError:
|
| 38 |
+
print(f"Device namespace '{device_name}' not found in torch, try to load 'torch.cuda'.")
|
| 39 |
+
return torch.cuda
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def get_device_id() -> int:
|
| 43 |
+
"""Get current device id based on device type."""
|
| 44 |
+
return get_torch_device().current_device()
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def get_device_name() -> str:
|
| 48 |
+
"""Get current device name based on device type."""
|
| 49 |
+
return f"{get_device_type()}:{get_device_id()}"
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def synchronize() -> None:
|
| 53 |
+
"""Execute torch synchronize operation."""
|
| 54 |
+
get_torch_device().synchronize()
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def empty_cache() -> None:
|
| 58 |
+
"""Execute torch empty cache operation."""
|
| 59 |
+
get_torch_device().empty_cache()
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def get_nccl_backend() -> str:
|
| 63 |
+
"""Return distributed communication backend type based on device type."""
|
| 64 |
+
if IS_CUDA_AVAILABLE:
|
| 65 |
+
return "nccl"
|
| 66 |
+
elif IS_NPU_AVAILABLE:
|
| 67 |
+
return "hccl"
|
| 68 |
+
else:
|
| 69 |
+
raise RuntimeError(f"No available distributed communication backend found on device type {get_device_type()}.")
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def enable_high_precision_for_bf16():
|
| 73 |
+
"""
|
| 74 |
+
Set high accumulation dtype for matmul and reduction.
|
| 75 |
+
"""
|
| 76 |
+
if IS_CUDA_AVAILABLE:
|
| 77 |
+
torch.backends.cuda.matmul.allow_tf32 = False
|
| 78 |
+
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False
|
| 79 |
+
|
| 80 |
+
if IS_NPU_AVAILABLE:
|
| 81 |
+
torch.npu.matmul.allow_tf32 = False
|
| 82 |
+
torch.npu.matmul.allow_bf16_reduced_precision_reduction = False
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def parse_device_type(device):
|
| 86 |
+
if isinstance(device, str):
|
| 87 |
+
if device.startswith("cuda"):
|
| 88 |
+
return "cuda"
|
| 89 |
+
elif device.startswith("npu"):
|
| 90 |
+
return "npu"
|
| 91 |
+
else:
|
| 92 |
+
return "cpu"
|
| 93 |
+
elif isinstance(device, torch.device):
|
| 94 |
+
return device.type
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def parse_nccl_backend(device_type):
|
| 98 |
+
if device_type == "cuda":
|
| 99 |
+
return "nccl"
|
| 100 |
+
elif device_type == "npu":
|
| 101 |
+
return "hccl"
|
| 102 |
+
else:
|
| 103 |
+
raise RuntimeError(f"No available distributed communication backend found on device type {device_type}.")
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def get_available_device_type():
|
| 107 |
+
return get_device_type()
|
diffsynth/core/gradient/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .gradient_checkpoint import gradient_checkpoint_forward
|
diffsynth/core/gradient/gradient_checkpoint.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import warnings
|
| 3 |
+
# Suppress checkpoint requires_grad warning - gradients flow through model params, not inputs
|
| 4 |
+
warnings.filterwarnings("ignore", message=".*None of the inputs have requires_grad.*")
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def create_custom_forward(module):
|
| 8 |
+
def custom_forward(*inputs, **kwargs):
|
| 9 |
+
return module(*inputs, **kwargs)
|
| 10 |
+
return custom_forward
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def gradient_checkpoint_forward(
|
| 14 |
+
model,
|
| 15 |
+
use_gradient_checkpointing,
|
| 16 |
+
use_gradient_checkpointing_offload,
|
| 17 |
+
*args,
|
| 18 |
+
**kwargs,
|
| 19 |
+
):
|
| 20 |
+
if use_gradient_checkpointing_offload:
|
| 21 |
+
with torch.autograd.graph.save_on_cpu():
|
| 22 |
+
model_output = torch.utils.checkpoint.checkpoint(
|
| 23 |
+
create_custom_forward(model),
|
| 24 |
+
*args,
|
| 25 |
+
**kwargs,
|
| 26 |
+
use_reentrant=True,
|
| 27 |
+
)
|
| 28 |
+
elif use_gradient_checkpointing:
|
| 29 |
+
model_output = torch.utils.checkpoint.checkpoint(
|
| 30 |
+
create_custom_forward(model),
|
| 31 |
+
*args,
|
| 32 |
+
**kwargs,
|
| 33 |
+
use_reentrant=True,
|
| 34 |
+
)
|
| 35 |
+
else:
|
| 36 |
+
model_output = model(*args, **kwargs)
|
| 37 |
+
return model_output
|
diffsynth/core/loader/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .file import load_state_dict, hash_state_dict_keys, hash_model_file
|
| 2 |
+
from .model import load_model, load_model_with_disk_offload
|
| 3 |
+
from .config import ModelConfig
|
diffsynth/core/loader/config.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch, glob, os
|
| 2 |
+
from typing import Optional, Union, Dict
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from modelscope import snapshot_download
|
| 5 |
+
from huggingface_hub import snapshot_download as hf_snapshot_download
|
| 6 |
+
from typing import Optional
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
@dataclass
|
| 10 |
+
class ModelConfig:
|
| 11 |
+
path: Union[str, list[str]] = None
|
| 12 |
+
model_id: str = None
|
| 13 |
+
origin_file_pattern: Union[str, list[str]] = None
|
| 14 |
+
download_source: str = None
|
| 15 |
+
local_model_path: str = None
|
| 16 |
+
skip_download: bool = None
|
| 17 |
+
offload_device: Optional[Union[str, torch.device]] = None
|
| 18 |
+
offload_dtype: Optional[torch.dtype] = None
|
| 19 |
+
onload_device: Optional[Union[str, torch.device]] = None
|
| 20 |
+
onload_dtype: Optional[torch.dtype] = None
|
| 21 |
+
preparing_device: Optional[Union[str, torch.device]] = None
|
| 22 |
+
preparing_dtype: Optional[torch.dtype] = None
|
| 23 |
+
computation_device: Optional[Union[str, torch.device]] = None
|
| 24 |
+
computation_dtype: Optional[torch.dtype] = None
|
| 25 |
+
clear_parameters: bool = False
|
| 26 |
+
state_dict: Dict[str, torch.Tensor] = None
|
| 27 |
+
|
| 28 |
+
def check_input(self):
|
| 29 |
+
if self.path is None and self.model_id is None:
|
| 30 |
+
raise ValueError(f"""No valid model files. Please use `ModelConfig(path="xxx")` or `ModelConfig(model_id="xxx/yyy", origin_file_pattern="zzz")`. `skip_download=True` only supports the first one.""")
|
| 31 |
+
|
| 32 |
+
def parse_original_file_pattern(self):
|
| 33 |
+
if self.origin_file_pattern in [None, "", "./"]:
|
| 34 |
+
return "*"
|
| 35 |
+
elif self.origin_file_pattern.endswith("/"):
|
| 36 |
+
return self.origin_file_pattern + "*"
|
| 37 |
+
else:
|
| 38 |
+
return self.origin_file_pattern
|
| 39 |
+
|
| 40 |
+
def parse_download_source(self):
|
| 41 |
+
if self.download_source is None:
|
| 42 |
+
if os.environ.get('DIFFSYNTH_DOWNLOAD_SOURCE') is not None:
|
| 43 |
+
return os.environ.get('DIFFSYNTH_DOWNLOAD_SOURCE')
|
| 44 |
+
else:
|
| 45 |
+
return "modelscope"
|
| 46 |
+
else:
|
| 47 |
+
return self.download_source
|
| 48 |
+
|
| 49 |
+
def parse_skip_download(self):
|
| 50 |
+
if self.skip_download is None:
|
| 51 |
+
if os.environ.get('DIFFSYNTH_SKIP_DOWNLOAD') is not None:
|
| 52 |
+
if os.environ.get('DIFFSYNTH_SKIP_DOWNLOAD').lower() == "true":
|
| 53 |
+
return True
|
| 54 |
+
elif os.environ.get('DIFFSYNTH_SKIP_DOWNLOAD').lower() == "false":
|
| 55 |
+
return False
|
| 56 |
+
else:
|
| 57 |
+
return False
|
| 58 |
+
else:
|
| 59 |
+
return self.skip_download
|
| 60 |
+
|
| 61 |
+
def download(self):
|
| 62 |
+
origin_file_pattern = self.parse_original_file_pattern()
|
| 63 |
+
downloaded_files = glob.glob(origin_file_pattern, root_dir=os.path.join(self.local_model_path, self.model_id))
|
| 64 |
+
download_source = self.parse_download_source()
|
| 65 |
+
if download_source.lower() == "modelscope":
|
| 66 |
+
snapshot_download(
|
| 67 |
+
self.model_id,
|
| 68 |
+
local_dir=os.path.join(self.local_model_path, self.model_id),
|
| 69 |
+
allow_file_pattern=origin_file_pattern,
|
| 70 |
+
ignore_file_pattern=downloaded_files,
|
| 71 |
+
local_files_only=False
|
| 72 |
+
)
|
| 73 |
+
elif download_source.lower() == "huggingface":
|
| 74 |
+
hf_snapshot_download(
|
| 75 |
+
self.model_id,
|
| 76 |
+
local_dir=os.path.join(self.local_model_path, self.model_id),
|
| 77 |
+
allow_patterns=origin_file_pattern,
|
| 78 |
+
ignore_patterns=downloaded_files,
|
| 79 |
+
local_files_only=False
|
| 80 |
+
)
|
| 81 |
+
else:
|
| 82 |
+
raise ValueError("`download_source` should be `modelscope` or `huggingface`.")
|
| 83 |
+
|
| 84 |
+
def require_downloading(self):
|
| 85 |
+
if self.path is not None:
|
| 86 |
+
return False
|
| 87 |
+
skip_download = self.parse_skip_download()
|
| 88 |
+
return not skip_download
|
| 89 |
+
|
| 90 |
+
def reset_local_model_path(self):
|
| 91 |
+
if os.environ.get('DIFFSYNTH_MODEL_BASE_PATH') is not None:
|
| 92 |
+
self.local_model_path = os.environ.get('DIFFSYNTH_MODEL_BASE_PATH')
|
| 93 |
+
elif self.local_model_path is None:
|
| 94 |
+
self.local_model_path = "./models"
|
| 95 |
+
|
| 96 |
+
def download_if_necessary(self):
|
| 97 |
+
self.check_input()
|
| 98 |
+
self.reset_local_model_path()
|
| 99 |
+
if self.require_downloading():
|
| 100 |
+
self.download()
|
| 101 |
+
if self.path is None:
|
| 102 |
+
if self.origin_file_pattern in [None, "", "./"]:
|
| 103 |
+
self.path = os.path.join(self.local_model_path, self.model_id)
|
| 104 |
+
else:
|
| 105 |
+
self.path = glob.glob(os.path.join(self.local_model_path, self.model_id, self.origin_file_pattern))
|
| 106 |
+
if isinstance(self.path, list) and len(self.path) == 1:
|
| 107 |
+
self.path = self.path[0]
|
| 108 |
+
|
| 109 |
+
def vram_config(self):
|
| 110 |
+
return {
|
| 111 |
+
"offload_device": self.offload_device,
|
| 112 |
+
"offload_dtype": self.offload_dtype,
|
| 113 |
+
"onload_device": self.onload_device,
|
| 114 |
+
"onload_dtype": self.onload_dtype,
|
| 115 |
+
"preparing_device": self.preparing_device,
|
| 116 |
+
"preparing_dtype": self.preparing_dtype,
|
| 117 |
+
"computation_device": self.computation_device,
|
| 118 |
+
"computation_dtype": self.computation_dtype,
|
| 119 |
+
}
|
diffsynth/core/loader/file.py
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from safetensors import safe_open
|
| 2 |
+
import torch, hashlib
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def load_state_dict(file_path, torch_dtype=None, device="cpu", pin_memory=False, verbose=0):
|
| 6 |
+
if isinstance(file_path, list):
|
| 7 |
+
state_dict = {}
|
| 8 |
+
for file_path_ in file_path:
|
| 9 |
+
state_dict.update(load_state_dict(file_path_, torch_dtype, device, pin_memory=pin_memory, verbose=verbose))
|
| 10 |
+
else:
|
| 11 |
+
if verbose >= 1:
|
| 12 |
+
print(f"Loading file [started]: {file_path}")
|
| 13 |
+
if file_path.endswith(".safetensors"):
|
| 14 |
+
state_dict = load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype, device=device)
|
| 15 |
+
else:
|
| 16 |
+
state_dict = load_state_dict_from_bin(file_path, torch_dtype=torch_dtype, device=device)
|
| 17 |
+
# If load state dict in CPU memory, `pin_memory=True` will make `model.to("cuda")` faster.
|
| 18 |
+
if pin_memory:
|
| 19 |
+
for i in state_dict:
|
| 20 |
+
state_dict[i] = state_dict[i].pin_memory()
|
| 21 |
+
if verbose >= 1:
|
| 22 |
+
print(f"Loading file [done]: {file_path}")
|
| 23 |
+
return state_dict
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def load_state_dict_from_safetensors(file_path, torch_dtype=None, device="cpu"):
|
| 27 |
+
state_dict = {}
|
| 28 |
+
with safe_open(file_path, framework="pt", device=str(device)) as f:
|
| 29 |
+
for k in f.keys():
|
| 30 |
+
state_dict[k] = f.get_tensor(k)
|
| 31 |
+
if torch_dtype is not None:
|
| 32 |
+
state_dict[k] = state_dict[k].to(torch_dtype)
|
| 33 |
+
return state_dict
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def load_state_dict_from_bin(file_path, torch_dtype=None, device="cpu"):
|
| 37 |
+
state_dict = torch.load(file_path, map_location=device, weights_only=True)
|
| 38 |
+
if len(state_dict) == 1:
|
| 39 |
+
if "state_dict" in state_dict:
|
| 40 |
+
state_dict = state_dict["state_dict"]
|
| 41 |
+
elif "module" in state_dict:
|
| 42 |
+
state_dict = state_dict["module"]
|
| 43 |
+
elif "model_state" in state_dict:
|
| 44 |
+
state_dict = state_dict["model_state"]
|
| 45 |
+
if torch_dtype is not None:
|
| 46 |
+
for i in state_dict:
|
| 47 |
+
if isinstance(state_dict[i], torch.Tensor):
|
| 48 |
+
state_dict[i] = state_dict[i].to(torch_dtype)
|
| 49 |
+
return state_dict
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def convert_state_dict_keys_to_single_str(state_dict, with_shape=True):
|
| 53 |
+
keys = []
|
| 54 |
+
for key, value in state_dict.items():
|
| 55 |
+
if isinstance(key, str):
|
| 56 |
+
if isinstance(value, torch.Tensor):
|
| 57 |
+
if with_shape:
|
| 58 |
+
shape = "_".join(map(str, list(value.shape)))
|
| 59 |
+
keys.append(key + ":" + shape)
|
| 60 |
+
keys.append(key)
|
| 61 |
+
elif isinstance(value, dict):
|
| 62 |
+
keys.append(key + "|" + convert_state_dict_keys_to_single_str(value, with_shape=with_shape))
|
| 63 |
+
keys.sort()
|
| 64 |
+
keys_str = ",".join(keys)
|
| 65 |
+
return keys_str
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def hash_state_dict_keys(state_dict, with_shape=True):
|
| 69 |
+
keys_str = convert_state_dict_keys_to_single_str(state_dict, with_shape=with_shape)
|
| 70 |
+
keys_str = keys_str.encode(encoding="UTF-8")
|
| 71 |
+
return hashlib.md5(keys_str).hexdigest()
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def load_keys_dict(file_path):
|
| 75 |
+
if isinstance(file_path, list):
|
| 76 |
+
state_dict = {}
|
| 77 |
+
for file_path_ in file_path:
|
| 78 |
+
state_dict.update(load_keys_dict(file_path_))
|
| 79 |
+
return state_dict
|
| 80 |
+
if file_path.endswith(".safetensors"):
|
| 81 |
+
return load_keys_dict_from_safetensors(file_path)
|
| 82 |
+
else:
|
| 83 |
+
return load_keys_dict_from_bin(file_path)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def load_keys_dict_from_safetensors(file_path):
|
| 87 |
+
keys_dict = {}
|
| 88 |
+
with safe_open(file_path, framework="pt", device="cpu") as f:
|
| 89 |
+
for k in f.keys():
|
| 90 |
+
keys_dict[k] = f.get_slice(k).get_shape()
|
| 91 |
+
return keys_dict
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def convert_state_dict_to_keys_dict(state_dict):
|
| 95 |
+
keys_dict = {}
|
| 96 |
+
for k, v in state_dict.items():
|
| 97 |
+
if isinstance(v, torch.Tensor):
|
| 98 |
+
keys_dict[k] = list(v.shape)
|
| 99 |
+
else:
|
| 100 |
+
keys_dict[k] = convert_state_dict_to_keys_dict(v)
|
| 101 |
+
return keys_dict
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def load_keys_dict_from_bin(file_path):
|
| 105 |
+
state_dict = load_state_dict_from_bin(file_path)
|
| 106 |
+
keys_dict = convert_state_dict_to_keys_dict(state_dict)
|
| 107 |
+
return keys_dict
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def convert_keys_dict_to_single_str(state_dict, with_shape=True):
|
| 111 |
+
keys = []
|
| 112 |
+
for key, value in state_dict.items():
|
| 113 |
+
if isinstance(key, str):
|
| 114 |
+
if isinstance(value, dict):
|
| 115 |
+
keys.append(key + "|" + convert_keys_dict_to_single_str(value, with_shape=with_shape))
|
| 116 |
+
else:
|
| 117 |
+
if with_shape:
|
| 118 |
+
shape = "_".join(map(str, list(value)))
|
| 119 |
+
keys.append(key + ":" + shape)
|
| 120 |
+
keys.append(key)
|
| 121 |
+
keys.sort()
|
| 122 |
+
keys_str = ",".join(keys)
|
| 123 |
+
return keys_str
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def hash_model_file(path, with_shape=True):
|
| 127 |
+
keys_dict = load_keys_dict(path)
|
| 128 |
+
keys_str = convert_keys_dict_to_single_str(keys_dict, with_shape=with_shape)
|
| 129 |
+
keys_str = keys_str.encode(encoding="UTF-8")
|
| 130 |
+
return hashlib.md5(keys_str).hexdigest()
|
diffsynth/core/loader/model.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ..vram.initialization import skip_model_initialization
|
| 2 |
+
from ..vram.disk_map import DiskMap
|
| 3 |
+
from ..vram.layers import enable_vram_management
|
| 4 |
+
from .file import load_state_dict
|
| 5 |
+
import torch
|
| 6 |
+
from contextlib import contextmanager
|
| 7 |
+
from transformers.integrations import is_deepspeed_zero3_enabled
|
| 8 |
+
from transformers.utils import ContextManagers
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, device="cpu", state_dict_converter=None, use_disk_map=False, module_map=None, vram_config=None, vram_limit=None, state_dict=None):
|
| 12 |
+
config = {} if config is None else config
|
| 13 |
+
# Skip ZeRO-3 initialization for VAE to avoid compatibility issues
|
| 14 |
+
skip_zero3 = 'vae' in model_class.__name__.lower() if hasattr(model_class, '__name__') else False
|
| 15 |
+
with ContextManagers(get_init_context(torch_dtype=torch_dtype, device=device, skip_zero3=skip_zero3)):
|
| 16 |
+
model = model_class(**config)
|
| 17 |
+
# What is `module_map`?
|
| 18 |
+
# This is a module mapping table for VRAM management.
|
| 19 |
+
if module_map is not None:
|
| 20 |
+
devices = [vram_config["offload_device"], vram_config["onload_device"], vram_config["preparing_device"], vram_config["computation_device"]]
|
| 21 |
+
device = [d for d in devices if d != "disk"][0]
|
| 22 |
+
dtypes = [vram_config["offload_dtype"], vram_config["onload_dtype"], vram_config["preparing_dtype"], vram_config["computation_dtype"]]
|
| 23 |
+
dtype = [d for d in dtypes if d != "disk"][0]
|
| 24 |
+
if vram_config["offload_device"] != "disk":
|
| 25 |
+
if state_dict is None: state_dict = DiskMap(path, device, torch_dtype=dtype)
|
| 26 |
+
if state_dict_converter is not None:
|
| 27 |
+
state_dict = state_dict_converter(state_dict)
|
| 28 |
+
else:
|
| 29 |
+
state_dict = {i: state_dict[i] for i in state_dict}
|
| 30 |
+
if is_deepspeed_zero3_enabled():
|
| 31 |
+
from transformers.integrations.deepspeed import _load_state_dict_into_zero3_model
|
| 32 |
+
_load_state_dict_into_zero3_model(model, state_dict)
|
| 33 |
+
else:
|
| 34 |
+
model.load_state_dict(state_dict, assign=True)
|
| 35 |
+
model = enable_vram_management(model, module_map, vram_config=vram_config, disk_map=None, vram_limit=vram_limit)
|
| 36 |
+
else:
|
| 37 |
+
disk_map = DiskMap(path, device, state_dict_converter=state_dict_converter)
|
| 38 |
+
model = enable_vram_management(model, module_map, vram_config=vram_config, disk_map=disk_map, vram_limit=vram_limit)
|
| 39 |
+
else:
|
| 40 |
+
# Why do we use `DiskMap`?
|
| 41 |
+
# Sometimes a model file contains multiple models,
|
| 42 |
+
# and DiskMap can load only the parameters of a single model,
|
| 43 |
+
# avoiding the need to load all parameters in the file.
|
| 44 |
+
if state_dict is not None:
|
| 45 |
+
pass
|
| 46 |
+
elif use_disk_map:
|
| 47 |
+
state_dict = DiskMap(path, device, torch_dtype=torch_dtype)
|
| 48 |
+
else:
|
| 49 |
+
state_dict = load_state_dict(path, torch_dtype, device)
|
| 50 |
+
# Why do we use `state_dict_converter`?
|
| 51 |
+
# Some models are saved in complex formats,
|
| 52 |
+
# and we need to convert the state dict into the appropriate format.
|
| 53 |
+
if state_dict_converter is not None:
|
| 54 |
+
state_dict = state_dict_converter(state_dict)
|
| 55 |
+
else:
|
| 56 |
+
state_dict = {i: state_dict[i] for i in state_dict}
|
| 57 |
+
# Why does DeepSpeed ZeRO Stage 3 need to be handled separately?
|
| 58 |
+
# Because at this stage, model parameters are partitioned across multiple GPUs.
|
| 59 |
+
# Loading them directly could lead to excessive GPU memory consumption.
|
| 60 |
+
if is_deepspeed_zero3_enabled():
|
| 61 |
+
from transformers.integrations.deepspeed import _load_state_dict_into_zero3_model
|
| 62 |
+
_load_state_dict_into_zero3_model(model, state_dict)
|
| 63 |
+
else:
|
| 64 |
+
model.load_state_dict(state_dict, assign=True)
|
| 65 |
+
# Why do we call `to()`?
|
| 66 |
+
# Because some models override the behavior of `to()`,
|
| 67 |
+
# especially those from libraries like Transformers.
|
| 68 |
+
model = model.to(dtype=torch_dtype, device=device)
|
| 69 |
+
if hasattr(model, "eval"):
|
| 70 |
+
model = model.eval()
|
| 71 |
+
return model
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def load_model_with_disk_offload(model_class, path, config=None, torch_dtype=torch.bfloat16, device="cpu", state_dict_converter=None, module_map=None):
|
| 75 |
+
if isinstance(path, str):
|
| 76 |
+
path = [path]
|
| 77 |
+
config = {} if config is None else config
|
| 78 |
+
with skip_model_initialization():
|
| 79 |
+
model = model_class(**config)
|
| 80 |
+
if hasattr(model, "eval"):
|
| 81 |
+
model = model.eval()
|
| 82 |
+
disk_map = DiskMap(path, device, state_dict_converter=state_dict_converter)
|
| 83 |
+
vram_config = {
|
| 84 |
+
"offload_dtype": "disk",
|
| 85 |
+
"offload_device": "disk",
|
| 86 |
+
"onload_dtype": "disk",
|
| 87 |
+
"onload_device": "disk",
|
| 88 |
+
"preparing_dtype": torch.float8_e4m3fn,
|
| 89 |
+
"preparing_device": device,
|
| 90 |
+
"computation_dtype": torch_dtype,
|
| 91 |
+
"computation_device": device,
|
| 92 |
+
}
|
| 93 |
+
enable_vram_management(model, module_map, vram_config=vram_config, disk_map=disk_map, vram_limit=80)
|
| 94 |
+
return model
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def get_init_context(torch_dtype, device, skip_zero3=False):
|
| 98 |
+
if is_deepspeed_zero3_enabled() and not skip_zero3:
|
| 99 |
+
from transformers.modeling_utils import set_zero3_state
|
| 100 |
+
import deepspeed
|
| 101 |
+
# Why do we use "deepspeed.zero.Init"?
|
| 102 |
+
# Weight segmentation of the model can be performed on the CPU side
|
| 103 |
+
# and loading the segmented weights onto the computing card
|
| 104 |
+
init_contexts = [deepspeed.zero.Init(remote_device=device, dtype=torch_dtype), set_zero3_state()]
|
| 105 |
+
elif skip_zero3:
|
| 106 |
+
# For models excluded from ZeRO-3 (e.g. VAE), use normal initialization
|
| 107 |
+
# instead of skip_model_initialization to avoid meta tensor issues
|
| 108 |
+
init_contexts = []
|
| 109 |
+
else:
|
| 110 |
+
# Why do we use `skip_model_initialization`?
|
| 111 |
+
# It skips the random initialization of model parameters,
|
| 112 |
+
# thereby speeding up model loading and avoiding excessive memory usage.
|
| 113 |
+
init_contexts = [skip_model_initialization()]
|
| 114 |
+
|
| 115 |
+
return init_contexts
|
diffsynth/core/npu_patch/npu_fused_operator.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from ..device.npu_compatible_device import get_device_type
|
| 3 |
+
try:
|
| 4 |
+
import torch_npu
|
| 5 |
+
except:
|
| 6 |
+
pass
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def rms_norm_forward_npu(self, hidden_states):
|
| 10 |
+
"npu rms fused operator for RMSNorm.forward from diffsynth\models\general_modules.py"
|
| 11 |
+
if hidden_states.dtype != self.weight.dtype:
|
| 12 |
+
hidden_states = hidden_states.to(self.weight.dtype)
|
| 13 |
+
return torch_npu.npu_rms_norm(hidden_states, self.weight, self.eps)[0]
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def rms_norm_forward_transformers_npu(self, hidden_states):
|
| 17 |
+
"npu rms fused operator for transformers"
|
| 18 |
+
if hidden_states.dtype != self.weight.dtype:
|
| 19 |
+
hidden_states = hidden_states.to(self.weight.dtype)
|
| 20 |
+
return torch_npu.npu_rms_norm(hidden_states, self.weight, self.variance_epsilon)[0]
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def rotary_emb_Zimage_npu(self, x_in: torch.Tensor, freqs_cis: torch.Tensor):
|
| 24 |
+
"npu rope fused operator for Zimage"
|
| 25 |
+
with torch.amp.autocast(get_device_type(), enabled=False):
|
| 26 |
+
freqs_cis = freqs_cis.unsqueeze(2)
|
| 27 |
+
cos, sin = torch.chunk(torch.view_as_real(freqs_cis), 2, dim=-1)
|
| 28 |
+
cos = cos.expand(-1, -1, -1, -1, 2).flatten(-2)
|
| 29 |
+
sin = sin.expand(-1, -1, -1, -1, 2).flatten(-2)
|
| 30 |
+
return torch_npu.npu_rotary_mul(x_in, cos, sin, rotary_mode="interleave").to(x_in)
|
diffsynth/core/vram/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .initialization import skip_model_initialization
|
| 2 |
+
from .layers import *
|
diffsynth/core/vram/disk_map.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from safetensors import safe_open
|
| 2 |
+
import torch, os
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class SafetensorsCompatibleTensor:
|
| 6 |
+
def __init__(self, tensor):
|
| 7 |
+
self.tensor = tensor
|
| 8 |
+
|
| 9 |
+
def get_shape(self):
|
| 10 |
+
return list(self.tensor.shape)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class SafetensorsCompatibleBinaryLoader:
|
| 14 |
+
def __init__(self, path, device):
|
| 15 |
+
print("Detected non-safetensors files, which may cause slower loading. It's recommended to convert it to a safetensors file.")
|
| 16 |
+
self.state_dict = torch.load(path, weights_only=True, map_location=device)
|
| 17 |
+
|
| 18 |
+
def keys(self):
|
| 19 |
+
return self.state_dict.keys()
|
| 20 |
+
|
| 21 |
+
def get_tensor(self, name):
|
| 22 |
+
return self.state_dict[name]
|
| 23 |
+
|
| 24 |
+
def get_slice(self, name):
|
| 25 |
+
return SafetensorsCompatibleTensor(self.state_dict[name])
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class DiskMap:
|
| 29 |
+
|
| 30 |
+
def __init__(self, path, device, torch_dtype=None, state_dict_converter=None, buffer_size=10**9):
|
| 31 |
+
self.path = path if isinstance(path, list) else [path]
|
| 32 |
+
self.device = device
|
| 33 |
+
self.torch_dtype = torch_dtype
|
| 34 |
+
if os.environ.get('DIFFSYNTH_DISK_MAP_BUFFER_SIZE') is not None:
|
| 35 |
+
self.buffer_size = int(os.environ.get('DIFFSYNTH_DISK_MAP_BUFFER_SIZE'))
|
| 36 |
+
else:
|
| 37 |
+
self.buffer_size = buffer_size
|
| 38 |
+
self.files = []
|
| 39 |
+
self.flush_files()
|
| 40 |
+
self.name_map = {}
|
| 41 |
+
for file_id, file in enumerate(self.files):
|
| 42 |
+
for name in file.keys():
|
| 43 |
+
self.name_map[name] = file_id
|
| 44 |
+
self.rename_dict = self.fetch_rename_dict(state_dict_converter)
|
| 45 |
+
|
| 46 |
+
def flush_files(self):
|
| 47 |
+
if len(self.files) == 0:
|
| 48 |
+
for path in self.path:
|
| 49 |
+
if path.endswith(".safetensors"):
|
| 50 |
+
self.files.append(safe_open(path, framework="pt", device=str(self.device)))
|
| 51 |
+
else:
|
| 52 |
+
self.files.append(SafetensorsCompatibleBinaryLoader(path, device=self.device))
|
| 53 |
+
else:
|
| 54 |
+
for i, path in enumerate(self.path):
|
| 55 |
+
if path.endswith(".safetensors"):
|
| 56 |
+
self.files[i] = safe_open(path, framework="pt", device=str(self.device))
|
| 57 |
+
self.num_params = 0
|
| 58 |
+
|
| 59 |
+
def __getitem__(self, name):
|
| 60 |
+
if self.rename_dict is not None: name = self.rename_dict[name]
|
| 61 |
+
file_id = self.name_map[name]
|
| 62 |
+
param = self.files[file_id].get_tensor(name)
|
| 63 |
+
if self.torch_dtype is not None and isinstance(param, torch.Tensor):
|
| 64 |
+
param = param.to(self.torch_dtype)
|
| 65 |
+
if isinstance(param, torch.Tensor) and param.device == "cpu":
|
| 66 |
+
param = param.clone()
|
| 67 |
+
if isinstance(param, torch.Tensor):
|
| 68 |
+
self.num_params += param.numel()
|
| 69 |
+
if self.num_params > self.buffer_size:
|
| 70 |
+
self.flush_files()
|
| 71 |
+
return param
|
| 72 |
+
|
| 73 |
+
def fetch_rename_dict(self, state_dict_converter):
|
| 74 |
+
if state_dict_converter is None:
|
| 75 |
+
return None
|
| 76 |
+
state_dict = {}
|
| 77 |
+
for file in self.files:
|
| 78 |
+
for name in file.keys():
|
| 79 |
+
state_dict[name] = name
|
| 80 |
+
state_dict = state_dict_converter(state_dict)
|
| 81 |
+
return state_dict
|
| 82 |
+
|
| 83 |
+
def __iter__(self):
|
| 84 |
+
if self.rename_dict is not None:
|
| 85 |
+
return self.rename_dict.__iter__()
|
| 86 |
+
else:
|
| 87 |
+
return self.name_map.__iter__()
|
| 88 |
+
|
| 89 |
+
def __contains__(self, x):
|
| 90 |
+
if self.rename_dict is not None:
|
| 91 |
+
return x in self.rename_dict
|
| 92 |
+
else:
|
| 93 |
+
return x in self.name_map
|
diffsynth/core/vram/initialization.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from contextlib import contextmanager
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
@contextmanager
|
| 6 |
+
def skip_model_initialization(device=torch.device("meta")):
|
| 7 |
+
|
| 8 |
+
def register_empty_parameter(module, name, param):
|
| 9 |
+
old_register_parameter(module, name, param)
|
| 10 |
+
if param is not None:
|
| 11 |
+
param_cls = type(module._parameters[name])
|
| 12 |
+
kwargs = module._parameters[name].__dict__
|
| 13 |
+
kwargs["requires_grad"] = param.requires_grad
|
| 14 |
+
module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs)
|
| 15 |
+
|
| 16 |
+
old_register_parameter = torch.nn.Module.register_parameter
|
| 17 |
+
torch.nn.Module.register_parameter = register_empty_parameter
|
| 18 |
+
try:
|
| 19 |
+
yield
|
| 20 |
+
finally:
|
| 21 |
+
torch.nn.Module.register_parameter = old_register_parameter
|
diffsynth/core/vram/layers.py
ADDED
|
@@ -0,0 +1,479 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch, copy
|
| 2 |
+
from typing import Union
|
| 3 |
+
from .initialization import skip_model_initialization
|
| 4 |
+
from .disk_map import DiskMap
|
| 5 |
+
from ..device import parse_device_type, get_device_name, IS_NPU_AVAILABLE
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class AutoTorchModule(torch.nn.Module):
|
| 9 |
+
|
| 10 |
+
def __init__(
|
| 11 |
+
self,
|
| 12 |
+
offload_dtype: torch.dtype = None,
|
| 13 |
+
offload_device: Union[str, torch.device] = None,
|
| 14 |
+
onload_dtype: torch.dtype = None,
|
| 15 |
+
onload_device: Union[str, torch.device] = None,
|
| 16 |
+
preparing_dtype: torch.dtype = None,
|
| 17 |
+
preparing_device: Union[str, torch.device] = None,
|
| 18 |
+
computation_dtype: torch.dtype = None,
|
| 19 |
+
computation_device: Union[str, torch.device] = None,
|
| 20 |
+
vram_limit: float = None,
|
| 21 |
+
):
|
| 22 |
+
super().__init__()
|
| 23 |
+
self.set_dtype_and_device(
|
| 24 |
+
offload_dtype,
|
| 25 |
+
offload_device,
|
| 26 |
+
onload_dtype,
|
| 27 |
+
onload_device,
|
| 28 |
+
preparing_dtype,
|
| 29 |
+
preparing_device,
|
| 30 |
+
computation_dtype,
|
| 31 |
+
computation_device,
|
| 32 |
+
vram_limit,
|
| 33 |
+
)
|
| 34 |
+
self.state = 0
|
| 35 |
+
self.name = ""
|
| 36 |
+
self.computation_device_type = parse_device_type(self.computation_device)
|
| 37 |
+
|
| 38 |
+
def set_dtype_and_device(
|
| 39 |
+
self,
|
| 40 |
+
offload_dtype: torch.dtype = None,
|
| 41 |
+
offload_device: Union[str, torch.device] = None,
|
| 42 |
+
onload_dtype: torch.dtype = None,
|
| 43 |
+
onload_device: Union[str, torch.device] = None,
|
| 44 |
+
preparing_dtype: torch.dtype = None,
|
| 45 |
+
preparing_device: Union[str, torch.device] = None,
|
| 46 |
+
computation_dtype: torch.dtype = None,
|
| 47 |
+
computation_device: Union[str, torch.device] = None,
|
| 48 |
+
vram_limit: float = None,
|
| 49 |
+
):
|
| 50 |
+
self.offload_dtype = offload_dtype or computation_dtype
|
| 51 |
+
self.offload_device = offload_device or computation_dtype
|
| 52 |
+
self.onload_dtype = onload_dtype or computation_dtype
|
| 53 |
+
self.onload_device = onload_device or computation_dtype
|
| 54 |
+
self.preparing_dtype = preparing_dtype or computation_dtype
|
| 55 |
+
self.preparing_device = preparing_device or computation_dtype
|
| 56 |
+
self.computation_dtype = computation_dtype
|
| 57 |
+
self.computation_device = computation_device
|
| 58 |
+
self.vram_limit = vram_limit
|
| 59 |
+
|
| 60 |
+
def cast_to(self, weight, dtype, device):
|
| 61 |
+
r = torch.empty_like(weight, dtype=dtype, device=device)
|
| 62 |
+
r.copy_(weight)
|
| 63 |
+
return r
|
| 64 |
+
|
| 65 |
+
def check_free_vram(self):
|
| 66 |
+
device = self.computation_device if not IS_NPU_AVAILABLE else get_device_name()
|
| 67 |
+
gpu_mem_state = getattr(torch, self.computation_device_type).mem_get_info(device)
|
| 68 |
+
used_memory = (gpu_mem_state[1] - gpu_mem_state[0]) / (1024**3)
|
| 69 |
+
return used_memory < self.vram_limit
|
| 70 |
+
|
| 71 |
+
def offload(self):
|
| 72 |
+
if self.state != 0:
|
| 73 |
+
self.to(dtype=self.offload_dtype, device=self.offload_device)
|
| 74 |
+
self.state = 0
|
| 75 |
+
|
| 76 |
+
def onload(self):
|
| 77 |
+
if self.state != 1:
|
| 78 |
+
self.to(dtype=self.onload_dtype, device=self.onload_device)
|
| 79 |
+
self.state = 1
|
| 80 |
+
|
| 81 |
+
def param_name(self, name):
|
| 82 |
+
if self.name == "":
|
| 83 |
+
return name
|
| 84 |
+
else:
|
| 85 |
+
return self.name + "." + name
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
class AutoWrappedModule(AutoTorchModule):
|
| 89 |
+
|
| 90 |
+
def __init__(
|
| 91 |
+
self,
|
| 92 |
+
module: torch.nn.Module,
|
| 93 |
+
offload_dtype: torch.dtype = None,
|
| 94 |
+
offload_device: Union[str, torch.device] = None,
|
| 95 |
+
onload_dtype: torch.dtype = None,
|
| 96 |
+
onload_device: Union[str, torch.device] = None,
|
| 97 |
+
preparing_dtype: torch.dtype = None,
|
| 98 |
+
preparing_device: Union[str, torch.device] = None,
|
| 99 |
+
computation_dtype: torch.dtype = None,
|
| 100 |
+
computation_device: Union[str, torch.device] = None,
|
| 101 |
+
vram_limit: float = None,
|
| 102 |
+
name: str = "",
|
| 103 |
+
disk_map: DiskMap = None,
|
| 104 |
+
**kwargs
|
| 105 |
+
):
|
| 106 |
+
super().__init__(
|
| 107 |
+
offload_dtype,
|
| 108 |
+
offload_device,
|
| 109 |
+
onload_dtype,
|
| 110 |
+
onload_device,
|
| 111 |
+
preparing_dtype,
|
| 112 |
+
preparing_device,
|
| 113 |
+
computation_dtype,
|
| 114 |
+
computation_device,
|
| 115 |
+
vram_limit,
|
| 116 |
+
)
|
| 117 |
+
self.module = module
|
| 118 |
+
if offload_dtype == "disk":
|
| 119 |
+
self.name = name
|
| 120 |
+
self.disk_map = disk_map
|
| 121 |
+
self.required_params = [name for name, _ in self.module.named_parameters()]
|
| 122 |
+
self.disk_offload = True
|
| 123 |
+
else:
|
| 124 |
+
self.disk_offload = False
|
| 125 |
+
|
| 126 |
+
def load_from_disk(self, torch_dtype, device, copy_module=False):
|
| 127 |
+
if copy_module:
|
| 128 |
+
module = copy.deepcopy(self.module)
|
| 129 |
+
else:
|
| 130 |
+
module = self.module
|
| 131 |
+
state_dict = {}
|
| 132 |
+
for name in self.required_params:
|
| 133 |
+
param = self.disk_map[self.param_name(name)]
|
| 134 |
+
param = param.to(dtype=torch_dtype, device=device)
|
| 135 |
+
state_dict[name] = param
|
| 136 |
+
module.load_state_dict(state_dict, assign=True)
|
| 137 |
+
module.to(dtype=torch_dtype, device=device)
|
| 138 |
+
return module
|
| 139 |
+
|
| 140 |
+
def offload_to_disk(self, model: torch.nn.Module):
|
| 141 |
+
for buf in model.buffers():
|
| 142 |
+
# If there are some parameters are registed in buffers (not in state dict),
|
| 143 |
+
# We cannot offload the model.
|
| 144 |
+
for children in model.children():
|
| 145 |
+
self.offload_to_disk(children)
|
| 146 |
+
break
|
| 147 |
+
else:
|
| 148 |
+
model.to("meta")
|
| 149 |
+
|
| 150 |
+
def offload(self):
|
| 151 |
+
# offload / onload / preparing -> offload
|
| 152 |
+
if self.state != 0:
|
| 153 |
+
if self.disk_offload:
|
| 154 |
+
self.offload_to_disk(self.module)
|
| 155 |
+
else:
|
| 156 |
+
self.to(dtype=self.offload_dtype, device=self.offload_device)
|
| 157 |
+
self.state = 0
|
| 158 |
+
|
| 159 |
+
def onload(self):
|
| 160 |
+
# offload / onload / preparing -> onload
|
| 161 |
+
if self.state < 1:
|
| 162 |
+
if self.disk_offload and self.onload_device != "disk" and self.offload_device == "disk":
|
| 163 |
+
self.load_from_disk(self.onload_dtype, self.onload_device)
|
| 164 |
+
elif self.onload_device != "disk":
|
| 165 |
+
self.to(dtype=self.onload_dtype, device=self.onload_device)
|
| 166 |
+
self.state = 1
|
| 167 |
+
|
| 168 |
+
def preparing(self):
|
| 169 |
+
# onload / preparing -> preparing
|
| 170 |
+
if self.state != 2:
|
| 171 |
+
if self.disk_offload and self.preparing_device != "disk" and self.onload_device == "disk":
|
| 172 |
+
self.load_from_disk(self.preparing_dtype, self.preparing_device)
|
| 173 |
+
elif self.preparing_device != "disk":
|
| 174 |
+
self.to(dtype=self.preparing_dtype, device=self.preparing_device)
|
| 175 |
+
self.state = 2
|
| 176 |
+
|
| 177 |
+
def cast_to(self, module, dtype, device):
|
| 178 |
+
return copy.deepcopy(module).to(dtype=dtype, device=device)
|
| 179 |
+
|
| 180 |
+
def computation(self):
|
| 181 |
+
# onload / preparing -> computation (temporary)
|
| 182 |
+
if self.state == 2:
|
| 183 |
+
torch_dtype, device = self.preparing_dtype, self.preparing_device
|
| 184 |
+
else:
|
| 185 |
+
torch_dtype, device = self.onload_dtype, self.onload_device
|
| 186 |
+
if torch_dtype == self.computation_dtype and device == self.computation_device:
|
| 187 |
+
module = self.module
|
| 188 |
+
elif self.disk_offload and device == "disk":
|
| 189 |
+
module = self.load_from_disk(self.computation_dtype, self.computation_device, copy_module=True)
|
| 190 |
+
else:
|
| 191 |
+
module = self.cast_to(self.module, dtype=self.computation_dtype, device=self.computation_device)
|
| 192 |
+
return module
|
| 193 |
+
|
| 194 |
+
def forward(self, *args, **kwargs):
|
| 195 |
+
if self.state == 1 and (self.vram_limit is None or self.check_free_vram()):
|
| 196 |
+
self.preparing()
|
| 197 |
+
module = self.computation()
|
| 198 |
+
return module(*args, **kwargs)
|
| 199 |
+
|
| 200 |
+
def __getattr__(self, name):
|
| 201 |
+
if name in self.__dict__ or name == "module":
|
| 202 |
+
return super().__getattr__(name)
|
| 203 |
+
else:
|
| 204 |
+
return getattr(self.module, name)
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
class AutoWrappedNonRecurseModule(AutoWrappedModule):
|
| 208 |
+
|
| 209 |
+
def __init__(
|
| 210 |
+
self,
|
| 211 |
+
module: torch.nn.Module,
|
| 212 |
+
offload_dtype: torch.dtype = None,
|
| 213 |
+
offload_device: Union[str, torch.device] = None,
|
| 214 |
+
onload_dtype: torch.dtype = None,
|
| 215 |
+
onload_device: Union[str, torch.device] = None,
|
| 216 |
+
preparing_dtype: torch.dtype = None,
|
| 217 |
+
preparing_device: Union[str, torch.device] = None,
|
| 218 |
+
computation_dtype: torch.dtype = None,
|
| 219 |
+
computation_device: Union[str, torch.device] = None,
|
| 220 |
+
vram_limit: float = None,
|
| 221 |
+
name: str = "",
|
| 222 |
+
disk_map: DiskMap = None,
|
| 223 |
+
**kwargs
|
| 224 |
+
):
|
| 225 |
+
super().__init__(
|
| 226 |
+
module,
|
| 227 |
+
offload_dtype,
|
| 228 |
+
offload_device,
|
| 229 |
+
onload_dtype,
|
| 230 |
+
onload_device,
|
| 231 |
+
preparing_dtype,
|
| 232 |
+
preparing_device,
|
| 233 |
+
computation_dtype,
|
| 234 |
+
computation_device,
|
| 235 |
+
vram_limit,
|
| 236 |
+
name,
|
| 237 |
+
disk_map,
|
| 238 |
+
**kwargs
|
| 239 |
+
)
|
| 240 |
+
if self.disk_offload:
|
| 241 |
+
self.required_params = [name for name, _ in self.module.named_parameters(recurse=False)]
|
| 242 |
+
|
| 243 |
+
def load_from_disk(self, torch_dtype, device, copy_module=False):
|
| 244 |
+
if copy_module:
|
| 245 |
+
module = copy.deepcopy(self.module)
|
| 246 |
+
else:
|
| 247 |
+
module = self.module
|
| 248 |
+
state_dict = {}
|
| 249 |
+
for name in self.required_params:
|
| 250 |
+
param = self.disk_map[self.param_name(name)]
|
| 251 |
+
param = param.to(dtype=torch_dtype, device=device)
|
| 252 |
+
state_dict[name] = param
|
| 253 |
+
module.load_state_dict(state_dict, assign=True, strict=False)
|
| 254 |
+
return module
|
| 255 |
+
|
| 256 |
+
def offload_to_disk(self, model: torch.nn.Module):
|
| 257 |
+
for name in self.required_params:
|
| 258 |
+
getattr(self, name).to("meta")
|
| 259 |
+
|
| 260 |
+
def cast_to(self, module, dtype, device):
|
| 261 |
+
# Parameter casting is implemented in the model architecture.
|
| 262 |
+
return module
|
| 263 |
+
|
| 264 |
+
def __getattr__(self, name):
|
| 265 |
+
if name in self.__dict__ or name == "module":
|
| 266 |
+
return super().__getattr__(name)
|
| 267 |
+
else:
|
| 268 |
+
return getattr(self.module, name)
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
class AutoWrappedLinear(torch.nn.Linear, AutoTorchModule):
|
| 272 |
+
def __init__(
|
| 273 |
+
self,
|
| 274 |
+
module: torch.nn.Linear,
|
| 275 |
+
offload_dtype: torch.dtype = None,
|
| 276 |
+
offload_device: Union[str, torch.device] = None,
|
| 277 |
+
onload_dtype: torch.dtype = None,
|
| 278 |
+
onload_device: Union[str, torch.device] = None,
|
| 279 |
+
preparing_dtype: torch.dtype = None,
|
| 280 |
+
preparing_device: Union[str, torch.device] = None,
|
| 281 |
+
computation_dtype: torch.dtype = None,
|
| 282 |
+
computation_device: Union[str, torch.device] = None,
|
| 283 |
+
vram_limit: float = None,
|
| 284 |
+
name: str = "",
|
| 285 |
+
disk_map: DiskMap = None,
|
| 286 |
+
**kwargs
|
| 287 |
+
):
|
| 288 |
+
with skip_model_initialization():
|
| 289 |
+
super().__init__(
|
| 290 |
+
in_features=module.in_features,
|
| 291 |
+
out_features=module.out_features,
|
| 292 |
+
bias=module.bias is not None,
|
| 293 |
+
)
|
| 294 |
+
self.set_dtype_and_device(
|
| 295 |
+
offload_dtype,
|
| 296 |
+
offload_device,
|
| 297 |
+
onload_dtype,
|
| 298 |
+
onload_device,
|
| 299 |
+
preparing_dtype,
|
| 300 |
+
preparing_device,
|
| 301 |
+
computation_dtype,
|
| 302 |
+
computation_device,
|
| 303 |
+
vram_limit,
|
| 304 |
+
)
|
| 305 |
+
self.weight = module.weight
|
| 306 |
+
self.bias = module.bias
|
| 307 |
+
self.state = 0
|
| 308 |
+
self.name = name
|
| 309 |
+
self.lora_A_weights = []
|
| 310 |
+
self.lora_B_weights = []
|
| 311 |
+
self.lora_merger = None
|
| 312 |
+
self.enable_fp8 = computation_dtype in [torch.float8_e4m3fn, torch.float8_e4m3fnuz]
|
| 313 |
+
self.computation_device_type = parse_device_type(self.computation_device)
|
| 314 |
+
|
| 315 |
+
if offload_dtype == "disk":
|
| 316 |
+
self.disk_map = disk_map
|
| 317 |
+
self.disk_offload = True
|
| 318 |
+
else:
|
| 319 |
+
self.disk_offload = False
|
| 320 |
+
|
| 321 |
+
def fp8_linear(
|
| 322 |
+
self,
|
| 323 |
+
input: torch.Tensor,
|
| 324 |
+
weight: torch.Tensor,
|
| 325 |
+
bias: torch.Tensor = None,
|
| 326 |
+
) -> torch.Tensor:
|
| 327 |
+
device = input.device
|
| 328 |
+
origin_dtype = input.dtype
|
| 329 |
+
origin_shape = input.shape
|
| 330 |
+
input = input.reshape(-1, origin_shape[-1])
|
| 331 |
+
|
| 332 |
+
x_max = torch.max(torch.abs(input), dim=-1, keepdim=True).values
|
| 333 |
+
fp8_max = 448.0
|
| 334 |
+
# For float8_e4m3fnuz, the maximum representable value is half of that of e4m3fn.
|
| 335 |
+
# To avoid overflow and ensure numerical compatibility during FP8 computation,
|
| 336 |
+
# we scale down the input by 2.0 in advance.
|
| 337 |
+
# This scaling will be compensated later during the final result scaling.
|
| 338 |
+
if self.computation_dtype == torch.float8_e4m3fnuz:
|
| 339 |
+
fp8_max = fp8_max / 2.0
|
| 340 |
+
scale_a = torch.clamp(x_max / fp8_max, min=1.0).float().to(device=device)
|
| 341 |
+
scale_b = torch.ones((weight.shape[0], 1)).to(device=device)
|
| 342 |
+
input = input / (scale_a + 1e-8)
|
| 343 |
+
input = input.to(self.computation_dtype)
|
| 344 |
+
weight = weight.to(self.computation_dtype)
|
| 345 |
+
bias = bias.to(torch.bfloat16)
|
| 346 |
+
|
| 347 |
+
result = torch._scaled_mm(
|
| 348 |
+
input,
|
| 349 |
+
weight.T,
|
| 350 |
+
scale_a=scale_a,
|
| 351 |
+
scale_b=scale_b.T,
|
| 352 |
+
bias=bias,
|
| 353 |
+
out_dtype=origin_dtype,
|
| 354 |
+
)
|
| 355 |
+
new_shape = origin_shape[:-1] + result.shape[-1:]
|
| 356 |
+
result = result.reshape(new_shape)
|
| 357 |
+
return result
|
| 358 |
+
|
| 359 |
+
def load_from_disk(self, torch_dtype, device, assign=True):
|
| 360 |
+
weight = self.disk_map[self.name + ".weight"].to(dtype=torch_dtype, device=device)
|
| 361 |
+
bias = None if self.bias is None else self.disk_map[self.name + ".bias"].to(dtype=torch_dtype, device=device)
|
| 362 |
+
if assign:
|
| 363 |
+
state_dict = {"weight": weight}
|
| 364 |
+
if bias is not None: state_dict["bias"] = bias
|
| 365 |
+
self.load_state_dict(state_dict, assign=True)
|
| 366 |
+
return weight, bias
|
| 367 |
+
|
| 368 |
+
def offload(self):
|
| 369 |
+
# offload / onload / preparing -> offload
|
| 370 |
+
if self.state != 0:
|
| 371 |
+
if self.disk_offload:
|
| 372 |
+
self.to("meta")
|
| 373 |
+
else:
|
| 374 |
+
self.to(dtype=self.offload_dtype, device=self.offload_device)
|
| 375 |
+
self.state = 0
|
| 376 |
+
|
| 377 |
+
def onload(self):
|
| 378 |
+
# offload / onload / preparing -> onload
|
| 379 |
+
if self.state < 1:
|
| 380 |
+
if self.disk_offload and self.onload_device != "disk" and self.offload_device == "disk":
|
| 381 |
+
self.load_from_disk(self.onload_dtype, self.onload_device)
|
| 382 |
+
elif self.onload_device != "disk":
|
| 383 |
+
self.to(dtype=self.onload_dtype, device=self.onload_device)
|
| 384 |
+
self.state = 1
|
| 385 |
+
|
| 386 |
+
def preparing(self):
|
| 387 |
+
# onload / preparing -> preparing
|
| 388 |
+
if self.state != 2:
|
| 389 |
+
if self.disk_offload and self.preparing_device != "disk" and self.onload_device == "disk":
|
| 390 |
+
self.load_from_disk(self.preparing_dtype, self.preparing_device)
|
| 391 |
+
elif self.preparing_device != "disk":
|
| 392 |
+
self.to(dtype=self.preparing_dtype, device=self.preparing_device)
|
| 393 |
+
self.state = 2
|
| 394 |
+
|
| 395 |
+
def computation(self):
|
| 396 |
+
# onload / preparing -> computation (temporary)
|
| 397 |
+
if self.state == 2:
|
| 398 |
+
torch_dtype, device = self.preparing_dtype, self.preparing_device
|
| 399 |
+
else:
|
| 400 |
+
torch_dtype, device = self.onload_dtype, self.onload_device
|
| 401 |
+
if torch_dtype == self.computation_dtype and device == self.computation_device:
|
| 402 |
+
weight, bias = self.weight, self.bias
|
| 403 |
+
elif self.disk_offload and device == "disk":
|
| 404 |
+
weight, bias = self.load_from_disk(self.computation_dtype, self.computation_device, assign=False)
|
| 405 |
+
else:
|
| 406 |
+
weight = self.cast_to(self.weight, self.computation_dtype, self.computation_device)
|
| 407 |
+
bias = None if self.bias is None else self.cast_to(self.bias, self.computation_dtype, self.computation_device)
|
| 408 |
+
return weight, bias
|
| 409 |
+
|
| 410 |
+
def linear_forward(self, x, weight, bias):
|
| 411 |
+
if self.enable_fp8:
|
| 412 |
+
out = self.fp8_linear(x, weight, bias)
|
| 413 |
+
else:
|
| 414 |
+
out = torch.nn.functional.linear(x, weight, bias)
|
| 415 |
+
return out
|
| 416 |
+
|
| 417 |
+
def lora_forward(self, x, out):
|
| 418 |
+
if self.lora_merger is None:
|
| 419 |
+
for lora_A, lora_B in zip(self.lora_A_weights, self.lora_B_weights):
|
| 420 |
+
out = out + x @ lora_A.T.to(device=x.device, dtype=x.dtype) @ lora_B.T.to(device=x.device, dtype=x.dtype)
|
| 421 |
+
else:
|
| 422 |
+
lora_output = []
|
| 423 |
+
for lora_A, lora_B in zip(self.lora_A_weights, self.lora_B_weights):
|
| 424 |
+
lora_output.append(x @ lora_A.T @ lora_B.T)
|
| 425 |
+
lora_output = torch.stack(lora_output)
|
| 426 |
+
out = self.lora_merger(out, lora_output)
|
| 427 |
+
return out
|
| 428 |
+
|
| 429 |
+
def forward(self, x, *args, **kwargs):
|
| 430 |
+
if self.state == 1 and (self.vram_limit is None or self.check_free_vram()):
|
| 431 |
+
self.preparing()
|
| 432 |
+
weight, bias = self.computation()
|
| 433 |
+
out = self.linear_forward(x, weight, bias)
|
| 434 |
+
if len(self.lora_A_weights) > 0:
|
| 435 |
+
out = self.lora_forward(x, out)
|
| 436 |
+
return out
|
| 437 |
+
|
| 438 |
+
|
| 439 |
+
def enable_vram_management_recursively(model: torch.nn.Module, module_map: dict, vram_config: dict, vram_limit=None, name_prefix="", disk_map=None, **kwargs):
|
| 440 |
+
if isinstance(model, AutoWrappedNonRecurseModule):
|
| 441 |
+
model = model.module
|
| 442 |
+
for name, module in model.named_children():
|
| 443 |
+
layer_name = name if name_prefix == "" else name_prefix + "." + name
|
| 444 |
+
for source_module, target_module in module_map.items():
|
| 445 |
+
if isinstance(module, source_module):
|
| 446 |
+
module_ = target_module(module, **vram_config, vram_limit=vram_limit, name=layer_name, disk_map=disk_map, **kwargs)
|
| 447 |
+
if isinstance(module_, AutoWrappedNonRecurseModule):
|
| 448 |
+
enable_vram_management_recursively(module_, module_map, vram_config, vram_limit=vram_limit, name_prefix=layer_name, disk_map=disk_map, **kwargs)
|
| 449 |
+
setattr(model, name, module_)
|
| 450 |
+
break
|
| 451 |
+
else:
|
| 452 |
+
enable_vram_management_recursively(module, module_map, vram_config, vram_limit=vram_limit, name_prefix=layer_name, disk_map=disk_map, **kwargs)
|
| 453 |
+
|
| 454 |
+
|
| 455 |
+
def fill_vram_config(model, vram_config):
|
| 456 |
+
vram_config_ = vram_config.copy()
|
| 457 |
+
vram_config_["onload_dtype"] = vram_config["computation_dtype"]
|
| 458 |
+
vram_config_["onload_device"] = vram_config["computation_device"]
|
| 459 |
+
vram_config_["preparing_dtype"] = vram_config["computation_dtype"]
|
| 460 |
+
vram_config_["preparing_device"] = vram_config["computation_device"]
|
| 461 |
+
for k in vram_config:
|
| 462 |
+
if vram_config[k] != vram_config_[k]:
|
| 463 |
+
print(f"No fine-grained VRAM configuration is provided for {model.__class__.__name__}. [`onload`, `preparing`, `computation`] will be the same state. `vram_config` is set to {vram_config_}")
|
| 464 |
+
break
|
| 465 |
+
return vram_config_
|
| 466 |
+
|
| 467 |
+
|
| 468 |
+
def enable_vram_management(model: torch.nn.Module, module_map: dict, vram_config: dict, vram_limit=None, disk_map=None, **kwargs):
|
| 469 |
+
for source_module, target_module in module_map.items():
|
| 470 |
+
# If no fine-grained VRAM configuration is provided, the entire model will be managed uniformly.
|
| 471 |
+
if isinstance(model, source_module):
|
| 472 |
+
vram_config = fill_vram_config(model, vram_config)
|
| 473 |
+
model = target_module(model, **vram_config, vram_limit=vram_limit, disk_map=disk_map, **kwargs)
|
| 474 |
+
break
|
| 475 |
+
else:
|
| 476 |
+
enable_vram_management_recursively(model, module_map, vram_config, vram_limit=vram_limit, disk_map=disk_map, **kwargs)
|
| 477 |
+
# `vram_management_enabled` is a flag that allows the pipeline to determine whether VRAM management is enabled.
|
| 478 |
+
model.vram_management_enabled = True
|
| 479 |
+
return model
|
diffsynth/diffusion/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .flow_match import FlowMatchScheduler
|
| 2 |
+
from .training_module import DiffusionTrainingModule
|
| 3 |
+
from .logger import ModelLogger
|
| 4 |
+
from .runner import launch_training_task, launch_data_process_task
|
| 5 |
+
from .parsers import *
|
| 6 |
+
from .loss import *
|
diffsynth/diffusion/base_pipeline.py
ADDED
|
@@ -0,0 +1,500 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from PIL import Image
|
| 2 |
+
import torch
|
| 3 |
+
import numpy as np
|
| 4 |
+
from einops import repeat, reduce
|
| 5 |
+
from typing import Union
|
| 6 |
+
from ..core import AutoTorchModule, AutoWrappedLinear, load_state_dict, ModelConfig, parse_device_type
|
| 7 |
+
from ..core.device.npu_compatible_device import get_device_type
|
| 8 |
+
from ..utils.lora import GeneralLoRALoader
|
| 9 |
+
from ..models.model_loader import ModelPool
|
| 10 |
+
from ..utils.controlnet import ControlNetInput
|
| 11 |
+
from ..core.device import get_device_name, IS_NPU_AVAILABLE
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class PipelineUnit:
|
| 15 |
+
def __init__(
|
| 16 |
+
self,
|
| 17 |
+
seperate_cfg: bool = False,
|
| 18 |
+
take_over: bool = False,
|
| 19 |
+
input_params: tuple[str] = None,
|
| 20 |
+
output_params: tuple[str] = None,
|
| 21 |
+
input_params_posi: dict[str, str] = None,
|
| 22 |
+
input_params_nega: dict[str, str] = None,
|
| 23 |
+
onload_model_names: tuple[str] = None
|
| 24 |
+
):
|
| 25 |
+
self.seperate_cfg = seperate_cfg
|
| 26 |
+
self.take_over = take_over
|
| 27 |
+
self.input_params = input_params
|
| 28 |
+
self.output_params = output_params
|
| 29 |
+
self.input_params_posi = input_params_posi
|
| 30 |
+
self.input_params_nega = input_params_nega
|
| 31 |
+
self.onload_model_names = onload_model_names
|
| 32 |
+
|
| 33 |
+
def fetch_input_params(self):
|
| 34 |
+
params = []
|
| 35 |
+
if self.input_params is not None:
|
| 36 |
+
for param in self.input_params:
|
| 37 |
+
params.append(param)
|
| 38 |
+
if self.input_params_posi is not None:
|
| 39 |
+
for _, param in self.input_params_posi.items():
|
| 40 |
+
params.append(param)
|
| 41 |
+
if self.input_params_nega is not None:
|
| 42 |
+
for _, param in self.input_params_nega.items():
|
| 43 |
+
params.append(param)
|
| 44 |
+
params = sorted(list(set(params)))
|
| 45 |
+
return params
|
| 46 |
+
|
| 47 |
+
def fetch_output_params(self):
|
| 48 |
+
params = []
|
| 49 |
+
if self.output_params is not None:
|
| 50 |
+
for param in self.output_params:
|
| 51 |
+
params.append(param)
|
| 52 |
+
return params
|
| 53 |
+
|
| 54 |
+
def process(self, pipe, **kwargs) -> dict:
|
| 55 |
+
return {}
|
| 56 |
+
|
| 57 |
+
def post_process(self, pipe, **kwargs) -> dict:
|
| 58 |
+
return {}
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class BasePipeline(torch.nn.Module):
|
| 62 |
+
|
| 63 |
+
def __init__(
|
| 64 |
+
self,
|
| 65 |
+
device=get_device_type(), torch_dtype=torch.float16,
|
| 66 |
+
height_division_factor=64, width_division_factor=64,
|
| 67 |
+
time_division_factor=None, time_division_remainder=None,
|
| 68 |
+
):
|
| 69 |
+
super().__init__()
|
| 70 |
+
# The device and torch_dtype is used for the storage of intermediate variables, not models.
|
| 71 |
+
self.device = device
|
| 72 |
+
self.torch_dtype = torch_dtype
|
| 73 |
+
self.device_type = parse_device_type(device)
|
| 74 |
+
# The following parameters are used for shape check.
|
| 75 |
+
self.height_division_factor = height_division_factor
|
| 76 |
+
self.width_division_factor = width_division_factor
|
| 77 |
+
self.time_division_factor = time_division_factor
|
| 78 |
+
self.time_division_remainder = time_division_remainder
|
| 79 |
+
# VRAM management
|
| 80 |
+
self.vram_management_enabled = False
|
| 81 |
+
# Pipeline Unit Runner
|
| 82 |
+
self.unit_runner = PipelineUnitRunner()
|
| 83 |
+
# LoRA Loader
|
| 84 |
+
self.lora_loader = GeneralLoRALoader
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def to(self, *args, **kwargs):
|
| 88 |
+
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
|
| 89 |
+
if device is not None:
|
| 90 |
+
self.device = device
|
| 91 |
+
if dtype is not None:
|
| 92 |
+
self.torch_dtype = dtype
|
| 93 |
+
super().to(*args, **kwargs)
|
| 94 |
+
return self
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def check_resize_height_width(self, height, width, num_frames=None, verbose=1):
|
| 98 |
+
# Shape check
|
| 99 |
+
if height % self.height_division_factor != 0:
|
| 100 |
+
height = (height + self.height_division_factor - 1) // self.height_division_factor * self.height_division_factor
|
| 101 |
+
if verbose > 0:
|
| 102 |
+
print(f"height % {self.height_division_factor} != 0. We round it up to {height}.")
|
| 103 |
+
if width % self.width_division_factor != 0:
|
| 104 |
+
width = (width + self.width_division_factor - 1) // self.width_division_factor * self.width_division_factor
|
| 105 |
+
if verbose > 0:
|
| 106 |
+
print(f"width % {self.width_division_factor} != 0. We round it up to {width}.")
|
| 107 |
+
if num_frames is None:
|
| 108 |
+
return height, width
|
| 109 |
+
else:
|
| 110 |
+
if num_frames % self.time_division_factor != self.time_division_remainder:
|
| 111 |
+
num_frames = (num_frames + self.time_division_factor - 1) // self.time_division_factor * self.time_division_factor + self.time_division_remainder
|
| 112 |
+
if verbose > 0:
|
| 113 |
+
print(f"num_frames % {self.time_division_factor} != {self.time_division_remainder}. We round it up to {num_frames}.")
|
| 114 |
+
return height, width, num_frames
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def preprocess_image(self, image, torch_dtype=None, device=None, pattern="B C H W", min_value=-1, max_value=1):
|
| 118 |
+
# Transform a PIL.Image to torch.Tensor
|
| 119 |
+
image = torch.Tensor(np.array(image, dtype=np.float32))
|
| 120 |
+
image = image.to(dtype=torch_dtype or self.torch_dtype, device=device or self.device)
|
| 121 |
+
image = image * ((max_value - min_value) / 255) + min_value
|
| 122 |
+
image = repeat(image, f"H W C -> {pattern}", **({"B": 1} if "B" in pattern else {}))
|
| 123 |
+
return image
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def preprocess_video(self, video, torch_dtype=None, device=None, pattern="B C T H W", min_value=-1, max_value=1):
|
| 127 |
+
# Transform a list of PIL.Image to torch.Tensor
|
| 128 |
+
video = [self.preprocess_image(image, torch_dtype=torch_dtype, device=device, min_value=min_value, max_value=max_value) for image in video]
|
| 129 |
+
video = torch.stack(video, dim=pattern.index("T") // 2)
|
| 130 |
+
return video
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def vae_output_to_image(self, vae_output, pattern="B C H W", min_value=-1, max_value=1):
|
| 134 |
+
# Transform a torch.Tensor to PIL.Image
|
| 135 |
+
if pattern != "H W C":
|
| 136 |
+
vae_output = reduce(vae_output, f"{pattern} -> H W C", reduction="mean")
|
| 137 |
+
image = ((vae_output - min_value) * (255 / (max_value - min_value))).clip(0, 255)
|
| 138 |
+
image = image.to(device="cpu", dtype=torch.uint8)
|
| 139 |
+
image = Image.fromarray(image.numpy())
|
| 140 |
+
return image
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def vae_output_to_video(self, vae_output, pattern="B C T H W", min_value=-1, max_value=1):
|
| 144 |
+
# Transform a torch.Tensor to list of PIL.Image
|
| 145 |
+
if pattern != "T H W C":
|
| 146 |
+
vae_output = reduce(vae_output, f"{pattern} -> T H W C", reduction="mean")
|
| 147 |
+
video = [self.vae_output_to_image(image, pattern="H W C", min_value=min_value, max_value=max_value) for image in vae_output]
|
| 148 |
+
return video
|
| 149 |
+
|
| 150 |
+
def output_audio_format_check(self, audio_output):
|
| 151 |
+
# output standard foramt: [C, T], output dtype: float()
|
| 152 |
+
# remove batch dim
|
| 153 |
+
if audio_output.ndim == 3:
|
| 154 |
+
audio_output = audio_output.squeeze(0)
|
| 155 |
+
return audio_output.float()
|
| 156 |
+
|
| 157 |
+
def load_models_to_device(self, model_names):
|
| 158 |
+
if self.vram_management_enabled:
|
| 159 |
+
# offload models
|
| 160 |
+
for name, model in self.named_children():
|
| 161 |
+
if name not in model_names:
|
| 162 |
+
if hasattr(model, "vram_management_enabled") and model.vram_management_enabled:
|
| 163 |
+
if hasattr(model, "offload"):
|
| 164 |
+
model.offload()
|
| 165 |
+
else:
|
| 166 |
+
for module in model.modules():
|
| 167 |
+
if hasattr(module, "offload"):
|
| 168 |
+
module.offload()
|
| 169 |
+
getattr(torch, self.device_type).empty_cache()
|
| 170 |
+
# onload models
|
| 171 |
+
for name, model in self.named_children():
|
| 172 |
+
if name in model_names:
|
| 173 |
+
if hasattr(model, "vram_management_enabled") and model.vram_management_enabled:
|
| 174 |
+
if hasattr(model, "onload"):
|
| 175 |
+
model.onload()
|
| 176 |
+
else:
|
| 177 |
+
for module in model.modules():
|
| 178 |
+
if hasattr(module, "onload"):
|
| 179 |
+
module.onload()
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def generate_noise(self, shape, seed=None, rand_device="cpu", rand_torch_dtype=torch.float32, device=None, torch_dtype=None):
|
| 183 |
+
# Initialize Gaussian noise
|
| 184 |
+
generator = None if seed is None else torch.Generator(rand_device).manual_seed(seed)
|
| 185 |
+
noise = torch.randn(shape, generator=generator, device=rand_device, dtype=rand_torch_dtype)
|
| 186 |
+
noise = noise.to(dtype=torch_dtype or self.torch_dtype, device=device or self.device)
|
| 187 |
+
return noise
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def get_vram(self):
|
| 191 |
+
device = self.device if not IS_NPU_AVAILABLE else get_device_name()
|
| 192 |
+
return getattr(torch, self.device_type).mem_get_info(device)[1] / (1024 ** 3)
|
| 193 |
+
|
| 194 |
+
def get_module(self, model, name):
|
| 195 |
+
if "." in name:
|
| 196 |
+
name, suffix = name[:name.index(".")], name[name.index(".") + 1:]
|
| 197 |
+
if name.isdigit():
|
| 198 |
+
return self.get_module(model[int(name)], suffix)
|
| 199 |
+
else:
|
| 200 |
+
return self.get_module(getattr(model, name), suffix)
|
| 201 |
+
else:
|
| 202 |
+
return getattr(model, name)
|
| 203 |
+
|
| 204 |
+
def freeze_except(self, model_names):
|
| 205 |
+
self.eval()
|
| 206 |
+
self.requires_grad_(False)
|
| 207 |
+
for name in model_names:
|
| 208 |
+
module = self.get_module(self, name)
|
| 209 |
+
if module is None:
|
| 210 |
+
print(f"No {name} models in the pipeline. We cannot enable training on the model. If this occurs during the data processing stage, it is normal.")
|
| 211 |
+
continue
|
| 212 |
+
module.train()
|
| 213 |
+
module.requires_grad_(True)
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def blend_with_mask(self, base, addition, mask):
|
| 217 |
+
return base * (1 - mask) + addition * mask
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def step(self, scheduler, latents, progress_id, noise_pred, input_latents=None, inpaint_mask=None, **kwargs):
|
| 221 |
+
timestep = scheduler.timesteps[progress_id]
|
| 222 |
+
if inpaint_mask is not None:
|
| 223 |
+
noise_pred_expected = scheduler.return_to_timestep(scheduler.timesteps[progress_id], latents, input_latents)
|
| 224 |
+
noise_pred = self.blend_with_mask(noise_pred_expected, noise_pred, inpaint_mask)
|
| 225 |
+
latents_next = scheduler.step(noise_pred, timestep, latents)
|
| 226 |
+
return latents_next
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
def split_pipeline_units(self, model_names: list[str]):
|
| 230 |
+
return PipelineUnitGraph().split_pipeline_units(self.units, model_names)
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
def flush_vram_management_device(self, device):
|
| 234 |
+
for module in self.modules():
|
| 235 |
+
if isinstance(module, AutoTorchModule):
|
| 236 |
+
module.offload_device = device
|
| 237 |
+
module.onload_device = device
|
| 238 |
+
module.preparing_device = device
|
| 239 |
+
module.computation_device = device
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
def load_lora(
|
| 243 |
+
self,
|
| 244 |
+
module: torch.nn.Module,
|
| 245 |
+
lora_config: Union[ModelConfig, str] = None,
|
| 246 |
+
alpha=1,
|
| 247 |
+
hotload=None,
|
| 248 |
+
state_dict=None,
|
| 249 |
+
verbose=1,
|
| 250 |
+
):
|
| 251 |
+
if state_dict is None:
|
| 252 |
+
if isinstance(lora_config, str):
|
| 253 |
+
lora = load_state_dict(lora_config, torch_dtype=self.torch_dtype, device=self.device)
|
| 254 |
+
else:
|
| 255 |
+
lora_config.download_if_necessary()
|
| 256 |
+
lora = load_state_dict(lora_config.path, torch_dtype=self.torch_dtype, device=self.device)
|
| 257 |
+
else:
|
| 258 |
+
lora = state_dict
|
| 259 |
+
lora_loader = self.lora_loader(torch_dtype=self.torch_dtype, device=self.device)
|
| 260 |
+
lora = lora_loader.convert_state_dict(lora)
|
| 261 |
+
if hotload is None:
|
| 262 |
+
hotload = hasattr(module, "vram_management_enabled") and getattr(module, "vram_management_enabled")
|
| 263 |
+
if hotload:
|
| 264 |
+
if not (hasattr(module, "vram_management_enabled") and getattr(module, "vram_management_enabled")):
|
| 265 |
+
raise ValueError("VRAM Management is not enabled. LoRA hotloading is not supported.")
|
| 266 |
+
updated_num = 0
|
| 267 |
+
for _, module in module.named_modules():
|
| 268 |
+
if isinstance(module, AutoWrappedLinear):
|
| 269 |
+
name = module.name
|
| 270 |
+
lora_a_name = f'{name}.lora_A.weight'
|
| 271 |
+
lora_b_name = f'{name}.lora_B.weight'
|
| 272 |
+
if lora_a_name in lora and lora_b_name in lora:
|
| 273 |
+
updated_num += 1
|
| 274 |
+
module.lora_A_weights.append(lora[lora_a_name] * alpha)
|
| 275 |
+
module.lora_B_weights.append(lora[lora_b_name])
|
| 276 |
+
if verbose >= 1:
|
| 277 |
+
print(f"{updated_num} tensors are patched by LoRA. You can use `pipe.clear_lora()` to clear all LoRA layers.")
|
| 278 |
+
else:
|
| 279 |
+
lora_loader.fuse_lora_to_base_model(module, lora, alpha=alpha)
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
def clear_lora(self, verbose=1):
|
| 283 |
+
cleared_num = 0
|
| 284 |
+
for name, module in self.named_modules():
|
| 285 |
+
if isinstance(module, AutoWrappedLinear):
|
| 286 |
+
if hasattr(module, "lora_A_weights"):
|
| 287 |
+
if len(module.lora_A_weights) > 0:
|
| 288 |
+
cleared_num += 1
|
| 289 |
+
module.lora_A_weights.clear()
|
| 290 |
+
if hasattr(module, "lora_B_weights"):
|
| 291 |
+
module.lora_B_weights.clear()
|
| 292 |
+
if verbose >= 1:
|
| 293 |
+
print(f"{cleared_num} LoRA layers are cleared.")
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
def download_and_load_models(self, model_configs: list[ModelConfig] = [], vram_limit: float = None):
|
| 297 |
+
model_pool = ModelPool()
|
| 298 |
+
for model_config in model_configs:
|
| 299 |
+
model_config.download_if_necessary()
|
| 300 |
+
vram_config = model_config.vram_config()
|
| 301 |
+
vram_config["computation_dtype"] = vram_config["computation_dtype"] or self.torch_dtype
|
| 302 |
+
vram_config["computation_device"] = vram_config["computation_device"] or self.device
|
| 303 |
+
model_pool.auto_load_model(
|
| 304 |
+
model_config.path,
|
| 305 |
+
vram_config=vram_config,
|
| 306 |
+
vram_limit=vram_limit,
|
| 307 |
+
clear_parameters=model_config.clear_parameters,
|
| 308 |
+
state_dict=model_config.state_dict,
|
| 309 |
+
)
|
| 310 |
+
return model_pool
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
def check_vram_management_state(self):
|
| 314 |
+
vram_management_enabled = False
|
| 315 |
+
for module in self.children():
|
| 316 |
+
if hasattr(module, "vram_management_enabled") and getattr(module, "vram_management_enabled"):
|
| 317 |
+
vram_management_enabled = True
|
| 318 |
+
return vram_management_enabled
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
def cfg_guided_model_fn(self, model_fn, cfg_scale, inputs_shared, inputs_posi, inputs_nega, **inputs_others):
|
| 322 |
+
if inputs_shared.get("positive_only_lora", None) is not None:
|
| 323 |
+
self.clear_lora(verbose=0)
|
| 324 |
+
self.load_lora(self.dit, state_dict=inputs_shared["positive_only_lora"], verbose=0)
|
| 325 |
+
noise_pred_posi = model_fn(**inputs_posi, **inputs_shared, **inputs_others)
|
| 326 |
+
if cfg_scale != 1.0:
|
| 327 |
+
if inputs_shared.get("positive_only_lora", None) is not None:
|
| 328 |
+
self.clear_lora(verbose=0)
|
| 329 |
+
noise_pred_nega = model_fn(**inputs_nega, **inputs_shared, **inputs_others)
|
| 330 |
+
if isinstance(noise_pred_posi, tuple):
|
| 331 |
+
# Separately handling different output types of latents, eg. video and audio latents.
|
| 332 |
+
noise_pred = tuple(
|
| 333 |
+
n_nega + cfg_scale * (n_posi - n_nega)
|
| 334 |
+
for n_posi, n_nega in zip(noise_pred_posi, noise_pred_nega)
|
| 335 |
+
)
|
| 336 |
+
else:
|
| 337 |
+
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
| 338 |
+
else:
|
| 339 |
+
noise_pred = noise_pred_posi
|
| 340 |
+
return noise_pred
|
| 341 |
+
|
| 342 |
+
def compile_pipeline(self, mode: str = "default", dynamic: bool = True, fullgraph: bool = False, compile_models: list = None, **kwargs):
|
| 343 |
+
"""
|
| 344 |
+
compile the pipeline with torch.compile. The models that will be compiled are determined by the `compilable_models` attribute of the pipeline.
|
| 345 |
+
If a model has `_repeated_blocks` attribute, we will compile these blocks with regional compilation. Otherwise, we will compile the whole model.
|
| 346 |
+
See https://docs.pytorch.org/docs/stable/generated/torch.compile.html#torch.compile for details about compilation arguments.
|
| 347 |
+
Args:
|
| 348 |
+
mode: The compilation mode, which will be passed to `torch.compile`, options are "default", "reduce-overhead", "max-autotune" and "max-autotune-no-cudagraphs. Default to "default".
|
| 349 |
+
dynamic: Whether to enable dynamic graph compilation to support dynamic input shapes, which will be passed to `torch.compile`. Default to True (recommended).
|
| 350 |
+
fullgraph: Whether to use full graph compilation, which will be passed to `torch.compile`. Default to False (recommended).
|
| 351 |
+
compile_models: The list of model names to be compiled. If None, we will compile the models in `pipeline.compilable_models`. Default to None.
|
| 352 |
+
**kwargs: Other arguments for `torch.compile`.
|
| 353 |
+
"""
|
| 354 |
+
compile_models = compile_models or getattr(self, "compilable_models", [])
|
| 355 |
+
if len(compile_models) == 0:
|
| 356 |
+
print("No compilable models in the pipeline. Skip compilation.")
|
| 357 |
+
return
|
| 358 |
+
for name in compile_models:
|
| 359 |
+
model = getattr(self, name, None)
|
| 360 |
+
if model is None:
|
| 361 |
+
print(f"Model '{name}' not found in the pipeline.")
|
| 362 |
+
continue
|
| 363 |
+
repeated_blocks = getattr(model, "_repeated_blocks", None)
|
| 364 |
+
# regional compilation for repeated blocks.
|
| 365 |
+
if repeated_blocks is not None:
|
| 366 |
+
for submod in model.modules():
|
| 367 |
+
if submod.__class__.__name__ in repeated_blocks:
|
| 368 |
+
submod.compile(mode=mode, dynamic=dynamic, fullgraph=fullgraph, **kwargs)
|
| 369 |
+
# compile the whole model.
|
| 370 |
+
else:
|
| 371 |
+
model.compile(mode=mode, dynamic=dynamic, fullgraph=fullgraph, **kwargs)
|
| 372 |
+
print(f"{name} is compiled with mode={mode}, dynamic={dynamic}, fullgraph={fullgraph}.")
|
| 373 |
+
|
| 374 |
+
|
| 375 |
+
class PipelineUnitGraph:
|
| 376 |
+
def __init__(self):
|
| 377 |
+
pass
|
| 378 |
+
|
| 379 |
+
def build_edges(self, units: list[PipelineUnit]):
|
| 380 |
+
# Establish dependencies between units
|
| 381 |
+
# to search for subsequent related computation units.
|
| 382 |
+
last_compute_unit_id = {}
|
| 383 |
+
edges = []
|
| 384 |
+
for unit_id, unit in enumerate(units):
|
| 385 |
+
for input_param in unit.fetch_input_params():
|
| 386 |
+
if input_param in last_compute_unit_id:
|
| 387 |
+
edges.append((last_compute_unit_id[input_param], unit_id))
|
| 388 |
+
for output_param in unit.fetch_output_params():
|
| 389 |
+
last_compute_unit_id[output_param] = unit_id
|
| 390 |
+
return edges
|
| 391 |
+
|
| 392 |
+
def build_chains(self, units: list[PipelineUnit]):
|
| 393 |
+
# Establish updating chains for each variable
|
| 394 |
+
# to track their computation process.
|
| 395 |
+
params = sum([unit.fetch_input_params() + unit.fetch_output_params() for unit in units], [])
|
| 396 |
+
params = sorted(list(set(params)))
|
| 397 |
+
chains = {param: [] for param in params}
|
| 398 |
+
for unit_id, unit in enumerate(units):
|
| 399 |
+
for param in unit.fetch_output_params():
|
| 400 |
+
chains[param].append(unit_id)
|
| 401 |
+
return chains
|
| 402 |
+
|
| 403 |
+
def search_direct_unit_ids(self, units: list[PipelineUnit], model_names: list[str]):
|
| 404 |
+
# Search for units that directly participate in the model's computation.
|
| 405 |
+
related_unit_ids = []
|
| 406 |
+
for unit_id, unit in enumerate(units):
|
| 407 |
+
for model_name in model_names:
|
| 408 |
+
if unit.onload_model_names is not None and model_name in unit.onload_model_names:
|
| 409 |
+
related_unit_ids.append(unit_id)
|
| 410 |
+
break
|
| 411 |
+
return related_unit_ids
|
| 412 |
+
|
| 413 |
+
def search_related_unit_ids(self, edges, start_unit_ids, direction="target"):
|
| 414 |
+
# Search for subsequent related computation units.
|
| 415 |
+
related_unit_ids = [unit_id for unit_id in start_unit_ids]
|
| 416 |
+
while True:
|
| 417 |
+
neighbors = []
|
| 418 |
+
for source, target in edges:
|
| 419 |
+
if direction == "target" and source in related_unit_ids and target not in related_unit_ids:
|
| 420 |
+
neighbors.append(target)
|
| 421 |
+
elif direction == "source" and source not in related_unit_ids and target in related_unit_ids:
|
| 422 |
+
neighbors.append(source)
|
| 423 |
+
neighbors = sorted(list(set(neighbors)))
|
| 424 |
+
if len(neighbors) == 0:
|
| 425 |
+
break
|
| 426 |
+
else:
|
| 427 |
+
related_unit_ids.extend(neighbors)
|
| 428 |
+
related_unit_ids = sorted(list(set(related_unit_ids)))
|
| 429 |
+
return related_unit_ids
|
| 430 |
+
|
| 431 |
+
def search_updating_unit_ids(self, units: list[PipelineUnit], chains, related_unit_ids):
|
| 432 |
+
# If the input parameters of this subgraph are updated outside the subgraph,
|
| 433 |
+
# search for the units where these updates occur.
|
| 434 |
+
first_compute_unit_id = {}
|
| 435 |
+
for unit_id in related_unit_ids:
|
| 436 |
+
for param in units[unit_id].fetch_input_params():
|
| 437 |
+
if param not in first_compute_unit_id:
|
| 438 |
+
first_compute_unit_id[param] = unit_id
|
| 439 |
+
updating_unit_ids = []
|
| 440 |
+
for param in first_compute_unit_id:
|
| 441 |
+
unit_id = first_compute_unit_id[param]
|
| 442 |
+
chain = chains[param]
|
| 443 |
+
if unit_id in chain and chain.index(unit_id) != len(chain) - 1:
|
| 444 |
+
for unit_id_ in chain[chain.index(unit_id) + 1:]:
|
| 445 |
+
if unit_id_ not in related_unit_ids:
|
| 446 |
+
updating_unit_ids.append(unit_id_)
|
| 447 |
+
related_unit_ids.extend(updating_unit_ids)
|
| 448 |
+
related_unit_ids = sorted(list(set(related_unit_ids)))
|
| 449 |
+
return related_unit_ids
|
| 450 |
+
|
| 451 |
+
def split_pipeline_units(self, units: list[PipelineUnit], model_names: list[str]):
|
| 452 |
+
# Split the computation graph,
|
| 453 |
+
# separating all model-related computations.
|
| 454 |
+
related_unit_ids = self.search_direct_unit_ids(units, model_names)
|
| 455 |
+
edges = self.build_edges(units)
|
| 456 |
+
chains = self.build_chains(units)
|
| 457 |
+
while True:
|
| 458 |
+
num_related_unit_ids = len(related_unit_ids)
|
| 459 |
+
related_unit_ids = self.search_related_unit_ids(edges, related_unit_ids, "target")
|
| 460 |
+
related_unit_ids = self.search_updating_unit_ids(units, chains, related_unit_ids)
|
| 461 |
+
if len(related_unit_ids) == num_related_unit_ids:
|
| 462 |
+
break
|
| 463 |
+
else:
|
| 464 |
+
num_related_unit_ids = len(related_unit_ids)
|
| 465 |
+
related_units = [units[i] for i in related_unit_ids]
|
| 466 |
+
unrelated_units = [units[i] for i in range(len(units)) if i not in related_unit_ids]
|
| 467 |
+
return related_units, unrelated_units
|
| 468 |
+
|
| 469 |
+
|
| 470 |
+
class PipelineUnitRunner:
|
| 471 |
+
def __init__(self):
|
| 472 |
+
pass
|
| 473 |
+
|
| 474 |
+
def __call__(self, unit: PipelineUnit, pipe: BasePipeline, inputs_shared: dict, inputs_posi: dict, inputs_nega: dict) -> tuple[dict, dict]:
|
| 475 |
+
if unit.take_over:
|
| 476 |
+
# Let the pipeline unit take over this function.
|
| 477 |
+
inputs_shared, inputs_posi, inputs_nega = unit.process(pipe, inputs_shared=inputs_shared, inputs_posi=inputs_posi, inputs_nega=inputs_nega)
|
| 478 |
+
elif unit.seperate_cfg:
|
| 479 |
+
# Positive side
|
| 480 |
+
processor_inputs = {name: inputs_posi.get(name_) for name, name_ in unit.input_params_posi.items()}
|
| 481 |
+
if unit.input_params is not None:
|
| 482 |
+
for name in unit.input_params:
|
| 483 |
+
processor_inputs[name] = inputs_shared.get(name)
|
| 484 |
+
processor_outputs = unit.process(pipe, **processor_inputs)
|
| 485 |
+
inputs_posi.update(processor_outputs)
|
| 486 |
+
# Negative side
|
| 487 |
+
if inputs_shared["cfg_scale"] != 1:
|
| 488 |
+
processor_inputs = {name: inputs_nega.get(name_) for name, name_ in unit.input_params_nega.items()}
|
| 489 |
+
if unit.input_params is not None:
|
| 490 |
+
for name in unit.input_params:
|
| 491 |
+
processor_inputs[name] = inputs_shared.get(name)
|
| 492 |
+
processor_outputs = unit.process(pipe, **processor_inputs)
|
| 493 |
+
inputs_nega.update(processor_outputs)
|
| 494 |
+
else:
|
| 495 |
+
inputs_nega.update(processor_outputs)
|
| 496 |
+
else:
|
| 497 |
+
processor_inputs = {name: inputs_shared.get(name) for name in unit.input_params}
|
| 498 |
+
processor_outputs = unit.process(pipe, **processor_inputs)
|
| 499 |
+
inputs_shared.update(processor_outputs)
|
| 500 |
+
return inputs_shared, inputs_posi, inputs_nega
|
diffsynth/diffusion/flow_match.py
ADDED
|
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch, math
|
| 2 |
+
from typing_extensions import Literal
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class FlowMatchScheduler():
|
| 6 |
+
|
| 7 |
+
def __init__(self, template: Literal["FLUX.1", "Wan", "Qwen-Image", "FLUX.2", "Z-Image", "LTX-2", "Qwen-Image-Lightning"] = "FLUX.1"):
|
| 8 |
+
self.set_timesteps_fn = {
|
| 9 |
+
"FLUX.1": FlowMatchScheduler.set_timesteps_flux,
|
| 10 |
+
"Wan": FlowMatchScheduler.set_timesteps_wan,
|
| 11 |
+
"Qwen-Image": FlowMatchScheduler.set_timesteps_qwen_image,
|
| 12 |
+
"FLUX.2": FlowMatchScheduler.set_timesteps_flux2,
|
| 13 |
+
"Z-Image": FlowMatchScheduler.set_timesteps_z_image,
|
| 14 |
+
"LTX-2": FlowMatchScheduler.set_timesteps_ltx2,
|
| 15 |
+
"Qwen-Image-Lightning": FlowMatchScheduler.set_timesteps_qwen_image_lightning,
|
| 16 |
+
}.get(template, FlowMatchScheduler.set_timesteps_flux)
|
| 17 |
+
self.num_train_timesteps = 1000
|
| 18 |
+
|
| 19 |
+
@staticmethod
|
| 20 |
+
def set_timesteps_flux(num_inference_steps=100, denoising_strength=1.0, shift=None):
|
| 21 |
+
sigma_min = 0.003/1.002
|
| 22 |
+
sigma_max = 1.0
|
| 23 |
+
shift = 3 if shift is None else shift
|
| 24 |
+
num_train_timesteps = 1000
|
| 25 |
+
sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength
|
| 26 |
+
sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps)
|
| 27 |
+
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
|
| 28 |
+
timesteps = sigmas * num_train_timesteps
|
| 29 |
+
return sigmas, timesteps
|
| 30 |
+
|
| 31 |
+
@staticmethod
|
| 32 |
+
def set_timesteps_wan(num_inference_steps=100, denoising_strength=1.0, shift=None):
|
| 33 |
+
sigma_min = 0.0
|
| 34 |
+
sigma_max = 1.0
|
| 35 |
+
shift = 5 if shift is None else shift
|
| 36 |
+
num_train_timesteps = 1000
|
| 37 |
+
sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength
|
| 38 |
+
sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps + 1)[:-1]
|
| 39 |
+
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
|
| 40 |
+
timesteps = sigmas * num_train_timesteps
|
| 41 |
+
return sigmas, timesteps
|
| 42 |
+
|
| 43 |
+
@staticmethod
|
| 44 |
+
def _calculate_shift_qwen_image(image_seq_len, base_seq_len=256, max_seq_len=8192, base_shift=0.5, max_shift=0.9):
|
| 45 |
+
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
|
| 46 |
+
b = base_shift - m * base_seq_len
|
| 47 |
+
mu = image_seq_len * m + b
|
| 48 |
+
return mu
|
| 49 |
+
|
| 50 |
+
@staticmethod
|
| 51 |
+
def set_timesteps_qwen_image(num_inference_steps=100, denoising_strength=1.0, exponential_shift_mu=None, dynamic_shift_len=None):
|
| 52 |
+
sigma_min = 0.0
|
| 53 |
+
sigma_max = 1.0
|
| 54 |
+
num_train_timesteps = 1000
|
| 55 |
+
shift_terminal = 0.02
|
| 56 |
+
# Sigmas
|
| 57 |
+
sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength
|
| 58 |
+
sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps + 1)[:-1]
|
| 59 |
+
# Mu
|
| 60 |
+
if exponential_shift_mu is not None:
|
| 61 |
+
mu = exponential_shift_mu
|
| 62 |
+
elif dynamic_shift_len is not None:
|
| 63 |
+
mu = FlowMatchScheduler._calculate_shift_qwen_image(dynamic_shift_len)
|
| 64 |
+
else:
|
| 65 |
+
mu = 0.8
|
| 66 |
+
sigmas = math.exp(mu) / (math.exp(mu) + (1 / sigmas - 1))
|
| 67 |
+
# Shift terminal
|
| 68 |
+
one_minus_z = 1 - sigmas
|
| 69 |
+
scale_factor = one_minus_z[-1] / (1 - shift_terminal)
|
| 70 |
+
sigmas = 1 - (one_minus_z / scale_factor)
|
| 71 |
+
# Timesteps
|
| 72 |
+
timesteps = sigmas * num_train_timesteps
|
| 73 |
+
return sigmas, timesteps
|
| 74 |
+
|
| 75 |
+
@staticmethod
|
| 76 |
+
def set_timesteps_qwen_image_lightning(num_inference_steps=100, denoising_strength=1.0, exponential_shift_mu=None, dynamic_shift_len=None):
|
| 77 |
+
sigma_min = 0.0
|
| 78 |
+
sigma_max = 1.0
|
| 79 |
+
num_train_timesteps = 1000
|
| 80 |
+
base_shift = math.log(3)
|
| 81 |
+
max_shift = math.log(3)
|
| 82 |
+
# Sigmas
|
| 83 |
+
sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength
|
| 84 |
+
sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps + 1)[:-1]
|
| 85 |
+
# Mu
|
| 86 |
+
if exponential_shift_mu is not None:
|
| 87 |
+
mu = exponential_shift_mu
|
| 88 |
+
elif dynamic_shift_len is not None:
|
| 89 |
+
mu = FlowMatchScheduler._calculate_shift_qwen_image(dynamic_shift_len, base_shift=base_shift, max_shift=max_shift)
|
| 90 |
+
else:
|
| 91 |
+
mu = 0.8
|
| 92 |
+
sigmas = math.exp(mu) / (math.exp(mu) + (1 / sigmas - 1))
|
| 93 |
+
# Timesteps
|
| 94 |
+
timesteps = sigmas * num_train_timesteps
|
| 95 |
+
return sigmas, timesteps
|
| 96 |
+
|
| 97 |
+
@staticmethod
|
| 98 |
+
def compute_empirical_mu(image_seq_len, num_steps):
|
| 99 |
+
a1, b1 = 8.73809524e-05, 1.89833333
|
| 100 |
+
a2, b2 = 0.00016927, 0.45666666
|
| 101 |
+
|
| 102 |
+
if image_seq_len > 4300:
|
| 103 |
+
mu = a2 * image_seq_len + b2
|
| 104 |
+
return float(mu)
|
| 105 |
+
|
| 106 |
+
m_200 = a2 * image_seq_len + b2
|
| 107 |
+
m_10 = a1 * image_seq_len + b1
|
| 108 |
+
|
| 109 |
+
a = (m_200 - m_10) / 190.0
|
| 110 |
+
b = m_200 - 200.0 * a
|
| 111 |
+
mu = a * num_steps + b
|
| 112 |
+
|
| 113 |
+
return float(mu)
|
| 114 |
+
|
| 115 |
+
@staticmethod
|
| 116 |
+
def set_timesteps_flux2(num_inference_steps=100, denoising_strength=1.0, dynamic_shift_len=None):
|
| 117 |
+
sigma_min = 1 / num_inference_steps
|
| 118 |
+
sigma_max = 1.0
|
| 119 |
+
num_train_timesteps = 1000
|
| 120 |
+
sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength
|
| 121 |
+
sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps)
|
| 122 |
+
if dynamic_shift_len is None:
|
| 123 |
+
# If you ask me why I set mu=0.8,
|
| 124 |
+
# I can only say that it yields better training results.
|
| 125 |
+
mu = 0.8
|
| 126 |
+
else:
|
| 127 |
+
mu = FlowMatchScheduler.compute_empirical_mu(dynamic_shift_len, num_inference_steps)
|
| 128 |
+
sigmas = math.exp(mu) / (math.exp(mu) + (1 / sigmas - 1))
|
| 129 |
+
timesteps = sigmas * num_train_timesteps
|
| 130 |
+
return sigmas, timesteps
|
| 131 |
+
|
| 132 |
+
@staticmethod
|
| 133 |
+
def set_timesteps_z_image(num_inference_steps=100, denoising_strength=1.0, shift=None, target_timesteps=None):
|
| 134 |
+
sigma_min = 0.0
|
| 135 |
+
sigma_max = 1.0
|
| 136 |
+
shift = 3 if shift is None else shift
|
| 137 |
+
num_train_timesteps = 1000
|
| 138 |
+
sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength
|
| 139 |
+
sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps + 1)[:-1]
|
| 140 |
+
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
|
| 141 |
+
timesteps = sigmas * num_train_timesteps
|
| 142 |
+
if target_timesteps is not None:
|
| 143 |
+
target_timesteps = target_timesteps.to(dtype=timesteps.dtype, device=timesteps.device)
|
| 144 |
+
for timestep in target_timesteps:
|
| 145 |
+
timestep_id = torch.argmin((timesteps - timestep).abs())
|
| 146 |
+
timesteps[timestep_id] = timestep
|
| 147 |
+
return sigmas, timesteps
|
| 148 |
+
|
| 149 |
+
@staticmethod
|
| 150 |
+
def set_timesteps_ltx2(num_inference_steps=100, denoising_strength=1.0, dynamic_shift_len=None, terminal=0.1, special_case=None):
|
| 151 |
+
num_train_timesteps = 1000
|
| 152 |
+
if special_case == "stage2":
|
| 153 |
+
sigmas = torch.Tensor([0.909375, 0.725, 0.421875])
|
| 154 |
+
elif special_case == "ditilled_stage1":
|
| 155 |
+
sigmas = torch.Tensor([1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875])
|
| 156 |
+
else:
|
| 157 |
+
dynamic_shift_len = dynamic_shift_len or 4096
|
| 158 |
+
sigma_shift = FlowMatchScheduler._calculate_shift_qwen_image(
|
| 159 |
+
image_seq_len=dynamic_shift_len,
|
| 160 |
+
base_seq_len=1024,
|
| 161 |
+
max_seq_len=4096,
|
| 162 |
+
base_shift=0.95,
|
| 163 |
+
max_shift=2.05,
|
| 164 |
+
)
|
| 165 |
+
sigma_min = 0.0
|
| 166 |
+
sigma_max = 1.0
|
| 167 |
+
sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength
|
| 168 |
+
sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps + 1)[:-1]
|
| 169 |
+
sigmas = math.exp(sigma_shift) / (math.exp(sigma_shift) + (1 / sigmas - 1))
|
| 170 |
+
# Shift terminal
|
| 171 |
+
one_minus_z = 1.0 - sigmas
|
| 172 |
+
scale_factor = one_minus_z[-1] / (1 - terminal)
|
| 173 |
+
sigmas = 1.0 - (one_minus_z / scale_factor)
|
| 174 |
+
timesteps = sigmas * num_train_timesteps
|
| 175 |
+
return sigmas, timesteps
|
| 176 |
+
|
| 177 |
+
def set_training_weight(self):
|
| 178 |
+
steps = 1000
|
| 179 |
+
x = self.timesteps
|
| 180 |
+
y = torch.exp(-2 * ((x - steps / 2) / steps) ** 2)
|
| 181 |
+
y_shifted = y - y.min()
|
| 182 |
+
bsmntw_weighing = y_shifted * (steps / y_shifted.sum())
|
| 183 |
+
if len(self.timesteps) != 1000:
|
| 184 |
+
# This is an empirical formula.
|
| 185 |
+
bsmntw_weighing = bsmntw_weighing * (len(self.timesteps) / steps)
|
| 186 |
+
bsmntw_weighing = bsmntw_weighing + bsmntw_weighing[1]
|
| 187 |
+
self.linear_timesteps_weights = bsmntw_weighing
|
| 188 |
+
|
| 189 |
+
def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, training=False, **kwargs):
|
| 190 |
+
self.sigmas, self.timesteps = self.set_timesteps_fn(
|
| 191 |
+
num_inference_steps=num_inference_steps,
|
| 192 |
+
denoising_strength=denoising_strength,
|
| 193 |
+
**kwargs,
|
| 194 |
+
)
|
| 195 |
+
if training:
|
| 196 |
+
self.set_training_weight()
|
| 197 |
+
self.training = True
|
| 198 |
+
else:
|
| 199 |
+
self.training = False
|
| 200 |
+
|
| 201 |
+
def step(self, model_output, timestep, sample, to_final=False, **kwargs):
|
| 202 |
+
if isinstance(timestep, torch.Tensor):
|
| 203 |
+
timestep = timestep.cpu()
|
| 204 |
+
timestep_id = torch.argmin((self.timesteps - timestep).abs())
|
| 205 |
+
sigma = self.sigmas[timestep_id]
|
| 206 |
+
if to_final or timestep_id + 1 >= len(self.timesteps):
|
| 207 |
+
sigma_ = 0
|
| 208 |
+
else:
|
| 209 |
+
sigma_ = self.sigmas[timestep_id + 1]
|
| 210 |
+
prev_sample = sample + model_output * (sigma_ - sigma)
|
| 211 |
+
return prev_sample
|
| 212 |
+
|
| 213 |
+
def return_to_timestep(self, timestep, sample, sample_stablized):
|
| 214 |
+
if isinstance(timestep, torch.Tensor):
|
| 215 |
+
timestep = timestep.cpu()
|
| 216 |
+
timestep_id = torch.argmin((self.timesteps - timestep).abs())
|
| 217 |
+
sigma = self.sigmas[timestep_id]
|
| 218 |
+
model_output = (sample - sample_stablized) / sigma
|
| 219 |
+
return model_output
|
| 220 |
+
|
| 221 |
+
def add_noise(self, original_samples, noise, timestep):
|
| 222 |
+
if isinstance(timestep, torch.Tensor):
|
| 223 |
+
timestep = timestep.cpu()
|
| 224 |
+
timestep_id = torch.argmin((self.timesteps - timestep).abs())
|
| 225 |
+
sigma = self.sigmas[timestep_id]
|
| 226 |
+
sample = (1 - sigma) * original_samples + sigma * noise
|
| 227 |
+
return sample
|
| 228 |
+
|
| 229 |
+
def training_target(self, sample, noise, timestep):
|
| 230 |
+
target = noise - sample
|
| 231 |
+
return target
|
| 232 |
+
|
| 233 |
+
def training_weight(self, timestep):
|
| 234 |
+
timestep_id = torch.argmin((self.timesteps - timestep.to(self.timesteps.device)).abs())
|
| 235 |
+
weights = self.linear_timesteps_weights[timestep_id]
|
| 236 |
+
return weights
|
diffsynth/diffusion/logger.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os, torch
|
| 2 |
+
from accelerate import Accelerator
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class ModelLogger:
|
| 6 |
+
def __init__(self, output_path, remove_prefix_in_ckpt=None, state_dict_converter=lambda x:x, resume_step=0):
|
| 7 |
+
self.output_path = output_path
|
| 8 |
+
self.remove_prefix_in_ckpt = remove_prefix_in_ckpt
|
| 9 |
+
self.state_dict_converter = state_dict_converter
|
| 10 |
+
self.num_steps = resume_step
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def on_step_end(self, accelerator: Accelerator, model: torch.nn.Module, save_steps=None, **kwargs):
|
| 14 |
+
self.num_steps += 1
|
| 15 |
+
if save_steps is not None and self.num_steps % save_steps == 0:
|
| 16 |
+
self.save_model(accelerator, model, f"step-{self.num_steps}.safetensors")
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def on_epoch_end(self, accelerator: Accelerator, model: torch.nn.Module, epoch_id):
|
| 20 |
+
accelerator.wait_for_everyone()
|
| 21 |
+
state_dict = accelerator.get_state_dict(model)
|
| 22 |
+
if accelerator.is_main_process:
|
| 23 |
+
state_dict = accelerator.unwrap_model(model).export_trainable_state_dict(state_dict, remove_prefix=self.remove_prefix_in_ckpt)
|
| 24 |
+
state_dict = self.state_dict_converter(state_dict)
|
| 25 |
+
os.makedirs(self.output_path, exist_ok=True)
|
| 26 |
+
path = os.path.join(self.output_path, f"epoch-{epoch_id}.safetensors")
|
| 27 |
+
accelerator.save(state_dict, path, safe_serialization=True)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def on_training_end(self, accelerator: Accelerator, model: torch.nn.Module, save_steps=None):
|
| 31 |
+
if save_steps is not None and self.num_steps % save_steps != 0:
|
| 32 |
+
self.save_model(accelerator, model, f"step-{self.num_steps}.safetensors")
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def save_model(self, accelerator: Accelerator, model: torch.nn.Module, file_name):
|
| 36 |
+
accelerator.wait_for_everyone()
|
| 37 |
+
state_dict = accelerator.get_state_dict(model)
|
| 38 |
+
if accelerator.is_main_process:
|
| 39 |
+
state_dict = accelerator.unwrap_model(model).export_trainable_state_dict(state_dict, remove_prefix=self.remove_prefix_in_ckpt)
|
| 40 |
+
state_dict = self.state_dict_converter(state_dict)
|
| 41 |
+
os.makedirs(self.output_path, exist_ok=True)
|
| 42 |
+
path = os.path.join(self.output_path, file_name)
|
| 43 |
+
accelerator.save(state_dict, path, safe_serialization=True)
|
diffsynth/diffusion/loss.py
ADDED
|
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .base_pipeline import BasePipeline
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def FlowMatchSFTLoss(pipe: BasePipeline, **inputs):
|
| 6 |
+
max_timestep_boundary = int(inputs.get("max_timestep_boundary", 1) * len(pipe.scheduler.timesteps))
|
| 7 |
+
min_timestep_boundary = int(inputs.get("min_timestep_boundary", 0) * len(pipe.scheduler.timesteps))
|
| 8 |
+
|
| 9 |
+
timestep_id = torch.randint(min_timestep_boundary, max_timestep_boundary, (1,))
|
| 10 |
+
timestep = pipe.scheduler.timesteps[timestep_id].to(dtype=pipe.torch_dtype, device=pipe.device)
|
| 11 |
+
|
| 12 |
+
noise = torch.randn_like(inputs["input_latents"])
|
| 13 |
+
inputs["latents"] = pipe.scheduler.add_noise(inputs["input_latents"], noise, timestep)
|
| 14 |
+
training_target = pipe.scheduler.training_target(inputs["input_latents"], noise, timestep)
|
| 15 |
+
|
| 16 |
+
if "first_frame_latents" in inputs:
|
| 17 |
+
inputs["latents"][:, :, 0:1] = inputs["first_frame_latents"]
|
| 18 |
+
|
| 19 |
+
models = {name: getattr(pipe, name) for name in pipe.in_iteration_models}
|
| 20 |
+
noise_pred = pipe.model_fn(**models, **inputs, timestep=timestep)
|
| 21 |
+
|
| 22 |
+
if "first_frame_latents" in inputs:
|
| 23 |
+
noise_pred = noise_pred[:, :, 1:]
|
| 24 |
+
training_target = training_target[:, :, 1:]
|
| 25 |
+
|
| 26 |
+
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
|
| 27 |
+
loss = loss * pipe.scheduler.training_weight(timestep)
|
| 28 |
+
return loss
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def FlowMatchSFTAudioVideoLoss(pipe: BasePipeline, **inputs):
|
| 32 |
+
max_timestep_boundary = int(inputs.get("max_timestep_boundary", 1) * len(pipe.scheduler.timesteps))
|
| 33 |
+
min_timestep_boundary = int(inputs.get("min_timestep_boundary", 0) * len(pipe.scheduler.timesteps))
|
| 34 |
+
|
| 35 |
+
timestep_id = torch.randint(min_timestep_boundary, max_timestep_boundary, (1,))
|
| 36 |
+
timestep = pipe.scheduler.timesteps[timestep_id].to(dtype=pipe.torch_dtype, device=pipe.device)
|
| 37 |
+
|
| 38 |
+
# video
|
| 39 |
+
noise = torch.randn_like(inputs["input_latents"])
|
| 40 |
+
inputs["video_latents"] = pipe.scheduler.add_noise(inputs["input_latents"], noise, timestep)
|
| 41 |
+
training_target = pipe.scheduler.training_target(inputs["input_latents"], noise, timestep)
|
| 42 |
+
|
| 43 |
+
# audio
|
| 44 |
+
if inputs.get("audio_input_latents") is not None:
|
| 45 |
+
audio_noise = torch.randn_like(inputs["audio_input_latents"])
|
| 46 |
+
inputs["audio_latents"] = pipe.scheduler.add_noise(inputs["audio_input_latents"], audio_noise, timestep)
|
| 47 |
+
training_target_audio = pipe.scheduler.training_target(inputs["audio_input_latents"], audio_noise, timestep)
|
| 48 |
+
|
| 49 |
+
models = {name: getattr(pipe, name) for name in pipe.in_iteration_models}
|
| 50 |
+
noise_pred, noise_pred_audio = pipe.model_fn(**models, **inputs, timestep=timestep)
|
| 51 |
+
|
| 52 |
+
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
|
| 53 |
+
loss = loss * pipe.scheduler.training_weight(timestep)
|
| 54 |
+
if inputs.get("audio_input_latents") is not None:
|
| 55 |
+
loss_audio = torch.nn.functional.mse_loss(noise_pred_audio.float(), training_target_audio.float())
|
| 56 |
+
loss_audio = loss_audio * pipe.scheduler.training_weight(timestep)
|
| 57 |
+
loss = loss + loss_audio
|
| 58 |
+
return loss
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def DirectDistillLoss(pipe: BasePipeline, **inputs):
|
| 62 |
+
pipe.scheduler.set_timesteps(inputs["num_inference_steps"])
|
| 63 |
+
pipe.scheduler.training = True
|
| 64 |
+
models = {name: getattr(pipe, name) for name in pipe.in_iteration_models}
|
| 65 |
+
for progress_id, timestep in enumerate(pipe.scheduler.timesteps):
|
| 66 |
+
timestep = timestep.unsqueeze(0).to(dtype=pipe.torch_dtype, device=pipe.device)
|
| 67 |
+
noise_pred = pipe.model_fn(**models, **inputs, timestep=timestep, progress_id=progress_id)
|
| 68 |
+
inputs["latents"] = pipe.step(pipe.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs)
|
| 69 |
+
loss = torch.nn.functional.mse_loss(inputs["latents"].float(), inputs["input_latents"].float())
|
| 70 |
+
return loss
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class TrajectoryImitationLoss(torch.nn.Module):
|
| 74 |
+
def __init__(self):
|
| 75 |
+
super().__init__()
|
| 76 |
+
self.initialized = False
|
| 77 |
+
|
| 78 |
+
def initialize(self, device):
|
| 79 |
+
import lpips # TODO: remove it
|
| 80 |
+
self.loss_fn = lpips.LPIPS(net='alex').to(device)
|
| 81 |
+
self.initialized = True
|
| 82 |
+
|
| 83 |
+
def fetch_trajectory(self, pipe: BasePipeline, timesteps_student, inputs_shared, inputs_posi, inputs_nega, num_inference_steps, cfg_scale):
|
| 84 |
+
trajectory = [inputs_shared["latents"].clone()]
|
| 85 |
+
|
| 86 |
+
pipe.scheduler.set_timesteps(num_inference_steps, target_timesteps=timesteps_student)
|
| 87 |
+
models = {name: getattr(pipe, name) for name in pipe.in_iteration_models}
|
| 88 |
+
for progress_id, timestep in enumerate(pipe.scheduler.timesteps):
|
| 89 |
+
timestep = timestep.unsqueeze(0).to(dtype=pipe.torch_dtype, device=pipe.device)
|
| 90 |
+
noise_pred = pipe.cfg_guided_model_fn(
|
| 91 |
+
pipe.model_fn, cfg_scale,
|
| 92 |
+
inputs_shared, inputs_posi, inputs_nega,
|
| 93 |
+
**models, timestep=timestep, progress_id=progress_id
|
| 94 |
+
)
|
| 95 |
+
inputs_shared["latents"] = pipe.step(pipe.scheduler, progress_id=progress_id, noise_pred=noise_pred.detach(), **inputs_shared)
|
| 96 |
+
|
| 97 |
+
trajectory.append(inputs_shared["latents"].clone())
|
| 98 |
+
return pipe.scheduler.timesteps, trajectory
|
| 99 |
+
|
| 100 |
+
def align_trajectory(self, pipe: BasePipeline, timesteps_teacher, trajectory_teacher, inputs_shared, inputs_posi, inputs_nega, num_inference_steps, cfg_scale):
|
| 101 |
+
loss = 0
|
| 102 |
+
pipe.scheduler.set_timesteps(num_inference_steps, training=True)
|
| 103 |
+
models = {name: getattr(pipe, name) for name in pipe.in_iteration_models}
|
| 104 |
+
for progress_id, timestep in enumerate(pipe.scheduler.timesteps):
|
| 105 |
+
timestep = timestep.unsqueeze(0).to(dtype=pipe.torch_dtype, device=pipe.device)
|
| 106 |
+
|
| 107 |
+
progress_id_teacher = torch.argmin((timesteps_teacher - timestep).abs())
|
| 108 |
+
inputs_shared["latents"] = trajectory_teacher[progress_id_teacher]
|
| 109 |
+
|
| 110 |
+
noise_pred = pipe.cfg_guided_model_fn(
|
| 111 |
+
pipe.model_fn, cfg_scale,
|
| 112 |
+
inputs_shared, inputs_posi, inputs_nega,
|
| 113 |
+
**models, timestep=timestep, progress_id=progress_id
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
sigma = pipe.scheduler.sigmas[progress_id]
|
| 117 |
+
sigma_ = 0 if progress_id + 1 >= len(pipe.scheduler.timesteps) else pipe.scheduler.sigmas[progress_id + 1]
|
| 118 |
+
if progress_id + 1 >= len(pipe.scheduler.timesteps):
|
| 119 |
+
latents_ = trajectory_teacher[-1]
|
| 120 |
+
else:
|
| 121 |
+
progress_id_teacher = torch.argmin((timesteps_teacher - pipe.scheduler.timesteps[progress_id + 1]).abs())
|
| 122 |
+
latents_ = trajectory_teacher[progress_id_teacher]
|
| 123 |
+
|
| 124 |
+
denom = sigma_ - sigma
|
| 125 |
+
denom = torch.sign(denom) * torch.clamp(denom.abs(), min=1e-6)
|
| 126 |
+
target = (latents_ - inputs_shared["latents"]) / denom
|
| 127 |
+
loss = loss + torch.nn.functional.mse_loss(noise_pred.float(), target.float()) * pipe.scheduler.training_weight(timestep)
|
| 128 |
+
return loss
|
| 129 |
+
|
| 130 |
+
def compute_regularization(self, pipe: BasePipeline, trajectory_teacher, inputs_shared, inputs_posi, inputs_nega, num_inference_steps, cfg_scale):
|
| 131 |
+
inputs_shared["latents"] = trajectory_teacher[0]
|
| 132 |
+
pipe.scheduler.set_timesteps(num_inference_steps)
|
| 133 |
+
models = {name: getattr(pipe, name) for name in pipe.in_iteration_models}
|
| 134 |
+
for progress_id, timestep in enumerate(pipe.scheduler.timesteps):
|
| 135 |
+
timestep = timestep.unsqueeze(0).to(dtype=pipe.torch_dtype, device=pipe.device)
|
| 136 |
+
noise_pred = pipe.cfg_guided_model_fn(
|
| 137 |
+
pipe.model_fn, cfg_scale,
|
| 138 |
+
inputs_shared, inputs_posi, inputs_nega,
|
| 139 |
+
**models, timestep=timestep, progress_id=progress_id
|
| 140 |
+
)
|
| 141 |
+
inputs_shared["latents"] = pipe.step(pipe.scheduler, progress_id=progress_id, noise_pred=noise_pred.detach(), **inputs_shared)
|
| 142 |
+
|
| 143 |
+
image_pred = pipe.vae_decoder(inputs_shared["latents"])
|
| 144 |
+
image_real = pipe.vae_decoder(trajectory_teacher[-1])
|
| 145 |
+
loss = self.loss_fn(image_pred.float(), image_real.float())
|
| 146 |
+
return loss
|
| 147 |
+
|
| 148 |
+
def forward(self, pipe: BasePipeline, inputs_shared, inputs_posi, inputs_nega):
|
| 149 |
+
if not self.initialized:
|
| 150 |
+
self.initialize(pipe.device)
|
| 151 |
+
with torch.no_grad():
|
| 152 |
+
pipe.scheduler.set_timesteps(8)
|
| 153 |
+
timesteps_teacher, trajectory_teacher = self.fetch_trajectory(inputs_shared["teacher"], pipe.scheduler.timesteps, inputs_shared, inputs_posi, inputs_nega, 50, 2)
|
| 154 |
+
timesteps_teacher = timesteps_teacher.to(dtype=pipe.torch_dtype, device=pipe.device)
|
| 155 |
+
loss_1 = self.align_trajectory(pipe, timesteps_teacher, trajectory_teacher, inputs_shared, inputs_posi, inputs_nega, 8, 1)
|
| 156 |
+
loss_2 = self.compute_regularization(pipe, trajectory_teacher, inputs_shared, inputs_posi, inputs_nega, 8, 1)
|
| 157 |
+
loss = loss_1 + loss_2
|
| 158 |
+
return loss
|
diffsynth/diffusion/parsers.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def add_dataset_base_config(parser: argparse.ArgumentParser):
|
| 5 |
+
parser.add_argument("--dataset_base_path", type=str, default="", required=True, help="Base path of the dataset.")
|
| 6 |
+
parser.add_argument("--dataset_metadata_path", type=str, default=None, help="Path to the metadata file of the dataset.")
|
| 7 |
+
parser.add_argument("--dataset_repeat", type=int, default=1, help="Number of times to repeat the dataset per epoch.")
|
| 8 |
+
parser.add_argument("--dataset_num_workers", type=int, default=0, help="Number of workers for data loading.")
|
| 9 |
+
parser.add_argument("--data_file_keys", type=str, default="image,video", help="Data file keys in the metadata. Comma-separated.")
|
| 10 |
+
return parser
|
| 11 |
+
|
| 12 |
+
def add_image_size_config(parser: argparse.ArgumentParser):
|
| 13 |
+
parser.add_argument("--height", type=int, default=None, help="Height of images. Leave `height` and `width` empty to enable dynamic resolution.")
|
| 14 |
+
parser.add_argument("--width", type=int, default=None, help="Width of images. Leave `height` and `width` empty to enable dynamic resolution.")
|
| 15 |
+
parser.add_argument("--max_pixels", type=int, default=1024*1024, help="Maximum number of pixels per frame, used for dynamic resolution.")
|
| 16 |
+
return parser
|
| 17 |
+
|
| 18 |
+
def add_video_size_config(parser: argparse.ArgumentParser):
|
| 19 |
+
parser.add_argument("--height", type=int, default=None, help="Height of images. Leave `height` and `width` empty to enable dynamic resolution.")
|
| 20 |
+
parser.add_argument("--width", type=int, default=None, help="Width of images. Leave `height` and `width` empty to enable dynamic resolution.")
|
| 21 |
+
parser.add_argument("--max_pixels", type=int, default=1024*1024, help="Maximum number of pixels per frame, used for dynamic resolution.")
|
| 22 |
+
parser.add_argument("--num_frames", type=int, default=81, help="Number of frames per video. Frames are sampled from the video prefix.")
|
| 23 |
+
return parser
|
| 24 |
+
|
| 25 |
+
def add_model_config(parser: argparse.ArgumentParser):
|
| 26 |
+
parser.add_argument("--model_paths", type=str, default=None, help="Paths to load models. In JSON format.")
|
| 27 |
+
parser.add_argument("--model_id_with_origin_paths", type=str, default=None, help="Model ID with origin paths, e.g., Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors. Comma-separated.")
|
| 28 |
+
parser.add_argument("--extra_inputs", default=None, help="Additional model inputs, comma-separated.")
|
| 29 |
+
parser.add_argument("--fp8_models", default=None, help="Models with FP8 precision, comma-separated.")
|
| 30 |
+
parser.add_argument("--offload_models", default=None, help="Models with offload, comma-separated. Only used in splited training.")
|
| 31 |
+
return parser
|
| 32 |
+
|
| 33 |
+
def add_training_config(parser: argparse.ArgumentParser):
|
| 34 |
+
parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate.")
|
| 35 |
+
parser.add_argument("--num_epochs", type=int, default=1, help="Number of epochs.")
|
| 36 |
+
parser.add_argument("--trainable_models", type=str, default=None, help="Models to train, e.g., dit, vae, text_encoder.")
|
| 37 |
+
parser.add_argument("--find_unused_parameters", default=False, action="store_true", help="Whether to find unused parameters in DDP.")
|
| 38 |
+
parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay.")
|
| 39 |
+
parser.add_argument("--task", type=str, default="sft", required=False, help="Task type.")
|
| 40 |
+
return parser
|
| 41 |
+
|
| 42 |
+
def add_output_config(parser: argparse.ArgumentParser):
|
| 43 |
+
parser.add_argument("--output_path", type=str, default="./models", help="Output save path.")
|
| 44 |
+
parser.add_argument("--remove_prefix_in_ckpt", type=str, default="pipe.dit.", help="Remove prefix in ckpt.")
|
| 45 |
+
parser.add_argument("--save_steps", type=int, default=None, help="Number of checkpoint saving invervals. If None, checkpoints will be saved every epoch.")
|
| 46 |
+
parser.add_argument("--resume_step", type=int, default=0, help="Starting step count when resuming. ModelLogger.num_steps initializes here; training stops when num_steps reaches num_epochs * steps_per_epoch.")
|
| 47 |
+
return parser
|
| 48 |
+
|
| 49 |
+
def add_lora_config(parser: argparse.ArgumentParser):
|
| 50 |
+
parser.add_argument("--lora_base_model", type=str, default=None, help="Which model LoRA is added to.")
|
| 51 |
+
parser.add_argument("--lora_target_modules", type=str, default="q,k,v,o,ffn.0,ffn.2", help="Which layers LoRA is added to.")
|
| 52 |
+
parser.add_argument("--lora_rank", type=int, default=32, help="Rank of LoRA.")
|
| 53 |
+
parser.add_argument("--lora_checkpoint", type=str, default=None, help="Path to the LoRA checkpoint. If provided, LoRA will be loaded from this checkpoint.")
|
| 54 |
+
parser.add_argument("--preset_lora_path", type=str, default=None, help="Path to the preset LoRA checkpoint. If provided, this LoRA will be fused to the base model.")
|
| 55 |
+
parser.add_argument("--preset_lora_model", type=str, default=None, help="Which model the preset LoRA is fused to.")
|
| 56 |
+
return parser
|
| 57 |
+
|
| 58 |
+
def add_gradient_config(parser: argparse.ArgumentParser):
|
| 59 |
+
parser.add_argument("--use_gradient_checkpointing", default=False, action="store_true", help="Whether to use gradient checkpointing.")
|
| 60 |
+
parser.add_argument("--use_gradient_checkpointing_offload", default=False, action="store_true", help="Whether to offload gradient checkpointing to CPU memory.")
|
| 61 |
+
parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Gradient accumulation steps.")
|
| 62 |
+
return parser
|
| 63 |
+
|
| 64 |
+
def add_general_config(parser: argparse.ArgumentParser):
|
| 65 |
+
parser = add_dataset_base_config(parser)
|
| 66 |
+
parser = add_model_config(parser)
|
| 67 |
+
parser = add_training_config(parser)
|
| 68 |
+
parser = add_output_config(parser)
|
| 69 |
+
parser = add_lora_config(parser)
|
| 70 |
+
parser = add_gradient_config(parser)
|
| 71 |
+
return parser
|
diffsynth/diffusion/runner.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os, torch
|
| 2 |
+
from tqdm import tqdm
|
| 3 |
+
from accelerate import Accelerator
|
| 4 |
+
from .training_module import DiffusionTrainingModule
|
| 5 |
+
from .logger import ModelLogger
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def launch_training_task(
|
| 9 |
+
accelerator: Accelerator,
|
| 10 |
+
dataset: torch.utils.data.Dataset,
|
| 11 |
+
model: DiffusionTrainingModule,
|
| 12 |
+
model_logger: ModelLogger,
|
| 13 |
+
learning_rate: float = 1e-5,
|
| 14 |
+
weight_decay: float = 1e-2,
|
| 15 |
+
num_workers: int = 1,
|
| 16 |
+
save_steps: int = None,
|
| 17 |
+
num_epochs: int = 1,
|
| 18 |
+
args = None,
|
| 19 |
+
):
|
| 20 |
+
if args is not None:
|
| 21 |
+
learning_rate = args.learning_rate
|
| 22 |
+
weight_decay = args.weight_decay
|
| 23 |
+
num_workers = args.dataset_num_workers
|
| 24 |
+
save_steps = args.save_steps
|
| 25 |
+
num_epochs = args.num_epochs
|
| 26 |
+
|
| 27 |
+
optimizer = torch.optim.AdamW(model.trainable_modules(), lr=learning_rate, weight_decay=weight_decay)
|
| 28 |
+
scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer)
|
| 29 |
+
dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, collate_fn=lambda x: x[0], num_workers=num_workers)
|
| 30 |
+
model.to(device=accelerator.device)
|
| 31 |
+
# Exclude VAE from DeepSpeed ZeRO-3 wrapping to avoid compatibility issues
|
| 32 |
+
# Store VAE outside the module tree so DeepSpeed doesn't touch it
|
| 33 |
+
vae_module = getattr(model.pipe, 'vae', None)
|
| 34 |
+
if vae_module is not None:
|
| 35 |
+
del model.pipe._modules['vae']
|
| 36 |
+
model, optimizer, dataloader, scheduler = accelerator.prepare(model, optimizer, dataloader, scheduler)
|
| 37 |
+
if vae_module is not None:
|
| 38 |
+
vae_module.to(accelerator.device)
|
| 39 |
+
# Store VAE as a non-module attribute so pipeline code can still use pipe.vae
|
| 40 |
+
pipe = model.module.pipe if hasattr(model, 'module') else model.pipe
|
| 41 |
+
# Use object.__setattr__ to bypass nn.Module's __setattr__ which would register it as a submodule
|
| 42 |
+
object.__setattr__(pipe, 'vae', vae_module)
|
| 43 |
+
initialize_deepspeed_gradient_checkpointing(accelerator)
|
| 44 |
+
# Training log file
|
| 45 |
+
log_path = os.path.join(model_logger.output_path, "training_log.txt")
|
| 46 |
+
if accelerator.is_main_process:
|
| 47 |
+
os.makedirs(model_logger.output_path, exist_ok=True)
|
| 48 |
+
log_file = open(log_path, "a")
|
| 49 |
+
log_file.write(f"Training started. Epochs: {num_epochs}, LR: {learning_rate}, Steps/epoch: {len(dataloader)}\n")
|
| 50 |
+
log_file.flush()
|
| 51 |
+
else:
|
| 52 |
+
log_file = None
|
| 53 |
+
|
| 54 |
+
total_target = num_epochs * len(dataloader)
|
| 55 |
+
reached_target = False
|
| 56 |
+
for epoch_id in range(num_epochs):
|
| 57 |
+
if reached_target:
|
| 58 |
+
break
|
| 59 |
+
progress = tqdm(
|
| 60 |
+
total=total_target,
|
| 61 |
+
initial=model_logger.num_steps,
|
| 62 |
+
desc=f"Epoch {epoch_id+1}/{num_epochs}",
|
| 63 |
+
)
|
| 64 |
+
for step_id, data in enumerate(dataloader):
|
| 65 |
+
if model_logger.num_steps >= total_target:
|
| 66 |
+
reached_target = True
|
| 67 |
+
break
|
| 68 |
+
with accelerator.accumulate(model):
|
| 69 |
+
optimizer.zero_grad()
|
| 70 |
+
if dataset.load_from_cache:
|
| 71 |
+
loss = model({}, inputs=data)
|
| 72 |
+
else:
|
| 73 |
+
loss = model(data)
|
| 74 |
+
accelerator.backward(loss)
|
| 75 |
+
optimizer.step()
|
| 76 |
+
model_logger.on_step_end(accelerator, model, save_steps, loss=loss)
|
| 77 |
+
scheduler.step()
|
| 78 |
+
|
| 79 |
+
# Log loss
|
| 80 |
+
loss_val = loss.item()
|
| 81 |
+
progress.update(1)
|
| 82 |
+
progress.set_postfix(loss=f"{loss_val:.4f}")
|
| 83 |
+
if accelerator.is_main_process and log_file is not None and (model_logger.num_steps % 10 == 0 or model_logger.num_steps <= 5):
|
| 84 |
+
log_file.write(f"epoch={epoch_id+1} step={model_logger.num_steps} loss={loss_val:.6f}\n")
|
| 85 |
+
log_file.flush()
|
| 86 |
+
progress.close()
|
| 87 |
+
if save_steps is None:
|
| 88 |
+
model_logger.on_epoch_end(accelerator, model, epoch_id)
|
| 89 |
+
if accelerator.is_main_process and log_file is not None:
|
| 90 |
+
log_file.write(f"Epoch {epoch_id+1} completed. Checkpoint saved.\n")
|
| 91 |
+
log_file.flush()
|
| 92 |
+
model_logger.on_training_end(accelerator, model, save_steps)
|
| 93 |
+
if log_file is not None:
|
| 94 |
+
log_file.close()
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def launch_data_process_task(
|
| 98 |
+
accelerator: Accelerator,
|
| 99 |
+
dataset: torch.utils.data.Dataset,
|
| 100 |
+
model: DiffusionTrainingModule,
|
| 101 |
+
model_logger: ModelLogger,
|
| 102 |
+
num_workers: int = 8,
|
| 103 |
+
args = None,
|
| 104 |
+
):
|
| 105 |
+
if args is not None:
|
| 106 |
+
num_workers = args.dataset_num_workers
|
| 107 |
+
|
| 108 |
+
dataloader = torch.utils.data.DataLoader(dataset, shuffle=False, collate_fn=lambda x: x[0], num_workers=num_workers)
|
| 109 |
+
model.to(device=accelerator.device)
|
| 110 |
+
model, dataloader = accelerator.prepare(model, dataloader)
|
| 111 |
+
|
| 112 |
+
for data_id, data in enumerate(tqdm(dataloader)):
|
| 113 |
+
with accelerator.accumulate(model):
|
| 114 |
+
with torch.no_grad():
|
| 115 |
+
folder = os.path.join(model_logger.output_path, str(accelerator.process_index))
|
| 116 |
+
os.makedirs(folder, exist_ok=True)
|
| 117 |
+
save_path = os.path.join(model_logger.output_path, str(accelerator.process_index), f"{data_id}.pth")
|
| 118 |
+
data = model(data)
|
| 119 |
+
torch.save(data, save_path)
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def initialize_deepspeed_gradient_checkpointing(accelerator: Accelerator):
|
| 123 |
+
if getattr(accelerator.state, "deepspeed_plugin", None) is not None:
|
| 124 |
+
ds_config = accelerator.state.deepspeed_plugin.deepspeed_config
|
| 125 |
+
if "activation_checkpointing" in ds_config:
|
| 126 |
+
import deepspeed
|
| 127 |
+
act_config = ds_config["activation_checkpointing"]
|
| 128 |
+
deepspeed.checkpointing.configure(
|
| 129 |
+
mpu_=None,
|
| 130 |
+
partition_activations=act_config.get("partition_activations", False),
|
| 131 |
+
checkpoint_in_cpu=act_config.get("cpu_checkpointing", False),
|
| 132 |
+
contiguous_checkpointing=act_config.get("contiguous_memory_optimization", False)
|
| 133 |
+
)
|
| 134 |
+
else:
|
| 135 |
+
print("Do not find activation_checkpointing config in deepspeed config, skip initializing deepspeed gradient checkpointing.")
|
diffsynth/diffusion/training_module.py
ADDED
|
@@ -0,0 +1,302 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch, json, os, inspect
|
| 2 |
+
from ..core import ModelConfig, load_state_dict
|
| 3 |
+
from ..utils.controlnet import ControlNetInput
|
| 4 |
+
from .base_pipeline import PipelineUnit
|
| 5 |
+
from peft import LoraConfig, inject_adapter_in_model
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class GeneralUnit_RemoveCache(PipelineUnit):
|
| 9 |
+
def __init__(self, required_params=tuple(), force_remove_params_shared=tuple(), force_remove_params_posi=tuple(), force_remove_params_nega=tuple()):
|
| 10 |
+
super().__init__(take_over=True)
|
| 11 |
+
self.required_params = required_params
|
| 12 |
+
self.force_remove_params_shared = force_remove_params_shared
|
| 13 |
+
self.force_remove_params_posi = force_remove_params_posi
|
| 14 |
+
self.force_remove_params_nega = force_remove_params_nega
|
| 15 |
+
|
| 16 |
+
def process_params(self, inputs, required_params, force_remove_params):
|
| 17 |
+
inputs_ = {}
|
| 18 |
+
for name, param in inputs.items():
|
| 19 |
+
if name in required_params and name not in force_remove_params:
|
| 20 |
+
inputs_[name] = param
|
| 21 |
+
return inputs_
|
| 22 |
+
|
| 23 |
+
def process(self, pipe, inputs_shared, inputs_posi, inputs_nega):
|
| 24 |
+
inputs_shared = self.process_params(inputs_shared, self.required_params, self.force_remove_params_shared)
|
| 25 |
+
inputs_posi = self.process_params(inputs_posi, self.required_params, self.force_remove_params_posi)
|
| 26 |
+
inputs_nega = self.process_params(inputs_nega, self.required_params, self.force_remove_params_nega)
|
| 27 |
+
return inputs_shared, inputs_posi, inputs_nega
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class DiffusionTrainingModule(torch.nn.Module):
|
| 31 |
+
def __init__(self):
|
| 32 |
+
super().__init__()
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def to(self, *args, **kwargs):
|
| 36 |
+
for name, model in self.named_children():
|
| 37 |
+
model.to(*args, **kwargs)
|
| 38 |
+
return self
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def trainable_modules(self):
|
| 42 |
+
trainable_modules = filter(lambda p: p.requires_grad, self.parameters())
|
| 43 |
+
return trainable_modules
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def trainable_param_names(self):
|
| 47 |
+
trainable_param_names = list(filter(lambda named_param: named_param[1].requires_grad, self.named_parameters()))
|
| 48 |
+
trainable_param_names = set([named_param[0] for named_param in trainable_param_names])
|
| 49 |
+
return trainable_param_names
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def add_lora_to_model(self, model, target_modules, lora_rank, lora_alpha=None, upcast_dtype=None):
|
| 53 |
+
if lora_alpha is None:
|
| 54 |
+
lora_alpha = lora_rank
|
| 55 |
+
if isinstance(target_modules, list) and len(target_modules) == 1:
|
| 56 |
+
target_modules = target_modules[0]
|
| 57 |
+
lora_config = LoraConfig(r=lora_rank, lora_alpha=lora_alpha, target_modules=target_modules)
|
| 58 |
+
model = inject_adapter_in_model(lora_config, model)
|
| 59 |
+
if upcast_dtype is not None:
|
| 60 |
+
for param in model.parameters():
|
| 61 |
+
if param.requires_grad:
|
| 62 |
+
param.data = param.to(upcast_dtype)
|
| 63 |
+
return model
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def mapping_lora_state_dict(self, state_dict):
|
| 67 |
+
new_state_dict = {}
|
| 68 |
+
for key, value in state_dict.items():
|
| 69 |
+
if "lora_A.weight" in key or "lora_B.weight" in key:
|
| 70 |
+
new_key = key.replace("lora_A.weight", "lora_A.default.weight").replace("lora_B.weight", "lora_B.default.weight")
|
| 71 |
+
new_state_dict[new_key] = value
|
| 72 |
+
elif "lora_A.default.weight" in key or "lora_B.default.weight" in key:
|
| 73 |
+
new_state_dict[key] = value
|
| 74 |
+
return new_state_dict
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def export_trainable_state_dict(self, state_dict, remove_prefix=None):
|
| 78 |
+
trainable_param_names = self.trainable_param_names()
|
| 79 |
+
state_dict = {name: param for name, param in state_dict.items() if name in trainable_param_names}
|
| 80 |
+
if remove_prefix is not None:
|
| 81 |
+
state_dict_ = {}
|
| 82 |
+
for name, param in state_dict.items():
|
| 83 |
+
if name.startswith(remove_prefix):
|
| 84 |
+
name = name[len(remove_prefix):]
|
| 85 |
+
state_dict_[name] = param
|
| 86 |
+
state_dict = state_dict_
|
| 87 |
+
return state_dict
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def transfer_data_to_device(self, data, device, torch_float_dtype=None):
|
| 91 |
+
if data is None:
|
| 92 |
+
return data
|
| 93 |
+
elif isinstance(data, torch.Tensor):
|
| 94 |
+
data = data.to(device)
|
| 95 |
+
if torch_float_dtype is not None and data.dtype in [torch.float, torch.float16, torch.bfloat16]:
|
| 96 |
+
data = data.to(torch_float_dtype)
|
| 97 |
+
return data
|
| 98 |
+
elif isinstance(data, tuple):
|
| 99 |
+
data = tuple(self.transfer_data_to_device(x, device, torch_float_dtype) for x in data)
|
| 100 |
+
return data
|
| 101 |
+
elif isinstance(data, list):
|
| 102 |
+
data = list(self.transfer_data_to_device(x, device, torch_float_dtype) for x in data)
|
| 103 |
+
return data
|
| 104 |
+
elif isinstance(data, dict):
|
| 105 |
+
data = {i: self.transfer_data_to_device(data[i], device, torch_float_dtype) for i in data}
|
| 106 |
+
return data
|
| 107 |
+
else:
|
| 108 |
+
return data
|
| 109 |
+
|
| 110 |
+
def parse_vram_config(self, fp8=False, offload=False, device="cpu"):
|
| 111 |
+
if fp8:
|
| 112 |
+
return {
|
| 113 |
+
"offload_dtype": torch.float8_e4m3fn,
|
| 114 |
+
"offload_device": device,
|
| 115 |
+
"onload_dtype": torch.float8_e4m3fn,
|
| 116 |
+
"onload_device": device,
|
| 117 |
+
"preparing_dtype": torch.float8_e4m3fn,
|
| 118 |
+
"preparing_device": device,
|
| 119 |
+
"computation_dtype": torch.bfloat16,
|
| 120 |
+
"computation_device": device,
|
| 121 |
+
}
|
| 122 |
+
elif offload:
|
| 123 |
+
return {
|
| 124 |
+
"offload_dtype": "disk",
|
| 125 |
+
"offload_device": "disk",
|
| 126 |
+
"onload_dtype": "disk",
|
| 127 |
+
"onload_device": "disk",
|
| 128 |
+
"preparing_dtype": torch.bfloat16,
|
| 129 |
+
"preparing_device": device,
|
| 130 |
+
"computation_dtype": torch.bfloat16,
|
| 131 |
+
"computation_device": device,
|
| 132 |
+
"clear_parameters": True,
|
| 133 |
+
}
|
| 134 |
+
else:
|
| 135 |
+
return {}
|
| 136 |
+
|
| 137 |
+
def parse_model_configs(self, model_paths, model_id_with_origin_paths, fp8_models=None, offload_models=None, device="cpu"):
|
| 138 |
+
fp8_models = [] if fp8_models is None else fp8_models.split(",")
|
| 139 |
+
offload_models = [] if offload_models is None else offload_models.split(",")
|
| 140 |
+
model_configs = []
|
| 141 |
+
if model_paths is not None:
|
| 142 |
+
model_paths = json.loads(model_paths)
|
| 143 |
+
for path in model_paths:
|
| 144 |
+
vram_config = self.parse_vram_config(
|
| 145 |
+
fp8=path in fp8_models,
|
| 146 |
+
offload=path in offload_models,
|
| 147 |
+
device=device
|
| 148 |
+
)
|
| 149 |
+
model_configs.append(ModelConfig(path=path, **vram_config))
|
| 150 |
+
if model_id_with_origin_paths is not None:
|
| 151 |
+
model_id_with_origin_paths = model_id_with_origin_paths.split(",")
|
| 152 |
+
for model_id_with_origin_path in model_id_with_origin_paths:
|
| 153 |
+
vram_config = self.parse_vram_config(
|
| 154 |
+
fp8=model_id_with_origin_path in fp8_models,
|
| 155 |
+
offload=model_id_with_origin_path in offload_models,
|
| 156 |
+
device=device
|
| 157 |
+
)
|
| 158 |
+
config = self.parse_path_or_model_id(model_id_with_origin_path)
|
| 159 |
+
model_configs.append(ModelConfig(model_id=config.model_id, origin_file_pattern=config.origin_file_pattern, **vram_config))
|
| 160 |
+
return model_configs
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def parse_path_or_model_id(self, model_id_with_origin_path, default_value=None):
|
| 164 |
+
if model_id_with_origin_path is None:
|
| 165 |
+
return default_value
|
| 166 |
+
elif os.path.exists(model_id_with_origin_path):
|
| 167 |
+
return ModelConfig(path=model_id_with_origin_path)
|
| 168 |
+
else:
|
| 169 |
+
if ":" not in model_id_with_origin_path:
|
| 170 |
+
raise ValueError(f"Failed to parse model config: {model_id_with_origin_path}. This is neither a valid path nor in the format of `model_id/origin_file_pattern`.")
|
| 171 |
+
split_id = model_id_with_origin_path.rfind(":")
|
| 172 |
+
model_id = model_id_with_origin_path[:split_id]
|
| 173 |
+
origin_file_pattern = model_id_with_origin_path[split_id + 1:]
|
| 174 |
+
return ModelConfig(model_id=model_id, origin_file_pattern=origin_file_pattern)
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def auto_detect_lora_target_modules(
|
| 178 |
+
self,
|
| 179 |
+
model: torch.nn.Module,
|
| 180 |
+
search_for_linear=False,
|
| 181 |
+
linear_detector=lambda x: min(x.weight.shape) >= 512,
|
| 182 |
+
block_list_detector=lambda x: isinstance(x, torch.nn.ModuleList) and len(x) > 1,
|
| 183 |
+
name_prefix="",
|
| 184 |
+
):
|
| 185 |
+
lora_target_modules = []
|
| 186 |
+
if search_for_linear:
|
| 187 |
+
for name, module in model.named_modules():
|
| 188 |
+
module_name = name_prefix + ["", "."][name_prefix != ""] + name
|
| 189 |
+
if isinstance(module, torch.nn.Linear) and linear_detector(module):
|
| 190 |
+
lora_target_modules.append(module_name)
|
| 191 |
+
else:
|
| 192 |
+
for name, module in model.named_children():
|
| 193 |
+
module_name = name_prefix + ["", "."][name_prefix != ""] + name
|
| 194 |
+
lora_target_modules += self.auto_detect_lora_target_modules(
|
| 195 |
+
module,
|
| 196 |
+
search_for_linear=block_list_detector(module),
|
| 197 |
+
linear_detector=linear_detector,
|
| 198 |
+
block_list_detector=block_list_detector,
|
| 199 |
+
name_prefix=module_name,
|
| 200 |
+
)
|
| 201 |
+
return lora_target_modules
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
def parse_lora_target_modules(self, model, lora_target_modules):
|
| 205 |
+
if lora_target_modules == "":
|
| 206 |
+
print("No LoRA target modules specified. The framework will automatically search for them.")
|
| 207 |
+
lora_target_modules = self.auto_detect_lora_target_modules(model)
|
| 208 |
+
print(f"LoRA will be patched at {lora_target_modules}.")
|
| 209 |
+
else:
|
| 210 |
+
lora_target_modules = lora_target_modules.split(",")
|
| 211 |
+
return lora_target_modules
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
def switch_pipe_to_training_mode(
|
| 215 |
+
self,
|
| 216 |
+
pipe,
|
| 217 |
+
trainable_models=None,
|
| 218 |
+
lora_base_model=None, lora_target_modules="", lora_rank=32, lora_checkpoint=None,
|
| 219 |
+
preset_lora_path=None, preset_lora_model=None,
|
| 220 |
+
task="sft",
|
| 221 |
+
):
|
| 222 |
+
# Scheduler
|
| 223 |
+
pipe.scheduler.set_timesteps(1000, training=True)
|
| 224 |
+
|
| 225 |
+
# Freeze untrainable models
|
| 226 |
+
pipe.freeze_except([] if trainable_models is None else trainable_models.split(","))
|
| 227 |
+
|
| 228 |
+
# Preset LoRA
|
| 229 |
+
if preset_lora_path is not None:
|
| 230 |
+
pipe.load_lora(getattr(pipe, preset_lora_model), preset_lora_path)
|
| 231 |
+
|
| 232 |
+
# FP8
|
| 233 |
+
# FP8 relies on a model-specific memory management scheme.
|
| 234 |
+
# It is delegated to the subclass.
|
| 235 |
+
|
| 236 |
+
# Add LoRA to the base models
|
| 237 |
+
if lora_base_model is not None and not task.endswith(":data_process"):
|
| 238 |
+
if (not hasattr(pipe, lora_base_model)) or getattr(pipe, lora_base_model) is None:
|
| 239 |
+
print(f"No {lora_base_model} models in the pipeline. We cannot patch LoRA on the model. If this occurs during the data processing stage, it is normal.")
|
| 240 |
+
return
|
| 241 |
+
model = self.add_lora_to_model(
|
| 242 |
+
getattr(pipe, lora_base_model),
|
| 243 |
+
target_modules=self.parse_lora_target_modules(getattr(pipe, lora_base_model), lora_target_modules),
|
| 244 |
+
lora_rank=lora_rank,
|
| 245 |
+
upcast_dtype=pipe.torch_dtype,
|
| 246 |
+
)
|
| 247 |
+
if lora_checkpoint is not None:
|
| 248 |
+
state_dict = load_state_dict(lora_checkpoint)
|
| 249 |
+
state_dict = self.mapping_lora_state_dict(state_dict)
|
| 250 |
+
load_result = model.load_state_dict(state_dict, strict=False)
|
| 251 |
+
print(f"LoRA checkpoint loaded: {lora_checkpoint}, total {len(state_dict)} keys")
|
| 252 |
+
if len(load_result[1]) > 0:
|
| 253 |
+
print(f"Warning, LoRA key mismatch! Unexpected keys in LoRA checkpoint: {load_result[1]}")
|
| 254 |
+
setattr(pipe, lora_base_model, model)
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
def split_pipeline_units(
|
| 258 |
+
self, task, pipe,
|
| 259 |
+
trainable_models=None, lora_base_model=None,
|
| 260 |
+
# TODO: set `remove_unnecessary_params` to `True` by default
|
| 261 |
+
remove_unnecessary_params=False,
|
| 262 |
+
# TODO: move `loss_required_params` to `loss.py`
|
| 263 |
+
loss_required_params=("input_latents", "max_timestep_boundary", "min_timestep_boundary", "first_frame_latents", "video_latents", "audio_input_latents", "num_inference_steps"),
|
| 264 |
+
force_remove_params_shared=tuple(),
|
| 265 |
+
force_remove_params_posi=tuple(),
|
| 266 |
+
force_remove_params_nega=tuple(),
|
| 267 |
+
):
|
| 268 |
+
models_require_backward = []
|
| 269 |
+
if trainable_models is not None:
|
| 270 |
+
models_require_backward += trainable_models.split(",")
|
| 271 |
+
if lora_base_model is not None:
|
| 272 |
+
models_require_backward += [lora_base_model]
|
| 273 |
+
if task.endswith(":data_process"):
|
| 274 |
+
other_units, pipe.units = pipe.split_pipeline_units(models_require_backward)
|
| 275 |
+
if remove_unnecessary_params:
|
| 276 |
+
required_params = list(loss_required_params) + [i for i in inspect.signature(self.pipe.model_fn).parameters]
|
| 277 |
+
for unit in other_units:
|
| 278 |
+
required_params.extend(unit.fetch_input_params())
|
| 279 |
+
required_params = sorted(list(set(required_params)))
|
| 280 |
+
pipe.units.append(GeneralUnit_RemoveCache(required_params, force_remove_params_shared, force_remove_params_posi, force_remove_params_nega))
|
| 281 |
+
elif task.endswith(":train"):
|
| 282 |
+
pipe.units, _ = pipe.split_pipeline_units(models_require_backward)
|
| 283 |
+
return pipe
|
| 284 |
+
|
| 285 |
+
def parse_extra_inputs(self, data, extra_inputs, inputs_shared):
|
| 286 |
+
controlnet_keys_map = (
|
| 287 |
+
("blockwise_controlnet_", "blockwise_controlnet_inputs",),
|
| 288 |
+
("controlnet_", "controlnet_inputs"),
|
| 289 |
+
)
|
| 290 |
+
controlnet_inputs = {}
|
| 291 |
+
for extra_input in extra_inputs:
|
| 292 |
+
for prefix, name in controlnet_keys_map:
|
| 293 |
+
if extra_input.startswith(prefix):
|
| 294 |
+
if name not in controlnet_inputs:
|
| 295 |
+
controlnet_inputs[name] = {}
|
| 296 |
+
controlnet_inputs[name][extra_input.replace(prefix, "")] = data[extra_input]
|
| 297 |
+
break
|
| 298 |
+
else:
|
| 299 |
+
inputs_shared[extra_input] = data[extra_input]
|
| 300 |
+
for name, params in controlnet_inputs.items():
|
| 301 |
+
inputs_shared[name] = [ControlNetInput(**params)]
|
| 302 |
+
return inputs_shared
|
diffsynth/models/anima_dit.py
ADDED
|
@@ -0,0 +1,1307 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# original code from: comfy/ldm/cosmos/predict2.py
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn
|
| 5 |
+
from einops import rearrange, repeat
|
| 6 |
+
from einops.layers.torch import Rearrange
|
| 7 |
+
import logging
|
| 8 |
+
from typing import Callable, Optional, Tuple, List
|
| 9 |
+
import math
|
| 10 |
+
from torchvision import transforms
|
| 11 |
+
from ..core.attention import attention_forward
|
| 12 |
+
from ..core.gradient import gradient_checkpoint_forward
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class VideoPositionEmb(nn.Module):
|
| 16 |
+
def forward(self, x_B_T_H_W_C: torch.Tensor, fps=Optional[torch.Tensor], device=None, dtype=None) -> torch.Tensor:
|
| 17 |
+
"""
|
| 18 |
+
It delegates the embedding generation to generate_embeddings function.
|
| 19 |
+
"""
|
| 20 |
+
B_T_H_W_C = x_B_T_H_W_C.shape
|
| 21 |
+
embeddings = self.generate_embeddings(B_T_H_W_C, fps=fps, device=device, dtype=dtype)
|
| 22 |
+
|
| 23 |
+
return embeddings
|
| 24 |
+
|
| 25 |
+
def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor], device=None):
|
| 26 |
+
raise NotImplementedError
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def normalize(x: torch.Tensor, dim: Optional[List[int]] = None, eps: float = 0) -> torch.Tensor:
|
| 30 |
+
"""
|
| 31 |
+
Normalizes the input tensor along specified dimensions such that the average square norm of elements is adjusted.
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
x (torch.Tensor): The input tensor to normalize.
|
| 35 |
+
dim (list, optional): The dimensions over which to normalize. If None, normalizes over all dimensions except the first.
|
| 36 |
+
eps (float, optional): A small constant to ensure numerical stability during division.
|
| 37 |
+
|
| 38 |
+
Returns:
|
| 39 |
+
torch.Tensor: The normalized tensor.
|
| 40 |
+
"""
|
| 41 |
+
if dim is None:
|
| 42 |
+
dim = list(range(1, x.ndim))
|
| 43 |
+
norm = torch.linalg.vector_norm(x, dim=dim, keepdim=True, dtype=torch.float32)
|
| 44 |
+
norm = torch.add(eps, norm, alpha=math.sqrt(norm.numel() / x.numel()))
|
| 45 |
+
return x / norm.to(x.dtype)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class LearnablePosEmbAxis(VideoPositionEmb):
|
| 49 |
+
def __init__(
|
| 50 |
+
self,
|
| 51 |
+
*, # enforce keyword arguments
|
| 52 |
+
interpolation: str,
|
| 53 |
+
model_channels: int,
|
| 54 |
+
len_h: int,
|
| 55 |
+
len_w: int,
|
| 56 |
+
len_t: int,
|
| 57 |
+
device=None,
|
| 58 |
+
dtype=None,
|
| 59 |
+
**kwargs,
|
| 60 |
+
):
|
| 61 |
+
"""
|
| 62 |
+
Args:
|
| 63 |
+
interpolation (str): we curretly only support "crop", ideally when we need extrapolation capacity, we should adjust frequency or other more advanced methods. they are not implemented yet.
|
| 64 |
+
"""
|
| 65 |
+
del kwargs # unused
|
| 66 |
+
super().__init__()
|
| 67 |
+
self.interpolation = interpolation
|
| 68 |
+
assert self.interpolation in ["crop"], f"Unknown interpolation method {self.interpolation}"
|
| 69 |
+
|
| 70 |
+
self.pos_emb_h = nn.Parameter(torch.empty(len_h, model_channels, device=device, dtype=dtype))
|
| 71 |
+
self.pos_emb_w = nn.Parameter(torch.empty(len_w, model_channels, device=device, dtype=dtype))
|
| 72 |
+
self.pos_emb_t = nn.Parameter(torch.empty(len_t, model_channels, device=device, dtype=dtype))
|
| 73 |
+
|
| 74 |
+
def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor], device=None, dtype=None) -> torch.Tensor:
|
| 75 |
+
B, T, H, W, _ = B_T_H_W_C
|
| 76 |
+
if self.interpolation == "crop":
|
| 77 |
+
emb_h_H = self.pos_emb_h[:H].to(device=device, dtype=dtype)
|
| 78 |
+
emb_w_W = self.pos_emb_w[:W].to(device=device, dtype=dtype)
|
| 79 |
+
emb_t_T = self.pos_emb_t[:T].to(device=device, dtype=dtype)
|
| 80 |
+
emb = (
|
| 81 |
+
repeat(emb_t_T, "t d-> b t h w d", b=B, h=H, w=W)
|
| 82 |
+
+ repeat(emb_h_H, "h d-> b t h w d", b=B, t=T, w=W)
|
| 83 |
+
+ repeat(emb_w_W, "w d-> b t h w d", b=B, t=T, h=H)
|
| 84 |
+
)
|
| 85 |
+
assert list(emb.shape)[:4] == [B, T, H, W], f"bad shape: {list(emb.shape)[:4]} != {B, T, H, W}"
|
| 86 |
+
else:
|
| 87 |
+
raise ValueError(f"Unknown interpolation method {self.interpolation}")
|
| 88 |
+
|
| 89 |
+
return normalize(emb, dim=-1, eps=1e-6)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
class VideoRopePosition3DEmb(VideoPositionEmb):
|
| 93 |
+
def __init__(
|
| 94 |
+
self,
|
| 95 |
+
*, # enforce keyword arguments
|
| 96 |
+
head_dim: int,
|
| 97 |
+
len_h: int,
|
| 98 |
+
len_w: int,
|
| 99 |
+
len_t: int,
|
| 100 |
+
base_fps: int = 24,
|
| 101 |
+
h_extrapolation_ratio: float = 1.0,
|
| 102 |
+
w_extrapolation_ratio: float = 1.0,
|
| 103 |
+
t_extrapolation_ratio: float = 1.0,
|
| 104 |
+
enable_fps_modulation: bool = True,
|
| 105 |
+
device=None,
|
| 106 |
+
**kwargs, # used for compatibility with other positional embeddings; unused in this class
|
| 107 |
+
):
|
| 108 |
+
del kwargs
|
| 109 |
+
super().__init__()
|
| 110 |
+
self.base_fps = base_fps
|
| 111 |
+
self.max_h = len_h
|
| 112 |
+
self.max_w = len_w
|
| 113 |
+
self.enable_fps_modulation = enable_fps_modulation
|
| 114 |
+
|
| 115 |
+
dim = head_dim
|
| 116 |
+
dim_h = dim // 6 * 2
|
| 117 |
+
dim_w = dim_h
|
| 118 |
+
dim_t = dim - 2 * dim_h
|
| 119 |
+
assert dim == dim_h + dim_w + dim_t, f"bad dim: {dim} != {dim_h} + {dim_w} + {dim_t}"
|
| 120 |
+
self.register_buffer(
|
| 121 |
+
"dim_spatial_range",
|
| 122 |
+
torch.arange(0, dim_h, 2, device=device)[: (dim_h // 2)].float() / dim_h,
|
| 123 |
+
persistent=False,
|
| 124 |
+
)
|
| 125 |
+
self.register_buffer(
|
| 126 |
+
"dim_temporal_range",
|
| 127 |
+
torch.arange(0, dim_t, 2, device=device)[: (dim_t // 2)].float() / dim_t,
|
| 128 |
+
persistent=False,
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
self.h_ntk_factor = h_extrapolation_ratio ** (dim_h / (dim_h - 2))
|
| 132 |
+
self.w_ntk_factor = w_extrapolation_ratio ** (dim_w / (dim_w - 2))
|
| 133 |
+
self.t_ntk_factor = t_extrapolation_ratio ** (dim_t / (dim_t - 2))
|
| 134 |
+
|
| 135 |
+
def generate_embeddings(
|
| 136 |
+
self,
|
| 137 |
+
B_T_H_W_C: torch.Size,
|
| 138 |
+
fps: Optional[torch.Tensor] = None,
|
| 139 |
+
h_ntk_factor: Optional[float] = None,
|
| 140 |
+
w_ntk_factor: Optional[float] = None,
|
| 141 |
+
t_ntk_factor: Optional[float] = None,
|
| 142 |
+
device=None,
|
| 143 |
+
dtype=None,
|
| 144 |
+
):
|
| 145 |
+
"""
|
| 146 |
+
Generate embeddings for the given input size.
|
| 147 |
+
|
| 148 |
+
Args:
|
| 149 |
+
B_T_H_W_C (torch.Size): Input tensor size (Batch, Time, Height, Width, Channels).
|
| 150 |
+
fps (Optional[torch.Tensor], optional): Frames per second. Defaults to None.
|
| 151 |
+
h_ntk_factor (Optional[float], optional): Height NTK factor. If None, uses self.h_ntk_factor.
|
| 152 |
+
w_ntk_factor (Optional[float], optional): Width NTK factor. If None, uses self.w_ntk_factor.
|
| 153 |
+
t_ntk_factor (Optional[float], optional): Time NTK factor. If None, uses self.t_ntk_factor.
|
| 154 |
+
|
| 155 |
+
Returns:
|
| 156 |
+
Not specified in the original code snippet.
|
| 157 |
+
"""
|
| 158 |
+
h_ntk_factor = h_ntk_factor if h_ntk_factor is not None else self.h_ntk_factor
|
| 159 |
+
w_ntk_factor = w_ntk_factor if w_ntk_factor is not None else self.w_ntk_factor
|
| 160 |
+
t_ntk_factor = t_ntk_factor if t_ntk_factor is not None else self.t_ntk_factor
|
| 161 |
+
|
| 162 |
+
h_theta = 10000.0 * h_ntk_factor
|
| 163 |
+
w_theta = 10000.0 * w_ntk_factor
|
| 164 |
+
t_theta = 10000.0 * t_ntk_factor
|
| 165 |
+
|
| 166 |
+
h_spatial_freqs = 1.0 / (h_theta**self.dim_spatial_range.to(device=device))
|
| 167 |
+
w_spatial_freqs = 1.0 / (w_theta**self.dim_spatial_range.to(device=device))
|
| 168 |
+
temporal_freqs = 1.0 / (t_theta**self.dim_temporal_range.to(device=device))
|
| 169 |
+
|
| 170 |
+
B, T, H, W, _ = B_T_H_W_C
|
| 171 |
+
seq = torch.arange(max(H, W, T), dtype=torch.float, device=device)
|
| 172 |
+
uniform_fps = (fps is None) or isinstance(fps, (int, float)) or (fps.min() == fps.max())
|
| 173 |
+
assert (
|
| 174 |
+
uniform_fps or B == 1 or T == 1
|
| 175 |
+
), "For video batch, batch size should be 1 for non-uniform fps. For image batch, T should be 1"
|
| 176 |
+
half_emb_h = torch.outer(seq[:H].to(device=device), h_spatial_freqs)
|
| 177 |
+
half_emb_w = torch.outer(seq[:W].to(device=device), w_spatial_freqs)
|
| 178 |
+
|
| 179 |
+
# apply sequence scaling in temporal dimension
|
| 180 |
+
if fps is None or self.enable_fps_modulation is False: # image case
|
| 181 |
+
half_emb_t = torch.outer(seq[:T].to(device=device), temporal_freqs)
|
| 182 |
+
else:
|
| 183 |
+
half_emb_t = torch.outer(seq[:T].to(device=device) / fps * self.base_fps, temporal_freqs)
|
| 184 |
+
|
| 185 |
+
half_emb_h = torch.stack([torch.cos(half_emb_h), -torch.sin(half_emb_h), torch.sin(half_emb_h), torch.cos(half_emb_h)], dim=-1)
|
| 186 |
+
half_emb_w = torch.stack([torch.cos(half_emb_w), -torch.sin(half_emb_w), torch.sin(half_emb_w), torch.cos(half_emb_w)], dim=-1)
|
| 187 |
+
half_emb_t = torch.stack([torch.cos(half_emb_t), -torch.sin(half_emb_t), torch.sin(half_emb_t), torch.cos(half_emb_t)], dim=-1)
|
| 188 |
+
|
| 189 |
+
em_T_H_W_D = torch.cat(
|
| 190 |
+
[
|
| 191 |
+
repeat(half_emb_t, "t d x -> t h w d x", h=H, w=W),
|
| 192 |
+
repeat(half_emb_h, "h d x -> t h w d x", t=T, w=W),
|
| 193 |
+
repeat(half_emb_w, "w d x -> t h w d x", t=T, h=H),
|
| 194 |
+
]
|
| 195 |
+
, dim=-2,
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
return rearrange(em_T_H_W_D, "t h w d (i j) -> (t h w) d i j", i=2, j=2).float()
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
def apply_rotary_pos_emb(
|
| 202 |
+
t: torch.Tensor,
|
| 203 |
+
freqs: torch.Tensor,
|
| 204 |
+
) -> torch.Tensor:
|
| 205 |
+
t_ = t.reshape(*t.shape[:-1], 2, -1).movedim(-2, -1).unsqueeze(-2).float()
|
| 206 |
+
t_out = freqs[..., 0] * t_[..., 0] + freqs[..., 1] * t_[..., 1]
|
| 207 |
+
t_out = t_out.movedim(-1, -2).reshape(*t.shape).type_as(t)
|
| 208 |
+
return t_out
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
# ---------------------- Feed Forward Network -----------------------
|
| 212 |
+
class GPT2FeedForward(nn.Module):
|
| 213 |
+
def __init__(self, d_model: int, d_ff: int, device=None, dtype=None, operations=None) -> None:
|
| 214 |
+
super().__init__()
|
| 215 |
+
self.activation = nn.GELU()
|
| 216 |
+
self.layer1 = operations.Linear(d_model, d_ff, bias=False, device=device, dtype=dtype)
|
| 217 |
+
self.layer2 = operations.Linear(d_ff, d_model, bias=False, device=device, dtype=dtype)
|
| 218 |
+
|
| 219 |
+
self._layer_id = None
|
| 220 |
+
self._dim = d_model
|
| 221 |
+
self._hidden_dim = d_ff
|
| 222 |
+
|
| 223 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 224 |
+
x = self.layer1(x)
|
| 225 |
+
|
| 226 |
+
x = self.activation(x)
|
| 227 |
+
x = self.layer2(x)
|
| 228 |
+
return x
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
def torch_attention_op(q_B_S_H_D: torch.Tensor, k_B_S_H_D: torch.Tensor, v_B_S_H_D: torch.Tensor, transformer_options: Optional[dict] = {}) -> torch.Tensor:
|
| 232 |
+
"""Computes multi-head attention using PyTorch's native implementation.
|
| 233 |
+
|
| 234 |
+
This function provides a PyTorch backend alternative to Transformer Engine's attention operation.
|
| 235 |
+
It rearranges the input tensors to match PyTorch's expected format, computes scaled dot-product
|
| 236 |
+
attention, and rearranges the output back to the original format.
|
| 237 |
+
|
| 238 |
+
The input tensor names use the following dimension conventions:
|
| 239 |
+
|
| 240 |
+
- B: batch size
|
| 241 |
+
- S: sequence length
|
| 242 |
+
- H: number of attention heads
|
| 243 |
+
- D: head dimension
|
| 244 |
+
|
| 245 |
+
Args:
|
| 246 |
+
q_B_S_H_D: Query tensor with shape (batch, seq_len, n_heads, head_dim)
|
| 247 |
+
k_B_S_H_D: Key tensor with shape (batch, seq_len, n_heads, head_dim)
|
| 248 |
+
v_B_S_H_D: Value tensor with shape (batch, seq_len, n_heads, head_dim)
|
| 249 |
+
|
| 250 |
+
Returns:
|
| 251 |
+
Attention output tensor with shape (batch, seq_len, n_heads * head_dim)
|
| 252 |
+
"""
|
| 253 |
+
in_q_shape = q_B_S_H_D.shape
|
| 254 |
+
in_k_shape = k_B_S_H_D.shape
|
| 255 |
+
q_B_H_S_D = rearrange(q_B_S_H_D, "b ... h k -> b h ... k").view(in_q_shape[0], in_q_shape[-2], -1, in_q_shape[-1])
|
| 256 |
+
k_B_H_S_D = rearrange(k_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1])
|
| 257 |
+
v_B_H_S_D = rearrange(v_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1])
|
| 258 |
+
return attention_forward(q_B_H_S_D, k_B_H_S_D, v_B_H_S_D, out_pattern="b s (n d)")
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
class Attention(nn.Module):
|
| 262 |
+
"""
|
| 263 |
+
A flexible attention module supporting both self-attention and cross-attention mechanisms.
|
| 264 |
+
|
| 265 |
+
This module implements a multi-head attention layer that can operate in either self-attention
|
| 266 |
+
or cross-attention mode. The mode is determined by whether a context dimension is provided.
|
| 267 |
+
The implementation uses scaled dot-product attention and supports optional bias terms and
|
| 268 |
+
dropout regularization.
|
| 269 |
+
|
| 270 |
+
Args:
|
| 271 |
+
query_dim (int): The dimensionality of the query vectors.
|
| 272 |
+
context_dim (int, optional): The dimensionality of the context (key/value) vectors.
|
| 273 |
+
If None, the module operates in self-attention mode using query_dim. Default: None
|
| 274 |
+
n_heads (int, optional): Number of attention heads for multi-head attention. Default: 8
|
| 275 |
+
head_dim (int, optional): The dimension of each attention head. Default: 64
|
| 276 |
+
dropout (float, optional): Dropout probability applied to the output. Default: 0.0
|
| 277 |
+
qkv_format (str, optional): Format specification for QKV tensors. Default: "bshd"
|
| 278 |
+
backend (str, optional): Backend to use for the attention operation. Default: "transformer_engine"
|
| 279 |
+
|
| 280 |
+
Examples:
|
| 281 |
+
>>> # Self-attention with 512 dimensions and 8 heads
|
| 282 |
+
>>> self_attn = Attention(query_dim=512)
|
| 283 |
+
>>> x = torch.randn(32, 16, 512) # (batch_size, seq_len, dim)
|
| 284 |
+
>>> out = self_attn(x) # (32, 16, 512)
|
| 285 |
+
|
| 286 |
+
>>> # Cross-attention
|
| 287 |
+
>>> cross_attn = Attention(query_dim=512, context_dim=256)
|
| 288 |
+
>>> query = torch.randn(32, 16, 512)
|
| 289 |
+
>>> context = torch.randn(32, 8, 256)
|
| 290 |
+
>>> out = cross_attn(query, context) # (32, 16, 512)
|
| 291 |
+
"""
|
| 292 |
+
|
| 293 |
+
def __init__(
|
| 294 |
+
self,
|
| 295 |
+
query_dim: int,
|
| 296 |
+
context_dim: Optional[int] = None,
|
| 297 |
+
n_heads: int = 8,
|
| 298 |
+
head_dim: int = 64,
|
| 299 |
+
dropout: float = 0.0,
|
| 300 |
+
device=None,
|
| 301 |
+
dtype=None,
|
| 302 |
+
operations=None,
|
| 303 |
+
) -> None:
|
| 304 |
+
super().__init__()
|
| 305 |
+
logging.debug(
|
| 306 |
+
f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
|
| 307 |
+
f"{n_heads} heads with a dimension of {head_dim}."
|
| 308 |
+
)
|
| 309 |
+
self.is_selfattn = context_dim is None # self attention
|
| 310 |
+
|
| 311 |
+
context_dim = query_dim if context_dim is None else context_dim
|
| 312 |
+
inner_dim = head_dim * n_heads
|
| 313 |
+
|
| 314 |
+
self.n_heads = n_heads
|
| 315 |
+
self.head_dim = head_dim
|
| 316 |
+
self.query_dim = query_dim
|
| 317 |
+
self.context_dim = context_dim
|
| 318 |
+
|
| 319 |
+
self.q_proj = operations.Linear(query_dim, inner_dim, bias=False, device=device, dtype=dtype)
|
| 320 |
+
self.q_norm = operations.RMSNorm(self.head_dim, eps=1e-6, device=device, dtype=dtype)
|
| 321 |
+
|
| 322 |
+
self.k_proj = operations.Linear(context_dim, inner_dim, bias=False, device=device, dtype=dtype)
|
| 323 |
+
self.k_norm = operations.RMSNorm(self.head_dim, eps=1e-6, device=device, dtype=dtype)
|
| 324 |
+
|
| 325 |
+
self.v_proj = operations.Linear(context_dim, inner_dim, bias=False, device=device, dtype=dtype)
|
| 326 |
+
self.v_norm = nn.Identity()
|
| 327 |
+
|
| 328 |
+
self.output_proj = operations.Linear(inner_dim, query_dim, bias=False, device=device, dtype=dtype)
|
| 329 |
+
self.output_dropout = nn.Dropout(dropout) if dropout > 1e-4 else nn.Identity()
|
| 330 |
+
|
| 331 |
+
self.attn_op = torch_attention_op
|
| 332 |
+
|
| 333 |
+
self._query_dim = query_dim
|
| 334 |
+
self._context_dim = context_dim
|
| 335 |
+
self._inner_dim = inner_dim
|
| 336 |
+
|
| 337 |
+
def compute_qkv(
|
| 338 |
+
self,
|
| 339 |
+
x: torch.Tensor,
|
| 340 |
+
context: Optional[torch.Tensor] = None,
|
| 341 |
+
rope_emb: Optional[torch.Tensor] = None,
|
| 342 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 343 |
+
q = self.q_proj(x)
|
| 344 |
+
context = x if context is None else context
|
| 345 |
+
k = self.k_proj(context)
|
| 346 |
+
v = self.v_proj(context)
|
| 347 |
+
q, k, v = map(
|
| 348 |
+
lambda t: rearrange(t, "b ... (h d) -> b ... h d", h=self.n_heads, d=self.head_dim),
|
| 349 |
+
(q, k, v),
|
| 350 |
+
)
|
| 351 |
+
|
| 352 |
+
def apply_norm_and_rotary_pos_emb(
|
| 353 |
+
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, rope_emb: Optional[torch.Tensor]
|
| 354 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 355 |
+
q = self.q_norm(q)
|
| 356 |
+
k = self.k_norm(k)
|
| 357 |
+
v = self.v_norm(v)
|
| 358 |
+
if self.is_selfattn and rope_emb is not None: # only apply to self-attention!
|
| 359 |
+
q = apply_rotary_pos_emb(q, rope_emb)
|
| 360 |
+
k = apply_rotary_pos_emb(k, rope_emb)
|
| 361 |
+
return q, k, v
|
| 362 |
+
|
| 363 |
+
q, k, v = apply_norm_and_rotary_pos_emb(q, k, v, rope_emb)
|
| 364 |
+
|
| 365 |
+
return q, k, v
|
| 366 |
+
|
| 367 |
+
def compute_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, transformer_options: Optional[dict] = {}) -> torch.Tensor:
|
| 368 |
+
result = self.attn_op(q, k, v, transformer_options=transformer_options) # [B, S, H, D]
|
| 369 |
+
return self.output_dropout(self.output_proj(result))
|
| 370 |
+
|
| 371 |
+
def forward(
|
| 372 |
+
self,
|
| 373 |
+
x: torch.Tensor,
|
| 374 |
+
context: Optional[torch.Tensor] = None,
|
| 375 |
+
rope_emb: Optional[torch.Tensor] = None,
|
| 376 |
+
transformer_options: Optional[dict] = {},
|
| 377 |
+
) -> torch.Tensor:
|
| 378 |
+
"""
|
| 379 |
+
Args:
|
| 380 |
+
x (Tensor): The query tensor of shape [B, Mq, K]
|
| 381 |
+
context (Optional[Tensor]): The key tensor of shape [B, Mk, K] or use x as context [self attention] if None
|
| 382 |
+
"""
|
| 383 |
+
q, k, v = self.compute_qkv(x, context, rope_emb=rope_emb)
|
| 384 |
+
return self.compute_attention(q, k, v, transformer_options=transformer_options)
|
| 385 |
+
|
| 386 |
+
|
| 387 |
+
class Timesteps(nn.Module):
|
| 388 |
+
def __init__(self, num_channels: int):
|
| 389 |
+
super().__init__()
|
| 390 |
+
self.num_channels = num_channels
|
| 391 |
+
|
| 392 |
+
def forward(self, timesteps_B_T: torch.Tensor) -> torch.Tensor:
|
| 393 |
+
assert timesteps_B_T.ndim == 2, f"Expected 2D input, got {timesteps_B_T.ndim}"
|
| 394 |
+
timesteps = timesteps_B_T.flatten().float()
|
| 395 |
+
half_dim = self.num_channels // 2
|
| 396 |
+
exponent = -math.log(10000) * torch.arange(half_dim, dtype=torch.float32, device=timesteps.device)
|
| 397 |
+
exponent = exponent / (half_dim - 0.0)
|
| 398 |
+
|
| 399 |
+
emb = torch.exp(exponent)
|
| 400 |
+
emb = timesteps[:, None].float() * emb[None, :]
|
| 401 |
+
|
| 402 |
+
sin_emb = torch.sin(emb)
|
| 403 |
+
cos_emb = torch.cos(emb)
|
| 404 |
+
emb = torch.cat([cos_emb, sin_emb], dim=-1)
|
| 405 |
+
|
| 406 |
+
return rearrange(emb, "(b t) d -> b t d", b=timesteps_B_T.shape[0], t=timesteps_B_T.shape[1])
|
| 407 |
+
|
| 408 |
+
|
| 409 |
+
class TimestepEmbedding(nn.Module):
|
| 410 |
+
def __init__(self, in_features: int, out_features: int, use_adaln_lora: bool = False, device=None, dtype=None, operations=None):
|
| 411 |
+
super().__init__()
|
| 412 |
+
logging.debug(
|
| 413 |
+
f"Using AdaLN LoRA Flag: {use_adaln_lora}. We enable bias if no AdaLN LoRA for backward compatibility."
|
| 414 |
+
)
|
| 415 |
+
self.in_dim = in_features
|
| 416 |
+
self.out_dim = out_features
|
| 417 |
+
self.linear_1 = operations.Linear(in_features, out_features, bias=not use_adaln_lora, device=device, dtype=dtype)
|
| 418 |
+
self.activation = nn.SiLU()
|
| 419 |
+
self.use_adaln_lora = use_adaln_lora
|
| 420 |
+
if use_adaln_lora:
|
| 421 |
+
self.linear_2 = operations.Linear(out_features, 3 * out_features, bias=False, device=device, dtype=dtype)
|
| 422 |
+
else:
|
| 423 |
+
self.linear_2 = operations.Linear(out_features, out_features, bias=False, device=device, dtype=dtype)
|
| 424 |
+
|
| 425 |
+
def forward(self, sample: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
| 426 |
+
emb = self.linear_1(sample)
|
| 427 |
+
emb = self.activation(emb)
|
| 428 |
+
emb = self.linear_2(emb)
|
| 429 |
+
|
| 430 |
+
if self.use_adaln_lora:
|
| 431 |
+
adaln_lora_B_T_3D = emb
|
| 432 |
+
emb_B_T_D = sample
|
| 433 |
+
else:
|
| 434 |
+
adaln_lora_B_T_3D = None
|
| 435 |
+
emb_B_T_D = emb
|
| 436 |
+
|
| 437 |
+
return emb_B_T_D, adaln_lora_B_T_3D
|
| 438 |
+
|
| 439 |
+
|
| 440 |
+
class PatchEmbed(nn.Module):
|
| 441 |
+
"""
|
| 442 |
+
PatchEmbed is a module for embedding patches from an input tensor by applying either 3D or 2D convolutional layers,
|
| 443 |
+
depending on the . This module can process inputs with temporal (video) and spatial (image) dimensions,
|
| 444 |
+
making it suitable for video and image processing tasks. It supports dividing the input into patches
|
| 445 |
+
and embedding each patch into a vector of size `out_channels`.
|
| 446 |
+
|
| 447 |
+
Parameters:
|
| 448 |
+
- spatial_patch_size (int): The size of each spatial patch.
|
| 449 |
+
- temporal_patch_size (int): The size of each temporal patch.
|
| 450 |
+
- in_channels (int): Number of input channels. Default: 3.
|
| 451 |
+
- out_channels (int): The dimension of the embedding vector for each patch. Default: 768.
|
| 452 |
+
- bias (bool): If True, adds a learnable bias to the output of the convolutional layers. Default: True.
|
| 453 |
+
"""
|
| 454 |
+
|
| 455 |
+
def __init__(
|
| 456 |
+
self,
|
| 457 |
+
spatial_patch_size: int,
|
| 458 |
+
temporal_patch_size: int,
|
| 459 |
+
in_channels: int = 3,
|
| 460 |
+
out_channels: int = 768,
|
| 461 |
+
device=None, dtype=None, operations=None
|
| 462 |
+
):
|
| 463 |
+
super().__init__()
|
| 464 |
+
self.spatial_patch_size = spatial_patch_size
|
| 465 |
+
self.temporal_patch_size = temporal_patch_size
|
| 466 |
+
|
| 467 |
+
self.proj = nn.Sequential(
|
| 468 |
+
Rearrange(
|
| 469 |
+
"b c (t r) (h m) (w n) -> b t h w (c r m n)",
|
| 470 |
+
r=temporal_patch_size,
|
| 471 |
+
m=spatial_patch_size,
|
| 472 |
+
n=spatial_patch_size,
|
| 473 |
+
),
|
| 474 |
+
operations.Linear(
|
| 475 |
+
in_channels * spatial_patch_size * spatial_patch_size * temporal_patch_size, out_channels, bias=False, device=device, dtype=dtype
|
| 476 |
+
),
|
| 477 |
+
)
|
| 478 |
+
self.dim = in_channels * spatial_patch_size * spatial_patch_size * temporal_patch_size
|
| 479 |
+
|
| 480 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 481 |
+
"""
|
| 482 |
+
Forward pass of the PatchEmbed module.
|
| 483 |
+
|
| 484 |
+
Parameters:
|
| 485 |
+
- x (torch.Tensor): The input tensor of shape (B, C, T, H, W) where
|
| 486 |
+
B is the batch size,
|
| 487 |
+
C is the number of channels,
|
| 488 |
+
T is the temporal dimension,
|
| 489 |
+
H is the height, and
|
| 490 |
+
W is the width of the input.
|
| 491 |
+
|
| 492 |
+
Returns:
|
| 493 |
+
- torch.Tensor: The embedded patches as a tensor, with shape b t h w c.
|
| 494 |
+
"""
|
| 495 |
+
assert x.dim() == 5
|
| 496 |
+
_, _, T, H, W = x.shape
|
| 497 |
+
assert (
|
| 498 |
+
H % self.spatial_patch_size == 0 and W % self.spatial_patch_size == 0
|
| 499 |
+
), f"H,W {(H, W)} should be divisible by spatial_patch_size {self.spatial_patch_size}"
|
| 500 |
+
assert T % self.temporal_patch_size == 0
|
| 501 |
+
x = self.proj(x)
|
| 502 |
+
return x
|
| 503 |
+
|
| 504 |
+
|
| 505 |
+
class FinalLayer(nn.Module):
|
| 506 |
+
"""
|
| 507 |
+
The final layer of video DiT.
|
| 508 |
+
"""
|
| 509 |
+
|
| 510 |
+
def __init__(
|
| 511 |
+
self,
|
| 512 |
+
hidden_size: int,
|
| 513 |
+
spatial_patch_size: int,
|
| 514 |
+
temporal_patch_size: int,
|
| 515 |
+
out_channels: int,
|
| 516 |
+
use_adaln_lora: bool = False,
|
| 517 |
+
adaln_lora_dim: int = 256,
|
| 518 |
+
device=None, dtype=None, operations=None
|
| 519 |
+
):
|
| 520 |
+
super().__init__()
|
| 521 |
+
self.layer_norm = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 522 |
+
self.linear = operations.Linear(
|
| 523 |
+
hidden_size, spatial_patch_size * spatial_patch_size * temporal_patch_size * out_channels, bias=False, device=device, dtype=dtype
|
| 524 |
+
)
|
| 525 |
+
self.hidden_size = hidden_size
|
| 526 |
+
self.n_adaln_chunks = 2
|
| 527 |
+
self.use_adaln_lora = use_adaln_lora
|
| 528 |
+
self.adaln_lora_dim = adaln_lora_dim
|
| 529 |
+
if use_adaln_lora:
|
| 530 |
+
self.adaln_modulation = nn.Sequential(
|
| 531 |
+
nn.SiLU(),
|
| 532 |
+
operations.Linear(hidden_size, adaln_lora_dim, bias=False, device=device, dtype=dtype),
|
| 533 |
+
operations.Linear(adaln_lora_dim, self.n_adaln_chunks * hidden_size, bias=False, device=device, dtype=dtype),
|
| 534 |
+
)
|
| 535 |
+
else:
|
| 536 |
+
self.adaln_modulation = nn.Sequential(
|
| 537 |
+
nn.SiLU(), operations.Linear(hidden_size, self.n_adaln_chunks * hidden_size, bias=False, device=device, dtype=dtype)
|
| 538 |
+
)
|
| 539 |
+
|
| 540 |
+
def forward(
|
| 541 |
+
self,
|
| 542 |
+
x_B_T_H_W_D: torch.Tensor,
|
| 543 |
+
emb_B_T_D: torch.Tensor,
|
| 544 |
+
adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
|
| 545 |
+
):
|
| 546 |
+
if self.use_adaln_lora:
|
| 547 |
+
assert adaln_lora_B_T_3D is not None
|
| 548 |
+
shift_B_T_D, scale_B_T_D = (
|
| 549 |
+
self.adaln_modulation(emb_B_T_D) + adaln_lora_B_T_3D[:, :, : 2 * self.hidden_size]
|
| 550 |
+
).chunk(2, dim=-1)
|
| 551 |
+
else:
|
| 552 |
+
shift_B_T_D, scale_B_T_D = self.adaln_modulation(emb_B_T_D).chunk(2, dim=-1)
|
| 553 |
+
|
| 554 |
+
shift_B_T_1_1_D, scale_B_T_1_1_D = rearrange(shift_B_T_D, "b t d -> b t 1 1 d"), rearrange(
|
| 555 |
+
scale_B_T_D, "b t d -> b t 1 1 d"
|
| 556 |
+
)
|
| 557 |
+
|
| 558 |
+
def _fn(
|
| 559 |
+
_x_B_T_H_W_D: torch.Tensor,
|
| 560 |
+
_norm_layer: nn.Module,
|
| 561 |
+
_scale_B_T_1_1_D: torch.Tensor,
|
| 562 |
+
_shift_B_T_1_1_D: torch.Tensor,
|
| 563 |
+
) -> torch.Tensor:
|
| 564 |
+
return _norm_layer(_x_B_T_H_W_D) * (1 + _scale_B_T_1_1_D) + _shift_B_T_1_1_D
|
| 565 |
+
|
| 566 |
+
x_B_T_H_W_D = _fn(x_B_T_H_W_D, self.layer_norm, scale_B_T_1_1_D, shift_B_T_1_1_D)
|
| 567 |
+
x_B_T_H_W_O = self.linear(x_B_T_H_W_D)
|
| 568 |
+
return x_B_T_H_W_O
|
| 569 |
+
|
| 570 |
+
|
| 571 |
+
class Block(nn.Module):
|
| 572 |
+
"""
|
| 573 |
+
A transformer block that combines self-attention, cross-attention and MLP layers with AdaLN modulation.
|
| 574 |
+
Each component (self-attention, cross-attention, MLP) has its own layer normalization and AdaLN modulation.
|
| 575 |
+
|
| 576 |
+
Parameters:
|
| 577 |
+
x_dim (int): Dimension of input features
|
| 578 |
+
context_dim (int): Dimension of context features for cross-attention
|
| 579 |
+
num_heads (int): Number of attention heads
|
| 580 |
+
mlp_ratio (float): Multiplier for MLP hidden dimension. Default: 4.0
|
| 581 |
+
use_adaln_lora (bool): Whether to use AdaLN-LoRA modulation. Default: False
|
| 582 |
+
adaln_lora_dim (int): Hidden dimension for AdaLN-LoRA layers. Default: 256
|
| 583 |
+
|
| 584 |
+
The block applies the following sequence:
|
| 585 |
+
1. Self-attention with AdaLN modulation
|
| 586 |
+
2. Cross-attention with AdaLN modulation
|
| 587 |
+
3. MLP with AdaLN modulation
|
| 588 |
+
|
| 589 |
+
Each component uses skip connections and layer normalization.
|
| 590 |
+
"""
|
| 591 |
+
|
| 592 |
+
def __init__(
|
| 593 |
+
self,
|
| 594 |
+
x_dim: int,
|
| 595 |
+
context_dim: int,
|
| 596 |
+
num_heads: int,
|
| 597 |
+
mlp_ratio: float = 4.0,
|
| 598 |
+
use_adaln_lora: bool = False,
|
| 599 |
+
adaln_lora_dim: int = 256,
|
| 600 |
+
device=None,
|
| 601 |
+
dtype=None,
|
| 602 |
+
operations=None,
|
| 603 |
+
):
|
| 604 |
+
super().__init__()
|
| 605 |
+
self.x_dim = x_dim
|
| 606 |
+
self.layer_norm_self_attn = operations.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype)
|
| 607 |
+
self.self_attn = Attention(x_dim, None, num_heads, x_dim // num_heads, device=device, dtype=dtype, operations=operations)
|
| 608 |
+
|
| 609 |
+
self.layer_norm_cross_attn = operations.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype)
|
| 610 |
+
self.cross_attn = Attention(
|
| 611 |
+
x_dim, context_dim, num_heads, x_dim // num_heads, device=device, dtype=dtype, operations=operations
|
| 612 |
+
)
|
| 613 |
+
|
| 614 |
+
self.layer_norm_mlp = operations.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype)
|
| 615 |
+
self.mlp = GPT2FeedForward(x_dim, int(x_dim * mlp_ratio), device=device, dtype=dtype, operations=operations)
|
| 616 |
+
|
| 617 |
+
self.use_adaln_lora = use_adaln_lora
|
| 618 |
+
if self.use_adaln_lora:
|
| 619 |
+
self.adaln_modulation_self_attn = nn.Sequential(
|
| 620 |
+
nn.SiLU(),
|
| 621 |
+
operations.Linear(x_dim, adaln_lora_dim, bias=False, device=device, dtype=dtype),
|
| 622 |
+
operations.Linear(adaln_lora_dim, 3 * x_dim, bias=False, device=device, dtype=dtype),
|
| 623 |
+
)
|
| 624 |
+
self.adaln_modulation_cross_attn = nn.Sequential(
|
| 625 |
+
nn.SiLU(),
|
| 626 |
+
operations.Linear(x_dim, adaln_lora_dim, bias=False, device=device, dtype=dtype),
|
| 627 |
+
operations.Linear(adaln_lora_dim, 3 * x_dim, bias=False, device=device, dtype=dtype),
|
| 628 |
+
)
|
| 629 |
+
self.adaln_modulation_mlp = nn.Sequential(
|
| 630 |
+
nn.SiLU(),
|
| 631 |
+
operations.Linear(x_dim, adaln_lora_dim, bias=False, device=device, dtype=dtype),
|
| 632 |
+
operations.Linear(adaln_lora_dim, 3 * x_dim, bias=False, device=device, dtype=dtype),
|
| 633 |
+
)
|
| 634 |
+
else:
|
| 635 |
+
self.adaln_modulation_self_attn = nn.Sequential(nn.SiLU(), operations.Linear(x_dim, 3 * x_dim, bias=False, device=device, dtype=dtype))
|
| 636 |
+
self.adaln_modulation_cross_attn = nn.Sequential(nn.SiLU(), operations.Linear(x_dim, 3 * x_dim, bias=False, device=device, dtype=dtype))
|
| 637 |
+
self.adaln_modulation_mlp = nn.Sequential(nn.SiLU(), operations.Linear(x_dim, 3 * x_dim, bias=False, device=device, dtype=dtype))
|
| 638 |
+
|
| 639 |
+
def forward(
|
| 640 |
+
self,
|
| 641 |
+
x_B_T_H_W_D: torch.Tensor,
|
| 642 |
+
emb_B_T_D: torch.Tensor,
|
| 643 |
+
crossattn_emb: torch.Tensor,
|
| 644 |
+
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
|
| 645 |
+
adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
|
| 646 |
+
extra_per_block_pos_emb: Optional[torch.Tensor] = None,
|
| 647 |
+
transformer_options: Optional[dict] = {},
|
| 648 |
+
) -> torch.Tensor:
|
| 649 |
+
residual_dtype = x_B_T_H_W_D.dtype
|
| 650 |
+
compute_dtype = emb_B_T_D.dtype
|
| 651 |
+
if extra_per_block_pos_emb is not None:
|
| 652 |
+
x_B_T_H_W_D = x_B_T_H_W_D + extra_per_block_pos_emb
|
| 653 |
+
|
| 654 |
+
if self.use_adaln_lora:
|
| 655 |
+
shift_self_attn_B_T_D, scale_self_attn_B_T_D, gate_self_attn_B_T_D = (
|
| 656 |
+
self.adaln_modulation_self_attn(emb_B_T_D) + adaln_lora_B_T_3D
|
| 657 |
+
).chunk(3, dim=-1)
|
| 658 |
+
shift_cross_attn_B_T_D, scale_cross_attn_B_T_D, gate_cross_attn_B_T_D = (
|
| 659 |
+
self.adaln_modulation_cross_attn(emb_B_T_D) + adaln_lora_B_T_3D
|
| 660 |
+
).chunk(3, dim=-1)
|
| 661 |
+
shift_mlp_B_T_D, scale_mlp_B_T_D, gate_mlp_B_T_D = (
|
| 662 |
+
self.adaln_modulation_mlp(emb_B_T_D) + adaln_lora_B_T_3D
|
| 663 |
+
).chunk(3, dim=-1)
|
| 664 |
+
else:
|
| 665 |
+
shift_self_attn_B_T_D, scale_self_attn_B_T_D, gate_self_attn_B_T_D = self.adaln_modulation_self_attn(
|
| 666 |
+
emb_B_T_D
|
| 667 |
+
).chunk(3, dim=-1)
|
| 668 |
+
shift_cross_attn_B_T_D, scale_cross_attn_B_T_D, gate_cross_attn_B_T_D = self.adaln_modulation_cross_attn(
|
| 669 |
+
emb_B_T_D
|
| 670 |
+
).chunk(3, dim=-1)
|
| 671 |
+
shift_mlp_B_T_D, scale_mlp_B_T_D, gate_mlp_B_T_D = self.adaln_modulation_mlp(emb_B_T_D).chunk(3, dim=-1)
|
| 672 |
+
|
| 673 |
+
# Reshape tensors from (B, T, D) to (B, T, 1, 1, D) for broadcasting
|
| 674 |
+
shift_self_attn_B_T_1_1_D = rearrange(shift_self_attn_B_T_D, "b t d -> b t 1 1 d")
|
| 675 |
+
scale_self_attn_B_T_1_1_D = rearrange(scale_self_attn_B_T_D, "b t d -> b t 1 1 d")
|
| 676 |
+
gate_self_attn_B_T_1_1_D = rearrange(gate_self_attn_B_T_D, "b t d -> b t 1 1 d")
|
| 677 |
+
|
| 678 |
+
shift_cross_attn_B_T_1_1_D = rearrange(shift_cross_attn_B_T_D, "b t d -> b t 1 1 d")
|
| 679 |
+
scale_cross_attn_B_T_1_1_D = rearrange(scale_cross_attn_B_T_D, "b t d -> b t 1 1 d")
|
| 680 |
+
gate_cross_attn_B_T_1_1_D = rearrange(gate_cross_attn_B_T_D, "b t d -> b t 1 1 d")
|
| 681 |
+
|
| 682 |
+
shift_mlp_B_T_1_1_D = rearrange(shift_mlp_B_T_D, "b t d -> b t 1 1 d")
|
| 683 |
+
scale_mlp_B_T_1_1_D = rearrange(scale_mlp_B_T_D, "b t d -> b t 1 1 d")
|
| 684 |
+
gate_mlp_B_T_1_1_D = rearrange(gate_mlp_B_T_D, "b t d -> b t 1 1 d")
|
| 685 |
+
|
| 686 |
+
B, T, H, W, D = x_B_T_H_W_D.shape
|
| 687 |
+
|
| 688 |
+
def _fn(_x_B_T_H_W_D, _norm_layer, _scale_B_T_1_1_D, _shift_B_T_1_1_D):
|
| 689 |
+
return _norm_layer(_x_B_T_H_W_D) * (1 + _scale_B_T_1_1_D) + _shift_B_T_1_1_D
|
| 690 |
+
|
| 691 |
+
normalized_x_B_T_H_W_D = _fn(
|
| 692 |
+
x_B_T_H_W_D,
|
| 693 |
+
self.layer_norm_self_attn,
|
| 694 |
+
scale_self_attn_B_T_1_1_D,
|
| 695 |
+
shift_self_attn_B_T_1_1_D,
|
| 696 |
+
)
|
| 697 |
+
result_B_T_H_W_D = rearrange(
|
| 698 |
+
self.self_attn(
|
| 699 |
+
# normalized_x_B_T_HW_D,
|
| 700 |
+
rearrange(normalized_x_B_T_H_W_D.to(compute_dtype), "b t h w d -> b (t h w) d"),
|
| 701 |
+
None,
|
| 702 |
+
rope_emb=rope_emb_L_1_1_D,
|
| 703 |
+
transformer_options=transformer_options,
|
| 704 |
+
),
|
| 705 |
+
"b (t h w) d -> b t h w d",
|
| 706 |
+
t=T,
|
| 707 |
+
h=H,
|
| 708 |
+
w=W,
|
| 709 |
+
)
|
| 710 |
+
x_B_T_H_W_D = x_B_T_H_W_D + gate_self_attn_B_T_1_1_D.to(residual_dtype) * result_B_T_H_W_D.to(residual_dtype)
|
| 711 |
+
|
| 712 |
+
def _x_fn(
|
| 713 |
+
_x_B_T_H_W_D: torch.Tensor,
|
| 714 |
+
layer_norm_cross_attn: Callable,
|
| 715 |
+
_scale_cross_attn_B_T_1_1_D: torch.Tensor,
|
| 716 |
+
_shift_cross_attn_B_T_1_1_D: torch.Tensor,
|
| 717 |
+
transformer_options: Optional[dict] = {},
|
| 718 |
+
) -> torch.Tensor:
|
| 719 |
+
_normalized_x_B_T_H_W_D = _fn(
|
| 720 |
+
_x_B_T_H_W_D, layer_norm_cross_attn, _scale_cross_attn_B_T_1_1_D, _shift_cross_attn_B_T_1_1_D
|
| 721 |
+
)
|
| 722 |
+
_result_B_T_H_W_D = rearrange(
|
| 723 |
+
self.cross_attn(
|
| 724 |
+
rearrange(_normalized_x_B_T_H_W_D.to(compute_dtype), "b t h w d -> b (t h w) d"),
|
| 725 |
+
crossattn_emb,
|
| 726 |
+
rope_emb=rope_emb_L_1_1_D,
|
| 727 |
+
transformer_options=transformer_options,
|
| 728 |
+
),
|
| 729 |
+
"b (t h w) d -> b t h w d",
|
| 730 |
+
t=T,
|
| 731 |
+
h=H,
|
| 732 |
+
w=W,
|
| 733 |
+
)
|
| 734 |
+
return _result_B_T_H_W_D
|
| 735 |
+
|
| 736 |
+
result_B_T_H_W_D = _x_fn(
|
| 737 |
+
x_B_T_H_W_D,
|
| 738 |
+
self.layer_norm_cross_attn,
|
| 739 |
+
scale_cross_attn_B_T_1_1_D,
|
| 740 |
+
shift_cross_attn_B_T_1_1_D,
|
| 741 |
+
transformer_options=transformer_options,
|
| 742 |
+
)
|
| 743 |
+
x_B_T_H_W_D = result_B_T_H_W_D.to(residual_dtype) * gate_cross_attn_B_T_1_1_D.to(residual_dtype) + x_B_T_H_W_D
|
| 744 |
+
|
| 745 |
+
normalized_x_B_T_H_W_D = _fn(
|
| 746 |
+
x_B_T_H_W_D,
|
| 747 |
+
self.layer_norm_mlp,
|
| 748 |
+
scale_mlp_B_T_1_1_D,
|
| 749 |
+
shift_mlp_B_T_1_1_D,
|
| 750 |
+
)
|
| 751 |
+
result_B_T_H_W_D = self.mlp(normalized_x_B_T_H_W_D.to(compute_dtype))
|
| 752 |
+
x_B_T_H_W_D = x_B_T_H_W_D + gate_mlp_B_T_1_1_D.to(residual_dtype) * result_B_T_H_W_D.to(residual_dtype)
|
| 753 |
+
return x_B_T_H_W_D
|
| 754 |
+
|
| 755 |
+
|
| 756 |
+
class MiniTrainDIT(nn.Module):
|
| 757 |
+
"""
|
| 758 |
+
A clean impl of DIT that can load and reproduce the training results of the original DIT model in~(cosmos 1)
|
| 759 |
+
A general implementation of adaln-modulated VIT-like~(DiT) transformer for video processing.
|
| 760 |
+
|
| 761 |
+
Args:
|
| 762 |
+
max_img_h (int): Maximum height of the input images.
|
| 763 |
+
max_img_w (int): Maximum width of the input images.
|
| 764 |
+
max_frames (int): Maximum number of frames in the video sequence.
|
| 765 |
+
in_channels (int): Number of input channels (e.g., RGB channels for color images).
|
| 766 |
+
out_channels (int): Number of output channels.
|
| 767 |
+
patch_spatial (tuple): Spatial resolution of patches for input processing.
|
| 768 |
+
patch_temporal (int): Temporal resolution of patches for input processing.
|
| 769 |
+
concat_padding_mask (bool): If True, includes a mask channel in the input to handle padding.
|
| 770 |
+
model_channels (int): Base number of channels used throughout the model.
|
| 771 |
+
num_blocks (int): Number of transformer blocks.
|
| 772 |
+
num_heads (int): Number of heads in the multi-head attention layers.
|
| 773 |
+
mlp_ratio (float): Expansion ratio for MLP blocks.
|
| 774 |
+
crossattn_emb_channels (int): Number of embedding channels for cross-attention.
|
| 775 |
+
pos_emb_cls (str): Type of positional embeddings.
|
| 776 |
+
pos_emb_learnable (bool): Whether positional embeddings are learnable.
|
| 777 |
+
pos_emb_interpolation (str): Method for interpolating positional embeddings.
|
| 778 |
+
min_fps (int): Minimum frames per second.
|
| 779 |
+
max_fps (int): Maximum frames per second.
|
| 780 |
+
use_adaln_lora (bool): Whether to use AdaLN-LoRA.
|
| 781 |
+
adaln_lora_dim (int): Dimension for AdaLN-LoRA.
|
| 782 |
+
rope_h_extrapolation_ratio (float): Height extrapolation ratio for RoPE.
|
| 783 |
+
rope_w_extrapolation_ratio (float): Width extrapolation ratio for RoPE.
|
| 784 |
+
rope_t_extrapolation_ratio (float): Temporal extrapolation ratio for RoPE.
|
| 785 |
+
extra_per_block_abs_pos_emb (bool): Whether to use extra per-block absolute positional embeddings.
|
| 786 |
+
extra_h_extrapolation_ratio (float): Height extrapolation ratio for extra embeddings.
|
| 787 |
+
extra_w_extrapolation_ratio (float): Width extrapolation ratio for extra embeddings.
|
| 788 |
+
extra_t_extrapolation_ratio (float): Temporal extrapolation ratio for extra embeddings.
|
| 789 |
+
"""
|
| 790 |
+
|
| 791 |
+
def __init__(
|
| 792 |
+
self,
|
| 793 |
+
max_img_h: int,
|
| 794 |
+
max_img_w: int,
|
| 795 |
+
max_frames: int,
|
| 796 |
+
in_channels: int,
|
| 797 |
+
out_channels: int,
|
| 798 |
+
patch_spatial: int, # tuple,
|
| 799 |
+
patch_temporal: int,
|
| 800 |
+
concat_padding_mask: bool = True,
|
| 801 |
+
# attention settings
|
| 802 |
+
model_channels: int = 768,
|
| 803 |
+
num_blocks: int = 10,
|
| 804 |
+
num_heads: int = 16,
|
| 805 |
+
mlp_ratio: float = 4.0,
|
| 806 |
+
# cross attention settings
|
| 807 |
+
crossattn_emb_channels: int = 1024,
|
| 808 |
+
# positional embedding settings
|
| 809 |
+
pos_emb_cls: str = "sincos",
|
| 810 |
+
pos_emb_learnable: bool = False,
|
| 811 |
+
pos_emb_interpolation: str = "crop",
|
| 812 |
+
min_fps: int = 1,
|
| 813 |
+
max_fps: int = 30,
|
| 814 |
+
use_adaln_lora: bool = False,
|
| 815 |
+
adaln_lora_dim: int = 256,
|
| 816 |
+
rope_h_extrapolation_ratio: float = 1.0,
|
| 817 |
+
rope_w_extrapolation_ratio: float = 1.0,
|
| 818 |
+
rope_t_extrapolation_ratio: float = 1.0,
|
| 819 |
+
extra_per_block_abs_pos_emb: bool = False,
|
| 820 |
+
extra_h_extrapolation_ratio: float = 1.0,
|
| 821 |
+
extra_w_extrapolation_ratio: float = 1.0,
|
| 822 |
+
extra_t_extrapolation_ratio: float = 1.0,
|
| 823 |
+
rope_enable_fps_modulation: bool = True,
|
| 824 |
+
image_model=None,
|
| 825 |
+
device=None,
|
| 826 |
+
dtype=None,
|
| 827 |
+
operations=None,
|
| 828 |
+
) -> None:
|
| 829 |
+
super().__init__()
|
| 830 |
+
self.dtype = dtype
|
| 831 |
+
self.max_img_h = max_img_h
|
| 832 |
+
self.max_img_w = max_img_w
|
| 833 |
+
self.max_frames = max_frames
|
| 834 |
+
self.in_channels = in_channels
|
| 835 |
+
self.out_channels = out_channels
|
| 836 |
+
self.patch_spatial = patch_spatial
|
| 837 |
+
self.patch_temporal = patch_temporal
|
| 838 |
+
self.num_heads = num_heads
|
| 839 |
+
self.num_blocks = num_blocks
|
| 840 |
+
self.model_channels = model_channels
|
| 841 |
+
self.concat_padding_mask = concat_padding_mask
|
| 842 |
+
# positional embedding settings
|
| 843 |
+
self.pos_emb_cls = pos_emb_cls
|
| 844 |
+
self.pos_emb_learnable = pos_emb_learnable
|
| 845 |
+
self.pos_emb_interpolation = pos_emb_interpolation
|
| 846 |
+
self.min_fps = min_fps
|
| 847 |
+
self.max_fps = max_fps
|
| 848 |
+
self.rope_h_extrapolation_ratio = rope_h_extrapolation_ratio
|
| 849 |
+
self.rope_w_extrapolation_ratio = rope_w_extrapolation_ratio
|
| 850 |
+
self.rope_t_extrapolation_ratio = rope_t_extrapolation_ratio
|
| 851 |
+
self.extra_per_block_abs_pos_emb = extra_per_block_abs_pos_emb
|
| 852 |
+
self.extra_h_extrapolation_ratio = extra_h_extrapolation_ratio
|
| 853 |
+
self.extra_w_extrapolation_ratio = extra_w_extrapolation_ratio
|
| 854 |
+
self.extra_t_extrapolation_ratio = extra_t_extrapolation_ratio
|
| 855 |
+
self.rope_enable_fps_modulation = rope_enable_fps_modulation
|
| 856 |
+
|
| 857 |
+
self.build_pos_embed(device=device, dtype=dtype)
|
| 858 |
+
self.use_adaln_lora = use_adaln_lora
|
| 859 |
+
self.adaln_lora_dim = adaln_lora_dim
|
| 860 |
+
self.t_embedder = nn.Sequential(
|
| 861 |
+
Timesteps(model_channels),
|
| 862 |
+
TimestepEmbedding(model_channels, model_channels, use_adaln_lora=use_adaln_lora, device=device, dtype=dtype, operations=operations,),
|
| 863 |
+
)
|
| 864 |
+
|
| 865 |
+
in_channels = in_channels + 1 if concat_padding_mask else in_channels
|
| 866 |
+
self.x_embedder = PatchEmbed(
|
| 867 |
+
spatial_patch_size=patch_spatial,
|
| 868 |
+
temporal_patch_size=patch_temporal,
|
| 869 |
+
in_channels=in_channels,
|
| 870 |
+
out_channels=model_channels,
|
| 871 |
+
device=device, dtype=dtype, operations=operations,
|
| 872 |
+
)
|
| 873 |
+
|
| 874 |
+
self.blocks = nn.ModuleList(
|
| 875 |
+
[
|
| 876 |
+
Block(
|
| 877 |
+
x_dim=model_channels,
|
| 878 |
+
context_dim=crossattn_emb_channels,
|
| 879 |
+
num_heads=num_heads,
|
| 880 |
+
mlp_ratio=mlp_ratio,
|
| 881 |
+
use_adaln_lora=use_adaln_lora,
|
| 882 |
+
adaln_lora_dim=adaln_lora_dim,
|
| 883 |
+
device=device, dtype=dtype, operations=operations,
|
| 884 |
+
)
|
| 885 |
+
for _ in range(num_blocks)
|
| 886 |
+
]
|
| 887 |
+
)
|
| 888 |
+
|
| 889 |
+
self.final_layer = FinalLayer(
|
| 890 |
+
hidden_size=self.model_channels,
|
| 891 |
+
spatial_patch_size=self.patch_spatial,
|
| 892 |
+
temporal_patch_size=self.patch_temporal,
|
| 893 |
+
out_channels=self.out_channels,
|
| 894 |
+
use_adaln_lora=self.use_adaln_lora,
|
| 895 |
+
adaln_lora_dim=self.adaln_lora_dim,
|
| 896 |
+
device=device, dtype=dtype, operations=operations,
|
| 897 |
+
)
|
| 898 |
+
|
| 899 |
+
self.t_embedding_norm = operations.RMSNorm(model_channels, eps=1e-6, device=device, dtype=dtype)
|
| 900 |
+
|
| 901 |
+
def build_pos_embed(self, device=None, dtype=None) -> None:
|
| 902 |
+
if self.pos_emb_cls == "rope3d":
|
| 903 |
+
cls_type = VideoRopePosition3DEmb
|
| 904 |
+
else:
|
| 905 |
+
raise ValueError(f"Unknown pos_emb_cls {self.pos_emb_cls}")
|
| 906 |
+
|
| 907 |
+
logging.debug(f"Building positional embedding with {self.pos_emb_cls} class, impl {cls_type}")
|
| 908 |
+
kwargs = dict(
|
| 909 |
+
model_channels=self.model_channels,
|
| 910 |
+
len_h=self.max_img_h // self.patch_spatial,
|
| 911 |
+
len_w=self.max_img_w // self.patch_spatial,
|
| 912 |
+
len_t=self.max_frames // self.patch_temporal,
|
| 913 |
+
max_fps=self.max_fps,
|
| 914 |
+
min_fps=self.min_fps,
|
| 915 |
+
is_learnable=self.pos_emb_learnable,
|
| 916 |
+
interpolation=self.pos_emb_interpolation,
|
| 917 |
+
head_dim=self.model_channels // self.num_heads,
|
| 918 |
+
h_extrapolation_ratio=self.rope_h_extrapolation_ratio,
|
| 919 |
+
w_extrapolation_ratio=self.rope_w_extrapolation_ratio,
|
| 920 |
+
t_extrapolation_ratio=self.rope_t_extrapolation_ratio,
|
| 921 |
+
enable_fps_modulation=self.rope_enable_fps_modulation,
|
| 922 |
+
device=device,
|
| 923 |
+
)
|
| 924 |
+
self.pos_embedder = cls_type(
|
| 925 |
+
**kwargs, # type: ignore
|
| 926 |
+
)
|
| 927 |
+
|
| 928 |
+
if self.extra_per_block_abs_pos_emb:
|
| 929 |
+
kwargs["h_extrapolation_ratio"] = self.extra_h_extrapolation_ratio
|
| 930 |
+
kwargs["w_extrapolation_ratio"] = self.extra_w_extrapolation_ratio
|
| 931 |
+
kwargs["t_extrapolation_ratio"] = self.extra_t_extrapolation_ratio
|
| 932 |
+
kwargs["device"] = device
|
| 933 |
+
kwargs["dtype"] = dtype
|
| 934 |
+
self.extra_pos_embedder = LearnablePosEmbAxis(
|
| 935 |
+
**kwargs, # type: ignore
|
| 936 |
+
)
|
| 937 |
+
|
| 938 |
+
def prepare_embedded_sequence(
|
| 939 |
+
self,
|
| 940 |
+
x_B_C_T_H_W: torch.Tensor,
|
| 941 |
+
fps: Optional[torch.Tensor] = None,
|
| 942 |
+
padding_mask: Optional[torch.Tensor] = None,
|
| 943 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
|
| 944 |
+
"""
|
| 945 |
+
Prepares an embedded sequence tensor by applying positional embeddings and handling padding masks.
|
| 946 |
+
|
| 947 |
+
Args:
|
| 948 |
+
x_B_C_T_H_W (torch.Tensor): video
|
| 949 |
+
fps (Optional[torch.Tensor]): Frames per second tensor to be used for positional embedding when required.
|
| 950 |
+
If None, a default value (`self.base_fps`) will be used.
|
| 951 |
+
padding_mask (Optional[torch.Tensor]): current it is not used
|
| 952 |
+
|
| 953 |
+
Returns:
|
| 954 |
+
Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
| 955 |
+
- A tensor of shape (B, T, H, W, D) with the embedded sequence.
|
| 956 |
+
- An optional positional embedding tensor, returned only if the positional embedding class
|
| 957 |
+
(`self.pos_emb_cls`) includes 'rope'. Otherwise, None.
|
| 958 |
+
|
| 959 |
+
Notes:
|
| 960 |
+
- If `self.concat_padding_mask` is True, a padding mask channel is concatenated to the input tensor.
|
| 961 |
+
- The method of applying positional embeddings depends on the value of `self.pos_emb_cls`.
|
| 962 |
+
- If 'rope' is in `self.pos_emb_cls` (case insensitive), the positional embeddings are generated using
|
| 963 |
+
the `self.pos_embedder` with the shape [T, H, W].
|
| 964 |
+
- If "fps_aware" is in `self.pos_emb_cls`, the positional embeddings are generated using the
|
| 965 |
+
`self.pos_embedder` with the fps tensor.
|
| 966 |
+
- Otherwise, the positional embeddings are generated without considering fps.
|
| 967 |
+
"""
|
| 968 |
+
if self.concat_padding_mask:
|
| 969 |
+
if padding_mask is None:
|
| 970 |
+
padding_mask = torch.zeros(x_B_C_T_H_W.shape[0], 1, x_B_C_T_H_W.shape[3], x_B_C_T_H_W.shape[4], dtype=x_B_C_T_H_W.dtype, device=x_B_C_T_H_W.device)
|
| 971 |
+
else:
|
| 972 |
+
padding_mask = transforms.functional.resize(
|
| 973 |
+
padding_mask, list(x_B_C_T_H_W.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST
|
| 974 |
+
)
|
| 975 |
+
x_B_C_T_H_W = torch.cat(
|
| 976 |
+
[x_B_C_T_H_W, padding_mask.unsqueeze(1).repeat(1, 1, x_B_C_T_H_W.shape[2], 1, 1)], dim=1
|
| 977 |
+
)
|
| 978 |
+
x_B_T_H_W_D = self.x_embedder(x_B_C_T_H_W)
|
| 979 |
+
|
| 980 |
+
if self.extra_per_block_abs_pos_emb:
|
| 981 |
+
extra_pos_emb = self.extra_pos_embedder(x_B_T_H_W_D, fps=fps, device=x_B_C_T_H_W.device, dtype=x_B_C_T_H_W.dtype)
|
| 982 |
+
else:
|
| 983 |
+
extra_pos_emb = None
|
| 984 |
+
|
| 985 |
+
if "rope" in self.pos_emb_cls.lower():
|
| 986 |
+
return x_B_T_H_W_D, self.pos_embedder(x_B_T_H_W_D, fps=fps, device=x_B_C_T_H_W.device), extra_pos_emb
|
| 987 |
+
x_B_T_H_W_D = x_B_T_H_W_D + self.pos_embedder(x_B_T_H_W_D, device=x_B_C_T_H_W.device) # [B, T, H, W, D]
|
| 988 |
+
|
| 989 |
+
return x_B_T_H_W_D, None, extra_pos_emb
|
| 990 |
+
|
| 991 |
+
def unpatchify(self, x_B_T_H_W_M: torch.Tensor) -> torch.Tensor:
|
| 992 |
+
x_B_C_Tt_Hp_Wp = rearrange(
|
| 993 |
+
x_B_T_H_W_M,
|
| 994 |
+
"B T H W (p1 p2 t C) -> B C (T t) (H p1) (W p2)",
|
| 995 |
+
p1=self.patch_spatial,
|
| 996 |
+
p2=self.patch_spatial,
|
| 997 |
+
t=self.patch_temporal,
|
| 998 |
+
)
|
| 999 |
+
return x_B_C_Tt_Hp_Wp
|
| 1000 |
+
|
| 1001 |
+
def pad_to_patch_size(self, img, patch_size=(2, 2), padding_mode="circular"):
|
| 1002 |
+
if padding_mode == "circular" and (torch.jit.is_tracing() or torch.jit.is_scripting()):
|
| 1003 |
+
padding_mode = "reflect"
|
| 1004 |
+
|
| 1005 |
+
pad = ()
|
| 1006 |
+
for i in range(img.ndim - 2):
|
| 1007 |
+
pad = (0, (patch_size[i] - img.shape[i + 2] % patch_size[i]) % patch_size[i]) + pad
|
| 1008 |
+
|
| 1009 |
+
return torch.nn.functional.pad(img, pad, mode=padding_mode)
|
| 1010 |
+
|
| 1011 |
+
def forward(
|
| 1012 |
+
self,
|
| 1013 |
+
x: torch.Tensor,
|
| 1014 |
+
timesteps: torch.Tensor,
|
| 1015 |
+
context: torch.Tensor,
|
| 1016 |
+
fps: Optional[torch.Tensor] = None,
|
| 1017 |
+
padding_mask: Optional[torch.Tensor] = None,
|
| 1018 |
+
use_gradient_checkpointing=False,
|
| 1019 |
+
use_gradient_checkpointing_offload=False,
|
| 1020 |
+
**kwargs,
|
| 1021 |
+
):
|
| 1022 |
+
orig_shape = list(x.shape)
|
| 1023 |
+
x = self.pad_to_patch_size(x, (self.patch_temporal, self.patch_spatial, self.patch_spatial))
|
| 1024 |
+
x_B_C_T_H_W = x
|
| 1025 |
+
timesteps_B_T = timesteps
|
| 1026 |
+
crossattn_emb = context
|
| 1027 |
+
"""
|
| 1028 |
+
Args:
|
| 1029 |
+
x: (B, C, T, H, W) tensor of spatial-temp inputs
|
| 1030 |
+
timesteps: (B, ) tensor of timesteps
|
| 1031 |
+
crossattn_emb: (B, N, D) tensor of cross-attention embeddings
|
| 1032 |
+
"""
|
| 1033 |
+
x_B_T_H_W_D, rope_emb_L_1_1_D, extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = self.prepare_embedded_sequence(
|
| 1034 |
+
x_B_C_T_H_W,
|
| 1035 |
+
fps=fps,
|
| 1036 |
+
padding_mask=padding_mask,
|
| 1037 |
+
)
|
| 1038 |
+
|
| 1039 |
+
if timesteps_B_T.ndim == 1:
|
| 1040 |
+
timesteps_B_T = timesteps_B_T.unsqueeze(1)
|
| 1041 |
+
t_embedding_B_T_D, adaln_lora_B_T_3D = self.t_embedder[1](self.t_embedder[0](timesteps_B_T).to(x_B_T_H_W_D.dtype))
|
| 1042 |
+
t_embedding_B_T_D = self.t_embedding_norm(t_embedding_B_T_D)
|
| 1043 |
+
|
| 1044 |
+
# for logging purpose
|
| 1045 |
+
affline_scale_log_info = {}
|
| 1046 |
+
affline_scale_log_info["t_embedding_B_T_D"] = t_embedding_B_T_D.detach()
|
| 1047 |
+
self.affline_scale_log_info = affline_scale_log_info
|
| 1048 |
+
self.affline_emb = t_embedding_B_T_D
|
| 1049 |
+
self.crossattn_emb = crossattn_emb
|
| 1050 |
+
|
| 1051 |
+
if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None:
|
| 1052 |
+
assert (
|
| 1053 |
+
x_B_T_H_W_D.shape == extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape
|
| 1054 |
+
), f"{x_B_T_H_W_D.shape} != {extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape}"
|
| 1055 |
+
|
| 1056 |
+
block_kwargs = {
|
| 1057 |
+
"rope_emb_L_1_1_D": rope_emb_L_1_1_D.unsqueeze(1).unsqueeze(0),
|
| 1058 |
+
"adaln_lora_B_T_3D": adaln_lora_B_T_3D,
|
| 1059 |
+
"extra_per_block_pos_emb": extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D,
|
| 1060 |
+
"transformer_options": kwargs.get("transformer_options", {}),
|
| 1061 |
+
}
|
| 1062 |
+
|
| 1063 |
+
# The residual stream for this model has large values. To make fp16 compute_dtype work, we keep the residual stream
|
| 1064 |
+
# in fp32, but run attention and MLP modules in fp16.
|
| 1065 |
+
# An alternate method that clamps fp16 values "works" in the sense that it makes coherent images, but there is noticeable
|
| 1066 |
+
# quality degradation and visual artifacts.
|
| 1067 |
+
if x_B_T_H_W_D.dtype == torch.float16:
|
| 1068 |
+
x_B_T_H_W_D = x_B_T_H_W_D.float()
|
| 1069 |
+
|
| 1070 |
+
for block in self.blocks:
|
| 1071 |
+
x_B_T_H_W_D = gradient_checkpoint_forward(
|
| 1072 |
+
block,
|
| 1073 |
+
use_gradient_checkpointing=use_gradient_checkpointing,
|
| 1074 |
+
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
| 1075 |
+
x_B_T_H_W_D=x_B_T_H_W_D,
|
| 1076 |
+
emb_B_T_D=t_embedding_B_T_D,
|
| 1077 |
+
crossattn_emb=crossattn_emb,
|
| 1078 |
+
**block_kwargs,
|
| 1079 |
+
)
|
| 1080 |
+
|
| 1081 |
+
x_B_T_H_W_O = self.final_layer(x_B_T_H_W_D.to(crossattn_emb.dtype), t_embedding_B_T_D, adaln_lora_B_T_3D=adaln_lora_B_T_3D)
|
| 1082 |
+
x_B_C_Tt_Hp_Wp = self.unpatchify(x_B_T_H_W_O)[:, :, :orig_shape[-3], :orig_shape[-2], :orig_shape[-1]]
|
| 1083 |
+
return x_B_C_Tt_Hp_Wp
|
| 1084 |
+
|
| 1085 |
+
|
| 1086 |
+
def rotate_half(x):
|
| 1087 |
+
x1 = x[..., : x.shape[-1] // 2]
|
| 1088 |
+
x2 = x[..., x.shape[-1] // 2 :]
|
| 1089 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 1090 |
+
|
| 1091 |
+
|
| 1092 |
+
def apply_rotary_pos_emb2(x, cos, sin, unsqueeze_dim=1):
|
| 1093 |
+
cos = cos.unsqueeze(unsqueeze_dim)
|
| 1094 |
+
sin = sin.unsqueeze(unsqueeze_dim)
|
| 1095 |
+
x_embed = (x * cos) + (rotate_half(x) * sin)
|
| 1096 |
+
return x_embed
|
| 1097 |
+
|
| 1098 |
+
|
| 1099 |
+
class RotaryEmbedding(nn.Module):
|
| 1100 |
+
def __init__(self, head_dim):
|
| 1101 |
+
super().__init__()
|
| 1102 |
+
self.rope_theta = 10000
|
| 1103 |
+
inv_freq = 1.0 / (self.rope_theta ** (torch.arange(0, head_dim, 2, dtype=torch.int64).to(dtype=torch.float) / head_dim))
|
| 1104 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 1105 |
+
|
| 1106 |
+
@torch.no_grad()
|
| 1107 |
+
def forward(self, x, position_ids):
|
| 1108 |
+
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
|
| 1109 |
+
position_ids_expanded = position_ids[:, None, :].float()
|
| 1110 |
+
|
| 1111 |
+
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
| 1112 |
+
with torch.autocast(device_type=device_type, enabled=False): # Force float32
|
| 1113 |
+
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
| 1114 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 1115 |
+
cos = emb.cos()
|
| 1116 |
+
sin = emb.sin()
|
| 1117 |
+
|
| 1118 |
+
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
| 1119 |
+
|
| 1120 |
+
|
| 1121 |
+
class LLMAdapterAttention(nn.Module):
|
| 1122 |
+
def __init__(self, query_dim, context_dim, n_heads, head_dim, device=None, dtype=None, operations=None):
|
| 1123 |
+
super().__init__()
|
| 1124 |
+
|
| 1125 |
+
inner_dim = head_dim * n_heads
|
| 1126 |
+
self.n_heads = n_heads
|
| 1127 |
+
self.head_dim = head_dim
|
| 1128 |
+
self.query_dim = query_dim
|
| 1129 |
+
self.context_dim = context_dim
|
| 1130 |
+
|
| 1131 |
+
self.q_proj = operations.Linear(query_dim, inner_dim, bias=False, device=device, dtype=dtype)
|
| 1132 |
+
self.q_norm = operations.RMSNorm(self.head_dim, eps=1e-6, device=device, dtype=dtype)
|
| 1133 |
+
|
| 1134 |
+
self.k_proj = operations.Linear(context_dim, inner_dim, bias=False, device=device, dtype=dtype)
|
| 1135 |
+
self.k_norm = operations.RMSNorm(self.head_dim, eps=1e-6, device=device, dtype=dtype)
|
| 1136 |
+
|
| 1137 |
+
self.v_proj = operations.Linear(context_dim, inner_dim, bias=False, device=device, dtype=dtype)
|
| 1138 |
+
|
| 1139 |
+
self.o_proj = operations.Linear(inner_dim, query_dim, bias=False, device=device, dtype=dtype)
|
| 1140 |
+
|
| 1141 |
+
def forward(self, x, mask=None, context=None, position_embeddings=None, position_embeddings_context=None):
|
| 1142 |
+
context = x if context is None else context
|
| 1143 |
+
input_shape = x.shape[:-1]
|
| 1144 |
+
q_shape = (*input_shape, self.n_heads, self.head_dim)
|
| 1145 |
+
context_shape = context.shape[:-1]
|
| 1146 |
+
kv_shape = (*context_shape, self.n_heads, self.head_dim)
|
| 1147 |
+
|
| 1148 |
+
query_states = self.q_norm(self.q_proj(x).view(q_shape)).transpose(1, 2)
|
| 1149 |
+
key_states = self.k_norm(self.k_proj(context).view(kv_shape)).transpose(1, 2)
|
| 1150 |
+
value_states = self.v_proj(context).view(kv_shape).transpose(1, 2)
|
| 1151 |
+
|
| 1152 |
+
if position_embeddings is not None:
|
| 1153 |
+
assert position_embeddings_context is not None
|
| 1154 |
+
cos, sin = position_embeddings
|
| 1155 |
+
query_states = apply_rotary_pos_emb2(query_states, cos, sin)
|
| 1156 |
+
cos, sin = position_embeddings_context
|
| 1157 |
+
key_states = apply_rotary_pos_emb2(key_states, cos, sin)
|
| 1158 |
+
|
| 1159 |
+
attn_output = torch.nn.functional.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask=mask)
|
| 1160 |
+
|
| 1161 |
+
attn_output = attn_output.transpose(1, 2).reshape(*input_shape, -1).contiguous()
|
| 1162 |
+
attn_output = self.o_proj(attn_output)
|
| 1163 |
+
return attn_output
|
| 1164 |
+
|
| 1165 |
+
def init_weights(self):
|
| 1166 |
+
torch.nn.init.zeros_(self.o_proj.weight)
|
| 1167 |
+
|
| 1168 |
+
|
| 1169 |
+
class LLMAdapterTransformerBlock(nn.Module):
|
| 1170 |
+
def __init__(self, source_dim, model_dim, num_heads=16, mlp_ratio=4.0, use_self_attn=False, layer_norm=False, device=None, dtype=None, operations=None):
|
| 1171 |
+
super().__init__()
|
| 1172 |
+
self.use_self_attn = use_self_attn
|
| 1173 |
+
|
| 1174 |
+
if self.use_self_attn:
|
| 1175 |
+
self.norm_self_attn = operations.LayerNorm(model_dim, device=device, dtype=dtype) if layer_norm else operations.RMSNorm(model_dim, eps=1e-6, device=device, dtype=dtype)
|
| 1176 |
+
self.self_attn = LLMAdapterAttention(
|
| 1177 |
+
query_dim=model_dim,
|
| 1178 |
+
context_dim=model_dim,
|
| 1179 |
+
n_heads=num_heads,
|
| 1180 |
+
head_dim=model_dim//num_heads,
|
| 1181 |
+
device=device,
|
| 1182 |
+
dtype=dtype,
|
| 1183 |
+
operations=operations,
|
| 1184 |
+
)
|
| 1185 |
+
|
| 1186 |
+
self.norm_cross_attn = operations.LayerNorm(model_dim, device=device, dtype=dtype) if layer_norm else operations.RMSNorm(model_dim, eps=1e-6, device=device, dtype=dtype)
|
| 1187 |
+
self.cross_attn = LLMAdapterAttention(
|
| 1188 |
+
query_dim=model_dim,
|
| 1189 |
+
context_dim=source_dim,
|
| 1190 |
+
n_heads=num_heads,
|
| 1191 |
+
head_dim=model_dim//num_heads,
|
| 1192 |
+
device=device,
|
| 1193 |
+
dtype=dtype,
|
| 1194 |
+
operations=operations,
|
| 1195 |
+
)
|
| 1196 |
+
|
| 1197 |
+
self.norm_mlp = operations.LayerNorm(model_dim, device=device, dtype=dtype) if layer_norm else operations.RMSNorm(model_dim, eps=1e-6, device=device, dtype=dtype)
|
| 1198 |
+
self.mlp = nn.Sequential(
|
| 1199 |
+
operations.Linear(model_dim, int(model_dim * mlp_ratio), device=device, dtype=dtype),
|
| 1200 |
+
nn.GELU(),
|
| 1201 |
+
operations.Linear(int(model_dim * mlp_ratio), model_dim, device=device, dtype=dtype)
|
| 1202 |
+
)
|
| 1203 |
+
|
| 1204 |
+
def forward(self, x, context, target_attention_mask=None, source_attention_mask=None, position_embeddings=None, position_embeddings_context=None):
|
| 1205 |
+
if self.use_self_attn:
|
| 1206 |
+
normed = self.norm_self_attn(x)
|
| 1207 |
+
attn_out = self.self_attn(normed, mask=target_attention_mask, position_embeddings=position_embeddings, position_embeddings_context=position_embeddings)
|
| 1208 |
+
x = x + attn_out
|
| 1209 |
+
|
| 1210 |
+
normed = self.norm_cross_attn(x)
|
| 1211 |
+
attn_out = self.cross_attn(normed, mask=source_attention_mask, context=context, position_embeddings=position_embeddings, position_embeddings_context=position_embeddings_context)
|
| 1212 |
+
x = x + attn_out
|
| 1213 |
+
|
| 1214 |
+
x = x + self.mlp(self.norm_mlp(x))
|
| 1215 |
+
return x
|
| 1216 |
+
|
| 1217 |
+
def init_weights(self):
|
| 1218 |
+
torch.nn.init.zeros_(self.mlp[2].weight)
|
| 1219 |
+
self.cross_attn.init_weights()
|
| 1220 |
+
|
| 1221 |
+
|
| 1222 |
+
class LLMAdapter(nn.Module):
|
| 1223 |
+
def __init__(
|
| 1224 |
+
self,
|
| 1225 |
+
source_dim=1024,
|
| 1226 |
+
target_dim=1024,
|
| 1227 |
+
model_dim=1024,
|
| 1228 |
+
num_layers=6,
|
| 1229 |
+
num_heads=16,
|
| 1230 |
+
use_self_attn=True,
|
| 1231 |
+
layer_norm=False,
|
| 1232 |
+
device=None,
|
| 1233 |
+
dtype=None,
|
| 1234 |
+
operations=None,
|
| 1235 |
+
):
|
| 1236 |
+
super().__init__()
|
| 1237 |
+
|
| 1238 |
+
self.embed = operations.Embedding(32128, target_dim, device=device, dtype=dtype)
|
| 1239 |
+
if model_dim != target_dim:
|
| 1240 |
+
self.in_proj = operations.Linear(target_dim, model_dim, device=device, dtype=dtype)
|
| 1241 |
+
else:
|
| 1242 |
+
self.in_proj = nn.Identity()
|
| 1243 |
+
self.rotary_emb = RotaryEmbedding(model_dim//num_heads)
|
| 1244 |
+
self.blocks = nn.ModuleList([
|
| 1245 |
+
LLMAdapterTransformerBlock(source_dim, model_dim, num_heads=num_heads, use_self_attn=use_self_attn, layer_norm=layer_norm, device=device, dtype=dtype, operations=operations) for _ in range(num_layers)
|
| 1246 |
+
])
|
| 1247 |
+
self.out_proj = operations.Linear(model_dim, target_dim, device=device, dtype=dtype)
|
| 1248 |
+
self.norm = operations.RMSNorm(target_dim, eps=1e-6, device=device, dtype=dtype)
|
| 1249 |
+
|
| 1250 |
+
def forward(self, source_hidden_states, target_input_ids, target_attention_mask=None, source_attention_mask=None):
|
| 1251 |
+
if target_attention_mask is not None:
|
| 1252 |
+
target_attention_mask = target_attention_mask.to(torch.bool)
|
| 1253 |
+
if target_attention_mask.ndim == 2:
|
| 1254 |
+
target_attention_mask = target_attention_mask.unsqueeze(1).unsqueeze(1)
|
| 1255 |
+
|
| 1256 |
+
if source_attention_mask is not None:
|
| 1257 |
+
source_attention_mask = source_attention_mask.to(torch.bool)
|
| 1258 |
+
if source_attention_mask.ndim == 2:
|
| 1259 |
+
source_attention_mask = source_attention_mask.unsqueeze(1).unsqueeze(1)
|
| 1260 |
+
|
| 1261 |
+
context = source_hidden_states
|
| 1262 |
+
x = self.in_proj(self.embed(target_input_ids).to(context.dtype))
|
| 1263 |
+
position_ids = torch.arange(x.shape[1], device=x.device).unsqueeze(0)
|
| 1264 |
+
position_ids_context = torch.arange(context.shape[1], device=x.device).unsqueeze(0)
|
| 1265 |
+
position_embeddings = self.rotary_emb(x, position_ids)
|
| 1266 |
+
position_embeddings_context = self.rotary_emb(x, position_ids_context)
|
| 1267 |
+
for block in self.blocks:
|
| 1268 |
+
x = block(x, context, target_attention_mask=target_attention_mask, source_attention_mask=source_attention_mask, position_embeddings=position_embeddings, position_embeddings_context=position_embeddings_context)
|
| 1269 |
+
return self.norm(self.out_proj(x))
|
| 1270 |
+
|
| 1271 |
+
|
| 1272 |
+
class AnimaDiT(MiniTrainDIT):
|
| 1273 |
+
|
| 1274 |
+
_repeated_blocks = ["Block"]
|
| 1275 |
+
|
| 1276 |
+
def __init__(self):
|
| 1277 |
+
kwargs = {'image_model': 'anima', 'max_img_h': 240, 'max_img_w': 240, 'max_frames': 128, 'in_channels': 16, 'out_channels': 16, 'patch_spatial': 2, 'patch_temporal': 1, 'model_channels': 2048, 'concat_padding_mask': True, 'crossattn_emb_channels': 1024, 'pos_emb_cls': 'rope3d', 'pos_emb_learnable': True, 'pos_emb_interpolation': 'crop', 'min_fps': 1, 'max_fps': 30, 'use_adaln_lora': True, 'adaln_lora_dim': 256, 'num_blocks': 28, 'num_heads': 16, 'extra_per_block_abs_pos_emb': False, 'rope_h_extrapolation_ratio': 4.0, 'rope_w_extrapolation_ratio': 4.0, 'rope_t_extrapolation_ratio': 1.0, 'extra_h_extrapolation_ratio': 1.0, 'extra_w_extrapolation_ratio': 1.0, 'extra_t_extrapolation_ratio': 1.0, 'rope_enable_fps_modulation': False, 'dtype': torch.bfloat16, 'device': None, 'operations': torch.nn}
|
| 1278 |
+
super().__init__(**kwargs)
|
| 1279 |
+
self.llm_adapter = LLMAdapter(device=kwargs.get("device"), dtype=kwargs.get("dtype"), operations=kwargs.get("operations"))
|
| 1280 |
+
|
| 1281 |
+
def preprocess_text_embeds(self, text_embeds, text_ids, t5xxl_weights=None):
|
| 1282 |
+
if text_ids is not None:
|
| 1283 |
+
out = self.llm_adapter(text_embeds, text_ids)
|
| 1284 |
+
if t5xxl_weights is not None:
|
| 1285 |
+
out = out * t5xxl_weights
|
| 1286 |
+
|
| 1287 |
+
if out.shape[1] < 512:
|
| 1288 |
+
out = torch.nn.functional.pad(out, (0, 0, 0, 512 - out.shape[1]))
|
| 1289 |
+
return out
|
| 1290 |
+
else:
|
| 1291 |
+
return text_embeds
|
| 1292 |
+
|
| 1293 |
+
def forward(
|
| 1294 |
+
self,
|
| 1295 |
+
x, timesteps, context,
|
| 1296 |
+
use_gradient_checkpointing=False,
|
| 1297 |
+
use_gradient_checkpointing_offload=False,
|
| 1298 |
+
**kwargs
|
| 1299 |
+
):
|
| 1300 |
+
t5xxl_ids = kwargs.pop("t5xxl_ids", None)
|
| 1301 |
+
if t5xxl_ids is not None:
|
| 1302 |
+
context = self.preprocess_text_embeds(context, t5xxl_ids, t5xxl_weights=kwargs.pop("t5xxl_weights", None))
|
| 1303 |
+
return super().forward(
|
| 1304 |
+
x, timesteps, context,
|
| 1305 |
+
use_gradient_checkpointing=use_gradient_checkpointing, use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
| 1306 |
+
**kwargs
|
| 1307 |
+
)
|
diffsynth/models/dinov3_image_encoder.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import DINOv3ViTModel, DINOv3ViTImageProcessorFast
|
| 2 |
+
from transformers.models.dinov3_vit.modeling_dinov3_vit import DINOv3ViTConfig
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from ..core.device.npu_compatible_device import get_device_type
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class DINOv3ImageEncoder(DINOv3ViTModel):
|
| 9 |
+
def __init__(self):
|
| 10 |
+
config = DINOv3ViTConfig(
|
| 11 |
+
architectures = [
|
| 12 |
+
"DINOv3ViTModel"
|
| 13 |
+
],
|
| 14 |
+
attention_dropout = 0.0,
|
| 15 |
+
drop_path_rate = 0.0,
|
| 16 |
+
dtype = "float32",
|
| 17 |
+
hidden_act = "silu",
|
| 18 |
+
hidden_size = 4096,
|
| 19 |
+
image_size = 224,
|
| 20 |
+
initializer_range = 0.02,
|
| 21 |
+
intermediate_size = 8192,
|
| 22 |
+
key_bias = False,
|
| 23 |
+
layer_norm_eps = 1e-05,
|
| 24 |
+
layerscale_value = 1.0,
|
| 25 |
+
mlp_bias = True,
|
| 26 |
+
model_type = "dinov3_vit",
|
| 27 |
+
num_attention_heads = 32,
|
| 28 |
+
num_channels = 3,
|
| 29 |
+
num_hidden_layers = 40,
|
| 30 |
+
num_register_tokens = 4,
|
| 31 |
+
patch_size = 16,
|
| 32 |
+
pos_embed_jitter = None,
|
| 33 |
+
pos_embed_rescale = 2.0,
|
| 34 |
+
pos_embed_shift = None,
|
| 35 |
+
proj_bias = True,
|
| 36 |
+
query_bias = False,
|
| 37 |
+
rope_theta = 100.0,
|
| 38 |
+
transformers_version = "4.56.1",
|
| 39 |
+
use_gated_mlp = True,
|
| 40 |
+
value_bias = False
|
| 41 |
+
)
|
| 42 |
+
super().__init__(config)
|
| 43 |
+
self.processor = DINOv3ViTImageProcessorFast(
|
| 44 |
+
crop_size = None,
|
| 45 |
+
data_format = "channels_first",
|
| 46 |
+
default_to_square = True,
|
| 47 |
+
device = None,
|
| 48 |
+
disable_grouping = None,
|
| 49 |
+
do_center_crop = None,
|
| 50 |
+
do_convert_rgb = None,
|
| 51 |
+
do_normalize = True,
|
| 52 |
+
do_rescale = True,
|
| 53 |
+
do_resize = True,
|
| 54 |
+
image_mean = [
|
| 55 |
+
0.485,
|
| 56 |
+
0.456,
|
| 57 |
+
0.406
|
| 58 |
+
],
|
| 59 |
+
image_processor_type = "DINOv3ViTImageProcessorFast",
|
| 60 |
+
image_std = [
|
| 61 |
+
0.229,
|
| 62 |
+
0.224,
|
| 63 |
+
0.225
|
| 64 |
+
],
|
| 65 |
+
input_data_format = None,
|
| 66 |
+
resample = 2,
|
| 67 |
+
rescale_factor = 0.00392156862745098,
|
| 68 |
+
return_tensors = None,
|
| 69 |
+
size = {
|
| 70 |
+
"height": 224,
|
| 71 |
+
"width": 224
|
| 72 |
+
}
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
def forward(self, image, torch_dtype=torch.bfloat16, device=get_device_type()):
|
| 76 |
+
inputs = self.processor(images=image, return_tensors="pt")
|
| 77 |
+
pixel_values = inputs["pixel_values"].to(dtype=torch_dtype, device=device)
|
| 78 |
+
bool_masked_pos = None
|
| 79 |
+
head_mask = None
|
| 80 |
+
|
| 81 |
+
pixel_values = pixel_values.to(torch_dtype)
|
| 82 |
+
hidden_states = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos)
|
| 83 |
+
position_embeddings = self.rope_embeddings(pixel_values)
|
| 84 |
+
|
| 85 |
+
for i, layer_module in enumerate(self.layer):
|
| 86 |
+
layer_head_mask = head_mask[i] if head_mask is not None else None
|
| 87 |
+
hidden_states = layer_module(
|
| 88 |
+
hidden_states,
|
| 89 |
+
attention_mask=layer_head_mask,
|
| 90 |
+
position_embeddings=position_embeddings,
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
sequence_output = self.norm(hidden_states)
|
| 94 |
+
pooled_output = sequence_output[:, 0, :]
|
| 95 |
+
|
| 96 |
+
return pooled_output
|
diffsynth/models/flux2_dit.py
ADDED
|
@@ -0,0 +1,1053 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import inspect
|
| 2 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 3 |
+
|
| 4 |
+
import torch, math
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
import numpy as np
|
| 8 |
+
from ..core.attention import attention_forward
|
| 9 |
+
from ..core.gradient import gradient_checkpoint_forward
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def get_timestep_embedding(
|
| 13 |
+
timesteps: torch.Tensor,
|
| 14 |
+
embedding_dim: int,
|
| 15 |
+
flip_sin_to_cos: bool = False,
|
| 16 |
+
downscale_freq_shift: float = 1,
|
| 17 |
+
scale: float = 1,
|
| 18 |
+
max_period: int = 10000,
|
| 19 |
+
) -> torch.Tensor:
|
| 20 |
+
"""
|
| 21 |
+
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
|
| 22 |
+
|
| 23 |
+
Args
|
| 24 |
+
timesteps (torch.Tensor):
|
| 25 |
+
a 1-D Tensor of N indices, one per batch element. These may be fractional.
|
| 26 |
+
embedding_dim (int):
|
| 27 |
+
the dimension of the output.
|
| 28 |
+
flip_sin_to_cos (bool):
|
| 29 |
+
Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False)
|
| 30 |
+
downscale_freq_shift (float):
|
| 31 |
+
Controls the delta between frequencies between dimensions
|
| 32 |
+
scale (float):
|
| 33 |
+
Scaling factor applied to the embeddings.
|
| 34 |
+
max_period (int):
|
| 35 |
+
Controls the maximum frequency of the embeddings
|
| 36 |
+
Returns
|
| 37 |
+
torch.Tensor: an [N x dim] Tensor of positional embeddings.
|
| 38 |
+
"""
|
| 39 |
+
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
|
| 40 |
+
|
| 41 |
+
half_dim = embedding_dim // 2
|
| 42 |
+
exponent = -math.log(max_period) * torch.arange(
|
| 43 |
+
start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
|
| 44 |
+
)
|
| 45 |
+
exponent = exponent / (half_dim - downscale_freq_shift)
|
| 46 |
+
|
| 47 |
+
emb = torch.exp(exponent)
|
| 48 |
+
emb = timesteps[:, None].float() * emb[None, :]
|
| 49 |
+
|
| 50 |
+
# scale embeddings
|
| 51 |
+
emb = scale * emb
|
| 52 |
+
|
| 53 |
+
# concat sine and cosine embeddings
|
| 54 |
+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
|
| 55 |
+
|
| 56 |
+
# flip sine and cosine embeddings
|
| 57 |
+
if flip_sin_to_cos:
|
| 58 |
+
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
|
| 59 |
+
|
| 60 |
+
# zero pad
|
| 61 |
+
if embedding_dim % 2 == 1:
|
| 62 |
+
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
|
| 63 |
+
return emb
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class TimestepEmbedding(nn.Module):
|
| 67 |
+
def __init__(
|
| 68 |
+
self,
|
| 69 |
+
in_channels: int,
|
| 70 |
+
time_embed_dim: int,
|
| 71 |
+
act_fn: str = "silu",
|
| 72 |
+
out_dim: int = None,
|
| 73 |
+
post_act_fn: Optional[str] = None,
|
| 74 |
+
cond_proj_dim=None,
|
| 75 |
+
sample_proj_bias=True,
|
| 76 |
+
):
|
| 77 |
+
super().__init__()
|
| 78 |
+
|
| 79 |
+
self.linear_1 = nn.Linear(in_channels, time_embed_dim, sample_proj_bias)
|
| 80 |
+
|
| 81 |
+
if cond_proj_dim is not None:
|
| 82 |
+
self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
|
| 83 |
+
else:
|
| 84 |
+
self.cond_proj = None
|
| 85 |
+
|
| 86 |
+
self.act = torch.nn.SiLU()
|
| 87 |
+
|
| 88 |
+
if out_dim is not None:
|
| 89 |
+
time_embed_dim_out = out_dim
|
| 90 |
+
else:
|
| 91 |
+
time_embed_dim_out = time_embed_dim
|
| 92 |
+
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias)
|
| 93 |
+
|
| 94 |
+
if post_act_fn is None:
|
| 95 |
+
self.post_act = None
|
| 96 |
+
|
| 97 |
+
def forward(self, sample, condition=None):
|
| 98 |
+
if condition is not None:
|
| 99 |
+
sample = sample + self.cond_proj(condition)
|
| 100 |
+
sample = self.linear_1(sample)
|
| 101 |
+
|
| 102 |
+
if self.act is not None:
|
| 103 |
+
sample = self.act(sample)
|
| 104 |
+
|
| 105 |
+
sample = self.linear_2(sample)
|
| 106 |
+
|
| 107 |
+
if self.post_act is not None:
|
| 108 |
+
sample = self.post_act(sample)
|
| 109 |
+
return sample
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
class Timesteps(nn.Module):
|
| 113 |
+
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1):
|
| 114 |
+
super().__init__()
|
| 115 |
+
self.num_channels = num_channels
|
| 116 |
+
self.flip_sin_to_cos = flip_sin_to_cos
|
| 117 |
+
self.downscale_freq_shift = downscale_freq_shift
|
| 118 |
+
self.scale = scale
|
| 119 |
+
|
| 120 |
+
def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
|
| 121 |
+
t_emb = get_timestep_embedding(
|
| 122 |
+
timesteps,
|
| 123 |
+
self.num_channels,
|
| 124 |
+
flip_sin_to_cos=self.flip_sin_to_cos,
|
| 125 |
+
downscale_freq_shift=self.downscale_freq_shift,
|
| 126 |
+
scale=self.scale,
|
| 127 |
+
)
|
| 128 |
+
return t_emb
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
class AdaLayerNormContinuous(nn.Module):
|
| 132 |
+
r"""
|
| 133 |
+
Adaptive normalization layer with a norm layer (layer_norm or rms_norm).
|
| 134 |
+
|
| 135 |
+
Args:
|
| 136 |
+
embedding_dim (`int`): Embedding dimension to use during projection.
|
| 137 |
+
conditioning_embedding_dim (`int`): Dimension of the input condition.
|
| 138 |
+
elementwise_affine (`bool`, defaults to `True`):
|
| 139 |
+
Boolean flag to denote if affine transformation should be applied.
|
| 140 |
+
eps (`float`, defaults to 1e-5): Epsilon factor.
|
| 141 |
+
bias (`bias`, defaults to `True`): Boolean flag to denote if bias should be use.
|
| 142 |
+
norm_type (`str`, defaults to `"layer_norm"`):
|
| 143 |
+
Normalization layer to use. Values supported: "layer_norm", "rms_norm".
|
| 144 |
+
"""
|
| 145 |
+
|
| 146 |
+
def __init__(
|
| 147 |
+
self,
|
| 148 |
+
embedding_dim: int,
|
| 149 |
+
conditioning_embedding_dim: int,
|
| 150 |
+
# NOTE: It is a bit weird that the norm layer can be configured to have scale and shift parameters
|
| 151 |
+
# because the output is immediately scaled and shifted by the projected conditioning embeddings.
|
| 152 |
+
# Note that AdaLayerNorm does not let the norm layer have scale and shift parameters.
|
| 153 |
+
# However, this is how it was implemented in the original code, and it's rather likely you should
|
| 154 |
+
# set `elementwise_affine` to False.
|
| 155 |
+
elementwise_affine=True,
|
| 156 |
+
eps=1e-5,
|
| 157 |
+
bias=True,
|
| 158 |
+
norm_type="layer_norm",
|
| 159 |
+
):
|
| 160 |
+
super().__init__()
|
| 161 |
+
self.silu = nn.SiLU()
|
| 162 |
+
self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias)
|
| 163 |
+
if norm_type == "layer_norm":
|
| 164 |
+
self.norm = nn.LayerNorm(embedding_dim, eps, elementwise_affine, bias)
|
| 165 |
+
|
| 166 |
+
def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor:
|
| 167 |
+
# convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT)
|
| 168 |
+
emb = self.linear(self.silu(conditioning_embedding).to(x.dtype))
|
| 169 |
+
scale, shift = torch.chunk(emb, 2, dim=1)
|
| 170 |
+
x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
|
| 171 |
+
return x
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def get_1d_rotary_pos_embed(
|
| 175 |
+
dim: int,
|
| 176 |
+
pos: Union[np.ndarray, int],
|
| 177 |
+
theta: float = 10000.0,
|
| 178 |
+
use_real=False,
|
| 179 |
+
linear_factor=1.0,
|
| 180 |
+
ntk_factor=1.0,
|
| 181 |
+
repeat_interleave_real=True,
|
| 182 |
+
freqs_dtype=torch.float32, # torch.float32, torch.float64 (flux)
|
| 183 |
+
):
|
| 184 |
+
"""
|
| 185 |
+
Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
|
| 186 |
+
|
| 187 |
+
This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end
|
| 188 |
+
index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64
|
| 189 |
+
data type.
|
| 190 |
+
|
| 191 |
+
Args:
|
| 192 |
+
dim (`int`): Dimension of the frequency tensor.
|
| 193 |
+
pos (`np.ndarray` or `int`): Position indices for the frequency tensor. [S] or scalar
|
| 194 |
+
theta (`float`, *optional*, defaults to 10000.0):
|
| 195 |
+
Scaling factor for frequency computation. Defaults to 10000.0.
|
| 196 |
+
use_real (`bool`, *optional*):
|
| 197 |
+
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
|
| 198 |
+
linear_factor (`float`, *optional*, defaults to 1.0):
|
| 199 |
+
Scaling factor for the context extrapolation. Defaults to 1.0.
|
| 200 |
+
ntk_factor (`float`, *optional*, defaults to 1.0):
|
| 201 |
+
Scaling factor for the NTK-Aware RoPE. Defaults to 1.0.
|
| 202 |
+
repeat_interleave_real (`bool`, *optional*, defaults to `True`):
|
| 203 |
+
If `True` and `use_real`, real part and imaginary part are each interleaved with themselves to reach `dim`.
|
| 204 |
+
Otherwise, they are concateanted with themselves.
|
| 205 |
+
freqs_dtype (`torch.float32` or `torch.float64`, *optional*, defaults to `torch.float32`):
|
| 206 |
+
the dtype of the frequency tensor.
|
| 207 |
+
Returns:
|
| 208 |
+
`torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2]
|
| 209 |
+
"""
|
| 210 |
+
assert dim % 2 == 0
|
| 211 |
+
|
| 212 |
+
if isinstance(pos, int):
|
| 213 |
+
pos = torch.arange(pos)
|
| 214 |
+
if isinstance(pos, np.ndarray):
|
| 215 |
+
pos = torch.from_numpy(pos) # type: ignore # [S]
|
| 216 |
+
|
| 217 |
+
theta = theta * ntk_factor
|
| 218 |
+
freqs = (
|
| 219 |
+
1.0 / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device) / dim)) / linear_factor
|
| 220 |
+
) # [D/2]
|
| 221 |
+
freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2]
|
| 222 |
+
is_npu = freqs.device.type == "npu"
|
| 223 |
+
if is_npu:
|
| 224 |
+
freqs = freqs.float()
|
| 225 |
+
if use_real and repeat_interleave_real:
|
| 226 |
+
# flux, hunyuan-dit, cogvideox
|
| 227 |
+
freqs_cos = freqs.cos().repeat_interleave(2, dim=1, output_size=freqs.shape[1] * 2).float() # [S, D]
|
| 228 |
+
freqs_sin = freqs.sin().repeat_interleave(2, dim=1, output_size=freqs.shape[1] * 2).float() # [S, D]
|
| 229 |
+
return freqs_cos, freqs_sin
|
| 230 |
+
elif use_real:
|
| 231 |
+
# stable audio, allegro
|
| 232 |
+
freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D]
|
| 233 |
+
freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D]
|
| 234 |
+
return freqs_cos, freqs_sin
|
| 235 |
+
else:
|
| 236 |
+
# lumina
|
| 237 |
+
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
|
| 238 |
+
return freqs_cis
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
def apply_rotary_emb(
|
| 242 |
+
x: torch.Tensor,
|
| 243 |
+
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
|
| 244 |
+
use_real: bool = True,
|
| 245 |
+
use_real_unbind_dim: int = -1,
|
| 246 |
+
sequence_dim: int = 2,
|
| 247 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 248 |
+
"""
|
| 249 |
+
Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
|
| 250 |
+
to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
|
| 251 |
+
reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
|
| 252 |
+
tensors contain rotary embeddings and are returned as real tensors.
|
| 253 |
+
|
| 254 |
+
Args:
|
| 255 |
+
x (`torch.Tensor`):
|
| 256 |
+
Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
|
| 257 |
+
freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
|
| 258 |
+
|
| 259 |
+
Returns:
|
| 260 |
+
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
|
| 261 |
+
"""
|
| 262 |
+
if use_real:
|
| 263 |
+
cos, sin = freqs_cis # [S, D]
|
| 264 |
+
if sequence_dim == 2:
|
| 265 |
+
cos = cos[None, None, :, :]
|
| 266 |
+
sin = sin[None, None, :, :]
|
| 267 |
+
elif sequence_dim == 1:
|
| 268 |
+
cos = cos[None, :, None, :]
|
| 269 |
+
sin = sin[None, :, None, :]
|
| 270 |
+
else:
|
| 271 |
+
raise ValueError(f"`sequence_dim={sequence_dim}` but should be 1 or 2.")
|
| 272 |
+
|
| 273 |
+
cos, sin = cos.to(x.device), sin.to(x.device)
|
| 274 |
+
|
| 275 |
+
if use_real_unbind_dim == -1:
|
| 276 |
+
# Used for flux, cogvideox, hunyuan-dit
|
| 277 |
+
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, H, S, D//2]
|
| 278 |
+
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
| 279 |
+
elif use_real_unbind_dim == -2:
|
| 280 |
+
# Used for Stable Audio, OmniGen, CogView4 and Cosmos
|
| 281 |
+
x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, H, S, D//2]
|
| 282 |
+
x_rotated = torch.cat([-x_imag, x_real], dim=-1)
|
| 283 |
+
else:
|
| 284 |
+
raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
|
| 285 |
+
|
| 286 |
+
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
|
| 287 |
+
|
| 288 |
+
return out
|
| 289 |
+
else:
|
| 290 |
+
# used for lumina
|
| 291 |
+
x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
|
| 292 |
+
freqs_cis = freqs_cis.unsqueeze(2)
|
| 293 |
+
x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
|
| 294 |
+
|
| 295 |
+
return x_out.type_as(x)
|
| 296 |
+
|
| 297 |
+
def _get_projections(attn: "Flux2Attention", hidden_states, encoder_hidden_states=None):
|
| 298 |
+
query = attn.to_q(hidden_states)
|
| 299 |
+
key = attn.to_k(hidden_states)
|
| 300 |
+
value = attn.to_v(hidden_states)
|
| 301 |
+
|
| 302 |
+
encoder_query = encoder_key = encoder_value = None
|
| 303 |
+
if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None:
|
| 304 |
+
encoder_query = attn.add_q_proj(encoder_hidden_states)
|
| 305 |
+
encoder_key = attn.add_k_proj(encoder_hidden_states)
|
| 306 |
+
encoder_value = attn.add_v_proj(encoder_hidden_states)
|
| 307 |
+
|
| 308 |
+
return query, key, value, encoder_query, encoder_key, encoder_value
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
def _get_fused_projections(attn: "Flux2Attention", hidden_states, encoder_hidden_states=None):
|
| 312 |
+
query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1)
|
| 313 |
+
|
| 314 |
+
encoder_query = encoder_key = encoder_value = (None,)
|
| 315 |
+
if encoder_hidden_states is not None and hasattr(attn, "to_added_qkv"):
|
| 316 |
+
encoder_query, encoder_key, encoder_value = attn.to_added_qkv(encoder_hidden_states).chunk(3, dim=-1)
|
| 317 |
+
|
| 318 |
+
return query, key, value, encoder_query, encoder_key, encoder_value
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
def _get_qkv_projections(attn: "Flux2Attention", hidden_states, encoder_hidden_states=None):
|
| 322 |
+
return _get_projections(attn, hidden_states, encoder_hidden_states)
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
class Flux2SwiGLU(nn.Module):
|
| 326 |
+
"""
|
| 327 |
+
Flux 2 uses a SwiGLU-style activation in the transformer feedforward sub-blocks, but with the linear projection
|
| 328 |
+
layer fused into the first linear layer of the FF sub-block. Thus, this module has no trainable parameters.
|
| 329 |
+
"""
|
| 330 |
+
|
| 331 |
+
def __init__(self):
|
| 332 |
+
super().__init__()
|
| 333 |
+
self.gate_fn = nn.SiLU()
|
| 334 |
+
|
| 335 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 336 |
+
x1, x2 = x.chunk(2, dim=-1)
|
| 337 |
+
x = self.gate_fn(x1) * x2
|
| 338 |
+
return x
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
class Flux2FeedForward(nn.Module):
|
| 342 |
+
def __init__(
|
| 343 |
+
self,
|
| 344 |
+
dim: int,
|
| 345 |
+
dim_out: Optional[int] = None,
|
| 346 |
+
mult: float = 3.0,
|
| 347 |
+
inner_dim: Optional[int] = None,
|
| 348 |
+
bias: bool = False,
|
| 349 |
+
):
|
| 350 |
+
super().__init__()
|
| 351 |
+
if inner_dim is None:
|
| 352 |
+
inner_dim = int(dim * mult)
|
| 353 |
+
dim_out = dim_out or dim
|
| 354 |
+
|
| 355 |
+
# Flux2SwiGLU will reduce the dimension by half
|
| 356 |
+
self.linear_in = nn.Linear(dim, inner_dim * 2, bias=bias)
|
| 357 |
+
self.act_fn = Flux2SwiGLU()
|
| 358 |
+
self.linear_out = nn.Linear(inner_dim, dim_out, bias=bias)
|
| 359 |
+
|
| 360 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 361 |
+
x = self.linear_in(x)
|
| 362 |
+
x = self.act_fn(x)
|
| 363 |
+
x = self.linear_out(x)
|
| 364 |
+
return x
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
class Flux2AttnProcessor:
|
| 368 |
+
_attention_backend = None
|
| 369 |
+
_parallel_config = None
|
| 370 |
+
|
| 371 |
+
def __init__(self):
|
| 372 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
| 373 |
+
raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.")
|
| 374 |
+
|
| 375 |
+
def __call__(
|
| 376 |
+
self,
|
| 377 |
+
attn: "Flux2Attention",
|
| 378 |
+
hidden_states: torch.Tensor,
|
| 379 |
+
encoder_hidden_states: torch.Tensor = None,
|
| 380 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 381 |
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
| 382 |
+
) -> torch.Tensor:
|
| 383 |
+
query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections(
|
| 384 |
+
attn, hidden_states, encoder_hidden_states
|
| 385 |
+
)
|
| 386 |
+
|
| 387 |
+
query = query.unflatten(-1, (attn.heads, -1))
|
| 388 |
+
key = key.unflatten(-1, (attn.heads, -1))
|
| 389 |
+
value = value.unflatten(-1, (attn.heads, -1))
|
| 390 |
+
|
| 391 |
+
query = attn.norm_q(query)
|
| 392 |
+
key = attn.norm_k(key)
|
| 393 |
+
|
| 394 |
+
if attn.added_kv_proj_dim is not None:
|
| 395 |
+
encoder_query = encoder_query.unflatten(-1, (attn.heads, -1))
|
| 396 |
+
encoder_key = encoder_key.unflatten(-1, (attn.heads, -1))
|
| 397 |
+
encoder_value = encoder_value.unflatten(-1, (attn.heads, -1))
|
| 398 |
+
|
| 399 |
+
encoder_query = attn.norm_added_q(encoder_query)
|
| 400 |
+
encoder_key = attn.norm_added_k(encoder_key)
|
| 401 |
+
|
| 402 |
+
query = torch.cat([encoder_query, query], dim=1)
|
| 403 |
+
key = torch.cat([encoder_key, key], dim=1)
|
| 404 |
+
value = torch.cat([encoder_value, value], dim=1)
|
| 405 |
+
|
| 406 |
+
if image_rotary_emb is not None:
|
| 407 |
+
query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
|
| 408 |
+
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
|
| 409 |
+
|
| 410 |
+
query, key, value = query.to(hidden_states.dtype), key.to(hidden_states.dtype), value.to(hidden_states.dtype)
|
| 411 |
+
hidden_states = attention_forward(
|
| 412 |
+
query,
|
| 413 |
+
key,
|
| 414 |
+
value,
|
| 415 |
+
q_pattern="b s n d", k_pattern="b s n d", v_pattern="b s n d", out_pattern="b s n d",
|
| 416 |
+
)
|
| 417 |
+
hidden_states = hidden_states.flatten(2, 3)
|
| 418 |
+
hidden_states = hidden_states.to(query.dtype)
|
| 419 |
+
|
| 420 |
+
if encoder_hidden_states is not None:
|
| 421 |
+
encoder_hidden_states, hidden_states = hidden_states.split_with_sizes(
|
| 422 |
+
[encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1
|
| 423 |
+
)
|
| 424 |
+
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
| 425 |
+
|
| 426 |
+
hidden_states = attn.to_out[0](hidden_states)
|
| 427 |
+
hidden_states = attn.to_out[1](hidden_states)
|
| 428 |
+
|
| 429 |
+
if encoder_hidden_states is not None:
|
| 430 |
+
return hidden_states, encoder_hidden_states
|
| 431 |
+
else:
|
| 432 |
+
return hidden_states
|
| 433 |
+
|
| 434 |
+
|
| 435 |
+
class Flux2Attention(torch.nn.Module):
|
| 436 |
+
_default_processor_cls = Flux2AttnProcessor
|
| 437 |
+
_available_processors = [Flux2AttnProcessor]
|
| 438 |
+
|
| 439 |
+
def __init__(
|
| 440 |
+
self,
|
| 441 |
+
query_dim: int,
|
| 442 |
+
heads: int = 8,
|
| 443 |
+
dim_head: int = 64,
|
| 444 |
+
dropout: float = 0.0,
|
| 445 |
+
bias: bool = False,
|
| 446 |
+
added_kv_proj_dim: Optional[int] = None,
|
| 447 |
+
added_proj_bias: Optional[bool] = True,
|
| 448 |
+
out_bias: bool = True,
|
| 449 |
+
eps: float = 1e-5,
|
| 450 |
+
out_dim: int = None,
|
| 451 |
+
elementwise_affine: bool = True,
|
| 452 |
+
processor=None,
|
| 453 |
+
):
|
| 454 |
+
super().__init__()
|
| 455 |
+
|
| 456 |
+
self.head_dim = dim_head
|
| 457 |
+
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
|
| 458 |
+
self.query_dim = query_dim
|
| 459 |
+
self.out_dim = out_dim if out_dim is not None else query_dim
|
| 460 |
+
self.heads = out_dim // dim_head if out_dim is not None else heads
|
| 461 |
+
|
| 462 |
+
self.use_bias = bias
|
| 463 |
+
self.dropout = dropout
|
| 464 |
+
|
| 465 |
+
self.added_kv_proj_dim = added_kv_proj_dim
|
| 466 |
+
self.added_proj_bias = added_proj_bias
|
| 467 |
+
|
| 468 |
+
self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
|
| 469 |
+
self.to_k = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
|
| 470 |
+
self.to_v = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
|
| 471 |
+
|
| 472 |
+
# QK Norm
|
| 473 |
+
self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
| 474 |
+
self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
| 475 |
+
|
| 476 |
+
self.to_out = torch.nn.ModuleList([])
|
| 477 |
+
self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
|
| 478 |
+
self.to_out.append(torch.nn.Dropout(dropout))
|
| 479 |
+
|
| 480 |
+
if added_kv_proj_dim is not None:
|
| 481 |
+
self.norm_added_q = torch.nn.RMSNorm(dim_head, eps=eps)
|
| 482 |
+
self.norm_added_k = torch.nn.RMSNorm(dim_head, eps=eps)
|
| 483 |
+
self.add_q_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
|
| 484 |
+
self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
|
| 485 |
+
self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
|
| 486 |
+
self.to_add_out = torch.nn.Linear(self.inner_dim, query_dim, bias=out_bias)
|
| 487 |
+
|
| 488 |
+
if processor is None:
|
| 489 |
+
processor = self._default_processor_cls()
|
| 490 |
+
self.processor = processor
|
| 491 |
+
|
| 492 |
+
def forward(
|
| 493 |
+
self,
|
| 494 |
+
hidden_states: torch.Tensor,
|
| 495 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
| 496 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 497 |
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
| 498 |
+
**kwargs,
|
| 499 |
+
) -> torch.Tensor:
|
| 500 |
+
attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
|
| 501 |
+
kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters}
|
| 502 |
+
return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs)
|
| 503 |
+
|
| 504 |
+
|
| 505 |
+
class Flux2ParallelSelfAttnProcessor:
|
| 506 |
+
_attention_backend = None
|
| 507 |
+
_parallel_config = None
|
| 508 |
+
|
| 509 |
+
def __init__(self):
|
| 510 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
| 511 |
+
raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.")
|
| 512 |
+
|
| 513 |
+
def __call__(
|
| 514 |
+
self,
|
| 515 |
+
attn: "Flux2ParallelSelfAttention",
|
| 516 |
+
hidden_states: torch.Tensor,
|
| 517 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 518 |
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
| 519 |
+
) -> torch.Tensor:
|
| 520 |
+
# Parallel in (QKV + MLP in) projection
|
| 521 |
+
hidden_states = attn.to_qkv_mlp_proj(hidden_states)
|
| 522 |
+
qkv, mlp_hidden_states = torch.split(
|
| 523 |
+
hidden_states, [3 * attn.inner_dim, attn.mlp_hidden_dim * attn.mlp_mult_factor], dim=-1
|
| 524 |
+
)
|
| 525 |
+
|
| 526 |
+
# Handle the attention logic
|
| 527 |
+
query, key, value = qkv.chunk(3, dim=-1)
|
| 528 |
+
|
| 529 |
+
query = query.unflatten(-1, (attn.heads, -1))
|
| 530 |
+
key = key.unflatten(-1, (attn.heads, -1))
|
| 531 |
+
value = value.unflatten(-1, (attn.heads, -1))
|
| 532 |
+
|
| 533 |
+
query = attn.norm_q(query)
|
| 534 |
+
key = attn.norm_k(key)
|
| 535 |
+
|
| 536 |
+
if image_rotary_emb is not None:
|
| 537 |
+
query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
|
| 538 |
+
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
|
| 539 |
+
|
| 540 |
+
query, key, value = query.to(hidden_states.dtype), key.to(hidden_states.dtype), value.to(hidden_states.dtype)
|
| 541 |
+
hidden_states = attention_forward(
|
| 542 |
+
query,
|
| 543 |
+
key,
|
| 544 |
+
value,
|
| 545 |
+
q_pattern="b s n d", k_pattern="b s n d", v_pattern="b s n d", out_pattern="b s n d",
|
| 546 |
+
)
|
| 547 |
+
hidden_states = hidden_states.flatten(2, 3)
|
| 548 |
+
hidden_states = hidden_states.to(query.dtype)
|
| 549 |
+
|
| 550 |
+
# Handle the feedforward (FF) logic
|
| 551 |
+
mlp_hidden_states = attn.mlp_act_fn(mlp_hidden_states)
|
| 552 |
+
|
| 553 |
+
# Concatenate and parallel output projection
|
| 554 |
+
hidden_states = torch.cat([hidden_states, mlp_hidden_states], dim=-1)
|
| 555 |
+
hidden_states = attn.to_out(hidden_states)
|
| 556 |
+
|
| 557 |
+
return hidden_states
|
| 558 |
+
|
| 559 |
+
|
| 560 |
+
class Flux2ParallelSelfAttention(torch.nn.Module):
|
| 561 |
+
"""
|
| 562 |
+
Flux 2 parallel self-attention for the Flux 2 single-stream transformer blocks.
|
| 563 |
+
|
| 564 |
+
This implements a parallel transformer block, where the attention QKV projections are fused to the feedforward (FF)
|
| 565 |
+
input projections, and the attention output projections are fused to the FF output projections. See the [ViT-22B
|
| 566 |
+
paper](https://arxiv.org/abs/2302.05442) for a visual depiction of this type of transformer block.
|
| 567 |
+
"""
|
| 568 |
+
|
| 569 |
+
_default_processor_cls = Flux2ParallelSelfAttnProcessor
|
| 570 |
+
_available_processors = [Flux2ParallelSelfAttnProcessor]
|
| 571 |
+
# Does not support QKV fusion as the QKV projections are always fused
|
| 572 |
+
_supports_qkv_fusion = False
|
| 573 |
+
|
| 574 |
+
def __init__(
|
| 575 |
+
self,
|
| 576 |
+
query_dim: int,
|
| 577 |
+
heads: int = 8,
|
| 578 |
+
dim_head: int = 64,
|
| 579 |
+
dropout: float = 0.0,
|
| 580 |
+
bias: bool = False,
|
| 581 |
+
out_bias: bool = True,
|
| 582 |
+
eps: float = 1e-5,
|
| 583 |
+
out_dim: int = None,
|
| 584 |
+
elementwise_affine: bool = True,
|
| 585 |
+
mlp_ratio: float = 4.0,
|
| 586 |
+
mlp_mult_factor: int = 2,
|
| 587 |
+
processor=None,
|
| 588 |
+
):
|
| 589 |
+
super().__init__()
|
| 590 |
+
|
| 591 |
+
self.head_dim = dim_head
|
| 592 |
+
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
|
| 593 |
+
self.query_dim = query_dim
|
| 594 |
+
self.out_dim = out_dim if out_dim is not None else query_dim
|
| 595 |
+
self.heads = out_dim // dim_head if out_dim is not None else heads
|
| 596 |
+
|
| 597 |
+
self.use_bias = bias
|
| 598 |
+
self.dropout = dropout
|
| 599 |
+
|
| 600 |
+
self.mlp_ratio = mlp_ratio
|
| 601 |
+
self.mlp_hidden_dim = int(query_dim * self.mlp_ratio)
|
| 602 |
+
self.mlp_mult_factor = mlp_mult_factor
|
| 603 |
+
|
| 604 |
+
# Fused QKV projections + MLP input projection
|
| 605 |
+
self.to_qkv_mlp_proj = torch.nn.Linear(
|
| 606 |
+
self.query_dim, self.inner_dim * 3 + self.mlp_hidden_dim * self.mlp_mult_factor, bias=bias
|
| 607 |
+
)
|
| 608 |
+
self.mlp_act_fn = Flux2SwiGLU()
|
| 609 |
+
|
| 610 |
+
# QK Norm
|
| 611 |
+
self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
| 612 |
+
self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
| 613 |
+
|
| 614 |
+
# Fused attention output projection + MLP output projection
|
| 615 |
+
self.to_out = torch.nn.Linear(self.inner_dim + self.mlp_hidden_dim, self.out_dim, bias=out_bias)
|
| 616 |
+
|
| 617 |
+
if processor is None:
|
| 618 |
+
processor = self._default_processor_cls()
|
| 619 |
+
self.processor = processor
|
| 620 |
+
|
| 621 |
+
def forward(
|
| 622 |
+
self,
|
| 623 |
+
hidden_states: torch.Tensor,
|
| 624 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 625 |
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
| 626 |
+
**kwargs,
|
| 627 |
+
) -> torch.Tensor:
|
| 628 |
+
attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
|
| 629 |
+
kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters}
|
| 630 |
+
return self.processor(self, hidden_states, attention_mask, image_rotary_emb, **kwargs)
|
| 631 |
+
|
| 632 |
+
|
| 633 |
+
class Flux2SingleTransformerBlock(nn.Module):
|
| 634 |
+
def __init__(
|
| 635 |
+
self,
|
| 636 |
+
dim: int,
|
| 637 |
+
num_attention_heads: int,
|
| 638 |
+
attention_head_dim: int,
|
| 639 |
+
mlp_ratio: float = 3.0,
|
| 640 |
+
eps: float = 1e-6,
|
| 641 |
+
bias: bool = False,
|
| 642 |
+
):
|
| 643 |
+
super().__init__()
|
| 644 |
+
|
| 645 |
+
self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
|
| 646 |
+
|
| 647 |
+
# Note that the MLP in/out linear layers are fused with the attention QKV/out projections, respectively; this
|
| 648 |
+
# is often called a "parallel" transformer block. See the [ViT-22B paper](https://arxiv.org/abs/2302.05442)
|
| 649 |
+
# for a visual depiction of this type of transformer block.
|
| 650 |
+
self.attn = Flux2ParallelSelfAttention(
|
| 651 |
+
query_dim=dim,
|
| 652 |
+
dim_head=attention_head_dim,
|
| 653 |
+
heads=num_attention_heads,
|
| 654 |
+
out_dim=dim,
|
| 655 |
+
bias=bias,
|
| 656 |
+
out_bias=bias,
|
| 657 |
+
eps=eps,
|
| 658 |
+
mlp_ratio=mlp_ratio,
|
| 659 |
+
mlp_mult_factor=2,
|
| 660 |
+
processor=Flux2ParallelSelfAttnProcessor(),
|
| 661 |
+
)
|
| 662 |
+
|
| 663 |
+
def forward(
|
| 664 |
+
self,
|
| 665 |
+
hidden_states: torch.Tensor,
|
| 666 |
+
encoder_hidden_states: Optional[torch.Tensor],
|
| 667 |
+
temb_mod_params: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
|
| 668 |
+
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 669 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 670 |
+
split_hidden_states: bool = False,
|
| 671 |
+
text_seq_len: Optional[int] = None,
|
| 672 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 673 |
+
# If encoder_hidden_states is None, hidden_states is assumed to have encoder_hidden_states already
|
| 674 |
+
# concatenated
|
| 675 |
+
if encoder_hidden_states is not None:
|
| 676 |
+
text_seq_len = encoder_hidden_states.shape[1]
|
| 677 |
+
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
| 678 |
+
|
| 679 |
+
mod_shift, mod_scale, mod_gate = temb_mod_params
|
| 680 |
+
|
| 681 |
+
norm_hidden_states = self.norm(hidden_states)
|
| 682 |
+
norm_hidden_states = (1 + mod_scale) * norm_hidden_states + mod_shift
|
| 683 |
+
|
| 684 |
+
joint_attention_kwargs = joint_attention_kwargs or {}
|
| 685 |
+
attn_output = self.attn(
|
| 686 |
+
hidden_states=norm_hidden_states,
|
| 687 |
+
image_rotary_emb=image_rotary_emb,
|
| 688 |
+
**joint_attention_kwargs,
|
| 689 |
+
)
|
| 690 |
+
|
| 691 |
+
hidden_states = hidden_states + mod_gate * attn_output
|
| 692 |
+
if hidden_states.dtype == torch.float16:
|
| 693 |
+
hidden_states = hidden_states.clip(-65504, 65504)
|
| 694 |
+
|
| 695 |
+
if split_hidden_states:
|
| 696 |
+
encoder_hidden_states, hidden_states = hidden_states[:, :text_seq_len], hidden_states[:, text_seq_len:]
|
| 697 |
+
return encoder_hidden_states, hidden_states
|
| 698 |
+
else:
|
| 699 |
+
return hidden_states
|
| 700 |
+
|
| 701 |
+
|
| 702 |
+
class Flux2TransformerBlock(nn.Module):
|
| 703 |
+
def __init__(
|
| 704 |
+
self,
|
| 705 |
+
dim: int,
|
| 706 |
+
num_attention_heads: int,
|
| 707 |
+
attention_head_dim: int,
|
| 708 |
+
mlp_ratio: float = 3.0,
|
| 709 |
+
eps: float = 1e-6,
|
| 710 |
+
bias: bool = False,
|
| 711 |
+
):
|
| 712 |
+
super().__init__()
|
| 713 |
+
self.mlp_hidden_dim = int(dim * mlp_ratio)
|
| 714 |
+
|
| 715 |
+
self.norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
|
| 716 |
+
self.norm1_context = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
|
| 717 |
+
|
| 718 |
+
self.attn = Flux2Attention(
|
| 719 |
+
query_dim=dim,
|
| 720 |
+
added_kv_proj_dim=dim,
|
| 721 |
+
dim_head=attention_head_dim,
|
| 722 |
+
heads=num_attention_heads,
|
| 723 |
+
out_dim=dim,
|
| 724 |
+
bias=bias,
|
| 725 |
+
added_proj_bias=bias,
|
| 726 |
+
out_bias=bias,
|
| 727 |
+
eps=eps,
|
| 728 |
+
processor=Flux2AttnProcessor(),
|
| 729 |
+
)
|
| 730 |
+
|
| 731 |
+
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
|
| 732 |
+
self.ff = Flux2FeedForward(dim=dim, dim_out=dim, mult=mlp_ratio, bias=bias)
|
| 733 |
+
|
| 734 |
+
self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
|
| 735 |
+
self.ff_context = Flux2FeedForward(dim=dim, dim_out=dim, mult=mlp_ratio, bias=bias)
|
| 736 |
+
|
| 737 |
+
def forward(
|
| 738 |
+
self,
|
| 739 |
+
hidden_states: torch.Tensor,
|
| 740 |
+
encoder_hidden_states: torch.Tensor,
|
| 741 |
+
temb_mod_params_img: Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...],
|
| 742 |
+
temb_mod_params_txt: Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...],
|
| 743 |
+
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 744 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 745 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 746 |
+
joint_attention_kwargs = joint_attention_kwargs or {}
|
| 747 |
+
|
| 748 |
+
# Modulation parameters shape: [1, 1, self.dim]
|
| 749 |
+
(shift_msa, scale_msa, gate_msa), (shift_mlp, scale_mlp, gate_mlp) = temb_mod_params_img
|
| 750 |
+
(c_shift_msa, c_scale_msa, c_gate_msa), (c_shift_mlp, c_scale_mlp, c_gate_mlp) = temb_mod_params_txt
|
| 751 |
+
|
| 752 |
+
# Img stream
|
| 753 |
+
norm_hidden_states = self.norm1(hidden_states)
|
| 754 |
+
norm_hidden_states = (1 + scale_msa) * norm_hidden_states + shift_msa
|
| 755 |
+
|
| 756 |
+
# Conditioning txt stream
|
| 757 |
+
norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states)
|
| 758 |
+
norm_encoder_hidden_states = (1 + c_scale_msa) * norm_encoder_hidden_states + c_shift_msa
|
| 759 |
+
|
| 760 |
+
# Attention on concatenated img + txt stream
|
| 761 |
+
attention_outputs = self.attn(
|
| 762 |
+
hidden_states=norm_hidden_states,
|
| 763 |
+
encoder_hidden_states=norm_encoder_hidden_states,
|
| 764 |
+
image_rotary_emb=image_rotary_emb,
|
| 765 |
+
**joint_attention_kwargs,
|
| 766 |
+
)
|
| 767 |
+
|
| 768 |
+
attn_output, context_attn_output = attention_outputs
|
| 769 |
+
|
| 770 |
+
# Process attention outputs for the image stream (`hidden_states`).
|
| 771 |
+
attn_output = gate_msa * attn_output
|
| 772 |
+
hidden_states = hidden_states + attn_output
|
| 773 |
+
|
| 774 |
+
norm_hidden_states = self.norm2(hidden_states)
|
| 775 |
+
norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
|
| 776 |
+
|
| 777 |
+
ff_output = self.ff(norm_hidden_states)
|
| 778 |
+
hidden_states = hidden_states + gate_mlp * ff_output
|
| 779 |
+
|
| 780 |
+
# Process attention outputs for the text stream (`encoder_hidden_states`).
|
| 781 |
+
context_attn_output = c_gate_msa * context_attn_output
|
| 782 |
+
encoder_hidden_states = encoder_hidden_states + context_attn_output
|
| 783 |
+
|
| 784 |
+
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
|
| 785 |
+
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp) + c_shift_mlp
|
| 786 |
+
|
| 787 |
+
context_ff_output = self.ff_context(norm_encoder_hidden_states)
|
| 788 |
+
encoder_hidden_states = encoder_hidden_states + c_gate_mlp * context_ff_output
|
| 789 |
+
if encoder_hidden_states.dtype == torch.float16:
|
| 790 |
+
encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
|
| 791 |
+
|
| 792 |
+
return encoder_hidden_states, hidden_states
|
| 793 |
+
|
| 794 |
+
|
| 795 |
+
class Flux2PosEmbed(nn.Module):
|
| 796 |
+
# modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
|
| 797 |
+
def __init__(self, theta: int, axes_dim: List[int]):
|
| 798 |
+
super().__init__()
|
| 799 |
+
self.theta = theta
|
| 800 |
+
self.axes_dim = axes_dim
|
| 801 |
+
|
| 802 |
+
def forward(self, ids: torch.Tensor) -> torch.Tensor:
|
| 803 |
+
# Expected ids shape: [S, len(self.axes_dim)]
|
| 804 |
+
cos_out = []
|
| 805 |
+
sin_out = []
|
| 806 |
+
pos = ids.float()
|
| 807 |
+
is_mps = ids.device.type == "mps"
|
| 808 |
+
is_npu = ids.device.type == "npu"
|
| 809 |
+
freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64
|
| 810 |
+
# Unlike Flux 1, loop over len(self.axes_dim) rather than ids.shape[-1]
|
| 811 |
+
for i in range(len(self.axes_dim)):
|
| 812 |
+
cos, sin = get_1d_rotary_pos_embed(
|
| 813 |
+
self.axes_dim[i],
|
| 814 |
+
pos[..., i],
|
| 815 |
+
theta=self.theta,
|
| 816 |
+
repeat_interleave_real=True,
|
| 817 |
+
use_real=True,
|
| 818 |
+
freqs_dtype=freqs_dtype,
|
| 819 |
+
)
|
| 820 |
+
cos_out.append(cos)
|
| 821 |
+
sin_out.append(sin)
|
| 822 |
+
freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
|
| 823 |
+
freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
|
| 824 |
+
return freqs_cos, freqs_sin
|
| 825 |
+
|
| 826 |
+
|
| 827 |
+
class Flux2TimestepGuidanceEmbeddings(nn.Module):
|
| 828 |
+
def __init__(
|
| 829 |
+
self,
|
| 830 |
+
in_channels: int = 256,
|
| 831 |
+
embedding_dim: int = 6144,
|
| 832 |
+
bias: bool = False,
|
| 833 |
+
guidance_embeds: bool = True,
|
| 834 |
+
):
|
| 835 |
+
super().__init__()
|
| 836 |
+
|
| 837 |
+
self.time_proj = Timesteps(num_channels=in_channels, flip_sin_to_cos=True, downscale_freq_shift=0)
|
| 838 |
+
self.timestep_embedder = TimestepEmbedding(
|
| 839 |
+
in_channels=in_channels, time_embed_dim=embedding_dim, sample_proj_bias=bias
|
| 840 |
+
)
|
| 841 |
+
|
| 842 |
+
if guidance_embeds:
|
| 843 |
+
self.guidance_embedder = TimestepEmbedding(
|
| 844 |
+
in_channels=in_channels, time_embed_dim=embedding_dim, sample_proj_bias=bias
|
| 845 |
+
)
|
| 846 |
+
else:
|
| 847 |
+
self.guidance_embedder = None
|
| 848 |
+
|
| 849 |
+
def forward(self, timestep: torch.Tensor, guidance: torch.Tensor) -> torch.Tensor:
|
| 850 |
+
timesteps_proj = self.time_proj(timestep)
|
| 851 |
+
timesteps_emb = self.timestep_embedder(timesteps_proj.to(timestep.dtype)) # (N, D)
|
| 852 |
+
|
| 853 |
+
if guidance is not None and self.guidance_embedder is not None:
|
| 854 |
+
guidance_proj = self.time_proj(guidance)
|
| 855 |
+
guidance_emb = self.guidance_embedder(guidance_proj.to(guidance.dtype)) # (N, D)
|
| 856 |
+
time_guidance_emb = timesteps_emb + guidance_emb
|
| 857 |
+
return time_guidance_emb
|
| 858 |
+
else:
|
| 859 |
+
return timesteps_emb
|
| 860 |
+
|
| 861 |
+
|
| 862 |
+
class Flux2Modulation(nn.Module):
|
| 863 |
+
def __init__(self, dim: int, mod_param_sets: int = 2, bias: bool = False):
|
| 864 |
+
super().__init__()
|
| 865 |
+
self.mod_param_sets = mod_param_sets
|
| 866 |
+
|
| 867 |
+
self.linear = nn.Linear(dim, dim * 3 * self.mod_param_sets, bias=bias)
|
| 868 |
+
self.act_fn = nn.SiLU()
|
| 869 |
+
|
| 870 |
+
def forward(self, temb: torch.Tensor) -> Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...]:
|
| 871 |
+
mod = self.act_fn(temb)
|
| 872 |
+
mod = self.linear(mod)
|
| 873 |
+
|
| 874 |
+
if mod.ndim == 2:
|
| 875 |
+
mod = mod.unsqueeze(1)
|
| 876 |
+
mod_params = torch.chunk(mod, 3 * self.mod_param_sets, dim=-1)
|
| 877 |
+
# Return tuple of 3-tuples of modulation params shift/scale/gate
|
| 878 |
+
return tuple(mod_params[3 * i : 3 * (i + 1)] for i in range(self.mod_param_sets))
|
| 879 |
+
|
| 880 |
+
|
| 881 |
+
class Flux2DiT(torch.nn.Module):
|
| 882 |
+
|
| 883 |
+
_repeated_blocks = ["Flux2TransformerBlock", "Flux2SingleTransformerBlock"]
|
| 884 |
+
|
| 885 |
+
def __init__(
|
| 886 |
+
self,
|
| 887 |
+
patch_size: int = 1,
|
| 888 |
+
in_channels: int = 128,
|
| 889 |
+
out_channels: Optional[int] = None,
|
| 890 |
+
num_layers: int = 8,
|
| 891 |
+
num_single_layers: int = 48,
|
| 892 |
+
attention_head_dim: int = 128,
|
| 893 |
+
num_attention_heads: int = 48,
|
| 894 |
+
joint_attention_dim: int = 15360,
|
| 895 |
+
timestep_guidance_channels: int = 256,
|
| 896 |
+
mlp_ratio: float = 3.0,
|
| 897 |
+
axes_dims_rope: Tuple[int, ...] = (32, 32, 32, 32),
|
| 898 |
+
rope_theta: int = 2000,
|
| 899 |
+
eps: float = 1e-6,
|
| 900 |
+
guidance_embeds: bool = True,
|
| 901 |
+
):
|
| 902 |
+
super().__init__()
|
| 903 |
+
self.out_channels = out_channels or in_channels
|
| 904 |
+
self.inner_dim = num_attention_heads * attention_head_dim
|
| 905 |
+
|
| 906 |
+
# 1. Sinusoidal positional embedding for RoPE on image and text tokens
|
| 907 |
+
self.pos_embed = Flux2PosEmbed(theta=rope_theta, axes_dim=axes_dims_rope)
|
| 908 |
+
|
| 909 |
+
# 2. Combined timestep + guidance embedding
|
| 910 |
+
self.time_guidance_embed = Flux2TimestepGuidanceEmbeddings(
|
| 911 |
+
in_channels=timestep_guidance_channels,
|
| 912 |
+
embedding_dim=self.inner_dim,
|
| 913 |
+
bias=False,
|
| 914 |
+
guidance_embeds=guidance_embeds,
|
| 915 |
+
)
|
| 916 |
+
|
| 917 |
+
# 3. Modulation (double stream and single stream blocks share modulation parameters, resp.)
|
| 918 |
+
# Two sets of shift/scale/gate modulation parameters for the double stream attn and FF sub-blocks
|
| 919 |
+
self.double_stream_modulation_img = Flux2Modulation(self.inner_dim, mod_param_sets=2, bias=False)
|
| 920 |
+
self.double_stream_modulation_txt = Flux2Modulation(self.inner_dim, mod_param_sets=2, bias=False)
|
| 921 |
+
# Only one set of modulation parameters as the attn and FF sub-blocks are run in parallel for single stream
|
| 922 |
+
self.single_stream_modulation = Flux2Modulation(self.inner_dim, mod_param_sets=1, bias=False)
|
| 923 |
+
|
| 924 |
+
# 4. Input projections
|
| 925 |
+
self.x_embedder = nn.Linear(in_channels, self.inner_dim, bias=False)
|
| 926 |
+
self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim, bias=False)
|
| 927 |
+
|
| 928 |
+
# 5. Double Stream Transformer Blocks
|
| 929 |
+
self.transformer_blocks = nn.ModuleList(
|
| 930 |
+
[
|
| 931 |
+
Flux2TransformerBlock(
|
| 932 |
+
dim=self.inner_dim,
|
| 933 |
+
num_attention_heads=num_attention_heads,
|
| 934 |
+
attention_head_dim=attention_head_dim,
|
| 935 |
+
mlp_ratio=mlp_ratio,
|
| 936 |
+
eps=eps,
|
| 937 |
+
bias=False,
|
| 938 |
+
)
|
| 939 |
+
for _ in range(num_layers)
|
| 940 |
+
]
|
| 941 |
+
)
|
| 942 |
+
|
| 943 |
+
# 6. Single Stream Transformer Blocks
|
| 944 |
+
self.single_transformer_blocks = nn.ModuleList(
|
| 945 |
+
[
|
| 946 |
+
Flux2SingleTransformerBlock(
|
| 947 |
+
dim=self.inner_dim,
|
| 948 |
+
num_attention_heads=num_attention_heads,
|
| 949 |
+
attention_head_dim=attention_head_dim,
|
| 950 |
+
mlp_ratio=mlp_ratio,
|
| 951 |
+
eps=eps,
|
| 952 |
+
bias=False,
|
| 953 |
+
)
|
| 954 |
+
for _ in range(num_single_layers)
|
| 955 |
+
]
|
| 956 |
+
)
|
| 957 |
+
|
| 958 |
+
# 7. Output layers
|
| 959 |
+
self.norm_out = AdaLayerNormContinuous(
|
| 960 |
+
self.inner_dim, self.inner_dim, elementwise_affine=False, eps=eps, bias=False
|
| 961 |
+
)
|
| 962 |
+
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=False)
|
| 963 |
+
|
| 964 |
+
self.gradient_checkpointing = False
|
| 965 |
+
|
| 966 |
+
def forward(
|
| 967 |
+
self,
|
| 968 |
+
hidden_states: torch.Tensor,
|
| 969 |
+
encoder_hidden_states: torch.Tensor = None,
|
| 970 |
+
timestep: torch.LongTensor = None,
|
| 971 |
+
img_ids: torch.Tensor = None,
|
| 972 |
+
txt_ids: torch.Tensor = None,
|
| 973 |
+
guidance: torch.Tensor = None,
|
| 974 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 975 |
+
use_gradient_checkpointing=False,
|
| 976 |
+
use_gradient_checkpointing_offload=False,
|
| 977 |
+
):
|
| 978 |
+
# 0. Handle input arguments
|
| 979 |
+
if joint_attention_kwargs is not None:
|
| 980 |
+
joint_attention_kwargs = joint_attention_kwargs.copy()
|
| 981 |
+
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
|
| 982 |
+
else:
|
| 983 |
+
lora_scale = 1.0
|
| 984 |
+
|
| 985 |
+
num_txt_tokens = encoder_hidden_states.shape[1]
|
| 986 |
+
|
| 987 |
+
# 1. Calculate timestep embedding and modulation parameters
|
| 988 |
+
timestep = timestep.to(hidden_states.dtype) * 1000
|
| 989 |
+
|
| 990 |
+
if guidance is not None:
|
| 991 |
+
guidance = guidance.to(hidden_states.dtype) * 1000
|
| 992 |
+
|
| 993 |
+
temb = self.time_guidance_embed(timestep, guidance)
|
| 994 |
+
|
| 995 |
+
double_stream_mod_img = self.double_stream_modulation_img(temb)
|
| 996 |
+
double_stream_mod_txt = self.double_stream_modulation_txt(temb)
|
| 997 |
+
single_stream_mod = self.single_stream_modulation(temb)[0]
|
| 998 |
+
|
| 999 |
+
# 2. Input projection for image (hidden_states) and conditioning text (encoder_hidden_states)
|
| 1000 |
+
hidden_states = self.x_embedder(hidden_states)
|
| 1001 |
+
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
|
| 1002 |
+
|
| 1003 |
+
# 3. Calculate RoPE embeddings from image and text tokens
|
| 1004 |
+
# NOTE: the below logic means that we can't support batched inference with images of different resolutions or
|
| 1005 |
+
# text prompts of differents lengths. Is this a use case we want to support?
|
| 1006 |
+
if img_ids.ndim == 3:
|
| 1007 |
+
img_ids = img_ids[0]
|
| 1008 |
+
if txt_ids.ndim == 3:
|
| 1009 |
+
txt_ids = txt_ids[0]
|
| 1010 |
+
|
| 1011 |
+
image_rotary_emb = self.pos_embed(img_ids)
|
| 1012 |
+
text_rotary_emb = self.pos_embed(txt_ids)
|
| 1013 |
+
concat_rotary_emb = (
|
| 1014 |
+
torch.cat([text_rotary_emb[0], image_rotary_emb[0]], dim=0),
|
| 1015 |
+
torch.cat([text_rotary_emb[1], image_rotary_emb[1]], dim=0),
|
| 1016 |
+
)
|
| 1017 |
+
|
| 1018 |
+
# 4. Double Stream Transformer Blocks
|
| 1019 |
+
for index_block, block in enumerate(self.transformer_blocks):
|
| 1020 |
+
encoder_hidden_states, hidden_states = gradient_checkpoint_forward(
|
| 1021 |
+
block,
|
| 1022 |
+
use_gradient_checkpointing=use_gradient_checkpointing,
|
| 1023 |
+
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
| 1024 |
+
hidden_states=hidden_states,
|
| 1025 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 1026 |
+
temb_mod_params_img=double_stream_mod_img,
|
| 1027 |
+
temb_mod_params_txt=double_stream_mod_txt,
|
| 1028 |
+
image_rotary_emb=concat_rotary_emb,
|
| 1029 |
+
joint_attention_kwargs=joint_attention_kwargs,
|
| 1030 |
+
)
|
| 1031 |
+
# Concatenate text and image streams for single-block inference
|
| 1032 |
+
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
| 1033 |
+
|
| 1034 |
+
# 5. Single Stream Transformer Blocks
|
| 1035 |
+
for index_block, block in enumerate(self.single_transformer_blocks):
|
| 1036 |
+
hidden_states = gradient_checkpoint_forward(
|
| 1037 |
+
block,
|
| 1038 |
+
use_gradient_checkpointing=use_gradient_checkpointing,
|
| 1039 |
+
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
| 1040 |
+
hidden_states=hidden_states,
|
| 1041 |
+
encoder_hidden_states=None,
|
| 1042 |
+
temb_mod_params=single_stream_mod,
|
| 1043 |
+
image_rotary_emb=concat_rotary_emb,
|
| 1044 |
+
joint_attention_kwargs=joint_attention_kwargs,
|
| 1045 |
+
)
|
| 1046 |
+
# Remove text tokens from concatenated stream
|
| 1047 |
+
hidden_states = hidden_states[:, num_txt_tokens:, ...]
|
| 1048 |
+
|
| 1049 |
+
# 6. Output layers
|
| 1050 |
+
hidden_states = self.norm_out(hidden_states, temb)
|
| 1051 |
+
output = self.proj_out(hidden_states)
|
| 1052 |
+
|
| 1053 |
+
return output
|
diffsynth/models/flux2_text_encoder.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import Mistral3ForConditionalGeneration, Mistral3Config
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class Flux2TextEncoder(Mistral3ForConditionalGeneration):
|
| 5 |
+
def __init__(self):
|
| 6 |
+
config = Mistral3Config(**{
|
| 7 |
+
"architectures": [
|
| 8 |
+
"Mistral3ForConditionalGeneration"
|
| 9 |
+
],
|
| 10 |
+
"dtype": "bfloat16",
|
| 11 |
+
"image_token_index": 10,
|
| 12 |
+
"model_type": "mistral3",
|
| 13 |
+
"multimodal_projector_bias": False,
|
| 14 |
+
"projector_hidden_act": "gelu",
|
| 15 |
+
"spatial_merge_size": 2,
|
| 16 |
+
"text_config": {
|
| 17 |
+
"attention_dropout": 0.0,
|
| 18 |
+
"dtype": "bfloat16",
|
| 19 |
+
"head_dim": 128,
|
| 20 |
+
"hidden_act": "silu",
|
| 21 |
+
"hidden_size": 5120,
|
| 22 |
+
"initializer_range": 0.02,
|
| 23 |
+
"intermediate_size": 32768,
|
| 24 |
+
"max_position_embeddings": 131072,
|
| 25 |
+
"model_type": "mistral",
|
| 26 |
+
"num_attention_heads": 32,
|
| 27 |
+
"num_hidden_layers": 40,
|
| 28 |
+
"num_key_value_heads": 8,
|
| 29 |
+
"rms_norm_eps": 1e-05,
|
| 30 |
+
"rope_theta": 1000000000.0,
|
| 31 |
+
"sliding_window": None,
|
| 32 |
+
"use_cache": True,
|
| 33 |
+
"vocab_size": 131072
|
| 34 |
+
},
|
| 35 |
+
"transformers_version": "4.57.1",
|
| 36 |
+
"vision_config": {
|
| 37 |
+
"attention_dropout": 0.0,
|
| 38 |
+
"dtype": "bfloat16",
|
| 39 |
+
"head_dim": 64,
|
| 40 |
+
"hidden_act": "silu",
|
| 41 |
+
"hidden_size": 1024,
|
| 42 |
+
"image_size": 1540,
|
| 43 |
+
"initializer_range": 0.02,
|
| 44 |
+
"intermediate_size": 4096,
|
| 45 |
+
"model_type": "pixtral",
|
| 46 |
+
"num_attention_heads": 16,
|
| 47 |
+
"num_channels": 3,
|
| 48 |
+
"num_hidden_layers": 24,
|
| 49 |
+
"patch_size": 14,
|
| 50 |
+
"rope_theta": 10000.0
|
| 51 |
+
},
|
| 52 |
+
"vision_feature_layer": -1
|
| 53 |
+
})
|
| 54 |
+
super().__init__(config)
|
| 55 |
+
|
| 56 |
+
def forward(self, input_ids = None, pixel_values = None, attention_mask = None, position_ids = None, past_key_values = None, inputs_embeds = None, labels = None, use_cache = None, output_attentions = None, output_hidden_states = None, return_dict = None, cache_position = None, logits_to_keep = 0, image_sizes = None, **kwargs):
|
| 57 |
+
return super().forward(input_ids, pixel_values, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, logits_to_keep, image_sizes, **kwargs)
|
| 58 |
+
|
diffsynth/models/flux2_vae.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
diffsynth/models/flux_controlnet.py
ADDED
|
@@ -0,0 +1,384 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from einops import rearrange, repeat
|
| 3 |
+
from .flux_dit import RoPEEmbedding, TimestepEmbeddings, FluxJointTransformerBlock, FluxSingleTransformerBlock, RMSNorm
|
| 4 |
+
# from .utils import hash_state_dict_keys, init_weights_on_device
|
| 5 |
+
from contextlib import contextmanager
|
| 6 |
+
|
| 7 |
+
def hash_state_dict_keys(state_dict, with_shape=True):
|
| 8 |
+
keys_str = convert_state_dict_keys_to_single_str(state_dict, with_shape=with_shape)
|
| 9 |
+
keys_str = keys_str.encode(encoding="UTF-8")
|
| 10 |
+
return hashlib.md5(keys_str).hexdigest()
|
| 11 |
+
|
| 12 |
+
@contextmanager
|
| 13 |
+
def init_weights_on_device(device = torch.device("meta"), include_buffers :bool = False):
|
| 14 |
+
|
| 15 |
+
old_register_parameter = torch.nn.Module.register_parameter
|
| 16 |
+
if include_buffers:
|
| 17 |
+
old_register_buffer = torch.nn.Module.register_buffer
|
| 18 |
+
|
| 19 |
+
def register_empty_parameter(module, name, param):
|
| 20 |
+
old_register_parameter(module, name, param)
|
| 21 |
+
if param is not None:
|
| 22 |
+
param_cls = type(module._parameters[name])
|
| 23 |
+
kwargs = module._parameters[name].__dict__
|
| 24 |
+
kwargs["requires_grad"] = param.requires_grad
|
| 25 |
+
module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs)
|
| 26 |
+
|
| 27 |
+
def register_empty_buffer(module, name, buffer, persistent=True):
|
| 28 |
+
old_register_buffer(module, name, buffer, persistent=persistent)
|
| 29 |
+
if buffer is not None:
|
| 30 |
+
module._buffers[name] = module._buffers[name].to(device)
|
| 31 |
+
|
| 32 |
+
def patch_tensor_constructor(fn):
|
| 33 |
+
def wrapper(*args, **kwargs):
|
| 34 |
+
kwargs["device"] = device
|
| 35 |
+
return fn(*args, **kwargs)
|
| 36 |
+
|
| 37 |
+
return wrapper
|
| 38 |
+
|
| 39 |
+
if include_buffers:
|
| 40 |
+
tensor_constructors_to_patch = {
|
| 41 |
+
torch_function_name: getattr(torch, torch_function_name)
|
| 42 |
+
for torch_function_name in ["empty", "zeros", "ones", "full"]
|
| 43 |
+
}
|
| 44 |
+
else:
|
| 45 |
+
tensor_constructors_to_patch = {}
|
| 46 |
+
|
| 47 |
+
try:
|
| 48 |
+
torch.nn.Module.register_parameter = register_empty_parameter
|
| 49 |
+
if include_buffers:
|
| 50 |
+
torch.nn.Module.register_buffer = register_empty_buffer
|
| 51 |
+
for torch_function_name in tensor_constructors_to_patch.keys():
|
| 52 |
+
setattr(torch, torch_function_name, patch_tensor_constructor(getattr(torch, torch_function_name)))
|
| 53 |
+
yield
|
| 54 |
+
finally:
|
| 55 |
+
torch.nn.Module.register_parameter = old_register_parameter
|
| 56 |
+
if include_buffers:
|
| 57 |
+
torch.nn.Module.register_buffer = old_register_buffer
|
| 58 |
+
for torch_function_name, old_torch_function in tensor_constructors_to_patch.items():
|
| 59 |
+
setattr(torch, torch_function_name, old_torch_function)
|
| 60 |
+
|
| 61 |
+
class FluxControlNet(torch.nn.Module):
|
| 62 |
+
def __init__(self, disable_guidance_embedder=False, num_joint_blocks=5, num_single_blocks=10, num_mode=0, mode_dict={}, additional_input_dim=0):
|
| 63 |
+
super().__init__()
|
| 64 |
+
self.pos_embedder = RoPEEmbedding(3072, 10000, [16, 56, 56])
|
| 65 |
+
self.time_embedder = TimestepEmbeddings(256, 3072)
|
| 66 |
+
self.guidance_embedder = None if disable_guidance_embedder else TimestepEmbeddings(256, 3072)
|
| 67 |
+
self.pooled_text_embedder = torch.nn.Sequential(torch.nn.Linear(768, 3072), torch.nn.SiLU(), torch.nn.Linear(3072, 3072))
|
| 68 |
+
self.context_embedder = torch.nn.Linear(4096, 3072)
|
| 69 |
+
self.x_embedder = torch.nn.Linear(64, 3072)
|
| 70 |
+
|
| 71 |
+
self.blocks = torch.nn.ModuleList([FluxJointTransformerBlock(3072, 24) for _ in range(num_joint_blocks)])
|
| 72 |
+
self.single_blocks = torch.nn.ModuleList([FluxSingleTransformerBlock(3072, 24) for _ in range(num_single_blocks)])
|
| 73 |
+
|
| 74 |
+
self.controlnet_blocks = torch.nn.ModuleList([torch.nn.Linear(3072, 3072) for _ in range(num_joint_blocks)])
|
| 75 |
+
self.controlnet_single_blocks = torch.nn.ModuleList([torch.nn.Linear(3072, 3072) for _ in range(num_single_blocks)])
|
| 76 |
+
|
| 77 |
+
self.mode_dict = mode_dict
|
| 78 |
+
self.controlnet_mode_embedder = torch.nn.Embedding(num_mode, 3072) if len(mode_dict) > 0 else None
|
| 79 |
+
self.controlnet_x_embedder = torch.nn.Linear(64 + additional_input_dim, 3072)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def prepare_image_ids(self, latents):
|
| 83 |
+
batch_size, _, height, width = latents.shape
|
| 84 |
+
latent_image_ids = torch.zeros(height // 2, width // 2, 3)
|
| 85 |
+
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
|
| 86 |
+
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
|
| 87 |
+
|
| 88 |
+
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
|
| 89 |
+
|
| 90 |
+
latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1)
|
| 91 |
+
latent_image_ids = latent_image_ids.reshape(
|
| 92 |
+
batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels
|
| 93 |
+
)
|
| 94 |
+
latent_image_ids = latent_image_ids.to(device=latents.device, dtype=latents.dtype)
|
| 95 |
+
|
| 96 |
+
return latent_image_ids
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def patchify(self, hidden_states):
|
| 100 |
+
hidden_states = rearrange(hidden_states, "B C (H P) (W Q) -> B (H W) (C P Q)", P=2, Q=2)
|
| 101 |
+
return hidden_states
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def align_res_stack_to_original_blocks(self, res_stack, num_blocks, hidden_states):
|
| 105 |
+
if len(res_stack) == 0:
|
| 106 |
+
return [torch.zeros_like(hidden_states)] * num_blocks
|
| 107 |
+
interval = (num_blocks + len(res_stack) - 1) // len(res_stack)
|
| 108 |
+
aligned_res_stack = [res_stack[block_id // interval] for block_id in range(num_blocks)]
|
| 109 |
+
return aligned_res_stack
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def forward(
|
| 113 |
+
self,
|
| 114 |
+
hidden_states,
|
| 115 |
+
controlnet_conditioning,
|
| 116 |
+
timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids, image_ids=None,
|
| 117 |
+
processor_id=None,
|
| 118 |
+
tiled=False, tile_size=128, tile_stride=64,
|
| 119 |
+
**kwargs
|
| 120 |
+
):
|
| 121 |
+
if image_ids is None:
|
| 122 |
+
image_ids = self.prepare_image_ids(hidden_states)
|
| 123 |
+
|
| 124 |
+
conditioning = self.time_embedder(timestep, hidden_states.dtype) + self.pooled_text_embedder(pooled_prompt_emb)
|
| 125 |
+
if self.guidance_embedder is not None:
|
| 126 |
+
guidance = guidance * 1000
|
| 127 |
+
conditioning = conditioning + self.guidance_embedder(guidance, hidden_states.dtype)
|
| 128 |
+
prompt_emb = self.context_embedder(prompt_emb)
|
| 129 |
+
if self.controlnet_mode_embedder is not None: # Different from FluxDiT
|
| 130 |
+
processor_id = torch.tensor([self.mode_dict[processor_id]], dtype=torch.int)
|
| 131 |
+
processor_id = repeat(processor_id, "D -> B D", B=1).to(text_ids.device)
|
| 132 |
+
prompt_emb = torch.concat([self.controlnet_mode_embedder(processor_id), prompt_emb], dim=1)
|
| 133 |
+
text_ids = torch.cat([text_ids[:, :1], text_ids], dim=1)
|
| 134 |
+
image_rotary_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
|
| 135 |
+
|
| 136 |
+
hidden_states = self.patchify(hidden_states)
|
| 137 |
+
hidden_states = self.x_embedder(hidden_states)
|
| 138 |
+
controlnet_conditioning = self.patchify(controlnet_conditioning) # Different from FluxDiT
|
| 139 |
+
hidden_states = hidden_states + self.controlnet_x_embedder(controlnet_conditioning) # Different from FluxDiT
|
| 140 |
+
|
| 141 |
+
controlnet_res_stack = []
|
| 142 |
+
for block, controlnet_block in zip(self.blocks, self.controlnet_blocks):
|
| 143 |
+
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb)
|
| 144 |
+
controlnet_res_stack.append(controlnet_block(hidden_states))
|
| 145 |
+
|
| 146 |
+
controlnet_single_res_stack = []
|
| 147 |
+
hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
|
| 148 |
+
for block, controlnet_block in zip(self.single_blocks, self.controlnet_single_blocks):
|
| 149 |
+
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb)
|
| 150 |
+
controlnet_single_res_stack.append(controlnet_block(hidden_states[:, prompt_emb.shape[1]:]))
|
| 151 |
+
|
| 152 |
+
controlnet_res_stack = self.align_res_stack_to_original_blocks(controlnet_res_stack, 19, hidden_states[:, prompt_emb.shape[1]:])
|
| 153 |
+
controlnet_single_res_stack = self.align_res_stack_to_original_blocks(controlnet_single_res_stack, 38, hidden_states[:, prompt_emb.shape[1]:])
|
| 154 |
+
|
| 155 |
+
return controlnet_res_stack, controlnet_single_res_stack
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
# @staticmethod
|
| 159 |
+
# def state_dict_converter():
|
| 160 |
+
# return FluxControlNetStateDictConverter()
|
| 161 |
+
|
| 162 |
+
def quantize(self):
|
| 163 |
+
def cast_to(weight, dtype=None, device=None, copy=False):
|
| 164 |
+
if device is None or weight.device == device:
|
| 165 |
+
if not copy:
|
| 166 |
+
if dtype is None or weight.dtype == dtype:
|
| 167 |
+
return weight
|
| 168 |
+
return weight.to(dtype=dtype, copy=copy)
|
| 169 |
+
|
| 170 |
+
r = torch.empty_like(weight, dtype=dtype, device=device)
|
| 171 |
+
r.copy_(weight)
|
| 172 |
+
return r
|
| 173 |
+
|
| 174 |
+
def cast_weight(s, input=None, dtype=None, device=None):
|
| 175 |
+
if input is not None:
|
| 176 |
+
if dtype is None:
|
| 177 |
+
dtype = input.dtype
|
| 178 |
+
if device is None:
|
| 179 |
+
device = input.device
|
| 180 |
+
weight = cast_to(s.weight, dtype, device)
|
| 181 |
+
return weight
|
| 182 |
+
|
| 183 |
+
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
|
| 184 |
+
if input is not None:
|
| 185 |
+
if dtype is None:
|
| 186 |
+
dtype = input.dtype
|
| 187 |
+
if bias_dtype is None:
|
| 188 |
+
bias_dtype = dtype
|
| 189 |
+
if device is None:
|
| 190 |
+
device = input.device
|
| 191 |
+
bias = None
|
| 192 |
+
weight = cast_to(s.weight, dtype, device)
|
| 193 |
+
bias = cast_to(s.bias, bias_dtype, device)
|
| 194 |
+
return weight, bias
|
| 195 |
+
|
| 196 |
+
class quantized_layer:
|
| 197 |
+
class QLinear(torch.nn.Linear):
|
| 198 |
+
def __init__(self, *args, **kwargs):
|
| 199 |
+
super().__init__(*args, **kwargs)
|
| 200 |
+
|
| 201 |
+
def forward(self,input,**kwargs):
|
| 202 |
+
weight,bias= cast_bias_weight(self,input)
|
| 203 |
+
return torch.nn.functional.linear(input,weight,bias)
|
| 204 |
+
|
| 205 |
+
class QRMSNorm(torch.nn.Module):
|
| 206 |
+
def __init__(self, module):
|
| 207 |
+
super().__init__()
|
| 208 |
+
self.module = module
|
| 209 |
+
|
| 210 |
+
def forward(self,hidden_states,**kwargs):
|
| 211 |
+
weight= cast_weight(self.module,hidden_states)
|
| 212 |
+
input_dtype = hidden_states.dtype
|
| 213 |
+
variance = hidden_states.to(torch.float32).square().mean(-1, keepdim=True)
|
| 214 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.module.eps)
|
| 215 |
+
hidden_states = hidden_states.to(input_dtype) * weight
|
| 216 |
+
return hidden_states
|
| 217 |
+
|
| 218 |
+
class QEmbedding(torch.nn.Embedding):
|
| 219 |
+
def __init__(self, *args, **kwargs):
|
| 220 |
+
super().__init__(*args, **kwargs)
|
| 221 |
+
|
| 222 |
+
def forward(self,input,**kwargs):
|
| 223 |
+
weight= cast_weight(self,input)
|
| 224 |
+
return torch.nn.functional.embedding(
|
| 225 |
+
input, weight, self.padding_idx, self.max_norm,
|
| 226 |
+
self.norm_type, self.scale_grad_by_freq, self.sparse)
|
| 227 |
+
|
| 228 |
+
def replace_layer(model):
|
| 229 |
+
for name, module in model.named_children():
|
| 230 |
+
if isinstance(module,quantized_layer.QRMSNorm):
|
| 231 |
+
continue
|
| 232 |
+
if isinstance(module, torch.nn.Linear):
|
| 233 |
+
with init_weights_on_device():
|
| 234 |
+
new_layer = quantized_layer.QLinear(module.in_features,module.out_features)
|
| 235 |
+
new_layer.weight = module.weight
|
| 236 |
+
if module.bias is not None:
|
| 237 |
+
new_layer.bias = module.bias
|
| 238 |
+
setattr(model, name, new_layer)
|
| 239 |
+
elif isinstance(module, RMSNorm):
|
| 240 |
+
if hasattr(module,"quantized"):
|
| 241 |
+
continue
|
| 242 |
+
module.quantized= True
|
| 243 |
+
new_layer = quantized_layer.QRMSNorm(module)
|
| 244 |
+
setattr(model, name, new_layer)
|
| 245 |
+
elif isinstance(module,torch.nn.Embedding):
|
| 246 |
+
rows, cols = module.weight.shape
|
| 247 |
+
new_layer = quantized_layer.QEmbedding(
|
| 248 |
+
num_embeddings=rows,
|
| 249 |
+
embedding_dim=cols,
|
| 250 |
+
_weight=module.weight,
|
| 251 |
+
# _freeze=module.freeze,
|
| 252 |
+
padding_idx=module.padding_idx,
|
| 253 |
+
max_norm=module.max_norm,
|
| 254 |
+
norm_type=module.norm_type,
|
| 255 |
+
scale_grad_by_freq=module.scale_grad_by_freq,
|
| 256 |
+
sparse=module.sparse)
|
| 257 |
+
setattr(model, name, new_layer)
|
| 258 |
+
else:
|
| 259 |
+
replace_layer(module)
|
| 260 |
+
|
| 261 |
+
replace_layer(self)
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
class FluxControlNetStateDictConverter:
|
| 266 |
+
def __init__(self):
|
| 267 |
+
pass
|
| 268 |
+
|
| 269 |
+
def from_diffusers(self, state_dict):
|
| 270 |
+
hash_value = hash_state_dict_keys(state_dict)
|
| 271 |
+
global_rename_dict = {
|
| 272 |
+
"context_embedder": "context_embedder",
|
| 273 |
+
"x_embedder": "x_embedder",
|
| 274 |
+
"time_text_embed.timestep_embedder.linear_1": "time_embedder.timestep_embedder.0",
|
| 275 |
+
"time_text_embed.timestep_embedder.linear_2": "time_embedder.timestep_embedder.2",
|
| 276 |
+
"time_text_embed.guidance_embedder.linear_1": "guidance_embedder.timestep_embedder.0",
|
| 277 |
+
"time_text_embed.guidance_embedder.linear_2": "guidance_embedder.timestep_embedder.2",
|
| 278 |
+
"time_text_embed.text_embedder.linear_1": "pooled_text_embedder.0",
|
| 279 |
+
"time_text_embed.text_embedder.linear_2": "pooled_text_embedder.2",
|
| 280 |
+
"norm_out.linear": "final_norm_out.linear",
|
| 281 |
+
"proj_out": "final_proj_out",
|
| 282 |
+
}
|
| 283 |
+
rename_dict = {
|
| 284 |
+
"proj_out": "proj_out",
|
| 285 |
+
"norm1.linear": "norm1_a.linear",
|
| 286 |
+
"norm1_context.linear": "norm1_b.linear",
|
| 287 |
+
"attn.to_q": "attn.a_to_q",
|
| 288 |
+
"attn.to_k": "attn.a_to_k",
|
| 289 |
+
"attn.to_v": "attn.a_to_v",
|
| 290 |
+
"attn.to_out.0": "attn.a_to_out",
|
| 291 |
+
"attn.add_q_proj": "attn.b_to_q",
|
| 292 |
+
"attn.add_k_proj": "attn.b_to_k",
|
| 293 |
+
"attn.add_v_proj": "attn.b_to_v",
|
| 294 |
+
"attn.to_add_out": "attn.b_to_out",
|
| 295 |
+
"ff.net.0.proj": "ff_a.0",
|
| 296 |
+
"ff.net.2": "ff_a.2",
|
| 297 |
+
"ff_context.net.0.proj": "ff_b.0",
|
| 298 |
+
"ff_context.net.2": "ff_b.2",
|
| 299 |
+
"attn.norm_q": "attn.norm_q_a",
|
| 300 |
+
"attn.norm_k": "attn.norm_k_a",
|
| 301 |
+
"attn.norm_added_q": "attn.norm_q_b",
|
| 302 |
+
"attn.norm_added_k": "attn.norm_k_b",
|
| 303 |
+
}
|
| 304 |
+
rename_dict_single = {
|
| 305 |
+
"attn.to_q": "a_to_q",
|
| 306 |
+
"attn.to_k": "a_to_k",
|
| 307 |
+
"attn.to_v": "a_to_v",
|
| 308 |
+
"attn.norm_q": "norm_q_a",
|
| 309 |
+
"attn.norm_k": "norm_k_a",
|
| 310 |
+
"norm.linear": "norm.linear",
|
| 311 |
+
"proj_mlp": "proj_in_besides_attn",
|
| 312 |
+
"proj_out": "proj_out",
|
| 313 |
+
}
|
| 314 |
+
state_dict_ = {}
|
| 315 |
+
for name, param in state_dict.items():
|
| 316 |
+
if name.endswith(".weight") or name.endswith(".bias"):
|
| 317 |
+
suffix = ".weight" if name.endswith(".weight") else ".bias"
|
| 318 |
+
prefix = name[:-len(suffix)]
|
| 319 |
+
if prefix in global_rename_dict:
|
| 320 |
+
state_dict_[global_rename_dict[prefix] + suffix] = param
|
| 321 |
+
elif prefix.startswith("transformer_blocks."):
|
| 322 |
+
names = prefix.split(".")
|
| 323 |
+
names[0] = "blocks"
|
| 324 |
+
middle = ".".join(names[2:])
|
| 325 |
+
if middle in rename_dict:
|
| 326 |
+
name_ = ".".join(names[:2] + [rename_dict[middle]] + [suffix[1:]])
|
| 327 |
+
state_dict_[name_] = param
|
| 328 |
+
elif prefix.startswith("single_transformer_blocks."):
|
| 329 |
+
names = prefix.split(".")
|
| 330 |
+
names[0] = "single_blocks"
|
| 331 |
+
middle = ".".join(names[2:])
|
| 332 |
+
if middle in rename_dict_single:
|
| 333 |
+
name_ = ".".join(names[:2] + [rename_dict_single[middle]] + [suffix[1:]])
|
| 334 |
+
state_dict_[name_] = param
|
| 335 |
+
else:
|
| 336 |
+
state_dict_[name] = param
|
| 337 |
+
else:
|
| 338 |
+
state_dict_[name] = param
|
| 339 |
+
for name in list(state_dict_.keys()):
|
| 340 |
+
if ".proj_in_besides_attn." in name:
|
| 341 |
+
name_ = name.replace(".proj_in_besides_attn.", ".to_qkv_mlp.")
|
| 342 |
+
param = torch.concat([
|
| 343 |
+
state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_q.")],
|
| 344 |
+
state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_k.")],
|
| 345 |
+
state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_v.")],
|
| 346 |
+
state_dict_[name],
|
| 347 |
+
], dim=0)
|
| 348 |
+
state_dict_[name_] = param
|
| 349 |
+
state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_q."))
|
| 350 |
+
state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_k."))
|
| 351 |
+
state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_v."))
|
| 352 |
+
state_dict_.pop(name)
|
| 353 |
+
for name in list(state_dict_.keys()):
|
| 354 |
+
for component in ["a", "b"]:
|
| 355 |
+
if f".{component}_to_q." in name:
|
| 356 |
+
name_ = name.replace(f".{component}_to_q.", f".{component}_to_qkv.")
|
| 357 |
+
param = torch.concat([
|
| 358 |
+
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")],
|
| 359 |
+
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")],
|
| 360 |
+
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")],
|
| 361 |
+
], dim=0)
|
| 362 |
+
state_dict_[name_] = param
|
| 363 |
+
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_q."))
|
| 364 |
+
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_k."))
|
| 365 |
+
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_v."))
|
| 366 |
+
if hash_value == "78d18b9101345ff695f312e7e62538c0":
|
| 367 |
+
extra_kwargs = {"num_mode": 10, "mode_dict": {"canny": 0, "tile": 1, "depth": 2, "blur": 3, "pose": 4, "gray": 5, "lq": 6}}
|
| 368 |
+
elif hash_value == "b001c89139b5f053c715fe772362dd2a":
|
| 369 |
+
extra_kwargs = {"num_single_blocks": 0}
|
| 370 |
+
elif hash_value == "52357cb26250681367488a8954c271e8":
|
| 371 |
+
extra_kwargs = {"num_joint_blocks": 6, "num_single_blocks": 0, "additional_input_dim": 4}
|
| 372 |
+
elif hash_value == "0cfd1740758423a2a854d67c136d1e8c":
|
| 373 |
+
extra_kwargs = {"num_joint_blocks": 4, "num_single_blocks": 1}
|
| 374 |
+
elif hash_value == "7f9583eb8ba86642abb9a21a4b2c9e16":
|
| 375 |
+
extra_kwargs = {"num_joint_blocks": 4, "num_single_blocks": 10}
|
| 376 |
+
elif hash_value == "43ad5aaa27dd4ee01b832ed16773fa52":
|
| 377 |
+
extra_kwargs = {"num_joint_blocks": 6, "num_single_blocks": 0}
|
| 378 |
+
else:
|
| 379 |
+
extra_kwargs = {}
|
| 380 |
+
return state_dict_, extra_kwargs
|
| 381 |
+
|
| 382 |
+
|
| 383 |
+
def from_civitai(self, state_dict):
|
| 384 |
+
return self.from_diffusers(state_dict)
|
diffsynth/models/flux_dit.py
ADDED
|
@@ -0,0 +1,398 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from .general_modules import TimestepEmbeddings, AdaLayerNorm, RMSNorm
|
| 3 |
+
from einops import rearrange
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def interact_with_ipadapter(hidden_states, q, ip_k, ip_v, scale=1.0):
|
| 7 |
+
batch_size, num_tokens = hidden_states.shape[0:2]
|
| 8 |
+
ip_hidden_states = torch.nn.functional.scaled_dot_product_attention(q, ip_k, ip_v)
|
| 9 |
+
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, num_tokens, -1)
|
| 10 |
+
hidden_states = hidden_states + scale * ip_hidden_states
|
| 11 |
+
return hidden_states
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class RoPEEmbedding(torch.nn.Module):
|
| 15 |
+
def __init__(self, dim, theta, axes_dim):
|
| 16 |
+
super().__init__()
|
| 17 |
+
self.dim = dim
|
| 18 |
+
self.theta = theta
|
| 19 |
+
self.axes_dim = axes_dim
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def rope(self, pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
|
| 23 |
+
assert dim % 2 == 0, "The dimension must be even."
|
| 24 |
+
|
| 25 |
+
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
|
| 26 |
+
omega = 1.0 / (theta**scale)
|
| 27 |
+
|
| 28 |
+
batch_size, seq_length = pos.shape
|
| 29 |
+
out = torch.einsum("...n,d->...nd", pos, omega)
|
| 30 |
+
cos_out = torch.cos(out)
|
| 31 |
+
sin_out = torch.sin(out)
|
| 32 |
+
|
| 33 |
+
stacked_out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1)
|
| 34 |
+
out = stacked_out.view(batch_size, -1, dim // 2, 2, 2)
|
| 35 |
+
return out.float()
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def forward(self, ids):
|
| 39 |
+
n_axes = ids.shape[-1]
|
| 40 |
+
emb = torch.cat([self.rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], dim=-3)
|
| 41 |
+
return emb.unsqueeze(1)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class FluxJointAttention(torch.nn.Module):
|
| 46 |
+
def __init__(self, dim_a, dim_b, num_heads, head_dim, only_out_a=False):
|
| 47 |
+
super().__init__()
|
| 48 |
+
self.num_heads = num_heads
|
| 49 |
+
self.head_dim = head_dim
|
| 50 |
+
self.only_out_a = only_out_a
|
| 51 |
+
|
| 52 |
+
self.a_to_qkv = torch.nn.Linear(dim_a, dim_a * 3)
|
| 53 |
+
self.b_to_qkv = torch.nn.Linear(dim_b, dim_b * 3)
|
| 54 |
+
|
| 55 |
+
self.norm_q_a = RMSNorm(head_dim, eps=1e-6)
|
| 56 |
+
self.norm_k_a = RMSNorm(head_dim, eps=1e-6)
|
| 57 |
+
self.norm_q_b = RMSNorm(head_dim, eps=1e-6)
|
| 58 |
+
self.norm_k_b = RMSNorm(head_dim, eps=1e-6)
|
| 59 |
+
|
| 60 |
+
self.a_to_out = torch.nn.Linear(dim_a, dim_a)
|
| 61 |
+
if not only_out_a:
|
| 62 |
+
self.b_to_out = torch.nn.Linear(dim_b, dim_b)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def apply_rope(self, xq, xk, freqs_cis):
|
| 66 |
+
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
|
| 67 |
+
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
|
| 68 |
+
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
| 69 |
+
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
| 70 |
+
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
| 71 |
+
|
| 72 |
+
def forward(self, hidden_states_a, hidden_states_b, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None):
|
| 73 |
+
batch_size = hidden_states_a.shape[0]
|
| 74 |
+
|
| 75 |
+
# Part A
|
| 76 |
+
qkv_a = self.a_to_qkv(hidden_states_a)
|
| 77 |
+
qkv_a = qkv_a.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2)
|
| 78 |
+
q_a, k_a, v_a = qkv_a.chunk(3, dim=1)
|
| 79 |
+
q_a, k_a = self.norm_q_a(q_a), self.norm_k_a(k_a)
|
| 80 |
+
|
| 81 |
+
# Part B
|
| 82 |
+
qkv_b = self.b_to_qkv(hidden_states_b)
|
| 83 |
+
qkv_b = qkv_b.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2)
|
| 84 |
+
q_b, k_b, v_b = qkv_b.chunk(3, dim=1)
|
| 85 |
+
q_b, k_b = self.norm_q_b(q_b), self.norm_k_b(k_b)
|
| 86 |
+
|
| 87 |
+
q = torch.concat([q_b, q_a], dim=2)
|
| 88 |
+
k = torch.concat([k_b, k_a], dim=2)
|
| 89 |
+
v = torch.concat([v_b, v_a], dim=2)
|
| 90 |
+
|
| 91 |
+
q, k = self.apply_rope(q, k, image_rotary_emb)
|
| 92 |
+
|
| 93 |
+
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
| 94 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
|
| 95 |
+
hidden_states = hidden_states.to(q.dtype)
|
| 96 |
+
hidden_states_b, hidden_states_a = hidden_states[:, :hidden_states_b.shape[1]], hidden_states[:, hidden_states_b.shape[1]:]
|
| 97 |
+
if ipadapter_kwargs_list is not None:
|
| 98 |
+
hidden_states_a = interact_with_ipadapter(hidden_states_a, q_a, **ipadapter_kwargs_list)
|
| 99 |
+
hidden_states_a = self.a_to_out(hidden_states_a)
|
| 100 |
+
if self.only_out_a:
|
| 101 |
+
return hidden_states_a
|
| 102 |
+
else:
|
| 103 |
+
hidden_states_b = self.b_to_out(hidden_states_b)
|
| 104 |
+
return hidden_states_a, hidden_states_b
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
class FluxJointTransformerBlock(torch.nn.Module):
|
| 109 |
+
def __init__(self, dim, num_attention_heads):
|
| 110 |
+
super().__init__()
|
| 111 |
+
self.norm1_a = AdaLayerNorm(dim)
|
| 112 |
+
self.norm1_b = AdaLayerNorm(dim)
|
| 113 |
+
|
| 114 |
+
self.attn = FluxJointAttention(dim, dim, num_attention_heads, dim // num_attention_heads)
|
| 115 |
+
|
| 116 |
+
self.norm2_a = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
| 117 |
+
self.ff_a = torch.nn.Sequential(
|
| 118 |
+
torch.nn.Linear(dim, dim*4),
|
| 119 |
+
torch.nn.GELU(approximate="tanh"),
|
| 120 |
+
torch.nn.Linear(dim*4, dim)
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
self.norm2_b = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
| 124 |
+
self.ff_b = torch.nn.Sequential(
|
| 125 |
+
torch.nn.Linear(dim, dim*4),
|
| 126 |
+
torch.nn.GELU(approximate="tanh"),
|
| 127 |
+
torch.nn.Linear(dim*4, dim)
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None):
|
| 132 |
+
norm_hidden_states_a, gate_msa_a, shift_mlp_a, scale_mlp_a, gate_mlp_a = self.norm1_a(hidden_states_a, emb=temb)
|
| 133 |
+
norm_hidden_states_b, gate_msa_b, shift_mlp_b, scale_mlp_b, gate_mlp_b = self.norm1_b(hidden_states_b, emb=temb)
|
| 134 |
+
|
| 135 |
+
# Attention
|
| 136 |
+
attn_output_a, attn_output_b = self.attn(norm_hidden_states_a, norm_hidden_states_b, image_rotary_emb, attn_mask, ipadapter_kwargs_list)
|
| 137 |
+
|
| 138 |
+
# Part A
|
| 139 |
+
hidden_states_a = hidden_states_a + gate_msa_a * attn_output_a
|
| 140 |
+
norm_hidden_states_a = self.norm2_a(hidden_states_a) * (1 + scale_mlp_a) + shift_mlp_a
|
| 141 |
+
hidden_states_a = hidden_states_a + gate_mlp_a * self.ff_a(norm_hidden_states_a)
|
| 142 |
+
|
| 143 |
+
# Part B
|
| 144 |
+
hidden_states_b = hidden_states_b + gate_msa_b * attn_output_b
|
| 145 |
+
norm_hidden_states_b = self.norm2_b(hidden_states_b) * (1 + scale_mlp_b) + shift_mlp_b
|
| 146 |
+
hidden_states_b = hidden_states_b + gate_mlp_b * self.ff_b(norm_hidden_states_b)
|
| 147 |
+
|
| 148 |
+
return hidden_states_a, hidden_states_b
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
class FluxSingleAttention(torch.nn.Module):
|
| 153 |
+
def __init__(self, dim_a, dim_b, num_heads, head_dim):
|
| 154 |
+
super().__init__()
|
| 155 |
+
self.num_heads = num_heads
|
| 156 |
+
self.head_dim = head_dim
|
| 157 |
+
|
| 158 |
+
self.a_to_qkv = torch.nn.Linear(dim_a, dim_a * 3)
|
| 159 |
+
|
| 160 |
+
self.norm_q_a = RMSNorm(head_dim, eps=1e-6)
|
| 161 |
+
self.norm_k_a = RMSNorm(head_dim, eps=1e-6)
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def apply_rope(self, xq, xk, freqs_cis):
|
| 165 |
+
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
|
| 166 |
+
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
|
| 167 |
+
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
| 168 |
+
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
| 169 |
+
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def forward(self, hidden_states, image_rotary_emb):
|
| 173 |
+
batch_size = hidden_states.shape[0]
|
| 174 |
+
|
| 175 |
+
qkv_a = self.a_to_qkv(hidden_states)
|
| 176 |
+
qkv_a = qkv_a.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2)
|
| 177 |
+
q_a, k_a, v = qkv_a.chunk(3, dim=1)
|
| 178 |
+
q_a, k_a = self.norm_q_a(q_a), self.norm_k_a(k_a)
|
| 179 |
+
|
| 180 |
+
q, k = self.apply_rope(q_a, k_a, image_rotary_emb)
|
| 181 |
+
|
| 182 |
+
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
| 183 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
|
| 184 |
+
hidden_states = hidden_states.to(q.dtype)
|
| 185 |
+
return hidden_states
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
class AdaLayerNormSingle(torch.nn.Module):
|
| 190 |
+
def __init__(self, dim):
|
| 191 |
+
super().__init__()
|
| 192 |
+
self.silu = torch.nn.SiLU()
|
| 193 |
+
self.linear = torch.nn.Linear(dim, 3 * dim, bias=True)
|
| 194 |
+
self.norm = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def forward(self, x, emb):
|
| 198 |
+
emb = self.linear(self.silu(emb))
|
| 199 |
+
shift_msa, scale_msa, gate_msa = emb.chunk(3, dim=1)
|
| 200 |
+
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
|
| 201 |
+
return x, gate_msa
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
class FluxSingleTransformerBlock(torch.nn.Module):
|
| 206 |
+
def __init__(self, dim, num_attention_heads):
|
| 207 |
+
super().__init__()
|
| 208 |
+
self.num_heads = num_attention_heads
|
| 209 |
+
self.head_dim = dim // num_attention_heads
|
| 210 |
+
self.dim = dim
|
| 211 |
+
|
| 212 |
+
self.norm = AdaLayerNormSingle(dim)
|
| 213 |
+
self.to_qkv_mlp = torch.nn.Linear(dim, dim * (3 + 4))
|
| 214 |
+
self.norm_q_a = RMSNorm(self.head_dim, eps=1e-6)
|
| 215 |
+
self.norm_k_a = RMSNorm(self.head_dim, eps=1e-6)
|
| 216 |
+
|
| 217 |
+
self.proj_out = torch.nn.Linear(dim * 5, dim)
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def apply_rope(self, xq, xk, freqs_cis):
|
| 221 |
+
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
|
| 222 |
+
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
|
| 223 |
+
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
| 224 |
+
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
| 225 |
+
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
def process_attention(self, hidden_states, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None):
|
| 229 |
+
batch_size = hidden_states.shape[0]
|
| 230 |
+
|
| 231 |
+
qkv = hidden_states.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2)
|
| 232 |
+
q, k, v = qkv.chunk(3, dim=1)
|
| 233 |
+
q, k = self.norm_q_a(q), self.norm_k_a(k)
|
| 234 |
+
|
| 235 |
+
q, k = self.apply_rope(q, k, image_rotary_emb)
|
| 236 |
+
|
| 237 |
+
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
| 238 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
|
| 239 |
+
hidden_states = hidden_states.to(q.dtype)
|
| 240 |
+
if ipadapter_kwargs_list is not None:
|
| 241 |
+
hidden_states = interact_with_ipadapter(hidden_states, q, **ipadapter_kwargs_list)
|
| 242 |
+
return hidden_states
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None):
|
| 246 |
+
residual = hidden_states_a
|
| 247 |
+
norm_hidden_states, gate = self.norm(hidden_states_a, emb=temb)
|
| 248 |
+
hidden_states_a = self.to_qkv_mlp(norm_hidden_states)
|
| 249 |
+
attn_output, mlp_hidden_states = hidden_states_a[:, :, :self.dim * 3], hidden_states_a[:, :, self.dim * 3:]
|
| 250 |
+
|
| 251 |
+
attn_output = self.process_attention(attn_output, image_rotary_emb, attn_mask, ipadapter_kwargs_list)
|
| 252 |
+
mlp_hidden_states = torch.nn.functional.gelu(mlp_hidden_states, approximate="tanh")
|
| 253 |
+
|
| 254 |
+
hidden_states_a = torch.cat([attn_output, mlp_hidden_states], dim=2)
|
| 255 |
+
hidden_states_a = gate.unsqueeze(1) * self.proj_out(hidden_states_a)
|
| 256 |
+
hidden_states_a = residual + hidden_states_a
|
| 257 |
+
|
| 258 |
+
return hidden_states_a, hidden_states_b
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
class AdaLayerNormContinuous(torch.nn.Module):
|
| 263 |
+
def __init__(self, dim):
|
| 264 |
+
super().__init__()
|
| 265 |
+
self.silu = torch.nn.SiLU()
|
| 266 |
+
self.linear = torch.nn.Linear(dim, dim * 2, bias=True)
|
| 267 |
+
self.norm = torch.nn.LayerNorm(dim, eps=1e-6, elementwise_affine=False)
|
| 268 |
+
|
| 269 |
+
def forward(self, x, conditioning):
|
| 270 |
+
emb = self.linear(self.silu(conditioning))
|
| 271 |
+
shift, scale = torch.chunk(emb, 2, dim=1)
|
| 272 |
+
x = self.norm(x) * (1 + scale)[:, None] + shift[:, None]
|
| 273 |
+
return x
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
class FluxDiT(torch.nn.Module):
|
| 278 |
+
|
| 279 |
+
_repeated_blocks = ["FluxJointTransformerBlock", "FluxSingleTransformerBlock"]
|
| 280 |
+
|
| 281 |
+
def __init__(self, disable_guidance_embedder=False, input_dim=64, num_blocks=19):
|
| 282 |
+
super().__init__()
|
| 283 |
+
self.pos_embedder = RoPEEmbedding(3072, 10000, [16, 56, 56])
|
| 284 |
+
self.time_embedder = TimestepEmbeddings(256, 3072)
|
| 285 |
+
self.guidance_embedder = None if disable_guidance_embedder else TimestepEmbeddings(256, 3072)
|
| 286 |
+
self.pooled_text_embedder = torch.nn.Sequential(torch.nn.Linear(768, 3072), torch.nn.SiLU(), torch.nn.Linear(3072, 3072))
|
| 287 |
+
self.context_embedder = torch.nn.Linear(4096, 3072)
|
| 288 |
+
self.x_embedder = torch.nn.Linear(input_dim, 3072)
|
| 289 |
+
|
| 290 |
+
self.blocks = torch.nn.ModuleList([FluxJointTransformerBlock(3072, 24) for _ in range(num_blocks)])
|
| 291 |
+
self.single_blocks = torch.nn.ModuleList([FluxSingleTransformerBlock(3072, 24) for _ in range(38)])
|
| 292 |
+
|
| 293 |
+
self.final_norm_out = AdaLayerNormContinuous(3072)
|
| 294 |
+
self.final_proj_out = torch.nn.Linear(3072, 64)
|
| 295 |
+
|
| 296 |
+
self.input_dim = input_dim
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
def patchify(self, hidden_states):
|
| 300 |
+
hidden_states = rearrange(hidden_states, "B C (H P) (W Q) -> B (H W) (C P Q)", P=2, Q=2)
|
| 301 |
+
return hidden_states
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
def unpatchify(self, hidden_states, height, width):
|
| 305 |
+
hidden_states = rearrange(hidden_states, "B (H W) (C P Q) -> B C (H P) (W Q)", P=2, Q=2, H=height//2, W=width//2)
|
| 306 |
+
return hidden_states
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
def prepare_image_ids(self, latents):
|
| 310 |
+
batch_size, _, height, width = latents.shape
|
| 311 |
+
latent_image_ids = torch.zeros(height // 2, width // 2, 3)
|
| 312 |
+
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
|
| 313 |
+
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
|
| 314 |
+
|
| 315 |
+
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
|
| 316 |
+
|
| 317 |
+
latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1)
|
| 318 |
+
latent_image_ids = latent_image_ids.reshape(
|
| 319 |
+
batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels
|
| 320 |
+
)
|
| 321 |
+
latent_image_ids = latent_image_ids.to(device=latents.device, dtype=latents.dtype)
|
| 322 |
+
|
| 323 |
+
return latent_image_ids
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
def construct_mask(self, entity_masks, prompt_seq_len, image_seq_len):
|
| 327 |
+
N = len(entity_masks)
|
| 328 |
+
batch_size = entity_masks[0].shape[0]
|
| 329 |
+
total_seq_len = N * prompt_seq_len + image_seq_len
|
| 330 |
+
patched_masks = [self.patchify(entity_masks[i]) for i in range(N)]
|
| 331 |
+
attention_mask = torch.ones((batch_size, total_seq_len, total_seq_len), dtype=torch.bool).to(device=entity_masks[0].device)
|
| 332 |
+
|
| 333 |
+
image_start = N * prompt_seq_len
|
| 334 |
+
image_end = N * prompt_seq_len + image_seq_len
|
| 335 |
+
# prompt-image mask
|
| 336 |
+
for i in range(N):
|
| 337 |
+
prompt_start = i * prompt_seq_len
|
| 338 |
+
prompt_end = (i + 1) * prompt_seq_len
|
| 339 |
+
image_mask = torch.sum(patched_masks[i], dim=-1) > 0
|
| 340 |
+
image_mask = image_mask.unsqueeze(1).repeat(1, prompt_seq_len, 1)
|
| 341 |
+
# prompt update with image
|
| 342 |
+
attention_mask[:, prompt_start:prompt_end, image_start:image_end] = image_mask
|
| 343 |
+
# image update with prompt
|
| 344 |
+
attention_mask[:, image_start:image_end, prompt_start:prompt_end] = image_mask.transpose(1, 2)
|
| 345 |
+
# prompt-prompt mask
|
| 346 |
+
for i in range(N):
|
| 347 |
+
for j in range(N):
|
| 348 |
+
if i != j:
|
| 349 |
+
prompt_start_i = i * prompt_seq_len
|
| 350 |
+
prompt_end_i = (i + 1) * prompt_seq_len
|
| 351 |
+
prompt_start_j = j * prompt_seq_len
|
| 352 |
+
prompt_end_j = (j + 1) * prompt_seq_len
|
| 353 |
+
attention_mask[:, prompt_start_i:prompt_end_i, prompt_start_j:prompt_end_j] = False
|
| 354 |
+
|
| 355 |
+
attention_mask = attention_mask.float()
|
| 356 |
+
attention_mask[attention_mask == 0] = float('-inf')
|
| 357 |
+
attention_mask[attention_mask == 1] = 0
|
| 358 |
+
return attention_mask
|
| 359 |
+
|
| 360 |
+
|
| 361 |
+
def process_entity_masks(self, hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids, repeat_dim):
|
| 362 |
+
max_masks = 0
|
| 363 |
+
attention_mask = None
|
| 364 |
+
prompt_embs = [prompt_emb]
|
| 365 |
+
if entity_masks is not None:
|
| 366 |
+
# entity_masks
|
| 367 |
+
batch_size, max_masks = entity_masks.shape[0], entity_masks.shape[1]
|
| 368 |
+
entity_masks = entity_masks.repeat(1, 1, repeat_dim, 1, 1)
|
| 369 |
+
entity_masks = [entity_masks[:, i, None].squeeze(1) for i in range(max_masks)]
|
| 370 |
+
# global mask
|
| 371 |
+
global_mask = torch.ones_like(entity_masks[0]).to(device=hidden_states.device, dtype=hidden_states.dtype)
|
| 372 |
+
entity_masks = entity_masks + [global_mask] # append global to last
|
| 373 |
+
# attention mask
|
| 374 |
+
attention_mask = self.construct_mask(entity_masks, prompt_emb.shape[1], hidden_states.shape[1])
|
| 375 |
+
attention_mask = attention_mask.to(device=hidden_states.device, dtype=hidden_states.dtype)
|
| 376 |
+
attention_mask = attention_mask.unsqueeze(1)
|
| 377 |
+
# embds: n_masks * b * seq * d
|
| 378 |
+
local_embs = [entity_prompt_emb[:, i, None].squeeze(1) for i in range(max_masks)]
|
| 379 |
+
prompt_embs = local_embs + prompt_embs # append global to last
|
| 380 |
+
prompt_embs = [self.context_embedder(prompt_emb) for prompt_emb in prompt_embs]
|
| 381 |
+
prompt_emb = torch.cat(prompt_embs, dim=1)
|
| 382 |
+
|
| 383 |
+
# positional embedding
|
| 384 |
+
text_ids = torch.cat([text_ids] * (max_masks + 1), dim=1)
|
| 385 |
+
image_rotary_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
|
| 386 |
+
return prompt_emb, image_rotary_emb, attention_mask
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
def forward(
|
| 390 |
+
self,
|
| 391 |
+
hidden_states,
|
| 392 |
+
timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids, image_ids=None,
|
| 393 |
+
tiled=False, tile_size=128, tile_stride=64, entity_prompt_emb=None, entity_masks=None,
|
| 394 |
+
use_gradient_checkpointing=False,
|
| 395 |
+
**kwargs
|
| 396 |
+
):
|
| 397 |
+
# (Deprecated) The real forward is in `pipelines.flux_image`.
|
| 398 |
+
return None
|
diffsynth/models/flux_infiniteyou.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
# FFN
|
| 7 |
+
def FeedForward(dim, mult=4):
|
| 8 |
+
inner_dim = int(dim * mult)
|
| 9 |
+
return nn.Sequential(
|
| 10 |
+
nn.LayerNorm(dim),
|
| 11 |
+
nn.Linear(dim, inner_dim, bias=False),
|
| 12 |
+
nn.GELU(),
|
| 13 |
+
nn.Linear(inner_dim, dim, bias=False),
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def reshape_tensor(x, heads):
|
| 18 |
+
bs, length, width = x.shape
|
| 19 |
+
#(bs, length, width) --> (bs, length, n_heads, dim_per_head)
|
| 20 |
+
x = x.view(bs, length, heads, -1)
|
| 21 |
+
# (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
|
| 22 |
+
x = x.transpose(1, 2)
|
| 23 |
+
# (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
|
| 24 |
+
x = x.reshape(bs, heads, length, -1)
|
| 25 |
+
return x
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class PerceiverAttention(nn.Module):
|
| 29 |
+
|
| 30 |
+
def __init__(self, *, dim, dim_head=64, heads=8):
|
| 31 |
+
super().__init__()
|
| 32 |
+
self.scale = dim_head**-0.5
|
| 33 |
+
self.dim_head = dim_head
|
| 34 |
+
self.heads = heads
|
| 35 |
+
inner_dim = dim_head * heads
|
| 36 |
+
|
| 37 |
+
self.norm1 = nn.LayerNorm(dim)
|
| 38 |
+
self.norm2 = nn.LayerNorm(dim)
|
| 39 |
+
|
| 40 |
+
self.to_q = nn.Linear(dim, inner_dim, bias=False)
|
| 41 |
+
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
|
| 42 |
+
self.to_out = nn.Linear(inner_dim, dim, bias=False)
|
| 43 |
+
|
| 44 |
+
def forward(self, x, latents):
|
| 45 |
+
"""
|
| 46 |
+
Args:
|
| 47 |
+
x (torch.Tensor): image features
|
| 48 |
+
shape (b, n1, D)
|
| 49 |
+
latent (torch.Tensor): latent features
|
| 50 |
+
shape (b, n2, D)
|
| 51 |
+
"""
|
| 52 |
+
x = self.norm1(x)
|
| 53 |
+
latents = self.norm2(latents)
|
| 54 |
+
|
| 55 |
+
b, l, _ = latents.shape
|
| 56 |
+
|
| 57 |
+
q = self.to_q(latents)
|
| 58 |
+
kv_input = torch.cat((x, latents), dim=-2)
|
| 59 |
+
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
|
| 60 |
+
|
| 61 |
+
q = reshape_tensor(q, self.heads)
|
| 62 |
+
k = reshape_tensor(k, self.heads)
|
| 63 |
+
v = reshape_tensor(v, self.heads)
|
| 64 |
+
|
| 65 |
+
# attention
|
| 66 |
+
scale = 1 / math.sqrt(math.sqrt(self.dim_head))
|
| 67 |
+
weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
|
| 68 |
+
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
| 69 |
+
out = weight @ v
|
| 70 |
+
|
| 71 |
+
out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
|
| 72 |
+
|
| 73 |
+
return self.to_out(out)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
class InfiniteYouImageProjector(nn.Module):
|
| 77 |
+
|
| 78 |
+
def __init__(
|
| 79 |
+
self,
|
| 80 |
+
dim=1280,
|
| 81 |
+
depth=4,
|
| 82 |
+
dim_head=64,
|
| 83 |
+
heads=20,
|
| 84 |
+
num_queries=8,
|
| 85 |
+
embedding_dim=512,
|
| 86 |
+
output_dim=4096,
|
| 87 |
+
ff_mult=4,
|
| 88 |
+
):
|
| 89 |
+
super().__init__()
|
| 90 |
+
self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
|
| 91 |
+
self.proj_in = nn.Linear(embedding_dim, dim)
|
| 92 |
+
|
| 93 |
+
self.proj_out = nn.Linear(dim, output_dim)
|
| 94 |
+
self.norm_out = nn.LayerNorm(output_dim)
|
| 95 |
+
|
| 96 |
+
self.layers = nn.ModuleList([])
|
| 97 |
+
for _ in range(depth):
|
| 98 |
+
self.layers.append(
|
| 99 |
+
nn.ModuleList([
|
| 100 |
+
PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
|
| 101 |
+
FeedForward(dim=dim, mult=ff_mult),
|
| 102 |
+
]))
|
| 103 |
+
|
| 104 |
+
def forward(self, x):
|
| 105 |
+
|
| 106 |
+
latents = self.latents.repeat(x.size(0), 1, 1)
|
| 107 |
+
latents = latents.to(dtype=x.dtype, device=x.device)
|
| 108 |
+
|
| 109 |
+
x = self.proj_in(x)
|
| 110 |
+
|
| 111 |
+
for attn, ff in self.layers:
|
| 112 |
+
latents = attn(x, latents) + latents
|
| 113 |
+
latents = ff(latents) + latents
|
| 114 |
+
|
| 115 |
+
latents = self.proj_out(latents)
|
| 116 |
+
return self.norm_out(latents)
|
| 117 |
+
|
| 118 |
+
@staticmethod
|
| 119 |
+
def state_dict_converter():
|
| 120 |
+
return FluxInfiniteYouImageProjectorStateDictConverter()
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
class FluxInfiniteYouImageProjectorStateDictConverter:
|
| 124 |
+
|
| 125 |
+
def __init__(self):
|
| 126 |
+
pass
|
| 127 |
+
|
| 128 |
+
def from_diffusers(self, state_dict):
|
| 129 |
+
return state_dict['image_proj']
|
diffsynth/models/flux_ipadapter.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .general_modules import RMSNorm
|
| 2 |
+
from transformers import SiglipVisionModel, SiglipVisionConfig
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class SiglipVisionModelSO400M(SiglipVisionModel):
|
| 7 |
+
def __init__(self):
|
| 8 |
+
config = SiglipVisionConfig(
|
| 9 |
+
hidden_size=1152,
|
| 10 |
+
image_size=384,
|
| 11 |
+
intermediate_size=4304,
|
| 12 |
+
model_type="siglip_vision_model",
|
| 13 |
+
num_attention_heads=16,
|
| 14 |
+
num_hidden_layers=27,
|
| 15 |
+
patch_size=14,
|
| 16 |
+
architectures=["SiglipModel"],
|
| 17 |
+
initializer_factor=1.0,
|
| 18 |
+
torch_dtype="float32",
|
| 19 |
+
transformers_version="4.37.0.dev0"
|
| 20 |
+
)
|
| 21 |
+
super().__init__(config)
|
| 22 |
+
|
| 23 |
+
class MLPProjModel(torch.nn.Module):
|
| 24 |
+
def __init__(self, cross_attention_dim=768, id_embeddings_dim=512, num_tokens=4):
|
| 25 |
+
super().__init__()
|
| 26 |
+
|
| 27 |
+
self.cross_attention_dim = cross_attention_dim
|
| 28 |
+
self.num_tokens = num_tokens
|
| 29 |
+
|
| 30 |
+
self.proj = torch.nn.Sequential(
|
| 31 |
+
torch.nn.Linear(id_embeddings_dim, id_embeddings_dim*2),
|
| 32 |
+
torch.nn.GELU(),
|
| 33 |
+
torch.nn.Linear(id_embeddings_dim*2, cross_attention_dim*num_tokens),
|
| 34 |
+
)
|
| 35 |
+
self.norm = torch.nn.LayerNorm(cross_attention_dim)
|
| 36 |
+
|
| 37 |
+
def forward(self, id_embeds):
|
| 38 |
+
x = self.proj(id_embeds)
|
| 39 |
+
x = x.reshape(-1, self.num_tokens, self.cross_attention_dim)
|
| 40 |
+
x = self.norm(x)
|
| 41 |
+
return x
|
| 42 |
+
|
| 43 |
+
class IpAdapterModule(torch.nn.Module):
|
| 44 |
+
def __init__(self, num_attention_heads, attention_head_dim, input_dim):
|
| 45 |
+
super().__init__()
|
| 46 |
+
self.num_heads = num_attention_heads
|
| 47 |
+
self.head_dim = attention_head_dim
|
| 48 |
+
output_dim = num_attention_heads * attention_head_dim
|
| 49 |
+
self.to_k_ip = torch.nn.Linear(input_dim, output_dim, bias=False)
|
| 50 |
+
self.to_v_ip = torch.nn.Linear(input_dim, output_dim, bias=False)
|
| 51 |
+
self.norm_added_k = RMSNorm(attention_head_dim, eps=1e-5, elementwise_affine=False)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def forward(self, hidden_states):
|
| 55 |
+
batch_size = hidden_states.shape[0]
|
| 56 |
+
# ip_k
|
| 57 |
+
ip_k = self.to_k_ip(hidden_states)
|
| 58 |
+
ip_k = ip_k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
| 59 |
+
ip_k = self.norm_added_k(ip_k)
|
| 60 |
+
# ip_v
|
| 61 |
+
ip_v = self.to_v_ip(hidden_states)
|
| 62 |
+
ip_v = ip_v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
| 63 |
+
return ip_k, ip_v
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class FluxIpAdapter(torch.nn.Module):
|
| 67 |
+
def __init__(self, num_attention_heads=24, attention_head_dim=128, cross_attention_dim=4096, num_tokens=128, num_blocks=57):
|
| 68 |
+
super().__init__()
|
| 69 |
+
self.ipadapter_modules = torch.nn.ModuleList([IpAdapterModule(num_attention_heads, attention_head_dim, cross_attention_dim) for _ in range(num_blocks)])
|
| 70 |
+
self.image_proj = MLPProjModel(cross_attention_dim=cross_attention_dim, id_embeddings_dim=1152, num_tokens=num_tokens)
|
| 71 |
+
self.set_adapter()
|
| 72 |
+
|
| 73 |
+
def set_adapter(self):
|
| 74 |
+
self.call_block_id = {i:i for i in range(len(self.ipadapter_modules))}
|
| 75 |
+
|
| 76 |
+
def forward(self, hidden_states, scale=1.0):
|
| 77 |
+
hidden_states = self.image_proj(hidden_states)
|
| 78 |
+
hidden_states = hidden_states.view(1, -1, hidden_states.shape[-1])
|
| 79 |
+
ip_kv_dict = {}
|
| 80 |
+
for block_id in self.call_block_id:
|
| 81 |
+
ipadapter_id = self.call_block_id[block_id]
|
| 82 |
+
ip_k, ip_v = self.ipadapter_modules[ipadapter_id](hidden_states)
|
| 83 |
+
ip_kv_dict[block_id] = {
|
| 84 |
+
"ip_k": ip_k,
|
| 85 |
+
"ip_v": ip_v,
|
| 86 |
+
"scale": scale
|
| 87 |
+
}
|
| 88 |
+
return ip_kv_dict
|
| 89 |
+
|
| 90 |
+
@staticmethod
|
| 91 |
+
def state_dict_converter():
|
| 92 |
+
return FluxIpAdapterStateDictConverter()
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
class FluxIpAdapterStateDictConverter:
|
| 96 |
+
def __init__(self):
|
| 97 |
+
pass
|
| 98 |
+
|
| 99 |
+
def from_diffusers(self, state_dict):
|
| 100 |
+
state_dict_ = {}
|
| 101 |
+
for name in state_dict["ip_adapter"]:
|
| 102 |
+
name_ = 'ipadapter_modules.' + name
|
| 103 |
+
state_dict_[name_] = state_dict["ip_adapter"][name]
|
| 104 |
+
for name in state_dict["image_proj"]:
|
| 105 |
+
name_ = "image_proj." + name
|
| 106 |
+
state_dict_[name_] = state_dict["image_proj"][name]
|
| 107 |
+
return state_dict_
|
| 108 |
+
|
| 109 |
+
def from_civitai(self, state_dict):
|
| 110 |
+
return self.from_diffusers(state_dict)
|
diffsynth/models/flux_lora_encoder.py
ADDED
|
@@ -0,0 +1,521 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from einops import rearrange
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def low_version_attention(query, key, value, attn_bias=None):
|
| 6 |
+
scale = 1 / query.shape[-1] ** 0.5
|
| 7 |
+
query = query * scale
|
| 8 |
+
attn = torch.matmul(query, key.transpose(-2, -1))
|
| 9 |
+
if attn_bias is not None:
|
| 10 |
+
attn = attn + attn_bias
|
| 11 |
+
attn = attn.softmax(-1)
|
| 12 |
+
return attn @ value
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class Attention(torch.nn.Module):
|
| 16 |
+
|
| 17 |
+
def __init__(self, q_dim, num_heads, head_dim, kv_dim=None, bias_q=False, bias_kv=False, bias_out=False):
|
| 18 |
+
super().__init__()
|
| 19 |
+
dim_inner = head_dim * num_heads
|
| 20 |
+
kv_dim = kv_dim if kv_dim is not None else q_dim
|
| 21 |
+
self.num_heads = num_heads
|
| 22 |
+
self.head_dim = head_dim
|
| 23 |
+
|
| 24 |
+
self.to_q = torch.nn.Linear(q_dim, dim_inner, bias=bias_q)
|
| 25 |
+
self.to_k = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
|
| 26 |
+
self.to_v = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
|
| 27 |
+
self.to_out = torch.nn.Linear(dim_inner, q_dim, bias=bias_out)
|
| 28 |
+
|
| 29 |
+
def interact_with_ipadapter(self, hidden_states, q, ip_k, ip_v, scale=1.0):
|
| 30 |
+
batch_size = q.shape[0]
|
| 31 |
+
ip_k = ip_k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
| 32 |
+
ip_v = ip_v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
| 33 |
+
ip_hidden_states = torch.nn.functional.scaled_dot_product_attention(q, ip_k, ip_v)
|
| 34 |
+
hidden_states = hidden_states + scale * ip_hidden_states
|
| 35 |
+
return hidden_states
|
| 36 |
+
|
| 37 |
+
def torch_forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None, ipadapter_kwargs=None, qkv_preprocessor=None):
|
| 38 |
+
if encoder_hidden_states is None:
|
| 39 |
+
encoder_hidden_states = hidden_states
|
| 40 |
+
|
| 41 |
+
batch_size = encoder_hidden_states.shape[0]
|
| 42 |
+
|
| 43 |
+
q = self.to_q(hidden_states)
|
| 44 |
+
k = self.to_k(encoder_hidden_states)
|
| 45 |
+
v = self.to_v(encoder_hidden_states)
|
| 46 |
+
|
| 47 |
+
q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
| 48 |
+
k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
| 49 |
+
v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
| 50 |
+
|
| 51 |
+
if qkv_preprocessor is not None:
|
| 52 |
+
q, k, v = qkv_preprocessor(q, k, v)
|
| 53 |
+
|
| 54 |
+
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
| 55 |
+
if ipadapter_kwargs is not None:
|
| 56 |
+
hidden_states = self.interact_with_ipadapter(hidden_states, q, **ipadapter_kwargs)
|
| 57 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
|
| 58 |
+
hidden_states = hidden_states.to(q.dtype)
|
| 59 |
+
|
| 60 |
+
hidden_states = self.to_out(hidden_states)
|
| 61 |
+
|
| 62 |
+
return hidden_states
|
| 63 |
+
|
| 64 |
+
def xformers_forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None):
|
| 65 |
+
if encoder_hidden_states is None:
|
| 66 |
+
encoder_hidden_states = hidden_states
|
| 67 |
+
|
| 68 |
+
q = self.to_q(hidden_states)
|
| 69 |
+
k = self.to_k(encoder_hidden_states)
|
| 70 |
+
v = self.to_v(encoder_hidden_states)
|
| 71 |
+
|
| 72 |
+
q = rearrange(q, "b f (n d) -> (b n) f d", n=self.num_heads)
|
| 73 |
+
k = rearrange(k, "b f (n d) -> (b n) f d", n=self.num_heads)
|
| 74 |
+
v = rearrange(v, "b f (n d) -> (b n) f d", n=self.num_heads)
|
| 75 |
+
|
| 76 |
+
if attn_mask is not None:
|
| 77 |
+
hidden_states = low_version_attention(q, k, v, attn_bias=attn_mask)
|
| 78 |
+
else:
|
| 79 |
+
import xformers.ops as xops
|
| 80 |
+
hidden_states = xops.memory_efficient_attention(q, k, v)
|
| 81 |
+
hidden_states = rearrange(hidden_states, "(b n) f d -> b f (n d)", n=self.num_heads)
|
| 82 |
+
|
| 83 |
+
hidden_states = hidden_states.to(q.dtype)
|
| 84 |
+
hidden_states = self.to_out(hidden_states)
|
| 85 |
+
|
| 86 |
+
return hidden_states
|
| 87 |
+
|
| 88 |
+
def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None, ipadapter_kwargs=None, qkv_preprocessor=None):
|
| 89 |
+
return self.torch_forward(hidden_states, encoder_hidden_states=encoder_hidden_states, attn_mask=attn_mask, ipadapter_kwargs=ipadapter_kwargs, qkv_preprocessor=qkv_preprocessor)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
class CLIPEncoderLayer(torch.nn.Module):
|
| 96 |
+
def __init__(self, embed_dim, intermediate_size, num_heads=12, head_dim=64, use_quick_gelu=True):
|
| 97 |
+
super().__init__()
|
| 98 |
+
self.attn = Attention(q_dim=embed_dim, num_heads=num_heads, head_dim=head_dim, bias_q=True, bias_kv=True, bias_out=True)
|
| 99 |
+
self.layer_norm1 = torch.nn.LayerNorm(embed_dim)
|
| 100 |
+
self.layer_norm2 = torch.nn.LayerNorm(embed_dim)
|
| 101 |
+
self.fc1 = torch.nn.Linear(embed_dim, intermediate_size)
|
| 102 |
+
self.fc2 = torch.nn.Linear(intermediate_size, embed_dim)
|
| 103 |
+
|
| 104 |
+
self.use_quick_gelu = use_quick_gelu
|
| 105 |
+
|
| 106 |
+
def quickGELU(self, x):
|
| 107 |
+
return x * torch.sigmoid(1.702 * x)
|
| 108 |
+
|
| 109 |
+
def forward(self, hidden_states, attn_mask=None):
|
| 110 |
+
residual = hidden_states
|
| 111 |
+
|
| 112 |
+
hidden_states = self.layer_norm1(hidden_states)
|
| 113 |
+
hidden_states = self.attn(hidden_states, attn_mask=attn_mask)
|
| 114 |
+
hidden_states = residual + hidden_states
|
| 115 |
+
|
| 116 |
+
residual = hidden_states
|
| 117 |
+
hidden_states = self.layer_norm2(hidden_states)
|
| 118 |
+
hidden_states = self.fc1(hidden_states)
|
| 119 |
+
if self.use_quick_gelu:
|
| 120 |
+
hidden_states = self.quickGELU(hidden_states)
|
| 121 |
+
else:
|
| 122 |
+
hidden_states = torch.nn.functional.gelu(hidden_states)
|
| 123 |
+
hidden_states = self.fc2(hidden_states)
|
| 124 |
+
hidden_states = residual + hidden_states
|
| 125 |
+
|
| 126 |
+
return hidden_states
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
class SDTextEncoder(torch.nn.Module):
|
| 130 |
+
def __init__(self, embed_dim=768, vocab_size=49408, max_position_embeddings=77, num_encoder_layers=12, encoder_intermediate_size=3072):
|
| 131 |
+
super().__init__()
|
| 132 |
+
|
| 133 |
+
# token_embedding
|
| 134 |
+
self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim)
|
| 135 |
+
|
| 136 |
+
# position_embeds (This is a fixed tensor)
|
| 137 |
+
self.position_embeds = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, embed_dim))
|
| 138 |
+
|
| 139 |
+
# encoders
|
| 140 |
+
self.encoders = torch.nn.ModuleList([CLIPEncoderLayer(embed_dim, encoder_intermediate_size) for _ in range(num_encoder_layers)])
|
| 141 |
+
|
| 142 |
+
# attn_mask
|
| 143 |
+
self.attn_mask = self.attention_mask(max_position_embeddings)
|
| 144 |
+
|
| 145 |
+
# final_layer_norm
|
| 146 |
+
self.final_layer_norm = torch.nn.LayerNorm(embed_dim)
|
| 147 |
+
|
| 148 |
+
def attention_mask(self, length):
|
| 149 |
+
mask = torch.empty(length, length)
|
| 150 |
+
mask.fill_(float("-inf"))
|
| 151 |
+
mask.triu_(1)
|
| 152 |
+
return mask
|
| 153 |
+
|
| 154 |
+
def forward(self, input_ids, clip_skip=1):
|
| 155 |
+
embeds = self.token_embedding(input_ids) + self.position_embeds
|
| 156 |
+
attn_mask = self.attn_mask.to(device=embeds.device, dtype=embeds.dtype)
|
| 157 |
+
for encoder_id, encoder in enumerate(self.encoders):
|
| 158 |
+
embeds = encoder(embeds, attn_mask=attn_mask)
|
| 159 |
+
if encoder_id + clip_skip == len(self.encoders):
|
| 160 |
+
break
|
| 161 |
+
embeds = self.final_layer_norm(embeds)
|
| 162 |
+
return embeds
|
| 163 |
+
|
| 164 |
+
@staticmethod
|
| 165 |
+
def state_dict_converter():
|
| 166 |
+
return SDTextEncoderStateDictConverter()
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
class SDTextEncoderStateDictConverter:
|
| 170 |
+
def __init__(self):
|
| 171 |
+
pass
|
| 172 |
+
|
| 173 |
+
def from_diffusers(self, state_dict):
|
| 174 |
+
rename_dict = {
|
| 175 |
+
"text_model.embeddings.token_embedding.weight": "token_embedding.weight",
|
| 176 |
+
"text_model.embeddings.position_embedding.weight": "position_embeds",
|
| 177 |
+
"text_model.final_layer_norm.weight": "final_layer_norm.weight",
|
| 178 |
+
"text_model.final_layer_norm.bias": "final_layer_norm.bias"
|
| 179 |
+
}
|
| 180 |
+
attn_rename_dict = {
|
| 181 |
+
"self_attn.q_proj": "attn.to_q",
|
| 182 |
+
"self_attn.k_proj": "attn.to_k",
|
| 183 |
+
"self_attn.v_proj": "attn.to_v",
|
| 184 |
+
"self_attn.out_proj": "attn.to_out",
|
| 185 |
+
"layer_norm1": "layer_norm1",
|
| 186 |
+
"layer_norm2": "layer_norm2",
|
| 187 |
+
"mlp.fc1": "fc1",
|
| 188 |
+
"mlp.fc2": "fc2",
|
| 189 |
+
}
|
| 190 |
+
state_dict_ = {}
|
| 191 |
+
for name in state_dict:
|
| 192 |
+
if name in rename_dict:
|
| 193 |
+
param = state_dict[name]
|
| 194 |
+
if name == "text_model.embeddings.position_embedding.weight":
|
| 195 |
+
param = param.reshape((1, param.shape[0], param.shape[1]))
|
| 196 |
+
state_dict_[rename_dict[name]] = param
|
| 197 |
+
elif name.startswith("text_model.encoder.layers."):
|
| 198 |
+
param = state_dict[name]
|
| 199 |
+
names = name.split(".")
|
| 200 |
+
layer_id, layer_type, tail = names[3], ".".join(names[4:-1]), names[-1]
|
| 201 |
+
name_ = ".".join(["encoders", layer_id, attn_rename_dict[layer_type], tail])
|
| 202 |
+
state_dict_[name_] = param
|
| 203 |
+
return state_dict_
|
| 204 |
+
|
| 205 |
+
def from_civitai(self, state_dict):
|
| 206 |
+
rename_dict = {
|
| 207 |
+
"cond_stage_model.transformer.text_model.embeddings.token_embedding.weight": "token_embedding.weight",
|
| 208 |
+
"cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm1.bias": "encoders.0.layer_norm1.bias",
|
| 209 |
+
"cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm1.weight": "encoders.0.layer_norm1.weight",
|
| 210 |
+
"cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm2.bias": "encoders.0.layer_norm2.bias",
|
| 211 |
+
"cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm2.weight": "encoders.0.layer_norm2.weight",
|
| 212 |
+
"cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc1.bias": "encoders.0.fc1.bias",
|
| 213 |
+
"cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc1.weight": "encoders.0.fc1.weight",
|
| 214 |
+
"cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc2.bias": "encoders.0.fc2.bias",
|
| 215 |
+
"cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc2.weight": "encoders.0.fc2.weight",
|
| 216 |
+
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.k_proj.bias": "encoders.0.attn.to_k.bias",
|
| 217 |
+
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.k_proj.weight": "encoders.0.attn.to_k.weight",
|
| 218 |
+
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.out_proj.bias": "encoders.0.attn.to_out.bias",
|
| 219 |
+
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.out_proj.weight": "encoders.0.attn.to_out.weight",
|
| 220 |
+
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.q_proj.bias": "encoders.0.attn.to_q.bias",
|
| 221 |
+
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.q_proj.weight": "encoders.0.attn.to_q.weight",
|
| 222 |
+
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.v_proj.bias": "encoders.0.attn.to_v.bias",
|
| 223 |
+
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.v_proj.weight": "encoders.0.attn.to_v.weight",
|
| 224 |
+
"cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm1.bias": "encoders.1.layer_norm1.bias",
|
| 225 |
+
"cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm1.weight": "encoders.1.layer_norm1.weight",
|
| 226 |
+
"cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm2.bias": "encoders.1.layer_norm2.bias",
|
| 227 |
+
"cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm2.weight": "encoders.1.layer_norm2.weight",
|
| 228 |
+
"cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc1.bias": "encoders.1.fc1.bias",
|
| 229 |
+
"cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc1.weight": "encoders.1.fc1.weight",
|
| 230 |
+
"cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc2.bias": "encoders.1.fc2.bias",
|
| 231 |
+
"cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc2.weight": "encoders.1.fc2.weight",
|
| 232 |
+
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.k_proj.bias": "encoders.1.attn.to_k.bias",
|
| 233 |
+
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.k_proj.weight": "encoders.1.attn.to_k.weight",
|
| 234 |
+
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.out_proj.bias": "encoders.1.attn.to_out.bias",
|
| 235 |
+
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.out_proj.weight": "encoders.1.attn.to_out.weight",
|
| 236 |
+
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.q_proj.bias": "encoders.1.attn.to_q.bias",
|
| 237 |
+
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.q_proj.weight": "encoders.1.attn.to_q.weight",
|
| 238 |
+
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.v_proj.bias": "encoders.1.attn.to_v.bias",
|
| 239 |
+
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.v_proj.weight": "encoders.1.attn.to_v.weight",
|
| 240 |
+
"cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm1.bias": "encoders.10.layer_norm1.bias",
|
| 241 |
+
"cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm1.weight": "encoders.10.layer_norm1.weight",
|
| 242 |
+
"cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm2.bias": "encoders.10.layer_norm2.bias",
|
| 243 |
+
"cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm2.weight": "encoders.10.layer_norm2.weight",
|
| 244 |
+
"cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc1.bias": "encoders.10.fc1.bias",
|
| 245 |
+
"cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc1.weight": "encoders.10.fc1.weight",
|
| 246 |
+
"cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc2.bias": "encoders.10.fc2.bias",
|
| 247 |
+
"cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc2.weight": "encoders.10.fc2.weight",
|
| 248 |
+
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.k_proj.bias": "encoders.10.attn.to_k.bias",
|
| 249 |
+
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.k_proj.weight": "encoders.10.attn.to_k.weight",
|
| 250 |
+
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.out_proj.bias": "encoders.10.attn.to_out.bias",
|
| 251 |
+
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.out_proj.weight": "encoders.10.attn.to_out.weight",
|
| 252 |
+
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.q_proj.bias": "encoders.10.attn.to_q.bias",
|
| 253 |
+
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.q_proj.weight": "encoders.10.attn.to_q.weight",
|
| 254 |
+
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.v_proj.bias": "encoders.10.attn.to_v.bias",
|
| 255 |
+
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.v_proj.weight": "encoders.10.attn.to_v.weight",
|
| 256 |
+
"cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm1.bias": "encoders.11.layer_norm1.bias",
|
| 257 |
+
"cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm1.weight": "encoders.11.layer_norm1.weight",
|
| 258 |
+
"cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm2.bias": "encoders.11.layer_norm2.bias",
|
| 259 |
+
"cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm2.weight": "encoders.11.layer_norm2.weight",
|
| 260 |
+
"cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc1.bias": "encoders.11.fc1.bias",
|
| 261 |
+
"cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc1.weight": "encoders.11.fc1.weight",
|
| 262 |
+
"cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc2.bias": "encoders.11.fc2.bias",
|
| 263 |
+
"cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc2.weight": "encoders.11.fc2.weight",
|
| 264 |
+
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.k_proj.bias": "encoders.11.attn.to_k.bias",
|
| 265 |
+
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.k_proj.weight": "encoders.11.attn.to_k.weight",
|
| 266 |
+
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.out_proj.bias": "encoders.11.attn.to_out.bias",
|
| 267 |
+
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.out_proj.weight": "encoders.11.attn.to_out.weight",
|
| 268 |
+
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.q_proj.bias": "encoders.11.attn.to_q.bias",
|
| 269 |
+
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.q_proj.weight": "encoders.11.attn.to_q.weight",
|
| 270 |
+
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.v_proj.bias": "encoders.11.attn.to_v.bias",
|
| 271 |
+
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.v_proj.weight": "encoders.11.attn.to_v.weight",
|
| 272 |
+
"cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm1.bias": "encoders.2.layer_norm1.bias",
|
| 273 |
+
"cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm1.weight": "encoders.2.layer_norm1.weight",
|
| 274 |
+
"cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm2.bias": "encoders.2.layer_norm2.bias",
|
| 275 |
+
"cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm2.weight": "encoders.2.layer_norm2.weight",
|
| 276 |
+
"cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc1.bias": "encoders.2.fc1.bias",
|
| 277 |
+
"cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc1.weight": "encoders.2.fc1.weight",
|
| 278 |
+
"cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc2.bias": "encoders.2.fc2.bias",
|
| 279 |
+
"cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc2.weight": "encoders.2.fc2.weight",
|
| 280 |
+
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.k_proj.bias": "encoders.2.attn.to_k.bias",
|
| 281 |
+
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.k_proj.weight": "encoders.2.attn.to_k.weight",
|
| 282 |
+
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.out_proj.bias": "encoders.2.attn.to_out.bias",
|
| 283 |
+
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.out_proj.weight": "encoders.2.attn.to_out.weight",
|
| 284 |
+
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.q_proj.bias": "encoders.2.attn.to_q.bias",
|
| 285 |
+
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.q_proj.weight": "encoders.2.attn.to_q.weight",
|
| 286 |
+
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.v_proj.bias": "encoders.2.attn.to_v.bias",
|
| 287 |
+
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.v_proj.weight": "encoders.2.attn.to_v.weight",
|
| 288 |
+
"cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm1.bias": "encoders.3.layer_norm1.bias",
|
| 289 |
+
"cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm1.weight": "encoders.3.layer_norm1.weight",
|
| 290 |
+
"cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm2.bias": "encoders.3.layer_norm2.bias",
|
| 291 |
+
"cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm2.weight": "encoders.3.layer_norm2.weight",
|
| 292 |
+
"cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc1.bias": "encoders.3.fc1.bias",
|
| 293 |
+
"cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc1.weight": "encoders.3.fc1.weight",
|
| 294 |
+
"cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc2.bias": "encoders.3.fc2.bias",
|
| 295 |
+
"cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc2.weight": "encoders.3.fc2.weight",
|
| 296 |
+
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.k_proj.bias": "encoders.3.attn.to_k.bias",
|
| 297 |
+
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.k_proj.weight": "encoders.3.attn.to_k.weight",
|
| 298 |
+
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.out_proj.bias": "encoders.3.attn.to_out.bias",
|
| 299 |
+
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.out_proj.weight": "encoders.3.attn.to_out.weight",
|
| 300 |
+
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.q_proj.bias": "encoders.3.attn.to_q.bias",
|
| 301 |
+
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.q_proj.weight": "encoders.3.attn.to_q.weight",
|
| 302 |
+
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.v_proj.bias": "encoders.3.attn.to_v.bias",
|
| 303 |
+
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.v_proj.weight": "encoders.3.attn.to_v.weight",
|
| 304 |
+
"cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm1.bias": "encoders.4.layer_norm1.bias",
|
| 305 |
+
"cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm1.weight": "encoders.4.layer_norm1.weight",
|
| 306 |
+
"cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm2.bias": "encoders.4.layer_norm2.bias",
|
| 307 |
+
"cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm2.weight": "encoders.4.layer_norm2.weight",
|
| 308 |
+
"cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc1.bias": "encoders.4.fc1.bias",
|
| 309 |
+
"cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc1.weight": "encoders.4.fc1.weight",
|
| 310 |
+
"cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc2.bias": "encoders.4.fc2.bias",
|
| 311 |
+
"cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc2.weight": "encoders.4.fc2.weight",
|
| 312 |
+
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.k_proj.bias": "encoders.4.attn.to_k.bias",
|
| 313 |
+
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.k_proj.weight": "encoders.4.attn.to_k.weight",
|
| 314 |
+
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.out_proj.bias": "encoders.4.attn.to_out.bias",
|
| 315 |
+
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.out_proj.weight": "encoders.4.attn.to_out.weight",
|
| 316 |
+
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.q_proj.bias": "encoders.4.attn.to_q.bias",
|
| 317 |
+
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.q_proj.weight": "encoders.4.attn.to_q.weight",
|
| 318 |
+
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.v_proj.bias": "encoders.4.attn.to_v.bias",
|
| 319 |
+
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.v_proj.weight": "encoders.4.attn.to_v.weight",
|
| 320 |
+
"cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm1.bias": "encoders.5.layer_norm1.bias",
|
| 321 |
+
"cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm1.weight": "encoders.5.layer_norm1.weight",
|
| 322 |
+
"cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm2.bias": "encoders.5.layer_norm2.bias",
|
| 323 |
+
"cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm2.weight": "encoders.5.layer_norm2.weight",
|
| 324 |
+
"cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc1.bias": "encoders.5.fc1.bias",
|
| 325 |
+
"cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc1.weight": "encoders.5.fc1.weight",
|
| 326 |
+
"cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc2.bias": "encoders.5.fc2.bias",
|
| 327 |
+
"cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc2.weight": "encoders.5.fc2.weight",
|
| 328 |
+
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.k_proj.bias": "encoders.5.attn.to_k.bias",
|
| 329 |
+
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.k_proj.weight": "encoders.5.attn.to_k.weight",
|
| 330 |
+
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.out_proj.bias": "encoders.5.attn.to_out.bias",
|
| 331 |
+
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.out_proj.weight": "encoders.5.attn.to_out.weight",
|
| 332 |
+
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.q_proj.bias": "encoders.5.attn.to_q.bias",
|
| 333 |
+
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.q_proj.weight": "encoders.5.attn.to_q.weight",
|
| 334 |
+
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.v_proj.bias": "encoders.5.attn.to_v.bias",
|
| 335 |
+
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.v_proj.weight": "encoders.5.attn.to_v.weight",
|
| 336 |
+
"cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm1.bias": "encoders.6.layer_norm1.bias",
|
| 337 |
+
"cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm1.weight": "encoders.6.layer_norm1.weight",
|
| 338 |
+
"cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm2.bias": "encoders.6.layer_norm2.bias",
|
| 339 |
+
"cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm2.weight": "encoders.6.layer_norm2.weight",
|
| 340 |
+
"cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc1.bias": "encoders.6.fc1.bias",
|
| 341 |
+
"cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc1.weight": "encoders.6.fc1.weight",
|
| 342 |
+
"cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc2.bias": "encoders.6.fc2.bias",
|
| 343 |
+
"cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc2.weight": "encoders.6.fc2.weight",
|
| 344 |
+
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.k_proj.bias": "encoders.6.attn.to_k.bias",
|
| 345 |
+
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.k_proj.weight": "encoders.6.attn.to_k.weight",
|
| 346 |
+
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.out_proj.bias": "encoders.6.attn.to_out.bias",
|
| 347 |
+
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.out_proj.weight": "encoders.6.attn.to_out.weight",
|
| 348 |
+
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.q_proj.bias": "encoders.6.attn.to_q.bias",
|
| 349 |
+
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.q_proj.weight": "encoders.6.attn.to_q.weight",
|
| 350 |
+
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.v_proj.bias": "encoders.6.attn.to_v.bias",
|
| 351 |
+
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.v_proj.weight": "encoders.6.attn.to_v.weight",
|
| 352 |
+
"cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm1.bias": "encoders.7.layer_norm1.bias",
|
| 353 |
+
"cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm1.weight": "encoders.7.layer_norm1.weight",
|
| 354 |
+
"cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm2.bias": "encoders.7.layer_norm2.bias",
|
| 355 |
+
"cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm2.weight": "encoders.7.layer_norm2.weight",
|
| 356 |
+
"cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc1.bias": "encoders.7.fc1.bias",
|
| 357 |
+
"cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc1.weight": "encoders.7.fc1.weight",
|
| 358 |
+
"cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc2.bias": "encoders.7.fc2.bias",
|
| 359 |
+
"cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc2.weight": "encoders.7.fc2.weight",
|
| 360 |
+
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.k_proj.bias": "encoders.7.attn.to_k.bias",
|
| 361 |
+
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.k_proj.weight": "encoders.7.attn.to_k.weight",
|
| 362 |
+
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.out_proj.bias": "encoders.7.attn.to_out.bias",
|
| 363 |
+
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.out_proj.weight": "encoders.7.attn.to_out.weight",
|
| 364 |
+
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.q_proj.bias": "encoders.7.attn.to_q.bias",
|
| 365 |
+
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.q_proj.weight": "encoders.7.attn.to_q.weight",
|
| 366 |
+
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.v_proj.bias": "encoders.7.attn.to_v.bias",
|
| 367 |
+
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.v_proj.weight": "encoders.7.attn.to_v.weight",
|
| 368 |
+
"cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm1.bias": "encoders.8.layer_norm1.bias",
|
| 369 |
+
"cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm1.weight": "encoders.8.layer_norm1.weight",
|
| 370 |
+
"cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm2.bias": "encoders.8.layer_norm2.bias",
|
| 371 |
+
"cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm2.weight": "encoders.8.layer_norm2.weight",
|
| 372 |
+
"cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc1.bias": "encoders.8.fc1.bias",
|
| 373 |
+
"cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc1.weight": "encoders.8.fc1.weight",
|
| 374 |
+
"cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc2.bias": "encoders.8.fc2.bias",
|
| 375 |
+
"cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc2.weight": "encoders.8.fc2.weight",
|
| 376 |
+
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.k_proj.bias": "encoders.8.attn.to_k.bias",
|
| 377 |
+
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.k_proj.weight": "encoders.8.attn.to_k.weight",
|
| 378 |
+
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.out_proj.bias": "encoders.8.attn.to_out.bias",
|
| 379 |
+
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.out_proj.weight": "encoders.8.attn.to_out.weight",
|
| 380 |
+
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.q_proj.bias": "encoders.8.attn.to_q.bias",
|
| 381 |
+
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.q_proj.weight": "encoders.8.attn.to_q.weight",
|
| 382 |
+
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.v_proj.bias": "encoders.8.attn.to_v.bias",
|
| 383 |
+
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.v_proj.weight": "encoders.8.attn.to_v.weight",
|
| 384 |
+
"cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm1.bias": "encoders.9.layer_norm1.bias",
|
| 385 |
+
"cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm1.weight": "encoders.9.layer_norm1.weight",
|
| 386 |
+
"cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm2.bias": "encoders.9.layer_norm2.bias",
|
| 387 |
+
"cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm2.weight": "encoders.9.layer_norm2.weight",
|
| 388 |
+
"cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc1.bias": "encoders.9.fc1.bias",
|
| 389 |
+
"cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc1.weight": "encoders.9.fc1.weight",
|
| 390 |
+
"cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc2.bias": "encoders.9.fc2.bias",
|
| 391 |
+
"cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc2.weight": "encoders.9.fc2.weight",
|
| 392 |
+
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.k_proj.bias": "encoders.9.attn.to_k.bias",
|
| 393 |
+
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.k_proj.weight": "encoders.9.attn.to_k.weight",
|
| 394 |
+
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.out_proj.bias": "encoders.9.attn.to_out.bias",
|
| 395 |
+
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.out_proj.weight": "encoders.9.attn.to_out.weight",
|
| 396 |
+
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.q_proj.bias": "encoders.9.attn.to_q.bias",
|
| 397 |
+
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.q_proj.weight": "encoders.9.attn.to_q.weight",
|
| 398 |
+
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.v_proj.bias": "encoders.9.attn.to_v.bias",
|
| 399 |
+
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.v_proj.weight": "encoders.9.attn.to_v.weight",
|
| 400 |
+
"cond_stage_model.transformer.text_model.final_layer_norm.bias": "final_layer_norm.bias",
|
| 401 |
+
"cond_stage_model.transformer.text_model.final_layer_norm.weight": "final_layer_norm.weight",
|
| 402 |
+
"cond_stage_model.transformer.text_model.embeddings.position_embedding.weight": "position_embeds"
|
| 403 |
+
}
|
| 404 |
+
state_dict_ = {}
|
| 405 |
+
for name in state_dict:
|
| 406 |
+
if name in rename_dict:
|
| 407 |
+
param = state_dict[name]
|
| 408 |
+
if name == "cond_stage_model.transformer.text_model.embeddings.position_embedding.weight":
|
| 409 |
+
param = param.reshape((1, param.shape[0], param.shape[1]))
|
| 410 |
+
state_dict_[rename_dict[name]] = param
|
| 411 |
+
return state_dict_
|
| 412 |
+
|
| 413 |
+
|
| 414 |
+
|
| 415 |
+
class LoRALayerBlock(torch.nn.Module):
|
| 416 |
+
def __init__(self, L, dim_in, dim_out):
|
| 417 |
+
super().__init__()
|
| 418 |
+
self.x = torch.nn.Parameter(torch.randn(1, L, dim_in))
|
| 419 |
+
self.layer_norm = torch.nn.LayerNorm(dim_out)
|
| 420 |
+
|
| 421 |
+
def forward(self, lora_A, lora_B):
|
| 422 |
+
x = self.x @ lora_A.T @ lora_B.T
|
| 423 |
+
x = self.layer_norm(x)
|
| 424 |
+
return x
|
| 425 |
+
|
| 426 |
+
|
| 427 |
+
class LoRAEmbedder(torch.nn.Module):
|
| 428 |
+
def __init__(self, lora_patterns=None, L=1, out_dim=2048):
|
| 429 |
+
super().__init__()
|
| 430 |
+
if lora_patterns is None:
|
| 431 |
+
lora_patterns = self.default_lora_patterns()
|
| 432 |
+
|
| 433 |
+
model_dict = {}
|
| 434 |
+
for lora_pattern in lora_patterns:
|
| 435 |
+
name, dim = lora_pattern["name"], lora_pattern["dim"]
|
| 436 |
+
model_dict[name.replace(".", "___")] = LoRALayerBlock(L, dim[0], dim[1])
|
| 437 |
+
self.model_dict = torch.nn.ModuleDict(model_dict)
|
| 438 |
+
|
| 439 |
+
proj_dict = {}
|
| 440 |
+
for lora_pattern in lora_patterns:
|
| 441 |
+
layer_type, dim = lora_pattern["type"], lora_pattern["dim"]
|
| 442 |
+
if layer_type not in proj_dict:
|
| 443 |
+
proj_dict[layer_type.replace(".", "___")] = torch.nn.Linear(dim[1], out_dim)
|
| 444 |
+
self.proj_dict = torch.nn.ModuleDict(proj_dict)
|
| 445 |
+
|
| 446 |
+
self.lora_patterns = lora_patterns
|
| 447 |
+
|
| 448 |
+
|
| 449 |
+
def default_lora_patterns(self):
|
| 450 |
+
lora_patterns = []
|
| 451 |
+
lora_dict = {
|
| 452 |
+
"attn.a_to_qkv": (3072, 9216), "attn.a_to_out": (3072, 3072), "ff_a.0": (3072, 12288), "ff_a.2": (12288, 3072), "norm1_a.linear": (3072, 18432),
|
| 453 |
+
"attn.b_to_qkv": (3072, 9216), "attn.b_to_out": (3072, 3072), "ff_b.0": (3072, 12288), "ff_b.2": (12288, 3072), "norm1_b.linear": (3072, 18432),
|
| 454 |
+
}
|
| 455 |
+
for i in range(19):
|
| 456 |
+
for suffix in lora_dict:
|
| 457 |
+
lora_patterns.append({
|
| 458 |
+
"name": f"blocks.{i}.{suffix}",
|
| 459 |
+
"dim": lora_dict[suffix],
|
| 460 |
+
"type": suffix,
|
| 461 |
+
})
|
| 462 |
+
lora_dict = {"to_qkv_mlp": (3072, 21504), "proj_out": (15360, 3072), "norm.linear": (3072, 9216)}
|
| 463 |
+
for i in range(38):
|
| 464 |
+
for suffix in lora_dict:
|
| 465 |
+
lora_patterns.append({
|
| 466 |
+
"name": f"single_blocks.{i}.{suffix}",
|
| 467 |
+
"dim": lora_dict[suffix],
|
| 468 |
+
"type": suffix,
|
| 469 |
+
})
|
| 470 |
+
return lora_patterns
|
| 471 |
+
|
| 472 |
+
def forward(self, lora):
|
| 473 |
+
lora_emb = []
|
| 474 |
+
for lora_pattern in self.lora_patterns:
|
| 475 |
+
name, layer_type = lora_pattern["name"], lora_pattern["type"]
|
| 476 |
+
lora_A = lora[name + ".lora_A.weight"]
|
| 477 |
+
lora_B = lora[name + ".lora_B.weight"]
|
| 478 |
+
lora_out = self.model_dict[name.replace(".", "___")](lora_A, lora_B)
|
| 479 |
+
lora_out = self.proj_dict[layer_type.replace(".", "___")](lora_out)
|
| 480 |
+
lora_emb.append(lora_out)
|
| 481 |
+
lora_emb = torch.concat(lora_emb, dim=1)
|
| 482 |
+
return lora_emb
|
| 483 |
+
|
| 484 |
+
|
| 485 |
+
class FluxLoRAEncoder(torch.nn.Module):
|
| 486 |
+
def __init__(self, embed_dim=4096, encoder_intermediate_size=8192, num_encoder_layers=1, num_embeds_per_lora=16, num_special_embeds=1):
|
| 487 |
+
super().__init__()
|
| 488 |
+
self.num_embeds_per_lora = num_embeds_per_lora
|
| 489 |
+
# embedder
|
| 490 |
+
self.embedder = LoRAEmbedder(L=num_embeds_per_lora, out_dim=embed_dim)
|
| 491 |
+
|
| 492 |
+
# encoders
|
| 493 |
+
self.encoders = torch.nn.ModuleList([CLIPEncoderLayer(embed_dim, encoder_intermediate_size, num_heads=32, head_dim=128) for _ in range(num_encoder_layers)])
|
| 494 |
+
|
| 495 |
+
# special embedding
|
| 496 |
+
self.special_embeds = torch.nn.Parameter(torch.randn(1, num_special_embeds, embed_dim))
|
| 497 |
+
self.num_special_embeds = num_special_embeds
|
| 498 |
+
|
| 499 |
+
# final layer
|
| 500 |
+
self.final_layer_norm = torch.nn.LayerNorm(embed_dim)
|
| 501 |
+
self.final_linear = torch.nn.Linear(embed_dim, embed_dim)
|
| 502 |
+
|
| 503 |
+
def forward(self, lora):
|
| 504 |
+
lora_embeds = self.embedder(lora)
|
| 505 |
+
special_embeds = self.special_embeds.to(dtype=lora_embeds.dtype, device=lora_embeds.device)
|
| 506 |
+
embeds = torch.concat([special_embeds, lora_embeds], dim=1)
|
| 507 |
+
for encoder_id, encoder in enumerate(self.encoders):
|
| 508 |
+
embeds = encoder(embeds)
|
| 509 |
+
embeds = embeds[:, :self.num_special_embeds]
|
| 510 |
+
embeds = self.final_layer_norm(embeds)
|
| 511 |
+
embeds = self.final_linear(embeds)
|
| 512 |
+
return embeds
|
| 513 |
+
|
| 514 |
+
@staticmethod
|
| 515 |
+
def state_dict_converter():
|
| 516 |
+
return FluxLoRAEncoderStateDictConverter()
|
| 517 |
+
|
| 518 |
+
|
| 519 |
+
class FluxLoRAEncoderStateDictConverter:
|
| 520 |
+
def from_civitai(self, state_dict):
|
| 521 |
+
return state_dict
|
diffsynth/models/flux_lora_patcher.py
ADDED
|
@@ -0,0 +1,306 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch, math
|
| 2 |
+
from ..core.loader import load_state_dict
|
| 3 |
+
from typing import Union
|
| 4 |
+
|
| 5 |
+
class GeneralLoRALoader:
|
| 6 |
+
def __init__(self, device="cpu", torch_dtype=torch.float32):
|
| 7 |
+
self.device = device
|
| 8 |
+
self.torch_dtype = torch_dtype
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def get_name_dict(self, lora_state_dict):
|
| 12 |
+
lora_name_dict = {}
|
| 13 |
+
for key in lora_state_dict:
|
| 14 |
+
if ".lora_B." not in key:
|
| 15 |
+
continue
|
| 16 |
+
keys = key.split(".")
|
| 17 |
+
if len(keys) > keys.index("lora_B") + 2:
|
| 18 |
+
keys.pop(keys.index("lora_B") + 1)
|
| 19 |
+
keys.pop(keys.index("lora_B"))
|
| 20 |
+
if keys[0] == "diffusion_model":
|
| 21 |
+
keys.pop(0)
|
| 22 |
+
keys.pop(-1)
|
| 23 |
+
target_name = ".".join(keys)
|
| 24 |
+
lora_name_dict[target_name] = (key, key.replace(".lora_B.", ".lora_A."))
|
| 25 |
+
return lora_name_dict
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def load(self, model: torch.nn.Module, state_dict_lora, alpha=1.0):
|
| 29 |
+
updated_num = 0
|
| 30 |
+
lora_name_dict = self.get_name_dict(state_dict_lora)
|
| 31 |
+
for name, module in model.named_modules():
|
| 32 |
+
if name in lora_name_dict:
|
| 33 |
+
weight_up = state_dict_lora[lora_name_dict[name][0]].to(device=self.device, dtype=self.torch_dtype)
|
| 34 |
+
weight_down = state_dict_lora[lora_name_dict[name][1]].to(device=self.device, dtype=self.torch_dtype)
|
| 35 |
+
if len(weight_up.shape) == 4:
|
| 36 |
+
weight_up = weight_up.squeeze(3).squeeze(2)
|
| 37 |
+
weight_down = weight_down.squeeze(3).squeeze(2)
|
| 38 |
+
weight_lora = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
|
| 39 |
+
else:
|
| 40 |
+
weight_lora = alpha * torch.mm(weight_up, weight_down)
|
| 41 |
+
state_dict = module.state_dict()
|
| 42 |
+
state_dict["weight"] = state_dict["weight"].to(device=self.device, dtype=self.torch_dtype) + weight_lora
|
| 43 |
+
module.load_state_dict(state_dict)
|
| 44 |
+
updated_num += 1
|
| 45 |
+
print(f"{updated_num} tensors are updated by LoRA.")
|
| 46 |
+
|
| 47 |
+
class FluxLoRALoader(GeneralLoRALoader):
|
| 48 |
+
def __init__(self, device="cpu", torch_dtype=torch.float32):
|
| 49 |
+
super().__init__(device=device, torch_dtype=torch_dtype)
|
| 50 |
+
|
| 51 |
+
self.diffusers_rename_dict = {
|
| 52 |
+
"transformer.single_transformer_blocks.blockid.attn.to_k.lora_A.weight":"single_blocks.blockid.a_to_k.lora_A.default.weight",
|
| 53 |
+
"transformer.single_transformer_blocks.blockid.attn.to_k.lora_B.weight":"single_blocks.blockid.a_to_k.lora_B.default.weight",
|
| 54 |
+
"transformer.single_transformer_blocks.blockid.attn.to_q.lora_A.weight":"single_blocks.blockid.a_to_q.lora_A.default.weight",
|
| 55 |
+
"transformer.single_transformer_blocks.blockid.attn.to_q.lora_B.weight":"single_blocks.blockid.a_to_q.lora_B.default.weight",
|
| 56 |
+
"transformer.single_transformer_blocks.blockid.attn.to_v.lora_A.weight":"single_blocks.blockid.a_to_v.lora_A.default.weight",
|
| 57 |
+
"transformer.single_transformer_blocks.blockid.attn.to_v.lora_B.weight":"single_blocks.blockid.a_to_v.lora_B.default.weight",
|
| 58 |
+
"transformer.single_transformer_blocks.blockid.norm.linear.lora_A.weight":"single_blocks.blockid.norm.linear.lora_A.default.weight",
|
| 59 |
+
"transformer.single_transformer_blocks.blockid.norm.linear.lora_B.weight":"single_blocks.blockid.norm.linear.lora_B.default.weight",
|
| 60 |
+
"transformer.single_transformer_blocks.blockid.proj_mlp.lora_A.weight":"single_blocks.blockid.proj_in_besides_attn.lora_A.default.weight",
|
| 61 |
+
"transformer.single_transformer_blocks.blockid.proj_mlp.lora_B.weight":"single_blocks.blockid.proj_in_besides_attn.lora_B.default.weight",
|
| 62 |
+
"transformer.single_transformer_blocks.blockid.proj_out.lora_A.weight":"single_blocks.blockid.proj_out.lora_A.default.weight",
|
| 63 |
+
"transformer.single_transformer_blocks.blockid.proj_out.lora_B.weight":"single_blocks.blockid.proj_out.lora_B.default.weight",
|
| 64 |
+
"transformer.transformer_blocks.blockid.attn.add_k_proj.lora_A.weight":"blocks.blockid.attn.b_to_k.lora_A.default.weight",
|
| 65 |
+
"transformer.transformer_blocks.blockid.attn.add_k_proj.lora_B.weight":"blocks.blockid.attn.b_to_k.lora_B.default.weight",
|
| 66 |
+
"transformer.transformer_blocks.blockid.attn.add_q_proj.lora_A.weight":"blocks.blockid.attn.b_to_q.lora_A.default.weight",
|
| 67 |
+
"transformer.transformer_blocks.blockid.attn.add_q_proj.lora_B.weight":"blocks.blockid.attn.b_to_q.lora_B.default.weight",
|
| 68 |
+
"transformer.transformer_blocks.blockid.attn.add_v_proj.lora_A.weight":"blocks.blockid.attn.b_to_v.lora_A.default.weight",
|
| 69 |
+
"transformer.transformer_blocks.blockid.attn.add_v_proj.lora_B.weight":"blocks.blockid.attn.b_to_v.lora_B.default.weight",
|
| 70 |
+
"transformer.transformer_blocks.blockid.attn.to_add_out.lora_A.weight":"blocks.blockid.attn.b_to_out.lora_A.default.weight",
|
| 71 |
+
"transformer.transformer_blocks.blockid.attn.to_add_out.lora_B.weight":"blocks.blockid.attn.b_to_out.lora_B.default.weight",
|
| 72 |
+
"transformer.transformer_blocks.blockid.attn.to_k.lora_A.weight":"blocks.blockid.attn.a_to_k.lora_A.default.weight",
|
| 73 |
+
"transformer.transformer_blocks.blockid.attn.to_k.lora_B.weight":"blocks.blockid.attn.a_to_k.lora_B.default.weight",
|
| 74 |
+
"transformer.transformer_blocks.blockid.attn.to_out.0.lora_A.weight":"blocks.blockid.attn.a_to_out.lora_A.default.weight",
|
| 75 |
+
"transformer.transformer_blocks.blockid.attn.to_out.0.lora_B.weight":"blocks.blockid.attn.a_to_out.lora_B.default.weight",
|
| 76 |
+
"transformer.transformer_blocks.blockid.attn.to_q.lora_A.weight":"blocks.blockid.attn.a_to_q.lora_A.default.weight",
|
| 77 |
+
"transformer.transformer_blocks.blockid.attn.to_q.lora_B.weight":"blocks.blockid.attn.a_to_q.lora_B.default.weight",
|
| 78 |
+
"transformer.transformer_blocks.blockid.attn.to_v.lora_A.weight":"blocks.blockid.attn.a_to_v.lora_A.default.weight",
|
| 79 |
+
"transformer.transformer_blocks.blockid.attn.to_v.lora_B.weight":"blocks.blockid.attn.a_to_v.lora_B.default.weight",
|
| 80 |
+
"transformer.transformer_blocks.blockid.ff.net.0.proj.lora_A.weight":"blocks.blockid.ff_a.0.lora_A.default.weight",
|
| 81 |
+
"transformer.transformer_blocks.blockid.ff.net.0.proj.lora_B.weight":"blocks.blockid.ff_a.0.lora_B.default.weight",
|
| 82 |
+
"transformer.transformer_blocks.blockid.ff.net.2.lora_A.weight":"blocks.blockid.ff_a.2.lora_A.default.weight",
|
| 83 |
+
"transformer.transformer_blocks.blockid.ff.net.2.lora_B.weight":"blocks.blockid.ff_a.2.lora_B.default.weight",
|
| 84 |
+
"transformer.transformer_blocks.blockid.ff_context.net.0.proj.lora_A.weight":"blocks.blockid.ff_b.0.lora_A.default.weight",
|
| 85 |
+
"transformer.transformer_blocks.blockid.ff_context.net.0.proj.lora_B.weight":"blocks.blockid.ff_b.0.lora_B.default.weight",
|
| 86 |
+
"transformer.transformer_blocks.blockid.ff_context.net.2.lora_A.weight":"blocks.blockid.ff_b.2.lora_A.default.weight",
|
| 87 |
+
"transformer.transformer_blocks.blockid.ff_context.net.2.lora_B.weight":"blocks.blockid.ff_b.2.lora_B.default.weight",
|
| 88 |
+
"transformer.transformer_blocks.blockid.norm1.linear.lora_A.weight":"blocks.blockid.norm1_a.linear.lora_A.default.weight",
|
| 89 |
+
"transformer.transformer_blocks.blockid.norm1.linear.lora_B.weight":"blocks.blockid.norm1_a.linear.lora_B.default.weight",
|
| 90 |
+
"transformer.transformer_blocks.blockid.norm1_context.linear.lora_A.weight":"blocks.blockid.norm1_b.linear.lora_A.default.weight",
|
| 91 |
+
"transformer.transformer_blocks.blockid.norm1_context.linear.lora_B.weight":"blocks.blockid.norm1_b.linear.lora_B.default.weight",
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
self.civitai_rename_dict = {
|
| 95 |
+
"lora_unet_double_blocks_blockid_img_mod_lin.lora_down.weight": "blocks.blockid.norm1_a.linear.lora_A.default.weight",
|
| 96 |
+
"lora_unet_double_blocks_blockid_img_mod_lin.lora_up.weight": "blocks.blockid.norm1_a.linear.lora_B.default.weight",
|
| 97 |
+
"lora_unet_double_blocks_blockid_txt_mod_lin.lora_down.weight": "blocks.blockid.norm1_b.linear.lora_A.default.weight",
|
| 98 |
+
"lora_unet_double_blocks_blockid_txt_mod_lin.lora_up.weight": "blocks.blockid.norm1_b.linear.lora_B.default.weight",
|
| 99 |
+
"lora_unet_double_blocks_blockid_img_attn_qkv.lora_down.weight": "blocks.blockid.attn.a_to_qkv.lora_A.default.weight",
|
| 100 |
+
"lora_unet_double_blocks_blockid_img_attn_qkv.lora_up.weight": "blocks.blockid.attn.a_to_qkv.lora_B.default.weight",
|
| 101 |
+
"lora_unet_double_blocks_blockid_txt_attn_qkv.lora_down.weight": "blocks.blockid.attn.b_to_qkv.lora_A.default.weight",
|
| 102 |
+
"lora_unet_double_blocks_blockid_txt_attn_qkv.lora_up.weight": "blocks.blockid.attn.b_to_qkv.lora_B.default.weight",
|
| 103 |
+
"lora_unet_double_blocks_blockid_img_attn_proj.lora_down.weight": "blocks.blockid.attn.a_to_out.lora_A.default.weight",
|
| 104 |
+
"lora_unet_double_blocks_blockid_img_attn_proj.lora_up.weight": "blocks.blockid.attn.a_to_out.lora_B.default.weight",
|
| 105 |
+
"lora_unet_double_blocks_blockid_txt_attn_proj.lora_down.weight": "blocks.blockid.attn.b_to_out.lora_A.default.weight",
|
| 106 |
+
"lora_unet_double_blocks_blockid_txt_attn_proj.lora_up.weight": "blocks.blockid.attn.b_to_out.lora_B.default.weight",
|
| 107 |
+
"lora_unet_double_blocks_blockid_img_mlp_0.lora_down.weight": "blocks.blockid.ff_a.0.lora_A.default.weight",
|
| 108 |
+
"lora_unet_double_blocks_blockid_img_mlp_0.lora_up.weight": "blocks.blockid.ff_a.0.lora_B.default.weight",
|
| 109 |
+
"lora_unet_double_blocks_blockid_img_mlp_2.lora_down.weight": "blocks.blockid.ff_a.2.lora_A.default.weight",
|
| 110 |
+
"lora_unet_double_blocks_blockid_img_mlp_2.lora_up.weight": "blocks.blockid.ff_a.2.lora_B.default.weight",
|
| 111 |
+
"lora_unet_double_blocks_blockid_txt_mlp_0.lora_down.weight": "blocks.blockid.ff_b.0.lora_A.default.weight",
|
| 112 |
+
"lora_unet_double_blocks_blockid_txt_mlp_0.lora_up.weight": "blocks.blockid.ff_b.0.lora_B.default.weight",
|
| 113 |
+
"lora_unet_double_blocks_blockid_txt_mlp_2.lora_down.weight": "blocks.blockid.ff_b.2.lora_A.default.weight",
|
| 114 |
+
"lora_unet_double_blocks_blockid_txt_mlp_2.lora_up.weight": "blocks.blockid.ff_b.2.lora_B.default.weight",
|
| 115 |
+
"lora_unet_single_blocks_blockid_modulation_lin.lora_down.weight": "single_blocks.blockid.norm.linear.lora_A.default.weight",
|
| 116 |
+
"lora_unet_single_blocks_blockid_modulation_lin.lora_up.weight": "single_blocks.blockid.norm.linear.lora_B.default.weight",
|
| 117 |
+
"lora_unet_single_blocks_blockid_linear1.lora_down.weight": "single_blocks.blockid.to_qkv_mlp.lora_A.default.weight",
|
| 118 |
+
"lora_unet_single_blocks_blockid_linear1.lora_up.weight": "single_blocks.blockid.to_qkv_mlp.lora_B.default.weight",
|
| 119 |
+
"lora_unet_single_blocks_blockid_linear2.lora_down.weight": "single_blocks.blockid.proj_out.lora_A.default.weight",
|
| 120 |
+
"lora_unet_single_blocks_blockid_linear2.lora_up.weight": "single_blocks.blockid.proj_out.lora_B.default.weight",
|
| 121 |
+
}
|
| 122 |
+
|
| 123 |
+
def load(self, model: torch.nn.Module, state_dict_lora, alpha=1.0):
|
| 124 |
+
super().load(model, state_dict_lora, alpha)
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def convert_state_dict(self,state_dict):
|
| 128 |
+
|
| 129 |
+
def guess_block_id(name,model_resource):
|
| 130 |
+
if model_resource == 'civitai':
|
| 131 |
+
names = name.split("_")
|
| 132 |
+
for i in names:
|
| 133 |
+
if i.isdigit():
|
| 134 |
+
return i, name.replace(f"_{i}_", "_blockid_")
|
| 135 |
+
if model_resource == 'diffusers':
|
| 136 |
+
names = name.split(".")
|
| 137 |
+
for i in names:
|
| 138 |
+
if i.isdigit():
|
| 139 |
+
return i, name.replace(f"transformer_blocks.{i}.", "transformer_blocks.blockid.")
|
| 140 |
+
return None, None
|
| 141 |
+
|
| 142 |
+
def guess_resource(state_dict):
|
| 143 |
+
for k in state_dict:
|
| 144 |
+
if "lora_unet_" in k:
|
| 145 |
+
return 'civitai'
|
| 146 |
+
elif k.startswith("transformer."):
|
| 147 |
+
return 'diffusers'
|
| 148 |
+
else:
|
| 149 |
+
None
|
| 150 |
+
|
| 151 |
+
model_resource = guess_resource(state_dict)
|
| 152 |
+
if model_resource is None:
|
| 153 |
+
return state_dict
|
| 154 |
+
|
| 155 |
+
rename_dict = self.diffusers_rename_dict if model_resource == 'diffusers' else self.civitai_rename_dict
|
| 156 |
+
def guess_alpha(state_dict):
|
| 157 |
+
for name, param in state_dict.items():
|
| 158 |
+
if ".alpha" in name:
|
| 159 |
+
for suffix in [".lora_down.weight", ".lora_A.weight"]:
|
| 160 |
+
name_ = name.replace(".alpha", suffix)
|
| 161 |
+
if name_ in state_dict:
|
| 162 |
+
lora_alpha = param.item() / state_dict[name_].shape[0]
|
| 163 |
+
lora_alpha = math.sqrt(lora_alpha)
|
| 164 |
+
return lora_alpha
|
| 165 |
+
|
| 166 |
+
return 1
|
| 167 |
+
|
| 168 |
+
alpha = guess_alpha(state_dict)
|
| 169 |
+
|
| 170 |
+
state_dict_ = {}
|
| 171 |
+
for name, param in state_dict.items():
|
| 172 |
+
block_id, source_name = guess_block_id(name,model_resource)
|
| 173 |
+
if alpha != 1:
|
| 174 |
+
param *= alpha
|
| 175 |
+
if source_name in rename_dict:
|
| 176 |
+
target_name = rename_dict[source_name]
|
| 177 |
+
target_name = target_name.replace(".blockid.", f".{block_id}.")
|
| 178 |
+
state_dict_[target_name] = param
|
| 179 |
+
else:
|
| 180 |
+
state_dict_[name] = param
|
| 181 |
+
|
| 182 |
+
if model_resource == 'diffusers':
|
| 183 |
+
for name in list(state_dict_.keys()):
|
| 184 |
+
if "single_blocks." in name and ".a_to_q." in name:
|
| 185 |
+
mlp = state_dict_.get(name.replace(".a_to_q.", ".proj_in_besides_attn."), None)
|
| 186 |
+
if mlp is None:
|
| 187 |
+
dim = 4
|
| 188 |
+
if 'lora_A' in name:
|
| 189 |
+
dim = 1
|
| 190 |
+
mlp = torch.zeros(dim * state_dict_[name].shape[0],
|
| 191 |
+
*state_dict_[name].shape[1:],
|
| 192 |
+
dtype=state_dict_[name].dtype)
|
| 193 |
+
else:
|
| 194 |
+
state_dict_.pop(name.replace(".a_to_q.", ".proj_in_besides_attn."))
|
| 195 |
+
if 'lora_A' in name:
|
| 196 |
+
param = torch.concat([
|
| 197 |
+
state_dict_.pop(name),
|
| 198 |
+
state_dict_.pop(name.replace(".a_to_q.", ".a_to_k.")),
|
| 199 |
+
state_dict_.pop(name.replace(".a_to_q.", ".a_to_v.")),
|
| 200 |
+
mlp,
|
| 201 |
+
], dim=0)
|
| 202 |
+
elif 'lora_B' in name:
|
| 203 |
+
d, r = state_dict_[name].shape
|
| 204 |
+
param = torch.zeros((3*d+mlp.shape[0], 3*r+mlp.shape[1]), dtype=state_dict_[name].dtype, device=state_dict_[name].device)
|
| 205 |
+
param[:d, :r] = state_dict_.pop(name)
|
| 206 |
+
param[d:2*d, r:2*r] = state_dict_.pop(name.replace(".a_to_q.", ".a_to_k."))
|
| 207 |
+
param[2*d:3*d, 2*r:3*r] = state_dict_.pop(name.replace(".a_to_q.", ".a_to_v."))
|
| 208 |
+
param[3*d:, 3*r:] = mlp
|
| 209 |
+
else:
|
| 210 |
+
param = torch.concat([
|
| 211 |
+
state_dict_.pop(name),
|
| 212 |
+
state_dict_.pop(name.replace(".a_to_q.", ".a_to_k.")),
|
| 213 |
+
state_dict_.pop(name.replace(".a_to_q.", ".a_to_v.")),
|
| 214 |
+
mlp,
|
| 215 |
+
], dim=0)
|
| 216 |
+
name_ = name.replace(".a_to_q.", ".to_qkv_mlp.")
|
| 217 |
+
state_dict_[name_] = param
|
| 218 |
+
for name in list(state_dict_.keys()):
|
| 219 |
+
for component in ["a", "b"]:
|
| 220 |
+
if f".{component}_to_q." in name:
|
| 221 |
+
name_ = name.replace(f".{component}_to_q.", f".{component}_to_qkv.")
|
| 222 |
+
concat_dim = 0
|
| 223 |
+
if 'lora_A' in name:
|
| 224 |
+
param = torch.concat([
|
| 225 |
+
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")],
|
| 226 |
+
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")],
|
| 227 |
+
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")],
|
| 228 |
+
], dim=0)
|
| 229 |
+
elif 'lora_B' in name:
|
| 230 |
+
origin = state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")]
|
| 231 |
+
d, r = origin.shape
|
| 232 |
+
# print(d, r)
|
| 233 |
+
param = torch.zeros((3*d, 3*r), dtype=origin.dtype, device=origin.device)
|
| 234 |
+
param[:d, :r] = state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")]
|
| 235 |
+
param[d:2*d, r:2*r] = state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")]
|
| 236 |
+
param[2*d:3*d, 2*r:3*r] = state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")]
|
| 237 |
+
else:
|
| 238 |
+
param = torch.concat([
|
| 239 |
+
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")],
|
| 240 |
+
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")],
|
| 241 |
+
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")],
|
| 242 |
+
], dim=0)
|
| 243 |
+
state_dict_[name_] = param
|
| 244 |
+
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_q."))
|
| 245 |
+
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_k."))
|
| 246 |
+
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_v."))
|
| 247 |
+
return state_dict_
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
class LoraMerger(torch.nn.Module):
|
| 251 |
+
def __init__(self, dim):
|
| 252 |
+
super().__init__()
|
| 253 |
+
self.weight_base = torch.nn.Parameter(torch.randn((dim,)))
|
| 254 |
+
self.weight_lora = torch.nn.Parameter(torch.randn((dim,)))
|
| 255 |
+
self.weight_cross = torch.nn.Parameter(torch.randn((dim,)))
|
| 256 |
+
self.weight_out = torch.nn.Parameter(torch.ones((dim,)))
|
| 257 |
+
self.bias = torch.nn.Parameter(torch.randn((dim,)))
|
| 258 |
+
self.activation = torch.nn.Sigmoid()
|
| 259 |
+
self.norm_base = torch.nn.LayerNorm(dim, eps=1e-5)
|
| 260 |
+
self.norm_lora = torch.nn.LayerNorm(dim, eps=1e-5)
|
| 261 |
+
|
| 262 |
+
def forward(self, base_output, lora_outputs):
|
| 263 |
+
norm_base_output = self.norm_base(base_output)
|
| 264 |
+
norm_lora_outputs = self.norm_lora(lora_outputs)
|
| 265 |
+
gate = self.activation(
|
| 266 |
+
norm_base_output * self.weight_base \
|
| 267 |
+
+ norm_lora_outputs * self.weight_lora \
|
| 268 |
+
+ norm_base_output * norm_lora_outputs * self.weight_cross + self.bias
|
| 269 |
+
)
|
| 270 |
+
output = base_output + (self.weight_out * gate * lora_outputs).sum(dim=0)
|
| 271 |
+
return output
|
| 272 |
+
|
| 273 |
+
class FluxLoraPatcher(torch.nn.Module):
|
| 274 |
+
def __init__(self, lora_patterns=None):
|
| 275 |
+
super().__init__()
|
| 276 |
+
if lora_patterns is None:
|
| 277 |
+
lora_patterns = self.default_lora_patterns()
|
| 278 |
+
model_dict = {}
|
| 279 |
+
for lora_pattern in lora_patterns:
|
| 280 |
+
name, dim = lora_pattern["name"], lora_pattern["dim"]
|
| 281 |
+
model_dict[name.replace(".", "___")] = LoraMerger(dim)
|
| 282 |
+
self.model_dict = torch.nn.ModuleDict(model_dict)
|
| 283 |
+
|
| 284 |
+
def default_lora_patterns(self):
|
| 285 |
+
lora_patterns = []
|
| 286 |
+
lora_dict = {
|
| 287 |
+
"attn.a_to_qkv": 9216, "attn.a_to_out": 3072, "ff_a.0": 12288, "ff_a.2": 3072, "norm1_a.linear": 18432,
|
| 288 |
+
"attn.b_to_qkv": 9216, "attn.b_to_out": 3072, "ff_b.0": 12288, "ff_b.2": 3072, "norm1_b.linear": 18432,
|
| 289 |
+
}
|
| 290 |
+
for i in range(19):
|
| 291 |
+
for suffix in lora_dict:
|
| 292 |
+
lora_patterns.append({
|
| 293 |
+
"name": f"blocks.{i}.{suffix}",
|
| 294 |
+
"dim": lora_dict[suffix]
|
| 295 |
+
})
|
| 296 |
+
lora_dict = {"to_qkv_mlp": 21504, "proj_out": 3072, "norm.linear": 9216}
|
| 297 |
+
for i in range(38):
|
| 298 |
+
for suffix in lora_dict:
|
| 299 |
+
lora_patterns.append({
|
| 300 |
+
"name": f"single_blocks.{i}.{suffix}",
|
| 301 |
+
"dim": lora_dict[suffix]
|
| 302 |
+
})
|
| 303 |
+
return lora_patterns
|
| 304 |
+
|
| 305 |
+
def forward(self, base_output, lora_outputs, name):
|
| 306 |
+
return self.model_dict[name.replace(".", "___")](base_output, lora_outputs)
|
diffsynth/models/flux_text_encoder_clip.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class Attention(torch.nn.Module):
|
| 5 |
+
|
| 6 |
+
def __init__(self, q_dim, num_heads, head_dim, kv_dim=None, bias_q=False, bias_kv=False, bias_out=False):
|
| 7 |
+
super().__init__()
|
| 8 |
+
dim_inner = head_dim * num_heads
|
| 9 |
+
kv_dim = kv_dim if kv_dim is not None else q_dim
|
| 10 |
+
self.num_heads = num_heads
|
| 11 |
+
self.head_dim = head_dim
|
| 12 |
+
|
| 13 |
+
self.to_q = torch.nn.Linear(q_dim, dim_inner, bias=bias_q)
|
| 14 |
+
self.to_k = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
|
| 15 |
+
self.to_v = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
|
| 16 |
+
self.to_out = torch.nn.Linear(dim_inner, q_dim, bias=bias_out)
|
| 17 |
+
|
| 18 |
+
def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None):
|
| 19 |
+
if encoder_hidden_states is None:
|
| 20 |
+
encoder_hidden_states = hidden_states
|
| 21 |
+
|
| 22 |
+
batch_size = encoder_hidden_states.shape[0]
|
| 23 |
+
|
| 24 |
+
q = self.to_q(hidden_states)
|
| 25 |
+
k = self.to_k(encoder_hidden_states)
|
| 26 |
+
v = self.to_v(encoder_hidden_states)
|
| 27 |
+
|
| 28 |
+
q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
| 29 |
+
k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
| 30 |
+
v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
| 31 |
+
|
| 32 |
+
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
| 33 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
|
| 34 |
+
hidden_states = hidden_states.to(q.dtype)
|
| 35 |
+
|
| 36 |
+
hidden_states = self.to_out(hidden_states)
|
| 37 |
+
|
| 38 |
+
return hidden_states
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class CLIPEncoderLayer(torch.nn.Module):
|
| 42 |
+
def __init__(self, embed_dim, intermediate_size, num_heads=12, head_dim=64, use_quick_gelu=True):
|
| 43 |
+
super().__init__()
|
| 44 |
+
self.attn = Attention(q_dim=embed_dim, num_heads=num_heads, head_dim=head_dim, bias_q=True, bias_kv=True, bias_out=True)
|
| 45 |
+
self.layer_norm1 = torch.nn.LayerNorm(embed_dim)
|
| 46 |
+
self.layer_norm2 = torch.nn.LayerNorm(embed_dim)
|
| 47 |
+
self.fc1 = torch.nn.Linear(embed_dim, intermediate_size)
|
| 48 |
+
self.fc2 = torch.nn.Linear(intermediate_size, embed_dim)
|
| 49 |
+
|
| 50 |
+
self.use_quick_gelu = use_quick_gelu
|
| 51 |
+
|
| 52 |
+
def quickGELU(self, x):
|
| 53 |
+
return x * torch.sigmoid(1.702 * x)
|
| 54 |
+
|
| 55 |
+
def forward(self, hidden_states, attn_mask=None):
|
| 56 |
+
residual = hidden_states
|
| 57 |
+
|
| 58 |
+
hidden_states = self.layer_norm1(hidden_states)
|
| 59 |
+
hidden_states = self.attn(hidden_states, attn_mask=attn_mask)
|
| 60 |
+
hidden_states = residual + hidden_states
|
| 61 |
+
|
| 62 |
+
residual = hidden_states
|
| 63 |
+
hidden_states = self.layer_norm2(hidden_states)
|
| 64 |
+
hidden_states = self.fc1(hidden_states)
|
| 65 |
+
if self.use_quick_gelu:
|
| 66 |
+
hidden_states = self.quickGELU(hidden_states)
|
| 67 |
+
else:
|
| 68 |
+
hidden_states = torch.nn.functional.gelu(hidden_states)
|
| 69 |
+
hidden_states = self.fc2(hidden_states)
|
| 70 |
+
hidden_states = residual + hidden_states
|
| 71 |
+
|
| 72 |
+
return hidden_states
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
class FluxTextEncoderClip(torch.nn.Module):
|
| 76 |
+
def __init__(self, embed_dim=768, vocab_size=49408, max_position_embeddings=77, num_encoder_layers=12, encoder_intermediate_size=3072):
|
| 77 |
+
super().__init__()
|
| 78 |
+
|
| 79 |
+
# token_embedding
|
| 80 |
+
self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim)
|
| 81 |
+
|
| 82 |
+
# position_embeds (This is a fixed tensor)
|
| 83 |
+
self.position_embeds = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, embed_dim))
|
| 84 |
+
|
| 85 |
+
# encoders
|
| 86 |
+
self.encoders = torch.nn.ModuleList([CLIPEncoderLayer(embed_dim, encoder_intermediate_size) for _ in range(num_encoder_layers)])
|
| 87 |
+
|
| 88 |
+
# attn_mask
|
| 89 |
+
self.attn_mask = self.attention_mask(max_position_embeddings)
|
| 90 |
+
|
| 91 |
+
# final_layer_norm
|
| 92 |
+
self.final_layer_norm = torch.nn.LayerNorm(embed_dim)
|
| 93 |
+
|
| 94 |
+
def attention_mask(self, length):
|
| 95 |
+
mask = torch.empty(length, length)
|
| 96 |
+
mask.fill_(float("-inf"))
|
| 97 |
+
mask.triu_(1)
|
| 98 |
+
return mask
|
| 99 |
+
|
| 100 |
+
def forward(self, input_ids, clip_skip=2, extra_mask=None):
|
| 101 |
+
embeds = self.token_embedding(input_ids)
|
| 102 |
+
embeds = embeds + self.position_embeds.to(dtype=embeds.dtype, device=input_ids.device)
|
| 103 |
+
attn_mask = self.attn_mask.to(device=embeds.device, dtype=embeds.dtype)
|
| 104 |
+
if extra_mask is not None:
|
| 105 |
+
attn_mask[:, extra_mask[0]==0] = float("-inf")
|
| 106 |
+
for encoder_id, encoder in enumerate(self.encoders):
|
| 107 |
+
embeds = encoder(embeds, attn_mask=attn_mask)
|
| 108 |
+
if encoder_id + clip_skip == len(self.encoders):
|
| 109 |
+
hidden_states = embeds
|
| 110 |
+
embeds = self.final_layer_norm(embeds)
|
| 111 |
+
pooled_embeds = embeds[torch.arange(embeds.shape[0]), input_ids.to(dtype=torch.int).argmax(dim=-1)]
|
| 112 |
+
return pooled_embeds, hidden_states
|
diffsynth/models/flux_text_encoder_t5.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from transformers import T5EncoderModel, T5Config
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class FluxTextEncoderT5(T5EncoderModel):
|
| 6 |
+
def __init__(self):
|
| 7 |
+
config = T5Config(**{
|
| 8 |
+
"architectures": [
|
| 9 |
+
"T5EncoderModel"
|
| 10 |
+
],
|
| 11 |
+
"classifier_dropout": 0.0,
|
| 12 |
+
"d_ff": 10240,
|
| 13 |
+
"d_kv": 64,
|
| 14 |
+
"d_model": 4096,
|
| 15 |
+
"decoder_start_token_id": 0,
|
| 16 |
+
"dense_act_fn": "gelu_new",
|
| 17 |
+
"dropout_rate": 0.1,
|
| 18 |
+
"dtype": "bfloat16",
|
| 19 |
+
"eos_token_id": 1,
|
| 20 |
+
"feed_forward_proj": "gated-gelu",
|
| 21 |
+
"initializer_factor": 1.0,
|
| 22 |
+
"is_encoder_decoder": True,
|
| 23 |
+
"is_gated_act": True,
|
| 24 |
+
"layer_norm_epsilon": 1e-06,
|
| 25 |
+
"model_type": "t5",
|
| 26 |
+
"num_decoder_layers": 24,
|
| 27 |
+
"num_heads": 64,
|
| 28 |
+
"num_layers": 24,
|
| 29 |
+
"output_past": True,
|
| 30 |
+
"pad_token_id": 0,
|
| 31 |
+
"relative_attention_max_distance": 128,
|
| 32 |
+
"relative_attention_num_buckets": 32,
|
| 33 |
+
"tie_word_embeddings": False,
|
| 34 |
+
"transformers_version": "4.57.1",
|
| 35 |
+
"use_cache": True,
|
| 36 |
+
"vocab_size": 32128
|
| 37 |
+
})
|
| 38 |
+
super().__init__(config)
|
| 39 |
+
|
| 40 |
+
def forward(self, input_ids):
|
| 41 |
+
outputs = super().forward(input_ids=input_ids)
|
| 42 |
+
prompt_emb = outputs.last_hidden_state
|
| 43 |
+
return prompt_emb
|
diffsynth/models/flux_vae.py
ADDED
|
@@ -0,0 +1,451 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from einops import rearrange, repeat
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class TileWorker:
|
| 6 |
+
def __init__(self):
|
| 7 |
+
pass
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def mask(self, height, width, border_width):
|
| 11 |
+
# Create a mask with shape (height, width).
|
| 12 |
+
# The centre area is filled with 1, and the border line is filled with values in range (0, 1].
|
| 13 |
+
x = torch.arange(height).repeat(width, 1).T
|
| 14 |
+
y = torch.arange(width).repeat(height, 1)
|
| 15 |
+
mask = torch.stack([x + 1, height - x, y + 1, width - y]).min(dim=0).values
|
| 16 |
+
mask = (mask / border_width).clip(0, 1)
|
| 17 |
+
return mask
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def tile(self, model_input, tile_size, tile_stride, tile_device, tile_dtype):
|
| 21 |
+
# Convert a tensor (b, c, h, w) to (b, c, tile_size, tile_size, tile_num)
|
| 22 |
+
batch_size, channel, _, _ = model_input.shape
|
| 23 |
+
model_input = model_input.to(device=tile_device, dtype=tile_dtype)
|
| 24 |
+
unfold_operator = torch.nn.Unfold(
|
| 25 |
+
kernel_size=(tile_size, tile_size),
|
| 26 |
+
stride=(tile_stride, tile_stride)
|
| 27 |
+
)
|
| 28 |
+
model_input = unfold_operator(model_input)
|
| 29 |
+
model_input = model_input.view((batch_size, channel, tile_size, tile_size, -1))
|
| 30 |
+
|
| 31 |
+
return model_input
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def tiled_inference(self, forward_fn, model_input, tile_batch_size, inference_device, inference_dtype, tile_device, tile_dtype):
|
| 35 |
+
# Call y=forward_fn(x) for each tile
|
| 36 |
+
tile_num = model_input.shape[-1]
|
| 37 |
+
model_output_stack = []
|
| 38 |
+
|
| 39 |
+
for tile_id in range(0, tile_num, tile_batch_size):
|
| 40 |
+
|
| 41 |
+
# process input
|
| 42 |
+
tile_id_ = min(tile_id + tile_batch_size, tile_num)
|
| 43 |
+
x = model_input[:, :, :, :, tile_id: tile_id_]
|
| 44 |
+
x = x.to(device=inference_device, dtype=inference_dtype)
|
| 45 |
+
x = rearrange(x, "b c h w n -> (n b) c h w")
|
| 46 |
+
|
| 47 |
+
# process output
|
| 48 |
+
y = forward_fn(x)
|
| 49 |
+
y = rearrange(y, "(n b) c h w -> b c h w n", n=tile_id_-tile_id)
|
| 50 |
+
y = y.to(device=tile_device, dtype=tile_dtype)
|
| 51 |
+
model_output_stack.append(y)
|
| 52 |
+
|
| 53 |
+
model_output = torch.concat(model_output_stack, dim=-1)
|
| 54 |
+
return model_output
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def io_scale(self, model_output, tile_size):
|
| 58 |
+
# Determine the size modification happened in forward_fn
|
| 59 |
+
# We only consider the same scale on height and width.
|
| 60 |
+
io_scale = model_output.shape[2] / tile_size
|
| 61 |
+
return io_scale
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def untile(self, model_output, height, width, tile_size, tile_stride, border_width, tile_device, tile_dtype):
|
| 65 |
+
# The reversed function of tile
|
| 66 |
+
mask = self.mask(tile_size, tile_size, border_width)
|
| 67 |
+
mask = mask.to(device=tile_device, dtype=tile_dtype)
|
| 68 |
+
mask = rearrange(mask, "h w -> 1 1 h w 1")
|
| 69 |
+
model_output = model_output * mask
|
| 70 |
+
|
| 71 |
+
fold_operator = torch.nn.Fold(
|
| 72 |
+
output_size=(height, width),
|
| 73 |
+
kernel_size=(tile_size, tile_size),
|
| 74 |
+
stride=(tile_stride, tile_stride)
|
| 75 |
+
)
|
| 76 |
+
mask = repeat(mask[0, 0, :, :, 0], "h w -> 1 (h w) n", n=model_output.shape[-1])
|
| 77 |
+
model_output = rearrange(model_output, "b c h w n -> b (c h w) n")
|
| 78 |
+
model_output = fold_operator(model_output) / fold_operator(mask)
|
| 79 |
+
|
| 80 |
+
return model_output
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def tiled_forward(self, forward_fn, model_input, tile_size, tile_stride, tile_batch_size=1, tile_device="cpu", tile_dtype=torch.float32, border_width=None):
|
| 84 |
+
# Prepare
|
| 85 |
+
inference_device, inference_dtype = model_input.device, model_input.dtype
|
| 86 |
+
height, width = model_input.shape[2], model_input.shape[3]
|
| 87 |
+
border_width = int(tile_stride*0.5) if border_width is None else border_width
|
| 88 |
+
|
| 89 |
+
# tile
|
| 90 |
+
model_input = self.tile(model_input, tile_size, tile_stride, tile_device, tile_dtype)
|
| 91 |
+
|
| 92 |
+
# inference
|
| 93 |
+
model_output = self.tiled_inference(forward_fn, model_input, tile_batch_size, inference_device, inference_dtype, tile_device, tile_dtype)
|
| 94 |
+
|
| 95 |
+
# resize
|
| 96 |
+
io_scale = self.io_scale(model_output, tile_size)
|
| 97 |
+
height, width = int(height*io_scale), int(width*io_scale)
|
| 98 |
+
tile_size, tile_stride = int(tile_size*io_scale), int(tile_stride*io_scale)
|
| 99 |
+
border_width = int(border_width*io_scale)
|
| 100 |
+
|
| 101 |
+
# untile
|
| 102 |
+
model_output = self.untile(model_output, height, width, tile_size, tile_stride, border_width, tile_device, tile_dtype)
|
| 103 |
+
|
| 104 |
+
# Done!
|
| 105 |
+
model_output = model_output.to(device=inference_device, dtype=inference_dtype)
|
| 106 |
+
return model_output
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
class ConvAttention(torch.nn.Module):
|
| 110 |
+
|
| 111 |
+
def __init__(self, q_dim, num_heads, head_dim, kv_dim=None, bias_q=False, bias_kv=False, bias_out=False):
|
| 112 |
+
super().__init__()
|
| 113 |
+
dim_inner = head_dim * num_heads
|
| 114 |
+
kv_dim = kv_dim if kv_dim is not None else q_dim
|
| 115 |
+
self.num_heads = num_heads
|
| 116 |
+
self.head_dim = head_dim
|
| 117 |
+
|
| 118 |
+
self.to_q = torch.nn.Conv2d(q_dim, dim_inner, kernel_size=(1, 1), bias=bias_q)
|
| 119 |
+
self.to_k = torch.nn.Conv2d(kv_dim, dim_inner, kernel_size=(1, 1), bias=bias_kv)
|
| 120 |
+
self.to_v = torch.nn.Conv2d(kv_dim, dim_inner, kernel_size=(1, 1), bias=bias_kv)
|
| 121 |
+
self.to_out = torch.nn.Conv2d(dim_inner, q_dim, kernel_size=(1, 1), bias=bias_out)
|
| 122 |
+
|
| 123 |
+
def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None):
|
| 124 |
+
if encoder_hidden_states is None:
|
| 125 |
+
encoder_hidden_states = hidden_states
|
| 126 |
+
|
| 127 |
+
batch_size = encoder_hidden_states.shape[0]
|
| 128 |
+
|
| 129 |
+
conv_input = rearrange(hidden_states, "B L C -> B C L 1")
|
| 130 |
+
q = self.to_q(conv_input)
|
| 131 |
+
q = rearrange(q[:, :, :, 0], "B C L -> B L C")
|
| 132 |
+
conv_input = rearrange(encoder_hidden_states, "B L C -> B C L 1")
|
| 133 |
+
k = self.to_k(conv_input)
|
| 134 |
+
v = self.to_v(conv_input)
|
| 135 |
+
k = rearrange(k[:, :, :, 0], "B C L -> B L C")
|
| 136 |
+
v = rearrange(v[:, :, :, 0], "B C L -> B L C")
|
| 137 |
+
|
| 138 |
+
q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
| 139 |
+
k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
| 140 |
+
v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
| 141 |
+
|
| 142 |
+
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
| 143 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
|
| 144 |
+
hidden_states = hidden_states.to(q.dtype)
|
| 145 |
+
|
| 146 |
+
conv_input = rearrange(hidden_states, "B L C -> B C L 1")
|
| 147 |
+
hidden_states = self.to_out(conv_input)
|
| 148 |
+
hidden_states = rearrange(hidden_states[:, :, :, 0], "B C L -> B L C")
|
| 149 |
+
|
| 150 |
+
return hidden_states
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
class Attention(torch.nn.Module):
|
| 154 |
+
|
| 155 |
+
def __init__(self, q_dim, num_heads, head_dim, kv_dim=None, bias_q=False, bias_kv=False, bias_out=False):
|
| 156 |
+
super().__init__()
|
| 157 |
+
dim_inner = head_dim * num_heads
|
| 158 |
+
kv_dim = kv_dim if kv_dim is not None else q_dim
|
| 159 |
+
self.num_heads = num_heads
|
| 160 |
+
self.head_dim = head_dim
|
| 161 |
+
|
| 162 |
+
self.to_q = torch.nn.Linear(q_dim, dim_inner, bias=bias_q)
|
| 163 |
+
self.to_k = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
|
| 164 |
+
self.to_v = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
|
| 165 |
+
self.to_out = torch.nn.Linear(dim_inner, q_dim, bias=bias_out)
|
| 166 |
+
|
| 167 |
+
def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None):
|
| 168 |
+
if encoder_hidden_states is None:
|
| 169 |
+
encoder_hidden_states = hidden_states
|
| 170 |
+
|
| 171 |
+
batch_size = encoder_hidden_states.shape[0]
|
| 172 |
+
|
| 173 |
+
q = self.to_q(hidden_states)
|
| 174 |
+
k = self.to_k(encoder_hidden_states)
|
| 175 |
+
v = self.to_v(encoder_hidden_states)
|
| 176 |
+
|
| 177 |
+
q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
| 178 |
+
k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
| 179 |
+
v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
| 180 |
+
|
| 181 |
+
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
| 182 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
|
| 183 |
+
hidden_states = hidden_states.to(q.dtype)
|
| 184 |
+
|
| 185 |
+
hidden_states = self.to_out(hidden_states)
|
| 186 |
+
|
| 187 |
+
return hidden_states
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
class VAEAttentionBlock(torch.nn.Module):
|
| 191 |
+
|
| 192 |
+
def __init__(self, num_attention_heads, attention_head_dim, in_channels, num_layers=1, norm_num_groups=32, eps=1e-5, use_conv_attention=True):
|
| 193 |
+
super().__init__()
|
| 194 |
+
inner_dim = num_attention_heads * attention_head_dim
|
| 195 |
+
|
| 196 |
+
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=eps, affine=True)
|
| 197 |
+
|
| 198 |
+
if use_conv_attention:
|
| 199 |
+
self.transformer_blocks = torch.nn.ModuleList([
|
| 200 |
+
ConvAttention(
|
| 201 |
+
inner_dim,
|
| 202 |
+
num_attention_heads,
|
| 203 |
+
attention_head_dim,
|
| 204 |
+
bias_q=True,
|
| 205 |
+
bias_kv=True,
|
| 206 |
+
bias_out=True
|
| 207 |
+
)
|
| 208 |
+
for d in range(num_layers)
|
| 209 |
+
])
|
| 210 |
+
else:
|
| 211 |
+
self.transformer_blocks = torch.nn.ModuleList([
|
| 212 |
+
Attention(
|
| 213 |
+
inner_dim,
|
| 214 |
+
num_attention_heads,
|
| 215 |
+
attention_head_dim,
|
| 216 |
+
bias_q=True,
|
| 217 |
+
bias_kv=True,
|
| 218 |
+
bias_out=True
|
| 219 |
+
)
|
| 220 |
+
for d in range(num_layers)
|
| 221 |
+
])
|
| 222 |
+
|
| 223 |
+
def forward(self, hidden_states, time_emb, text_emb, res_stack):
|
| 224 |
+
batch, _, height, width = hidden_states.shape
|
| 225 |
+
residual = hidden_states
|
| 226 |
+
|
| 227 |
+
hidden_states = self.norm(hidden_states)
|
| 228 |
+
inner_dim = hidden_states.shape[1]
|
| 229 |
+
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
|
| 230 |
+
|
| 231 |
+
for block in self.transformer_blocks:
|
| 232 |
+
hidden_states = block(hidden_states)
|
| 233 |
+
|
| 234 |
+
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
|
| 235 |
+
hidden_states = hidden_states + residual
|
| 236 |
+
|
| 237 |
+
return hidden_states, time_emb, text_emb, res_stack
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
class ResnetBlock(torch.nn.Module):
|
| 241 |
+
def __init__(self, in_channels, out_channels, temb_channels=None, groups=32, eps=1e-5):
|
| 242 |
+
super().__init__()
|
| 243 |
+
self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
|
| 244 |
+
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
| 245 |
+
if temb_channels is not None:
|
| 246 |
+
self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels)
|
| 247 |
+
self.norm2 = torch.nn.GroupNorm(num_groups=groups, num_channels=out_channels, eps=eps, affine=True)
|
| 248 |
+
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
| 249 |
+
self.nonlinearity = torch.nn.SiLU()
|
| 250 |
+
self.conv_shortcut = None
|
| 251 |
+
if in_channels != out_channels:
|
| 252 |
+
self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=True)
|
| 253 |
+
|
| 254 |
+
def forward(self, hidden_states, time_emb, text_emb, res_stack, **kwargs):
|
| 255 |
+
x = hidden_states
|
| 256 |
+
x = self.norm1(x)
|
| 257 |
+
x = self.nonlinearity(x)
|
| 258 |
+
x = self.conv1(x)
|
| 259 |
+
if time_emb is not None:
|
| 260 |
+
emb = self.nonlinearity(time_emb)
|
| 261 |
+
emb = self.time_emb_proj(emb)[:, :, None, None]
|
| 262 |
+
x = x + emb
|
| 263 |
+
x = self.norm2(x)
|
| 264 |
+
x = self.nonlinearity(x)
|
| 265 |
+
x = self.conv2(x)
|
| 266 |
+
if self.conv_shortcut is not None:
|
| 267 |
+
hidden_states = self.conv_shortcut(hidden_states)
|
| 268 |
+
hidden_states = hidden_states + x
|
| 269 |
+
return hidden_states, time_emb, text_emb, res_stack
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
class UpSampler(torch.nn.Module):
|
| 273 |
+
def __init__(self, channels):
|
| 274 |
+
super().__init__()
|
| 275 |
+
self.conv = torch.nn.Conv2d(channels, channels, 3, padding=1)
|
| 276 |
+
|
| 277 |
+
def forward(self, hidden_states, time_emb, text_emb, res_stack, **kwargs):
|
| 278 |
+
hidden_states = torch.nn.functional.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
|
| 279 |
+
hidden_states = self.conv(hidden_states)
|
| 280 |
+
return hidden_states, time_emb, text_emb, res_stack
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
class DownSampler(torch.nn.Module):
|
| 284 |
+
def __init__(self, channels, padding=1, extra_padding=False):
|
| 285 |
+
super().__init__()
|
| 286 |
+
self.conv = torch.nn.Conv2d(channels, channels, 3, stride=2, padding=padding)
|
| 287 |
+
self.extra_padding = extra_padding
|
| 288 |
+
|
| 289 |
+
def forward(self, hidden_states, time_emb, text_emb, res_stack, **kwargs):
|
| 290 |
+
if self.extra_padding:
|
| 291 |
+
hidden_states = torch.nn.functional.pad(hidden_states, (0, 1, 0, 1), mode="constant", value=0)
|
| 292 |
+
hidden_states = self.conv(hidden_states)
|
| 293 |
+
return hidden_states, time_emb, text_emb, res_stack
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
class FluxVAEDecoder(torch.nn.Module):
|
| 297 |
+
def __init__(self, use_conv_attention=True):
|
| 298 |
+
super().__init__()
|
| 299 |
+
self.scaling_factor = 0.3611
|
| 300 |
+
self.shift_factor = 0.1159
|
| 301 |
+
self.conv_in = torch.nn.Conv2d(16, 512, kernel_size=3, padding=1) # Different from SD 1.x
|
| 302 |
+
|
| 303 |
+
self.blocks = torch.nn.ModuleList([
|
| 304 |
+
# UNetMidBlock2D
|
| 305 |
+
ResnetBlock(512, 512, eps=1e-6),
|
| 306 |
+
VAEAttentionBlock(1, 512, 512, 1, eps=1e-6, use_conv_attention=use_conv_attention),
|
| 307 |
+
ResnetBlock(512, 512, eps=1e-6),
|
| 308 |
+
# UpDecoderBlock2D
|
| 309 |
+
ResnetBlock(512, 512, eps=1e-6),
|
| 310 |
+
ResnetBlock(512, 512, eps=1e-6),
|
| 311 |
+
ResnetBlock(512, 512, eps=1e-6),
|
| 312 |
+
UpSampler(512),
|
| 313 |
+
# UpDecoderBlock2D
|
| 314 |
+
ResnetBlock(512, 512, eps=1e-6),
|
| 315 |
+
ResnetBlock(512, 512, eps=1e-6),
|
| 316 |
+
ResnetBlock(512, 512, eps=1e-6),
|
| 317 |
+
UpSampler(512),
|
| 318 |
+
# UpDecoderBlock2D
|
| 319 |
+
ResnetBlock(512, 256, eps=1e-6),
|
| 320 |
+
ResnetBlock(256, 256, eps=1e-6),
|
| 321 |
+
ResnetBlock(256, 256, eps=1e-6),
|
| 322 |
+
UpSampler(256),
|
| 323 |
+
# UpDecoderBlock2D
|
| 324 |
+
ResnetBlock(256, 128, eps=1e-6),
|
| 325 |
+
ResnetBlock(128, 128, eps=1e-6),
|
| 326 |
+
ResnetBlock(128, 128, eps=1e-6),
|
| 327 |
+
])
|
| 328 |
+
|
| 329 |
+
self.conv_norm_out = torch.nn.GroupNorm(num_channels=128, num_groups=32, eps=1e-6)
|
| 330 |
+
self.conv_act = torch.nn.SiLU()
|
| 331 |
+
self.conv_out = torch.nn.Conv2d(128, 3, kernel_size=3, padding=1)
|
| 332 |
+
|
| 333 |
+
def tiled_forward(self, sample, tile_size=64, tile_stride=32):
|
| 334 |
+
hidden_states = TileWorker().tiled_forward(
|
| 335 |
+
lambda x: self.forward(x),
|
| 336 |
+
sample,
|
| 337 |
+
tile_size,
|
| 338 |
+
tile_stride,
|
| 339 |
+
tile_device=sample.device,
|
| 340 |
+
tile_dtype=sample.dtype
|
| 341 |
+
)
|
| 342 |
+
return hidden_states
|
| 343 |
+
|
| 344 |
+
def forward(self, sample, tiled=False, tile_size=64, tile_stride=32, **kwargs):
|
| 345 |
+
# For VAE Decoder, we do not need to apply the tiler on each layer.
|
| 346 |
+
if tiled:
|
| 347 |
+
return self.tiled_forward(sample, tile_size=tile_size, tile_stride=tile_stride)
|
| 348 |
+
|
| 349 |
+
# 1. pre-process
|
| 350 |
+
hidden_states = sample / self.scaling_factor + self.shift_factor
|
| 351 |
+
hidden_states = self.conv_in(hidden_states)
|
| 352 |
+
time_emb = None
|
| 353 |
+
text_emb = None
|
| 354 |
+
res_stack = None
|
| 355 |
+
|
| 356 |
+
# 2. blocks
|
| 357 |
+
for i, block in enumerate(self.blocks):
|
| 358 |
+
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
|
| 359 |
+
|
| 360 |
+
# 3. output
|
| 361 |
+
hidden_states = self.conv_norm_out(hidden_states)
|
| 362 |
+
hidden_states = self.conv_act(hidden_states)
|
| 363 |
+
hidden_states = self.conv_out(hidden_states)
|
| 364 |
+
|
| 365 |
+
return hidden_states
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
class FluxVAEEncoder(torch.nn.Module):
|
| 369 |
+
def __init__(self, use_conv_attention=True):
|
| 370 |
+
super().__init__()
|
| 371 |
+
self.scaling_factor = 0.3611
|
| 372 |
+
self.shift_factor = 0.1159
|
| 373 |
+
self.conv_in = torch.nn.Conv2d(3, 128, kernel_size=3, padding=1)
|
| 374 |
+
|
| 375 |
+
self.blocks = torch.nn.ModuleList([
|
| 376 |
+
# DownEncoderBlock2D
|
| 377 |
+
ResnetBlock(128, 128, eps=1e-6),
|
| 378 |
+
ResnetBlock(128, 128, eps=1e-6),
|
| 379 |
+
DownSampler(128, padding=0, extra_padding=True),
|
| 380 |
+
# DownEncoderBlock2D
|
| 381 |
+
ResnetBlock(128, 256, eps=1e-6),
|
| 382 |
+
ResnetBlock(256, 256, eps=1e-6),
|
| 383 |
+
DownSampler(256, padding=0, extra_padding=True),
|
| 384 |
+
# DownEncoderBlock2D
|
| 385 |
+
ResnetBlock(256, 512, eps=1e-6),
|
| 386 |
+
ResnetBlock(512, 512, eps=1e-6),
|
| 387 |
+
DownSampler(512, padding=0, extra_padding=True),
|
| 388 |
+
# DownEncoderBlock2D
|
| 389 |
+
ResnetBlock(512, 512, eps=1e-6),
|
| 390 |
+
ResnetBlock(512, 512, eps=1e-6),
|
| 391 |
+
# UNetMidBlock2D
|
| 392 |
+
ResnetBlock(512, 512, eps=1e-6),
|
| 393 |
+
VAEAttentionBlock(1, 512, 512, 1, eps=1e-6, use_conv_attention=use_conv_attention),
|
| 394 |
+
ResnetBlock(512, 512, eps=1e-6),
|
| 395 |
+
])
|
| 396 |
+
|
| 397 |
+
self.conv_norm_out = torch.nn.GroupNorm(num_channels=512, num_groups=32, eps=1e-6)
|
| 398 |
+
self.conv_act = torch.nn.SiLU()
|
| 399 |
+
self.conv_out = torch.nn.Conv2d(512, 32, kernel_size=3, padding=1)
|
| 400 |
+
|
| 401 |
+
def tiled_forward(self, sample, tile_size=64, tile_stride=32):
|
| 402 |
+
hidden_states = TileWorker().tiled_forward(
|
| 403 |
+
lambda x: self.forward(x),
|
| 404 |
+
sample,
|
| 405 |
+
tile_size,
|
| 406 |
+
tile_stride,
|
| 407 |
+
tile_device=sample.device,
|
| 408 |
+
tile_dtype=sample.dtype
|
| 409 |
+
)
|
| 410 |
+
return hidden_states
|
| 411 |
+
|
| 412 |
+
def forward(self, sample, tiled=False, tile_size=64, tile_stride=32, **kwargs):
|
| 413 |
+
# For VAE Decoder, we do not need to apply the tiler on each layer.
|
| 414 |
+
if tiled:
|
| 415 |
+
return self.tiled_forward(sample, tile_size=tile_size, tile_stride=tile_stride)
|
| 416 |
+
|
| 417 |
+
# 1. pre-process
|
| 418 |
+
hidden_states = self.conv_in(sample)
|
| 419 |
+
time_emb = None
|
| 420 |
+
text_emb = None
|
| 421 |
+
res_stack = None
|
| 422 |
+
|
| 423 |
+
# 2. blocks
|
| 424 |
+
for i, block in enumerate(self.blocks):
|
| 425 |
+
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
|
| 426 |
+
|
| 427 |
+
# 3. output
|
| 428 |
+
hidden_states = self.conv_norm_out(hidden_states)
|
| 429 |
+
hidden_states = self.conv_act(hidden_states)
|
| 430 |
+
hidden_states = self.conv_out(hidden_states)
|
| 431 |
+
hidden_states = hidden_states[:, :16]
|
| 432 |
+
hidden_states = (hidden_states - self.shift_factor) * self.scaling_factor
|
| 433 |
+
|
| 434 |
+
return hidden_states
|
| 435 |
+
|
| 436 |
+
def encode_video(self, sample, batch_size=8):
|
| 437 |
+
B = sample.shape[0]
|
| 438 |
+
hidden_states = []
|
| 439 |
+
|
| 440 |
+
for i in range(0, sample.shape[2], batch_size):
|
| 441 |
+
|
| 442 |
+
j = min(i + batch_size, sample.shape[2])
|
| 443 |
+
sample_batch = rearrange(sample[:,:,i:j], "B C T H W -> (B T) C H W")
|
| 444 |
+
|
| 445 |
+
hidden_states_batch = self(sample_batch)
|
| 446 |
+
hidden_states_batch = rearrange(hidden_states_batch, "(B T) C H W -> B C T H W", B=B)
|
| 447 |
+
|
| 448 |
+
hidden_states.append(hidden_states_batch)
|
| 449 |
+
|
| 450 |
+
hidden_states = torch.concat(hidden_states, dim=2)
|
| 451 |
+
return hidden_states
|
diffsynth/models/flux_value_control.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from .general_modules import TemporalTimesteps
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class MultiValueEncoder(torch.nn.Module):
|
| 6 |
+
def __init__(self, encoders=()):
|
| 7 |
+
super().__init__()
|
| 8 |
+
if not isinstance(encoders, list):
|
| 9 |
+
encoders = [encoders]
|
| 10 |
+
self.encoders = torch.nn.ModuleList(encoders)
|
| 11 |
+
|
| 12 |
+
def __call__(self, values, dtype):
|
| 13 |
+
emb = []
|
| 14 |
+
for encoder, value in zip(self.encoders, values):
|
| 15 |
+
if value is not None:
|
| 16 |
+
value = value.unsqueeze(0)
|
| 17 |
+
emb.append(encoder(value, dtype))
|
| 18 |
+
emb = torch.concat(emb, dim=0)
|
| 19 |
+
return emb
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class SingleValueEncoder(torch.nn.Module):
|
| 23 |
+
def __init__(self, dim_in=256, dim_out=4096, prefer_len=32, computation_device=None):
|
| 24 |
+
super().__init__()
|
| 25 |
+
self.prefer_len = prefer_len
|
| 26 |
+
self.prefer_proj = TemporalTimesteps(num_channels=dim_in, flip_sin_to_cos=True, downscale_freq_shift=0, computation_device=computation_device)
|
| 27 |
+
self.prefer_value_embedder = torch.nn.Sequential(
|
| 28 |
+
torch.nn.Linear(dim_in, dim_out), torch.nn.SiLU(), torch.nn.Linear(dim_out, dim_out)
|
| 29 |
+
)
|
| 30 |
+
self.positional_embedding = torch.nn.Parameter(
|
| 31 |
+
torch.randn(self.prefer_len, dim_out)
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
def forward(self, value, dtype):
|
| 35 |
+
value = value * 1000
|
| 36 |
+
emb = self.prefer_proj(value).to(dtype)
|
| 37 |
+
emb = self.prefer_value_embedder(emb).squeeze(0)
|
| 38 |
+
base_embeddings = emb.expand(self.prefer_len, -1)
|
| 39 |
+
positional_embedding = self.positional_embedding.to(dtype=base_embeddings.dtype, device=base_embeddings.device)
|
| 40 |
+
learned_embeddings = base_embeddings + positional_embedding
|
| 41 |
+
return learned_embeddings
|
| 42 |
+
|
| 43 |
+
@staticmethod
|
| 44 |
+
def state_dict_converter():
|
| 45 |
+
return SingleValueEncoderStateDictConverter()
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class SingleValueEncoderStateDictConverter:
|
| 49 |
+
def __init__(self):
|
| 50 |
+
pass
|
| 51 |
+
|
| 52 |
+
def from_diffusers(self, state_dict):
|
| 53 |
+
return state_dict
|
| 54 |
+
|
| 55 |
+
def from_civitai(self, state_dict):
|
| 56 |
+
return state_dict
|
diffsynth/models/general_modules.py
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch, math
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def get_timestep_embedding(
|
| 5 |
+
timesteps: torch.Tensor,
|
| 6 |
+
embedding_dim: int,
|
| 7 |
+
flip_sin_to_cos: bool = False,
|
| 8 |
+
downscale_freq_shift: float = 1,
|
| 9 |
+
scale: float = 1,
|
| 10 |
+
max_period: int = 10000,
|
| 11 |
+
computation_device = None,
|
| 12 |
+
align_dtype_to_timestep = False,
|
| 13 |
+
):
|
| 14 |
+
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
|
| 15 |
+
|
| 16 |
+
half_dim = embedding_dim // 2
|
| 17 |
+
exponent = -math.log(max_period) * torch.arange(
|
| 18 |
+
start=0, end=half_dim, dtype=torch.float32, device=timesteps.device if computation_device is None else computation_device
|
| 19 |
+
)
|
| 20 |
+
exponent = exponent / (half_dim - downscale_freq_shift)
|
| 21 |
+
|
| 22 |
+
emb = torch.exp(exponent)
|
| 23 |
+
if align_dtype_to_timestep:
|
| 24 |
+
emb = emb.to(timesteps.dtype)
|
| 25 |
+
emb = timesteps[:, None].float() * emb[None, :]
|
| 26 |
+
|
| 27 |
+
# scale embeddings
|
| 28 |
+
emb = scale * emb
|
| 29 |
+
|
| 30 |
+
# concat sine and cosine embeddings
|
| 31 |
+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
|
| 32 |
+
|
| 33 |
+
# flip sine and cosine embeddings
|
| 34 |
+
if flip_sin_to_cos:
|
| 35 |
+
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
|
| 36 |
+
|
| 37 |
+
# zero pad
|
| 38 |
+
if embedding_dim % 2 == 1:
|
| 39 |
+
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
|
| 40 |
+
return emb
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class TemporalTimesteps(torch.nn.Module):
|
| 44 |
+
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, computation_device = None, scale=1, align_dtype_to_timestep=False):
|
| 45 |
+
super().__init__()
|
| 46 |
+
self.num_channels = num_channels
|
| 47 |
+
self.flip_sin_to_cos = flip_sin_to_cos
|
| 48 |
+
self.downscale_freq_shift = downscale_freq_shift
|
| 49 |
+
self.computation_device = computation_device
|
| 50 |
+
self.scale = scale
|
| 51 |
+
self.align_dtype_to_timestep = align_dtype_to_timestep
|
| 52 |
+
|
| 53 |
+
def forward(self, timesteps):
|
| 54 |
+
t_emb = get_timestep_embedding(
|
| 55 |
+
timesteps,
|
| 56 |
+
self.num_channels,
|
| 57 |
+
flip_sin_to_cos=self.flip_sin_to_cos,
|
| 58 |
+
downscale_freq_shift=self.downscale_freq_shift,
|
| 59 |
+
computation_device=self.computation_device,
|
| 60 |
+
scale=self.scale,
|
| 61 |
+
align_dtype_to_timestep=self.align_dtype_to_timestep,
|
| 62 |
+
)
|
| 63 |
+
return t_emb
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class DiffusersCompatibleTimestepProj(torch.nn.Module):
|
| 67 |
+
def __init__(self, dim_in, dim_out):
|
| 68 |
+
super().__init__()
|
| 69 |
+
self.linear_1 = torch.nn.Linear(dim_in, dim_out)
|
| 70 |
+
self.act = torch.nn.SiLU()
|
| 71 |
+
self.linear_2 = torch.nn.Linear(dim_out, dim_out)
|
| 72 |
+
|
| 73 |
+
def forward(self, x):
|
| 74 |
+
x = self.linear_1(x)
|
| 75 |
+
x = self.act(x)
|
| 76 |
+
x = self.linear_2(x)
|
| 77 |
+
return x
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class TimestepEmbeddings(torch.nn.Module):
|
| 81 |
+
def __init__(self, dim_in, dim_out, computation_device=None, diffusers_compatible_format=False, scale=1, align_dtype_to_timestep=False, use_additional_t_cond=False):
|
| 82 |
+
super().__init__()
|
| 83 |
+
self.time_proj = TemporalTimesteps(num_channels=dim_in, flip_sin_to_cos=True, downscale_freq_shift=0, computation_device=computation_device, scale=scale, align_dtype_to_timestep=align_dtype_to_timestep)
|
| 84 |
+
if diffusers_compatible_format:
|
| 85 |
+
self.timestep_embedder = DiffusersCompatibleTimestepProj(dim_in, dim_out)
|
| 86 |
+
else:
|
| 87 |
+
self.timestep_embedder = torch.nn.Sequential(
|
| 88 |
+
torch.nn.Linear(dim_in, dim_out), torch.nn.SiLU(), torch.nn.Linear(dim_out, dim_out)
|
| 89 |
+
)
|
| 90 |
+
self.use_additional_t_cond = use_additional_t_cond
|
| 91 |
+
if use_additional_t_cond:
|
| 92 |
+
self.addition_t_embedding = torch.nn.Embedding(2, dim_out)
|
| 93 |
+
|
| 94 |
+
def forward(self, timestep, dtype, addition_t_cond=None):
|
| 95 |
+
time_emb = self.time_proj(timestep).to(dtype)
|
| 96 |
+
time_emb = self.timestep_embedder(time_emb)
|
| 97 |
+
if addition_t_cond is not None:
|
| 98 |
+
addition_t_emb = self.addition_t_embedding(addition_t_cond)
|
| 99 |
+
addition_t_emb = addition_t_emb.to(dtype=dtype)
|
| 100 |
+
time_emb = time_emb + addition_t_emb
|
| 101 |
+
return time_emb
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
class RMSNorm(torch.nn.Module):
|
| 105 |
+
def __init__(self, dim, eps, elementwise_affine=True):
|
| 106 |
+
super().__init__()
|
| 107 |
+
self.eps = eps
|
| 108 |
+
if elementwise_affine:
|
| 109 |
+
self.weight = torch.nn.Parameter(torch.ones((dim,)))
|
| 110 |
+
else:
|
| 111 |
+
self.weight = None
|
| 112 |
+
|
| 113 |
+
def forward(self, hidden_states):
|
| 114 |
+
input_dtype = hidden_states.dtype
|
| 115 |
+
variance = hidden_states.to(torch.float32).square().mean(-1, keepdim=True)
|
| 116 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
|
| 117 |
+
hidden_states = hidden_states.to(input_dtype)
|
| 118 |
+
if self.weight is not None:
|
| 119 |
+
hidden_states = hidden_states * self.weight
|
| 120 |
+
return hidden_states
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
class AdaLayerNorm(torch.nn.Module):
|
| 124 |
+
def __init__(self, dim, single=False, dual=False):
|
| 125 |
+
super().__init__()
|
| 126 |
+
self.single = single
|
| 127 |
+
self.dual = dual
|
| 128 |
+
self.linear = torch.nn.Linear(dim, dim * [[6, 2][single], 9][dual])
|
| 129 |
+
self.norm = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
| 130 |
+
|
| 131 |
+
def forward(self, x, emb):
|
| 132 |
+
emb = self.linear(torch.nn.functional.silu(emb))
|
| 133 |
+
if self.single:
|
| 134 |
+
scale, shift = emb.unsqueeze(1).chunk(2, dim=2)
|
| 135 |
+
x = self.norm(x) * (1 + scale) + shift
|
| 136 |
+
return x
|
| 137 |
+
elif self.dual:
|
| 138 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp, shift_msa2, scale_msa2, gate_msa2 = emb.unsqueeze(1).chunk(9, dim=2)
|
| 139 |
+
norm_x = self.norm(x)
|
| 140 |
+
x = norm_x * (1 + scale_msa) + shift_msa
|
| 141 |
+
norm_x2 = norm_x * (1 + scale_msa2) + shift_msa2
|
| 142 |
+
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_x2, gate_msa2
|
| 143 |
+
else:
|
| 144 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.unsqueeze(1).chunk(6, dim=2)
|
| 145 |
+
x = self.norm(x) * (1 + scale_msa) + shift_msa
|
| 146 |
+
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
|
diffsynth/models/longcat_video_dit.py
ADDED
|
@@ -0,0 +1,902 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Optional, Tuple
|
| 2 |
+
|
| 3 |
+
import math
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torch.amp as amp
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
from einops import rearrange, repeat
|
| 11 |
+
from .wan_video_dit import flash_attention
|
| 12 |
+
from ..core.device.npu_compatible_device import get_device_type
|
| 13 |
+
from ..core.gradient import gradient_checkpoint_forward
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class RMSNorm_FP32(torch.nn.Module):
|
| 17 |
+
def __init__(self, dim: int, eps: float):
|
| 18 |
+
super().__init__()
|
| 19 |
+
self.eps = eps
|
| 20 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
| 21 |
+
|
| 22 |
+
def _norm(self, x):
|
| 23 |
+
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
| 24 |
+
|
| 25 |
+
def forward(self, x):
|
| 26 |
+
output = self._norm(x.float()).type_as(x)
|
| 27 |
+
return output * self.weight
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def broadcat(tensors, dim=-1):
|
| 31 |
+
num_tensors = len(tensors)
|
| 32 |
+
shape_lens = set(list(map(lambda t: len(t.shape), tensors)))
|
| 33 |
+
assert len(shape_lens) == 1, "tensors must all have the same number of dimensions"
|
| 34 |
+
shape_len = list(shape_lens)[0]
|
| 35 |
+
dim = (dim + shape_len) if dim < 0 else dim
|
| 36 |
+
dims = list(zip(*map(lambda t: list(t.shape), tensors)))
|
| 37 |
+
expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
|
| 38 |
+
assert all(
|
| 39 |
+
[*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]
|
| 40 |
+
), "invalid dimensions for broadcastable concatentation"
|
| 41 |
+
max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims))
|
| 42 |
+
expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims))
|
| 43 |
+
expanded_dims.insert(dim, (dim, dims[dim]))
|
| 44 |
+
expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims)))
|
| 45 |
+
tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes)))
|
| 46 |
+
return torch.cat(tensors, dim=dim)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def rotate_half(x):
|
| 50 |
+
x = rearrange(x, "... (d r) -> ... d r", r=2)
|
| 51 |
+
x1, x2 = x.unbind(dim=-1)
|
| 52 |
+
x = torch.stack((-x2, x1), dim=-1)
|
| 53 |
+
return rearrange(x, "... d r -> ... (d r)")
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class RotaryPositionalEmbedding(nn.Module):
|
| 57 |
+
|
| 58 |
+
def __init__(self,
|
| 59 |
+
head_dim,
|
| 60 |
+
cp_split_hw=None
|
| 61 |
+
):
|
| 62 |
+
"""Rotary positional embedding for 3D
|
| 63 |
+
Reference : https://blog.eleuther.ai/rotary-embeddings/
|
| 64 |
+
Paper: https://arxiv.org/pdf/2104.09864.pdf
|
| 65 |
+
Args:
|
| 66 |
+
dim: Dimension of embedding
|
| 67 |
+
base: Base value for exponential
|
| 68 |
+
"""
|
| 69 |
+
super().__init__()
|
| 70 |
+
self.head_dim = head_dim
|
| 71 |
+
assert self.head_dim % 8 == 0, 'Dim must be a multiply of 8 for 3D RoPE.'
|
| 72 |
+
self.cp_split_hw = cp_split_hw
|
| 73 |
+
# We take the assumption that the longest side of grid will not larger than 512, i.e, 512 * 8 = 4098 input pixels
|
| 74 |
+
self.base = 10000
|
| 75 |
+
self.freqs_dict = {}
|
| 76 |
+
|
| 77 |
+
def register_grid_size(self, grid_size):
|
| 78 |
+
if grid_size not in self.freqs_dict:
|
| 79 |
+
self.freqs_dict.update({
|
| 80 |
+
grid_size: self.precompute_freqs_cis_3d(grid_size)
|
| 81 |
+
})
|
| 82 |
+
|
| 83 |
+
def precompute_freqs_cis_3d(self, grid_size):
|
| 84 |
+
num_frames, height, width = grid_size
|
| 85 |
+
dim_t = self.head_dim - 4 * (self.head_dim // 6)
|
| 86 |
+
dim_h = 2 * (self.head_dim // 6)
|
| 87 |
+
dim_w = 2 * (self.head_dim // 6)
|
| 88 |
+
freqs_t = 1.0 / (self.base ** (torch.arange(0, dim_t, 2)[: (dim_t // 2)].float() / dim_t))
|
| 89 |
+
freqs_h = 1.0 / (self.base ** (torch.arange(0, dim_h, 2)[: (dim_h // 2)].float() / dim_h))
|
| 90 |
+
freqs_w = 1.0 / (self.base ** (torch.arange(0, dim_w, 2)[: (dim_w // 2)].float() / dim_w))
|
| 91 |
+
grid_t = np.linspace(0, num_frames, num_frames, endpoint=False, dtype=np.float32)
|
| 92 |
+
grid_h = np.linspace(0, height, height, endpoint=False, dtype=np.float32)
|
| 93 |
+
grid_w = np.linspace(0, width, width, endpoint=False, dtype=np.float32)
|
| 94 |
+
grid_t = torch.from_numpy(grid_t).float()
|
| 95 |
+
grid_h = torch.from_numpy(grid_h).float()
|
| 96 |
+
grid_w = torch.from_numpy(grid_w).float()
|
| 97 |
+
freqs_t = torch.einsum("..., f -> ... f", grid_t, freqs_t)
|
| 98 |
+
freqs_h = torch.einsum("..., f -> ... f", grid_h, freqs_h)
|
| 99 |
+
freqs_w = torch.einsum("..., f -> ... f", grid_w, freqs_w)
|
| 100 |
+
freqs_t = repeat(freqs_t, "... n -> ... (n r)", r=2)
|
| 101 |
+
freqs_h = repeat(freqs_h, "... n -> ... (n r)", r=2)
|
| 102 |
+
freqs_w = repeat(freqs_w, "... n -> ... (n r)", r=2)
|
| 103 |
+
freqs = broadcat((freqs_t[:, None, None, :], freqs_h[None, :, None, :], freqs_w[None, None, :, :]), dim=-1)
|
| 104 |
+
# (T H W D)
|
| 105 |
+
freqs = rearrange(freqs, "T H W D -> (T H W) D")
|
| 106 |
+
# if self.cp_split_hw[0] * self.cp_split_hw[1] > 1:
|
| 107 |
+
# with torch.no_grad():
|
| 108 |
+
# freqs = rearrange(freqs, "(T H W) D -> T H W D", T=num_frames, H=height, W=width)
|
| 109 |
+
# freqs = context_parallel_util.split_cp_2d(freqs, seq_dim_hw=(1, 2), split_hw=self.cp_split_hw)
|
| 110 |
+
# freqs = rearrange(freqs, "T H W D -> (T H W) D")
|
| 111 |
+
|
| 112 |
+
return freqs
|
| 113 |
+
|
| 114 |
+
def forward(self, q, k, grid_size):
|
| 115 |
+
"""3D RoPE.
|
| 116 |
+
|
| 117 |
+
Args:
|
| 118 |
+
query: [B, head, seq, head_dim]
|
| 119 |
+
key: [B, head, seq, head_dim]
|
| 120 |
+
Returns:
|
| 121 |
+
query and key with the same shape as input.
|
| 122 |
+
"""
|
| 123 |
+
|
| 124 |
+
if grid_size not in self.freqs_dict:
|
| 125 |
+
self.register_grid_size(grid_size)
|
| 126 |
+
|
| 127 |
+
freqs_cis = self.freqs_dict[grid_size].to(q.device)
|
| 128 |
+
q_, k_ = q.float(), k.float()
|
| 129 |
+
freqs_cis = freqs_cis.float().to(q.device)
|
| 130 |
+
cos, sin = freqs_cis.cos(), freqs_cis.sin()
|
| 131 |
+
cos, sin = rearrange(cos, 'n d -> 1 1 n d'), rearrange(sin, 'n d -> 1 1 n d')
|
| 132 |
+
q_ = (q_ * cos) + (rotate_half(q_) * sin)
|
| 133 |
+
k_ = (k_ * cos) + (rotate_half(k_) * sin)
|
| 134 |
+
|
| 135 |
+
return q_.type_as(q), k_.type_as(k)
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
class Attention(nn.Module):
|
| 139 |
+
def __init__(
|
| 140 |
+
self,
|
| 141 |
+
dim: int,
|
| 142 |
+
num_heads: int,
|
| 143 |
+
enable_flashattn3: bool = False,
|
| 144 |
+
enable_flashattn2: bool = False,
|
| 145 |
+
enable_xformers: bool = False,
|
| 146 |
+
enable_bsa: bool = False,
|
| 147 |
+
bsa_params: dict = None,
|
| 148 |
+
cp_split_hw: Optional[List[int]] = None
|
| 149 |
+
) -> None:
|
| 150 |
+
super().__init__()
|
| 151 |
+
assert dim % num_heads == 0, "dim should be divisible by num_heads"
|
| 152 |
+
self.dim = dim
|
| 153 |
+
self.num_heads = num_heads
|
| 154 |
+
self.head_dim = dim // num_heads
|
| 155 |
+
self.scale = self.head_dim**-0.5
|
| 156 |
+
self.enable_flashattn3 = enable_flashattn3
|
| 157 |
+
self.enable_flashattn2 = enable_flashattn2
|
| 158 |
+
self.enable_xformers = enable_xformers
|
| 159 |
+
self.enable_bsa = enable_bsa
|
| 160 |
+
self.bsa_params = bsa_params
|
| 161 |
+
self.cp_split_hw = cp_split_hw
|
| 162 |
+
|
| 163 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=True)
|
| 164 |
+
self.q_norm = RMSNorm_FP32(self.head_dim, eps=1e-6)
|
| 165 |
+
self.k_norm = RMSNorm_FP32(self.head_dim, eps=1e-6)
|
| 166 |
+
self.proj = nn.Linear(dim, dim)
|
| 167 |
+
|
| 168 |
+
self.rope_3d = RotaryPositionalEmbedding(
|
| 169 |
+
self.head_dim,
|
| 170 |
+
cp_split_hw=cp_split_hw
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
def _process_attn(self, q, k, v, shape):
|
| 174 |
+
q = rearrange(q, "B H S D -> B S (H D)")
|
| 175 |
+
k = rearrange(k, "B H S D -> B S (H D)")
|
| 176 |
+
v = rearrange(v, "B H S D -> B S (H D)")
|
| 177 |
+
x = flash_attention(q, k, v, num_heads=self.num_heads)
|
| 178 |
+
x = rearrange(x, "B S (H D) -> B H S D", H=self.num_heads)
|
| 179 |
+
return x
|
| 180 |
+
|
| 181 |
+
def forward(self, x: torch.Tensor, shape=None, num_cond_latents=None, return_kv=False) -> torch.Tensor:
|
| 182 |
+
"""
|
| 183 |
+
"""
|
| 184 |
+
B, N, C = x.shape
|
| 185 |
+
qkv = self.qkv(x)
|
| 186 |
+
|
| 187 |
+
qkv_shape = (B, N, 3, self.num_heads, self.head_dim)
|
| 188 |
+
qkv = qkv.view(qkv_shape).permute((2, 0, 3, 1, 4)) # [3, B, H, N, D]
|
| 189 |
+
q, k, v = qkv.unbind(0)
|
| 190 |
+
q, k = self.q_norm(q), self.k_norm(k)
|
| 191 |
+
|
| 192 |
+
if return_kv:
|
| 193 |
+
k_cache, v_cache = k.clone(), v.clone()
|
| 194 |
+
|
| 195 |
+
q, k = self.rope_3d(q, k, shape)
|
| 196 |
+
|
| 197 |
+
# cond mode
|
| 198 |
+
if num_cond_latents is not None and num_cond_latents > 0:
|
| 199 |
+
num_cond_latents_thw = num_cond_latents * (N // shape[0])
|
| 200 |
+
# process the condition tokens
|
| 201 |
+
q_cond = q[:, :, :num_cond_latents_thw].contiguous()
|
| 202 |
+
k_cond = k[:, :, :num_cond_latents_thw].contiguous()
|
| 203 |
+
v_cond = v[:, :, :num_cond_latents_thw].contiguous()
|
| 204 |
+
x_cond = self._process_attn(q_cond, k_cond, v_cond, shape)
|
| 205 |
+
# process the noise tokens
|
| 206 |
+
q_noise = q[:, :, num_cond_latents_thw:].contiguous()
|
| 207 |
+
x_noise = self._process_attn(q_noise, k, v, shape)
|
| 208 |
+
# merge x_cond and x_noise
|
| 209 |
+
x = torch.cat([x_cond, x_noise], dim=2).contiguous()
|
| 210 |
+
else:
|
| 211 |
+
x = self._process_attn(q, k, v, shape)
|
| 212 |
+
|
| 213 |
+
x_output_shape = (B, N, C)
|
| 214 |
+
x = x.transpose(1, 2) # [B, H, N, D] --> [B, N, H, D]
|
| 215 |
+
x = x.reshape(x_output_shape) # [B, N, H, D] --> [B, N, C]
|
| 216 |
+
x = self.proj(x)
|
| 217 |
+
|
| 218 |
+
if return_kv:
|
| 219 |
+
return x, (k_cache, v_cache)
|
| 220 |
+
else:
|
| 221 |
+
return x
|
| 222 |
+
|
| 223 |
+
def forward_with_kv_cache(self, x: torch.Tensor, shape=None, num_cond_latents=None, kv_cache=None) -> torch.Tensor:
|
| 224 |
+
"""
|
| 225 |
+
"""
|
| 226 |
+
B, N, C = x.shape
|
| 227 |
+
qkv = self.qkv(x)
|
| 228 |
+
|
| 229 |
+
qkv_shape = (B, N, 3, self.num_heads, self.head_dim)
|
| 230 |
+
qkv = qkv.view(qkv_shape).permute((2, 0, 3, 1, 4)) # [3, B, H, N, D]
|
| 231 |
+
q, k, v = qkv.unbind(0)
|
| 232 |
+
q, k = self.q_norm(q), self.k_norm(k)
|
| 233 |
+
|
| 234 |
+
T, H, W = shape
|
| 235 |
+
k_cache, v_cache = kv_cache
|
| 236 |
+
assert k_cache.shape[0] == v_cache.shape[0] and k_cache.shape[0] in [1, B]
|
| 237 |
+
if k_cache.shape[0] == 1:
|
| 238 |
+
k_cache = k_cache.repeat(B, 1, 1, 1)
|
| 239 |
+
v_cache = v_cache.repeat(B, 1, 1, 1)
|
| 240 |
+
|
| 241 |
+
if num_cond_latents is not None and num_cond_latents > 0:
|
| 242 |
+
k_full = torch.cat([k_cache, k], dim=2).contiguous()
|
| 243 |
+
v_full = torch.cat([v_cache, v], dim=2).contiguous()
|
| 244 |
+
q_padding = torch.cat([torch.empty_like(k_cache), q], dim=2).contiguous()
|
| 245 |
+
q_padding, k_full = self.rope_3d(q_padding, k_full, (T + num_cond_latents, H, W))
|
| 246 |
+
q = q_padding[:, :, -N:].contiguous()
|
| 247 |
+
|
| 248 |
+
x = self._process_attn(q, k_full, v_full, shape)
|
| 249 |
+
|
| 250 |
+
x_output_shape = (B, N, C)
|
| 251 |
+
x = x.transpose(1, 2) # [B, H, N, D] --> [B, N, H, D]
|
| 252 |
+
x = x.reshape(x_output_shape) # [B, N, H, D] --> [B, N, C]
|
| 253 |
+
x = self.proj(x)
|
| 254 |
+
|
| 255 |
+
return x
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
class MultiHeadCrossAttention(nn.Module):
|
| 259 |
+
def __init__(
|
| 260 |
+
self,
|
| 261 |
+
dim,
|
| 262 |
+
num_heads,
|
| 263 |
+
enable_flashattn3=False,
|
| 264 |
+
enable_flashattn2=False,
|
| 265 |
+
enable_xformers=False,
|
| 266 |
+
):
|
| 267 |
+
super(MultiHeadCrossAttention, self).__init__()
|
| 268 |
+
assert dim % num_heads == 0, "d_model must be divisible by num_heads"
|
| 269 |
+
|
| 270 |
+
self.dim = dim
|
| 271 |
+
self.num_heads = num_heads
|
| 272 |
+
self.head_dim = dim // num_heads
|
| 273 |
+
|
| 274 |
+
self.q_linear = nn.Linear(dim, dim)
|
| 275 |
+
self.kv_linear = nn.Linear(dim, dim * 2)
|
| 276 |
+
self.proj = nn.Linear(dim, dim)
|
| 277 |
+
|
| 278 |
+
self.q_norm = RMSNorm_FP32(self.head_dim, eps=1e-6)
|
| 279 |
+
self.k_norm = RMSNorm_FP32(self.head_dim, eps=1e-6)
|
| 280 |
+
|
| 281 |
+
self.enable_flashattn3 = enable_flashattn3
|
| 282 |
+
self.enable_flashattn2 = enable_flashattn2
|
| 283 |
+
self.enable_xformers = enable_xformers
|
| 284 |
+
|
| 285 |
+
def _process_cross_attn(self, x, cond, kv_seqlen):
|
| 286 |
+
B, N, C = x.shape
|
| 287 |
+
assert C == self.dim and cond.shape[2] == self.dim
|
| 288 |
+
|
| 289 |
+
q = self.q_linear(x).view(1, -1, self.num_heads, self.head_dim)
|
| 290 |
+
kv = self.kv_linear(cond).view(1, -1, 2, self.num_heads, self.head_dim)
|
| 291 |
+
k, v = kv.unbind(2)
|
| 292 |
+
|
| 293 |
+
q, k = self.q_norm(q), self.k_norm(k)
|
| 294 |
+
|
| 295 |
+
q = rearrange(q, "B S H D -> B S (H D)")
|
| 296 |
+
k = rearrange(k, "B S H D -> B S (H D)")
|
| 297 |
+
v = rearrange(v, "B S H D -> B S (H D)")
|
| 298 |
+
x = flash_attention(q, k, v, num_heads=self.num_heads)
|
| 299 |
+
|
| 300 |
+
x = x.view(B, -1, C)
|
| 301 |
+
x = self.proj(x)
|
| 302 |
+
return x
|
| 303 |
+
|
| 304 |
+
def forward(self, x, cond, kv_seqlen, num_cond_latents=None, shape=None):
|
| 305 |
+
"""
|
| 306 |
+
x: [B, N, C]
|
| 307 |
+
cond: [B, M, C]
|
| 308 |
+
"""
|
| 309 |
+
if num_cond_latents is None or num_cond_latents == 0:
|
| 310 |
+
return self._process_cross_attn(x, cond, kv_seqlen)
|
| 311 |
+
else:
|
| 312 |
+
B, N, C = x.shape
|
| 313 |
+
if num_cond_latents is not None and num_cond_latents > 0:
|
| 314 |
+
assert shape is not None, "SHOULD pass in the shape"
|
| 315 |
+
num_cond_latents_thw = num_cond_latents * (N // shape[0])
|
| 316 |
+
x_noise = x[:, num_cond_latents_thw:] # [B, N_noise, C]
|
| 317 |
+
output_noise = self._process_cross_attn(x_noise, cond, kv_seqlen) # [B, N_noise, C]
|
| 318 |
+
output = torch.cat([
|
| 319 |
+
torch.zeros((B, num_cond_latents_thw, C), dtype=output_noise.dtype, device=output_noise.device),
|
| 320 |
+
output_noise
|
| 321 |
+
], dim=1).contiguous()
|
| 322 |
+
else:
|
| 323 |
+
raise NotImplementedError
|
| 324 |
+
|
| 325 |
+
return output
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
class LayerNorm_FP32(nn.LayerNorm):
|
| 329 |
+
def __init__(self, dim, eps, elementwise_affine):
|
| 330 |
+
super().__init__(dim, eps=eps, elementwise_affine=elementwise_affine)
|
| 331 |
+
|
| 332 |
+
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
| 333 |
+
origin_dtype = inputs.dtype
|
| 334 |
+
out = F.layer_norm(
|
| 335 |
+
inputs.float(),
|
| 336 |
+
self.normalized_shape,
|
| 337 |
+
None if self.weight is None else self.weight.float(),
|
| 338 |
+
None if self.bias is None else self.bias.float() ,
|
| 339 |
+
self.eps
|
| 340 |
+
).to(origin_dtype)
|
| 341 |
+
return out
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
def modulate_fp32(norm_func, x, shift, scale):
|
| 345 |
+
# Suppose x is (B, N, D), shift is (B, -1, D), scale is (B, -1, D)
|
| 346 |
+
# ensure the modulation params be fp32
|
| 347 |
+
assert shift.dtype == torch.float32, scale.dtype == torch.float32
|
| 348 |
+
dtype = x.dtype
|
| 349 |
+
x = norm_func(x.to(torch.float32))
|
| 350 |
+
x = x * (scale + 1) + shift
|
| 351 |
+
x = x.to(dtype)
|
| 352 |
+
return x
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
class FinalLayer_FP32(nn.Module):
|
| 356 |
+
"""
|
| 357 |
+
The final layer of DiT.
|
| 358 |
+
"""
|
| 359 |
+
|
| 360 |
+
def __init__(self, hidden_size, num_patch, out_channels, adaln_tembed_dim):
|
| 361 |
+
super().__init__()
|
| 362 |
+
self.hidden_size = hidden_size
|
| 363 |
+
self.num_patch = num_patch
|
| 364 |
+
self.out_channels = out_channels
|
| 365 |
+
self.adaln_tembed_dim = adaln_tembed_dim
|
| 366 |
+
|
| 367 |
+
self.norm_final = LayerNorm_FP32(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 368 |
+
self.linear = nn.Linear(hidden_size, num_patch * out_channels, bias=True)
|
| 369 |
+
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(adaln_tembed_dim, 2 * hidden_size, bias=True))
|
| 370 |
+
|
| 371 |
+
def forward(self, x, t, latent_shape):
|
| 372 |
+
# timestep shape: [B, T, C]
|
| 373 |
+
assert t.dtype == torch.float32
|
| 374 |
+
B, N, C = x.shape
|
| 375 |
+
T, _, _ = latent_shape
|
| 376 |
+
|
| 377 |
+
with amp.autocast(get_device_type(), dtype=torch.float32):
|
| 378 |
+
shift, scale = self.adaLN_modulation(t).unsqueeze(2).chunk(2, dim=-1) # [B, T, 1, C]
|
| 379 |
+
x = modulate_fp32(self.norm_final, x.view(B, T, -1, C), shift, scale).view(B, N, C)
|
| 380 |
+
x = self.linear(x)
|
| 381 |
+
return x
|
| 382 |
+
|
| 383 |
+
|
| 384 |
+
class FeedForwardSwiGLU(nn.Module):
|
| 385 |
+
def __init__(
|
| 386 |
+
self,
|
| 387 |
+
dim: int,
|
| 388 |
+
hidden_dim: int,
|
| 389 |
+
multiple_of: int = 256,
|
| 390 |
+
ffn_dim_multiplier: Optional[float] = None,
|
| 391 |
+
):
|
| 392 |
+
super().__init__()
|
| 393 |
+
hidden_dim = int(2 * hidden_dim / 3)
|
| 394 |
+
# custom dim factor multiplier
|
| 395 |
+
if ffn_dim_multiplier is not None:
|
| 396 |
+
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
|
| 397 |
+
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
| 398 |
+
|
| 399 |
+
self.dim = dim
|
| 400 |
+
self.hidden_dim = hidden_dim
|
| 401 |
+
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
|
| 402 |
+
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
|
| 403 |
+
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
|
| 404 |
+
|
| 405 |
+
def forward(self, x):
|
| 406 |
+
return self.w2(F.silu(self.w1(x)) * self.w3(x))
|
| 407 |
+
|
| 408 |
+
|
| 409 |
+
class TimestepEmbedder(nn.Module):
|
| 410 |
+
"""
|
| 411 |
+
Embeds scalar timesteps into vector representations.
|
| 412 |
+
"""
|
| 413 |
+
|
| 414 |
+
def __init__(self, t_embed_dim, frequency_embedding_size=256):
|
| 415 |
+
super().__init__()
|
| 416 |
+
self.t_embed_dim = t_embed_dim
|
| 417 |
+
self.frequency_embedding_size = frequency_embedding_size
|
| 418 |
+
self.mlp = nn.Sequential(
|
| 419 |
+
nn.Linear(frequency_embedding_size, t_embed_dim, bias=True),
|
| 420 |
+
nn.SiLU(),
|
| 421 |
+
nn.Linear(t_embed_dim, t_embed_dim, bias=True),
|
| 422 |
+
)
|
| 423 |
+
|
| 424 |
+
@staticmethod
|
| 425 |
+
def timestep_embedding(t, dim, max_period=10000):
|
| 426 |
+
"""
|
| 427 |
+
Create sinusoidal timestep embeddings.
|
| 428 |
+
:param t: a 1-D Tensor of N indices, one per batch element.
|
| 429 |
+
These may be fractional.
|
| 430 |
+
:param dim: the dimension of the output.
|
| 431 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
| 432 |
+
:return: an (N, D) Tensor of positional embeddings.
|
| 433 |
+
"""
|
| 434 |
+
half = dim // 2
|
| 435 |
+
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half)
|
| 436 |
+
freqs = freqs.to(device=t.device)
|
| 437 |
+
args = t[:, None].float() * freqs[None]
|
| 438 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
| 439 |
+
if dim % 2:
|
| 440 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
| 441 |
+
return embedding
|
| 442 |
+
|
| 443 |
+
def forward(self, t, dtype):
|
| 444 |
+
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
| 445 |
+
if t_freq.dtype != dtype:
|
| 446 |
+
t_freq = t_freq.to(dtype)
|
| 447 |
+
t_emb = self.mlp(t_freq)
|
| 448 |
+
return t_emb
|
| 449 |
+
|
| 450 |
+
|
| 451 |
+
class CaptionEmbedder(nn.Module):
|
| 452 |
+
"""
|
| 453 |
+
Embeds class labels into vector representations.
|
| 454 |
+
"""
|
| 455 |
+
|
| 456 |
+
def __init__(self, in_channels, hidden_size):
|
| 457 |
+
super().__init__()
|
| 458 |
+
self.in_channels = in_channels
|
| 459 |
+
self.hidden_size = hidden_size
|
| 460 |
+
self.y_proj = nn.Sequential(
|
| 461 |
+
nn.Linear(in_channels, hidden_size, bias=True),
|
| 462 |
+
nn.GELU(approximate="tanh"),
|
| 463 |
+
nn.Linear(hidden_size, hidden_size, bias=True),
|
| 464 |
+
)
|
| 465 |
+
|
| 466 |
+
def forward(self, caption):
|
| 467 |
+
B, _, N, C = caption.shape
|
| 468 |
+
caption = self.y_proj(caption)
|
| 469 |
+
return caption
|
| 470 |
+
|
| 471 |
+
|
| 472 |
+
class PatchEmbed3D(nn.Module):
|
| 473 |
+
"""Video to Patch Embedding.
|
| 474 |
+
|
| 475 |
+
Args:
|
| 476 |
+
patch_size (int): Patch token size. Default: (2,4,4).
|
| 477 |
+
in_chans (int): Number of input video channels. Default: 3.
|
| 478 |
+
embed_dim (int): Number of linear projection output channels. Default: 96.
|
| 479 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: None
|
| 480 |
+
"""
|
| 481 |
+
|
| 482 |
+
def __init__(
|
| 483 |
+
self,
|
| 484 |
+
patch_size=(2, 4, 4),
|
| 485 |
+
in_chans=3,
|
| 486 |
+
embed_dim=96,
|
| 487 |
+
norm_layer=None,
|
| 488 |
+
flatten=True,
|
| 489 |
+
):
|
| 490 |
+
super().__init__()
|
| 491 |
+
self.patch_size = patch_size
|
| 492 |
+
self.flatten = flatten
|
| 493 |
+
|
| 494 |
+
self.in_chans = in_chans
|
| 495 |
+
self.embed_dim = embed_dim
|
| 496 |
+
|
| 497 |
+
self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
| 498 |
+
if norm_layer is not None:
|
| 499 |
+
self.norm = norm_layer(embed_dim)
|
| 500 |
+
else:
|
| 501 |
+
self.norm = None
|
| 502 |
+
|
| 503 |
+
def forward(self, x):
|
| 504 |
+
"""Forward function."""
|
| 505 |
+
# padding
|
| 506 |
+
_, _, D, H, W = x.size()
|
| 507 |
+
if W % self.patch_size[2] != 0:
|
| 508 |
+
x = F.pad(x, (0, self.patch_size[2] - W % self.patch_size[2]))
|
| 509 |
+
if H % self.patch_size[1] != 0:
|
| 510 |
+
x = F.pad(x, (0, 0, 0, self.patch_size[1] - H % self.patch_size[1]))
|
| 511 |
+
if D % self.patch_size[0] != 0:
|
| 512 |
+
x = F.pad(x, (0, 0, 0, 0, 0, self.patch_size[0] - D % self.patch_size[0]))
|
| 513 |
+
|
| 514 |
+
B, C, T, H, W = x.shape
|
| 515 |
+
x = self.proj(x) # (B C T H W)
|
| 516 |
+
if self.norm is not None:
|
| 517 |
+
D, Wh, Ww = x.size(2), x.size(3), x.size(4)
|
| 518 |
+
x = x.flatten(2).transpose(1, 2)
|
| 519 |
+
x = self.norm(x)
|
| 520 |
+
x = x.transpose(1, 2).view(-1, self.embed_dim, D, Wh, Ww)
|
| 521 |
+
if self.flatten:
|
| 522 |
+
x = x.flatten(2).transpose(1, 2) # BCTHW -> BNC
|
| 523 |
+
return x
|
| 524 |
+
|
| 525 |
+
|
| 526 |
+
class LongCatSingleStreamBlock(nn.Module):
|
| 527 |
+
def __init__(
|
| 528 |
+
self,
|
| 529 |
+
hidden_size: int,
|
| 530 |
+
num_heads: int,
|
| 531 |
+
mlp_ratio: int,
|
| 532 |
+
adaln_tembed_dim: int,
|
| 533 |
+
enable_flashattn3: bool = False,
|
| 534 |
+
enable_flashattn2: bool = False,
|
| 535 |
+
enable_xformers: bool = False,
|
| 536 |
+
enable_bsa: bool = False,
|
| 537 |
+
bsa_params=None,
|
| 538 |
+
cp_split_hw=None
|
| 539 |
+
):
|
| 540 |
+
super().__init__()
|
| 541 |
+
|
| 542 |
+
self.hidden_size = hidden_size
|
| 543 |
+
|
| 544 |
+
# scale and gate modulation
|
| 545 |
+
self.adaLN_modulation = nn.Sequential(
|
| 546 |
+
nn.SiLU(),
|
| 547 |
+
nn.Linear(adaln_tembed_dim, 6 * hidden_size, bias=True)
|
| 548 |
+
)
|
| 549 |
+
|
| 550 |
+
self.mod_norm_attn = LayerNorm_FP32(hidden_size, eps=1e-6, elementwise_affine=False)
|
| 551 |
+
self.mod_norm_ffn = LayerNorm_FP32(hidden_size, eps=1e-6, elementwise_affine=False)
|
| 552 |
+
self.pre_crs_attn_norm = LayerNorm_FP32(hidden_size, eps=1e-6, elementwise_affine=True)
|
| 553 |
+
|
| 554 |
+
self.attn = Attention(
|
| 555 |
+
dim=hidden_size,
|
| 556 |
+
num_heads=num_heads,
|
| 557 |
+
enable_flashattn3=enable_flashattn3,
|
| 558 |
+
enable_flashattn2=enable_flashattn2,
|
| 559 |
+
enable_xformers=enable_xformers,
|
| 560 |
+
enable_bsa=enable_bsa,
|
| 561 |
+
bsa_params=bsa_params,
|
| 562 |
+
cp_split_hw=cp_split_hw
|
| 563 |
+
)
|
| 564 |
+
self.cross_attn = MultiHeadCrossAttention(
|
| 565 |
+
dim=hidden_size,
|
| 566 |
+
num_heads=num_heads,
|
| 567 |
+
enable_flashattn3=enable_flashattn3,
|
| 568 |
+
enable_flashattn2=enable_flashattn2,
|
| 569 |
+
enable_xformers=enable_xformers,
|
| 570 |
+
)
|
| 571 |
+
self.ffn = FeedForwardSwiGLU(dim=hidden_size, hidden_dim=int(hidden_size * mlp_ratio))
|
| 572 |
+
|
| 573 |
+
def forward(self, x, y, t, y_seqlen, latent_shape, num_cond_latents=None, return_kv=False, kv_cache=None, skip_crs_attn=False):
|
| 574 |
+
"""
|
| 575 |
+
x: [B, N, C]
|
| 576 |
+
y: [1, N_valid_tokens, C]
|
| 577 |
+
t: [B, T, C_t]
|
| 578 |
+
y_seqlen: [B]; type of a list
|
| 579 |
+
latent_shape: latent shape of a single item
|
| 580 |
+
"""
|
| 581 |
+
x_dtype = x.dtype
|
| 582 |
+
|
| 583 |
+
B, N, C = x.shape
|
| 584 |
+
T, _, _ = latent_shape # S != T*H*W in case of CP split on H*W.
|
| 585 |
+
|
| 586 |
+
# compute modulation params in fp32
|
| 587 |
+
with amp.autocast(device_type=get_device_type(), dtype=torch.float32):
|
| 588 |
+
shift_msa, scale_msa, gate_msa, \
|
| 589 |
+
shift_mlp, scale_mlp, gate_mlp = \
|
| 590 |
+
self.adaLN_modulation(t).unsqueeze(2).chunk(6, dim=-1) # [B, T, 1, C]
|
| 591 |
+
|
| 592 |
+
# self attn with modulation
|
| 593 |
+
x_m = modulate_fp32(self.mod_norm_attn, x.view(B, T, -1, C), shift_msa, scale_msa).view(B, N, C)
|
| 594 |
+
|
| 595 |
+
if kv_cache is not None:
|
| 596 |
+
kv_cache = (kv_cache[0].to(x.device), kv_cache[1].to(x.device))
|
| 597 |
+
attn_outputs = self.attn.forward_with_kv_cache(x_m, shape=latent_shape, num_cond_latents=num_cond_latents, kv_cache=kv_cache)
|
| 598 |
+
else:
|
| 599 |
+
attn_outputs = self.attn(x_m, shape=latent_shape, num_cond_latents=num_cond_latents, return_kv=return_kv)
|
| 600 |
+
|
| 601 |
+
if return_kv:
|
| 602 |
+
x_s, kv_cache = attn_outputs
|
| 603 |
+
else:
|
| 604 |
+
x_s = attn_outputs
|
| 605 |
+
|
| 606 |
+
with amp.autocast(device_type=get_device_type(), dtype=torch.float32):
|
| 607 |
+
x = x + (gate_msa * x_s.view(B, -1, N//T, C)).view(B, -1, C) # [B, N, C]
|
| 608 |
+
x = x.to(x_dtype)
|
| 609 |
+
|
| 610 |
+
# cross attn
|
| 611 |
+
if not skip_crs_attn:
|
| 612 |
+
if kv_cache is not None:
|
| 613 |
+
num_cond_latents = None
|
| 614 |
+
x = x + self.cross_attn(self.pre_crs_attn_norm(x), y, y_seqlen, num_cond_latents=num_cond_latents, shape=latent_shape)
|
| 615 |
+
|
| 616 |
+
# ffn with modulation
|
| 617 |
+
x_m = modulate_fp32(self.mod_norm_ffn, x.view(B, -1, N//T, C), shift_mlp, scale_mlp).view(B, -1, C)
|
| 618 |
+
x_s = self.ffn(x_m)
|
| 619 |
+
with amp.autocast(device_type=get_device_type(), dtype=torch.float32):
|
| 620 |
+
x = x + (gate_mlp * x_s.view(B, -1, N//T, C)).view(B, -1, C) # [B, N, C]
|
| 621 |
+
x = x.to(x_dtype)
|
| 622 |
+
|
| 623 |
+
if return_kv:
|
| 624 |
+
return x, kv_cache
|
| 625 |
+
else:
|
| 626 |
+
return x
|
| 627 |
+
|
| 628 |
+
|
| 629 |
+
class LongCatVideoTransformer3DModel(torch.nn.Module):
|
| 630 |
+
def __init__(
|
| 631 |
+
self,
|
| 632 |
+
in_channels: int = 16,
|
| 633 |
+
out_channels: int = 16,
|
| 634 |
+
hidden_size: int = 4096,
|
| 635 |
+
depth: int = 48,
|
| 636 |
+
num_heads: int = 32,
|
| 637 |
+
caption_channels: int = 4096,
|
| 638 |
+
mlp_ratio: int = 4,
|
| 639 |
+
adaln_tembed_dim: int = 512,
|
| 640 |
+
frequency_embedding_size: int = 256,
|
| 641 |
+
# default params
|
| 642 |
+
patch_size: Tuple[int] = (1, 2, 2),
|
| 643 |
+
# attention config
|
| 644 |
+
enable_flashattn3: bool = False,
|
| 645 |
+
enable_flashattn2: bool = True,
|
| 646 |
+
enable_xformers: bool = False,
|
| 647 |
+
enable_bsa: bool = False,
|
| 648 |
+
bsa_params: dict = {'sparsity': 0.9375, 'chunk_3d_shape_q': [4, 4, 4], 'chunk_3d_shape_k': [4, 4, 4]},
|
| 649 |
+
cp_split_hw: Optional[List[int]] = [1, 1],
|
| 650 |
+
text_tokens_zero_pad: bool = True,
|
| 651 |
+
) -> None:
|
| 652 |
+
super().__init__()
|
| 653 |
+
|
| 654 |
+
self.patch_size = patch_size
|
| 655 |
+
self.in_channels = in_channels
|
| 656 |
+
self.out_channels = out_channels
|
| 657 |
+
self.cp_split_hw = cp_split_hw
|
| 658 |
+
|
| 659 |
+
self.x_embedder = PatchEmbed3D(patch_size, in_channels, hidden_size)
|
| 660 |
+
self.t_embedder = TimestepEmbedder(t_embed_dim=adaln_tembed_dim, frequency_embedding_size=frequency_embedding_size)
|
| 661 |
+
self.y_embedder = CaptionEmbedder(
|
| 662 |
+
in_channels=caption_channels,
|
| 663 |
+
hidden_size=hidden_size,
|
| 664 |
+
)
|
| 665 |
+
|
| 666 |
+
self.blocks = nn.ModuleList(
|
| 667 |
+
[
|
| 668 |
+
LongCatSingleStreamBlock(
|
| 669 |
+
hidden_size=hidden_size,
|
| 670 |
+
num_heads=num_heads,
|
| 671 |
+
mlp_ratio=mlp_ratio,
|
| 672 |
+
adaln_tembed_dim=adaln_tembed_dim,
|
| 673 |
+
enable_flashattn3=enable_flashattn3,
|
| 674 |
+
enable_flashattn2=enable_flashattn2,
|
| 675 |
+
enable_xformers=enable_xformers,
|
| 676 |
+
enable_bsa=enable_bsa,
|
| 677 |
+
bsa_params=bsa_params,
|
| 678 |
+
cp_split_hw=cp_split_hw
|
| 679 |
+
)
|
| 680 |
+
for i in range(depth)
|
| 681 |
+
]
|
| 682 |
+
)
|
| 683 |
+
|
| 684 |
+
self.final_layer = FinalLayer_FP32(
|
| 685 |
+
hidden_size,
|
| 686 |
+
np.prod(self.patch_size),
|
| 687 |
+
out_channels,
|
| 688 |
+
adaln_tembed_dim,
|
| 689 |
+
)
|
| 690 |
+
|
| 691 |
+
self.gradient_checkpointing = False
|
| 692 |
+
self.text_tokens_zero_pad = text_tokens_zero_pad
|
| 693 |
+
|
| 694 |
+
self.lora_dict = {}
|
| 695 |
+
self.active_loras = []
|
| 696 |
+
|
| 697 |
+
def enable_loras(self, lora_key_list=[]):
|
| 698 |
+
self.disable_all_loras()
|
| 699 |
+
|
| 700 |
+
module_loras = {} # {module_name: [lora1, lora2, ...]}
|
| 701 |
+
model_device = next(self.parameters()).device
|
| 702 |
+
model_dtype = next(self.parameters()).dtype
|
| 703 |
+
|
| 704 |
+
for lora_key in lora_key_list:
|
| 705 |
+
if lora_key in self.lora_dict:
|
| 706 |
+
for lora in self.lora_dict[lora_key].loras:
|
| 707 |
+
lora.to(model_device, dtype=model_dtype, non_blocking=True)
|
| 708 |
+
module_name = lora.lora_name.replace("lora___lorahyphen___", "").replace("___lorahyphen___", ".")
|
| 709 |
+
if module_name not in module_loras:
|
| 710 |
+
module_loras[module_name] = []
|
| 711 |
+
module_loras[module_name].append(lora)
|
| 712 |
+
self.active_loras.append(lora_key)
|
| 713 |
+
|
| 714 |
+
for module_name, loras in module_loras.items():
|
| 715 |
+
module = self._get_module_by_name(module_name)
|
| 716 |
+
if not hasattr(module, 'org_forward'):
|
| 717 |
+
module.org_forward = module.forward
|
| 718 |
+
module.forward = self._create_multi_lora_forward(module, loras)
|
| 719 |
+
|
| 720 |
+
def _create_multi_lora_forward(self, module, loras):
|
| 721 |
+
def multi_lora_forward(x, *args, **kwargs):
|
| 722 |
+
weight_dtype = x.dtype
|
| 723 |
+
org_output = module.org_forward(x, *args, **kwargs)
|
| 724 |
+
|
| 725 |
+
total_lora_output = 0
|
| 726 |
+
for lora in loras:
|
| 727 |
+
if lora.use_lora:
|
| 728 |
+
lx = lora.lora_down(x.to(lora.lora_down.weight.dtype))
|
| 729 |
+
lx = lora.lora_up(lx)
|
| 730 |
+
lora_output = lx.to(weight_dtype) * lora.multiplier * lora.alpha_scale
|
| 731 |
+
total_lora_output += lora_output
|
| 732 |
+
|
| 733 |
+
return org_output + total_lora_output
|
| 734 |
+
|
| 735 |
+
return multi_lora_forward
|
| 736 |
+
|
| 737 |
+
def _get_module_by_name(self, module_name):
|
| 738 |
+
try:
|
| 739 |
+
module = self
|
| 740 |
+
for part in module_name.split('.'):
|
| 741 |
+
module = getattr(module, part)
|
| 742 |
+
return module
|
| 743 |
+
except AttributeError as e:
|
| 744 |
+
raise ValueError(f"Cannot find module: {module_name}, error: {e}")
|
| 745 |
+
|
| 746 |
+
def disable_all_loras(self):
|
| 747 |
+
for name, module in self.named_modules():
|
| 748 |
+
if hasattr(module, 'org_forward'):
|
| 749 |
+
module.forward = module.org_forward
|
| 750 |
+
delattr(module, 'org_forward')
|
| 751 |
+
|
| 752 |
+
for lora_key, lora_network in self.lora_dict.items():
|
| 753 |
+
for lora in lora_network.loras:
|
| 754 |
+
lora.to("cpu")
|
| 755 |
+
|
| 756 |
+
self.active_loras.clear()
|
| 757 |
+
|
| 758 |
+
def enable_bsa(self,):
|
| 759 |
+
for block in self.blocks:
|
| 760 |
+
block.attn.enable_bsa = True
|
| 761 |
+
|
| 762 |
+
def disable_bsa(self,):
|
| 763 |
+
for block in self.blocks:
|
| 764 |
+
block.attn.enable_bsa = False
|
| 765 |
+
|
| 766 |
+
def forward(
|
| 767 |
+
self,
|
| 768 |
+
hidden_states,
|
| 769 |
+
timestep,
|
| 770 |
+
encoder_hidden_states,
|
| 771 |
+
encoder_attention_mask=None,
|
| 772 |
+
num_cond_latents=0,
|
| 773 |
+
return_kv=False,
|
| 774 |
+
kv_cache_dict={},
|
| 775 |
+
skip_crs_attn=False,
|
| 776 |
+
offload_kv_cache=False,
|
| 777 |
+
use_gradient_checkpointing=False,
|
| 778 |
+
use_gradient_checkpointing_offload=False,
|
| 779 |
+
):
|
| 780 |
+
|
| 781 |
+
B, _, T, H, W = hidden_states.shape
|
| 782 |
+
|
| 783 |
+
N_t = T // self.patch_size[0]
|
| 784 |
+
N_h = H // self.patch_size[1]
|
| 785 |
+
N_w = W // self.patch_size[2]
|
| 786 |
+
|
| 787 |
+
assert self.patch_size[0]==1, "Currently, 3D x_embedder should not compress the temporal dimension."
|
| 788 |
+
|
| 789 |
+
# expand the shape of timestep from [B] to [B, T]
|
| 790 |
+
if len(timestep.shape) == 1:
|
| 791 |
+
timestep = timestep.unsqueeze(1).expand(-1, N_t).clone() # [B, T]
|
| 792 |
+
timestep[:, :num_cond_latents] = 0
|
| 793 |
+
|
| 794 |
+
dtype = hidden_states.dtype
|
| 795 |
+
hidden_states = hidden_states.to(dtype)
|
| 796 |
+
timestep = timestep.to(dtype)
|
| 797 |
+
encoder_hidden_states = encoder_hidden_states.to(dtype)
|
| 798 |
+
|
| 799 |
+
hidden_states = self.x_embedder(hidden_states) # [B, N, C]
|
| 800 |
+
|
| 801 |
+
with amp.autocast(device_type=get_device_type(), dtype=torch.float32):
|
| 802 |
+
t = self.t_embedder(timestep.float().flatten(), dtype=torch.float32).reshape(B, N_t, -1) # [B, T, C_t]
|
| 803 |
+
|
| 804 |
+
encoder_hidden_states = self.y_embedder(encoder_hidden_states) # [B, 1, N_token, C]
|
| 805 |
+
|
| 806 |
+
if self.text_tokens_zero_pad and encoder_attention_mask is not None:
|
| 807 |
+
encoder_hidden_states = encoder_hidden_states * encoder_attention_mask[:, None, :, None]
|
| 808 |
+
encoder_attention_mask = (encoder_attention_mask * 0 + 1).to(encoder_attention_mask.dtype)
|
| 809 |
+
|
| 810 |
+
if encoder_attention_mask is not None:
|
| 811 |
+
encoder_attention_mask = encoder_attention_mask.squeeze(1).squeeze(1)
|
| 812 |
+
encoder_hidden_states = encoder_hidden_states.squeeze(1).masked_select(encoder_attention_mask.unsqueeze(-1) != 0).view(1, -1, hidden_states.shape[-1]) # [1, N_valid_tokens, C]
|
| 813 |
+
y_seqlens = encoder_attention_mask.sum(dim=1).tolist() # [B]
|
| 814 |
+
else:
|
| 815 |
+
y_seqlens = [encoder_hidden_states.shape[2]] * encoder_hidden_states.shape[0]
|
| 816 |
+
encoder_hidden_states = encoder_hidden_states.squeeze(1).view(1, -1, hidden_states.shape[-1])
|
| 817 |
+
|
| 818 |
+
# if self.cp_split_hw[0] * self.cp_split_hw[1] > 1:
|
| 819 |
+
# hidden_states = rearrange(hidden_states, "B (T H W) C -> B T H W C", T=N_t, H=N_h, W=N_w)
|
| 820 |
+
# hidden_states = context_parallel_util.split_cp_2d(hidden_states, seq_dim_hw=(2, 3), split_hw=self.cp_split_hw)
|
| 821 |
+
# hidden_states = rearrange(hidden_states, "B T H W C -> B (T H W) C")
|
| 822 |
+
|
| 823 |
+
# blocks
|
| 824 |
+
kv_cache_dict_ret = {}
|
| 825 |
+
for i, block in enumerate(self.blocks):
|
| 826 |
+
block_outputs = gradient_checkpoint_forward(
|
| 827 |
+
block,
|
| 828 |
+
use_gradient_checkpointing=use_gradient_checkpointing,
|
| 829 |
+
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
| 830 |
+
x=hidden_states,
|
| 831 |
+
y=encoder_hidden_states,
|
| 832 |
+
t=t,
|
| 833 |
+
y_seqlen=y_seqlens,
|
| 834 |
+
latent_shape=(N_t, N_h, N_w),
|
| 835 |
+
num_cond_latents=num_cond_latents,
|
| 836 |
+
return_kv=return_kv,
|
| 837 |
+
kv_cache=kv_cache_dict.get(i, None),
|
| 838 |
+
skip_crs_attn=skip_crs_attn,
|
| 839 |
+
)
|
| 840 |
+
|
| 841 |
+
if return_kv:
|
| 842 |
+
hidden_states, kv_cache = block_outputs
|
| 843 |
+
if offload_kv_cache:
|
| 844 |
+
kv_cache_dict_ret[i] = (kv_cache[0].cpu(), kv_cache[1].cpu())
|
| 845 |
+
else:
|
| 846 |
+
kv_cache_dict_ret[i] = (kv_cache[0].contiguous(), kv_cache[1].contiguous())
|
| 847 |
+
else:
|
| 848 |
+
hidden_states = block_outputs
|
| 849 |
+
|
| 850 |
+
hidden_states = self.final_layer(hidden_states, t, (N_t, N_h, N_w)) # [B, N, C=T_p*H_p*W_p*C_out]
|
| 851 |
+
|
| 852 |
+
# if self.cp_split_hw[0] * self.cp_split_hw[1] > 1:
|
| 853 |
+
# hidden_states = context_parallel_util.gather_cp_2d(hidden_states, shape=(N_t, N_h, N_w), split_hw=self.cp_split_hw)
|
| 854 |
+
|
| 855 |
+
hidden_states = self.unpatchify(hidden_states, N_t, N_h, N_w) # [B, C_out, H, W]
|
| 856 |
+
|
| 857 |
+
# cast to float32 for better accuracy
|
| 858 |
+
hidden_states = hidden_states.to(torch.float32)
|
| 859 |
+
|
| 860 |
+
if return_kv:
|
| 861 |
+
return hidden_states, kv_cache_dict_ret
|
| 862 |
+
else:
|
| 863 |
+
return hidden_states
|
| 864 |
+
|
| 865 |
+
|
| 866 |
+
def unpatchify(self, x, N_t, N_h, N_w):
|
| 867 |
+
"""
|
| 868 |
+
Args:
|
| 869 |
+
x (torch.Tensor): of shape [B, N, C]
|
| 870 |
+
|
| 871 |
+
Return:
|
| 872 |
+
x (torch.Tensor): of shape [B, C_out, T, H, W]
|
| 873 |
+
"""
|
| 874 |
+
T_p, H_p, W_p = self.patch_size
|
| 875 |
+
x = rearrange(
|
| 876 |
+
x,
|
| 877 |
+
"B (N_t N_h N_w) (T_p H_p W_p C_out) -> B C_out (N_t T_p) (N_h H_p) (N_w W_p)",
|
| 878 |
+
N_t=N_t,
|
| 879 |
+
N_h=N_h,
|
| 880 |
+
N_w=N_w,
|
| 881 |
+
T_p=T_p,
|
| 882 |
+
H_p=H_p,
|
| 883 |
+
W_p=W_p,
|
| 884 |
+
C_out=self.out_channels,
|
| 885 |
+
)
|
| 886 |
+
return x
|
| 887 |
+
|
| 888 |
+
@staticmethod
|
| 889 |
+
def state_dict_converter():
|
| 890 |
+
return LongCatVideoTransformer3DModelDictConverter()
|
| 891 |
+
|
| 892 |
+
|
| 893 |
+
class LongCatVideoTransformer3DModelDictConverter:
|
| 894 |
+
def __init__(self):
|
| 895 |
+
pass
|
| 896 |
+
|
| 897 |
+
def from_diffusers(self, state_dict):
|
| 898 |
+
return state_dict
|
| 899 |
+
|
| 900 |
+
def from_civitai(self, state_dict):
|
| 901 |
+
return state_dict
|
| 902 |
+
|
diffsynth/models/ltx2_audio_vae.py
ADDED
|
@@ -0,0 +1,1872 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Set, Tuple, Optional, List
|
| 2 |
+
from enum import Enum
|
| 3 |
+
import math
|
| 4 |
+
import einops
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
import torchaudio
|
| 9 |
+
from .ltx2_common import VideoLatentShape, AudioLatentShape, Patchifier, NormType, build_normalization_layer
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class AudioProcessor(nn.Module):
|
| 13 |
+
"""Converts audio waveforms to log-mel spectrograms with optional resampling."""
|
| 14 |
+
|
| 15 |
+
def __init__(
|
| 16 |
+
self,
|
| 17 |
+
sample_rate: int = 16000,
|
| 18 |
+
mel_bins: int = 64,
|
| 19 |
+
mel_hop_length: int = 160,
|
| 20 |
+
n_fft: int = 1024,
|
| 21 |
+
) -> None:
|
| 22 |
+
super().__init__()
|
| 23 |
+
self.sample_rate = sample_rate
|
| 24 |
+
self.mel_transform = torchaudio.transforms.MelSpectrogram(
|
| 25 |
+
sample_rate=sample_rate,
|
| 26 |
+
n_fft=n_fft,
|
| 27 |
+
win_length=n_fft,
|
| 28 |
+
hop_length=mel_hop_length,
|
| 29 |
+
f_min=0.0,
|
| 30 |
+
f_max=sample_rate / 2.0,
|
| 31 |
+
n_mels=mel_bins,
|
| 32 |
+
window_fn=torch.hann_window,
|
| 33 |
+
center=True,
|
| 34 |
+
pad_mode="reflect",
|
| 35 |
+
power=1.0,
|
| 36 |
+
mel_scale="slaney",
|
| 37 |
+
norm="slaney",
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
def resample_waveform(
|
| 41 |
+
self,
|
| 42 |
+
waveform: torch.Tensor,
|
| 43 |
+
source_rate: int,
|
| 44 |
+
target_rate: int,
|
| 45 |
+
) -> torch.Tensor:
|
| 46 |
+
"""Resample waveform to target sample rate if needed."""
|
| 47 |
+
if source_rate == target_rate:
|
| 48 |
+
return waveform
|
| 49 |
+
resampled = torchaudio.functional.resample(waveform, source_rate, target_rate)
|
| 50 |
+
return resampled.to(device=waveform.device, dtype=waveform.dtype)
|
| 51 |
+
|
| 52 |
+
def waveform_to_mel(
|
| 53 |
+
self,
|
| 54 |
+
waveform: torch.Tensor,
|
| 55 |
+
waveform_sample_rate: int,
|
| 56 |
+
) -> torch.Tensor:
|
| 57 |
+
"""Convert waveform to log-mel spectrogram [batch, channels, time, n_mels]."""
|
| 58 |
+
waveform = self.resample_waveform(waveform, waveform_sample_rate, self.sample_rate)
|
| 59 |
+
|
| 60 |
+
mel = self.mel_transform(waveform)
|
| 61 |
+
mel = torch.log(torch.clamp(mel, min=1e-5))
|
| 62 |
+
|
| 63 |
+
mel = mel.to(device=waveform.device, dtype=waveform.dtype)
|
| 64 |
+
return mel.permute(0, 1, 3, 2).contiguous()
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class AudioPatchifier(Patchifier):
|
| 68 |
+
def __init__(
|
| 69 |
+
self,
|
| 70 |
+
patch_size: int,
|
| 71 |
+
sample_rate: int = 16000,
|
| 72 |
+
hop_length: int = 160,
|
| 73 |
+
audio_latent_downsample_factor: int = 4,
|
| 74 |
+
is_causal: bool = True,
|
| 75 |
+
shift: int = 0,
|
| 76 |
+
):
|
| 77 |
+
"""
|
| 78 |
+
Patchifier tailored for spectrogram/audio latents.
|
| 79 |
+
Args:
|
| 80 |
+
patch_size: Number of mel bins combined into a single patch. This
|
| 81 |
+
controls the resolution along the frequency axis.
|
| 82 |
+
sample_rate: Original waveform sampling rate. Used to map latent
|
| 83 |
+
indices back to seconds so downstream consumers can align audio
|
| 84 |
+
and video cues.
|
| 85 |
+
hop_length: Window hop length used for the spectrogram. Determines
|
| 86 |
+
how many real-time samples separate two consecutive latent frames.
|
| 87 |
+
audio_latent_downsample_factor: Ratio between spectrogram frames and
|
| 88 |
+
latent frames; compensates for additional downsampling inside the
|
| 89 |
+
VAE encoder.
|
| 90 |
+
is_causal: When True, timing is shifted to account for causal
|
| 91 |
+
receptive fields so timestamps do not peek into the future.
|
| 92 |
+
shift: Integer offset applied to the latent indices. Enables
|
| 93 |
+
constructing overlapping windows from the same latent sequence.
|
| 94 |
+
"""
|
| 95 |
+
self.hop_length = hop_length
|
| 96 |
+
self.sample_rate = sample_rate
|
| 97 |
+
self.audio_latent_downsample_factor = audio_latent_downsample_factor
|
| 98 |
+
self.is_causal = is_causal
|
| 99 |
+
self.shift = shift
|
| 100 |
+
self._patch_size = (1, patch_size, patch_size)
|
| 101 |
+
|
| 102 |
+
@property
|
| 103 |
+
def patch_size(self) -> Tuple[int, int, int]:
|
| 104 |
+
return self._patch_size
|
| 105 |
+
|
| 106 |
+
def get_token_count(self, tgt_shape: AudioLatentShape) -> int:
|
| 107 |
+
return tgt_shape.frames
|
| 108 |
+
|
| 109 |
+
def _get_audio_latent_time_in_sec(
|
| 110 |
+
self,
|
| 111 |
+
start_latent: int,
|
| 112 |
+
end_latent: int,
|
| 113 |
+
dtype: torch.dtype,
|
| 114 |
+
device: Optional[torch.device] = None,
|
| 115 |
+
) -> torch.Tensor:
|
| 116 |
+
"""
|
| 117 |
+
Converts latent indices into real-time seconds while honoring causal
|
| 118 |
+
offsets and the configured hop length.
|
| 119 |
+
Args:
|
| 120 |
+
start_latent: Inclusive start index inside the latent sequence. This
|
| 121 |
+
sets the first timestamp returned.
|
| 122 |
+
end_latent: Exclusive end index. Determines how many timestamps get
|
| 123 |
+
generated.
|
| 124 |
+
dtype: Floating-point dtype used for the returned tensor, allowing
|
| 125 |
+
callers to control precision.
|
| 126 |
+
device: Target device for the timestamp tensor. When omitted the
|
| 127 |
+
computation occurs on CPU to avoid surprising GPU allocations.
|
| 128 |
+
"""
|
| 129 |
+
if device is None:
|
| 130 |
+
device = torch.device("cpu")
|
| 131 |
+
|
| 132 |
+
audio_latent_frame = torch.arange(start_latent, end_latent, dtype=dtype, device=device)
|
| 133 |
+
|
| 134 |
+
audio_mel_frame = audio_latent_frame * self.audio_latent_downsample_factor
|
| 135 |
+
|
| 136 |
+
if self.is_causal:
|
| 137 |
+
# Frame offset for causal alignment.
|
| 138 |
+
# The "+1" ensures the timestamp corresponds to the first sample that is fully available.
|
| 139 |
+
causal_offset = 1
|
| 140 |
+
audio_mel_frame = (audio_mel_frame + causal_offset - self.audio_latent_downsample_factor).clip(min=0)
|
| 141 |
+
|
| 142 |
+
return audio_mel_frame * self.hop_length / self.sample_rate
|
| 143 |
+
|
| 144 |
+
def _compute_audio_timings(
|
| 145 |
+
self,
|
| 146 |
+
batch_size: int,
|
| 147 |
+
num_steps: int,
|
| 148 |
+
device: Optional[torch.device] = None,
|
| 149 |
+
) -> torch.Tensor:
|
| 150 |
+
"""
|
| 151 |
+
Builds a `(B, 1, T, 2)` tensor containing timestamps for each latent frame.
|
| 152 |
+
This helper method underpins `get_patch_grid_bounds` for the audio patchifier.
|
| 153 |
+
Args:
|
| 154 |
+
batch_size: Number of sequences to broadcast the timings over.
|
| 155 |
+
num_steps: Number of latent frames (time steps) to convert into timestamps.
|
| 156 |
+
device: Device on which the resulting tensor should reside.
|
| 157 |
+
"""
|
| 158 |
+
resolved_device = device
|
| 159 |
+
if resolved_device is None:
|
| 160 |
+
resolved_device = torch.device("cpu")
|
| 161 |
+
|
| 162 |
+
start_timings = self._get_audio_latent_time_in_sec(
|
| 163 |
+
self.shift,
|
| 164 |
+
num_steps + self.shift,
|
| 165 |
+
torch.float32,
|
| 166 |
+
resolved_device,
|
| 167 |
+
)
|
| 168 |
+
start_timings = start_timings.unsqueeze(0).expand(batch_size, -1).unsqueeze(1)
|
| 169 |
+
|
| 170 |
+
end_timings = self._get_audio_latent_time_in_sec(
|
| 171 |
+
self.shift + 1,
|
| 172 |
+
num_steps + self.shift + 1,
|
| 173 |
+
torch.float32,
|
| 174 |
+
resolved_device,
|
| 175 |
+
)
|
| 176 |
+
end_timings = end_timings.unsqueeze(0).expand(batch_size, -1).unsqueeze(1)
|
| 177 |
+
|
| 178 |
+
return torch.stack([start_timings, end_timings], dim=-1)
|
| 179 |
+
|
| 180 |
+
def patchify(
|
| 181 |
+
self,
|
| 182 |
+
audio_latents: torch.Tensor,
|
| 183 |
+
) -> torch.Tensor:
|
| 184 |
+
"""
|
| 185 |
+
Flattens the audio latent tensor along time. Use `get_patch_grid_bounds`
|
| 186 |
+
to derive timestamps for each latent frame based on the configured hop
|
| 187 |
+
length and downsampling.
|
| 188 |
+
Args:
|
| 189 |
+
audio_latents: Latent tensor to patchify.
|
| 190 |
+
Returns:
|
| 191 |
+
Flattened patch tokens tensor. Use `get_patch_grid_bounds` to compute the
|
| 192 |
+
corresponding timing metadata when needed.
|
| 193 |
+
"""
|
| 194 |
+
audio_latents = einops.rearrange(
|
| 195 |
+
audio_latents,
|
| 196 |
+
"b c t f -> b t (c f)",
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
return audio_latents
|
| 200 |
+
|
| 201 |
+
def unpatchify(
|
| 202 |
+
self,
|
| 203 |
+
audio_latents: torch.Tensor,
|
| 204 |
+
output_shape: AudioLatentShape,
|
| 205 |
+
) -> torch.Tensor:
|
| 206 |
+
"""
|
| 207 |
+
Restores the `(B, C, T, F)` spectrogram tensor from flattened patches.
|
| 208 |
+
Use `get_patch_grid_bounds` to recompute the timestamps that describe each
|
| 209 |
+
frame's position in real time.
|
| 210 |
+
Args:
|
| 211 |
+
audio_latents: Latent tensor to unpatchify.
|
| 212 |
+
output_shape: Shape of the unpatched output tensor.
|
| 213 |
+
Returns:
|
| 214 |
+
Unpatched latent tensor. Use `get_patch_grid_bounds` to compute the timing
|
| 215 |
+
metadata associated with the restored latents.
|
| 216 |
+
"""
|
| 217 |
+
# audio_latents shape: (batch, time, freq * channels)
|
| 218 |
+
audio_latents = einops.rearrange(
|
| 219 |
+
audio_latents,
|
| 220 |
+
"b t (c f) -> b c t f",
|
| 221 |
+
c=output_shape.channels,
|
| 222 |
+
f=output_shape.mel_bins,
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
return audio_latents
|
| 226 |
+
|
| 227 |
+
def unpatchify_audio(
|
| 228 |
+
self,
|
| 229 |
+
audio_latents: torch.Tensor,
|
| 230 |
+
channels: int,
|
| 231 |
+
mel_bins: int
|
| 232 |
+
) -> torch.Tensor:
|
| 233 |
+
audio_latents = einops.rearrange(
|
| 234 |
+
audio_latents,
|
| 235 |
+
"b t (c f) -> b c t f",
|
| 236 |
+
c=channels,
|
| 237 |
+
f=mel_bins,
|
| 238 |
+
)
|
| 239 |
+
return audio_latents
|
| 240 |
+
|
| 241 |
+
def get_patch_grid_bounds(
|
| 242 |
+
self,
|
| 243 |
+
output_shape: AudioLatentShape | VideoLatentShape,
|
| 244 |
+
device: Optional[torch.device] = None,
|
| 245 |
+
) -> torch.Tensor:
|
| 246 |
+
"""
|
| 247 |
+
Return the temporal bounds `[inclusive start, exclusive end)` for every
|
| 248 |
+
patch emitted by `patchify`. For audio this corresponds to timestamps in
|
| 249 |
+
seconds aligned with the original spectrogram grid.
|
| 250 |
+
The returned tensor has shape `[batch_size, 1, time_steps, 2]`, where:
|
| 251 |
+
- axis 1 (size 1) represents the temporal dimension
|
| 252 |
+
- axis 3 (size 2) stores the `[start, end)` timestamps per patch
|
| 253 |
+
Args:
|
| 254 |
+
output_shape: Audio grid specification describing the number of time steps.
|
| 255 |
+
device: Target device for the returned tensor.
|
| 256 |
+
"""
|
| 257 |
+
if not isinstance(output_shape, AudioLatentShape):
|
| 258 |
+
raise ValueError("AudioPatchifier expects AudioLatentShape when computing coordinates")
|
| 259 |
+
|
| 260 |
+
return self._compute_audio_timings(output_shape.batch, output_shape.frames, device)
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
class AttentionType(Enum):
|
| 264 |
+
"""Enum for specifying the attention mechanism type."""
|
| 265 |
+
|
| 266 |
+
VANILLA = "vanilla"
|
| 267 |
+
LINEAR = "linear"
|
| 268 |
+
NONE = "none"
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
class AttnBlock(torch.nn.Module):
|
| 272 |
+
def __init__(
|
| 273 |
+
self,
|
| 274 |
+
in_channels: int,
|
| 275 |
+
norm_type: NormType = NormType.GROUP,
|
| 276 |
+
) -> None:
|
| 277 |
+
super().__init__()
|
| 278 |
+
self.in_channels = in_channels
|
| 279 |
+
|
| 280 |
+
self.norm = build_normalization_layer(in_channels, normtype=norm_type)
|
| 281 |
+
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
| 282 |
+
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
| 283 |
+
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
| 284 |
+
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
| 285 |
+
|
| 286 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 287 |
+
h_ = x
|
| 288 |
+
h_ = self.norm(h_)
|
| 289 |
+
q = self.q(h_)
|
| 290 |
+
k = self.k(h_)
|
| 291 |
+
v = self.v(h_)
|
| 292 |
+
|
| 293 |
+
# compute attention
|
| 294 |
+
b, c, h, w = q.shape
|
| 295 |
+
q = q.reshape(b, c, h * w).contiguous()
|
| 296 |
+
q = q.permute(0, 2, 1).contiguous() # b,hw,c
|
| 297 |
+
k = k.reshape(b, c, h * w).contiguous() # b,c,hw
|
| 298 |
+
w_ = torch.bmm(q, k).contiguous() # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
| 299 |
+
w_ = w_ * (int(c) ** (-0.5))
|
| 300 |
+
w_ = torch.nn.functional.softmax(w_, dim=2)
|
| 301 |
+
|
| 302 |
+
# attend to values
|
| 303 |
+
v = v.reshape(b, c, h * w).contiguous()
|
| 304 |
+
w_ = w_.permute(0, 2, 1).contiguous() # b,hw,hw (first hw of k, second of q)
|
| 305 |
+
h_ = torch.bmm(v, w_).contiguous() # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
| 306 |
+
h_ = h_.reshape(b, c, h, w).contiguous()
|
| 307 |
+
|
| 308 |
+
h_ = self.proj_out(h_)
|
| 309 |
+
|
| 310 |
+
return x + h_
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
def make_attn(
|
| 314 |
+
in_channels: int,
|
| 315 |
+
attn_type: AttentionType = AttentionType.VANILLA,
|
| 316 |
+
norm_type: NormType = NormType.GROUP,
|
| 317 |
+
) -> torch.nn.Module:
|
| 318 |
+
match attn_type:
|
| 319 |
+
case AttentionType.VANILLA:
|
| 320 |
+
return AttnBlock(in_channels, norm_type=norm_type)
|
| 321 |
+
case AttentionType.NONE:
|
| 322 |
+
return torch.nn.Identity()
|
| 323 |
+
case AttentionType.LINEAR:
|
| 324 |
+
raise NotImplementedError(f"Attention type {attn_type.value} is not supported yet.")
|
| 325 |
+
case _:
|
| 326 |
+
raise ValueError(f"Unknown attention type: {attn_type}")
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
class CausalityAxis(Enum):
|
| 330 |
+
"""Enum for specifying the causality axis in causal convolutions."""
|
| 331 |
+
|
| 332 |
+
NONE = None
|
| 333 |
+
WIDTH = "width"
|
| 334 |
+
HEIGHT = "height"
|
| 335 |
+
WIDTH_COMPATIBILITY = "width-compatibility"
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
class CausalConv2d(torch.nn.Module):
|
| 339 |
+
"""
|
| 340 |
+
A causal 2D convolution.
|
| 341 |
+
This layer ensures that the output at time `t` only depends on inputs
|
| 342 |
+
at time `t` and earlier. It achieves this by applying asymmetric padding
|
| 343 |
+
to the time dimension (width) before the convolution.
|
| 344 |
+
"""
|
| 345 |
+
|
| 346 |
+
def __init__(
|
| 347 |
+
self,
|
| 348 |
+
in_channels: int,
|
| 349 |
+
out_channels: int,
|
| 350 |
+
kernel_size: int | tuple[int, int],
|
| 351 |
+
stride: int = 1,
|
| 352 |
+
dilation: int | tuple[int, int] = 1,
|
| 353 |
+
groups: int = 1,
|
| 354 |
+
bias: bool = True,
|
| 355 |
+
causality_axis: CausalityAxis = CausalityAxis.HEIGHT,
|
| 356 |
+
) -> None:
|
| 357 |
+
super().__init__()
|
| 358 |
+
|
| 359 |
+
self.causality_axis = causality_axis
|
| 360 |
+
|
| 361 |
+
# Ensure kernel_size and dilation are tuples
|
| 362 |
+
kernel_size = torch.nn.modules.utils._pair(kernel_size)
|
| 363 |
+
dilation = torch.nn.modules.utils._pair(dilation)
|
| 364 |
+
|
| 365 |
+
# Calculate padding dimensions
|
| 366 |
+
pad_h = (kernel_size[0] - 1) * dilation[0]
|
| 367 |
+
pad_w = (kernel_size[1] - 1) * dilation[1]
|
| 368 |
+
|
| 369 |
+
# The padding tuple for F.pad is (pad_left, pad_right, pad_top, pad_bottom)
|
| 370 |
+
match self.causality_axis:
|
| 371 |
+
case CausalityAxis.NONE:
|
| 372 |
+
self.padding = (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2)
|
| 373 |
+
case CausalityAxis.WIDTH | CausalityAxis.WIDTH_COMPATIBILITY:
|
| 374 |
+
self.padding = (pad_w, 0, pad_h // 2, pad_h - pad_h // 2)
|
| 375 |
+
case CausalityAxis.HEIGHT:
|
| 376 |
+
self.padding = (pad_w // 2, pad_w - pad_w // 2, pad_h, 0)
|
| 377 |
+
case _:
|
| 378 |
+
raise ValueError(f"Invalid causality_axis: {causality_axis}")
|
| 379 |
+
|
| 380 |
+
# The internal convolution layer uses no padding, as we handle it manually
|
| 381 |
+
self.conv = torch.nn.Conv2d(
|
| 382 |
+
in_channels,
|
| 383 |
+
out_channels,
|
| 384 |
+
kernel_size,
|
| 385 |
+
stride=stride,
|
| 386 |
+
padding=0,
|
| 387 |
+
dilation=dilation,
|
| 388 |
+
groups=groups,
|
| 389 |
+
bias=bias,
|
| 390 |
+
)
|
| 391 |
+
|
| 392 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 393 |
+
# Apply causal padding before convolution
|
| 394 |
+
x = F.pad(x, self.padding)
|
| 395 |
+
return self.conv(x)
|
| 396 |
+
|
| 397 |
+
|
| 398 |
+
def make_conv2d(
|
| 399 |
+
in_channels: int,
|
| 400 |
+
out_channels: int,
|
| 401 |
+
kernel_size: int | tuple[int, int],
|
| 402 |
+
stride: int = 1,
|
| 403 |
+
padding: tuple[int, int, int, int] | None = None,
|
| 404 |
+
dilation: int = 1,
|
| 405 |
+
groups: int = 1,
|
| 406 |
+
bias: bool = True,
|
| 407 |
+
causality_axis: CausalityAxis | None = None,
|
| 408 |
+
) -> torch.nn.Module:
|
| 409 |
+
"""
|
| 410 |
+
Create a 2D convolution layer that can be either causal or non-causal.
|
| 411 |
+
Args:
|
| 412 |
+
in_channels: Number of input channels
|
| 413 |
+
out_channels: Number of output channels
|
| 414 |
+
kernel_size: Size of the convolution kernel
|
| 415 |
+
stride: Convolution stride
|
| 416 |
+
padding: Padding (if None, will be calculated based on causal flag)
|
| 417 |
+
dilation: Dilation rate
|
| 418 |
+
groups: Number of groups for grouped convolution
|
| 419 |
+
bias: Whether to use bias
|
| 420 |
+
causality_axis: Dimension along which to apply causality.
|
| 421 |
+
Returns:
|
| 422 |
+
Either a regular Conv2d or CausalConv2d layer
|
| 423 |
+
"""
|
| 424 |
+
if causality_axis is not None:
|
| 425 |
+
# For causal convolution, padding is handled internally by CausalConv2d
|
| 426 |
+
return CausalConv2d(in_channels, out_channels, kernel_size, stride, dilation, groups, bias, causality_axis)
|
| 427 |
+
else:
|
| 428 |
+
# For non-causal convolution, use symmetric padding if not specified
|
| 429 |
+
if padding is None:
|
| 430 |
+
padding = kernel_size // 2 if isinstance(kernel_size, int) else tuple(k // 2 for k in kernel_size)
|
| 431 |
+
|
| 432 |
+
return torch.nn.Conv2d(
|
| 433 |
+
in_channels,
|
| 434 |
+
out_channels,
|
| 435 |
+
kernel_size,
|
| 436 |
+
stride,
|
| 437 |
+
padding,
|
| 438 |
+
dilation,
|
| 439 |
+
groups,
|
| 440 |
+
bias,
|
| 441 |
+
)
|
| 442 |
+
|
| 443 |
+
|
| 444 |
+
|
| 445 |
+
LRELU_SLOPE = 0.1
|
| 446 |
+
|
| 447 |
+
|
| 448 |
+
class ResBlock1(torch.nn.Module):
|
| 449 |
+
def __init__(self, channels: int, kernel_size: int = 3, dilation: Tuple[int, int, int] = (1, 3, 5)):
|
| 450 |
+
super(ResBlock1, self).__init__()
|
| 451 |
+
self.convs1 = torch.nn.ModuleList(
|
| 452 |
+
[
|
| 453 |
+
torch.nn.Conv1d(
|
| 454 |
+
channels,
|
| 455 |
+
channels,
|
| 456 |
+
kernel_size,
|
| 457 |
+
1,
|
| 458 |
+
dilation=dilation[0],
|
| 459 |
+
padding="same",
|
| 460 |
+
),
|
| 461 |
+
torch.nn.Conv1d(
|
| 462 |
+
channels,
|
| 463 |
+
channels,
|
| 464 |
+
kernel_size,
|
| 465 |
+
1,
|
| 466 |
+
dilation=dilation[1],
|
| 467 |
+
padding="same",
|
| 468 |
+
),
|
| 469 |
+
torch.nn.Conv1d(
|
| 470 |
+
channels,
|
| 471 |
+
channels,
|
| 472 |
+
kernel_size,
|
| 473 |
+
1,
|
| 474 |
+
dilation=dilation[2],
|
| 475 |
+
padding="same",
|
| 476 |
+
),
|
| 477 |
+
]
|
| 478 |
+
)
|
| 479 |
+
|
| 480 |
+
self.convs2 = torch.nn.ModuleList(
|
| 481 |
+
[
|
| 482 |
+
torch.nn.Conv1d(
|
| 483 |
+
channels,
|
| 484 |
+
channels,
|
| 485 |
+
kernel_size,
|
| 486 |
+
1,
|
| 487 |
+
dilation=1,
|
| 488 |
+
padding="same",
|
| 489 |
+
),
|
| 490 |
+
torch.nn.Conv1d(
|
| 491 |
+
channels,
|
| 492 |
+
channels,
|
| 493 |
+
kernel_size,
|
| 494 |
+
1,
|
| 495 |
+
dilation=1,
|
| 496 |
+
padding="same",
|
| 497 |
+
),
|
| 498 |
+
torch.nn.Conv1d(
|
| 499 |
+
channels,
|
| 500 |
+
channels,
|
| 501 |
+
kernel_size,
|
| 502 |
+
1,
|
| 503 |
+
dilation=1,
|
| 504 |
+
padding="same",
|
| 505 |
+
),
|
| 506 |
+
]
|
| 507 |
+
)
|
| 508 |
+
|
| 509 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 510 |
+
for conv1, conv2 in zip(self.convs1, self.convs2, strict=True):
|
| 511 |
+
xt = torch.nn.functional.leaky_relu(x, LRELU_SLOPE)
|
| 512 |
+
xt = conv1(xt)
|
| 513 |
+
xt = torch.nn.functional.leaky_relu(xt, LRELU_SLOPE)
|
| 514 |
+
xt = conv2(xt)
|
| 515 |
+
x = xt + x
|
| 516 |
+
return x
|
| 517 |
+
|
| 518 |
+
|
| 519 |
+
class ResBlock2(torch.nn.Module):
|
| 520 |
+
def __init__(self, channels: int, kernel_size: int = 3, dilation: Tuple[int, int] = (1, 3)):
|
| 521 |
+
super(ResBlock2, self).__init__()
|
| 522 |
+
self.convs = torch.nn.ModuleList(
|
| 523 |
+
[
|
| 524 |
+
torch.nn.Conv1d(
|
| 525 |
+
channels,
|
| 526 |
+
channels,
|
| 527 |
+
kernel_size,
|
| 528 |
+
1,
|
| 529 |
+
dilation=dilation[0],
|
| 530 |
+
padding="same",
|
| 531 |
+
),
|
| 532 |
+
torch.nn.Conv1d(
|
| 533 |
+
channels,
|
| 534 |
+
channels,
|
| 535 |
+
kernel_size,
|
| 536 |
+
1,
|
| 537 |
+
dilation=dilation[1],
|
| 538 |
+
padding="same",
|
| 539 |
+
),
|
| 540 |
+
]
|
| 541 |
+
)
|
| 542 |
+
|
| 543 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 544 |
+
for conv in self.convs:
|
| 545 |
+
xt = torch.nn.functional.leaky_relu(x, LRELU_SLOPE)
|
| 546 |
+
xt = conv(xt)
|
| 547 |
+
x = xt + x
|
| 548 |
+
return x
|
| 549 |
+
|
| 550 |
+
|
| 551 |
+
class ResnetBlock(torch.nn.Module):
|
| 552 |
+
def __init__(
|
| 553 |
+
self,
|
| 554 |
+
*,
|
| 555 |
+
in_channels: int,
|
| 556 |
+
out_channels: int | None = None,
|
| 557 |
+
conv_shortcut: bool = False,
|
| 558 |
+
dropout: float = 0.0,
|
| 559 |
+
temb_channels: int = 512,
|
| 560 |
+
norm_type: NormType = NormType.GROUP,
|
| 561 |
+
causality_axis: CausalityAxis = CausalityAxis.HEIGHT,
|
| 562 |
+
) -> None:
|
| 563 |
+
super().__init__()
|
| 564 |
+
self.causality_axis = causality_axis
|
| 565 |
+
|
| 566 |
+
if self.causality_axis != CausalityAxis.NONE and norm_type == NormType.GROUP:
|
| 567 |
+
raise ValueError("Causal ResnetBlock with GroupNorm is not supported.")
|
| 568 |
+
self.in_channels = in_channels
|
| 569 |
+
out_channels = in_channels if out_channels is None else out_channels
|
| 570 |
+
self.out_channels = out_channels
|
| 571 |
+
self.use_conv_shortcut = conv_shortcut
|
| 572 |
+
|
| 573 |
+
self.norm1 = build_normalization_layer(in_channels, normtype=norm_type)
|
| 574 |
+
self.non_linearity = torch.nn.SiLU()
|
| 575 |
+
self.conv1 = make_conv2d(in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis)
|
| 576 |
+
if temb_channels > 0:
|
| 577 |
+
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
|
| 578 |
+
self.norm2 = build_normalization_layer(out_channels, normtype=norm_type)
|
| 579 |
+
self.dropout = torch.nn.Dropout(dropout)
|
| 580 |
+
self.conv2 = make_conv2d(out_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis)
|
| 581 |
+
if self.in_channels != self.out_channels:
|
| 582 |
+
if self.use_conv_shortcut:
|
| 583 |
+
self.conv_shortcut = make_conv2d(
|
| 584 |
+
in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis
|
| 585 |
+
)
|
| 586 |
+
else:
|
| 587 |
+
self.nin_shortcut = make_conv2d(
|
| 588 |
+
in_channels, out_channels, kernel_size=1, stride=1, causality_axis=causality_axis
|
| 589 |
+
)
|
| 590 |
+
|
| 591 |
+
def forward(
|
| 592 |
+
self,
|
| 593 |
+
x: torch.Tensor,
|
| 594 |
+
temb: torch.Tensor | None = None,
|
| 595 |
+
) -> torch.Tensor:
|
| 596 |
+
h = x
|
| 597 |
+
h = self.norm1(h)
|
| 598 |
+
h = self.non_linearity(h)
|
| 599 |
+
h = self.conv1(h)
|
| 600 |
+
|
| 601 |
+
if temb is not None:
|
| 602 |
+
h = h + self.temb_proj(self.non_linearity(temb))[:, :, None, None]
|
| 603 |
+
|
| 604 |
+
h = self.norm2(h)
|
| 605 |
+
h = self.non_linearity(h)
|
| 606 |
+
h = self.dropout(h)
|
| 607 |
+
h = self.conv2(h)
|
| 608 |
+
|
| 609 |
+
if self.in_channels != self.out_channels:
|
| 610 |
+
x = self.conv_shortcut(x) if self.use_conv_shortcut else self.nin_shortcut(x)
|
| 611 |
+
|
| 612 |
+
return x + h
|
| 613 |
+
|
| 614 |
+
|
| 615 |
+
class Downsample(torch.nn.Module):
|
| 616 |
+
"""
|
| 617 |
+
A downsampling layer that can use either a strided convolution
|
| 618 |
+
or average pooling. Supports standard and causal padding for the
|
| 619 |
+
convolutional mode.
|
| 620 |
+
"""
|
| 621 |
+
|
| 622 |
+
def __init__(
|
| 623 |
+
self,
|
| 624 |
+
in_channels: int,
|
| 625 |
+
with_conv: bool,
|
| 626 |
+
causality_axis: CausalityAxis = CausalityAxis.WIDTH,
|
| 627 |
+
) -> None:
|
| 628 |
+
super().__init__()
|
| 629 |
+
self.with_conv = with_conv
|
| 630 |
+
self.causality_axis = causality_axis
|
| 631 |
+
|
| 632 |
+
if self.causality_axis != CausalityAxis.NONE and not self.with_conv:
|
| 633 |
+
raise ValueError("causality is only supported when `with_conv=True`.")
|
| 634 |
+
|
| 635 |
+
if self.with_conv:
|
| 636 |
+
# Do time downsampling here
|
| 637 |
+
# no asymmetric padding in torch conv, must do it ourselves
|
| 638 |
+
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
|
| 639 |
+
|
| 640 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 641 |
+
if self.with_conv:
|
| 642 |
+
# Padding tuple is in the order: (left, right, top, bottom).
|
| 643 |
+
match self.causality_axis:
|
| 644 |
+
case CausalityAxis.NONE:
|
| 645 |
+
pad = (0, 1, 0, 1)
|
| 646 |
+
case CausalityAxis.WIDTH:
|
| 647 |
+
pad = (2, 0, 0, 1)
|
| 648 |
+
case CausalityAxis.HEIGHT:
|
| 649 |
+
pad = (0, 1, 2, 0)
|
| 650 |
+
case CausalityAxis.WIDTH_COMPATIBILITY:
|
| 651 |
+
pad = (1, 0, 0, 1)
|
| 652 |
+
case _:
|
| 653 |
+
raise ValueError(f"Invalid causality_axis: {self.causality_axis}")
|
| 654 |
+
|
| 655 |
+
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
| 656 |
+
x = self.conv(x)
|
| 657 |
+
else:
|
| 658 |
+
# This branch is only taken if with_conv=False, which implies causality_axis is NONE.
|
| 659 |
+
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
|
| 660 |
+
|
| 661 |
+
return x
|
| 662 |
+
|
| 663 |
+
|
| 664 |
+
def build_downsampling_path( # noqa: PLR0913
|
| 665 |
+
*,
|
| 666 |
+
ch: int,
|
| 667 |
+
ch_mult: Tuple[int, ...],
|
| 668 |
+
num_resolutions: int,
|
| 669 |
+
num_res_blocks: int,
|
| 670 |
+
resolution: int,
|
| 671 |
+
temb_channels: int,
|
| 672 |
+
dropout: float,
|
| 673 |
+
norm_type: NormType,
|
| 674 |
+
causality_axis: CausalityAxis,
|
| 675 |
+
attn_type: AttentionType,
|
| 676 |
+
attn_resolutions: Set[int],
|
| 677 |
+
resamp_with_conv: bool,
|
| 678 |
+
) -> tuple[torch.nn.ModuleList, int]:
|
| 679 |
+
"""Build the downsampling path with residual blocks, attention, and downsampling layers."""
|
| 680 |
+
down_modules = torch.nn.ModuleList()
|
| 681 |
+
curr_res = resolution
|
| 682 |
+
in_ch_mult = (1, *tuple(ch_mult))
|
| 683 |
+
block_in = ch
|
| 684 |
+
|
| 685 |
+
for i_level in range(num_resolutions):
|
| 686 |
+
block = torch.nn.ModuleList()
|
| 687 |
+
attn = torch.nn.ModuleList()
|
| 688 |
+
block_in = ch * in_ch_mult[i_level]
|
| 689 |
+
block_out = ch * ch_mult[i_level]
|
| 690 |
+
|
| 691 |
+
for _ in range(num_res_blocks):
|
| 692 |
+
block.append(
|
| 693 |
+
ResnetBlock(
|
| 694 |
+
in_channels=block_in,
|
| 695 |
+
out_channels=block_out,
|
| 696 |
+
temb_channels=temb_channels,
|
| 697 |
+
dropout=dropout,
|
| 698 |
+
norm_type=norm_type,
|
| 699 |
+
causality_axis=causality_axis,
|
| 700 |
+
)
|
| 701 |
+
)
|
| 702 |
+
block_in = block_out
|
| 703 |
+
if curr_res in attn_resolutions:
|
| 704 |
+
attn.append(make_attn(block_in, attn_type=attn_type, norm_type=norm_type))
|
| 705 |
+
|
| 706 |
+
down = torch.nn.Module()
|
| 707 |
+
down.block = block
|
| 708 |
+
down.attn = attn
|
| 709 |
+
if i_level != num_resolutions - 1:
|
| 710 |
+
down.downsample = Downsample(block_in, resamp_with_conv, causality_axis=causality_axis)
|
| 711 |
+
curr_res = curr_res // 2
|
| 712 |
+
down_modules.append(down)
|
| 713 |
+
|
| 714 |
+
return down_modules, block_in
|
| 715 |
+
|
| 716 |
+
|
| 717 |
+
class Upsample(torch.nn.Module):
|
| 718 |
+
def __init__(
|
| 719 |
+
self,
|
| 720 |
+
in_channels: int,
|
| 721 |
+
with_conv: bool,
|
| 722 |
+
causality_axis: CausalityAxis = CausalityAxis.HEIGHT,
|
| 723 |
+
) -> None:
|
| 724 |
+
super().__init__()
|
| 725 |
+
self.with_conv = with_conv
|
| 726 |
+
self.causality_axis = causality_axis
|
| 727 |
+
if self.with_conv:
|
| 728 |
+
self.conv = make_conv2d(in_channels, in_channels, kernel_size=3, stride=1, causality_axis=causality_axis)
|
| 729 |
+
|
| 730 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 731 |
+
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
| 732 |
+
if self.with_conv:
|
| 733 |
+
x = self.conv(x)
|
| 734 |
+
# Drop FIRST element in the causal axis to undo encoder's padding, while keeping the length 1 + 2 * n.
|
| 735 |
+
# For example, if the input is [0, 1, 2], after interpolation, the output is [0, 0, 1, 1, 2, 2].
|
| 736 |
+
# The causal convolution will pad the first element as [-, -, 0, 0, 1, 1, 2, 2],
|
| 737 |
+
# So the output elements rely on the following windows:
|
| 738 |
+
# 0: [-,-,0]
|
| 739 |
+
# 1: [-,0,0]
|
| 740 |
+
# 2: [0,0,1]
|
| 741 |
+
# 3: [0,1,1]
|
| 742 |
+
# 4: [1,1,2]
|
| 743 |
+
# 5: [1,2,2]
|
| 744 |
+
# Notice that the first and second elements in the output rely only on the first element in the input,
|
| 745 |
+
# while all other elements rely on two elements in the input.
|
| 746 |
+
# So we can drop the first element to undo the padding (rather than the last element).
|
| 747 |
+
# This is a no-op for non-causal convolutions.
|
| 748 |
+
match self.causality_axis:
|
| 749 |
+
case CausalityAxis.NONE:
|
| 750 |
+
pass # x remains unchanged
|
| 751 |
+
case CausalityAxis.HEIGHT:
|
| 752 |
+
x = x[:, :, 1:, :]
|
| 753 |
+
case CausalityAxis.WIDTH:
|
| 754 |
+
x = x[:, :, :, 1:]
|
| 755 |
+
case CausalityAxis.WIDTH_COMPATIBILITY:
|
| 756 |
+
pass # x remains unchanged
|
| 757 |
+
case _:
|
| 758 |
+
raise ValueError(f"Invalid causality_axis: {self.causality_axis}")
|
| 759 |
+
|
| 760 |
+
return x
|
| 761 |
+
|
| 762 |
+
|
| 763 |
+
def build_upsampling_path( # noqa: PLR0913
|
| 764 |
+
*,
|
| 765 |
+
ch: int,
|
| 766 |
+
ch_mult: Tuple[int, ...],
|
| 767 |
+
num_resolutions: int,
|
| 768 |
+
num_res_blocks: int,
|
| 769 |
+
resolution: int,
|
| 770 |
+
temb_channels: int,
|
| 771 |
+
dropout: float,
|
| 772 |
+
norm_type: NormType,
|
| 773 |
+
causality_axis: CausalityAxis,
|
| 774 |
+
attn_type: AttentionType,
|
| 775 |
+
attn_resolutions: Set[int],
|
| 776 |
+
resamp_with_conv: bool,
|
| 777 |
+
initial_block_channels: int,
|
| 778 |
+
) -> tuple[torch.nn.ModuleList, int]:
|
| 779 |
+
"""Build the upsampling path with residual blocks, attention, and upsampling layers."""
|
| 780 |
+
up_modules = torch.nn.ModuleList()
|
| 781 |
+
block_in = initial_block_channels
|
| 782 |
+
curr_res = resolution // (2 ** (num_resolutions - 1))
|
| 783 |
+
|
| 784 |
+
for level in reversed(range(num_resolutions)):
|
| 785 |
+
stage = torch.nn.Module()
|
| 786 |
+
stage.block = torch.nn.ModuleList()
|
| 787 |
+
stage.attn = torch.nn.ModuleList()
|
| 788 |
+
block_out = ch * ch_mult[level]
|
| 789 |
+
|
| 790 |
+
for _ in range(num_res_blocks + 1):
|
| 791 |
+
stage.block.append(
|
| 792 |
+
ResnetBlock(
|
| 793 |
+
in_channels=block_in,
|
| 794 |
+
out_channels=block_out,
|
| 795 |
+
temb_channels=temb_channels,
|
| 796 |
+
dropout=dropout,
|
| 797 |
+
norm_type=norm_type,
|
| 798 |
+
causality_axis=causality_axis,
|
| 799 |
+
)
|
| 800 |
+
)
|
| 801 |
+
block_in = block_out
|
| 802 |
+
if curr_res in attn_resolutions:
|
| 803 |
+
stage.attn.append(make_attn(block_in, attn_type=attn_type, norm_type=norm_type))
|
| 804 |
+
|
| 805 |
+
if level != 0:
|
| 806 |
+
stage.upsample = Upsample(block_in, resamp_with_conv, causality_axis=causality_axis)
|
| 807 |
+
curr_res *= 2
|
| 808 |
+
|
| 809 |
+
up_modules.insert(0, stage)
|
| 810 |
+
|
| 811 |
+
return up_modules, block_in
|
| 812 |
+
|
| 813 |
+
|
| 814 |
+
class PerChannelStatistics(nn.Module):
|
| 815 |
+
"""
|
| 816 |
+
Per-channel statistics for normalizing and denormalizing the latent representation.
|
| 817 |
+
This statics is computed over the entire dataset and stored in model's checkpoint under AudioVAE state_dict.
|
| 818 |
+
"""
|
| 819 |
+
|
| 820 |
+
def __init__(self, latent_channels: int = 128) -> None:
|
| 821 |
+
super().__init__()
|
| 822 |
+
self.register_buffer("std-of-means", torch.empty(latent_channels))
|
| 823 |
+
self.register_buffer("mean-of-means", torch.empty(latent_channels))
|
| 824 |
+
|
| 825 |
+
def un_normalize(self, x: torch.Tensor) -> torch.Tensor:
|
| 826 |
+
return (x * self.get_buffer("std-of-means").to(x)) + self.get_buffer("mean-of-means").to(x)
|
| 827 |
+
|
| 828 |
+
def normalize(self, x: torch.Tensor) -> torch.Tensor:
|
| 829 |
+
return (x - self.get_buffer("mean-of-means").to(x)) / self.get_buffer("std-of-means").to(x)
|
| 830 |
+
|
| 831 |
+
|
| 832 |
+
LATENT_DOWNSAMPLE_FACTOR = 4
|
| 833 |
+
|
| 834 |
+
|
| 835 |
+
def build_mid_block(
|
| 836 |
+
channels: int,
|
| 837 |
+
temb_channels: int,
|
| 838 |
+
dropout: float,
|
| 839 |
+
norm_type: NormType,
|
| 840 |
+
causality_axis: CausalityAxis,
|
| 841 |
+
attn_type: AttentionType,
|
| 842 |
+
add_attention: bool,
|
| 843 |
+
) -> torch.nn.Module:
|
| 844 |
+
"""Build the middle block with two ResNet blocks and optional attention."""
|
| 845 |
+
mid = torch.nn.Module()
|
| 846 |
+
mid.block_1 = ResnetBlock(
|
| 847 |
+
in_channels=channels,
|
| 848 |
+
out_channels=channels,
|
| 849 |
+
temb_channels=temb_channels,
|
| 850 |
+
dropout=dropout,
|
| 851 |
+
norm_type=norm_type,
|
| 852 |
+
causality_axis=causality_axis,
|
| 853 |
+
)
|
| 854 |
+
mid.attn_1 = make_attn(channels, attn_type=attn_type, norm_type=norm_type) if add_attention else torch.nn.Identity()
|
| 855 |
+
mid.block_2 = ResnetBlock(
|
| 856 |
+
in_channels=channels,
|
| 857 |
+
out_channels=channels,
|
| 858 |
+
temb_channels=temb_channels,
|
| 859 |
+
dropout=dropout,
|
| 860 |
+
norm_type=norm_type,
|
| 861 |
+
causality_axis=causality_axis,
|
| 862 |
+
)
|
| 863 |
+
return mid
|
| 864 |
+
|
| 865 |
+
|
| 866 |
+
def run_mid_block(mid: torch.nn.Module, features: torch.Tensor) -> torch.Tensor:
|
| 867 |
+
"""Run features through the middle block."""
|
| 868 |
+
features = mid.block_1(features, temb=None)
|
| 869 |
+
features = mid.attn_1(features)
|
| 870 |
+
return mid.block_2(features, temb=None)
|
| 871 |
+
|
| 872 |
+
|
| 873 |
+
class LTX2AudioEncoder(torch.nn.Module):
|
| 874 |
+
"""
|
| 875 |
+
Encoder that compresses audio spectrograms into latent representations.
|
| 876 |
+
The encoder uses a series of downsampling blocks with residual connections,
|
| 877 |
+
attention mechanisms, and configurable causal convolutions.
|
| 878 |
+
"""
|
| 879 |
+
|
| 880 |
+
def __init__( # noqa: PLR0913
|
| 881 |
+
self,
|
| 882 |
+
*,
|
| 883 |
+
ch: int = 128,
|
| 884 |
+
ch_mult: Tuple[int, ...] = (1, 2, 4),
|
| 885 |
+
num_res_blocks: int = 2,
|
| 886 |
+
attn_resolutions: Set[int] = set(),
|
| 887 |
+
dropout: float = 0.0,
|
| 888 |
+
resamp_with_conv: bool = True,
|
| 889 |
+
in_channels: int = 2,
|
| 890 |
+
resolution: int = 256,
|
| 891 |
+
z_channels: int = 8,
|
| 892 |
+
double_z: bool = True,
|
| 893 |
+
attn_type: AttentionType = AttentionType.VANILLA,
|
| 894 |
+
mid_block_add_attention: bool = False,
|
| 895 |
+
norm_type: NormType = NormType.PIXEL,
|
| 896 |
+
causality_axis: CausalityAxis = CausalityAxis.HEIGHT,
|
| 897 |
+
sample_rate: int = 16000,
|
| 898 |
+
mel_hop_length: int = 160,
|
| 899 |
+
n_fft: int = 1024,
|
| 900 |
+
is_causal: bool = True,
|
| 901 |
+
mel_bins: int = 64,
|
| 902 |
+
**_ignore_kwargs,
|
| 903 |
+
) -> None:
|
| 904 |
+
"""
|
| 905 |
+
Initialize the Encoder.
|
| 906 |
+
Args:
|
| 907 |
+
Arguments are configuration parameters, loaded from the audio VAE checkpoint config
|
| 908 |
+
(audio_vae.model.params.ddconfig):
|
| 909 |
+
ch: Base number of feature channels used in the first convolution layer.
|
| 910 |
+
ch_mult: Multiplicative factors for the number of channels at each resolution level.
|
| 911 |
+
num_res_blocks: Number of residual blocks to use at each resolution level.
|
| 912 |
+
attn_resolutions: Spatial resolutions (e.g., in time/frequency) at which to apply attention.
|
| 913 |
+
resolution: Input spatial resolution of the spectrogram (height, width).
|
| 914 |
+
z_channels: Number of channels in the latent representation.
|
| 915 |
+
norm_type: Normalization layer type to use within the network (e.g., group, batch).
|
| 916 |
+
causality_axis: Axis along which convolutions should be causal (e.g., time axis).
|
| 917 |
+
sample_rate: Audio sample rate in Hz for the input signals.
|
| 918 |
+
mel_hop_length: Hop length used when computing the mel spectrogram.
|
| 919 |
+
n_fft: FFT size used to compute the spectrogram.
|
| 920 |
+
mel_bins: Number of mel-frequency bins in the input spectrogram.
|
| 921 |
+
in_channels: Number of channels in the input spectrogram tensor.
|
| 922 |
+
double_z: If True, predict both mean and log-variance (doubling latent channels).
|
| 923 |
+
is_causal: If True, use causal convolutions suitable for streaming setups.
|
| 924 |
+
dropout: Dropout probability used in residual and mid blocks.
|
| 925 |
+
attn_type: Type of attention mechanism to use in attention blocks.
|
| 926 |
+
resamp_with_conv: If True, perform resolution changes using strided convolutions.
|
| 927 |
+
mid_block_add_attention: If True, add an attention block in the mid-level of the encoder.
|
| 928 |
+
"""
|
| 929 |
+
super().__init__()
|
| 930 |
+
|
| 931 |
+
self.per_channel_statistics = PerChannelStatistics(latent_channels=ch)
|
| 932 |
+
self.sample_rate = sample_rate
|
| 933 |
+
self.mel_hop_length = mel_hop_length
|
| 934 |
+
self.n_fft = n_fft
|
| 935 |
+
self.is_causal = is_causal
|
| 936 |
+
self.mel_bins = mel_bins
|
| 937 |
+
|
| 938 |
+
self.patchifier = AudioPatchifier(
|
| 939 |
+
patch_size=1,
|
| 940 |
+
audio_latent_downsample_factor=LATENT_DOWNSAMPLE_FACTOR,
|
| 941 |
+
sample_rate=sample_rate,
|
| 942 |
+
hop_length=mel_hop_length,
|
| 943 |
+
is_causal=is_causal,
|
| 944 |
+
)
|
| 945 |
+
|
| 946 |
+
self.ch = ch
|
| 947 |
+
self.temb_ch = 0
|
| 948 |
+
self.num_resolutions = len(ch_mult)
|
| 949 |
+
self.num_res_blocks = num_res_blocks
|
| 950 |
+
self.resolution = resolution
|
| 951 |
+
self.in_channels = in_channels
|
| 952 |
+
self.z_channels = z_channels
|
| 953 |
+
self.double_z = double_z
|
| 954 |
+
self.norm_type = norm_type
|
| 955 |
+
self.causality_axis = causality_axis
|
| 956 |
+
self.attn_type = attn_type
|
| 957 |
+
|
| 958 |
+
# downsampling
|
| 959 |
+
self.conv_in = make_conv2d(
|
| 960 |
+
in_channels,
|
| 961 |
+
self.ch,
|
| 962 |
+
kernel_size=3,
|
| 963 |
+
stride=1,
|
| 964 |
+
causality_axis=self.causality_axis,
|
| 965 |
+
)
|
| 966 |
+
|
| 967 |
+
self.non_linearity = torch.nn.SiLU()
|
| 968 |
+
|
| 969 |
+
self.down, block_in = build_downsampling_path(
|
| 970 |
+
ch=ch,
|
| 971 |
+
ch_mult=ch_mult,
|
| 972 |
+
num_resolutions=self.num_resolutions,
|
| 973 |
+
num_res_blocks=num_res_blocks,
|
| 974 |
+
resolution=resolution,
|
| 975 |
+
temb_channels=self.temb_ch,
|
| 976 |
+
dropout=dropout,
|
| 977 |
+
norm_type=self.norm_type,
|
| 978 |
+
causality_axis=self.causality_axis,
|
| 979 |
+
attn_type=self.attn_type,
|
| 980 |
+
attn_resolutions=attn_resolutions,
|
| 981 |
+
resamp_with_conv=resamp_with_conv,
|
| 982 |
+
)
|
| 983 |
+
|
| 984 |
+
self.mid = build_mid_block(
|
| 985 |
+
channels=block_in,
|
| 986 |
+
temb_channels=self.temb_ch,
|
| 987 |
+
dropout=dropout,
|
| 988 |
+
norm_type=self.norm_type,
|
| 989 |
+
causality_axis=self.causality_axis,
|
| 990 |
+
attn_type=self.attn_type,
|
| 991 |
+
add_attention=mid_block_add_attention,
|
| 992 |
+
)
|
| 993 |
+
|
| 994 |
+
self.norm_out = build_normalization_layer(block_in, normtype=self.norm_type)
|
| 995 |
+
self.conv_out = make_conv2d(
|
| 996 |
+
block_in,
|
| 997 |
+
2 * z_channels if double_z else z_channels,
|
| 998 |
+
kernel_size=3,
|
| 999 |
+
stride=1,
|
| 1000 |
+
causality_axis=self.causality_axis,
|
| 1001 |
+
)
|
| 1002 |
+
|
| 1003 |
+
def forward(self, spectrogram: torch.Tensor) -> torch.Tensor:
|
| 1004 |
+
"""
|
| 1005 |
+
Encode audio spectrogram into latent representations.
|
| 1006 |
+
Args:
|
| 1007 |
+
spectrogram: Input spectrogram of shape (batch, channels, time, frequency)
|
| 1008 |
+
Returns:
|
| 1009 |
+
Encoded latent representation of shape (batch, channels, frames, mel_bins)
|
| 1010 |
+
"""
|
| 1011 |
+
h = self.conv_in(spectrogram)
|
| 1012 |
+
h = self._run_downsampling_path(h)
|
| 1013 |
+
h = run_mid_block(self.mid, h)
|
| 1014 |
+
h = self._finalize_output(h)
|
| 1015 |
+
|
| 1016 |
+
return self._normalize_latents(h)
|
| 1017 |
+
|
| 1018 |
+
def _run_downsampling_path(self, h: torch.Tensor) -> torch.Tensor:
|
| 1019 |
+
for level in range(self.num_resolutions):
|
| 1020 |
+
stage = self.down[level]
|
| 1021 |
+
for block_idx in range(self.num_res_blocks):
|
| 1022 |
+
h = stage.block[block_idx](h, temb=None)
|
| 1023 |
+
if stage.attn:
|
| 1024 |
+
h = stage.attn[block_idx](h)
|
| 1025 |
+
|
| 1026 |
+
if level != self.num_resolutions - 1:
|
| 1027 |
+
h = stage.downsample(h)
|
| 1028 |
+
|
| 1029 |
+
return h
|
| 1030 |
+
|
| 1031 |
+
def _finalize_output(self, h: torch.Tensor) -> torch.Tensor:
|
| 1032 |
+
h = self.norm_out(h)
|
| 1033 |
+
h = self.non_linearity(h)
|
| 1034 |
+
return self.conv_out(h)
|
| 1035 |
+
|
| 1036 |
+
def _normalize_latents(self, latent_output: torch.Tensor) -> torch.Tensor:
|
| 1037 |
+
"""
|
| 1038 |
+
Normalize encoder latents using per-channel statistics.
|
| 1039 |
+
When the encoder is configured with ``double_z=True``, the final
|
| 1040 |
+
convolution produces twice the number of latent channels, typically
|
| 1041 |
+
interpreted as two concatenated tensors along the channel dimension
|
| 1042 |
+
(e.g., mean and variance or other auxiliary parameters).
|
| 1043 |
+
This method intentionally uses only the first half of the channels
|
| 1044 |
+
(the "mean" component) as input to the patchifier and normalization
|
| 1045 |
+
logic. The remaining channels are left unchanged by this method and
|
| 1046 |
+
are expected to be consumed elsewhere in the VAE pipeline.
|
| 1047 |
+
If ``double_z=False``, the encoder output already contains only the
|
| 1048 |
+
mean latents and the chunking operation simply returns that tensor.
|
| 1049 |
+
"""
|
| 1050 |
+
means = torch.chunk(latent_output, 2, dim=1)[0]
|
| 1051 |
+
latent_shape = AudioLatentShape(
|
| 1052 |
+
batch=means.shape[0],
|
| 1053 |
+
channels=means.shape[1],
|
| 1054 |
+
frames=means.shape[2],
|
| 1055 |
+
mel_bins=means.shape[3],
|
| 1056 |
+
)
|
| 1057 |
+
latent_patched = self.patchifier.patchify(means)
|
| 1058 |
+
latent_normalized = self.per_channel_statistics.normalize(latent_patched)
|
| 1059 |
+
return self.patchifier.unpatchify(latent_normalized, latent_shape)
|
| 1060 |
+
|
| 1061 |
+
|
| 1062 |
+
class LTX2AudioDecoder(torch.nn.Module):
|
| 1063 |
+
"""
|
| 1064 |
+
Symmetric decoder that reconstructs audio spectrograms from latent features.
|
| 1065 |
+
The decoder mirrors the encoder structure with configurable channel multipliers,
|
| 1066 |
+
attention resolutions, and causal convolutions.
|
| 1067 |
+
"""
|
| 1068 |
+
|
| 1069 |
+
def __init__( # noqa: PLR0913
|
| 1070 |
+
self,
|
| 1071 |
+
*,
|
| 1072 |
+
ch: int = 128,
|
| 1073 |
+
out_ch: int = 2,
|
| 1074 |
+
ch_mult: Tuple[int, ...] = (1, 2, 4),
|
| 1075 |
+
num_res_blocks: int = 2,
|
| 1076 |
+
attn_resolutions: Set[int] = set(),
|
| 1077 |
+
resolution: int=256,
|
| 1078 |
+
z_channels: int=8,
|
| 1079 |
+
norm_type: NormType = NormType.PIXEL,
|
| 1080 |
+
causality_axis: CausalityAxis = CausalityAxis.HEIGHT,
|
| 1081 |
+
dropout: float = 0.0,
|
| 1082 |
+
mid_block_add_attention: bool = False,
|
| 1083 |
+
sample_rate: int = 16000,
|
| 1084 |
+
mel_hop_length: int = 160,
|
| 1085 |
+
is_causal: bool = True,
|
| 1086 |
+
mel_bins: int | None = 64,
|
| 1087 |
+
) -> None:
|
| 1088 |
+
"""
|
| 1089 |
+
Initialize the Decoder.
|
| 1090 |
+
Args:
|
| 1091 |
+
Arguments are configuration parameters, loaded from the audio VAE checkpoint config
|
| 1092 |
+
(audio_vae.model.params.ddconfig):
|
| 1093 |
+
- ch, out_ch, ch_mult, num_res_blocks, attn_resolutions
|
| 1094 |
+
- resolution, z_channels
|
| 1095 |
+
- norm_type, causality_axis
|
| 1096 |
+
"""
|
| 1097 |
+
super().__init__()
|
| 1098 |
+
|
| 1099 |
+
# Internal behavioural defaults that are not driven by the checkpoint.
|
| 1100 |
+
resamp_with_conv = True
|
| 1101 |
+
attn_type = AttentionType.VANILLA
|
| 1102 |
+
|
| 1103 |
+
# Per-channel statistics for denormalizing latents
|
| 1104 |
+
self.per_channel_statistics = PerChannelStatistics(latent_channels=ch)
|
| 1105 |
+
self.sample_rate = sample_rate
|
| 1106 |
+
self.mel_hop_length = mel_hop_length
|
| 1107 |
+
self.is_causal = is_causal
|
| 1108 |
+
self.mel_bins = mel_bins
|
| 1109 |
+
self.patchifier = AudioPatchifier(
|
| 1110 |
+
patch_size=1,
|
| 1111 |
+
audio_latent_downsample_factor=LATENT_DOWNSAMPLE_FACTOR,
|
| 1112 |
+
sample_rate=sample_rate,
|
| 1113 |
+
hop_length=mel_hop_length,
|
| 1114 |
+
is_causal=is_causal,
|
| 1115 |
+
)
|
| 1116 |
+
|
| 1117 |
+
self.ch = ch
|
| 1118 |
+
self.temb_ch = 0
|
| 1119 |
+
self.num_resolutions = len(ch_mult)
|
| 1120 |
+
self.num_res_blocks = num_res_blocks
|
| 1121 |
+
self.resolution = resolution
|
| 1122 |
+
self.out_ch = out_ch
|
| 1123 |
+
self.give_pre_end = False
|
| 1124 |
+
self.tanh_out = False
|
| 1125 |
+
self.norm_type = norm_type
|
| 1126 |
+
self.z_channels = z_channels
|
| 1127 |
+
self.channel_multipliers = ch_mult
|
| 1128 |
+
self.attn_resolutions = attn_resolutions
|
| 1129 |
+
self.causality_axis = causality_axis
|
| 1130 |
+
self.attn_type = attn_type
|
| 1131 |
+
|
| 1132 |
+
base_block_channels = ch * self.channel_multipliers[-1]
|
| 1133 |
+
base_resolution = resolution // (2 ** (self.num_resolutions - 1))
|
| 1134 |
+
self.z_shape = (1, z_channels, base_resolution, base_resolution)
|
| 1135 |
+
|
| 1136 |
+
self.conv_in = make_conv2d(
|
| 1137 |
+
z_channels, base_block_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis
|
| 1138 |
+
)
|
| 1139 |
+
self.non_linearity = torch.nn.SiLU()
|
| 1140 |
+
self.mid = build_mid_block(
|
| 1141 |
+
channels=base_block_channels,
|
| 1142 |
+
temb_channels=self.temb_ch,
|
| 1143 |
+
dropout=dropout,
|
| 1144 |
+
norm_type=self.norm_type,
|
| 1145 |
+
causality_axis=self.causality_axis,
|
| 1146 |
+
attn_type=self.attn_type,
|
| 1147 |
+
add_attention=mid_block_add_attention,
|
| 1148 |
+
)
|
| 1149 |
+
self.up, final_block_channels = build_upsampling_path(
|
| 1150 |
+
ch=ch,
|
| 1151 |
+
ch_mult=ch_mult,
|
| 1152 |
+
num_resolutions=self.num_resolutions,
|
| 1153 |
+
num_res_blocks=num_res_blocks,
|
| 1154 |
+
resolution=resolution,
|
| 1155 |
+
temb_channels=self.temb_ch,
|
| 1156 |
+
dropout=dropout,
|
| 1157 |
+
norm_type=self.norm_type,
|
| 1158 |
+
causality_axis=self.causality_axis,
|
| 1159 |
+
attn_type=self.attn_type,
|
| 1160 |
+
attn_resolutions=attn_resolutions,
|
| 1161 |
+
resamp_with_conv=resamp_with_conv,
|
| 1162 |
+
initial_block_channels=base_block_channels,
|
| 1163 |
+
)
|
| 1164 |
+
|
| 1165 |
+
self.norm_out = build_normalization_layer(final_block_channels, normtype=self.norm_type)
|
| 1166 |
+
self.conv_out = make_conv2d(
|
| 1167 |
+
final_block_channels, out_ch, kernel_size=3, stride=1, causality_axis=self.causality_axis
|
| 1168 |
+
)
|
| 1169 |
+
|
| 1170 |
+
def forward(self, sample: torch.Tensor) -> torch.Tensor:
|
| 1171 |
+
"""
|
| 1172 |
+
Decode latent features back to audio spectrograms.
|
| 1173 |
+
Args:
|
| 1174 |
+
sample: Encoded latent representation of shape (batch, channels, frames, mel_bins)
|
| 1175 |
+
Returns:
|
| 1176 |
+
Reconstructed audio spectrogram of shape (batch, channels, time, frequency)
|
| 1177 |
+
"""
|
| 1178 |
+
sample, target_shape = self._denormalize_latents(sample)
|
| 1179 |
+
|
| 1180 |
+
h = self.conv_in(sample)
|
| 1181 |
+
h = run_mid_block(self.mid, h)
|
| 1182 |
+
h = self._run_upsampling_path(h)
|
| 1183 |
+
h = self._finalize_output(h)
|
| 1184 |
+
|
| 1185 |
+
return self._adjust_output_shape(h, target_shape)
|
| 1186 |
+
|
| 1187 |
+
def _denormalize_latents(self, sample: torch.Tensor) -> tuple[torch.Tensor, AudioLatentShape]:
|
| 1188 |
+
latent_shape = AudioLatentShape(
|
| 1189 |
+
batch=sample.shape[0],
|
| 1190 |
+
channels=sample.shape[1],
|
| 1191 |
+
frames=sample.shape[2],
|
| 1192 |
+
mel_bins=sample.shape[3],
|
| 1193 |
+
)
|
| 1194 |
+
|
| 1195 |
+
sample_patched = self.patchifier.patchify(sample)
|
| 1196 |
+
sample_denormalized = self.per_channel_statistics.un_normalize(sample_patched)
|
| 1197 |
+
sample = self.patchifier.unpatchify(sample_denormalized, latent_shape)
|
| 1198 |
+
|
| 1199 |
+
target_frames = latent_shape.frames * LATENT_DOWNSAMPLE_FACTOR
|
| 1200 |
+
if self.causality_axis != CausalityAxis.NONE:
|
| 1201 |
+
target_frames = max(target_frames - (LATENT_DOWNSAMPLE_FACTOR - 1), 1)
|
| 1202 |
+
|
| 1203 |
+
target_shape = AudioLatentShape(
|
| 1204 |
+
batch=latent_shape.batch,
|
| 1205 |
+
channels=self.out_ch,
|
| 1206 |
+
frames=target_frames,
|
| 1207 |
+
mel_bins=self.mel_bins if self.mel_bins is not None else latent_shape.mel_bins,
|
| 1208 |
+
)
|
| 1209 |
+
|
| 1210 |
+
return sample, target_shape
|
| 1211 |
+
|
| 1212 |
+
def _adjust_output_shape(
|
| 1213 |
+
self,
|
| 1214 |
+
decoded_output: torch.Tensor,
|
| 1215 |
+
target_shape: AudioLatentShape,
|
| 1216 |
+
) -> torch.Tensor:
|
| 1217 |
+
"""
|
| 1218 |
+
Adjust output shape to match target dimensions for variable-length audio.
|
| 1219 |
+
This function handles the common case where decoded audio spectrograms need to be
|
| 1220 |
+
resized to match a specific target shape.
|
| 1221 |
+
Args:
|
| 1222 |
+
decoded_output: Tensor of shape (batch, channels, time, frequency)
|
| 1223 |
+
target_shape: AudioLatentShape describing (batch, channels, time, mel bins)
|
| 1224 |
+
Returns:
|
| 1225 |
+
Tensor adjusted to match target_shape exactly
|
| 1226 |
+
"""
|
| 1227 |
+
# Current output shape: (batch, channels, time, frequency)
|
| 1228 |
+
_, _, current_time, current_freq = decoded_output.shape
|
| 1229 |
+
target_channels = target_shape.channels
|
| 1230 |
+
target_time = target_shape.frames
|
| 1231 |
+
target_freq = target_shape.mel_bins
|
| 1232 |
+
|
| 1233 |
+
# Step 1: Crop first to avoid exceeding target dimensions
|
| 1234 |
+
decoded_output = decoded_output[
|
| 1235 |
+
:, :target_channels, : min(current_time, target_time), : min(current_freq, target_freq)
|
| 1236 |
+
]
|
| 1237 |
+
|
| 1238 |
+
# Step 2: Calculate padding needed for time and frequency dimensions
|
| 1239 |
+
time_padding_needed = target_time - decoded_output.shape[2]
|
| 1240 |
+
freq_padding_needed = target_freq - decoded_output.shape[3]
|
| 1241 |
+
|
| 1242 |
+
# Step 3: Apply padding if needed
|
| 1243 |
+
if time_padding_needed > 0 or freq_padding_needed > 0:
|
| 1244 |
+
# PyTorch padding format: (pad_left, pad_right, pad_top, pad_bottom)
|
| 1245 |
+
# For audio: pad_left/right = frequency, pad_top/bottom = time
|
| 1246 |
+
padding = (
|
| 1247 |
+
0,
|
| 1248 |
+
max(freq_padding_needed, 0), # frequency padding (left, right)
|
| 1249 |
+
0,
|
| 1250 |
+
max(time_padding_needed, 0), # time padding (top, bottom)
|
| 1251 |
+
)
|
| 1252 |
+
decoded_output = F.pad(decoded_output, padding)
|
| 1253 |
+
|
| 1254 |
+
# Step 4: Final safety crop to ensure exact target shape
|
| 1255 |
+
decoded_output = decoded_output[:, :target_channels, :target_time, :target_freq]
|
| 1256 |
+
|
| 1257 |
+
return decoded_output
|
| 1258 |
+
|
| 1259 |
+
def _run_upsampling_path(self, h: torch.Tensor) -> torch.Tensor:
|
| 1260 |
+
for level in reversed(range(self.num_resolutions)):
|
| 1261 |
+
stage = self.up[level]
|
| 1262 |
+
for block_idx, block in enumerate(stage.block):
|
| 1263 |
+
h = block(h, temb=None)
|
| 1264 |
+
if stage.attn:
|
| 1265 |
+
h = stage.attn[block_idx](h)
|
| 1266 |
+
|
| 1267 |
+
if level != 0 and hasattr(stage, "upsample"):
|
| 1268 |
+
h = stage.upsample(h)
|
| 1269 |
+
|
| 1270 |
+
return h
|
| 1271 |
+
|
| 1272 |
+
def _finalize_output(self, h: torch.Tensor) -> torch.Tensor:
|
| 1273 |
+
if self.give_pre_end:
|
| 1274 |
+
return h
|
| 1275 |
+
|
| 1276 |
+
h = self.norm_out(h)
|
| 1277 |
+
h = self.non_linearity(h)
|
| 1278 |
+
h = self.conv_out(h)
|
| 1279 |
+
return torch.tanh(h) if self.tanh_out else h
|
| 1280 |
+
|
| 1281 |
+
|
| 1282 |
+
def get_padding(kernel_size: int, dilation: int = 1) -> int:
|
| 1283 |
+
return int((kernel_size * dilation - dilation) / 2)
|
| 1284 |
+
|
| 1285 |
+
|
| 1286 |
+
# ---------------------------------------------------------------------------
|
| 1287 |
+
# Anti-aliased resampling helpers (kaiser-sinc filters) for BigVGAN v2
|
| 1288 |
+
# Adopted from https://github.com/NVIDIA/BigVGAN
|
| 1289 |
+
# ---------------------------------------------------------------------------
|
| 1290 |
+
|
| 1291 |
+
|
| 1292 |
+
def _sinc(x: torch.Tensor) -> torch.Tensor:
|
| 1293 |
+
return torch.where(
|
| 1294 |
+
x == 0,
|
| 1295 |
+
torch.tensor(1.0, device=x.device, dtype=x.dtype),
|
| 1296 |
+
torch.sin(math.pi * x) / math.pi / x,
|
| 1297 |
+
)
|
| 1298 |
+
|
| 1299 |
+
|
| 1300 |
+
def kaiser_sinc_filter1d(cutoff: float, half_width: float, kernel_size: int) -> torch.Tensor:
|
| 1301 |
+
even = kernel_size % 2 == 0
|
| 1302 |
+
half_size = kernel_size // 2
|
| 1303 |
+
delta_f = 4 * half_width
|
| 1304 |
+
amplitude = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
|
| 1305 |
+
if amplitude > 50.0:
|
| 1306 |
+
beta = 0.1102 * (amplitude - 8.7)
|
| 1307 |
+
elif amplitude >= 21.0:
|
| 1308 |
+
beta = 0.5842 * (amplitude - 21) ** 0.4 + 0.07886 * (amplitude - 21.0)
|
| 1309 |
+
else:
|
| 1310 |
+
beta = 0.0
|
| 1311 |
+
window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
|
| 1312 |
+
time = torch.arange(-half_size, half_size) + 0.5 if even else torch.arange(kernel_size) - half_size
|
| 1313 |
+
if cutoff == 0:
|
| 1314 |
+
filter_ = torch.zeros_like(time)
|
| 1315 |
+
else:
|
| 1316 |
+
filter_ = 2 * cutoff * window * _sinc(2 * cutoff * time)
|
| 1317 |
+
filter_ /= filter_.sum()
|
| 1318 |
+
return filter_.view(1, 1, kernel_size)
|
| 1319 |
+
|
| 1320 |
+
|
| 1321 |
+
class LowPassFilter1d(nn.Module):
|
| 1322 |
+
def __init__(
|
| 1323 |
+
self,
|
| 1324 |
+
cutoff: float = 0.5,
|
| 1325 |
+
half_width: float = 0.6,
|
| 1326 |
+
stride: int = 1,
|
| 1327 |
+
padding: bool = True,
|
| 1328 |
+
padding_mode: str = "replicate",
|
| 1329 |
+
kernel_size: int = 12,
|
| 1330 |
+
) -> None:
|
| 1331 |
+
super().__init__()
|
| 1332 |
+
if cutoff < -0.0:
|
| 1333 |
+
raise ValueError("Minimum cutoff must be larger than zero.")
|
| 1334 |
+
if cutoff > 0.5:
|
| 1335 |
+
raise ValueError("A cutoff above 0.5 does not make sense.")
|
| 1336 |
+
self.kernel_size = kernel_size
|
| 1337 |
+
self.even = kernel_size % 2 == 0
|
| 1338 |
+
self.pad_left = kernel_size // 2 - int(self.even)
|
| 1339 |
+
self.pad_right = kernel_size // 2
|
| 1340 |
+
self.stride = stride
|
| 1341 |
+
self.padding = padding
|
| 1342 |
+
self.padding_mode = padding_mode
|
| 1343 |
+
self.register_buffer("filter", kaiser_sinc_filter1d(cutoff, half_width, kernel_size))
|
| 1344 |
+
|
| 1345 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 1346 |
+
_, n_channels, _ = x.shape
|
| 1347 |
+
if self.padding:
|
| 1348 |
+
x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode)
|
| 1349 |
+
return F.conv1d(x, self.filter.expand(n_channels, -1, -1), stride=self.stride, groups=n_channels)
|
| 1350 |
+
|
| 1351 |
+
|
| 1352 |
+
class UpSample1d(nn.Module):
|
| 1353 |
+
def __init__(
|
| 1354 |
+
self,
|
| 1355 |
+
ratio: int = 2,
|
| 1356 |
+
kernel_size: int | None = None,
|
| 1357 |
+
persistent: bool = True,
|
| 1358 |
+
window_type: str = "kaiser",
|
| 1359 |
+
) -> None:
|
| 1360 |
+
super().__init__()
|
| 1361 |
+
self.ratio = ratio
|
| 1362 |
+
self.stride = ratio
|
| 1363 |
+
|
| 1364 |
+
if window_type == "hann":
|
| 1365 |
+
# Hann-windowed sinc filter equivalent to torchaudio.functional.resample
|
| 1366 |
+
rolloff = 0.99
|
| 1367 |
+
lowpass_filter_width = 6
|
| 1368 |
+
width = math.ceil(lowpass_filter_width / rolloff)
|
| 1369 |
+
self.kernel_size = 2 * width * ratio + 1
|
| 1370 |
+
self.pad = width
|
| 1371 |
+
self.pad_left = 2 * width * ratio
|
| 1372 |
+
self.pad_right = self.kernel_size - ratio
|
| 1373 |
+
time_axis = (torch.arange(self.kernel_size) / ratio - width) * rolloff
|
| 1374 |
+
time_clamped = time_axis.clamp(-lowpass_filter_width, lowpass_filter_width)
|
| 1375 |
+
window = torch.cos(time_clamped * math.pi / lowpass_filter_width / 2) ** 2
|
| 1376 |
+
sinc_filter = (torch.sinc(time_axis) * window * rolloff / ratio).view(1, 1, -1)
|
| 1377 |
+
else:
|
| 1378 |
+
# Kaiser-windowed sinc filter (BigVGAN default).
|
| 1379 |
+
self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
| 1380 |
+
self.pad = self.kernel_size // ratio - 1
|
| 1381 |
+
self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
|
| 1382 |
+
self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
|
| 1383 |
+
sinc_filter = kaiser_sinc_filter1d(
|
| 1384 |
+
cutoff=0.5 / ratio,
|
| 1385 |
+
half_width=0.6 / ratio,
|
| 1386 |
+
kernel_size=self.kernel_size,
|
| 1387 |
+
)
|
| 1388 |
+
|
| 1389 |
+
self.register_buffer("filter", sinc_filter, persistent=persistent)
|
| 1390 |
+
|
| 1391 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 1392 |
+
_, n_channels, _ = x.shape
|
| 1393 |
+
x = F.pad(x, (self.pad, self.pad), mode="replicate")
|
| 1394 |
+
filt = self.filter.to(dtype=x.dtype, device=x.device).expand(n_channels, -1, -1)
|
| 1395 |
+
x = self.ratio * F.conv_transpose1d(x, filt, stride=self.stride, groups=n_channels)
|
| 1396 |
+
return x[..., self.pad_left : -self.pad_right]
|
| 1397 |
+
|
| 1398 |
+
|
| 1399 |
+
class DownSample1d(nn.Module):
|
| 1400 |
+
def __init__(self, ratio: int = 2, kernel_size: int | None = None) -> None:
|
| 1401 |
+
super().__init__()
|
| 1402 |
+
self.ratio = ratio
|
| 1403 |
+
self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
| 1404 |
+
self.lowpass = LowPassFilter1d(
|
| 1405 |
+
cutoff=0.5 / ratio,
|
| 1406 |
+
half_width=0.6 / ratio,
|
| 1407 |
+
stride=ratio,
|
| 1408 |
+
kernel_size=self.kernel_size,
|
| 1409 |
+
)
|
| 1410 |
+
|
| 1411 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 1412 |
+
return self.lowpass(x)
|
| 1413 |
+
|
| 1414 |
+
|
| 1415 |
+
class Activation1d(nn.Module):
|
| 1416 |
+
def __init__(
|
| 1417 |
+
self,
|
| 1418 |
+
activation: nn.Module,
|
| 1419 |
+
up_ratio: int = 2,
|
| 1420 |
+
down_ratio: int = 2,
|
| 1421 |
+
up_kernel_size: int = 12,
|
| 1422 |
+
down_kernel_size: int = 12,
|
| 1423 |
+
) -> None:
|
| 1424 |
+
super().__init__()
|
| 1425 |
+
self.act = activation
|
| 1426 |
+
self.upsample = UpSample1d(up_ratio, up_kernel_size)
|
| 1427 |
+
self.downsample = DownSample1d(down_ratio, down_kernel_size)
|
| 1428 |
+
|
| 1429 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 1430 |
+
x = self.upsample(x)
|
| 1431 |
+
x = self.act(x)
|
| 1432 |
+
return self.downsample(x)
|
| 1433 |
+
|
| 1434 |
+
|
| 1435 |
+
class Snake(nn.Module):
|
| 1436 |
+
def __init__(
|
| 1437 |
+
self,
|
| 1438 |
+
in_features: int,
|
| 1439 |
+
alpha: float = 1.0,
|
| 1440 |
+
alpha_trainable: bool = True,
|
| 1441 |
+
alpha_logscale: bool = True,
|
| 1442 |
+
) -> None:
|
| 1443 |
+
super().__init__()
|
| 1444 |
+
self.alpha_logscale = alpha_logscale
|
| 1445 |
+
self.alpha = nn.Parameter(torch.zeros(in_features) if alpha_logscale else torch.ones(in_features) * alpha)
|
| 1446 |
+
self.alpha.requires_grad = alpha_trainable
|
| 1447 |
+
self.eps = 1e-9
|
| 1448 |
+
|
| 1449 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 1450 |
+
alpha = self.alpha.unsqueeze(0).unsqueeze(-1)
|
| 1451 |
+
if self.alpha_logscale:
|
| 1452 |
+
alpha = torch.exp(alpha)
|
| 1453 |
+
return x + (1.0 / (alpha + self.eps)) * torch.sin(x * alpha).pow(2)
|
| 1454 |
+
|
| 1455 |
+
|
| 1456 |
+
class SnakeBeta(nn.Module):
|
| 1457 |
+
def __init__(
|
| 1458 |
+
self,
|
| 1459 |
+
in_features: int,
|
| 1460 |
+
alpha: float = 1.0,
|
| 1461 |
+
alpha_trainable: bool = True,
|
| 1462 |
+
alpha_logscale: bool = True,
|
| 1463 |
+
) -> None:
|
| 1464 |
+
super().__init__()
|
| 1465 |
+
self.alpha_logscale = alpha_logscale
|
| 1466 |
+
self.alpha = nn.Parameter(torch.zeros(in_features) if alpha_logscale else torch.ones(in_features) * alpha)
|
| 1467 |
+
self.alpha.requires_grad = alpha_trainable
|
| 1468 |
+
self.beta = nn.Parameter(torch.zeros(in_features) if alpha_logscale else torch.ones(in_features) * alpha)
|
| 1469 |
+
self.beta.requires_grad = alpha_trainable
|
| 1470 |
+
self.eps = 1e-9
|
| 1471 |
+
|
| 1472 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 1473 |
+
alpha = self.alpha.unsqueeze(0).unsqueeze(-1)
|
| 1474 |
+
beta = self.beta.unsqueeze(0).unsqueeze(-1)
|
| 1475 |
+
if self.alpha_logscale:
|
| 1476 |
+
alpha = torch.exp(alpha)
|
| 1477 |
+
beta = torch.exp(beta)
|
| 1478 |
+
return x + (1.0 / (beta + self.eps)) * torch.sin(x * alpha).pow(2)
|
| 1479 |
+
|
| 1480 |
+
|
| 1481 |
+
class AMPBlock1(nn.Module):
|
| 1482 |
+
def __init__(
|
| 1483 |
+
self,
|
| 1484 |
+
channels: int,
|
| 1485 |
+
kernel_size: int = 3,
|
| 1486 |
+
dilation: tuple[int, int, int] = (1, 3, 5),
|
| 1487 |
+
activation: str = "snake",
|
| 1488 |
+
) -> None:
|
| 1489 |
+
super().__init__()
|
| 1490 |
+
act_cls = SnakeBeta if activation == "snakebeta" else Snake
|
| 1491 |
+
self.convs1 = nn.ModuleList(
|
| 1492 |
+
[
|
| 1493 |
+
nn.Conv1d(
|
| 1494 |
+
channels,
|
| 1495 |
+
channels,
|
| 1496 |
+
kernel_size,
|
| 1497 |
+
1,
|
| 1498 |
+
dilation=dilation[0],
|
| 1499 |
+
padding=get_padding(kernel_size, dilation[0]),
|
| 1500 |
+
),
|
| 1501 |
+
nn.Conv1d(
|
| 1502 |
+
channels,
|
| 1503 |
+
channels,
|
| 1504 |
+
kernel_size,
|
| 1505 |
+
1,
|
| 1506 |
+
dilation=dilation[1],
|
| 1507 |
+
padding=get_padding(kernel_size, dilation[1]),
|
| 1508 |
+
),
|
| 1509 |
+
nn.Conv1d(
|
| 1510 |
+
channels,
|
| 1511 |
+
channels,
|
| 1512 |
+
kernel_size,
|
| 1513 |
+
1,
|
| 1514 |
+
dilation=dilation[2],
|
| 1515 |
+
padding=get_padding(kernel_size, dilation[2]),
|
| 1516 |
+
),
|
| 1517 |
+
]
|
| 1518 |
+
)
|
| 1519 |
+
|
| 1520 |
+
self.convs2 = nn.ModuleList(
|
| 1521 |
+
[
|
| 1522 |
+
nn.Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1)),
|
| 1523 |
+
nn.Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1)),
|
| 1524 |
+
nn.Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1)),
|
| 1525 |
+
]
|
| 1526 |
+
)
|
| 1527 |
+
|
| 1528 |
+
self.acts1 = nn.ModuleList([Activation1d(act_cls(channels)) for _ in range(len(self.convs1))])
|
| 1529 |
+
self.acts2 = nn.ModuleList([Activation1d(act_cls(channels)) for _ in range(len(self.convs2))])
|
| 1530 |
+
|
| 1531 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 1532 |
+
for c1, c2, a1, a2 in zip(self.convs1, self.convs2, self.acts1, self.acts2, strict=True):
|
| 1533 |
+
xt = a1(x)
|
| 1534 |
+
xt = c1(xt)
|
| 1535 |
+
xt = a2(xt)
|
| 1536 |
+
xt = c2(xt)
|
| 1537 |
+
x = x + xt
|
| 1538 |
+
return x
|
| 1539 |
+
|
| 1540 |
+
|
| 1541 |
+
class LTX2Vocoder(torch.nn.Module):
|
| 1542 |
+
"""
|
| 1543 |
+
LTX2Vocoder model for synthesizing audio from Mel spectrograms.
|
| 1544 |
+
Args:
|
| 1545 |
+
resblock_kernel_sizes: List of kernel sizes for the residual blocks.
|
| 1546 |
+
This value is read from the checkpoint at `config.vocoder.resblock_kernel_sizes`.
|
| 1547 |
+
upsample_rates: List of upsampling rates.
|
| 1548 |
+
This value is read from the checkpoint at `config.vocoder.upsample_rates`.
|
| 1549 |
+
upsample_kernel_sizes: List of kernel sizes for the upsampling layers.
|
| 1550 |
+
This value is read from the checkpoint at `config.vocoder.upsample_kernel_sizes`.
|
| 1551 |
+
resblock_dilation_sizes: List of dilation sizes for the residual blocks.
|
| 1552 |
+
This value is read from the checkpoint at `config.vocoder.resblock_dilation_sizes`.
|
| 1553 |
+
upsample_initial_channel: Initial number of channels for the upsampling layers.
|
| 1554 |
+
This value is read from the checkpoint at `config.vocoder.upsample_initial_channel`.
|
| 1555 |
+
resblock: Type of residual block to use ("1", "2", or "AMP1").
|
| 1556 |
+
This value is read from the checkpoint at `config.vocoder.resblock`.
|
| 1557 |
+
output_sampling_rate: Waveform sample rate.
|
| 1558 |
+
This value is read from the checkpoint at `config.vocoder.output_sampling_rate`.
|
| 1559 |
+
activation: Activation type for BigVGAN v2 ("snake" or "snakebeta"). Only used when resblock="AMP1".
|
| 1560 |
+
use_tanh_at_final: Apply tanh at the output (when apply_final_activation=True).
|
| 1561 |
+
apply_final_activation: Whether to apply the final tanh/clamp activation.
|
| 1562 |
+
use_bias_at_final: Whether to use bias in the final conv layer.
|
| 1563 |
+
"""
|
| 1564 |
+
|
| 1565 |
+
def __init__( # noqa: PLR0913
|
| 1566 |
+
self,
|
| 1567 |
+
resblock_kernel_sizes: List[int] | None = [3, 7, 11],
|
| 1568 |
+
upsample_rates: List[int] | None = [6, 5, 2, 2, 2],
|
| 1569 |
+
upsample_kernel_sizes: List[int] | None = [16, 15, 8, 4, 4],
|
| 1570 |
+
resblock_dilation_sizes: List[List[int]] | None = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
| 1571 |
+
upsample_initial_channel: int = 1024,
|
| 1572 |
+
resblock: str = "1",
|
| 1573 |
+
output_sampling_rate: int = 24000,
|
| 1574 |
+
activation: str = "snake",
|
| 1575 |
+
use_tanh_at_final: bool = True,
|
| 1576 |
+
apply_final_activation: bool = True,
|
| 1577 |
+
use_bias_at_final: bool = True,
|
| 1578 |
+
) -> None:
|
| 1579 |
+
super().__init__()
|
| 1580 |
+
|
| 1581 |
+
# Mutable default values are not supported as default arguments.
|
| 1582 |
+
if resblock_kernel_sizes is None:
|
| 1583 |
+
resblock_kernel_sizes = [3, 7, 11]
|
| 1584 |
+
if upsample_rates is None:
|
| 1585 |
+
upsample_rates = [6, 5, 2, 2, 2]
|
| 1586 |
+
if upsample_kernel_sizes is None:
|
| 1587 |
+
upsample_kernel_sizes = [16, 15, 8, 4, 4]
|
| 1588 |
+
if resblock_dilation_sizes is None:
|
| 1589 |
+
resblock_dilation_sizes = [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
|
| 1590 |
+
|
| 1591 |
+
self.output_sampling_rate = output_sampling_rate
|
| 1592 |
+
self.num_kernels = len(resblock_kernel_sizes)
|
| 1593 |
+
self.num_upsamples = len(upsample_rates)
|
| 1594 |
+
self.use_tanh_at_final = use_tanh_at_final
|
| 1595 |
+
self.apply_final_activation = apply_final_activation
|
| 1596 |
+
self.is_amp = resblock == "AMP1"
|
| 1597 |
+
|
| 1598 |
+
# All production checkpoints are stereo: 128 input channels (2 stereo channels x 64 mel
|
| 1599 |
+
# bins each), 2 output channels.
|
| 1600 |
+
self.conv_pre = nn.Conv1d(
|
| 1601 |
+
in_channels=128,
|
| 1602 |
+
out_channels=upsample_initial_channel,
|
| 1603 |
+
kernel_size=7,
|
| 1604 |
+
stride=1,
|
| 1605 |
+
padding=3,
|
| 1606 |
+
)
|
| 1607 |
+
resblock_cls = ResBlock1 if resblock == "1" else AMPBlock1
|
| 1608 |
+
|
| 1609 |
+
self.ups = nn.ModuleList(
|
| 1610 |
+
nn.ConvTranspose1d(
|
| 1611 |
+
upsample_initial_channel // (2**i),
|
| 1612 |
+
upsample_initial_channel // (2 ** (i + 1)),
|
| 1613 |
+
kernel_size,
|
| 1614 |
+
stride,
|
| 1615 |
+
padding=(kernel_size - stride) // 2,
|
| 1616 |
+
)
|
| 1617 |
+
for i, (stride, kernel_size) in enumerate(zip(upsample_rates, upsample_kernel_sizes, strict=True))
|
| 1618 |
+
)
|
| 1619 |
+
|
| 1620 |
+
final_channels = upsample_initial_channel // (2 ** len(upsample_rates))
|
| 1621 |
+
self.resblocks = nn.ModuleList()
|
| 1622 |
+
|
| 1623 |
+
for i in range(len(upsample_rates)):
|
| 1624 |
+
ch = upsample_initial_channel // (2 ** (i + 1))
|
| 1625 |
+
for kernel_size, dilations in zip(resblock_kernel_sizes, resblock_dilation_sizes, strict=True):
|
| 1626 |
+
if self.is_amp:
|
| 1627 |
+
self.resblocks.append(resblock_cls(ch, kernel_size, dilations, activation=activation))
|
| 1628 |
+
else:
|
| 1629 |
+
self.resblocks.append(resblock_cls(ch, kernel_size, dilations))
|
| 1630 |
+
|
| 1631 |
+
if self.is_amp:
|
| 1632 |
+
self.act_post: nn.Module = Activation1d(SnakeBeta(final_channels))
|
| 1633 |
+
else:
|
| 1634 |
+
self.act_post = nn.LeakyReLU()
|
| 1635 |
+
|
| 1636 |
+
# All production checkpoints are stereo: this final conv maps `final_channels` to 2 output channels (stereo).
|
| 1637 |
+
self.conv_post = nn.Conv1d(
|
| 1638 |
+
in_channels=final_channels,
|
| 1639 |
+
out_channels=2,
|
| 1640 |
+
kernel_size=7,
|
| 1641 |
+
stride=1,
|
| 1642 |
+
padding=3,
|
| 1643 |
+
bias=use_bias_at_final,
|
| 1644 |
+
)
|
| 1645 |
+
|
| 1646 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 1647 |
+
"""
|
| 1648 |
+
Forward pass of the vocoder.
|
| 1649 |
+
Args:
|
| 1650 |
+
x: Input Mel spectrogram tensor. Can be either:
|
| 1651 |
+
- 3D: (batch_size, time, mel_bins) for mono
|
| 1652 |
+
- 4D: (batch_size, 2, time, mel_bins) for stereo
|
| 1653 |
+
Returns:
|
| 1654 |
+
Audio waveform tensor of shape (batch_size, out_channels, audio_length)
|
| 1655 |
+
"""
|
| 1656 |
+
x = x.transpose(2, 3) # (batch, channels, time, mel_bins) -> (batch, channels, mel_bins, time)
|
| 1657 |
+
|
| 1658 |
+
if x.dim() == 4: # stereo
|
| 1659 |
+
assert x.shape[1] == 2, "Input must have 2 channels for stereo"
|
| 1660 |
+
x = einops.rearrange(x, "b s c t -> b (s c) t")
|
| 1661 |
+
|
| 1662 |
+
x = self.conv_pre(x)
|
| 1663 |
+
|
| 1664 |
+
for i in range(self.num_upsamples):
|
| 1665 |
+
if not self.is_amp:
|
| 1666 |
+
x = F.leaky_relu(x, LRELU_SLOPE)
|
| 1667 |
+
x = self.ups[i](x)
|
| 1668 |
+
start = i * self.num_kernels
|
| 1669 |
+
end = start + self.num_kernels
|
| 1670 |
+
|
| 1671 |
+
# Evaluate all resblocks with the same input tensor so they can run
|
| 1672 |
+
# independently (and thus in parallel on accelerator hardware) before
|
| 1673 |
+
# aggregating their outputs via mean.
|
| 1674 |
+
block_outputs = torch.stack(
|
| 1675 |
+
[self.resblocks[idx](x) for idx in range(start, end)],
|
| 1676 |
+
dim=0,
|
| 1677 |
+
)
|
| 1678 |
+
x = block_outputs.mean(dim=0)
|
| 1679 |
+
|
| 1680 |
+
x = self.act_post(x)
|
| 1681 |
+
x = self.conv_post(x)
|
| 1682 |
+
|
| 1683 |
+
if self.apply_final_activation:
|
| 1684 |
+
x = torch.tanh(x) if self.use_tanh_at_final else torch.clamp(x, -1, 1)
|
| 1685 |
+
|
| 1686 |
+
return x
|
| 1687 |
+
|
| 1688 |
+
|
| 1689 |
+
class _STFTFn(nn.Module):
|
| 1690 |
+
"""Implements STFT as a convolution with precomputed DFT x Hann-window bases.
|
| 1691 |
+
The DFT basis rows (real and imaginary parts interleaved) multiplied by the causal
|
| 1692 |
+
Hann window are stored as buffers and loaded from the checkpoint. Using the exact
|
| 1693 |
+
bfloat16 bases from training ensures the mel values fed to the BWE generator are
|
| 1694 |
+
bit-identical to what it was trained on.
|
| 1695 |
+
"""
|
| 1696 |
+
|
| 1697 |
+
def __init__(self, filter_length: int, hop_length: int, win_length: int) -> None:
|
| 1698 |
+
super().__init__()
|
| 1699 |
+
self.hop_length = hop_length
|
| 1700 |
+
self.win_length = win_length
|
| 1701 |
+
n_freqs = filter_length // 2 + 1
|
| 1702 |
+
self.register_buffer("forward_basis", torch.zeros(n_freqs * 2, 1, filter_length))
|
| 1703 |
+
self.register_buffer("inverse_basis", torch.zeros(n_freqs * 2, 1, filter_length))
|
| 1704 |
+
|
| 1705 |
+
def forward(self, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
| 1706 |
+
"""Compute magnitude and phase spectrogram from a batch of waveforms.
|
| 1707 |
+
Applies causal (left-only) padding of win_length - hop_length samples so that
|
| 1708 |
+
each output frame depends only on past and present input — no lookahead.
|
| 1709 |
+
Args:
|
| 1710 |
+
y: Waveform tensor of shape (B, T).
|
| 1711 |
+
Returns:
|
| 1712 |
+
magnitude: Linear amplitude spectrogram, shape (B, n_freqs, T_frames).
|
| 1713 |
+
phase: Phase spectrogram in radians, shape (B, n_freqs, T_frames).
|
| 1714 |
+
"""
|
| 1715 |
+
if y.dim() == 2:
|
| 1716 |
+
y = y.unsqueeze(1) # (B, 1, T)
|
| 1717 |
+
left_pad = max(0, self.win_length - self.hop_length) # causal: left-only
|
| 1718 |
+
y = F.pad(y, (left_pad, 0))
|
| 1719 |
+
spec = F.conv1d(y, self.forward_basis, stride=self.hop_length, padding=0)
|
| 1720 |
+
n_freqs = spec.shape[1] // 2
|
| 1721 |
+
real, imag = spec[:, :n_freqs], spec[:, n_freqs:]
|
| 1722 |
+
magnitude = torch.sqrt(real**2 + imag**2)
|
| 1723 |
+
phase = torch.atan2(imag.float(), real.float()).to(real.dtype)
|
| 1724 |
+
return magnitude, phase
|
| 1725 |
+
|
| 1726 |
+
|
| 1727 |
+
class MelSTFT(nn.Module):
|
| 1728 |
+
"""Causal log-mel spectrogram module whose buffers are loaded from the checkpoint.
|
| 1729 |
+
Computes a log-mel spectrogram by running the causal STFT (_STFTFn) on the input
|
| 1730 |
+
waveform and projecting the linear magnitude spectrum onto the mel filterbank.
|
| 1731 |
+
The module's state dict layout matches the 'mel_stft.*' keys stored in the checkpoint
|
| 1732 |
+
(mel_basis, stft_fn.forward_basis, stft_fn.inverse_basis).
|
| 1733 |
+
"""
|
| 1734 |
+
|
| 1735 |
+
def __init__(
|
| 1736 |
+
self,
|
| 1737 |
+
filter_length: int,
|
| 1738 |
+
hop_length: int,
|
| 1739 |
+
win_length: int,
|
| 1740 |
+
n_mel_channels: int,
|
| 1741 |
+
) -> None:
|
| 1742 |
+
super().__init__()
|
| 1743 |
+
self.stft_fn = _STFTFn(filter_length, hop_length, win_length)
|
| 1744 |
+
|
| 1745 |
+
# Initialized to zeros; load_state_dict overwrites with the checkpoint's
|
| 1746 |
+
# exact bfloat16 filterbank (vocoder.mel_stft.mel_basis, shape [n_mels, n_freqs]).
|
| 1747 |
+
n_freqs = filter_length // 2 + 1
|
| 1748 |
+
self.register_buffer("mel_basis", torch.zeros(n_mel_channels, n_freqs))
|
| 1749 |
+
|
| 1750 |
+
def mel_spectrogram(self, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 1751 |
+
"""Compute log-mel spectrogram and auxiliary spectral quantities.
|
| 1752 |
+
Args:
|
| 1753 |
+
y: Waveform tensor of shape (B, T).
|
| 1754 |
+
Returns:
|
| 1755 |
+
log_mel: Log-compressed mel spectrogram, shape (B, n_mel_channels, T_frames).
|
| 1756 |
+
magnitude: Linear amplitude spectrogram, shape (B, n_freqs, T_frames).
|
| 1757 |
+
phase: Phase spectrogram in radians, shape (B, n_freqs, T_frames).
|
| 1758 |
+
energy: Per-frame energy (L2 norm over frequency), shape (B, T_frames).
|
| 1759 |
+
"""
|
| 1760 |
+
magnitude, phase = self.stft_fn(y)
|
| 1761 |
+
energy = torch.norm(magnitude, dim=1)
|
| 1762 |
+
mel = torch.matmul(self.mel_basis.to(magnitude.dtype), magnitude)
|
| 1763 |
+
log_mel = torch.log(torch.clamp(mel, min=1e-5))
|
| 1764 |
+
return log_mel, magnitude, phase, energy
|
| 1765 |
+
|
| 1766 |
+
|
| 1767 |
+
class LTX2VocoderWithBWE(nn.Module):
|
| 1768 |
+
"""LTX2Vocoder with bandwidth extension (BWE) upsampling.
|
| 1769 |
+
Chains a mel-to-wav vocoder with a BWE module that upsamples the output
|
| 1770 |
+
to a higher sample rate. The BWE computes a mel spectrogram from the
|
| 1771 |
+
vocoder output, runs it through a second generator to predict a residual,
|
| 1772 |
+
and adds it to a sinc-resampled skip connection.
|
| 1773 |
+
"""
|
| 1774 |
+
|
| 1775 |
+
def __init__(
|
| 1776 |
+
self,
|
| 1777 |
+
input_sampling_rate: int = 16000,
|
| 1778 |
+
output_sampling_rate: int = 48000,
|
| 1779 |
+
hop_length: int = 80,
|
| 1780 |
+
) -> None:
|
| 1781 |
+
super().__init__()
|
| 1782 |
+
self.vocoder = LTX2Vocoder(
|
| 1783 |
+
resblock_kernel_sizes=[3, 7, 11],
|
| 1784 |
+
upsample_rates=[5, 2, 2, 2, 2, 2],
|
| 1785 |
+
upsample_kernel_sizes=[11, 4, 4, 4, 4, 4],
|
| 1786 |
+
resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
| 1787 |
+
upsample_initial_channel=1536,
|
| 1788 |
+
resblock="AMP1",
|
| 1789 |
+
activation="snakebeta",
|
| 1790 |
+
use_tanh_at_final=False,
|
| 1791 |
+
apply_final_activation=True,
|
| 1792 |
+
use_bias_at_final=False,
|
| 1793 |
+
output_sampling_rate=input_sampling_rate,
|
| 1794 |
+
)
|
| 1795 |
+
self.bwe_generator = LTX2Vocoder(
|
| 1796 |
+
resblock_kernel_sizes=[3, 7, 11],
|
| 1797 |
+
upsample_rates=[6, 5, 2, 2, 2],
|
| 1798 |
+
upsample_kernel_sizes=[12, 11, 4, 4, 4],
|
| 1799 |
+
resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
| 1800 |
+
upsample_initial_channel=512,
|
| 1801 |
+
resblock="AMP1",
|
| 1802 |
+
activation="snakebeta",
|
| 1803 |
+
use_tanh_at_final=False,
|
| 1804 |
+
apply_final_activation=False,
|
| 1805 |
+
use_bias_at_final=False,
|
| 1806 |
+
output_sampling_rate=output_sampling_rate,
|
| 1807 |
+
)
|
| 1808 |
+
|
| 1809 |
+
self.mel_stft = MelSTFT(
|
| 1810 |
+
filter_length=512,
|
| 1811 |
+
hop_length=hop_length,
|
| 1812 |
+
win_length=512,
|
| 1813 |
+
n_mel_channels=64,
|
| 1814 |
+
)
|
| 1815 |
+
self.input_sampling_rate = input_sampling_rate
|
| 1816 |
+
self.output_sampling_rate = output_sampling_rate
|
| 1817 |
+
self.hop_length = hop_length
|
| 1818 |
+
# Compute the resampler on CPU so the sinc filter is materialized even when
|
| 1819 |
+
# the model is constructed on meta device (SingleGPUModelBuilder pattern).
|
| 1820 |
+
# The filter is not stored in the checkpoint (persistent=False).
|
| 1821 |
+
with torch.device("cpu"):
|
| 1822 |
+
self.resampler = UpSample1d(
|
| 1823 |
+
ratio=output_sampling_rate // input_sampling_rate, persistent=False, window_type="hann"
|
| 1824 |
+
)
|
| 1825 |
+
|
| 1826 |
+
@property
|
| 1827 |
+
def conv_pre(self) -> nn.Conv1d:
|
| 1828 |
+
return self.vocoder.conv_pre
|
| 1829 |
+
|
| 1830 |
+
@property
|
| 1831 |
+
def conv_post(self) -> nn.Conv1d:
|
| 1832 |
+
return self.vocoder.conv_post
|
| 1833 |
+
|
| 1834 |
+
def _compute_mel(self, audio: torch.Tensor) -> torch.Tensor:
|
| 1835 |
+
"""Compute log-mel spectrogram from waveform using causal STFT bases.
|
| 1836 |
+
Args:
|
| 1837 |
+
audio: Waveform tensor of shape (B, C, T).
|
| 1838 |
+
Returns:
|
| 1839 |
+
mel: Log-mel spectrogram of shape (B, C, n_mels, T_frames).
|
| 1840 |
+
"""
|
| 1841 |
+
batch, n_channels, _ = audio.shape
|
| 1842 |
+
flat = audio.reshape(batch * n_channels, -1) # (B*C, T)
|
| 1843 |
+
mel, _, _, _ = self.mel_stft.mel_spectrogram(flat) # (B*C, n_mels, T_frames)
|
| 1844 |
+
return mel.reshape(batch, n_channels, mel.shape[1], mel.shape[2]) # (B, C, n_mels, T_frames)
|
| 1845 |
+
|
| 1846 |
+
def forward(self, mel_spec: torch.Tensor) -> torch.Tensor:
|
| 1847 |
+
"""Run the full vocoder + BWE forward pass.
|
| 1848 |
+
Args:
|
| 1849 |
+
mel_spec: Mel spectrogram of shape (B, 2, T, mel_bins) for stereo
|
| 1850 |
+
or (B, T, mel_bins) for mono. Same format as LTX2Vocoder.forward.
|
| 1851 |
+
Returns:
|
| 1852 |
+
Waveform tensor of shape (B, out_channels, T_out) clipped to [-1, 1].
|
| 1853 |
+
"""
|
| 1854 |
+
x = self.vocoder(mel_spec)
|
| 1855 |
+
_, _, length_low_rate = x.shape
|
| 1856 |
+
output_length = length_low_rate * self.output_sampling_rate // self.input_sampling_rate
|
| 1857 |
+
|
| 1858 |
+
# Pad to multiple of hop_length for exact mel frame count
|
| 1859 |
+
remainder = length_low_rate % self.hop_length
|
| 1860 |
+
if remainder != 0:
|
| 1861 |
+
x = F.pad(x, (0, self.hop_length - remainder))
|
| 1862 |
+
|
| 1863 |
+
# Compute mel spectrogram from vocoder output: (B, C, n_mels, T_frames)
|
| 1864 |
+
mel = self._compute_mel(x)
|
| 1865 |
+
|
| 1866 |
+
# LTX2Vocoder.forward expects (B, C, T, mel_bins) — transpose before calling bwe_generator
|
| 1867 |
+
mel_for_bwe = mel.transpose(2, 3) # (B, C, T_frames, mel_bins)
|
| 1868 |
+
residual = self.bwe_generator(mel_for_bwe)
|
| 1869 |
+
skip = self.resampler(x)
|
| 1870 |
+
assert residual.shape == skip.shape, f"residual {residual.shape} != skip {skip.shape}"
|
| 1871 |
+
|
| 1872 |
+
return torch.clamp(residual + skip, -1, 1)[..., :output_length]
|
diffsynth/models/ltx2_common.py
ADDED
|
@@ -0,0 +1,388 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
from typing import NamedTuple, Protocol, Tuple
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn
|
| 5 |
+
from enum import Enum
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class VideoPixelShape(NamedTuple):
|
| 9 |
+
"""
|
| 10 |
+
Shape of the tensor representing the video pixel array. Assumes BGR channel format.
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
batch: int
|
| 14 |
+
frames: int
|
| 15 |
+
height: int
|
| 16 |
+
width: int
|
| 17 |
+
fps: float
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class SpatioTemporalScaleFactors(NamedTuple):
|
| 21 |
+
"""
|
| 22 |
+
Describes the spatiotemporal downscaling between decoded video space and
|
| 23 |
+
the corresponding VAE latent grid.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
time: int
|
| 27 |
+
width: int
|
| 28 |
+
height: int
|
| 29 |
+
|
| 30 |
+
@classmethod
|
| 31 |
+
def default(cls) -> "SpatioTemporalScaleFactors":
|
| 32 |
+
return cls(time=8, width=32, height=32)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
VIDEO_SCALE_FACTORS = SpatioTemporalScaleFactors.default()
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class VideoLatentShape(NamedTuple):
|
| 39 |
+
"""
|
| 40 |
+
Shape of the tensor representing video in VAE latent space.
|
| 41 |
+
The latent representation is a 5D tensor with dimensions ordered as
|
| 42 |
+
(batch, channels, frames, height, width). Spatial and temporal dimensions
|
| 43 |
+
are downscaled relative to pixel space according to the VAE's scale factors.
|
| 44 |
+
"""
|
| 45 |
+
|
| 46 |
+
batch: int
|
| 47 |
+
channels: int
|
| 48 |
+
frames: int
|
| 49 |
+
height: int
|
| 50 |
+
width: int
|
| 51 |
+
|
| 52 |
+
def to_torch_shape(self) -> torch.Size:
|
| 53 |
+
return torch.Size([self.batch, self.channels, self.frames, self.height, self.width])
|
| 54 |
+
|
| 55 |
+
@staticmethod
|
| 56 |
+
def from_torch_shape(shape: torch.Size) -> "VideoLatentShape":
|
| 57 |
+
return VideoLatentShape(
|
| 58 |
+
batch=shape[0],
|
| 59 |
+
channels=shape[1],
|
| 60 |
+
frames=shape[2],
|
| 61 |
+
height=shape[3],
|
| 62 |
+
width=shape[4],
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
def mask_shape(self) -> "VideoLatentShape":
|
| 66 |
+
return self._replace(channels=1)
|
| 67 |
+
|
| 68 |
+
@staticmethod
|
| 69 |
+
def from_pixel_shape(
|
| 70 |
+
shape: VideoPixelShape,
|
| 71 |
+
latent_channels: int = 128,
|
| 72 |
+
scale_factors: SpatioTemporalScaleFactors = VIDEO_SCALE_FACTORS,
|
| 73 |
+
) -> "VideoLatentShape":
|
| 74 |
+
frames = (shape.frames - 1) // scale_factors[0] + 1
|
| 75 |
+
height = shape.height // scale_factors[1]
|
| 76 |
+
width = shape.width // scale_factors[2]
|
| 77 |
+
|
| 78 |
+
return VideoLatentShape(
|
| 79 |
+
batch=shape.batch,
|
| 80 |
+
channels=latent_channels,
|
| 81 |
+
frames=frames,
|
| 82 |
+
height=height,
|
| 83 |
+
width=width,
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
def upscale(self, scale_factors: SpatioTemporalScaleFactors = VIDEO_SCALE_FACTORS) -> "VideoLatentShape":
|
| 87 |
+
return self._replace(
|
| 88 |
+
channels=3,
|
| 89 |
+
frames=(self.frames - 1) * scale_factors.time + 1,
|
| 90 |
+
height=self.height * scale_factors.height,
|
| 91 |
+
width=self.width * scale_factors.width,
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
class AudioLatentShape(NamedTuple):
|
| 96 |
+
"""
|
| 97 |
+
Shape of audio in VAE latent space: (batch, channels, frames, mel_bins).
|
| 98 |
+
mel_bins is the number of frequency bins from the mel-spectrogram encoding.
|
| 99 |
+
"""
|
| 100 |
+
|
| 101 |
+
batch: int
|
| 102 |
+
channels: int
|
| 103 |
+
frames: int
|
| 104 |
+
mel_bins: int
|
| 105 |
+
|
| 106 |
+
def to_torch_shape(self) -> torch.Size:
|
| 107 |
+
return torch.Size([self.batch, self.channels, self.frames, self.mel_bins])
|
| 108 |
+
|
| 109 |
+
def mask_shape(self) -> "AudioLatentShape":
|
| 110 |
+
return self._replace(channels=1, mel_bins=1)
|
| 111 |
+
|
| 112 |
+
@staticmethod
|
| 113 |
+
def from_torch_shape(shape: torch.Size) -> "AudioLatentShape":
|
| 114 |
+
return AudioLatentShape(
|
| 115 |
+
batch=shape[0],
|
| 116 |
+
channels=shape[1],
|
| 117 |
+
frames=shape[2],
|
| 118 |
+
mel_bins=shape[3],
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
@staticmethod
|
| 122 |
+
def from_duration(
|
| 123 |
+
batch: int,
|
| 124 |
+
duration: float,
|
| 125 |
+
channels: int = 8,
|
| 126 |
+
mel_bins: int = 16,
|
| 127 |
+
sample_rate: int = 16000,
|
| 128 |
+
hop_length: int = 160,
|
| 129 |
+
audio_latent_downsample_factor: int = 4,
|
| 130 |
+
) -> "AudioLatentShape":
|
| 131 |
+
latents_per_second = float(sample_rate) / float(hop_length) / float(audio_latent_downsample_factor)
|
| 132 |
+
|
| 133 |
+
return AudioLatentShape(
|
| 134 |
+
batch=batch,
|
| 135 |
+
channels=channels,
|
| 136 |
+
frames=round(duration * latents_per_second),
|
| 137 |
+
mel_bins=mel_bins,
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
@staticmethod
|
| 141 |
+
def from_video_pixel_shape(
|
| 142 |
+
shape: VideoPixelShape,
|
| 143 |
+
channels: int = 8,
|
| 144 |
+
mel_bins: int = 16,
|
| 145 |
+
sample_rate: int = 16000,
|
| 146 |
+
hop_length: int = 160,
|
| 147 |
+
audio_latent_downsample_factor: int = 4,
|
| 148 |
+
) -> "AudioLatentShape":
|
| 149 |
+
return AudioLatentShape.from_duration(
|
| 150 |
+
batch=shape.batch,
|
| 151 |
+
duration=float(shape.frames) / float(shape.fps),
|
| 152 |
+
channels=channels,
|
| 153 |
+
mel_bins=mel_bins,
|
| 154 |
+
sample_rate=sample_rate,
|
| 155 |
+
hop_length=hop_length,
|
| 156 |
+
audio_latent_downsample_factor=audio_latent_downsample_factor,
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
@dataclass(frozen=True)
|
| 161 |
+
class LatentState:
|
| 162 |
+
"""
|
| 163 |
+
State of latents during the diffusion denoising process.
|
| 164 |
+
Attributes:
|
| 165 |
+
latent: The current noisy latent tensor being denoised.
|
| 166 |
+
denoise_mask: Mask encoding the denoising strength for each token (1 = full denoising, 0 = no denoising).
|
| 167 |
+
positions: Positional indices for each latent element, used for positional embeddings.
|
| 168 |
+
clean_latent: Initial state of the latent before denoising, may include conditioning latents.
|
| 169 |
+
"""
|
| 170 |
+
|
| 171 |
+
latent: torch.Tensor
|
| 172 |
+
denoise_mask: torch.Tensor
|
| 173 |
+
positions: torch.Tensor
|
| 174 |
+
clean_latent: torch.Tensor
|
| 175 |
+
|
| 176 |
+
def clone(self) -> "LatentState":
|
| 177 |
+
return LatentState(
|
| 178 |
+
latent=self.latent.clone(),
|
| 179 |
+
denoise_mask=self.denoise_mask.clone(),
|
| 180 |
+
positions=self.positions.clone(),
|
| 181 |
+
clean_latent=self.clean_latent.clone(),
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
class NormType(Enum):
|
| 186 |
+
"""Normalization layer types: GROUP (GroupNorm) or PIXEL (per-location RMS norm)."""
|
| 187 |
+
|
| 188 |
+
GROUP = "group"
|
| 189 |
+
PIXEL = "pixel"
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
class PixelNorm(nn.Module):
|
| 193 |
+
"""
|
| 194 |
+
Per-pixel (per-location) RMS normalization layer.
|
| 195 |
+
For each element along the chosen dimension, this layer normalizes the tensor
|
| 196 |
+
by the root-mean-square of its values across that dimension:
|
| 197 |
+
y = x / sqrt(mean(x^2, dim=dim, keepdim=True) + eps)
|
| 198 |
+
"""
|
| 199 |
+
|
| 200 |
+
def __init__(self, dim: int = 1, eps: float = 1e-8) -> None:
|
| 201 |
+
"""
|
| 202 |
+
Args:
|
| 203 |
+
dim: Dimension along which to compute the RMS (typically channels).
|
| 204 |
+
eps: Small constant added for numerical stability.
|
| 205 |
+
"""
|
| 206 |
+
super().__init__()
|
| 207 |
+
self.dim = dim
|
| 208 |
+
self.eps = eps
|
| 209 |
+
|
| 210 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 211 |
+
"""
|
| 212 |
+
Apply RMS normalization along the configured dimension.
|
| 213 |
+
"""
|
| 214 |
+
# Compute mean of squared values along `dim`, keep dimensions for broadcasting.
|
| 215 |
+
mean_sq = torch.mean(x**2, dim=self.dim, keepdim=True)
|
| 216 |
+
# Normalize by the root-mean-square (RMS).
|
| 217 |
+
rms = torch.sqrt(mean_sq + self.eps)
|
| 218 |
+
return x / rms
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
def build_normalization_layer(
|
| 222 |
+
in_channels: int, *, num_groups: int = 32, normtype: NormType = NormType.GROUP
|
| 223 |
+
) -> nn.Module:
|
| 224 |
+
"""
|
| 225 |
+
Create a normalization layer based on the normalization type.
|
| 226 |
+
Args:
|
| 227 |
+
in_channels: Number of input channels
|
| 228 |
+
num_groups: Number of groups for group normalization
|
| 229 |
+
normtype: Type of normalization: "group" or "pixel"
|
| 230 |
+
Returns:
|
| 231 |
+
A normalization layer
|
| 232 |
+
"""
|
| 233 |
+
if normtype == NormType.GROUP:
|
| 234 |
+
return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
| 235 |
+
if normtype == NormType.PIXEL:
|
| 236 |
+
return PixelNorm(dim=1, eps=1e-6)
|
| 237 |
+
raise ValueError(f"Invalid normalization type: {normtype}")
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
def rms_norm(x: torch.Tensor, weight: torch.Tensor | None = None, eps: float = 1e-6) -> torch.Tensor:
|
| 241 |
+
"""Root-mean-square (RMS) normalize `x` over its last dimension.
|
| 242 |
+
Thin wrapper around `torch.nn.functional.rms_norm` that infers the normalized
|
| 243 |
+
shape and forwards `weight` and `eps`.
|
| 244 |
+
"""
|
| 245 |
+
return torch.nn.functional.rms_norm(x, (x.shape[-1],), weight=weight, eps=eps)
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
@dataclass(frozen=True)
|
| 249 |
+
class Modality:
|
| 250 |
+
"""
|
| 251 |
+
Input data for a single modality (video or audio) in the transformer.
|
| 252 |
+
Bundles the latent tokens, timestep embeddings, positional information,
|
| 253 |
+
and text conditioning context for processing by the diffusion transformer.
|
| 254 |
+
Attributes:
|
| 255 |
+
latent: Patchified latent tokens, shape ``(B, T, D)`` where *B* is
|
| 256 |
+
the batch size, *T* is the total number of tokens (noisy +
|
| 257 |
+
conditioning), and *D* is the input dimension.
|
| 258 |
+
timesteps: Per-token timestep embeddings, shape ``(B, T)``.
|
| 259 |
+
positions: Positional coordinates, shape ``(B, 3, T)`` for video
|
| 260 |
+
(time, height, width) or ``(B, 1, T)`` for audio.
|
| 261 |
+
context: Text conditioning embeddings from the prompt encoder.
|
| 262 |
+
enabled: Whether this modality is active in the current forward pass.
|
| 263 |
+
context_mask: Optional mask for the text context tokens.
|
| 264 |
+
attention_mask: Optional 2-D self-attention mask, shape ``(B, T, T)``.
|
| 265 |
+
Values in ``[0, 1]`` where ``1`` = full attention and ``0`` = no
|
| 266 |
+
attention. ``None`` means unrestricted (full) attention between
|
| 267 |
+
all tokens. Built incrementally by conditioning items; see
|
| 268 |
+
:class:`~ltx_core.conditioning.types.attention_strength_wrapper.ConditioningItemAttentionStrengthWrapper`.
|
| 269 |
+
"""
|
| 270 |
+
|
| 271 |
+
latent: (
|
| 272 |
+
torch.Tensor
|
| 273 |
+
) # Shape: (B, T, D) where B is the batch size, T is the number of tokens, and D is input dimension
|
| 274 |
+
sigma: torch.Tensor # Shape: (B,). Current sigma value, used for cross-attention timestep calculation.
|
| 275 |
+
timesteps: torch.Tensor # Shape: (B, T) where T is the number of timesteps
|
| 276 |
+
positions: (
|
| 277 |
+
torch.Tensor
|
| 278 |
+
) # Shape: (B, 3, T) for video, where 3 is the number of dimensions and T is the number of tokens
|
| 279 |
+
context: torch.Tensor
|
| 280 |
+
enabled: bool = True
|
| 281 |
+
context_mask: torch.Tensor | None = None
|
| 282 |
+
attention_mask: torch.Tensor | None = None
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
def to_denoised(
|
| 286 |
+
sample: torch.Tensor,
|
| 287 |
+
velocity: torch.Tensor,
|
| 288 |
+
sigma: float | torch.Tensor,
|
| 289 |
+
calc_dtype: torch.dtype = torch.float32,
|
| 290 |
+
) -> torch.Tensor:
|
| 291 |
+
"""
|
| 292 |
+
Convert the sample and its denoising velocity to denoised sample.
|
| 293 |
+
Returns:
|
| 294 |
+
Denoised sample
|
| 295 |
+
"""
|
| 296 |
+
if isinstance(sigma, torch.Tensor):
|
| 297 |
+
sigma = sigma.to(calc_dtype)
|
| 298 |
+
return (sample.to(calc_dtype) - velocity.to(calc_dtype) * sigma).to(sample.dtype)
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
class Patchifier(Protocol):
|
| 303 |
+
"""
|
| 304 |
+
Protocol for patchifiers that convert latent tensors into patches and assemble them back.
|
| 305 |
+
"""
|
| 306 |
+
|
| 307 |
+
def patchify(
|
| 308 |
+
self,
|
| 309 |
+
latents: torch.Tensor,
|
| 310 |
+
) -> torch.Tensor:
|
| 311 |
+
...
|
| 312 |
+
"""
|
| 313 |
+
Convert latent tensors into flattened patch tokens.
|
| 314 |
+
Args:
|
| 315 |
+
latents: Latent tensor to patchify.
|
| 316 |
+
Returns:
|
| 317 |
+
Flattened patch tokens tensor.
|
| 318 |
+
"""
|
| 319 |
+
|
| 320 |
+
def unpatchify(
|
| 321 |
+
self,
|
| 322 |
+
latents: torch.Tensor,
|
| 323 |
+
output_shape: AudioLatentShape | VideoLatentShape,
|
| 324 |
+
) -> torch.Tensor:
|
| 325 |
+
"""
|
| 326 |
+
Converts latent tensors between spatio-temporal formats and flattened sequence representations.
|
| 327 |
+
Args:
|
| 328 |
+
latents: Patch tokens that must be rearranged back into the latent grid constructed by `patchify`.
|
| 329 |
+
output_shape: Shape of the output tensor. Note that output_shape is either AudioLatentShape or
|
| 330 |
+
VideoLatentShape.
|
| 331 |
+
Returns:
|
| 332 |
+
Dense latent tensor restored from the flattened representation.
|
| 333 |
+
"""
|
| 334 |
+
|
| 335 |
+
@property
|
| 336 |
+
def patch_size(self) -> Tuple[int, int, int]:
|
| 337 |
+
...
|
| 338 |
+
"""
|
| 339 |
+
Returns the patch size as a tuple of (temporal, height, width) dimensions
|
| 340 |
+
"""
|
| 341 |
+
|
| 342 |
+
def get_patch_grid_bounds(
|
| 343 |
+
self,
|
| 344 |
+
output_shape: AudioLatentShape | VideoLatentShape,
|
| 345 |
+
device: torch.device | None = None,
|
| 346 |
+
) -> torch.Tensor:
|
| 347 |
+
...
|
| 348 |
+
"""
|
| 349 |
+
Compute metadata describing where each latent patch resides within the
|
| 350 |
+
grid specified by `output_shape`.
|
| 351 |
+
Args:
|
| 352 |
+
output_shape: Target grid layout for the patches.
|
| 353 |
+
device: Target device for the returned tensor.
|
| 354 |
+
Returns:
|
| 355 |
+
Tensor containing patch coordinate metadata such as spatial or temporal intervals.
|
| 356 |
+
"""
|
| 357 |
+
|
| 358 |
+
|
| 359 |
+
def get_pixel_coords(
|
| 360 |
+
latent_coords: torch.Tensor,
|
| 361 |
+
scale_factors: SpatioTemporalScaleFactors,
|
| 362 |
+
causal_fix: bool = False,
|
| 363 |
+
) -> torch.Tensor:
|
| 364 |
+
"""
|
| 365 |
+
Map latent-space `[start, end)` coordinates to their pixel-space equivalents by scaling
|
| 366 |
+
each axis (frame/time, height, width) with the corresponding VAE downsampling factors.
|
| 367 |
+
Optionally compensate for causal encoding that keeps the first frame at unit temporal scale.
|
| 368 |
+
Args:
|
| 369 |
+
latent_coords: Tensor of latent bounds shaped `(batch, 3, num_patches, 2)`.
|
| 370 |
+
scale_factors: SpatioTemporalScaleFactors tuple `(temporal, height, width)` with integer scale factors applied
|
| 371 |
+
per axis.
|
| 372 |
+
causal_fix: When True, rewrites the temporal axis of the first frame so causal VAEs
|
| 373 |
+
that treat frame zero differently still yield non-negative timestamps.
|
| 374 |
+
"""
|
| 375 |
+
# Broadcast the VAE scale factors so they align with the `(batch, axis, patch, bound)` layout.
|
| 376 |
+
broadcast_shape = [1] * latent_coords.ndim
|
| 377 |
+
broadcast_shape[1] = -1 # axis dimension corresponds to (frame/time, height, width)
|
| 378 |
+
scale_tensor = torch.tensor(scale_factors, device=latent_coords.device).view(*broadcast_shape)
|
| 379 |
+
|
| 380 |
+
# Apply per-axis scaling to convert latent bounds into pixel-space coordinates.
|
| 381 |
+
pixel_coords = latent_coords * scale_tensor
|
| 382 |
+
|
| 383 |
+
if causal_fix:
|
| 384 |
+
# VAE temporal stride for the very first frame is 1 instead of `scale_factors[0]`.
|
| 385 |
+
# Shift and clamp to keep the first-frame timestamps causal and non-negative.
|
| 386 |
+
pixel_coords[:, 0, ...] = (pixel_coords[:, 0, ...] + 1 - scale_factors[0]).clamp(min=0)
|
| 387 |
+
|
| 388 |
+
return pixel_coords
|