ViTeX-Bench commited on
Commit
bc8c4af
·
verified ·
1 Parent(s): 38bf857

Bundle diffsynth library (no external repo dependency)

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. diffsynth/__init__.py +1 -0
  2. diffsynth/configs/__init__.py +2 -0
  3. diffsynth/configs/model_configs.py +888 -0
  4. diffsynth/configs/vram_management_module_maps.py +284 -0
  5. diffsynth/core/__init__.py +6 -0
  6. diffsynth/core/attention/__init__.py +1 -0
  7. diffsynth/core/attention/attention.py +121 -0
  8. diffsynth/core/data/__init__.py +1 -0
  9. diffsynth/core/data/operators.py +280 -0
  10. diffsynth/core/data/unified_dataset.py +118 -0
  11. diffsynth/core/device/__init__.py +2 -0
  12. diffsynth/core/device/npu_compatible_device.py +107 -0
  13. diffsynth/core/gradient/__init__.py +1 -0
  14. diffsynth/core/gradient/gradient_checkpoint.py +37 -0
  15. diffsynth/core/loader/__init__.py +3 -0
  16. diffsynth/core/loader/config.py +119 -0
  17. diffsynth/core/loader/file.py +130 -0
  18. diffsynth/core/loader/model.py +115 -0
  19. diffsynth/core/npu_patch/npu_fused_operator.py +30 -0
  20. diffsynth/core/vram/__init__.py +2 -0
  21. diffsynth/core/vram/disk_map.py +93 -0
  22. diffsynth/core/vram/initialization.py +21 -0
  23. diffsynth/core/vram/layers.py +479 -0
  24. diffsynth/diffusion/__init__.py +6 -0
  25. diffsynth/diffusion/base_pipeline.py +500 -0
  26. diffsynth/diffusion/flow_match.py +236 -0
  27. diffsynth/diffusion/logger.py +43 -0
  28. diffsynth/diffusion/loss.py +158 -0
  29. diffsynth/diffusion/parsers.py +71 -0
  30. diffsynth/diffusion/runner.py +135 -0
  31. diffsynth/diffusion/training_module.py +302 -0
  32. diffsynth/models/anima_dit.py +1307 -0
  33. diffsynth/models/dinov3_image_encoder.py +96 -0
  34. diffsynth/models/flux2_dit.py +1053 -0
  35. diffsynth/models/flux2_text_encoder.py +58 -0
  36. diffsynth/models/flux2_vae.py +0 -0
  37. diffsynth/models/flux_controlnet.py +384 -0
  38. diffsynth/models/flux_dit.py +398 -0
  39. diffsynth/models/flux_infiniteyou.py +129 -0
  40. diffsynth/models/flux_ipadapter.py +110 -0
  41. diffsynth/models/flux_lora_encoder.py +521 -0
  42. diffsynth/models/flux_lora_patcher.py +306 -0
  43. diffsynth/models/flux_text_encoder_clip.py +112 -0
  44. diffsynth/models/flux_text_encoder_t5.py +43 -0
  45. diffsynth/models/flux_vae.py +451 -0
  46. diffsynth/models/flux_value_control.py +56 -0
  47. diffsynth/models/general_modules.py +146 -0
  48. diffsynth/models/longcat_video_dit.py +902 -0
  49. diffsynth/models/ltx2_audio_vae.py +1872 -0
  50. diffsynth/models/ltx2_common.py +388 -0
diffsynth/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .core import *
diffsynth/configs/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .model_configs import MODEL_CONFIGS
2
+ from .vram_management_module_maps import VRAM_MANAGEMENT_MODULE_MAPS, VERSION_CHECKER_MAPS
diffsynth/configs/model_configs.py ADDED
@@ -0,0 +1,888 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ qwen_image_series = [
2
+ {
3
+ # Example: ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors")
4
+ "model_hash": "0319a1cb19835fb510907dd3367c95ff",
5
+ "model_name": "qwen_image_dit",
6
+ "model_class": "diffsynth.models.qwen_image_dit.QwenImageDiT",
7
+ },
8
+ {
9
+ # Example: ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors")
10
+ "model_hash": "8004730443f55db63092006dd9f7110e",
11
+ "model_name": "qwen_image_text_encoder",
12
+ "model_class": "diffsynth.models.qwen_image_text_encoder.QwenImageTextEncoder",
13
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.qwen_image_text_encoder.QwenImageTextEncoderStateDictConverter",
14
+ },
15
+ {
16
+ # Example: ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors")
17
+ "model_hash": "ed4ea5824d55ec3107b09815e318123a",
18
+ "model_name": "qwen_image_vae",
19
+ "model_class": "diffsynth.models.qwen_image_vae.QwenImageVAE",
20
+ },
21
+ {
22
+ # Example: ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth", origin_file_pattern="model.safetensors")
23
+ "model_hash": "073bce9cf969e317e5662cd570c3e79c",
24
+ "model_name": "qwen_image_blockwise_controlnet",
25
+ "model_class": "diffsynth.models.qwen_image_controlnet.QwenImageBlockWiseControlNet",
26
+ },
27
+ {
28
+ # Example: ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint", origin_file_pattern="model.safetensors")
29
+ "model_hash": "a9e54e480a628f0b956a688a81c33bab",
30
+ "model_name": "qwen_image_blockwise_controlnet",
31
+ "model_class": "diffsynth.models.qwen_image_controlnet.QwenImageBlockWiseControlNet",
32
+ "extra_kwargs": {"additional_in_dim": 4},
33
+ },
34
+ {
35
+ # Example: ModelConfig(model_id="DiffSynth-Studio/General-Image-Encoders", origin_file_pattern="SigLIP2-G384/model.safetensors")
36
+ "model_hash": "469c78b61e3e31bc9eec0d0af3d3f2f8",
37
+ "model_name": "siglip2_image_encoder",
38
+ "model_class": "diffsynth.models.siglip2_image_encoder.Siglip2ImageEncoder",
39
+ },
40
+ {
41
+ # Example: ModelConfig(model_id="DiffSynth-Studio/General-Image-Encoders", origin_file_pattern="DINOv3-7B/model.safetensors")
42
+ "model_hash": "5722b5c873720009de96422993b15682",
43
+ "model_name": "dinov3_image_encoder",
44
+ "model_class": "diffsynth.models.dinov3_image_encoder.DINOv3ImageEncoder",
45
+ },
46
+ {
47
+ # Example:
48
+ "model_hash": "a166c33455cdbd89c0888a3645ca5c0f",
49
+ "model_name": "qwen_image_image2lora_coarse",
50
+ "model_class": "diffsynth.models.qwen_image_image2lora.QwenImageImage2LoRAModel",
51
+ },
52
+ {
53
+ # Example:
54
+ "model_hash": "a5476e691767a4da6d3a6634a10f7408",
55
+ "model_name": "qwen_image_image2lora_fine",
56
+ "model_class": "diffsynth.models.qwen_image_image2lora.QwenImageImage2LoRAModel",
57
+ "extra_kwargs": {"residual_length": 37*37+7, "residual_mid_dim": 64}
58
+ },
59
+ {
60
+ # Example:
61
+ "model_hash": "0aad514690602ecaff932c701cb4b0bb",
62
+ "model_name": "qwen_image_image2lora_style",
63
+ "model_class": "diffsynth.models.qwen_image_image2lora.QwenImageImage2LoRAModel",
64
+ "extra_kwargs": {"compress_dim": 64, "use_residual": False}
65
+ },
66
+ {
67
+ # Example: ModelConfig(model_id="Qwen/Qwen-Image-Layered", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors")
68
+ "model_hash": "8dc8cda05de16c73afa755e2c1ce2839",
69
+ "model_name": "qwen_image_dit",
70
+ "model_class": "diffsynth.models.qwen_image_dit.QwenImageDiT",
71
+ "extra_kwargs": {"use_layer3d_rope": True, "use_additional_t_cond": True}
72
+ },
73
+ {
74
+ # Example: ModelConfig(model_id="Qwen/Qwen-Image-Layered", origin_file_pattern="vae/diffusion_pytorch_model.safetensors")
75
+ "model_hash": "44b39ddc499e027cfb24f7878d7416b9",
76
+ "model_name": "qwen_image_vae",
77
+ "model_class": "diffsynth.models.qwen_image_vae.QwenImageVAE",
78
+ "extra_kwargs": {"image_channels": 4}
79
+ },
80
+ ]
81
+
82
+ wan_series = [
83
+ {
84
+ # Example: ModelConfig(model_id="krea/krea-realtime-video", origin_file_pattern="krea-realtime-video-14b.safetensors")
85
+ "model_hash": "5ec04e02b42d2580483ad69f4e76346a",
86
+ "model_name": "wan_video_dit",
87
+ "model_class": "diffsynth.models.wan_video_dit.WanModel",
88
+ "extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 16, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06},
89
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_dit.WanVideoDiTStateDictConverter",
90
+ },
91
+ {
92
+ # Example: ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth")
93
+ "model_hash": "9c8818c2cbea55eca56c7b447df170da",
94
+ "model_name": "wan_video_text_encoder",
95
+ "model_class": "diffsynth.models.wan_video_text_encoder.WanTextEncoder",
96
+ },
97
+ {
98
+ # Example: ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="Wan2.1_VAE.pth")
99
+ "model_hash": "ccc42284ea13e1ad04693284c7a09be6",
100
+ "model_name": "wan_video_vae",
101
+ "model_class": "diffsynth.models.wan_video_vae.WanVideoVAE",
102
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_vae.WanVideoVAEStateDictConverter",
103
+ },
104
+ {
105
+ # Example: ModelConfig(model_id="meituan-longcat/LongCat-Video", origin_file_pattern="dit/diffusion_pytorch_model*.safetensors")
106
+ "model_hash": "8b27900f680d7251ce44e2dc8ae1ffef",
107
+ "model_name": "wan_video_dit",
108
+ "model_class": "diffsynth.models.longcat_video_dit.LongCatVideoTransformer3DModel",
109
+ },
110
+ {
111
+ # Example: ModelConfig(model_id="ByteDance/Video-As-Prompt-Wan2.1-14B", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors")
112
+ "model_hash": "5f90e66a0672219f12d9a626c8c21f61",
113
+ "model_name": "wan_video_dit",
114
+ "model_class": "diffsynth.models.wan_video_dit.WanModel",
115
+ "extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06},
116
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_dit.WanVideoDiTFromDiffusers"
117
+ },
118
+ {
119
+ # Example: ModelConfig(model_id="ByteDance/Video-As-Prompt-Wan2.1-14B", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors")
120
+ "model_hash": "5f90e66a0672219f12d9a626c8c21f61",
121
+ "model_name": "wan_video_vap",
122
+ "model_class": "diffsynth.models.wan_video_mot.MotWanModel",
123
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_mot.WanVideoMotStateDictConverter"
124
+ },
125
+ {
126
+ # Example: ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-480P", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth")
127
+ "model_hash": "5941c53e207d62f20f9025686193c40b",
128
+ "model_name": "wan_video_image_encoder",
129
+ "model_class": "diffsynth.models.wan_video_image_encoder.WanImageEncoder",
130
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_image_encoder.WanImageEncoderStateDictConverter"
131
+ },
132
+ {
133
+ # Example: ModelConfig(model_id="DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1", origin_file_pattern="model.safetensors")
134
+ "model_hash": "dbd5ec76bbf977983f972c151d545389",
135
+ "model_name": "wan_video_motion_controller",
136
+ "model_class": "diffsynth.models.wan_video_motion_controller.WanMotionControllerModel",
137
+ },
138
+ {
139
+ # Example: ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors")
140
+ "model_hash": "9269f8db9040a9d860eaca435be61814",
141
+ "model_name": "wan_video_dit",
142
+ "model_class": "diffsynth.models.wan_video_dit.WanModel",
143
+ "extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 16, 'dim': 1536, 'ffn_dim': 8960, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 12, 'num_layers': 30, 'eps': 1e-06}
144
+ },
145
+ {
146
+ # Example: ModelConfig(model_id="Wan-AI/Wan2.1-FLF2V-14B-720P", origin_file_pattern="diffusion_pytorch_model*.safetensors")
147
+ "model_hash": "3ef3b1f8e1dab83d5b71fd7b617f859f",
148
+ "model_name": "wan_video_dit",
149
+ "model_class": "diffsynth.models.wan_video_dit.WanModel",
150
+ "extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06, 'has_image_pos_emb': True}
151
+ },
152
+ {
153
+ # Example: ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors")
154
+ "model_hash": "349723183fc063b2bfc10bb2835cf677",
155
+ "model_name": "wan_video_dit",
156
+ "model_class": "diffsynth.models.wan_video_dit.WanModel",
157
+ "extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 48, 'dim': 1536, 'ffn_dim': 8960, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 12, 'num_layers': 30, 'eps': 1e-06}
158
+ },
159
+ {
160
+ # Example: ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-InP", origin_file_pattern="diffusion_pytorch_model*.safetensors")
161
+ "model_hash": "6d6ccde6845b95ad9114ab993d917893",
162
+ "model_name": "wan_video_dit",
163
+ "model_class": "diffsynth.models.wan_video_dit.WanModel",
164
+ "extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 1536, 'ffn_dim': 8960, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 12, 'num_layers': 30, 'eps': 1e-06}
165
+ },
166
+ {
167
+ # Example: ModelConfig(model_id="PAI/Wan2.1-Fun-14B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors")
168
+ "model_hash": "efa44cddf936c70abd0ea28b6cbe946c",
169
+ "model_name": "wan_video_dit",
170
+ "model_class": "diffsynth.models.wan_video_dit.WanModel",
171
+ "extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 48, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06}
172
+ },
173
+ {
174
+ # Example: ModelConfig(model_id="PAI/Wan2.1-Fun-14B-InP", origin_file_pattern="diffusion_pytorch_model*.safetensors")
175
+ "model_hash": "6bfcfb3b342cb286ce886889d519a77e",
176
+ "model_name": "wan_video_dit",
177
+ "model_class": "diffsynth.models.wan_video_dit.WanModel",
178
+ "extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06}
179
+ },
180
+ {
181
+ # Example: ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera", origin_file_pattern="diffusion_pytorch_model*.safetensors")
182
+ "model_hash": "ac6a5aa74f4a0aab6f64eb9a72f19901",
183
+ "model_name": "wan_video_dit",
184
+ "model_class": "diffsynth.models.wan_video_dit.WanModel",
185
+ "extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 32, 'dim': 1536, 'ffn_dim': 8960, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 12, 'num_layers': 30, 'eps': 1e-06, 'has_ref_conv': False, 'add_control_adapter': True, 'in_dim_control_adapter': 24}
186
+ },
187
+ {
188
+ # Example: ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors")
189
+ "model_hash": "70ddad9d3a133785da5ea371aae09504",
190
+ "model_name": "wan_video_dit",
191
+ "model_class": "diffsynth.models.wan_video_dit.WanModel",
192
+ "extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 48, 'dim': 1536, 'ffn_dim': 8960, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 12, 'num_layers': 30, 'eps': 1e-06, 'has_ref_conv': True}
193
+ },
194
+ {
195
+ # Example: ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control-Camera", origin_file_pattern="diffusion_pytorch_model*.safetensors")
196
+ "model_hash": "b61c605c2adbd23124d152ed28e049ae",
197
+ "model_name": "wan_video_dit",
198
+ "model_class": "diffsynth.models.wan_video_dit.WanModel",
199
+ "extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 32, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06, 'has_ref_conv': False, 'add_control_adapter': True, 'in_dim_control_adapter': 24}
200
+ },
201
+ {
202
+ # Example: ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors")
203
+ "model_hash": "26bde73488a92e64cc20b0a7485b9e5b",
204
+ "model_name": "wan_video_dit",
205
+ "model_class": "diffsynth.models.wan_video_dit.WanModel",
206
+ "extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 48, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06, 'has_ref_conv': True}
207
+ },
208
+ {
209
+ # Example: ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors")
210
+ "model_hash": "aafcfd9672c3a2456dc46e1cb6e52c70",
211
+ "model_name": "wan_video_dit",
212
+ "model_class": "diffsynth.models.wan_video_dit.WanModel",
213
+ "extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 16, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06}
214
+ },
215
+ {
216
+ # Example: ModelConfig(model_id="iic/VACE-Wan2.1-1.3B-Preview", origin_file_pattern="diffusion_pytorch_model*.safetensors")
217
+ "model_hash": "a61453409b67cd3246cf0c3bebad47ba",
218
+ "model_name": "wan_video_dit",
219
+ "model_class": "diffsynth.models.wan_video_dit.WanModel",
220
+ "extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 16, 'dim': 1536, 'ffn_dim': 8960, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 12, 'num_layers': 30, 'eps': 1e-06},
221
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_dit.WanVideoDiTStateDictConverter",
222
+ },
223
+ {
224
+ # Example: ModelConfig(model_id="iic/VACE-Wan2.1-1.3B-Preview", origin_file_pattern="diffusion_pytorch_model*.safetensors")
225
+ "model_hash": "a61453409b67cd3246cf0c3bebad47ba",
226
+ "model_name": "wan_video_vace",
227
+ "model_class": "diffsynth.models.wan_video_vace.VaceWanModel",
228
+ "extra_kwargs": {"use_target_text_encoder": True},
229
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_vace.VaceWanModelDictConverter"
230
+ },
231
+ {
232
+ # Example: ModelConfig(model_id="Wan-AI/Wan2.1-VACE-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors")
233
+ "model_hash": "7a513e1f257a861512b1afd387a8ecd9",
234
+ "model_name": "wan_video_dit",
235
+ "model_class": "diffsynth.models.wan_video_dit.WanModel",
236
+ "extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 16, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06},
237
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_dit.WanVideoDiTStateDictConverter",
238
+ },
239
+ {
240
+ # Example: ModelConfig(model_id="Wan-AI/Wan2.1-VACE-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors")
241
+ "model_hash": "7a513e1f257a861512b1afd387a8ecd9",
242
+ "model_name": "wan_video_vace",
243
+ "model_class": "diffsynth.models.wan_video_vace.VaceWanModel",
244
+ "extra_kwargs": {'vace_layers': (0, 5, 10, 15, 20, 25, 30, 35), 'vace_in_dim': 96, 'glyph_channels': 16, 'patch_size': (1, 2, 2), 'has_image_input': False, 'dim': 5120, 'num_heads': 40, 'ffn_dim': 13824, 'eps': 1e-06},
245
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_vace.VaceWanModelDictConverter"
246
+ },
247
+ {
248
+ # Example: ModelConfig(model_id="Wan-AI/Wan2.2-Animate-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors")
249
+ "model_hash": "31fa352acb8a1b1d33cd8764273d80a2",
250
+ "model_name": "wan_video_dit",
251
+ "model_class": "diffsynth.models.wan_video_dit.WanModel",
252
+ "extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06},
253
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_dit.WanVideoDiTStateDictConverter"
254
+ },
255
+ {
256
+ # Example: ModelConfig(model_id="Wan-AI/Wan2.2-Animate-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors")
257
+ "model_hash": "31fa352acb8a1b1d33cd8764273d80a2",
258
+ "model_name": "wan_video_animate_adapter",
259
+ "model_class": "diffsynth.models.wan_video_animate_adapter.WanAnimateAdapter",
260
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_animate_adapter.WanAnimateAdapterStateDictConverter"
261
+ },
262
+ {
263
+ # Example: ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-Control-Camera", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors")
264
+ "model_hash": "47dbeab5e560db3180adf51dc0232fb1",
265
+ "model_name": "wan_video_dit",
266
+ "model_class": "diffsynth.models.wan_video_dit.WanModel",
267
+ "extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06, 'has_ref_conv': False, 'add_control_adapter': True, 'in_dim_control_adapter': 24, 'require_clip_embedding': False}
268
+ },
269
+ {
270
+ # Example: ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-Control", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors")
271
+ "model_hash": "2267d489f0ceb9f21836532952852ee5",
272
+ "model_name": "wan_video_dit",
273
+ "model_class": "diffsynth.models.wan_video_dit.WanModel",
274
+ "extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 52, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06, 'has_ref_conv': True, 'require_clip_embedding': False},
275
+ },
276
+ {
277
+ # Example: ModelConfig(model_id="Wan-AI/Wan2.2-I2V-A14B", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors")
278
+ "model_hash": "5b013604280dd715f8457c6ed6d6a626",
279
+ "model_name": "wan_video_dit",
280
+ "model_class": "diffsynth.models.wan_video_dit.WanModel",
281
+ "extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06, 'require_clip_embedding': False}
282
+ },
283
+ {
284
+ # Example: ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors")
285
+ "model_hash": "966cffdcc52f9c46c391768b27637614",
286
+ "model_name": "wan_video_dit",
287
+ "model_class": "diffsynth.models.wan_video_dit_s2v.WanS2VModel",
288
+ "extra_kwargs": {'dim': 5120, 'in_dim': 16, 'ffn_dim': 13824, 'out_dim': 16, 'text_dim': 4096, 'freq_dim': 256, 'eps': 1e-06, 'patch_size': (1, 2, 2), 'num_heads': 40, 'num_layers': 40, 'cond_dim': 16, 'audio_dim': 1024, 'num_audio_token': 4}
289
+ },
290
+ {
291
+ # Example: ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="diffusion_pytorch_model*.safetensors")
292
+ "model_hash": "1f5ab7703c6fc803fdded85ff040c316",
293
+ "model_name": "wan_video_dit",
294
+ "model_class": "diffsynth.models.wan_video_dit.WanModel",
295
+ "extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 48, 'dim': 3072, 'ffn_dim': 14336, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 48, 'num_heads': 24, 'num_layers': 30, 'eps': 1e-06, 'seperated_timestep': True, 'require_clip_embedding': False, 'require_vae_embedding': False, 'fuse_vae_embedding_in_latents': True}
296
+ },
297
+ {
298
+ # Example: ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="Wan2.2_VAE.pth")
299
+ "model_hash": "e1de6c02cdac79f8b739f4d3698cd216",
300
+ "model_name": "wan_video_vae",
301
+ "model_class": "diffsynth.models.wan_video_vae.WanVideoVAE38",
302
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_vae.WanVideoVAEStateDictConverter",
303
+ },
304
+ {
305
+ # Example: ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="wav2vec2-large-xlsr-53-english/model.safetensors")
306
+ "model_hash": "06be60f3a4526586d8431cd038a71486",
307
+ "model_name": "wans2v_audio_encoder",
308
+ "model_class": "diffsynth.models.wav2vec.WanS2VAudioEncoder",
309
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.wans2v_audio_encoder.WanS2VAudioEncoderStateDictConverter",
310
+ },
311
+ {
312
+ # Example: ModelConfig(model_id="Wan-AI/WanToDance-14B", origin_file_pattern="global_model.safetensors")
313
+ "model_hash": "eb18873fc0ba77b541eb7b62dbcd2059",
314
+ "model_name": "wan_video_dit",
315
+ "model_class": "diffsynth.models.wan_video_dit.WanModel",
316
+ "extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06, 'wantodance_enable_music_inject': True, 'wantodance_music_inject_layers': [0, 4, 8, 12, 16, 20, 24, 27], 'wantodance_enable_refimage': True, 'has_ref_conv': True, 'wantodance_enable_refface': False, 'wantodance_enable_global': True, 'wantodance_enable_dynamicfps': True, 'wantodance_enable_unimodel': True}
317
+ },
318
+ ]
319
+
320
+ flux_series = [
321
+ {
322
+ # Example: ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors")
323
+ "model_hash": "a29710fea6dddb0314663ee823598e50",
324
+ "model_name": "flux_dit",
325
+ "model_class": "diffsynth.models.flux_dit.FluxDiT",
326
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverter",
327
+ },
328
+ {
329
+ # Supported due to historical reasons.
330
+ "model_hash": "605c56eab23e9e2af863ad8f0813a25d",
331
+ "model_name": "flux_dit",
332
+ "model_class": "diffsynth.models.flux_dit.FluxDiT",
333
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverterFromDiffusers",
334
+ },
335
+ {
336
+ # Example: ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors")
337
+ "model_hash": "94eefa3dac9cec93cb1ebaf1747d7b78",
338
+ "model_name": "flux_text_encoder_clip",
339
+ "model_class": "diffsynth.models.flux_text_encoder_clip.FluxTextEncoderClip",
340
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_text_encoder_clip.FluxTextEncoderClipStateDictConverter",
341
+ },
342
+ {
343
+ # Example: ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors")
344
+ "model_hash": "22540b49eaedbc2f2784b2091a234c7c",
345
+ "model_name": "flux_text_encoder_t5",
346
+ "model_class": "diffsynth.models.flux_text_encoder_t5.FluxTextEncoderT5",
347
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_text_encoder_t5.FluxTextEncoderT5StateDictConverter",
348
+ },
349
+ {
350
+ # Example: ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors")
351
+ "model_hash": "21ea55f476dfc4fd135587abb59dfe5d",
352
+ "model_name": "flux_vae_encoder",
353
+ "model_class": "diffsynth.models.flux_vae.FluxVAEEncoder",
354
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_vae.FluxVAEEncoderStateDictConverter",
355
+ },
356
+ {
357
+ # Example: ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors")
358
+ "model_hash": "21ea55f476dfc4fd135587abb59dfe5d",
359
+ "model_name": "flux_vae_decoder",
360
+ "model_class": "diffsynth.models.flux_vae.FluxVAEDecoder",
361
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_vae.FluxVAEDecoderStateDictConverter",
362
+ },
363
+ {
364
+ # Example: ModelConfig(model_id="ostris/Flex.2-preview", origin_file_pattern="Flex.2-preview.safetensors")
365
+ "model_hash": "d02f41c13549fa5093d3521f62a5570a",
366
+ "model_name": "flux_dit",
367
+ "model_class": "diffsynth.models.flux_dit.FluxDiT",
368
+ "extra_kwargs": {'input_dim': 196, 'num_blocks': 8},
369
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverter",
370
+ },
371
+ {
372
+ # Example: ModelConfig(model_id="DiffSynth-Studio/AttriCtrl-FLUX.1-Dev", origin_file_pattern="models/brightness.safetensors")
373
+ "model_hash": "0629116fce1472503a66992f96f3eb1a",
374
+ "model_name": "flux_value_controller",
375
+ "model_class": "diffsynth.models.flux_value_control.SingleValueEncoder",
376
+ },
377
+ {
378
+ # Example: ModelConfig(model_id="alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta", origin_file_pattern="diffusion_pytorch_model.safetensors")
379
+ "model_hash": "52357cb26250681367488a8954c271e8",
380
+ "model_name": "flux_controlnet",
381
+ "model_class": "diffsynth.models.flux_controlnet.FluxControlNet",
382
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_controlnet.FluxControlNetStateDictConverter",
383
+ "extra_kwargs": {"num_joint_blocks": 6, "num_single_blocks": 0, "additional_input_dim": 4},
384
+ },
385
+ {
386
+ # Example: ModelConfig(model_id="InstantX/FLUX.1-dev-Controlnet-Union-alpha", origin_file_pattern="diffusion_pytorch_model.safetensors")
387
+ "model_hash": "78d18b9101345ff695f312e7e62538c0",
388
+ "model_name": "flux_controlnet",
389
+ "model_class": "diffsynth.models.flux_controlnet.FluxControlNet",
390
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_controlnet.FluxControlNetStateDictConverter",
391
+ "extra_kwargs": {"num_mode": 10, "mode_dict": {"canny": 0, "tile": 1, "depth": 2, "blur": 3, "pose": 4, "gray": 5, "lq": 6}},
392
+ },
393
+ {
394
+ # Example: ModelConfig(model_id="jasperai/Flux.1-dev-Controlnet-Upscaler", origin_file_pattern="diffusion_pytorch_model.safetensors")
395
+ "model_hash": "b001c89139b5f053c715fe772362dd2a",
396
+ "model_name": "flux_controlnet",
397
+ "model_class": "diffsynth.models.flux_controlnet.FluxControlNet",
398
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_controlnet.FluxControlNetStateDictConverter",
399
+ "extra_kwargs": {"num_single_blocks": 0},
400
+ },
401
+ {
402
+ # Example: ModelConfig(model_id="ByteDance/InfiniteYou", origin_file_pattern="infu_flux_v1.0/aes_stage2/image_proj_model.bin")
403
+ "model_hash": "c07c0f04f5ff55e86b4e937c7a40d481",
404
+ "model_name": "infiniteyou_image_projector",
405
+ "model_class": "diffsynth.models.flux_infiniteyou.InfiniteYouImageProjector",
406
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_infiniteyou.FluxInfiniteYouImageProjectorStateDictConverter",
407
+ },
408
+ {
409
+ # Example: ModelConfig(model_id="ByteDance/InfiniteYou", origin_file_pattern="infu_flux_v1.0/aes_stage2/InfuseNetModel/*.safetensors")
410
+ "model_hash": "7f9583eb8ba86642abb9a21a4b2c9e16",
411
+ "model_name": "flux_controlnet",
412
+ "model_class": "diffsynth.models.flux_controlnet.FluxControlNet",
413
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_controlnet.FluxControlNetStateDictConverter",
414
+ "extra_kwargs": {"num_joint_blocks": 4, "num_single_blocks": 10},
415
+ },
416
+ {
417
+ # Example: ModelConfig(model_id="DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev", origin_file_pattern="model.safetensors")
418
+ "model_hash": "77c2e4dd2440269eb33bfaa0d004f6ab",
419
+ "model_name": "flux_lora_encoder",
420
+ "model_class": "diffsynth.models.flux_lora_encoder.FluxLoRAEncoder",
421
+ },
422
+ {
423
+ # Example: ModelConfig(model_id="DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev", origin_file_pattern="model.safetensors")
424
+ "model_hash": "30143afb2dea73d1ac580e0787628f8c",
425
+ "model_name": "flux_lora_patcher",
426
+ "model_class": "diffsynth.models.flux_lora_patcher.FluxLoraPatcher",
427
+ },
428
+ {
429
+ # Example: ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="model*.safetensors")
430
+ "model_hash": "2bd19e845116e4f875a0a048e27fc219",
431
+ "model_name": "nexus_gen_llm",
432
+ "model_class": "diffsynth.models.nexus_gen.NexusGenAutoregressiveModel",
433
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.nexus_gen.NexusGenAutoregressiveModelStateDictConverter",
434
+ },
435
+ {
436
+ # Example: ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="edit_decoder.bin")
437
+ "model_hash": "63c969fd37cce769a90aa781fbff5f81",
438
+ "model_name": "nexus_gen_editing_adapter",
439
+ "model_class": "diffsynth.models.nexus_gen_projector.NexusGenImageEmbeddingMerger",
440
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.nexus_gen_projector.NexusGenMergerStateDictConverter",
441
+ },
442
+ {
443
+ # Example: ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="edit_decoder.bin")
444
+ "model_hash": "63c969fd37cce769a90aa781fbff5f81",
445
+ "model_name": "flux_dit",
446
+ "model_class": "diffsynth.models.flux_dit.FluxDiT",
447
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverter",
448
+ },
449
+ {
450
+ # Example: ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="generation_decoder.bin")
451
+ "model_hash": "3e6c61b0f9471135fc9c6d6a98e98b6d",
452
+ "model_name": "nexus_gen_generation_adapter",
453
+ "model_class": "diffsynth.models.nexus_gen_projector.NexusGenAdapter",
454
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.nexus_gen_projector.NexusGenAdapterStateDictConverter",
455
+ },
456
+ {
457
+ # Example: ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="generation_decoder.bin")
458
+ "model_hash": "3e6c61b0f9471135fc9c6d6a98e98b6d",
459
+ "model_name": "flux_dit",
460
+ "model_class": "diffsynth.models.flux_dit.FluxDiT",
461
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverter",
462
+ },
463
+ {
464
+ # Example: ModelConfig(model_id="InstantX/FLUX.1-dev-IP-Adapter", origin_file_pattern="ip-adapter.bin")
465
+ "model_hash": "4daaa66cc656a8fe369908693dad0a35",
466
+ "model_name": "flux_ipadapter",
467
+ "model_class": "diffsynth.models.flux_ipadapter.FluxIpAdapter",
468
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_ipadapter.FluxIpAdapterStateDictConverter",
469
+ },
470
+ {
471
+ # Example: ModelConfig(model_id="google/siglip-so400m-patch14-384", origin_file_pattern="model.safetensors")
472
+ "model_hash": "04d8c1e20a1f1b25f7434f111992a33f",
473
+ "model_name": "siglip_vision_model",
474
+ "model_class": "diffsynth.models.flux_ipadapter.SiglipVisionModelSO400M",
475
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_ipadapter.SiglipStateDictConverter",
476
+ },
477
+ {
478
+ # Example: ModelConfig(model_id="stepfun-ai/Step1X-Edit", origin_file_pattern="step1x-edit-i1258.safetensors"),
479
+ "model_hash": "d30fb9e02b1dbf4e509142f05cf7dd50",
480
+ "model_name": "step1x_connector",
481
+ "model_class": "diffsynth.models.step1x_connector.Qwen2Connector",
482
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.step1x_connector.Qwen2ConnectorStateDictConverter",
483
+ },
484
+ {
485
+ # Example: ModelConfig(model_id="stepfun-ai/Step1X-Edit", origin_file_pattern="step1x-edit-i1258.safetensors"),
486
+ "model_hash": "d30fb9e02b1dbf4e509142f05cf7dd50",
487
+ "model_name": "flux_dit",
488
+ "model_class": "diffsynth.models.flux_dit.FluxDiT",
489
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverter",
490
+ "extra_kwargs": {"disable_guidance_embedder": True},
491
+ },
492
+ {
493
+ # Example: ModelConfig(model_id="MAILAND/majicflus_v1", origin_file_pattern="majicflus_v134.safetensors")
494
+ "model_hash": "3394f306c4cbf04334b712bf5aaed95f",
495
+ "model_name": "flux_dit",
496
+ "model_class": "diffsynth.models.flux_dit.FluxDiT",
497
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverter",
498
+ },
499
+ ]
500
+
501
+ flux2_series = [
502
+ {
503
+ # Example: ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="text_encoder/*.safetensors")
504
+ "model_hash": "28fca3d8e5bf2a2d1271748a773f6757",
505
+ "model_name": "flux2_text_encoder",
506
+ "model_class": "diffsynth.models.flux2_text_encoder.Flux2TextEncoder",
507
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.flux2_text_encoder.Flux2TextEncoderStateDictConverter",
508
+ },
509
+ {
510
+ # Example: ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="transformer/*.safetensors")
511
+ "model_hash": "d38e1d5c5aec3b0a11e79327ac6e3b0f",
512
+ "model_name": "flux2_dit",
513
+ "model_class": "diffsynth.models.flux2_dit.Flux2DiT",
514
+ },
515
+ {
516
+ # Example: ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="vae/diffusion_pytorch_model.safetensors")
517
+ "model_hash": "c54288e3ee12ca215898840682337b95",
518
+ "model_name": "flux2_vae",
519
+ "model_class": "diffsynth.models.flux2_vae.Flux2VAE",
520
+ },
521
+ {
522
+ # Example: ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="transformer/*.safetensors")
523
+ "model_hash": "3bde7b817fec8143028b6825a63180df",
524
+ "model_name": "flux2_dit",
525
+ "model_class": "diffsynth.models.flux2_dit.Flux2DiT",
526
+ "extra_kwargs": {"guidance_embeds": False, "joint_attention_dim": 7680, "num_attention_heads": 24, "num_layers": 5, "num_single_layers": 20}
527
+ },
528
+ {
529
+ # Example: ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="text_encoder/*.safetensors")
530
+ "model_hash": "9195f3ea256fcd0ae6d929c203470754",
531
+ "model_name": "z_image_text_encoder",
532
+ "model_class": "diffsynth.models.z_image_text_encoder.ZImageTextEncoder",
533
+ "extra_kwargs": {"model_size": "8B"},
534
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.z_image_text_encoder.ZImageTextEncoderStateDictConverter",
535
+ },
536
+ {
537
+ # Example: ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="transformer/*.safetensors")
538
+ "model_hash": "39c6fc48f07bebecedbbaa971ff466c8",
539
+ "model_name": "flux2_dit",
540
+ "model_class": "diffsynth.models.flux2_dit.Flux2DiT",
541
+ "extra_kwargs": {"guidance_embeds": False, "joint_attention_dim": 12288, "num_attention_heads": 32, "num_layers": 8, "num_single_layers": 24}
542
+ },
543
+ ]
544
+
545
+ z_image_series = [
546
+ {
547
+ # Example: ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="transformer/*.safetensors")
548
+ "model_hash": "fc3a8a1247fe185ce116ccbe0e426c28",
549
+ "model_name": "z_image_dit",
550
+ "model_class": "diffsynth.models.z_image_dit.ZImageDiT",
551
+ },
552
+ {
553
+ # Example: ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors")
554
+ "model_hash": "0f050f62a88876fea6eae0a18dac5a2e",
555
+ "model_name": "z_image_text_encoder",
556
+ "model_class": "diffsynth.models.z_image_text_encoder.ZImageTextEncoder",
557
+ },
558
+ {
559
+ # Example: ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/vae/diffusion_pytorch_model.safetensors")
560
+ "model_hash": "1aafa3cc91716fb6b300cc1cd51b85a3",
561
+ "model_name": "flux_vae_encoder",
562
+ "model_class": "diffsynth.models.flux_vae.FluxVAEEncoder",
563
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_vae.FluxVAEEncoderStateDictConverterDiffusers",
564
+ "extra_kwargs": {"use_conv_attention": False},
565
+ },
566
+ {
567
+ # Example: ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/vae/diffusion_pytorch_model.safetensors")
568
+ "model_hash": "1aafa3cc91716fb6b300cc1cd51b85a3",
569
+ "model_name": "flux_vae_decoder",
570
+ "model_class": "diffsynth.models.flux_vae.FluxVAEDecoder",
571
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_vae.FluxVAEDecoderStateDictConverterDiffusers",
572
+ "extra_kwargs": {"use_conv_attention": False},
573
+ },
574
+ {
575
+ # Example: ModelConfig(model_id="Tongyi-MAI/Z-Image-Omni-Base", origin_file_pattern="transformer/*.safetensors")
576
+ "model_hash": "aa3563718e5c3ecde3dfbb020ca61180",
577
+ "model_name": "z_image_dit",
578
+ "model_class": "diffsynth.models.z_image_dit.ZImageDiT",
579
+ "extra_kwargs": {"siglip_feat_dim": 1152},
580
+ },
581
+ {
582
+ # Example: ModelConfig(model_id="Tongyi-MAI/Z-Image-Omni-Base", origin_file_pattern="siglip/model.safetensors")
583
+ "model_hash": "89d48e420f45cff95115a9f3e698d44a",
584
+ "model_name": "siglip_vision_model_428m",
585
+ "model_class": "diffsynth.models.siglip2_image_encoder.Siglip2ImageEncoder428M",
586
+ },
587
+ {
588
+ # Example: ModelConfig(model_id="PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1", origin_file_pattern="Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.safetensors")
589
+ "model_hash": "1677708d40029ab380a95f6c731a57d7",
590
+ "model_name": "z_image_controlnet",
591
+ "model_class": "diffsynth.models.z_image_controlnet.ZImageControlNet",
592
+ },
593
+ {
594
+ # Example: ???
595
+ "model_hash": "9510cb8cd1dd34ee0e4f111c24905510",
596
+ "model_name": "z_image_image2lora_style",
597
+ "model_class": "diffsynth.models.z_image_image2lora.ZImageImage2LoRAModel",
598
+ "extra_kwargs": {"compress_dim": 128},
599
+ },
600
+ {
601
+ # Example: ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="model.safetensors")
602
+ "model_hash": "1392adecee344136041e70553f875f31",
603
+ "model_name": "z_image_text_encoder",
604
+ "model_class": "diffsynth.models.z_image_text_encoder.ZImageTextEncoder",
605
+ "extra_kwargs": {"model_size": "0.6B"},
606
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.z_image_text_encoder.ZImageTextEncoderStateDictConverter",
607
+ },
608
+ {
609
+ # To ensure compatibility with the `model.diffusion_model` prefix introduced by other frameworks.
610
+ "model_hash": "8cf241a0d32f93d5de368502a086852f",
611
+ "model_name": "z_image_dit",
612
+ "model_class": "diffsynth.models.z_image_dit.ZImageDiT",
613
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.z_image_dit.ZImageDiTStateDictConverter",
614
+ },
615
+ ]
616
+ """
617
+ Offical model repo: https://www.modelscope.cn/models/Lightricks/LTX-2
618
+ Repackaged model repo: https://www.modelscope.cn/models/DiffSynth-Studio/LTX-2-Repackage
619
+ For base models of LTX-2, offical checkpoint (with model config ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors"))
620
+ and repackaged checkpoints (with model config ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="*.safetensors")) are both supported.
621
+ We have repackeged the official checkpoints in DiffSynth-Studio/LTX-2-Repackage repo to support separate loading of different submodules,
622
+ and avoid redundant memory usage when users only want to use part of the model.
623
+ """
624
+ ltx2_series = [
625
+ {
626
+ # Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
627
+ "model_hash": "aca7b0bbf8415e9c98360750268915fc",
628
+ "model_name": "ltx2_dit",
629
+ "model_class": "diffsynth.models.ltx2_dit.LTXModel",
630
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_dit.LTXModelStateDictConverter",
631
+ },
632
+ {
633
+ # Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="transformer.safetensors")
634
+ "model_hash": "c567aaa37d5ed7454c73aa6024458661",
635
+ "model_name": "ltx2_dit",
636
+ "model_class": "diffsynth.models.ltx2_dit.LTXModel",
637
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_dit.LTXModelStateDictConverter",
638
+ },
639
+ {
640
+ # Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
641
+ "model_hash": "aca7b0bbf8415e9c98360750268915fc",
642
+ "model_name": "ltx2_video_vae_encoder",
643
+ "model_class": "diffsynth.models.ltx2_video_vae.LTX2VideoEncoder",
644
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_video_vae.LTX2VideoEncoderStateDictConverter",
645
+ },
646
+ {
647
+ # Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_encoder.safetensors")
648
+ "model_hash": "7f7e904a53260ec0351b05f32153754b",
649
+ "model_name": "ltx2_video_vae_encoder",
650
+ "model_class": "diffsynth.models.ltx2_video_vae.LTX2VideoEncoder",
651
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_video_vae.LTX2VideoEncoderStateDictConverter",
652
+ },
653
+ {
654
+ # Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
655
+ "model_hash": "aca7b0bbf8415e9c98360750268915fc",
656
+ "model_name": "ltx2_video_vae_decoder",
657
+ "model_class": "diffsynth.models.ltx2_video_vae.LTX2VideoDecoder",
658
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_video_vae.LTX2VideoDecoderStateDictConverter",
659
+ },
660
+ {
661
+ # Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_decoder.safetensors")
662
+ "model_hash": "dc6029ca2825147872b45e35a2dc3a97",
663
+ "model_name": "ltx2_video_vae_decoder",
664
+ "model_class": "diffsynth.models.ltx2_video_vae.LTX2VideoDecoder",
665
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_video_vae.LTX2VideoDecoderStateDictConverter",
666
+ },
667
+ {
668
+ # Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
669
+ "model_hash": "aca7b0bbf8415e9c98360750268915fc",
670
+ "model_name": "ltx2_audio_vae_decoder",
671
+ "model_class": "diffsynth.models.ltx2_audio_vae.LTX2AudioDecoder",
672
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2AudioDecoderStateDictConverter",
673
+ },
674
+ {
675
+ # Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_decoder.safetensors")
676
+ "model_hash": "7d7823dde8f1ea0b50fb07ac329dd4cb",
677
+ "model_name": "ltx2_audio_vae_decoder",
678
+ "model_class": "diffsynth.models.ltx2_audio_vae.LTX2AudioDecoder",
679
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2AudioDecoderStateDictConverter",
680
+ },
681
+ {
682
+ # Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
683
+ "model_hash": "aca7b0bbf8415e9c98360750268915fc",
684
+ "model_name": "ltx2_audio_vocoder",
685
+ "model_class": "diffsynth.models.ltx2_audio_vae.LTX2Vocoder",
686
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2VocoderStateDictConverter",
687
+ },
688
+ {
689
+ # Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vocoder.safetensors")
690
+ "model_hash": "f471360f6b24bef702ab73133d9f8bb9",
691
+ "model_name": "ltx2_audio_vocoder",
692
+ "model_class": "diffsynth.models.ltx2_audio_vae.LTX2Vocoder",
693
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2VocoderStateDictConverter",
694
+ },
695
+ {
696
+ # Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
697
+ "model_hash": "aca7b0bbf8415e9c98360750268915fc",
698
+ "model_name": "ltx2_audio_vae_encoder",
699
+ "model_class": "diffsynth.models.ltx2_audio_vae.LTX2AudioEncoder",
700
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2AudioEncoderStateDictConverter",
701
+ },
702
+ {
703
+ # Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_encoder.safetensors")
704
+ "model_hash": "29338f3b95e7e312a3460a482e4f4554",
705
+ "model_name": "ltx2_audio_vae_encoder",
706
+ "model_class": "diffsynth.models.ltx2_audio_vae.LTX2AudioEncoder",
707
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2AudioEncoderStateDictConverter",
708
+ },
709
+ {
710
+ # Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
711
+ "model_hash": "aca7b0bbf8415e9c98360750268915fc",
712
+ "model_name": "ltx2_text_encoder_post_modules",
713
+ "model_class": "diffsynth.models.ltx2_text_encoder.LTX2TextEncoderPostModules",
714
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_text_encoder.LTX2TextEncoderPostModulesStateDictConverter",
715
+ },
716
+ {
717
+ # Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="text_encoder_post_modules.safetensors")
718
+ "model_hash": "981629689c8be92a712ab3c5eb4fc3f6",
719
+ "model_name": "ltx2_text_encoder_post_modules",
720
+ "model_class": "diffsynth.models.ltx2_text_encoder.LTX2TextEncoderPostModules",
721
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_text_encoder.LTX2TextEncoderPostModulesStateDictConverter",
722
+ },
723
+ {
724
+ # Example: ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors")
725
+ "model_hash": "33917f31c4a79196171154cca39f165e",
726
+ "model_name": "ltx2_text_encoder",
727
+ "model_class": "diffsynth.models.ltx2_text_encoder.LTX2TextEncoder",
728
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_text_encoder.LTX2TextEncoderStateDictConverter",
729
+ },
730
+ {
731
+ # Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
732
+ "model_hash": "c79c458c6e99e0e14d47e676761732d2",
733
+ "model_name": "ltx2_latent_upsampler",
734
+ "model_class": "diffsynth.models.ltx2_upsampler.LTX2LatentUpsampler",
735
+ },
736
+ {
737
+ # Example: ModelConfig(model_id="Lightricks/LTX-2.3", origin_file_pattern="ltx-2.3-22b-dev.safetensors")
738
+ "model_hash": "f3a83ecf3995dcc4fae2d27e08ad5767",
739
+ "model_name": "ltx2_dit",
740
+ "model_class": "diffsynth.models.ltx2_dit.LTXModel",
741
+ "extra_kwargs": {"apply_gated_attention": True, "cross_attention_adaln": True, "caption_channels": None},
742
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_dit.LTXModelStateDictConverter",
743
+ },
744
+ {
745
+ # Example: ModelConfig(model_id="Lightricks/LTX-2.3", origin_file_pattern="ltx-2.3-22b-dev.safetensors")
746
+ "model_hash": "f3a83ecf3995dcc4fae2d27e08ad5767",
747
+ "model_name": "ltx2_video_vae_encoder",
748
+ "model_class": "diffsynth.models.ltx2_video_vae.LTX2VideoEncoder",
749
+ "extra_kwargs": {"encoder_version": "ltx-2.3"},
750
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_video_vae.LTX2VideoEncoderStateDictConverter",
751
+ },
752
+ {
753
+ # Example: ModelConfig(model_id="Lightricks/LTX-2.3", origin_file_pattern="ltx-2.3-22b-dev.safetensors")
754
+ "model_hash": "f3a83ecf3995dcc4fae2d27e08ad5767",
755
+ "model_name": "ltx2_video_vae_decoder",
756
+ "model_class": "diffsynth.models.ltx2_video_vae.LTX2VideoDecoder",
757
+ "extra_kwargs": {"decoder_version": "ltx-2.3"},
758
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_video_vae.LTX2VideoDecoderStateDictConverter",
759
+ },
760
+ {
761
+ # Example: ModelConfig(model_id="Lightricks/LTX-2.3", origin_file_pattern="ltx-2.3-22b-dev.safetensors")
762
+ "model_hash": "f3a83ecf3995dcc4fae2d27e08ad5767",
763
+ "model_name": "ltx2_audio_vae_decoder",
764
+ "model_class": "diffsynth.models.ltx2_audio_vae.LTX2AudioDecoder",
765
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2AudioDecoderStateDictConverter",
766
+ },
767
+ {
768
+ # Example: ModelConfig(model_id="Lightricks/LTX-2.3", origin_file_pattern="ltx-2.3-22b-dev.safetensors")
769
+ "model_hash": "f3a83ecf3995dcc4fae2d27e08ad5767",
770
+ "model_name": "ltx2_audio_vocoder",
771
+ "model_class": "diffsynth.models.ltx2_audio_vae.LTX2VocoderWithBWE",
772
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2VocoderStateDictConverter",
773
+ },
774
+ {
775
+ # Example: ModelConfig(model_id="Lightricks/LTX-2.3", origin_file_pattern="ltx-2.3-22b-dev.safetensors")
776
+ "model_hash": "f3a83ecf3995dcc4fae2d27e08ad5767",
777
+ "model_name": "ltx2_audio_vae_encoder",
778
+ "model_class": "diffsynth.models.ltx2_audio_vae.LTX2AudioEncoder",
779
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2AudioEncoderStateDictConverter",
780
+ },
781
+ {
782
+ # Example: ModelConfig(model_id="Lightricks/LTX-2.3", origin_file_pattern="ltx-2.3-22b-dev.safetensors")
783
+ "model_hash": "f3a83ecf3995dcc4fae2d27e08ad5767",
784
+ "model_name": "ltx2_text_encoder_post_modules",
785
+ "model_class": "diffsynth.models.ltx2_text_encoder.LTX2TextEncoderPostModules",
786
+ "extra_kwargs": {"separated_audio_video": True, "embedding_dim_gemma": 3840, "num_layers_gemma": 49, "video_attention_heads": 32, "video_attention_head_dim": 128, "audio_attention_heads": 32, "audio_attention_head_dim": 64, "num_connector_layers": 8, "apply_gated_attention": True},
787
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_text_encoder.LTX2TextEncoderPostModulesStateDictConverter",
788
+ },
789
+ {
790
+ # Example: ModelConfig(model_id="Lightricks/LTX-2.3", origin_file_pattern="ltx-2.3-spatial-upscaler-x2-1.0.safetensors")
791
+ "model_hash": "aed408774d694a2452f69936c32febb5",
792
+ "model_name": "ltx2_latent_upsampler",
793
+ "model_class": "diffsynth.models.ltx2_upsampler.LTX2LatentUpsampler",
794
+ "extra_kwargs": {"rational_resampler": False},
795
+ },
796
+ {
797
+ # Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2.3-Repackage", origin_file_pattern="transformer.safetensors")
798
+ "model_hash": "1c55afad76ed33c112a2978550b524d1",
799
+ "model_name": "ltx2_dit",
800
+ "model_class": "diffsynth.models.ltx2_dit.LTXModel",
801
+ "extra_kwargs": {"apply_gated_attention": True, "cross_attention_adaln": True, "caption_channels": None},
802
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_dit.LTXModelStateDictConverter",
803
+ },
804
+ {
805
+ # Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2.3-Repackage", origin_file_pattern="video_vae_encoder.safetensors")
806
+ "model_hash": "eecdc07c2ec30863b8a2b8b2134036cf",
807
+ "model_name": "ltx2_video_vae_encoder",
808
+ "model_class": "diffsynth.models.ltx2_video_vae.LTX2VideoEncoder",
809
+ "extra_kwargs": {"encoder_version": "ltx-2.3"},
810
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_video_vae.LTX2VideoEncoderStateDictConverter",
811
+ },
812
+ {
813
+ # Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2.3-Repackage", origin_file_pattern="video_vae_decoder.safetensors")
814
+ "model_hash": "deda2f542e17ee25bc8c38fd605316ea",
815
+ "model_name": "ltx2_video_vae_decoder",
816
+ "model_class": "diffsynth.models.ltx2_video_vae.LTX2VideoDecoder",
817
+ "extra_kwargs": {"decoder_version": "ltx-2.3"},
818
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_video_vae.LTX2VideoDecoderStateDictConverter",
819
+ },
820
+ {
821
+ # Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2.3-Repackage", origin_file_pattern="audio_vocoder.safetensors")
822
+ "model_hash": "7d7823dde8f1ea0b50fb07ac329dd4cb",
823
+ "model_name": "ltx2_audio_vae_decoder",
824
+ "model_class": "diffsynth.models.ltx2_audio_vae.LTX2AudioDecoder",
825
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2AudioDecoderStateDictConverter",
826
+ },
827
+ {
828
+ # Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2.3-Repackage", origin_file_pattern="audio_vae_encoder.safetensors")
829
+ "model_hash": "29338f3b95e7e312a3460a482e4f4554",
830
+ "model_name": "ltx2_audio_vae_encoder",
831
+ "model_class": "diffsynth.models.ltx2_audio_vae.LTX2AudioEncoder",
832
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2AudioEncoderStateDictConverter",
833
+ },
834
+ {
835
+ # Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2.3-Repackage", origin_file_pattern="audio_vocoder.safetensors")
836
+ "model_hash": "cd436c99e69ec5c80f050f0944f02a15",
837
+ "model_name": "ltx2_audio_vocoder",
838
+ "model_class": "diffsynth.models.ltx2_audio_vae.LTX2VocoderWithBWE",
839
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2VocoderStateDictConverter",
840
+ },
841
+ {
842
+ # Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2.3-Repackage", origin_file_pattern="text_encoder_post_modules.safetensors")
843
+ "model_hash": "05da2aab1c4b061f72c426311c165a43",
844
+ "model_name": "ltx2_text_encoder_post_modules",
845
+ "model_class": "diffsynth.models.ltx2_text_encoder.LTX2TextEncoderPostModules",
846
+ "extra_kwargs": {"separated_audio_video": True, "embedding_dim_gemma": 3840, "num_layers_gemma": 49, "video_attention_heads": 32, "video_attention_head_dim": 128, "audio_attention_heads": 32, "audio_attention_head_dim": 64, "num_connector_layers": 8, "apply_gated_attention": True},
847
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_text_encoder.LTX2TextEncoderPostModulesStateDictConverter",
848
+ },
849
+ ]
850
+ anima_series = [
851
+ {
852
+ # Example: ModelConfig(model_id="circlestone-labs/Anima", origin_file_pattern="split_files/vae/qwen_image_vae.safetensors")
853
+ "model_hash": "a9995952c2d8e63cf82e115005eb61b9",
854
+ "model_name": "z_image_text_encoder",
855
+ "model_class": "diffsynth.models.z_image_text_encoder.ZImageTextEncoder",
856
+ "extra_kwargs": {"model_size": "0.6B"},
857
+ },
858
+ {
859
+ # Example: ModelConfig(model_id="circlestone-labs/Anima", origin_file_pattern="split_files/diffusion_models/anima-preview.safetensors")
860
+ "model_hash": "417673936471e79e31ed4d186d7a3f4a",
861
+ "model_name": "anima_dit",
862
+ "model_class": "diffsynth.models.anima_dit.AnimaDiT",
863
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.anima_dit.AnimaDiTStateDictConverter",
864
+ }
865
+ ]
866
+
867
+ mova_series = [
868
+ # Example: ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="audio_dit/diffusion_pytorch_model.safetensors")
869
+ {
870
+ "model_hash": "8c57e12790e2c45a64817e0ce28cde2f",
871
+ "model_name": "mova_audio_dit",
872
+ "model_class": "diffsynth.models.mova_audio_dit.MovaAudioDit",
873
+ "extra_kwargs": {'has_image_input': False, 'patch_size': [1], 'in_dim': 128, 'dim': 1536, 'ffn_dim': 8960, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 128, 'num_heads': 12, 'num_layers': 30, 'eps': 1e-06}
874
+ },
875
+ # Example: ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="audio_vae/diffusion_pytorch_model.safetensors")
876
+ {
877
+ "model_hash": "418517fb2b4e919d2cac8f314fcf82ac",
878
+ "model_name": "mova_audio_vae",
879
+ "model_class": "diffsynth.models.mova_audio_vae.DacVAE",
880
+ },
881
+ # Example: ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="dual_tower_bridge/diffusion_pytorch_model.safetensors")
882
+ {
883
+ "model_hash": "d1139dbbc8b4ab53cf4b4243d57bbceb",
884
+ "model_name": "mova_dual_tower_bridge",
885
+ "model_class": "diffsynth.models.mova_dual_tower_bridge.DualTowerConditionalBridge",
886
+ },
887
+ ]
888
+ MODEL_CONFIGS = qwen_image_series + wan_series + flux_series + flux2_series + z_image_series + ltx2_series + anima_series + mova_series
diffsynth/configs/vram_management_module_maps.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ flux_general_vram_config = {
2
+ "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
3
+ "torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
4
+ "torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
5
+ "torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
6
+ "torch.nn.GroupNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
7
+ "diffsynth.models.general_modules.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
8
+ "diffsynth.models.flux_lora_encoder.LoRALayerBlock": "diffsynth.core.vram.layers.AutoWrappedModule",
9
+ "diffsynth.models.flux_lora_patcher.LoraMerger": "diffsynth.core.vram.layers.AutoWrappedModule",
10
+ }
11
+
12
+ VRAM_MANAGEMENT_MODULE_MAPS = {
13
+ "diffsynth.models.qwen_image_dit.QwenImageDiT": {
14
+ "diffsynth.models.qwen_image_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
15
+ "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
16
+ "torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
17
+ },
18
+ "diffsynth.models.qwen_image_text_encoder.QwenImageTextEncoder": {
19
+ "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
20
+ "torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
21
+ "transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VLRotaryEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule",
22
+ "transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
23
+ "transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VisionPatchEmbed": "diffsynth.core.vram.layers.AutoWrappedModule",
24
+ "transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VisionRotaryEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule",
25
+ },
26
+ "diffsynth.models.qwen_image_vae.QwenImageVAE": {
27
+ "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
28
+ "torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
29
+ "torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
30
+ "diffsynth.models.qwen_image_vae.QwenImageRMS_norm": "diffsynth.core.vram.layers.AutoWrappedModule",
31
+ },
32
+ "diffsynth.models.qwen_image_controlnet.BlockWiseControlBlock": {
33
+ "diffsynth.models.qwen_image_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
34
+ "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
35
+ },
36
+ "diffsynth.models.siglip2_image_encoder.Siglip2ImageEncoder": {
37
+ "transformers.models.siglip.modeling_siglip.SiglipVisionEmbeddings": "diffsynth.core.vram.layers.AutoWrappedModule",
38
+ "transformers.models.siglip.modeling_siglip.SiglipMultiheadAttentionPoolingHead": "diffsynth.core.vram.layers.AutoWrappedModule",
39
+ "torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
40
+ "torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
41
+ "torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
42
+ "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
43
+ },
44
+ "diffsynth.models.dinov3_image_encoder.DINOv3ImageEncoder": {
45
+ "transformers.models.dinov3_vit.modeling_dinov3_vit.DINOv3ViTLayerScale": "diffsynth.core.vram.layers.AutoWrappedModule",
46
+ "transformers.models.dinov3_vit.modeling_dinov3_vit.DINOv3ViTRopePositionEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule",
47
+ "transformers.models.dinov3_vit.modeling_dinov3_vit.DINOv3ViTEmbeddings": "diffsynth.core.vram.layers.AutoWrappedModule",
48
+ "torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
49
+ "torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
50
+ "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
51
+ },
52
+ "diffsynth.models.qwen_image_image2lora.QwenImageImage2LoRAModel": {
53
+ "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
54
+ },
55
+ "diffsynth.models.wan_video_animate_adapter.WanAnimateAdapter": {
56
+ "diffsynth.models.wan_video_animate_adapter.FaceEncoder": "diffsynth.core.vram.layers.AutoWrappedModule",
57
+ "diffsynth.models.wan_video_animate_adapter.EqualLinear": "diffsynth.core.vram.layers.AutoWrappedModule",
58
+ "diffsynth.models.wan_video_animate_adapter.ConvLayer": "diffsynth.core.vram.layers.AutoWrappedModule",
59
+ "diffsynth.models.wan_video_animate_adapter.FusedLeakyReLU": "diffsynth.core.vram.layers.AutoWrappedModule",
60
+ "diffsynth.models.wan_video_animate_adapter.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
61
+ "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
62
+ "torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
63
+ "torch.nn.Conv1d": "diffsynth.core.vram.layers.AutoWrappedModule",
64
+ "torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
65
+ "torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
66
+ },
67
+ "diffsynth.models.wan_video_dit_s2v.WanS2VModel": {
68
+ "diffsynth.models.wan_video_dit.Head": "diffsynth.core.vram.layers.AutoWrappedModule",
69
+ "diffsynth.models.wan_video_dit_s2v.WanS2VDiTBlock": "diffsynth.core.vram.layers.AutoWrappedModule",
70
+ "diffsynth.models.wan_video_dit_s2v.CausalAudioEncoder": "diffsynth.core.vram.layers.AutoWrappedModule",
71
+ "torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
72
+ "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
73
+ "torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
74
+ "torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
75
+ "diffsynth.models.wan_video_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
76
+ "torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
77
+ },
78
+ "diffsynth.models.wan_video_dit.WanModel": {
79
+ "diffsynth.models.wan_video_dit.MLP": "diffsynth.core.vram.layers.AutoWrappedModule",
80
+ "diffsynth.models.wan_video_dit.DiTBlock": "diffsynth.core.vram.layers.AutoWrappedNonRecurseModule",
81
+ "diffsynth.models.wan_video_dit.Head": "diffsynth.core.vram.layers.AutoWrappedModule",
82
+ "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
83
+ "torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
84
+ "torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
85
+ "diffsynth.models.wan_video_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
86
+ "torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
87
+ },
88
+ "diffsynth.models.wan_video_image_encoder.WanImageEncoder": {
89
+ "diffsynth.models.wan_video_image_encoder.VisionTransformer": "diffsynth.core.vram.layers.AutoWrappedModule",
90
+ "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
91
+ "torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
92
+ "torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
93
+ },
94
+ "diffsynth.models.wan_video_mot.MotWanModel": {
95
+ "diffsynth.models.wan_video_mot.MotWanAttentionBlock": "diffsynth.core.vram.layers.AutoWrappedModule",
96
+ "torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
97
+ "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
98
+ "torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
99
+ },
100
+ "diffsynth.models.wan_video_motion_controller.WanMotionControllerModel": {
101
+ "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
102
+ },
103
+ "diffsynth.models.wan_video_text_encoder.WanTextEncoder": {
104
+ "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
105
+ "torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
106
+ "diffsynth.models.wan_video_text_encoder.T5RelativeEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule",
107
+ "diffsynth.models.wan_video_text_encoder.T5LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
108
+ },
109
+ "diffsynth.models.wan_video_vace.VaceWanModel": {
110
+ "diffsynth.models.wan_video_dit.DiTBlock": "diffsynth.core.vram.layers.AutoWrappedModule",
111
+ "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
112
+ "torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
113
+ "torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
114
+ "diffsynth.models.wan_video_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
115
+ },
116
+ "diffsynth.models.wan_video_vae.WanVideoVAE": {
117
+ "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
118
+ "torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
119
+ "diffsynth.models.wan_video_vae.RMS_norm": "diffsynth.core.vram.layers.AutoWrappedModule",
120
+ "diffsynth.models.wan_video_vae.CausalConv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
121
+ "diffsynth.models.wan_video_vae.Upsample": "diffsynth.core.vram.layers.AutoWrappedModule",
122
+ "torch.nn.SiLU": "diffsynth.core.vram.layers.AutoWrappedModule",
123
+ "torch.nn.Dropout": "diffsynth.core.vram.layers.AutoWrappedModule",
124
+ },
125
+ "diffsynth.models.wan_video_vae.WanVideoVAE38": {
126
+ "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
127
+ "torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
128
+ "diffsynth.models.wan_video_vae.RMS_norm": "diffsynth.core.vram.layers.AutoWrappedModule",
129
+ "diffsynth.models.wan_video_vae.CausalConv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
130
+ "diffsynth.models.wan_video_vae.Upsample": "diffsynth.core.vram.layers.AutoWrappedModule",
131
+ "torch.nn.SiLU": "diffsynth.core.vram.layers.AutoWrappedModule",
132
+ "torch.nn.Dropout": "diffsynth.core.vram.layers.AutoWrappedModule",
133
+ },
134
+ "diffsynth.models.wav2vec.WanS2VAudioEncoder": {
135
+ "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
136
+ "torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
137
+ "torch.nn.Conv1d": "diffsynth.core.vram.layers.AutoWrappedModule",
138
+ },
139
+ "diffsynth.models.longcat_video_dit.LongCatVideoTransformer3DModel": {
140
+ "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
141
+ "torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
142
+ "diffsynth.models.longcat_video_dit.RMSNorm_FP32": "diffsynth.core.vram.layers.AutoWrappedModule",
143
+ "diffsynth.models.longcat_video_dit.LayerNorm_FP32": "diffsynth.core.vram.layers.AutoWrappedModule",
144
+ },
145
+ "diffsynth.models.flux_dit.FluxDiT": {
146
+ "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
147
+ "diffsynth.models.flux_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
148
+ },
149
+ "diffsynth.models.flux_text_encoder_clip.FluxTextEncoderClip": flux_general_vram_config,
150
+ "diffsynth.models.flux_vae.FluxVAEEncoder": flux_general_vram_config,
151
+ "diffsynth.models.flux_vae.FluxVAEDecoder": flux_general_vram_config,
152
+ "diffsynth.models.flux_controlnet.FluxControlNet": flux_general_vram_config,
153
+ "diffsynth.models.flux_infiniteyou.InfiniteYouImageProjector": flux_general_vram_config,
154
+ "diffsynth.models.flux_ipadapter.FluxIpAdapter": flux_general_vram_config,
155
+ "diffsynth.models.flux_lora_patcher.FluxLoraPatcher": flux_general_vram_config,
156
+ "diffsynth.models.step1x_connector.Qwen2Connector": flux_general_vram_config,
157
+ "diffsynth.models.flux_lora_encoder.FluxLoRAEncoder": flux_general_vram_config,
158
+ "diffsynth.models.flux_text_encoder_t5.FluxTextEncoderT5": {
159
+ "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
160
+ "torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
161
+ "transformers.models.t5.modeling_t5.T5LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
162
+ "transformers.models.t5.modeling_t5.T5DenseActDense": "diffsynth.core.vram.layers.AutoWrappedModule",
163
+ "transformers.models.t5.modeling_t5.T5DenseGatedActDense": "diffsynth.core.vram.layers.AutoWrappedModule",
164
+ },
165
+ "diffsynth.models.flux_ipadapter.SiglipVisionModelSO400M": {
166
+ "transformers.models.siglip.modeling_siglip.SiglipVisionEmbeddings": "diffsynth.core.vram.layers.AutoWrappedModule",
167
+ "transformers.models.siglip.modeling_siglip.SiglipEncoder": "diffsynth.core.vram.layers.AutoWrappedModule",
168
+ "transformers.models.siglip.modeling_siglip.SiglipMultiheadAttentionPoolingHead": "diffsynth.core.vram.layers.AutoWrappedModule",
169
+ "torch.nn.MultiheadAttention": "diffsynth.core.vram.layers.AutoWrappedModule",
170
+ "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
171
+ "torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
172
+ },
173
+ "diffsynth.models.flux2_dit.Flux2DiT": {
174
+ "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
175
+ "torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
176
+ "torch.nn.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
177
+ },
178
+ "diffsynth.models.flux2_text_encoder.Flux2TextEncoder": {
179
+ "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
180
+ "torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
181
+ "torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
182
+ "transformers.models.mistral.modeling_mistral.MistralRMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
183
+ },
184
+ "diffsynth.models.flux2_vae.Flux2VAE": {
185
+ "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
186
+ "torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
187
+ "torch.nn.GroupNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
188
+ },
189
+ "diffsynth.models.z_image_text_encoder.ZImageTextEncoder": {
190
+ "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
191
+ "transformers.models.qwen3.modeling_qwen3.Qwen3RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
192
+ "torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
193
+ },
194
+ "diffsynth.models.z_image_dit.ZImageDiT": {
195
+ "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
196
+ "diffsynth.models.z_image_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
197
+ },
198
+ "diffsynth.models.z_image_controlnet.ZImageControlNet": {
199
+ "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
200
+ "diffsynth.models.z_image_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
201
+ },
202
+ "diffsynth.models.z_image_image2lora.ZImageImage2LoRAModel": {
203
+ "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
204
+ },
205
+ "diffsynth.models.siglip2_image_encoder.Siglip2ImageEncoder428M": {
206
+ "transformers.models.siglip2.modeling_siglip2.Siglip2VisionEmbeddings": "diffsynth.core.vram.layers.AutoWrappedModule",
207
+ "transformers.models.siglip2.modeling_siglip2.Siglip2MultiheadAttentionPoolingHead": "diffsynth.core.vram.layers.AutoWrappedModule",
208
+ "torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
209
+ "torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
210
+ "torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
211
+ "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
212
+ },
213
+ "diffsynth.models.ltx2_dit.LTXModel": {
214
+ "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
215
+ "torch.nn.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
216
+ },
217
+ "diffsynth.models.ltx2_upsampler.LTX2LatentUpsampler": {
218
+ "torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
219
+ "torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
220
+ "torch.nn.GroupNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
221
+ },
222
+ "diffsynth.models.ltx2_video_vae.LTX2VideoEncoder": {
223
+ "torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
224
+ },
225
+ "diffsynth.models.ltx2_video_vae.LTX2VideoDecoder": {
226
+ "torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
227
+ },
228
+ "diffsynth.models.ltx2_audio_vae.LTX2AudioDecoder": {
229
+ "torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
230
+ },
231
+ "diffsynth.models.ltx2_audio_vae.LTX2Vocoder": {
232
+ "torch.nn.Conv1d": "diffsynth.core.vram.layers.AutoWrappedModule",
233
+ "torch.nn.ConvTranspose1d": "diffsynth.core.vram.layers.AutoWrappedModule",
234
+ },
235
+ "diffsynth.models.ltx2_text_encoder.LTX2TextEncoderPostModules": {
236
+ "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
237
+ "torch.nn.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
238
+ "diffsynth.models.ltx2_text_encoder.Embeddings1DConnector": "diffsynth.core.vram.layers.AutoWrappedModule",
239
+ },
240
+ "diffsynth.models.ltx2_text_encoder.LTX2TextEncoder": {
241
+ "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
242
+ "transformers.models.gemma3.modeling_gemma3.Gemma3MultiModalProjector": "diffsynth.core.vram.layers.AutoWrappedModule",
243
+ "transformers.models.gemma3.modeling_gemma3.Gemma3RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
244
+ "transformers.models.gemma3.modeling_gemma3.Gemma3TextScaledWordEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule",
245
+ },
246
+ "diffsynth.models.anima_dit.AnimaDiT": {
247
+ "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
248
+ "torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
249
+ "torch.nn.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
250
+ "torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
251
+ },
252
+ "diffsynth.models.mova_audio_dit.MovaAudioDit": {
253
+ "diffsynth.models.wan_video_dit.DiTBlock": "diffsynth.core.vram.layers.AutoWrappedNonRecurseModule",
254
+ "diffsynth.models.wan_video_dit.Head": "diffsynth.core.vram.layers.AutoWrappedModule",
255
+ "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
256
+ "torch.nn.Conv1d": "diffsynth.core.vram.layers.AutoWrappedModule",
257
+ "torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
258
+ "diffsynth.models.wan_video_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
259
+ },
260
+ "diffsynth.models.mova_dual_tower_bridge.DualTowerConditionalBridge": {
261
+ "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
262
+ "torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
263
+ "diffsynth.models.wan_video_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
264
+ },
265
+ "diffsynth.models.mova_audio_vae.DacVAE": {
266
+ "diffsynth.models.mova_audio_vae.Snake1d": "diffsynth.core.vram.layers.AutoWrappedModule",
267
+ "torch.nn.Conv1d": "diffsynth.core.vram.layers.AutoWrappedModule",
268
+ "torch.nn.ConvTranspose1d": "diffsynth.core.vram.layers.AutoWrappedModule",
269
+ },
270
+ }
271
+
272
+ def QwenImageTextEncoder_Module_Map_Updater():
273
+ current = VRAM_MANAGEMENT_MODULE_MAPS["diffsynth.models.qwen_image_text_encoder.QwenImageTextEncoder"]
274
+ from packaging import version
275
+ import transformers
276
+ if version.parse(transformers.__version__) >= version.parse("5.2.0"):
277
+ # The Qwen2RMSNorm in transformers 5.2.0+ has been renamed to Qwen2_5_VLRMSNorm, so we need to update the module map accordingly
278
+ current.pop("transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2RMSNorm", None)
279
+ current["transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VLRMSNorm"] = "diffsynth.core.vram.layers.AutoWrappedModule"
280
+ return current
281
+
282
+ VERSION_CHECKER_MAPS = {
283
+ "diffsynth.models.qwen_image_text_encoder.QwenImageTextEncoder": QwenImageTextEncoder_Module_Map_Updater,
284
+ }
diffsynth/core/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from .attention import *
2
+ from .data import *
3
+ from .gradient import *
4
+ from .loader import *
5
+ from .vram import *
6
+ from .device import *
diffsynth/core/attention/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .attention import attention_forward
diffsynth/core/attention/attention.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, os
2
+ from einops import rearrange
3
+
4
+
5
+ try:
6
+ import flash_attn_interface
7
+ FLASH_ATTN_3_AVAILABLE = True
8
+ except ModuleNotFoundError:
9
+ FLASH_ATTN_3_AVAILABLE = False
10
+
11
+ try:
12
+ import flash_attn
13
+ FLASH_ATTN_2_AVAILABLE = True
14
+ except ModuleNotFoundError:
15
+ FLASH_ATTN_2_AVAILABLE = False
16
+
17
+ try:
18
+ from sageattention import sageattn
19
+ SAGE_ATTN_AVAILABLE = True
20
+ except ModuleNotFoundError:
21
+ SAGE_ATTN_AVAILABLE = False
22
+
23
+ try:
24
+ import xformers.ops as xops
25
+ XFORMERS_AVAILABLE = True
26
+ except ModuleNotFoundError:
27
+ XFORMERS_AVAILABLE = False
28
+
29
+
30
+ def initialize_attention_priority():
31
+ if os.environ.get('DIFFSYNTH_ATTENTION_IMPLEMENTATION') is not None:
32
+ return os.environ.get('DIFFSYNTH_ATTENTION_IMPLEMENTATION').lower()
33
+ elif FLASH_ATTN_3_AVAILABLE:
34
+ return "flash_attention_3"
35
+ elif FLASH_ATTN_2_AVAILABLE:
36
+ return "flash_attention_2"
37
+ elif SAGE_ATTN_AVAILABLE:
38
+ return "sage_attention"
39
+ elif XFORMERS_AVAILABLE:
40
+ return "xformers"
41
+ else:
42
+ return "torch"
43
+
44
+
45
+ ATTENTION_IMPLEMENTATION = initialize_attention_priority()
46
+
47
+
48
+ def rearrange_qkv(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", required_in_pattern="b n s d", dims=None):
49
+ dims = {} if dims is None else dims
50
+ if q_pattern != required_in_pattern:
51
+ q = rearrange(q, f"{q_pattern} -> {required_in_pattern}", **dims)
52
+ if k_pattern != required_in_pattern:
53
+ k = rearrange(k, f"{k_pattern} -> {required_in_pattern}", **dims)
54
+ if v_pattern != required_in_pattern:
55
+ v = rearrange(v, f"{v_pattern} -> {required_in_pattern}", **dims)
56
+ return q, k, v
57
+
58
+
59
+ def rearrange_out(out: torch.Tensor, out_pattern="b n s d", required_out_pattern="b n s d", dims=None):
60
+ dims = {} if dims is None else dims
61
+ if out_pattern != required_out_pattern:
62
+ out = rearrange(out, f"{required_out_pattern} -> {out_pattern}", **dims)
63
+ return out
64
+
65
+
66
+ def torch_sdpa(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, attn_mask=None, scale=None):
67
+ required_in_pattern, required_out_pattern= "b n s d", "b n s d"
68
+ q, k, v = rearrange_qkv(q, k, v, q_pattern, k_pattern, v_pattern, required_in_pattern, dims)
69
+ out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask, scale=scale)
70
+ out = rearrange_out(out, out_pattern, required_out_pattern, dims)
71
+ return out
72
+
73
+
74
+ def flash_attention_3(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, scale=None):
75
+ required_in_pattern, required_out_pattern= "b s n d", "b s n d"
76
+ q, k, v = rearrange_qkv(q, k, v, q_pattern, k_pattern, v_pattern, required_in_pattern, dims)
77
+ out = flash_attn_interface.flash_attn_func(q, k, v, softmax_scale=scale)
78
+ if isinstance(out, tuple):
79
+ out = out[0]
80
+ out = rearrange_out(out, out_pattern, required_out_pattern, dims)
81
+ return out
82
+
83
+
84
+ def flash_attention_2(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, scale=None):
85
+ required_in_pattern, required_out_pattern= "b s n d", "b s n d"
86
+ q, k, v = rearrange_qkv(q, k, v, q_pattern, k_pattern, v_pattern, required_in_pattern, dims)
87
+ out = flash_attn.flash_attn_func(q, k, v, softmax_scale=scale)
88
+ out = rearrange_out(out, out_pattern, required_out_pattern, dims)
89
+ return out
90
+
91
+
92
+ def sage_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, scale=None):
93
+ required_in_pattern, required_out_pattern= "b n s d", "b n s d"
94
+ q, k, v = rearrange_qkv(q, k, v, q_pattern, k_pattern, v_pattern, required_in_pattern, dims)
95
+ out = sageattn(q, k, v, sm_scale=scale)
96
+ out = rearrange_out(out, out_pattern, required_out_pattern, dims)
97
+ return out
98
+
99
+
100
+ def xformers_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, scale=None):
101
+ required_in_pattern, required_out_pattern= "b s n d", "b s n d"
102
+ q, k, v = rearrange_qkv(q, k, v, q_pattern, k_pattern, v_pattern, required_in_pattern, dims)
103
+ out = xops.memory_efficient_attention(q, k, v, scale=scale)
104
+ out = rearrange_out(out, out_pattern, required_out_pattern, dims)
105
+ return out
106
+
107
+
108
+ def attention_forward(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, attn_mask=None, scale=None, compatibility_mode=False):
109
+ if compatibility_mode or (attn_mask is not None):
110
+ return torch_sdpa(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, attn_mask=attn_mask, scale=scale)
111
+ else:
112
+ if ATTENTION_IMPLEMENTATION == "flash_attention_3":
113
+ return flash_attention_3(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale)
114
+ elif ATTENTION_IMPLEMENTATION == "flash_attention_2":
115
+ return flash_attention_2(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale)
116
+ elif ATTENTION_IMPLEMENTATION == "sage_attention":
117
+ return sage_attention(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale)
118
+ elif ATTENTION_IMPLEMENTATION == "xformers":
119
+ return xformers_attention(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale)
120
+ else:
121
+ return torch_sdpa(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale)
diffsynth/core/data/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .unified_dataset import UnifiedDataset
diffsynth/core/data/operators.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch, torchvision, imageio, os
3
+ import imageio.v3 as iio
4
+ from PIL import Image
5
+ import torchaudio
6
+
7
+
8
+ class DataProcessingPipeline:
9
+ def __init__(self, operators=None):
10
+ self.operators: list[DataProcessingOperator] = [] if operators is None else operators
11
+
12
+ def __call__(self, data):
13
+ for operator in self.operators:
14
+ data = operator(data)
15
+ return data
16
+
17
+ def __rshift__(self, pipe):
18
+ if isinstance(pipe, DataProcessingOperator):
19
+ pipe = DataProcessingPipeline([pipe])
20
+ return DataProcessingPipeline(self.operators + pipe.operators)
21
+
22
+
23
+ class DataProcessingOperator:
24
+ def __call__(self, data):
25
+ raise NotImplementedError("DataProcessingOperator cannot be called directly.")
26
+
27
+ def __rshift__(self, pipe):
28
+ if isinstance(pipe, DataProcessingOperator):
29
+ pipe = DataProcessingPipeline([pipe])
30
+ return DataProcessingPipeline([self]).__rshift__(pipe)
31
+
32
+
33
+ class DataProcessingOperatorRaw(DataProcessingOperator):
34
+ def __call__(self, data):
35
+ return data
36
+
37
+
38
+ class ToInt(DataProcessingOperator):
39
+ def __call__(self, data):
40
+ return int(data)
41
+
42
+
43
+ class ToFloat(DataProcessingOperator):
44
+ def __call__(self, data):
45
+ return float(data)
46
+
47
+
48
+ class ToStr(DataProcessingOperator):
49
+ def __init__(self, none_value=""):
50
+ self.none_value = none_value
51
+
52
+ def __call__(self, data):
53
+ if data is None: data = self.none_value
54
+ return str(data)
55
+
56
+
57
+ class LoadImage(DataProcessingOperator):
58
+ def __init__(self, convert_RGB=True, convert_RGBA=False):
59
+ self.convert_RGB = convert_RGB
60
+ self.convert_RGBA = convert_RGBA
61
+
62
+ def __call__(self, data: str):
63
+ image = Image.open(data)
64
+ if self.convert_RGB: image = image.convert("RGB")
65
+ if self.convert_RGBA: image = image.convert("RGBA")
66
+ return image
67
+
68
+
69
+ class ImageCropAndResize(DataProcessingOperator):
70
+ def __init__(self, height=None, width=None, max_pixels=None, height_division_factor=1, width_division_factor=1):
71
+ self.height = height
72
+ self.width = width
73
+ self.max_pixels = max_pixels
74
+ self.height_division_factor = height_division_factor
75
+ self.width_division_factor = width_division_factor
76
+
77
+ def crop_and_resize(self, image, target_height, target_width):
78
+ width, height = image.size
79
+ scale = max(target_width / width, target_height / height)
80
+ image = torchvision.transforms.functional.resize(
81
+ image,
82
+ (round(height*scale), round(width*scale)),
83
+ interpolation=torchvision.transforms.InterpolationMode.BILINEAR
84
+ )
85
+ image = torchvision.transforms.functional.center_crop(image, (target_height, target_width))
86
+ return image
87
+
88
+ def get_height_width(self, image):
89
+ if self.height is None or self.width is None:
90
+ width, height = image.size
91
+ if width * height > self.max_pixels:
92
+ scale = (width * height / self.max_pixels) ** 0.5
93
+ height, width = int(height / scale), int(width / scale)
94
+ height = height // self.height_division_factor * self.height_division_factor
95
+ width = width // self.width_division_factor * self.width_division_factor
96
+ else:
97
+ height, width = self.height, self.width
98
+ return height, width
99
+
100
+ def __call__(self, data: Image.Image):
101
+ image = self.crop_and_resize(data, *self.get_height_width(data))
102
+ return image
103
+
104
+
105
+ class ToList(DataProcessingOperator):
106
+ def __call__(self, data):
107
+ return [data]
108
+
109
+
110
+ class FrameSamplerByRateMixin:
111
+ def __init__(self, num_frames=81, time_division_factor=4, time_division_remainder=1, frame_rate=24, fix_frame_rate=False):
112
+ self.num_frames = num_frames
113
+ self.time_division_factor = time_division_factor
114
+ self.time_division_remainder = time_division_remainder
115
+ self.frame_rate = frame_rate
116
+ self.fix_frame_rate = fix_frame_rate
117
+
118
+ def get_reader(self, data: str):
119
+ return imageio.get_reader(data)
120
+
121
+ def get_available_num_frames(self, reader):
122
+ if not self.fix_frame_rate:
123
+ return reader.count_frames()
124
+ meta_data = reader.get_meta_data()
125
+ total_original_frames = int(reader.count_frames())
126
+ duration = meta_data["duration"] if "duration" in meta_data else total_original_frames / meta_data['fps']
127
+ total_available_frames = math.floor(duration * self.frame_rate)
128
+ return int(total_available_frames)
129
+
130
+ def get_num_frames(self, reader):
131
+ num_frames = self.num_frames
132
+ total_frames = self.get_available_num_frames(reader)
133
+ if int(total_frames) < num_frames:
134
+ num_frames = total_frames
135
+ while num_frames > 1 and num_frames % self.time_division_factor != self.time_division_remainder:
136
+ num_frames -= 1
137
+ return num_frames
138
+
139
+ def map_single_frame_id(self, new_sequence_id: int, raw_frame_rate: float, total_raw_frames: int) -> int:
140
+ if not self.fix_frame_rate:
141
+ return new_sequence_id
142
+ target_time_in_seconds = new_sequence_id / self.frame_rate
143
+ raw_frame_index_float = target_time_in_seconds * raw_frame_rate
144
+ frame_id = int(round(raw_frame_index_float))
145
+ frame_id = min(frame_id, total_raw_frames - 1)
146
+ return frame_id
147
+
148
+
149
+ class LoadVideo(DataProcessingOperator, FrameSamplerByRateMixin):
150
+ def __init__(self, num_frames=81, time_division_factor=4, time_division_remainder=1, frame_processor=lambda x: x, frame_rate=24, fix_frame_rate=False):
151
+ FrameSamplerByRateMixin.__init__(self, num_frames, time_division_factor, time_division_remainder, frame_rate, fix_frame_rate)
152
+ # frame_processor is build in the video loader for high efficiency.
153
+ self.frame_processor = frame_processor
154
+
155
+ def __call__(self, data: str):
156
+ reader = self.get_reader(data)
157
+ raw_frame_rate = reader.get_meta_data()['fps']
158
+ total_raw_frames = reader.count_frames()
159
+ total_available = self.get_available_num_frames(reader)
160
+ # Pad short videos with the last frame instead of reducing num_frames
161
+ num_frames = self.num_frames
162
+ frames = []
163
+ for frame_id in range(num_frames):
164
+ if frame_id < total_available:
165
+ raw_id = self.map_single_frame_id(frame_id, raw_frame_rate, total_raw_frames)
166
+ frame = reader.get_data(raw_id)
167
+ frame = Image.fromarray(frame)
168
+ frame = self.frame_processor(frame)
169
+ frames.append(frame)
170
+ else:
171
+ # Pad with the last frame
172
+ frames.append(frames[-1])
173
+ reader.close()
174
+ return frames
175
+
176
+
177
+ class SequencialProcess(DataProcessingOperator):
178
+ def __init__(self, operator=lambda x: x):
179
+ self.operator = operator
180
+
181
+ def __call__(self, data):
182
+ return [self.operator(i) for i in data]
183
+
184
+
185
+ class LoadGIF(DataProcessingOperator):
186
+ def __init__(self, num_frames=81, time_division_factor=4, time_division_remainder=1, frame_processor=lambda x: x):
187
+ self.num_frames = num_frames
188
+ self.time_division_factor = time_division_factor
189
+ self.time_division_remainder = time_division_remainder
190
+ # frame_processor is build in the video loader for high efficiency.
191
+ self.frame_processor = frame_processor
192
+
193
+ def get_num_frames(self, path):
194
+ num_frames = self.num_frames
195
+ images = iio.imread(path, mode="RGB")
196
+ if len(images) < num_frames:
197
+ num_frames = len(images)
198
+ while num_frames > 1 and num_frames % self.time_division_factor != self.time_division_remainder:
199
+ num_frames -= 1
200
+ return num_frames
201
+
202
+ def __call__(self, data: str):
203
+ num_frames = self.get_num_frames(data)
204
+ frames = []
205
+ images = iio.imread(data, mode="RGB")
206
+ for img in images:
207
+ frame = Image.fromarray(img)
208
+ frame = self.frame_processor(frame)
209
+ frames.append(frame)
210
+ if len(frames) >= num_frames:
211
+ break
212
+ return frames
213
+
214
+
215
+ class RouteByExtensionName(DataProcessingOperator):
216
+ def __init__(self, operator_map):
217
+ self.operator_map = operator_map
218
+
219
+ def __call__(self, data: str):
220
+ file_ext_name = data.split(".")[-1].lower()
221
+ for ext_names, operator in self.operator_map:
222
+ if ext_names is None or file_ext_name in ext_names:
223
+ return operator(data)
224
+ raise ValueError(f"Unsupported file: {data}")
225
+
226
+
227
+ class RouteByType(DataProcessingOperator):
228
+ def __init__(self, operator_map):
229
+ self.operator_map = operator_map
230
+
231
+ def __call__(self, data):
232
+ for dtype, operator in self.operator_map:
233
+ if dtype is None or isinstance(data, dtype):
234
+ return operator(data)
235
+ raise ValueError(f"Unsupported data: {data}")
236
+
237
+
238
+ class LoadTorchPickle(DataProcessingOperator):
239
+ def __init__(self, map_location="cpu"):
240
+ self.map_location = map_location
241
+
242
+ def __call__(self, data):
243
+ return torch.load(data, map_location=self.map_location, weights_only=False)
244
+
245
+
246
+ class ToAbsolutePath(DataProcessingOperator):
247
+ def __init__(self, base_path=""):
248
+ self.base_path = base_path
249
+
250
+ def __call__(self, data):
251
+ return os.path.join(self.base_path, data)
252
+
253
+
254
+ class LoadAudio(DataProcessingOperator):
255
+ def __init__(self, sr=16000):
256
+ self.sr = sr
257
+ def __call__(self, data: str):
258
+ import librosa
259
+ input_audio, sample_rate = librosa.load(data, sr=self.sr)
260
+ return input_audio
261
+
262
+
263
+ class LoadAudioWithTorchaudio(DataProcessingOperator, FrameSamplerByRateMixin):
264
+
265
+ def __init__(self, num_frames=121, time_division_factor=8, time_division_remainder=1, frame_rate=24, fix_frame_rate=True):
266
+ FrameSamplerByRateMixin.__init__(self, num_frames, time_division_factor, time_division_remainder, frame_rate, fix_frame_rate)
267
+
268
+ def __call__(self, data: str):
269
+ reader = self.get_reader(data)
270
+ num_frames = self.get_num_frames(reader)
271
+ duration = num_frames / self.frame_rate
272
+ waveform, sample_rate = torchaudio.load(data)
273
+ target_samples = int(duration * sample_rate)
274
+ current_samples = waveform.shape[-1]
275
+ if current_samples > target_samples:
276
+ waveform = waveform[..., :target_samples]
277
+ elif current_samples < target_samples:
278
+ padding = target_samples - current_samples
279
+ waveform = torch.nn.functional.pad(waveform, (0, padding))
280
+ return waveform, sample_rate
diffsynth/core/data/unified_dataset.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .operators import *
2
+ import torch, json, pandas
3
+
4
+
5
+ class UnifiedDataset(torch.utils.data.Dataset):
6
+ def __init__(
7
+ self,
8
+ base_path=None, metadata_path=None,
9
+ repeat=1,
10
+ data_file_keys=tuple(),
11
+ main_data_operator=lambda x: x,
12
+ special_operator_map=None,
13
+ max_data_items=None,
14
+ ):
15
+ self.base_path = base_path
16
+ self.metadata_path = metadata_path
17
+ self.repeat = repeat
18
+ self.data_file_keys = data_file_keys
19
+ self.main_data_operator = main_data_operator
20
+ self.cached_data_operator = LoadTorchPickle()
21
+ self.special_operator_map = {} if special_operator_map is None else special_operator_map
22
+ self.max_data_items = max_data_items
23
+ self.data = []
24
+ self.cached_data = []
25
+ self.load_from_cache = metadata_path is None
26
+ self.load_metadata(metadata_path)
27
+
28
+ @staticmethod
29
+ def default_image_operator(
30
+ base_path="",
31
+ max_pixels=1920*1080, height=None, width=None,
32
+ height_division_factor=16, width_division_factor=16,
33
+ ):
34
+ return RouteByType(operator_map=[
35
+ (str, ToAbsolutePath(base_path) >> LoadImage() >> ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor)),
36
+ (list, SequencialProcess(ToAbsolutePath(base_path) >> LoadImage() >> ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor))),
37
+ ])
38
+
39
+ @staticmethod
40
+ def default_video_operator(
41
+ base_path="",
42
+ max_pixels=1920*1080, height=None, width=None,
43
+ height_division_factor=16, width_division_factor=16,
44
+ num_frames=81, time_division_factor=4, time_division_remainder=1,
45
+ frame_rate=24, fix_frame_rate=False,
46
+ ):
47
+ return RouteByType(operator_map=[
48
+ (str, ToAbsolutePath(base_path) >> RouteByExtensionName(operator_map=[
49
+ (("jpg", "jpeg", "png", "webp"), LoadImage() >> ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor) >> ToList()),
50
+ (("gif",), LoadGIF(
51
+ num_frames, time_division_factor, time_division_remainder,
52
+ frame_processor=ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor),
53
+ )),
54
+ (("mp4", "avi", "mov", "wmv", "mkv", "flv", "webm"), LoadVideo(
55
+ num_frames, time_division_factor, time_division_remainder,
56
+ frame_processor=ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor),
57
+ frame_rate=frame_rate, fix_frame_rate=fix_frame_rate,
58
+ )),
59
+ ])),
60
+ ])
61
+
62
+ def search_for_cached_data_files(self, path):
63
+ for file_name in os.listdir(path):
64
+ subpath = os.path.join(path, file_name)
65
+ if os.path.isdir(subpath):
66
+ self.search_for_cached_data_files(subpath)
67
+ elif subpath.endswith(".pth"):
68
+ self.cached_data.append(subpath)
69
+
70
+ def load_metadata(self, metadata_path):
71
+ if metadata_path is None:
72
+ print("No metadata_path. Searching for cached data files.")
73
+ self.search_for_cached_data_files(self.base_path)
74
+ print(f"{len(self.cached_data)} cached data files found.")
75
+ elif metadata_path.endswith(".json"):
76
+ with open(metadata_path, "r") as f:
77
+ metadata = json.load(f)
78
+ self.data = metadata
79
+ elif metadata_path.endswith(".jsonl"):
80
+ metadata = []
81
+ with open(metadata_path, 'r') as f:
82
+ for line in f:
83
+ metadata.append(json.loads(line.strip()))
84
+ self.data = metadata
85
+ else:
86
+ metadata = pandas.read_csv(metadata_path)
87
+ self.data = [metadata.iloc[i].to_dict() for i in range(len(metadata))]
88
+
89
+ def __getitem__(self, data_id):
90
+ if self.load_from_cache:
91
+ data = self.cached_data[data_id % len(self.cached_data)]
92
+ data = self.cached_data_operator(data)
93
+ else:
94
+ data = self.data[data_id % len(self.data)].copy()
95
+ for key in self.data_file_keys:
96
+ if key in data:
97
+ if key in self.special_operator_map:
98
+ data[key] = self.special_operator_map[key](data[key])
99
+ elif key in self.data_file_keys:
100
+ data[key] = self.main_data_operator(data[key])
101
+ return data
102
+
103
+ def __len__(self):
104
+ if self.max_data_items is not None:
105
+ return self.max_data_items
106
+ elif self.load_from_cache:
107
+ return len(self.cached_data) * self.repeat
108
+ else:
109
+ return len(self.data) * self.repeat
110
+
111
+ def check_data_equal(self, data1, data2):
112
+ # Debug only
113
+ if len(data1) != len(data2):
114
+ return False
115
+ for k in data1:
116
+ if data1[k] != data2[k]:
117
+ return False
118
+ return True
diffsynth/core/device/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .npu_compatible_device import parse_device_type, parse_nccl_backend, get_available_device_type, get_device_name
2
+ from .npu_compatible_device import IS_NPU_AVAILABLE, IS_CUDA_AVAILABLE
diffsynth/core/device/npu_compatible_device.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ import torch
3
+ from typing import Any
4
+
5
+
6
+ def is_torch_npu_available():
7
+ return importlib.util.find_spec("torch_npu") is not None
8
+
9
+
10
+ IS_CUDA_AVAILABLE = torch.cuda.is_available()
11
+ IS_NPU_AVAILABLE = is_torch_npu_available() and torch.npu.is_available()
12
+
13
+ if IS_NPU_AVAILABLE:
14
+ import torch_npu
15
+
16
+ torch.npu.config.allow_internal_format = False
17
+
18
+
19
+ def get_device_type() -> str:
20
+ """Get device type based on current machine, currently only support CPU, CUDA, NPU."""
21
+ if IS_CUDA_AVAILABLE:
22
+ device = "cuda"
23
+ elif IS_NPU_AVAILABLE:
24
+ device = "npu"
25
+ else:
26
+ device = "cpu"
27
+
28
+ return device
29
+
30
+
31
+ def get_torch_device() -> Any:
32
+ """Get torch attribute based on device type, e.g. torch.cuda or torch.npu"""
33
+ device_name = get_device_type()
34
+
35
+ try:
36
+ return getattr(torch, device_name)
37
+ except AttributeError:
38
+ print(f"Device namespace '{device_name}' not found in torch, try to load 'torch.cuda'.")
39
+ return torch.cuda
40
+
41
+
42
+ def get_device_id() -> int:
43
+ """Get current device id based on device type."""
44
+ return get_torch_device().current_device()
45
+
46
+
47
+ def get_device_name() -> str:
48
+ """Get current device name based on device type."""
49
+ return f"{get_device_type()}:{get_device_id()}"
50
+
51
+
52
+ def synchronize() -> None:
53
+ """Execute torch synchronize operation."""
54
+ get_torch_device().synchronize()
55
+
56
+
57
+ def empty_cache() -> None:
58
+ """Execute torch empty cache operation."""
59
+ get_torch_device().empty_cache()
60
+
61
+
62
+ def get_nccl_backend() -> str:
63
+ """Return distributed communication backend type based on device type."""
64
+ if IS_CUDA_AVAILABLE:
65
+ return "nccl"
66
+ elif IS_NPU_AVAILABLE:
67
+ return "hccl"
68
+ else:
69
+ raise RuntimeError(f"No available distributed communication backend found on device type {get_device_type()}.")
70
+
71
+
72
+ def enable_high_precision_for_bf16():
73
+ """
74
+ Set high accumulation dtype for matmul and reduction.
75
+ """
76
+ if IS_CUDA_AVAILABLE:
77
+ torch.backends.cuda.matmul.allow_tf32 = False
78
+ torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False
79
+
80
+ if IS_NPU_AVAILABLE:
81
+ torch.npu.matmul.allow_tf32 = False
82
+ torch.npu.matmul.allow_bf16_reduced_precision_reduction = False
83
+
84
+
85
+ def parse_device_type(device):
86
+ if isinstance(device, str):
87
+ if device.startswith("cuda"):
88
+ return "cuda"
89
+ elif device.startswith("npu"):
90
+ return "npu"
91
+ else:
92
+ return "cpu"
93
+ elif isinstance(device, torch.device):
94
+ return device.type
95
+
96
+
97
+ def parse_nccl_backend(device_type):
98
+ if device_type == "cuda":
99
+ return "nccl"
100
+ elif device_type == "npu":
101
+ return "hccl"
102
+ else:
103
+ raise RuntimeError(f"No available distributed communication backend found on device type {device_type}.")
104
+
105
+
106
+ def get_available_device_type():
107
+ return get_device_type()
diffsynth/core/gradient/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .gradient_checkpoint import gradient_checkpoint_forward
diffsynth/core/gradient/gradient_checkpoint.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import warnings
3
+ # Suppress checkpoint requires_grad warning - gradients flow through model params, not inputs
4
+ warnings.filterwarnings("ignore", message=".*None of the inputs have requires_grad.*")
5
+
6
+
7
+ def create_custom_forward(module):
8
+ def custom_forward(*inputs, **kwargs):
9
+ return module(*inputs, **kwargs)
10
+ return custom_forward
11
+
12
+
13
+ def gradient_checkpoint_forward(
14
+ model,
15
+ use_gradient_checkpointing,
16
+ use_gradient_checkpointing_offload,
17
+ *args,
18
+ **kwargs,
19
+ ):
20
+ if use_gradient_checkpointing_offload:
21
+ with torch.autograd.graph.save_on_cpu():
22
+ model_output = torch.utils.checkpoint.checkpoint(
23
+ create_custom_forward(model),
24
+ *args,
25
+ **kwargs,
26
+ use_reentrant=True,
27
+ )
28
+ elif use_gradient_checkpointing:
29
+ model_output = torch.utils.checkpoint.checkpoint(
30
+ create_custom_forward(model),
31
+ *args,
32
+ **kwargs,
33
+ use_reentrant=True,
34
+ )
35
+ else:
36
+ model_output = model(*args, **kwargs)
37
+ return model_output
diffsynth/core/loader/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .file import load_state_dict, hash_state_dict_keys, hash_model_file
2
+ from .model import load_model, load_model_with_disk_offload
3
+ from .config import ModelConfig
diffsynth/core/loader/config.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, glob, os
2
+ from typing import Optional, Union, Dict
3
+ from dataclasses import dataclass
4
+ from modelscope import snapshot_download
5
+ from huggingface_hub import snapshot_download as hf_snapshot_download
6
+ from typing import Optional
7
+
8
+
9
+ @dataclass
10
+ class ModelConfig:
11
+ path: Union[str, list[str]] = None
12
+ model_id: str = None
13
+ origin_file_pattern: Union[str, list[str]] = None
14
+ download_source: str = None
15
+ local_model_path: str = None
16
+ skip_download: bool = None
17
+ offload_device: Optional[Union[str, torch.device]] = None
18
+ offload_dtype: Optional[torch.dtype] = None
19
+ onload_device: Optional[Union[str, torch.device]] = None
20
+ onload_dtype: Optional[torch.dtype] = None
21
+ preparing_device: Optional[Union[str, torch.device]] = None
22
+ preparing_dtype: Optional[torch.dtype] = None
23
+ computation_device: Optional[Union[str, torch.device]] = None
24
+ computation_dtype: Optional[torch.dtype] = None
25
+ clear_parameters: bool = False
26
+ state_dict: Dict[str, torch.Tensor] = None
27
+
28
+ def check_input(self):
29
+ if self.path is None and self.model_id is None:
30
+ raise ValueError(f"""No valid model files. Please use `ModelConfig(path="xxx")` or `ModelConfig(model_id="xxx/yyy", origin_file_pattern="zzz")`. `skip_download=True` only supports the first one.""")
31
+
32
+ def parse_original_file_pattern(self):
33
+ if self.origin_file_pattern in [None, "", "./"]:
34
+ return "*"
35
+ elif self.origin_file_pattern.endswith("/"):
36
+ return self.origin_file_pattern + "*"
37
+ else:
38
+ return self.origin_file_pattern
39
+
40
+ def parse_download_source(self):
41
+ if self.download_source is None:
42
+ if os.environ.get('DIFFSYNTH_DOWNLOAD_SOURCE') is not None:
43
+ return os.environ.get('DIFFSYNTH_DOWNLOAD_SOURCE')
44
+ else:
45
+ return "modelscope"
46
+ else:
47
+ return self.download_source
48
+
49
+ def parse_skip_download(self):
50
+ if self.skip_download is None:
51
+ if os.environ.get('DIFFSYNTH_SKIP_DOWNLOAD') is not None:
52
+ if os.environ.get('DIFFSYNTH_SKIP_DOWNLOAD').lower() == "true":
53
+ return True
54
+ elif os.environ.get('DIFFSYNTH_SKIP_DOWNLOAD').lower() == "false":
55
+ return False
56
+ else:
57
+ return False
58
+ else:
59
+ return self.skip_download
60
+
61
+ def download(self):
62
+ origin_file_pattern = self.parse_original_file_pattern()
63
+ downloaded_files = glob.glob(origin_file_pattern, root_dir=os.path.join(self.local_model_path, self.model_id))
64
+ download_source = self.parse_download_source()
65
+ if download_source.lower() == "modelscope":
66
+ snapshot_download(
67
+ self.model_id,
68
+ local_dir=os.path.join(self.local_model_path, self.model_id),
69
+ allow_file_pattern=origin_file_pattern,
70
+ ignore_file_pattern=downloaded_files,
71
+ local_files_only=False
72
+ )
73
+ elif download_source.lower() == "huggingface":
74
+ hf_snapshot_download(
75
+ self.model_id,
76
+ local_dir=os.path.join(self.local_model_path, self.model_id),
77
+ allow_patterns=origin_file_pattern,
78
+ ignore_patterns=downloaded_files,
79
+ local_files_only=False
80
+ )
81
+ else:
82
+ raise ValueError("`download_source` should be `modelscope` or `huggingface`.")
83
+
84
+ def require_downloading(self):
85
+ if self.path is not None:
86
+ return False
87
+ skip_download = self.parse_skip_download()
88
+ return not skip_download
89
+
90
+ def reset_local_model_path(self):
91
+ if os.environ.get('DIFFSYNTH_MODEL_BASE_PATH') is not None:
92
+ self.local_model_path = os.environ.get('DIFFSYNTH_MODEL_BASE_PATH')
93
+ elif self.local_model_path is None:
94
+ self.local_model_path = "./models"
95
+
96
+ def download_if_necessary(self):
97
+ self.check_input()
98
+ self.reset_local_model_path()
99
+ if self.require_downloading():
100
+ self.download()
101
+ if self.path is None:
102
+ if self.origin_file_pattern in [None, "", "./"]:
103
+ self.path = os.path.join(self.local_model_path, self.model_id)
104
+ else:
105
+ self.path = glob.glob(os.path.join(self.local_model_path, self.model_id, self.origin_file_pattern))
106
+ if isinstance(self.path, list) and len(self.path) == 1:
107
+ self.path = self.path[0]
108
+
109
+ def vram_config(self):
110
+ return {
111
+ "offload_device": self.offload_device,
112
+ "offload_dtype": self.offload_dtype,
113
+ "onload_device": self.onload_device,
114
+ "onload_dtype": self.onload_dtype,
115
+ "preparing_device": self.preparing_device,
116
+ "preparing_dtype": self.preparing_dtype,
117
+ "computation_device": self.computation_device,
118
+ "computation_dtype": self.computation_dtype,
119
+ }
diffsynth/core/loader/file.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from safetensors import safe_open
2
+ import torch, hashlib
3
+
4
+
5
+ def load_state_dict(file_path, torch_dtype=None, device="cpu", pin_memory=False, verbose=0):
6
+ if isinstance(file_path, list):
7
+ state_dict = {}
8
+ for file_path_ in file_path:
9
+ state_dict.update(load_state_dict(file_path_, torch_dtype, device, pin_memory=pin_memory, verbose=verbose))
10
+ else:
11
+ if verbose >= 1:
12
+ print(f"Loading file [started]: {file_path}")
13
+ if file_path.endswith(".safetensors"):
14
+ state_dict = load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype, device=device)
15
+ else:
16
+ state_dict = load_state_dict_from_bin(file_path, torch_dtype=torch_dtype, device=device)
17
+ # If load state dict in CPU memory, `pin_memory=True` will make `model.to("cuda")` faster.
18
+ if pin_memory:
19
+ for i in state_dict:
20
+ state_dict[i] = state_dict[i].pin_memory()
21
+ if verbose >= 1:
22
+ print(f"Loading file [done]: {file_path}")
23
+ return state_dict
24
+
25
+
26
+ def load_state_dict_from_safetensors(file_path, torch_dtype=None, device="cpu"):
27
+ state_dict = {}
28
+ with safe_open(file_path, framework="pt", device=str(device)) as f:
29
+ for k in f.keys():
30
+ state_dict[k] = f.get_tensor(k)
31
+ if torch_dtype is not None:
32
+ state_dict[k] = state_dict[k].to(torch_dtype)
33
+ return state_dict
34
+
35
+
36
+ def load_state_dict_from_bin(file_path, torch_dtype=None, device="cpu"):
37
+ state_dict = torch.load(file_path, map_location=device, weights_only=True)
38
+ if len(state_dict) == 1:
39
+ if "state_dict" in state_dict:
40
+ state_dict = state_dict["state_dict"]
41
+ elif "module" in state_dict:
42
+ state_dict = state_dict["module"]
43
+ elif "model_state" in state_dict:
44
+ state_dict = state_dict["model_state"]
45
+ if torch_dtype is not None:
46
+ for i in state_dict:
47
+ if isinstance(state_dict[i], torch.Tensor):
48
+ state_dict[i] = state_dict[i].to(torch_dtype)
49
+ return state_dict
50
+
51
+
52
+ def convert_state_dict_keys_to_single_str(state_dict, with_shape=True):
53
+ keys = []
54
+ for key, value in state_dict.items():
55
+ if isinstance(key, str):
56
+ if isinstance(value, torch.Tensor):
57
+ if with_shape:
58
+ shape = "_".join(map(str, list(value.shape)))
59
+ keys.append(key + ":" + shape)
60
+ keys.append(key)
61
+ elif isinstance(value, dict):
62
+ keys.append(key + "|" + convert_state_dict_keys_to_single_str(value, with_shape=with_shape))
63
+ keys.sort()
64
+ keys_str = ",".join(keys)
65
+ return keys_str
66
+
67
+
68
+ def hash_state_dict_keys(state_dict, with_shape=True):
69
+ keys_str = convert_state_dict_keys_to_single_str(state_dict, with_shape=with_shape)
70
+ keys_str = keys_str.encode(encoding="UTF-8")
71
+ return hashlib.md5(keys_str).hexdigest()
72
+
73
+
74
+ def load_keys_dict(file_path):
75
+ if isinstance(file_path, list):
76
+ state_dict = {}
77
+ for file_path_ in file_path:
78
+ state_dict.update(load_keys_dict(file_path_))
79
+ return state_dict
80
+ if file_path.endswith(".safetensors"):
81
+ return load_keys_dict_from_safetensors(file_path)
82
+ else:
83
+ return load_keys_dict_from_bin(file_path)
84
+
85
+
86
+ def load_keys_dict_from_safetensors(file_path):
87
+ keys_dict = {}
88
+ with safe_open(file_path, framework="pt", device="cpu") as f:
89
+ for k in f.keys():
90
+ keys_dict[k] = f.get_slice(k).get_shape()
91
+ return keys_dict
92
+
93
+
94
+ def convert_state_dict_to_keys_dict(state_dict):
95
+ keys_dict = {}
96
+ for k, v in state_dict.items():
97
+ if isinstance(v, torch.Tensor):
98
+ keys_dict[k] = list(v.shape)
99
+ else:
100
+ keys_dict[k] = convert_state_dict_to_keys_dict(v)
101
+ return keys_dict
102
+
103
+
104
+ def load_keys_dict_from_bin(file_path):
105
+ state_dict = load_state_dict_from_bin(file_path)
106
+ keys_dict = convert_state_dict_to_keys_dict(state_dict)
107
+ return keys_dict
108
+
109
+
110
+ def convert_keys_dict_to_single_str(state_dict, with_shape=True):
111
+ keys = []
112
+ for key, value in state_dict.items():
113
+ if isinstance(key, str):
114
+ if isinstance(value, dict):
115
+ keys.append(key + "|" + convert_keys_dict_to_single_str(value, with_shape=with_shape))
116
+ else:
117
+ if with_shape:
118
+ shape = "_".join(map(str, list(value)))
119
+ keys.append(key + ":" + shape)
120
+ keys.append(key)
121
+ keys.sort()
122
+ keys_str = ",".join(keys)
123
+ return keys_str
124
+
125
+
126
+ def hash_model_file(path, with_shape=True):
127
+ keys_dict = load_keys_dict(path)
128
+ keys_str = convert_keys_dict_to_single_str(keys_dict, with_shape=with_shape)
129
+ keys_str = keys_str.encode(encoding="UTF-8")
130
+ return hashlib.md5(keys_str).hexdigest()
diffsynth/core/loader/model.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..vram.initialization import skip_model_initialization
2
+ from ..vram.disk_map import DiskMap
3
+ from ..vram.layers import enable_vram_management
4
+ from .file import load_state_dict
5
+ import torch
6
+ from contextlib import contextmanager
7
+ from transformers.integrations import is_deepspeed_zero3_enabled
8
+ from transformers.utils import ContextManagers
9
+
10
+
11
+ def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, device="cpu", state_dict_converter=None, use_disk_map=False, module_map=None, vram_config=None, vram_limit=None, state_dict=None):
12
+ config = {} if config is None else config
13
+ # Skip ZeRO-3 initialization for VAE to avoid compatibility issues
14
+ skip_zero3 = 'vae' in model_class.__name__.lower() if hasattr(model_class, '__name__') else False
15
+ with ContextManagers(get_init_context(torch_dtype=torch_dtype, device=device, skip_zero3=skip_zero3)):
16
+ model = model_class(**config)
17
+ # What is `module_map`?
18
+ # This is a module mapping table for VRAM management.
19
+ if module_map is not None:
20
+ devices = [vram_config["offload_device"], vram_config["onload_device"], vram_config["preparing_device"], vram_config["computation_device"]]
21
+ device = [d for d in devices if d != "disk"][0]
22
+ dtypes = [vram_config["offload_dtype"], vram_config["onload_dtype"], vram_config["preparing_dtype"], vram_config["computation_dtype"]]
23
+ dtype = [d for d in dtypes if d != "disk"][0]
24
+ if vram_config["offload_device"] != "disk":
25
+ if state_dict is None: state_dict = DiskMap(path, device, torch_dtype=dtype)
26
+ if state_dict_converter is not None:
27
+ state_dict = state_dict_converter(state_dict)
28
+ else:
29
+ state_dict = {i: state_dict[i] for i in state_dict}
30
+ if is_deepspeed_zero3_enabled():
31
+ from transformers.integrations.deepspeed import _load_state_dict_into_zero3_model
32
+ _load_state_dict_into_zero3_model(model, state_dict)
33
+ else:
34
+ model.load_state_dict(state_dict, assign=True)
35
+ model = enable_vram_management(model, module_map, vram_config=vram_config, disk_map=None, vram_limit=vram_limit)
36
+ else:
37
+ disk_map = DiskMap(path, device, state_dict_converter=state_dict_converter)
38
+ model = enable_vram_management(model, module_map, vram_config=vram_config, disk_map=disk_map, vram_limit=vram_limit)
39
+ else:
40
+ # Why do we use `DiskMap`?
41
+ # Sometimes a model file contains multiple models,
42
+ # and DiskMap can load only the parameters of a single model,
43
+ # avoiding the need to load all parameters in the file.
44
+ if state_dict is not None:
45
+ pass
46
+ elif use_disk_map:
47
+ state_dict = DiskMap(path, device, torch_dtype=torch_dtype)
48
+ else:
49
+ state_dict = load_state_dict(path, torch_dtype, device)
50
+ # Why do we use `state_dict_converter`?
51
+ # Some models are saved in complex formats,
52
+ # and we need to convert the state dict into the appropriate format.
53
+ if state_dict_converter is not None:
54
+ state_dict = state_dict_converter(state_dict)
55
+ else:
56
+ state_dict = {i: state_dict[i] for i in state_dict}
57
+ # Why does DeepSpeed ZeRO Stage 3 need to be handled separately?
58
+ # Because at this stage, model parameters are partitioned across multiple GPUs.
59
+ # Loading them directly could lead to excessive GPU memory consumption.
60
+ if is_deepspeed_zero3_enabled():
61
+ from transformers.integrations.deepspeed import _load_state_dict_into_zero3_model
62
+ _load_state_dict_into_zero3_model(model, state_dict)
63
+ else:
64
+ model.load_state_dict(state_dict, assign=True)
65
+ # Why do we call `to()`?
66
+ # Because some models override the behavior of `to()`,
67
+ # especially those from libraries like Transformers.
68
+ model = model.to(dtype=torch_dtype, device=device)
69
+ if hasattr(model, "eval"):
70
+ model = model.eval()
71
+ return model
72
+
73
+
74
+ def load_model_with_disk_offload(model_class, path, config=None, torch_dtype=torch.bfloat16, device="cpu", state_dict_converter=None, module_map=None):
75
+ if isinstance(path, str):
76
+ path = [path]
77
+ config = {} if config is None else config
78
+ with skip_model_initialization():
79
+ model = model_class(**config)
80
+ if hasattr(model, "eval"):
81
+ model = model.eval()
82
+ disk_map = DiskMap(path, device, state_dict_converter=state_dict_converter)
83
+ vram_config = {
84
+ "offload_dtype": "disk",
85
+ "offload_device": "disk",
86
+ "onload_dtype": "disk",
87
+ "onload_device": "disk",
88
+ "preparing_dtype": torch.float8_e4m3fn,
89
+ "preparing_device": device,
90
+ "computation_dtype": torch_dtype,
91
+ "computation_device": device,
92
+ }
93
+ enable_vram_management(model, module_map, vram_config=vram_config, disk_map=disk_map, vram_limit=80)
94
+ return model
95
+
96
+
97
+ def get_init_context(torch_dtype, device, skip_zero3=False):
98
+ if is_deepspeed_zero3_enabled() and not skip_zero3:
99
+ from transformers.modeling_utils import set_zero3_state
100
+ import deepspeed
101
+ # Why do we use "deepspeed.zero.Init"?
102
+ # Weight segmentation of the model can be performed on the CPU side
103
+ # and loading the segmented weights onto the computing card
104
+ init_contexts = [deepspeed.zero.Init(remote_device=device, dtype=torch_dtype), set_zero3_state()]
105
+ elif skip_zero3:
106
+ # For models excluded from ZeRO-3 (e.g. VAE), use normal initialization
107
+ # instead of skip_model_initialization to avoid meta tensor issues
108
+ init_contexts = []
109
+ else:
110
+ # Why do we use `skip_model_initialization`?
111
+ # It skips the random initialization of model parameters,
112
+ # thereby speeding up model loading and avoiding excessive memory usage.
113
+ init_contexts = [skip_model_initialization()]
114
+
115
+ return init_contexts
diffsynth/core/npu_patch/npu_fused_operator.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from ..device.npu_compatible_device import get_device_type
3
+ try:
4
+ import torch_npu
5
+ except:
6
+ pass
7
+
8
+
9
+ def rms_norm_forward_npu(self, hidden_states):
10
+ "npu rms fused operator for RMSNorm.forward from diffsynth\models\general_modules.py"
11
+ if hidden_states.dtype != self.weight.dtype:
12
+ hidden_states = hidden_states.to(self.weight.dtype)
13
+ return torch_npu.npu_rms_norm(hidden_states, self.weight, self.eps)[0]
14
+
15
+
16
+ def rms_norm_forward_transformers_npu(self, hidden_states):
17
+ "npu rms fused operator for transformers"
18
+ if hidden_states.dtype != self.weight.dtype:
19
+ hidden_states = hidden_states.to(self.weight.dtype)
20
+ return torch_npu.npu_rms_norm(hidden_states, self.weight, self.variance_epsilon)[0]
21
+
22
+
23
+ def rotary_emb_Zimage_npu(self, x_in: torch.Tensor, freqs_cis: torch.Tensor):
24
+ "npu rope fused operator for Zimage"
25
+ with torch.amp.autocast(get_device_type(), enabled=False):
26
+ freqs_cis = freqs_cis.unsqueeze(2)
27
+ cos, sin = torch.chunk(torch.view_as_real(freqs_cis), 2, dim=-1)
28
+ cos = cos.expand(-1, -1, -1, -1, 2).flatten(-2)
29
+ sin = sin.expand(-1, -1, -1, -1, 2).flatten(-2)
30
+ return torch_npu.npu_rotary_mul(x_in, cos, sin, rotary_mode="interleave").to(x_in)
diffsynth/core/vram/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .initialization import skip_model_initialization
2
+ from .layers import *
diffsynth/core/vram/disk_map.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from safetensors import safe_open
2
+ import torch, os
3
+
4
+
5
+ class SafetensorsCompatibleTensor:
6
+ def __init__(self, tensor):
7
+ self.tensor = tensor
8
+
9
+ def get_shape(self):
10
+ return list(self.tensor.shape)
11
+
12
+
13
+ class SafetensorsCompatibleBinaryLoader:
14
+ def __init__(self, path, device):
15
+ print("Detected non-safetensors files, which may cause slower loading. It's recommended to convert it to a safetensors file.")
16
+ self.state_dict = torch.load(path, weights_only=True, map_location=device)
17
+
18
+ def keys(self):
19
+ return self.state_dict.keys()
20
+
21
+ def get_tensor(self, name):
22
+ return self.state_dict[name]
23
+
24
+ def get_slice(self, name):
25
+ return SafetensorsCompatibleTensor(self.state_dict[name])
26
+
27
+
28
+ class DiskMap:
29
+
30
+ def __init__(self, path, device, torch_dtype=None, state_dict_converter=None, buffer_size=10**9):
31
+ self.path = path if isinstance(path, list) else [path]
32
+ self.device = device
33
+ self.torch_dtype = torch_dtype
34
+ if os.environ.get('DIFFSYNTH_DISK_MAP_BUFFER_SIZE') is not None:
35
+ self.buffer_size = int(os.environ.get('DIFFSYNTH_DISK_MAP_BUFFER_SIZE'))
36
+ else:
37
+ self.buffer_size = buffer_size
38
+ self.files = []
39
+ self.flush_files()
40
+ self.name_map = {}
41
+ for file_id, file in enumerate(self.files):
42
+ for name in file.keys():
43
+ self.name_map[name] = file_id
44
+ self.rename_dict = self.fetch_rename_dict(state_dict_converter)
45
+
46
+ def flush_files(self):
47
+ if len(self.files) == 0:
48
+ for path in self.path:
49
+ if path.endswith(".safetensors"):
50
+ self.files.append(safe_open(path, framework="pt", device=str(self.device)))
51
+ else:
52
+ self.files.append(SafetensorsCompatibleBinaryLoader(path, device=self.device))
53
+ else:
54
+ for i, path in enumerate(self.path):
55
+ if path.endswith(".safetensors"):
56
+ self.files[i] = safe_open(path, framework="pt", device=str(self.device))
57
+ self.num_params = 0
58
+
59
+ def __getitem__(self, name):
60
+ if self.rename_dict is not None: name = self.rename_dict[name]
61
+ file_id = self.name_map[name]
62
+ param = self.files[file_id].get_tensor(name)
63
+ if self.torch_dtype is not None and isinstance(param, torch.Tensor):
64
+ param = param.to(self.torch_dtype)
65
+ if isinstance(param, torch.Tensor) and param.device == "cpu":
66
+ param = param.clone()
67
+ if isinstance(param, torch.Tensor):
68
+ self.num_params += param.numel()
69
+ if self.num_params > self.buffer_size:
70
+ self.flush_files()
71
+ return param
72
+
73
+ def fetch_rename_dict(self, state_dict_converter):
74
+ if state_dict_converter is None:
75
+ return None
76
+ state_dict = {}
77
+ for file in self.files:
78
+ for name in file.keys():
79
+ state_dict[name] = name
80
+ state_dict = state_dict_converter(state_dict)
81
+ return state_dict
82
+
83
+ def __iter__(self):
84
+ if self.rename_dict is not None:
85
+ return self.rename_dict.__iter__()
86
+ else:
87
+ return self.name_map.__iter__()
88
+
89
+ def __contains__(self, x):
90
+ if self.rename_dict is not None:
91
+ return x in self.rename_dict
92
+ else:
93
+ return x in self.name_map
diffsynth/core/vram/initialization.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from contextlib import contextmanager
3
+
4
+
5
+ @contextmanager
6
+ def skip_model_initialization(device=torch.device("meta")):
7
+
8
+ def register_empty_parameter(module, name, param):
9
+ old_register_parameter(module, name, param)
10
+ if param is not None:
11
+ param_cls = type(module._parameters[name])
12
+ kwargs = module._parameters[name].__dict__
13
+ kwargs["requires_grad"] = param.requires_grad
14
+ module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs)
15
+
16
+ old_register_parameter = torch.nn.Module.register_parameter
17
+ torch.nn.Module.register_parameter = register_empty_parameter
18
+ try:
19
+ yield
20
+ finally:
21
+ torch.nn.Module.register_parameter = old_register_parameter
diffsynth/core/vram/layers.py ADDED
@@ -0,0 +1,479 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, copy
2
+ from typing import Union
3
+ from .initialization import skip_model_initialization
4
+ from .disk_map import DiskMap
5
+ from ..device import parse_device_type, get_device_name, IS_NPU_AVAILABLE
6
+
7
+
8
+ class AutoTorchModule(torch.nn.Module):
9
+
10
+ def __init__(
11
+ self,
12
+ offload_dtype: torch.dtype = None,
13
+ offload_device: Union[str, torch.device] = None,
14
+ onload_dtype: torch.dtype = None,
15
+ onload_device: Union[str, torch.device] = None,
16
+ preparing_dtype: torch.dtype = None,
17
+ preparing_device: Union[str, torch.device] = None,
18
+ computation_dtype: torch.dtype = None,
19
+ computation_device: Union[str, torch.device] = None,
20
+ vram_limit: float = None,
21
+ ):
22
+ super().__init__()
23
+ self.set_dtype_and_device(
24
+ offload_dtype,
25
+ offload_device,
26
+ onload_dtype,
27
+ onload_device,
28
+ preparing_dtype,
29
+ preparing_device,
30
+ computation_dtype,
31
+ computation_device,
32
+ vram_limit,
33
+ )
34
+ self.state = 0
35
+ self.name = ""
36
+ self.computation_device_type = parse_device_type(self.computation_device)
37
+
38
+ def set_dtype_and_device(
39
+ self,
40
+ offload_dtype: torch.dtype = None,
41
+ offload_device: Union[str, torch.device] = None,
42
+ onload_dtype: torch.dtype = None,
43
+ onload_device: Union[str, torch.device] = None,
44
+ preparing_dtype: torch.dtype = None,
45
+ preparing_device: Union[str, torch.device] = None,
46
+ computation_dtype: torch.dtype = None,
47
+ computation_device: Union[str, torch.device] = None,
48
+ vram_limit: float = None,
49
+ ):
50
+ self.offload_dtype = offload_dtype or computation_dtype
51
+ self.offload_device = offload_device or computation_dtype
52
+ self.onload_dtype = onload_dtype or computation_dtype
53
+ self.onload_device = onload_device or computation_dtype
54
+ self.preparing_dtype = preparing_dtype or computation_dtype
55
+ self.preparing_device = preparing_device or computation_dtype
56
+ self.computation_dtype = computation_dtype
57
+ self.computation_device = computation_device
58
+ self.vram_limit = vram_limit
59
+
60
+ def cast_to(self, weight, dtype, device):
61
+ r = torch.empty_like(weight, dtype=dtype, device=device)
62
+ r.copy_(weight)
63
+ return r
64
+
65
+ def check_free_vram(self):
66
+ device = self.computation_device if not IS_NPU_AVAILABLE else get_device_name()
67
+ gpu_mem_state = getattr(torch, self.computation_device_type).mem_get_info(device)
68
+ used_memory = (gpu_mem_state[1] - gpu_mem_state[0]) / (1024**3)
69
+ return used_memory < self.vram_limit
70
+
71
+ def offload(self):
72
+ if self.state != 0:
73
+ self.to(dtype=self.offload_dtype, device=self.offload_device)
74
+ self.state = 0
75
+
76
+ def onload(self):
77
+ if self.state != 1:
78
+ self.to(dtype=self.onload_dtype, device=self.onload_device)
79
+ self.state = 1
80
+
81
+ def param_name(self, name):
82
+ if self.name == "":
83
+ return name
84
+ else:
85
+ return self.name + "." + name
86
+
87
+
88
+ class AutoWrappedModule(AutoTorchModule):
89
+
90
+ def __init__(
91
+ self,
92
+ module: torch.nn.Module,
93
+ offload_dtype: torch.dtype = None,
94
+ offload_device: Union[str, torch.device] = None,
95
+ onload_dtype: torch.dtype = None,
96
+ onload_device: Union[str, torch.device] = None,
97
+ preparing_dtype: torch.dtype = None,
98
+ preparing_device: Union[str, torch.device] = None,
99
+ computation_dtype: torch.dtype = None,
100
+ computation_device: Union[str, torch.device] = None,
101
+ vram_limit: float = None,
102
+ name: str = "",
103
+ disk_map: DiskMap = None,
104
+ **kwargs
105
+ ):
106
+ super().__init__(
107
+ offload_dtype,
108
+ offload_device,
109
+ onload_dtype,
110
+ onload_device,
111
+ preparing_dtype,
112
+ preparing_device,
113
+ computation_dtype,
114
+ computation_device,
115
+ vram_limit,
116
+ )
117
+ self.module = module
118
+ if offload_dtype == "disk":
119
+ self.name = name
120
+ self.disk_map = disk_map
121
+ self.required_params = [name for name, _ in self.module.named_parameters()]
122
+ self.disk_offload = True
123
+ else:
124
+ self.disk_offload = False
125
+
126
+ def load_from_disk(self, torch_dtype, device, copy_module=False):
127
+ if copy_module:
128
+ module = copy.deepcopy(self.module)
129
+ else:
130
+ module = self.module
131
+ state_dict = {}
132
+ for name in self.required_params:
133
+ param = self.disk_map[self.param_name(name)]
134
+ param = param.to(dtype=torch_dtype, device=device)
135
+ state_dict[name] = param
136
+ module.load_state_dict(state_dict, assign=True)
137
+ module.to(dtype=torch_dtype, device=device)
138
+ return module
139
+
140
+ def offload_to_disk(self, model: torch.nn.Module):
141
+ for buf in model.buffers():
142
+ # If there are some parameters are registed in buffers (not in state dict),
143
+ # We cannot offload the model.
144
+ for children in model.children():
145
+ self.offload_to_disk(children)
146
+ break
147
+ else:
148
+ model.to("meta")
149
+
150
+ def offload(self):
151
+ # offload / onload / preparing -> offload
152
+ if self.state != 0:
153
+ if self.disk_offload:
154
+ self.offload_to_disk(self.module)
155
+ else:
156
+ self.to(dtype=self.offload_dtype, device=self.offload_device)
157
+ self.state = 0
158
+
159
+ def onload(self):
160
+ # offload / onload / preparing -> onload
161
+ if self.state < 1:
162
+ if self.disk_offload and self.onload_device != "disk" and self.offload_device == "disk":
163
+ self.load_from_disk(self.onload_dtype, self.onload_device)
164
+ elif self.onload_device != "disk":
165
+ self.to(dtype=self.onload_dtype, device=self.onload_device)
166
+ self.state = 1
167
+
168
+ def preparing(self):
169
+ # onload / preparing -> preparing
170
+ if self.state != 2:
171
+ if self.disk_offload and self.preparing_device != "disk" and self.onload_device == "disk":
172
+ self.load_from_disk(self.preparing_dtype, self.preparing_device)
173
+ elif self.preparing_device != "disk":
174
+ self.to(dtype=self.preparing_dtype, device=self.preparing_device)
175
+ self.state = 2
176
+
177
+ def cast_to(self, module, dtype, device):
178
+ return copy.deepcopy(module).to(dtype=dtype, device=device)
179
+
180
+ def computation(self):
181
+ # onload / preparing -> computation (temporary)
182
+ if self.state == 2:
183
+ torch_dtype, device = self.preparing_dtype, self.preparing_device
184
+ else:
185
+ torch_dtype, device = self.onload_dtype, self.onload_device
186
+ if torch_dtype == self.computation_dtype and device == self.computation_device:
187
+ module = self.module
188
+ elif self.disk_offload and device == "disk":
189
+ module = self.load_from_disk(self.computation_dtype, self.computation_device, copy_module=True)
190
+ else:
191
+ module = self.cast_to(self.module, dtype=self.computation_dtype, device=self.computation_device)
192
+ return module
193
+
194
+ def forward(self, *args, **kwargs):
195
+ if self.state == 1 and (self.vram_limit is None or self.check_free_vram()):
196
+ self.preparing()
197
+ module = self.computation()
198
+ return module(*args, **kwargs)
199
+
200
+ def __getattr__(self, name):
201
+ if name in self.__dict__ or name == "module":
202
+ return super().__getattr__(name)
203
+ else:
204
+ return getattr(self.module, name)
205
+
206
+
207
+ class AutoWrappedNonRecurseModule(AutoWrappedModule):
208
+
209
+ def __init__(
210
+ self,
211
+ module: torch.nn.Module,
212
+ offload_dtype: torch.dtype = None,
213
+ offload_device: Union[str, torch.device] = None,
214
+ onload_dtype: torch.dtype = None,
215
+ onload_device: Union[str, torch.device] = None,
216
+ preparing_dtype: torch.dtype = None,
217
+ preparing_device: Union[str, torch.device] = None,
218
+ computation_dtype: torch.dtype = None,
219
+ computation_device: Union[str, torch.device] = None,
220
+ vram_limit: float = None,
221
+ name: str = "",
222
+ disk_map: DiskMap = None,
223
+ **kwargs
224
+ ):
225
+ super().__init__(
226
+ module,
227
+ offload_dtype,
228
+ offload_device,
229
+ onload_dtype,
230
+ onload_device,
231
+ preparing_dtype,
232
+ preparing_device,
233
+ computation_dtype,
234
+ computation_device,
235
+ vram_limit,
236
+ name,
237
+ disk_map,
238
+ **kwargs
239
+ )
240
+ if self.disk_offload:
241
+ self.required_params = [name for name, _ in self.module.named_parameters(recurse=False)]
242
+
243
+ def load_from_disk(self, torch_dtype, device, copy_module=False):
244
+ if copy_module:
245
+ module = copy.deepcopy(self.module)
246
+ else:
247
+ module = self.module
248
+ state_dict = {}
249
+ for name in self.required_params:
250
+ param = self.disk_map[self.param_name(name)]
251
+ param = param.to(dtype=torch_dtype, device=device)
252
+ state_dict[name] = param
253
+ module.load_state_dict(state_dict, assign=True, strict=False)
254
+ return module
255
+
256
+ def offload_to_disk(self, model: torch.nn.Module):
257
+ for name in self.required_params:
258
+ getattr(self, name).to("meta")
259
+
260
+ def cast_to(self, module, dtype, device):
261
+ # Parameter casting is implemented in the model architecture.
262
+ return module
263
+
264
+ def __getattr__(self, name):
265
+ if name in self.__dict__ or name == "module":
266
+ return super().__getattr__(name)
267
+ else:
268
+ return getattr(self.module, name)
269
+
270
+
271
+ class AutoWrappedLinear(torch.nn.Linear, AutoTorchModule):
272
+ def __init__(
273
+ self,
274
+ module: torch.nn.Linear,
275
+ offload_dtype: torch.dtype = None,
276
+ offload_device: Union[str, torch.device] = None,
277
+ onload_dtype: torch.dtype = None,
278
+ onload_device: Union[str, torch.device] = None,
279
+ preparing_dtype: torch.dtype = None,
280
+ preparing_device: Union[str, torch.device] = None,
281
+ computation_dtype: torch.dtype = None,
282
+ computation_device: Union[str, torch.device] = None,
283
+ vram_limit: float = None,
284
+ name: str = "",
285
+ disk_map: DiskMap = None,
286
+ **kwargs
287
+ ):
288
+ with skip_model_initialization():
289
+ super().__init__(
290
+ in_features=module.in_features,
291
+ out_features=module.out_features,
292
+ bias=module.bias is not None,
293
+ )
294
+ self.set_dtype_and_device(
295
+ offload_dtype,
296
+ offload_device,
297
+ onload_dtype,
298
+ onload_device,
299
+ preparing_dtype,
300
+ preparing_device,
301
+ computation_dtype,
302
+ computation_device,
303
+ vram_limit,
304
+ )
305
+ self.weight = module.weight
306
+ self.bias = module.bias
307
+ self.state = 0
308
+ self.name = name
309
+ self.lora_A_weights = []
310
+ self.lora_B_weights = []
311
+ self.lora_merger = None
312
+ self.enable_fp8 = computation_dtype in [torch.float8_e4m3fn, torch.float8_e4m3fnuz]
313
+ self.computation_device_type = parse_device_type(self.computation_device)
314
+
315
+ if offload_dtype == "disk":
316
+ self.disk_map = disk_map
317
+ self.disk_offload = True
318
+ else:
319
+ self.disk_offload = False
320
+
321
+ def fp8_linear(
322
+ self,
323
+ input: torch.Tensor,
324
+ weight: torch.Tensor,
325
+ bias: torch.Tensor = None,
326
+ ) -> torch.Tensor:
327
+ device = input.device
328
+ origin_dtype = input.dtype
329
+ origin_shape = input.shape
330
+ input = input.reshape(-1, origin_shape[-1])
331
+
332
+ x_max = torch.max(torch.abs(input), dim=-1, keepdim=True).values
333
+ fp8_max = 448.0
334
+ # For float8_e4m3fnuz, the maximum representable value is half of that of e4m3fn.
335
+ # To avoid overflow and ensure numerical compatibility during FP8 computation,
336
+ # we scale down the input by 2.0 in advance.
337
+ # This scaling will be compensated later during the final result scaling.
338
+ if self.computation_dtype == torch.float8_e4m3fnuz:
339
+ fp8_max = fp8_max / 2.0
340
+ scale_a = torch.clamp(x_max / fp8_max, min=1.0).float().to(device=device)
341
+ scale_b = torch.ones((weight.shape[0], 1)).to(device=device)
342
+ input = input / (scale_a + 1e-8)
343
+ input = input.to(self.computation_dtype)
344
+ weight = weight.to(self.computation_dtype)
345
+ bias = bias.to(torch.bfloat16)
346
+
347
+ result = torch._scaled_mm(
348
+ input,
349
+ weight.T,
350
+ scale_a=scale_a,
351
+ scale_b=scale_b.T,
352
+ bias=bias,
353
+ out_dtype=origin_dtype,
354
+ )
355
+ new_shape = origin_shape[:-1] + result.shape[-1:]
356
+ result = result.reshape(new_shape)
357
+ return result
358
+
359
+ def load_from_disk(self, torch_dtype, device, assign=True):
360
+ weight = self.disk_map[self.name + ".weight"].to(dtype=torch_dtype, device=device)
361
+ bias = None if self.bias is None else self.disk_map[self.name + ".bias"].to(dtype=torch_dtype, device=device)
362
+ if assign:
363
+ state_dict = {"weight": weight}
364
+ if bias is not None: state_dict["bias"] = bias
365
+ self.load_state_dict(state_dict, assign=True)
366
+ return weight, bias
367
+
368
+ def offload(self):
369
+ # offload / onload / preparing -> offload
370
+ if self.state != 0:
371
+ if self.disk_offload:
372
+ self.to("meta")
373
+ else:
374
+ self.to(dtype=self.offload_dtype, device=self.offload_device)
375
+ self.state = 0
376
+
377
+ def onload(self):
378
+ # offload / onload / preparing -> onload
379
+ if self.state < 1:
380
+ if self.disk_offload and self.onload_device != "disk" and self.offload_device == "disk":
381
+ self.load_from_disk(self.onload_dtype, self.onload_device)
382
+ elif self.onload_device != "disk":
383
+ self.to(dtype=self.onload_dtype, device=self.onload_device)
384
+ self.state = 1
385
+
386
+ def preparing(self):
387
+ # onload / preparing -> preparing
388
+ if self.state != 2:
389
+ if self.disk_offload and self.preparing_device != "disk" and self.onload_device == "disk":
390
+ self.load_from_disk(self.preparing_dtype, self.preparing_device)
391
+ elif self.preparing_device != "disk":
392
+ self.to(dtype=self.preparing_dtype, device=self.preparing_device)
393
+ self.state = 2
394
+
395
+ def computation(self):
396
+ # onload / preparing -> computation (temporary)
397
+ if self.state == 2:
398
+ torch_dtype, device = self.preparing_dtype, self.preparing_device
399
+ else:
400
+ torch_dtype, device = self.onload_dtype, self.onload_device
401
+ if torch_dtype == self.computation_dtype and device == self.computation_device:
402
+ weight, bias = self.weight, self.bias
403
+ elif self.disk_offload and device == "disk":
404
+ weight, bias = self.load_from_disk(self.computation_dtype, self.computation_device, assign=False)
405
+ else:
406
+ weight = self.cast_to(self.weight, self.computation_dtype, self.computation_device)
407
+ bias = None if self.bias is None else self.cast_to(self.bias, self.computation_dtype, self.computation_device)
408
+ return weight, bias
409
+
410
+ def linear_forward(self, x, weight, bias):
411
+ if self.enable_fp8:
412
+ out = self.fp8_linear(x, weight, bias)
413
+ else:
414
+ out = torch.nn.functional.linear(x, weight, bias)
415
+ return out
416
+
417
+ def lora_forward(self, x, out):
418
+ if self.lora_merger is None:
419
+ for lora_A, lora_B in zip(self.lora_A_weights, self.lora_B_weights):
420
+ out = out + x @ lora_A.T.to(device=x.device, dtype=x.dtype) @ lora_B.T.to(device=x.device, dtype=x.dtype)
421
+ else:
422
+ lora_output = []
423
+ for lora_A, lora_B in zip(self.lora_A_weights, self.lora_B_weights):
424
+ lora_output.append(x @ lora_A.T @ lora_B.T)
425
+ lora_output = torch.stack(lora_output)
426
+ out = self.lora_merger(out, lora_output)
427
+ return out
428
+
429
+ def forward(self, x, *args, **kwargs):
430
+ if self.state == 1 and (self.vram_limit is None or self.check_free_vram()):
431
+ self.preparing()
432
+ weight, bias = self.computation()
433
+ out = self.linear_forward(x, weight, bias)
434
+ if len(self.lora_A_weights) > 0:
435
+ out = self.lora_forward(x, out)
436
+ return out
437
+
438
+
439
+ def enable_vram_management_recursively(model: torch.nn.Module, module_map: dict, vram_config: dict, vram_limit=None, name_prefix="", disk_map=None, **kwargs):
440
+ if isinstance(model, AutoWrappedNonRecurseModule):
441
+ model = model.module
442
+ for name, module in model.named_children():
443
+ layer_name = name if name_prefix == "" else name_prefix + "." + name
444
+ for source_module, target_module in module_map.items():
445
+ if isinstance(module, source_module):
446
+ module_ = target_module(module, **vram_config, vram_limit=vram_limit, name=layer_name, disk_map=disk_map, **kwargs)
447
+ if isinstance(module_, AutoWrappedNonRecurseModule):
448
+ enable_vram_management_recursively(module_, module_map, vram_config, vram_limit=vram_limit, name_prefix=layer_name, disk_map=disk_map, **kwargs)
449
+ setattr(model, name, module_)
450
+ break
451
+ else:
452
+ enable_vram_management_recursively(module, module_map, vram_config, vram_limit=vram_limit, name_prefix=layer_name, disk_map=disk_map, **kwargs)
453
+
454
+
455
+ def fill_vram_config(model, vram_config):
456
+ vram_config_ = vram_config.copy()
457
+ vram_config_["onload_dtype"] = vram_config["computation_dtype"]
458
+ vram_config_["onload_device"] = vram_config["computation_device"]
459
+ vram_config_["preparing_dtype"] = vram_config["computation_dtype"]
460
+ vram_config_["preparing_device"] = vram_config["computation_device"]
461
+ for k in vram_config:
462
+ if vram_config[k] != vram_config_[k]:
463
+ print(f"No fine-grained VRAM configuration is provided for {model.__class__.__name__}. [`onload`, `preparing`, `computation`] will be the same state. `vram_config` is set to {vram_config_}")
464
+ break
465
+ return vram_config_
466
+
467
+
468
+ def enable_vram_management(model: torch.nn.Module, module_map: dict, vram_config: dict, vram_limit=None, disk_map=None, **kwargs):
469
+ for source_module, target_module in module_map.items():
470
+ # If no fine-grained VRAM configuration is provided, the entire model will be managed uniformly.
471
+ if isinstance(model, source_module):
472
+ vram_config = fill_vram_config(model, vram_config)
473
+ model = target_module(model, **vram_config, vram_limit=vram_limit, disk_map=disk_map, **kwargs)
474
+ break
475
+ else:
476
+ enable_vram_management_recursively(model, module_map, vram_config, vram_limit=vram_limit, disk_map=disk_map, **kwargs)
477
+ # `vram_management_enabled` is a flag that allows the pipeline to determine whether VRAM management is enabled.
478
+ model.vram_management_enabled = True
479
+ return model
diffsynth/diffusion/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from .flow_match import FlowMatchScheduler
2
+ from .training_module import DiffusionTrainingModule
3
+ from .logger import ModelLogger
4
+ from .runner import launch_training_task, launch_data_process_task
5
+ from .parsers import *
6
+ from .loss import *
diffsynth/diffusion/base_pipeline.py ADDED
@@ -0,0 +1,500 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import torch
3
+ import numpy as np
4
+ from einops import repeat, reduce
5
+ from typing import Union
6
+ from ..core import AutoTorchModule, AutoWrappedLinear, load_state_dict, ModelConfig, parse_device_type
7
+ from ..core.device.npu_compatible_device import get_device_type
8
+ from ..utils.lora import GeneralLoRALoader
9
+ from ..models.model_loader import ModelPool
10
+ from ..utils.controlnet import ControlNetInput
11
+ from ..core.device import get_device_name, IS_NPU_AVAILABLE
12
+
13
+
14
+ class PipelineUnit:
15
+ def __init__(
16
+ self,
17
+ seperate_cfg: bool = False,
18
+ take_over: bool = False,
19
+ input_params: tuple[str] = None,
20
+ output_params: tuple[str] = None,
21
+ input_params_posi: dict[str, str] = None,
22
+ input_params_nega: dict[str, str] = None,
23
+ onload_model_names: tuple[str] = None
24
+ ):
25
+ self.seperate_cfg = seperate_cfg
26
+ self.take_over = take_over
27
+ self.input_params = input_params
28
+ self.output_params = output_params
29
+ self.input_params_posi = input_params_posi
30
+ self.input_params_nega = input_params_nega
31
+ self.onload_model_names = onload_model_names
32
+
33
+ def fetch_input_params(self):
34
+ params = []
35
+ if self.input_params is not None:
36
+ for param in self.input_params:
37
+ params.append(param)
38
+ if self.input_params_posi is not None:
39
+ for _, param in self.input_params_posi.items():
40
+ params.append(param)
41
+ if self.input_params_nega is not None:
42
+ for _, param in self.input_params_nega.items():
43
+ params.append(param)
44
+ params = sorted(list(set(params)))
45
+ return params
46
+
47
+ def fetch_output_params(self):
48
+ params = []
49
+ if self.output_params is not None:
50
+ for param in self.output_params:
51
+ params.append(param)
52
+ return params
53
+
54
+ def process(self, pipe, **kwargs) -> dict:
55
+ return {}
56
+
57
+ def post_process(self, pipe, **kwargs) -> dict:
58
+ return {}
59
+
60
+
61
+ class BasePipeline(torch.nn.Module):
62
+
63
+ def __init__(
64
+ self,
65
+ device=get_device_type(), torch_dtype=torch.float16,
66
+ height_division_factor=64, width_division_factor=64,
67
+ time_division_factor=None, time_division_remainder=None,
68
+ ):
69
+ super().__init__()
70
+ # The device and torch_dtype is used for the storage of intermediate variables, not models.
71
+ self.device = device
72
+ self.torch_dtype = torch_dtype
73
+ self.device_type = parse_device_type(device)
74
+ # The following parameters are used for shape check.
75
+ self.height_division_factor = height_division_factor
76
+ self.width_division_factor = width_division_factor
77
+ self.time_division_factor = time_division_factor
78
+ self.time_division_remainder = time_division_remainder
79
+ # VRAM management
80
+ self.vram_management_enabled = False
81
+ # Pipeline Unit Runner
82
+ self.unit_runner = PipelineUnitRunner()
83
+ # LoRA Loader
84
+ self.lora_loader = GeneralLoRALoader
85
+
86
+
87
+ def to(self, *args, **kwargs):
88
+ device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
89
+ if device is not None:
90
+ self.device = device
91
+ if dtype is not None:
92
+ self.torch_dtype = dtype
93
+ super().to(*args, **kwargs)
94
+ return self
95
+
96
+
97
+ def check_resize_height_width(self, height, width, num_frames=None, verbose=1):
98
+ # Shape check
99
+ if height % self.height_division_factor != 0:
100
+ height = (height + self.height_division_factor - 1) // self.height_division_factor * self.height_division_factor
101
+ if verbose > 0:
102
+ print(f"height % {self.height_division_factor} != 0. We round it up to {height}.")
103
+ if width % self.width_division_factor != 0:
104
+ width = (width + self.width_division_factor - 1) // self.width_division_factor * self.width_division_factor
105
+ if verbose > 0:
106
+ print(f"width % {self.width_division_factor} != 0. We round it up to {width}.")
107
+ if num_frames is None:
108
+ return height, width
109
+ else:
110
+ if num_frames % self.time_division_factor != self.time_division_remainder:
111
+ num_frames = (num_frames + self.time_division_factor - 1) // self.time_division_factor * self.time_division_factor + self.time_division_remainder
112
+ if verbose > 0:
113
+ print(f"num_frames % {self.time_division_factor} != {self.time_division_remainder}. We round it up to {num_frames}.")
114
+ return height, width, num_frames
115
+
116
+
117
+ def preprocess_image(self, image, torch_dtype=None, device=None, pattern="B C H W", min_value=-1, max_value=1):
118
+ # Transform a PIL.Image to torch.Tensor
119
+ image = torch.Tensor(np.array(image, dtype=np.float32))
120
+ image = image.to(dtype=torch_dtype or self.torch_dtype, device=device or self.device)
121
+ image = image * ((max_value - min_value) / 255) + min_value
122
+ image = repeat(image, f"H W C -> {pattern}", **({"B": 1} if "B" in pattern else {}))
123
+ return image
124
+
125
+
126
+ def preprocess_video(self, video, torch_dtype=None, device=None, pattern="B C T H W", min_value=-1, max_value=1):
127
+ # Transform a list of PIL.Image to torch.Tensor
128
+ video = [self.preprocess_image(image, torch_dtype=torch_dtype, device=device, min_value=min_value, max_value=max_value) for image in video]
129
+ video = torch.stack(video, dim=pattern.index("T") // 2)
130
+ return video
131
+
132
+
133
+ def vae_output_to_image(self, vae_output, pattern="B C H W", min_value=-1, max_value=1):
134
+ # Transform a torch.Tensor to PIL.Image
135
+ if pattern != "H W C":
136
+ vae_output = reduce(vae_output, f"{pattern} -> H W C", reduction="mean")
137
+ image = ((vae_output - min_value) * (255 / (max_value - min_value))).clip(0, 255)
138
+ image = image.to(device="cpu", dtype=torch.uint8)
139
+ image = Image.fromarray(image.numpy())
140
+ return image
141
+
142
+
143
+ def vae_output_to_video(self, vae_output, pattern="B C T H W", min_value=-1, max_value=1):
144
+ # Transform a torch.Tensor to list of PIL.Image
145
+ if pattern != "T H W C":
146
+ vae_output = reduce(vae_output, f"{pattern} -> T H W C", reduction="mean")
147
+ video = [self.vae_output_to_image(image, pattern="H W C", min_value=min_value, max_value=max_value) for image in vae_output]
148
+ return video
149
+
150
+ def output_audio_format_check(self, audio_output):
151
+ # output standard foramt: [C, T], output dtype: float()
152
+ # remove batch dim
153
+ if audio_output.ndim == 3:
154
+ audio_output = audio_output.squeeze(0)
155
+ return audio_output.float()
156
+
157
+ def load_models_to_device(self, model_names):
158
+ if self.vram_management_enabled:
159
+ # offload models
160
+ for name, model in self.named_children():
161
+ if name not in model_names:
162
+ if hasattr(model, "vram_management_enabled") and model.vram_management_enabled:
163
+ if hasattr(model, "offload"):
164
+ model.offload()
165
+ else:
166
+ for module in model.modules():
167
+ if hasattr(module, "offload"):
168
+ module.offload()
169
+ getattr(torch, self.device_type).empty_cache()
170
+ # onload models
171
+ for name, model in self.named_children():
172
+ if name in model_names:
173
+ if hasattr(model, "vram_management_enabled") and model.vram_management_enabled:
174
+ if hasattr(model, "onload"):
175
+ model.onload()
176
+ else:
177
+ for module in model.modules():
178
+ if hasattr(module, "onload"):
179
+ module.onload()
180
+
181
+
182
+ def generate_noise(self, shape, seed=None, rand_device="cpu", rand_torch_dtype=torch.float32, device=None, torch_dtype=None):
183
+ # Initialize Gaussian noise
184
+ generator = None if seed is None else torch.Generator(rand_device).manual_seed(seed)
185
+ noise = torch.randn(shape, generator=generator, device=rand_device, dtype=rand_torch_dtype)
186
+ noise = noise.to(dtype=torch_dtype or self.torch_dtype, device=device or self.device)
187
+ return noise
188
+
189
+
190
+ def get_vram(self):
191
+ device = self.device if not IS_NPU_AVAILABLE else get_device_name()
192
+ return getattr(torch, self.device_type).mem_get_info(device)[1] / (1024 ** 3)
193
+
194
+ def get_module(self, model, name):
195
+ if "." in name:
196
+ name, suffix = name[:name.index(".")], name[name.index(".") + 1:]
197
+ if name.isdigit():
198
+ return self.get_module(model[int(name)], suffix)
199
+ else:
200
+ return self.get_module(getattr(model, name), suffix)
201
+ else:
202
+ return getattr(model, name)
203
+
204
+ def freeze_except(self, model_names):
205
+ self.eval()
206
+ self.requires_grad_(False)
207
+ for name in model_names:
208
+ module = self.get_module(self, name)
209
+ if module is None:
210
+ print(f"No {name} models in the pipeline. We cannot enable training on the model. If this occurs during the data processing stage, it is normal.")
211
+ continue
212
+ module.train()
213
+ module.requires_grad_(True)
214
+
215
+
216
+ def blend_with_mask(self, base, addition, mask):
217
+ return base * (1 - mask) + addition * mask
218
+
219
+
220
+ def step(self, scheduler, latents, progress_id, noise_pred, input_latents=None, inpaint_mask=None, **kwargs):
221
+ timestep = scheduler.timesteps[progress_id]
222
+ if inpaint_mask is not None:
223
+ noise_pred_expected = scheduler.return_to_timestep(scheduler.timesteps[progress_id], latents, input_latents)
224
+ noise_pred = self.blend_with_mask(noise_pred_expected, noise_pred, inpaint_mask)
225
+ latents_next = scheduler.step(noise_pred, timestep, latents)
226
+ return latents_next
227
+
228
+
229
+ def split_pipeline_units(self, model_names: list[str]):
230
+ return PipelineUnitGraph().split_pipeline_units(self.units, model_names)
231
+
232
+
233
+ def flush_vram_management_device(self, device):
234
+ for module in self.modules():
235
+ if isinstance(module, AutoTorchModule):
236
+ module.offload_device = device
237
+ module.onload_device = device
238
+ module.preparing_device = device
239
+ module.computation_device = device
240
+
241
+
242
+ def load_lora(
243
+ self,
244
+ module: torch.nn.Module,
245
+ lora_config: Union[ModelConfig, str] = None,
246
+ alpha=1,
247
+ hotload=None,
248
+ state_dict=None,
249
+ verbose=1,
250
+ ):
251
+ if state_dict is None:
252
+ if isinstance(lora_config, str):
253
+ lora = load_state_dict(lora_config, torch_dtype=self.torch_dtype, device=self.device)
254
+ else:
255
+ lora_config.download_if_necessary()
256
+ lora = load_state_dict(lora_config.path, torch_dtype=self.torch_dtype, device=self.device)
257
+ else:
258
+ lora = state_dict
259
+ lora_loader = self.lora_loader(torch_dtype=self.torch_dtype, device=self.device)
260
+ lora = lora_loader.convert_state_dict(lora)
261
+ if hotload is None:
262
+ hotload = hasattr(module, "vram_management_enabled") and getattr(module, "vram_management_enabled")
263
+ if hotload:
264
+ if not (hasattr(module, "vram_management_enabled") and getattr(module, "vram_management_enabled")):
265
+ raise ValueError("VRAM Management is not enabled. LoRA hotloading is not supported.")
266
+ updated_num = 0
267
+ for _, module in module.named_modules():
268
+ if isinstance(module, AutoWrappedLinear):
269
+ name = module.name
270
+ lora_a_name = f'{name}.lora_A.weight'
271
+ lora_b_name = f'{name}.lora_B.weight'
272
+ if lora_a_name in lora and lora_b_name in lora:
273
+ updated_num += 1
274
+ module.lora_A_weights.append(lora[lora_a_name] * alpha)
275
+ module.lora_B_weights.append(lora[lora_b_name])
276
+ if verbose >= 1:
277
+ print(f"{updated_num} tensors are patched by LoRA. You can use `pipe.clear_lora()` to clear all LoRA layers.")
278
+ else:
279
+ lora_loader.fuse_lora_to_base_model(module, lora, alpha=alpha)
280
+
281
+
282
+ def clear_lora(self, verbose=1):
283
+ cleared_num = 0
284
+ for name, module in self.named_modules():
285
+ if isinstance(module, AutoWrappedLinear):
286
+ if hasattr(module, "lora_A_weights"):
287
+ if len(module.lora_A_weights) > 0:
288
+ cleared_num += 1
289
+ module.lora_A_weights.clear()
290
+ if hasattr(module, "lora_B_weights"):
291
+ module.lora_B_weights.clear()
292
+ if verbose >= 1:
293
+ print(f"{cleared_num} LoRA layers are cleared.")
294
+
295
+
296
+ def download_and_load_models(self, model_configs: list[ModelConfig] = [], vram_limit: float = None):
297
+ model_pool = ModelPool()
298
+ for model_config in model_configs:
299
+ model_config.download_if_necessary()
300
+ vram_config = model_config.vram_config()
301
+ vram_config["computation_dtype"] = vram_config["computation_dtype"] or self.torch_dtype
302
+ vram_config["computation_device"] = vram_config["computation_device"] or self.device
303
+ model_pool.auto_load_model(
304
+ model_config.path,
305
+ vram_config=vram_config,
306
+ vram_limit=vram_limit,
307
+ clear_parameters=model_config.clear_parameters,
308
+ state_dict=model_config.state_dict,
309
+ )
310
+ return model_pool
311
+
312
+
313
+ def check_vram_management_state(self):
314
+ vram_management_enabled = False
315
+ for module in self.children():
316
+ if hasattr(module, "vram_management_enabled") and getattr(module, "vram_management_enabled"):
317
+ vram_management_enabled = True
318
+ return vram_management_enabled
319
+
320
+
321
+ def cfg_guided_model_fn(self, model_fn, cfg_scale, inputs_shared, inputs_posi, inputs_nega, **inputs_others):
322
+ if inputs_shared.get("positive_only_lora", None) is not None:
323
+ self.clear_lora(verbose=0)
324
+ self.load_lora(self.dit, state_dict=inputs_shared["positive_only_lora"], verbose=0)
325
+ noise_pred_posi = model_fn(**inputs_posi, **inputs_shared, **inputs_others)
326
+ if cfg_scale != 1.0:
327
+ if inputs_shared.get("positive_only_lora", None) is not None:
328
+ self.clear_lora(verbose=0)
329
+ noise_pred_nega = model_fn(**inputs_nega, **inputs_shared, **inputs_others)
330
+ if isinstance(noise_pred_posi, tuple):
331
+ # Separately handling different output types of latents, eg. video and audio latents.
332
+ noise_pred = tuple(
333
+ n_nega + cfg_scale * (n_posi - n_nega)
334
+ for n_posi, n_nega in zip(noise_pred_posi, noise_pred_nega)
335
+ )
336
+ else:
337
+ noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
338
+ else:
339
+ noise_pred = noise_pred_posi
340
+ return noise_pred
341
+
342
+ def compile_pipeline(self, mode: str = "default", dynamic: bool = True, fullgraph: bool = False, compile_models: list = None, **kwargs):
343
+ """
344
+ compile the pipeline with torch.compile. The models that will be compiled are determined by the `compilable_models` attribute of the pipeline.
345
+ If a model has `_repeated_blocks` attribute, we will compile these blocks with regional compilation. Otherwise, we will compile the whole model.
346
+ See https://docs.pytorch.org/docs/stable/generated/torch.compile.html#torch.compile for details about compilation arguments.
347
+ Args:
348
+ mode: The compilation mode, which will be passed to `torch.compile`, options are "default", "reduce-overhead", "max-autotune" and "max-autotune-no-cudagraphs. Default to "default".
349
+ dynamic: Whether to enable dynamic graph compilation to support dynamic input shapes, which will be passed to `torch.compile`. Default to True (recommended).
350
+ fullgraph: Whether to use full graph compilation, which will be passed to `torch.compile`. Default to False (recommended).
351
+ compile_models: The list of model names to be compiled. If None, we will compile the models in `pipeline.compilable_models`. Default to None.
352
+ **kwargs: Other arguments for `torch.compile`.
353
+ """
354
+ compile_models = compile_models or getattr(self, "compilable_models", [])
355
+ if len(compile_models) == 0:
356
+ print("No compilable models in the pipeline. Skip compilation.")
357
+ return
358
+ for name in compile_models:
359
+ model = getattr(self, name, None)
360
+ if model is None:
361
+ print(f"Model '{name}' not found in the pipeline.")
362
+ continue
363
+ repeated_blocks = getattr(model, "_repeated_blocks", None)
364
+ # regional compilation for repeated blocks.
365
+ if repeated_blocks is not None:
366
+ for submod in model.modules():
367
+ if submod.__class__.__name__ in repeated_blocks:
368
+ submod.compile(mode=mode, dynamic=dynamic, fullgraph=fullgraph, **kwargs)
369
+ # compile the whole model.
370
+ else:
371
+ model.compile(mode=mode, dynamic=dynamic, fullgraph=fullgraph, **kwargs)
372
+ print(f"{name} is compiled with mode={mode}, dynamic={dynamic}, fullgraph={fullgraph}.")
373
+
374
+
375
+ class PipelineUnitGraph:
376
+ def __init__(self):
377
+ pass
378
+
379
+ def build_edges(self, units: list[PipelineUnit]):
380
+ # Establish dependencies between units
381
+ # to search for subsequent related computation units.
382
+ last_compute_unit_id = {}
383
+ edges = []
384
+ for unit_id, unit in enumerate(units):
385
+ for input_param in unit.fetch_input_params():
386
+ if input_param in last_compute_unit_id:
387
+ edges.append((last_compute_unit_id[input_param], unit_id))
388
+ for output_param in unit.fetch_output_params():
389
+ last_compute_unit_id[output_param] = unit_id
390
+ return edges
391
+
392
+ def build_chains(self, units: list[PipelineUnit]):
393
+ # Establish updating chains for each variable
394
+ # to track their computation process.
395
+ params = sum([unit.fetch_input_params() + unit.fetch_output_params() for unit in units], [])
396
+ params = sorted(list(set(params)))
397
+ chains = {param: [] for param in params}
398
+ for unit_id, unit in enumerate(units):
399
+ for param in unit.fetch_output_params():
400
+ chains[param].append(unit_id)
401
+ return chains
402
+
403
+ def search_direct_unit_ids(self, units: list[PipelineUnit], model_names: list[str]):
404
+ # Search for units that directly participate in the model's computation.
405
+ related_unit_ids = []
406
+ for unit_id, unit in enumerate(units):
407
+ for model_name in model_names:
408
+ if unit.onload_model_names is not None and model_name in unit.onload_model_names:
409
+ related_unit_ids.append(unit_id)
410
+ break
411
+ return related_unit_ids
412
+
413
+ def search_related_unit_ids(self, edges, start_unit_ids, direction="target"):
414
+ # Search for subsequent related computation units.
415
+ related_unit_ids = [unit_id for unit_id in start_unit_ids]
416
+ while True:
417
+ neighbors = []
418
+ for source, target in edges:
419
+ if direction == "target" and source in related_unit_ids and target not in related_unit_ids:
420
+ neighbors.append(target)
421
+ elif direction == "source" and source not in related_unit_ids and target in related_unit_ids:
422
+ neighbors.append(source)
423
+ neighbors = sorted(list(set(neighbors)))
424
+ if len(neighbors) == 0:
425
+ break
426
+ else:
427
+ related_unit_ids.extend(neighbors)
428
+ related_unit_ids = sorted(list(set(related_unit_ids)))
429
+ return related_unit_ids
430
+
431
+ def search_updating_unit_ids(self, units: list[PipelineUnit], chains, related_unit_ids):
432
+ # If the input parameters of this subgraph are updated outside the subgraph,
433
+ # search for the units where these updates occur.
434
+ first_compute_unit_id = {}
435
+ for unit_id in related_unit_ids:
436
+ for param in units[unit_id].fetch_input_params():
437
+ if param not in first_compute_unit_id:
438
+ first_compute_unit_id[param] = unit_id
439
+ updating_unit_ids = []
440
+ for param in first_compute_unit_id:
441
+ unit_id = first_compute_unit_id[param]
442
+ chain = chains[param]
443
+ if unit_id in chain and chain.index(unit_id) != len(chain) - 1:
444
+ for unit_id_ in chain[chain.index(unit_id) + 1:]:
445
+ if unit_id_ not in related_unit_ids:
446
+ updating_unit_ids.append(unit_id_)
447
+ related_unit_ids.extend(updating_unit_ids)
448
+ related_unit_ids = sorted(list(set(related_unit_ids)))
449
+ return related_unit_ids
450
+
451
+ def split_pipeline_units(self, units: list[PipelineUnit], model_names: list[str]):
452
+ # Split the computation graph,
453
+ # separating all model-related computations.
454
+ related_unit_ids = self.search_direct_unit_ids(units, model_names)
455
+ edges = self.build_edges(units)
456
+ chains = self.build_chains(units)
457
+ while True:
458
+ num_related_unit_ids = len(related_unit_ids)
459
+ related_unit_ids = self.search_related_unit_ids(edges, related_unit_ids, "target")
460
+ related_unit_ids = self.search_updating_unit_ids(units, chains, related_unit_ids)
461
+ if len(related_unit_ids) == num_related_unit_ids:
462
+ break
463
+ else:
464
+ num_related_unit_ids = len(related_unit_ids)
465
+ related_units = [units[i] for i in related_unit_ids]
466
+ unrelated_units = [units[i] for i in range(len(units)) if i not in related_unit_ids]
467
+ return related_units, unrelated_units
468
+
469
+
470
+ class PipelineUnitRunner:
471
+ def __init__(self):
472
+ pass
473
+
474
+ def __call__(self, unit: PipelineUnit, pipe: BasePipeline, inputs_shared: dict, inputs_posi: dict, inputs_nega: dict) -> tuple[dict, dict]:
475
+ if unit.take_over:
476
+ # Let the pipeline unit take over this function.
477
+ inputs_shared, inputs_posi, inputs_nega = unit.process(pipe, inputs_shared=inputs_shared, inputs_posi=inputs_posi, inputs_nega=inputs_nega)
478
+ elif unit.seperate_cfg:
479
+ # Positive side
480
+ processor_inputs = {name: inputs_posi.get(name_) for name, name_ in unit.input_params_posi.items()}
481
+ if unit.input_params is not None:
482
+ for name in unit.input_params:
483
+ processor_inputs[name] = inputs_shared.get(name)
484
+ processor_outputs = unit.process(pipe, **processor_inputs)
485
+ inputs_posi.update(processor_outputs)
486
+ # Negative side
487
+ if inputs_shared["cfg_scale"] != 1:
488
+ processor_inputs = {name: inputs_nega.get(name_) for name, name_ in unit.input_params_nega.items()}
489
+ if unit.input_params is not None:
490
+ for name in unit.input_params:
491
+ processor_inputs[name] = inputs_shared.get(name)
492
+ processor_outputs = unit.process(pipe, **processor_inputs)
493
+ inputs_nega.update(processor_outputs)
494
+ else:
495
+ inputs_nega.update(processor_outputs)
496
+ else:
497
+ processor_inputs = {name: inputs_shared.get(name) for name in unit.input_params}
498
+ processor_outputs = unit.process(pipe, **processor_inputs)
499
+ inputs_shared.update(processor_outputs)
500
+ return inputs_shared, inputs_posi, inputs_nega
diffsynth/diffusion/flow_match.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, math
2
+ from typing_extensions import Literal
3
+
4
+
5
+ class FlowMatchScheduler():
6
+
7
+ def __init__(self, template: Literal["FLUX.1", "Wan", "Qwen-Image", "FLUX.2", "Z-Image", "LTX-2", "Qwen-Image-Lightning"] = "FLUX.1"):
8
+ self.set_timesteps_fn = {
9
+ "FLUX.1": FlowMatchScheduler.set_timesteps_flux,
10
+ "Wan": FlowMatchScheduler.set_timesteps_wan,
11
+ "Qwen-Image": FlowMatchScheduler.set_timesteps_qwen_image,
12
+ "FLUX.2": FlowMatchScheduler.set_timesteps_flux2,
13
+ "Z-Image": FlowMatchScheduler.set_timesteps_z_image,
14
+ "LTX-2": FlowMatchScheduler.set_timesteps_ltx2,
15
+ "Qwen-Image-Lightning": FlowMatchScheduler.set_timesteps_qwen_image_lightning,
16
+ }.get(template, FlowMatchScheduler.set_timesteps_flux)
17
+ self.num_train_timesteps = 1000
18
+
19
+ @staticmethod
20
+ def set_timesteps_flux(num_inference_steps=100, denoising_strength=1.0, shift=None):
21
+ sigma_min = 0.003/1.002
22
+ sigma_max = 1.0
23
+ shift = 3 if shift is None else shift
24
+ num_train_timesteps = 1000
25
+ sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength
26
+ sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps)
27
+ sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
28
+ timesteps = sigmas * num_train_timesteps
29
+ return sigmas, timesteps
30
+
31
+ @staticmethod
32
+ def set_timesteps_wan(num_inference_steps=100, denoising_strength=1.0, shift=None):
33
+ sigma_min = 0.0
34
+ sigma_max = 1.0
35
+ shift = 5 if shift is None else shift
36
+ num_train_timesteps = 1000
37
+ sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength
38
+ sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps + 1)[:-1]
39
+ sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
40
+ timesteps = sigmas * num_train_timesteps
41
+ return sigmas, timesteps
42
+
43
+ @staticmethod
44
+ def _calculate_shift_qwen_image(image_seq_len, base_seq_len=256, max_seq_len=8192, base_shift=0.5, max_shift=0.9):
45
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
46
+ b = base_shift - m * base_seq_len
47
+ mu = image_seq_len * m + b
48
+ return mu
49
+
50
+ @staticmethod
51
+ def set_timesteps_qwen_image(num_inference_steps=100, denoising_strength=1.0, exponential_shift_mu=None, dynamic_shift_len=None):
52
+ sigma_min = 0.0
53
+ sigma_max = 1.0
54
+ num_train_timesteps = 1000
55
+ shift_terminal = 0.02
56
+ # Sigmas
57
+ sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength
58
+ sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps + 1)[:-1]
59
+ # Mu
60
+ if exponential_shift_mu is not None:
61
+ mu = exponential_shift_mu
62
+ elif dynamic_shift_len is not None:
63
+ mu = FlowMatchScheduler._calculate_shift_qwen_image(dynamic_shift_len)
64
+ else:
65
+ mu = 0.8
66
+ sigmas = math.exp(mu) / (math.exp(mu) + (1 / sigmas - 1))
67
+ # Shift terminal
68
+ one_minus_z = 1 - sigmas
69
+ scale_factor = one_minus_z[-1] / (1 - shift_terminal)
70
+ sigmas = 1 - (one_minus_z / scale_factor)
71
+ # Timesteps
72
+ timesteps = sigmas * num_train_timesteps
73
+ return sigmas, timesteps
74
+
75
+ @staticmethod
76
+ def set_timesteps_qwen_image_lightning(num_inference_steps=100, denoising_strength=1.0, exponential_shift_mu=None, dynamic_shift_len=None):
77
+ sigma_min = 0.0
78
+ sigma_max = 1.0
79
+ num_train_timesteps = 1000
80
+ base_shift = math.log(3)
81
+ max_shift = math.log(3)
82
+ # Sigmas
83
+ sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength
84
+ sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps + 1)[:-1]
85
+ # Mu
86
+ if exponential_shift_mu is not None:
87
+ mu = exponential_shift_mu
88
+ elif dynamic_shift_len is not None:
89
+ mu = FlowMatchScheduler._calculate_shift_qwen_image(dynamic_shift_len, base_shift=base_shift, max_shift=max_shift)
90
+ else:
91
+ mu = 0.8
92
+ sigmas = math.exp(mu) / (math.exp(mu) + (1 / sigmas - 1))
93
+ # Timesteps
94
+ timesteps = sigmas * num_train_timesteps
95
+ return sigmas, timesteps
96
+
97
+ @staticmethod
98
+ def compute_empirical_mu(image_seq_len, num_steps):
99
+ a1, b1 = 8.73809524e-05, 1.89833333
100
+ a2, b2 = 0.00016927, 0.45666666
101
+
102
+ if image_seq_len > 4300:
103
+ mu = a2 * image_seq_len + b2
104
+ return float(mu)
105
+
106
+ m_200 = a2 * image_seq_len + b2
107
+ m_10 = a1 * image_seq_len + b1
108
+
109
+ a = (m_200 - m_10) / 190.0
110
+ b = m_200 - 200.0 * a
111
+ mu = a * num_steps + b
112
+
113
+ return float(mu)
114
+
115
+ @staticmethod
116
+ def set_timesteps_flux2(num_inference_steps=100, denoising_strength=1.0, dynamic_shift_len=None):
117
+ sigma_min = 1 / num_inference_steps
118
+ sigma_max = 1.0
119
+ num_train_timesteps = 1000
120
+ sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength
121
+ sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps)
122
+ if dynamic_shift_len is None:
123
+ # If you ask me why I set mu=0.8,
124
+ # I can only say that it yields better training results.
125
+ mu = 0.8
126
+ else:
127
+ mu = FlowMatchScheduler.compute_empirical_mu(dynamic_shift_len, num_inference_steps)
128
+ sigmas = math.exp(mu) / (math.exp(mu) + (1 / sigmas - 1))
129
+ timesteps = sigmas * num_train_timesteps
130
+ return sigmas, timesteps
131
+
132
+ @staticmethod
133
+ def set_timesteps_z_image(num_inference_steps=100, denoising_strength=1.0, shift=None, target_timesteps=None):
134
+ sigma_min = 0.0
135
+ sigma_max = 1.0
136
+ shift = 3 if shift is None else shift
137
+ num_train_timesteps = 1000
138
+ sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength
139
+ sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps + 1)[:-1]
140
+ sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
141
+ timesteps = sigmas * num_train_timesteps
142
+ if target_timesteps is not None:
143
+ target_timesteps = target_timesteps.to(dtype=timesteps.dtype, device=timesteps.device)
144
+ for timestep in target_timesteps:
145
+ timestep_id = torch.argmin((timesteps - timestep).abs())
146
+ timesteps[timestep_id] = timestep
147
+ return sigmas, timesteps
148
+
149
+ @staticmethod
150
+ def set_timesteps_ltx2(num_inference_steps=100, denoising_strength=1.0, dynamic_shift_len=None, terminal=0.1, special_case=None):
151
+ num_train_timesteps = 1000
152
+ if special_case == "stage2":
153
+ sigmas = torch.Tensor([0.909375, 0.725, 0.421875])
154
+ elif special_case == "ditilled_stage1":
155
+ sigmas = torch.Tensor([1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875])
156
+ else:
157
+ dynamic_shift_len = dynamic_shift_len or 4096
158
+ sigma_shift = FlowMatchScheduler._calculate_shift_qwen_image(
159
+ image_seq_len=dynamic_shift_len,
160
+ base_seq_len=1024,
161
+ max_seq_len=4096,
162
+ base_shift=0.95,
163
+ max_shift=2.05,
164
+ )
165
+ sigma_min = 0.0
166
+ sigma_max = 1.0
167
+ sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength
168
+ sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps + 1)[:-1]
169
+ sigmas = math.exp(sigma_shift) / (math.exp(sigma_shift) + (1 / sigmas - 1))
170
+ # Shift terminal
171
+ one_minus_z = 1.0 - sigmas
172
+ scale_factor = one_minus_z[-1] / (1 - terminal)
173
+ sigmas = 1.0 - (one_minus_z / scale_factor)
174
+ timesteps = sigmas * num_train_timesteps
175
+ return sigmas, timesteps
176
+
177
+ def set_training_weight(self):
178
+ steps = 1000
179
+ x = self.timesteps
180
+ y = torch.exp(-2 * ((x - steps / 2) / steps) ** 2)
181
+ y_shifted = y - y.min()
182
+ bsmntw_weighing = y_shifted * (steps / y_shifted.sum())
183
+ if len(self.timesteps) != 1000:
184
+ # This is an empirical formula.
185
+ bsmntw_weighing = bsmntw_weighing * (len(self.timesteps) / steps)
186
+ bsmntw_weighing = bsmntw_weighing + bsmntw_weighing[1]
187
+ self.linear_timesteps_weights = bsmntw_weighing
188
+
189
+ def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, training=False, **kwargs):
190
+ self.sigmas, self.timesteps = self.set_timesteps_fn(
191
+ num_inference_steps=num_inference_steps,
192
+ denoising_strength=denoising_strength,
193
+ **kwargs,
194
+ )
195
+ if training:
196
+ self.set_training_weight()
197
+ self.training = True
198
+ else:
199
+ self.training = False
200
+
201
+ def step(self, model_output, timestep, sample, to_final=False, **kwargs):
202
+ if isinstance(timestep, torch.Tensor):
203
+ timestep = timestep.cpu()
204
+ timestep_id = torch.argmin((self.timesteps - timestep).abs())
205
+ sigma = self.sigmas[timestep_id]
206
+ if to_final or timestep_id + 1 >= len(self.timesteps):
207
+ sigma_ = 0
208
+ else:
209
+ sigma_ = self.sigmas[timestep_id + 1]
210
+ prev_sample = sample + model_output * (sigma_ - sigma)
211
+ return prev_sample
212
+
213
+ def return_to_timestep(self, timestep, sample, sample_stablized):
214
+ if isinstance(timestep, torch.Tensor):
215
+ timestep = timestep.cpu()
216
+ timestep_id = torch.argmin((self.timesteps - timestep).abs())
217
+ sigma = self.sigmas[timestep_id]
218
+ model_output = (sample - sample_stablized) / sigma
219
+ return model_output
220
+
221
+ def add_noise(self, original_samples, noise, timestep):
222
+ if isinstance(timestep, torch.Tensor):
223
+ timestep = timestep.cpu()
224
+ timestep_id = torch.argmin((self.timesteps - timestep).abs())
225
+ sigma = self.sigmas[timestep_id]
226
+ sample = (1 - sigma) * original_samples + sigma * noise
227
+ return sample
228
+
229
+ def training_target(self, sample, noise, timestep):
230
+ target = noise - sample
231
+ return target
232
+
233
+ def training_weight(self, timestep):
234
+ timestep_id = torch.argmin((self.timesteps - timestep.to(self.timesteps.device)).abs())
235
+ weights = self.linear_timesteps_weights[timestep_id]
236
+ return weights
diffsynth/diffusion/logger.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, torch
2
+ from accelerate import Accelerator
3
+
4
+
5
+ class ModelLogger:
6
+ def __init__(self, output_path, remove_prefix_in_ckpt=None, state_dict_converter=lambda x:x, resume_step=0):
7
+ self.output_path = output_path
8
+ self.remove_prefix_in_ckpt = remove_prefix_in_ckpt
9
+ self.state_dict_converter = state_dict_converter
10
+ self.num_steps = resume_step
11
+
12
+
13
+ def on_step_end(self, accelerator: Accelerator, model: torch.nn.Module, save_steps=None, **kwargs):
14
+ self.num_steps += 1
15
+ if save_steps is not None and self.num_steps % save_steps == 0:
16
+ self.save_model(accelerator, model, f"step-{self.num_steps}.safetensors")
17
+
18
+
19
+ def on_epoch_end(self, accelerator: Accelerator, model: torch.nn.Module, epoch_id):
20
+ accelerator.wait_for_everyone()
21
+ state_dict = accelerator.get_state_dict(model)
22
+ if accelerator.is_main_process:
23
+ state_dict = accelerator.unwrap_model(model).export_trainable_state_dict(state_dict, remove_prefix=self.remove_prefix_in_ckpt)
24
+ state_dict = self.state_dict_converter(state_dict)
25
+ os.makedirs(self.output_path, exist_ok=True)
26
+ path = os.path.join(self.output_path, f"epoch-{epoch_id}.safetensors")
27
+ accelerator.save(state_dict, path, safe_serialization=True)
28
+
29
+
30
+ def on_training_end(self, accelerator: Accelerator, model: torch.nn.Module, save_steps=None):
31
+ if save_steps is not None and self.num_steps % save_steps != 0:
32
+ self.save_model(accelerator, model, f"step-{self.num_steps}.safetensors")
33
+
34
+
35
+ def save_model(self, accelerator: Accelerator, model: torch.nn.Module, file_name):
36
+ accelerator.wait_for_everyone()
37
+ state_dict = accelerator.get_state_dict(model)
38
+ if accelerator.is_main_process:
39
+ state_dict = accelerator.unwrap_model(model).export_trainable_state_dict(state_dict, remove_prefix=self.remove_prefix_in_ckpt)
40
+ state_dict = self.state_dict_converter(state_dict)
41
+ os.makedirs(self.output_path, exist_ok=True)
42
+ path = os.path.join(self.output_path, file_name)
43
+ accelerator.save(state_dict, path, safe_serialization=True)
diffsynth/diffusion/loss.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .base_pipeline import BasePipeline
2
+ import torch
3
+
4
+
5
+ def FlowMatchSFTLoss(pipe: BasePipeline, **inputs):
6
+ max_timestep_boundary = int(inputs.get("max_timestep_boundary", 1) * len(pipe.scheduler.timesteps))
7
+ min_timestep_boundary = int(inputs.get("min_timestep_boundary", 0) * len(pipe.scheduler.timesteps))
8
+
9
+ timestep_id = torch.randint(min_timestep_boundary, max_timestep_boundary, (1,))
10
+ timestep = pipe.scheduler.timesteps[timestep_id].to(dtype=pipe.torch_dtype, device=pipe.device)
11
+
12
+ noise = torch.randn_like(inputs["input_latents"])
13
+ inputs["latents"] = pipe.scheduler.add_noise(inputs["input_latents"], noise, timestep)
14
+ training_target = pipe.scheduler.training_target(inputs["input_latents"], noise, timestep)
15
+
16
+ if "first_frame_latents" in inputs:
17
+ inputs["latents"][:, :, 0:1] = inputs["first_frame_latents"]
18
+
19
+ models = {name: getattr(pipe, name) for name in pipe.in_iteration_models}
20
+ noise_pred = pipe.model_fn(**models, **inputs, timestep=timestep)
21
+
22
+ if "first_frame_latents" in inputs:
23
+ noise_pred = noise_pred[:, :, 1:]
24
+ training_target = training_target[:, :, 1:]
25
+
26
+ loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
27
+ loss = loss * pipe.scheduler.training_weight(timestep)
28
+ return loss
29
+
30
+
31
+ def FlowMatchSFTAudioVideoLoss(pipe: BasePipeline, **inputs):
32
+ max_timestep_boundary = int(inputs.get("max_timestep_boundary", 1) * len(pipe.scheduler.timesteps))
33
+ min_timestep_boundary = int(inputs.get("min_timestep_boundary", 0) * len(pipe.scheduler.timesteps))
34
+
35
+ timestep_id = torch.randint(min_timestep_boundary, max_timestep_boundary, (1,))
36
+ timestep = pipe.scheduler.timesteps[timestep_id].to(dtype=pipe.torch_dtype, device=pipe.device)
37
+
38
+ # video
39
+ noise = torch.randn_like(inputs["input_latents"])
40
+ inputs["video_latents"] = pipe.scheduler.add_noise(inputs["input_latents"], noise, timestep)
41
+ training_target = pipe.scheduler.training_target(inputs["input_latents"], noise, timestep)
42
+
43
+ # audio
44
+ if inputs.get("audio_input_latents") is not None:
45
+ audio_noise = torch.randn_like(inputs["audio_input_latents"])
46
+ inputs["audio_latents"] = pipe.scheduler.add_noise(inputs["audio_input_latents"], audio_noise, timestep)
47
+ training_target_audio = pipe.scheduler.training_target(inputs["audio_input_latents"], audio_noise, timestep)
48
+
49
+ models = {name: getattr(pipe, name) for name in pipe.in_iteration_models}
50
+ noise_pred, noise_pred_audio = pipe.model_fn(**models, **inputs, timestep=timestep)
51
+
52
+ loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
53
+ loss = loss * pipe.scheduler.training_weight(timestep)
54
+ if inputs.get("audio_input_latents") is not None:
55
+ loss_audio = torch.nn.functional.mse_loss(noise_pred_audio.float(), training_target_audio.float())
56
+ loss_audio = loss_audio * pipe.scheduler.training_weight(timestep)
57
+ loss = loss + loss_audio
58
+ return loss
59
+
60
+
61
+ def DirectDistillLoss(pipe: BasePipeline, **inputs):
62
+ pipe.scheduler.set_timesteps(inputs["num_inference_steps"])
63
+ pipe.scheduler.training = True
64
+ models = {name: getattr(pipe, name) for name in pipe.in_iteration_models}
65
+ for progress_id, timestep in enumerate(pipe.scheduler.timesteps):
66
+ timestep = timestep.unsqueeze(0).to(dtype=pipe.torch_dtype, device=pipe.device)
67
+ noise_pred = pipe.model_fn(**models, **inputs, timestep=timestep, progress_id=progress_id)
68
+ inputs["latents"] = pipe.step(pipe.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs)
69
+ loss = torch.nn.functional.mse_loss(inputs["latents"].float(), inputs["input_latents"].float())
70
+ return loss
71
+
72
+
73
+ class TrajectoryImitationLoss(torch.nn.Module):
74
+ def __init__(self):
75
+ super().__init__()
76
+ self.initialized = False
77
+
78
+ def initialize(self, device):
79
+ import lpips # TODO: remove it
80
+ self.loss_fn = lpips.LPIPS(net='alex').to(device)
81
+ self.initialized = True
82
+
83
+ def fetch_trajectory(self, pipe: BasePipeline, timesteps_student, inputs_shared, inputs_posi, inputs_nega, num_inference_steps, cfg_scale):
84
+ trajectory = [inputs_shared["latents"].clone()]
85
+
86
+ pipe.scheduler.set_timesteps(num_inference_steps, target_timesteps=timesteps_student)
87
+ models = {name: getattr(pipe, name) for name in pipe.in_iteration_models}
88
+ for progress_id, timestep in enumerate(pipe.scheduler.timesteps):
89
+ timestep = timestep.unsqueeze(0).to(dtype=pipe.torch_dtype, device=pipe.device)
90
+ noise_pred = pipe.cfg_guided_model_fn(
91
+ pipe.model_fn, cfg_scale,
92
+ inputs_shared, inputs_posi, inputs_nega,
93
+ **models, timestep=timestep, progress_id=progress_id
94
+ )
95
+ inputs_shared["latents"] = pipe.step(pipe.scheduler, progress_id=progress_id, noise_pred=noise_pred.detach(), **inputs_shared)
96
+
97
+ trajectory.append(inputs_shared["latents"].clone())
98
+ return pipe.scheduler.timesteps, trajectory
99
+
100
+ def align_trajectory(self, pipe: BasePipeline, timesteps_teacher, trajectory_teacher, inputs_shared, inputs_posi, inputs_nega, num_inference_steps, cfg_scale):
101
+ loss = 0
102
+ pipe.scheduler.set_timesteps(num_inference_steps, training=True)
103
+ models = {name: getattr(pipe, name) for name in pipe.in_iteration_models}
104
+ for progress_id, timestep in enumerate(pipe.scheduler.timesteps):
105
+ timestep = timestep.unsqueeze(0).to(dtype=pipe.torch_dtype, device=pipe.device)
106
+
107
+ progress_id_teacher = torch.argmin((timesteps_teacher - timestep).abs())
108
+ inputs_shared["latents"] = trajectory_teacher[progress_id_teacher]
109
+
110
+ noise_pred = pipe.cfg_guided_model_fn(
111
+ pipe.model_fn, cfg_scale,
112
+ inputs_shared, inputs_posi, inputs_nega,
113
+ **models, timestep=timestep, progress_id=progress_id
114
+ )
115
+
116
+ sigma = pipe.scheduler.sigmas[progress_id]
117
+ sigma_ = 0 if progress_id + 1 >= len(pipe.scheduler.timesteps) else pipe.scheduler.sigmas[progress_id + 1]
118
+ if progress_id + 1 >= len(pipe.scheduler.timesteps):
119
+ latents_ = trajectory_teacher[-1]
120
+ else:
121
+ progress_id_teacher = torch.argmin((timesteps_teacher - pipe.scheduler.timesteps[progress_id + 1]).abs())
122
+ latents_ = trajectory_teacher[progress_id_teacher]
123
+
124
+ denom = sigma_ - sigma
125
+ denom = torch.sign(denom) * torch.clamp(denom.abs(), min=1e-6)
126
+ target = (latents_ - inputs_shared["latents"]) / denom
127
+ loss = loss + torch.nn.functional.mse_loss(noise_pred.float(), target.float()) * pipe.scheduler.training_weight(timestep)
128
+ return loss
129
+
130
+ def compute_regularization(self, pipe: BasePipeline, trajectory_teacher, inputs_shared, inputs_posi, inputs_nega, num_inference_steps, cfg_scale):
131
+ inputs_shared["latents"] = trajectory_teacher[0]
132
+ pipe.scheduler.set_timesteps(num_inference_steps)
133
+ models = {name: getattr(pipe, name) for name in pipe.in_iteration_models}
134
+ for progress_id, timestep in enumerate(pipe.scheduler.timesteps):
135
+ timestep = timestep.unsqueeze(0).to(dtype=pipe.torch_dtype, device=pipe.device)
136
+ noise_pred = pipe.cfg_guided_model_fn(
137
+ pipe.model_fn, cfg_scale,
138
+ inputs_shared, inputs_posi, inputs_nega,
139
+ **models, timestep=timestep, progress_id=progress_id
140
+ )
141
+ inputs_shared["latents"] = pipe.step(pipe.scheduler, progress_id=progress_id, noise_pred=noise_pred.detach(), **inputs_shared)
142
+
143
+ image_pred = pipe.vae_decoder(inputs_shared["latents"])
144
+ image_real = pipe.vae_decoder(trajectory_teacher[-1])
145
+ loss = self.loss_fn(image_pred.float(), image_real.float())
146
+ return loss
147
+
148
+ def forward(self, pipe: BasePipeline, inputs_shared, inputs_posi, inputs_nega):
149
+ if not self.initialized:
150
+ self.initialize(pipe.device)
151
+ with torch.no_grad():
152
+ pipe.scheduler.set_timesteps(8)
153
+ timesteps_teacher, trajectory_teacher = self.fetch_trajectory(inputs_shared["teacher"], pipe.scheduler.timesteps, inputs_shared, inputs_posi, inputs_nega, 50, 2)
154
+ timesteps_teacher = timesteps_teacher.to(dtype=pipe.torch_dtype, device=pipe.device)
155
+ loss_1 = self.align_trajectory(pipe, timesteps_teacher, trajectory_teacher, inputs_shared, inputs_posi, inputs_nega, 8, 1)
156
+ loss_2 = self.compute_regularization(pipe, trajectory_teacher, inputs_shared, inputs_posi, inputs_nega, 8, 1)
157
+ loss = loss_1 + loss_2
158
+ return loss
diffsynth/diffusion/parsers.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+
4
+ def add_dataset_base_config(parser: argparse.ArgumentParser):
5
+ parser.add_argument("--dataset_base_path", type=str, default="", required=True, help="Base path of the dataset.")
6
+ parser.add_argument("--dataset_metadata_path", type=str, default=None, help="Path to the metadata file of the dataset.")
7
+ parser.add_argument("--dataset_repeat", type=int, default=1, help="Number of times to repeat the dataset per epoch.")
8
+ parser.add_argument("--dataset_num_workers", type=int, default=0, help="Number of workers for data loading.")
9
+ parser.add_argument("--data_file_keys", type=str, default="image,video", help="Data file keys in the metadata. Comma-separated.")
10
+ return parser
11
+
12
+ def add_image_size_config(parser: argparse.ArgumentParser):
13
+ parser.add_argument("--height", type=int, default=None, help="Height of images. Leave `height` and `width` empty to enable dynamic resolution.")
14
+ parser.add_argument("--width", type=int, default=None, help="Width of images. Leave `height` and `width` empty to enable dynamic resolution.")
15
+ parser.add_argument("--max_pixels", type=int, default=1024*1024, help="Maximum number of pixels per frame, used for dynamic resolution.")
16
+ return parser
17
+
18
+ def add_video_size_config(parser: argparse.ArgumentParser):
19
+ parser.add_argument("--height", type=int, default=None, help="Height of images. Leave `height` and `width` empty to enable dynamic resolution.")
20
+ parser.add_argument("--width", type=int, default=None, help="Width of images. Leave `height` and `width` empty to enable dynamic resolution.")
21
+ parser.add_argument("--max_pixels", type=int, default=1024*1024, help="Maximum number of pixels per frame, used for dynamic resolution.")
22
+ parser.add_argument("--num_frames", type=int, default=81, help="Number of frames per video. Frames are sampled from the video prefix.")
23
+ return parser
24
+
25
+ def add_model_config(parser: argparse.ArgumentParser):
26
+ parser.add_argument("--model_paths", type=str, default=None, help="Paths to load models. In JSON format.")
27
+ parser.add_argument("--model_id_with_origin_paths", type=str, default=None, help="Model ID with origin paths, e.g., Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors. Comma-separated.")
28
+ parser.add_argument("--extra_inputs", default=None, help="Additional model inputs, comma-separated.")
29
+ parser.add_argument("--fp8_models", default=None, help="Models with FP8 precision, comma-separated.")
30
+ parser.add_argument("--offload_models", default=None, help="Models with offload, comma-separated. Only used in splited training.")
31
+ return parser
32
+
33
+ def add_training_config(parser: argparse.ArgumentParser):
34
+ parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate.")
35
+ parser.add_argument("--num_epochs", type=int, default=1, help="Number of epochs.")
36
+ parser.add_argument("--trainable_models", type=str, default=None, help="Models to train, e.g., dit, vae, text_encoder.")
37
+ parser.add_argument("--find_unused_parameters", default=False, action="store_true", help="Whether to find unused parameters in DDP.")
38
+ parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay.")
39
+ parser.add_argument("--task", type=str, default="sft", required=False, help="Task type.")
40
+ return parser
41
+
42
+ def add_output_config(parser: argparse.ArgumentParser):
43
+ parser.add_argument("--output_path", type=str, default="./models", help="Output save path.")
44
+ parser.add_argument("--remove_prefix_in_ckpt", type=str, default="pipe.dit.", help="Remove prefix in ckpt.")
45
+ parser.add_argument("--save_steps", type=int, default=None, help="Number of checkpoint saving invervals. If None, checkpoints will be saved every epoch.")
46
+ parser.add_argument("--resume_step", type=int, default=0, help="Starting step count when resuming. ModelLogger.num_steps initializes here; training stops when num_steps reaches num_epochs * steps_per_epoch.")
47
+ return parser
48
+
49
+ def add_lora_config(parser: argparse.ArgumentParser):
50
+ parser.add_argument("--lora_base_model", type=str, default=None, help="Which model LoRA is added to.")
51
+ parser.add_argument("--lora_target_modules", type=str, default="q,k,v,o,ffn.0,ffn.2", help="Which layers LoRA is added to.")
52
+ parser.add_argument("--lora_rank", type=int, default=32, help="Rank of LoRA.")
53
+ parser.add_argument("--lora_checkpoint", type=str, default=None, help="Path to the LoRA checkpoint. If provided, LoRA will be loaded from this checkpoint.")
54
+ parser.add_argument("--preset_lora_path", type=str, default=None, help="Path to the preset LoRA checkpoint. If provided, this LoRA will be fused to the base model.")
55
+ parser.add_argument("--preset_lora_model", type=str, default=None, help="Which model the preset LoRA is fused to.")
56
+ return parser
57
+
58
+ def add_gradient_config(parser: argparse.ArgumentParser):
59
+ parser.add_argument("--use_gradient_checkpointing", default=False, action="store_true", help="Whether to use gradient checkpointing.")
60
+ parser.add_argument("--use_gradient_checkpointing_offload", default=False, action="store_true", help="Whether to offload gradient checkpointing to CPU memory.")
61
+ parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Gradient accumulation steps.")
62
+ return parser
63
+
64
+ def add_general_config(parser: argparse.ArgumentParser):
65
+ parser = add_dataset_base_config(parser)
66
+ parser = add_model_config(parser)
67
+ parser = add_training_config(parser)
68
+ parser = add_output_config(parser)
69
+ parser = add_lora_config(parser)
70
+ parser = add_gradient_config(parser)
71
+ return parser
diffsynth/diffusion/runner.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, torch
2
+ from tqdm import tqdm
3
+ from accelerate import Accelerator
4
+ from .training_module import DiffusionTrainingModule
5
+ from .logger import ModelLogger
6
+
7
+
8
+ def launch_training_task(
9
+ accelerator: Accelerator,
10
+ dataset: torch.utils.data.Dataset,
11
+ model: DiffusionTrainingModule,
12
+ model_logger: ModelLogger,
13
+ learning_rate: float = 1e-5,
14
+ weight_decay: float = 1e-2,
15
+ num_workers: int = 1,
16
+ save_steps: int = None,
17
+ num_epochs: int = 1,
18
+ args = None,
19
+ ):
20
+ if args is not None:
21
+ learning_rate = args.learning_rate
22
+ weight_decay = args.weight_decay
23
+ num_workers = args.dataset_num_workers
24
+ save_steps = args.save_steps
25
+ num_epochs = args.num_epochs
26
+
27
+ optimizer = torch.optim.AdamW(model.trainable_modules(), lr=learning_rate, weight_decay=weight_decay)
28
+ scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer)
29
+ dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, collate_fn=lambda x: x[0], num_workers=num_workers)
30
+ model.to(device=accelerator.device)
31
+ # Exclude VAE from DeepSpeed ZeRO-3 wrapping to avoid compatibility issues
32
+ # Store VAE outside the module tree so DeepSpeed doesn't touch it
33
+ vae_module = getattr(model.pipe, 'vae', None)
34
+ if vae_module is not None:
35
+ del model.pipe._modules['vae']
36
+ model, optimizer, dataloader, scheduler = accelerator.prepare(model, optimizer, dataloader, scheduler)
37
+ if vae_module is not None:
38
+ vae_module.to(accelerator.device)
39
+ # Store VAE as a non-module attribute so pipeline code can still use pipe.vae
40
+ pipe = model.module.pipe if hasattr(model, 'module') else model.pipe
41
+ # Use object.__setattr__ to bypass nn.Module's __setattr__ which would register it as a submodule
42
+ object.__setattr__(pipe, 'vae', vae_module)
43
+ initialize_deepspeed_gradient_checkpointing(accelerator)
44
+ # Training log file
45
+ log_path = os.path.join(model_logger.output_path, "training_log.txt")
46
+ if accelerator.is_main_process:
47
+ os.makedirs(model_logger.output_path, exist_ok=True)
48
+ log_file = open(log_path, "a")
49
+ log_file.write(f"Training started. Epochs: {num_epochs}, LR: {learning_rate}, Steps/epoch: {len(dataloader)}\n")
50
+ log_file.flush()
51
+ else:
52
+ log_file = None
53
+
54
+ total_target = num_epochs * len(dataloader)
55
+ reached_target = False
56
+ for epoch_id in range(num_epochs):
57
+ if reached_target:
58
+ break
59
+ progress = tqdm(
60
+ total=total_target,
61
+ initial=model_logger.num_steps,
62
+ desc=f"Epoch {epoch_id+1}/{num_epochs}",
63
+ )
64
+ for step_id, data in enumerate(dataloader):
65
+ if model_logger.num_steps >= total_target:
66
+ reached_target = True
67
+ break
68
+ with accelerator.accumulate(model):
69
+ optimizer.zero_grad()
70
+ if dataset.load_from_cache:
71
+ loss = model({}, inputs=data)
72
+ else:
73
+ loss = model(data)
74
+ accelerator.backward(loss)
75
+ optimizer.step()
76
+ model_logger.on_step_end(accelerator, model, save_steps, loss=loss)
77
+ scheduler.step()
78
+
79
+ # Log loss
80
+ loss_val = loss.item()
81
+ progress.update(1)
82
+ progress.set_postfix(loss=f"{loss_val:.4f}")
83
+ if accelerator.is_main_process and log_file is not None and (model_logger.num_steps % 10 == 0 or model_logger.num_steps <= 5):
84
+ log_file.write(f"epoch={epoch_id+1} step={model_logger.num_steps} loss={loss_val:.6f}\n")
85
+ log_file.flush()
86
+ progress.close()
87
+ if save_steps is None:
88
+ model_logger.on_epoch_end(accelerator, model, epoch_id)
89
+ if accelerator.is_main_process and log_file is not None:
90
+ log_file.write(f"Epoch {epoch_id+1} completed. Checkpoint saved.\n")
91
+ log_file.flush()
92
+ model_logger.on_training_end(accelerator, model, save_steps)
93
+ if log_file is not None:
94
+ log_file.close()
95
+
96
+
97
+ def launch_data_process_task(
98
+ accelerator: Accelerator,
99
+ dataset: torch.utils.data.Dataset,
100
+ model: DiffusionTrainingModule,
101
+ model_logger: ModelLogger,
102
+ num_workers: int = 8,
103
+ args = None,
104
+ ):
105
+ if args is not None:
106
+ num_workers = args.dataset_num_workers
107
+
108
+ dataloader = torch.utils.data.DataLoader(dataset, shuffle=False, collate_fn=lambda x: x[0], num_workers=num_workers)
109
+ model.to(device=accelerator.device)
110
+ model, dataloader = accelerator.prepare(model, dataloader)
111
+
112
+ for data_id, data in enumerate(tqdm(dataloader)):
113
+ with accelerator.accumulate(model):
114
+ with torch.no_grad():
115
+ folder = os.path.join(model_logger.output_path, str(accelerator.process_index))
116
+ os.makedirs(folder, exist_ok=True)
117
+ save_path = os.path.join(model_logger.output_path, str(accelerator.process_index), f"{data_id}.pth")
118
+ data = model(data)
119
+ torch.save(data, save_path)
120
+
121
+
122
+ def initialize_deepspeed_gradient_checkpointing(accelerator: Accelerator):
123
+ if getattr(accelerator.state, "deepspeed_plugin", None) is not None:
124
+ ds_config = accelerator.state.deepspeed_plugin.deepspeed_config
125
+ if "activation_checkpointing" in ds_config:
126
+ import deepspeed
127
+ act_config = ds_config["activation_checkpointing"]
128
+ deepspeed.checkpointing.configure(
129
+ mpu_=None,
130
+ partition_activations=act_config.get("partition_activations", False),
131
+ checkpoint_in_cpu=act_config.get("cpu_checkpointing", False),
132
+ contiguous_checkpointing=act_config.get("contiguous_memory_optimization", False)
133
+ )
134
+ else:
135
+ print("Do not find activation_checkpointing config in deepspeed config, skip initializing deepspeed gradient checkpointing.")
diffsynth/diffusion/training_module.py ADDED
@@ -0,0 +1,302 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, json, os, inspect
2
+ from ..core import ModelConfig, load_state_dict
3
+ from ..utils.controlnet import ControlNetInput
4
+ from .base_pipeline import PipelineUnit
5
+ from peft import LoraConfig, inject_adapter_in_model
6
+
7
+
8
+ class GeneralUnit_RemoveCache(PipelineUnit):
9
+ def __init__(self, required_params=tuple(), force_remove_params_shared=tuple(), force_remove_params_posi=tuple(), force_remove_params_nega=tuple()):
10
+ super().__init__(take_over=True)
11
+ self.required_params = required_params
12
+ self.force_remove_params_shared = force_remove_params_shared
13
+ self.force_remove_params_posi = force_remove_params_posi
14
+ self.force_remove_params_nega = force_remove_params_nega
15
+
16
+ def process_params(self, inputs, required_params, force_remove_params):
17
+ inputs_ = {}
18
+ for name, param in inputs.items():
19
+ if name in required_params and name not in force_remove_params:
20
+ inputs_[name] = param
21
+ return inputs_
22
+
23
+ def process(self, pipe, inputs_shared, inputs_posi, inputs_nega):
24
+ inputs_shared = self.process_params(inputs_shared, self.required_params, self.force_remove_params_shared)
25
+ inputs_posi = self.process_params(inputs_posi, self.required_params, self.force_remove_params_posi)
26
+ inputs_nega = self.process_params(inputs_nega, self.required_params, self.force_remove_params_nega)
27
+ return inputs_shared, inputs_posi, inputs_nega
28
+
29
+
30
+ class DiffusionTrainingModule(torch.nn.Module):
31
+ def __init__(self):
32
+ super().__init__()
33
+
34
+
35
+ def to(self, *args, **kwargs):
36
+ for name, model in self.named_children():
37
+ model.to(*args, **kwargs)
38
+ return self
39
+
40
+
41
+ def trainable_modules(self):
42
+ trainable_modules = filter(lambda p: p.requires_grad, self.parameters())
43
+ return trainable_modules
44
+
45
+
46
+ def trainable_param_names(self):
47
+ trainable_param_names = list(filter(lambda named_param: named_param[1].requires_grad, self.named_parameters()))
48
+ trainable_param_names = set([named_param[0] for named_param in trainable_param_names])
49
+ return trainable_param_names
50
+
51
+
52
+ def add_lora_to_model(self, model, target_modules, lora_rank, lora_alpha=None, upcast_dtype=None):
53
+ if lora_alpha is None:
54
+ lora_alpha = lora_rank
55
+ if isinstance(target_modules, list) and len(target_modules) == 1:
56
+ target_modules = target_modules[0]
57
+ lora_config = LoraConfig(r=lora_rank, lora_alpha=lora_alpha, target_modules=target_modules)
58
+ model = inject_adapter_in_model(lora_config, model)
59
+ if upcast_dtype is not None:
60
+ for param in model.parameters():
61
+ if param.requires_grad:
62
+ param.data = param.to(upcast_dtype)
63
+ return model
64
+
65
+
66
+ def mapping_lora_state_dict(self, state_dict):
67
+ new_state_dict = {}
68
+ for key, value in state_dict.items():
69
+ if "lora_A.weight" in key or "lora_B.weight" in key:
70
+ new_key = key.replace("lora_A.weight", "lora_A.default.weight").replace("lora_B.weight", "lora_B.default.weight")
71
+ new_state_dict[new_key] = value
72
+ elif "lora_A.default.weight" in key or "lora_B.default.weight" in key:
73
+ new_state_dict[key] = value
74
+ return new_state_dict
75
+
76
+
77
+ def export_trainable_state_dict(self, state_dict, remove_prefix=None):
78
+ trainable_param_names = self.trainable_param_names()
79
+ state_dict = {name: param for name, param in state_dict.items() if name in trainable_param_names}
80
+ if remove_prefix is not None:
81
+ state_dict_ = {}
82
+ for name, param in state_dict.items():
83
+ if name.startswith(remove_prefix):
84
+ name = name[len(remove_prefix):]
85
+ state_dict_[name] = param
86
+ state_dict = state_dict_
87
+ return state_dict
88
+
89
+
90
+ def transfer_data_to_device(self, data, device, torch_float_dtype=None):
91
+ if data is None:
92
+ return data
93
+ elif isinstance(data, torch.Tensor):
94
+ data = data.to(device)
95
+ if torch_float_dtype is not None and data.dtype in [torch.float, torch.float16, torch.bfloat16]:
96
+ data = data.to(torch_float_dtype)
97
+ return data
98
+ elif isinstance(data, tuple):
99
+ data = tuple(self.transfer_data_to_device(x, device, torch_float_dtype) for x in data)
100
+ return data
101
+ elif isinstance(data, list):
102
+ data = list(self.transfer_data_to_device(x, device, torch_float_dtype) for x in data)
103
+ return data
104
+ elif isinstance(data, dict):
105
+ data = {i: self.transfer_data_to_device(data[i], device, torch_float_dtype) for i in data}
106
+ return data
107
+ else:
108
+ return data
109
+
110
+ def parse_vram_config(self, fp8=False, offload=False, device="cpu"):
111
+ if fp8:
112
+ return {
113
+ "offload_dtype": torch.float8_e4m3fn,
114
+ "offload_device": device,
115
+ "onload_dtype": torch.float8_e4m3fn,
116
+ "onload_device": device,
117
+ "preparing_dtype": torch.float8_e4m3fn,
118
+ "preparing_device": device,
119
+ "computation_dtype": torch.bfloat16,
120
+ "computation_device": device,
121
+ }
122
+ elif offload:
123
+ return {
124
+ "offload_dtype": "disk",
125
+ "offload_device": "disk",
126
+ "onload_dtype": "disk",
127
+ "onload_device": "disk",
128
+ "preparing_dtype": torch.bfloat16,
129
+ "preparing_device": device,
130
+ "computation_dtype": torch.bfloat16,
131
+ "computation_device": device,
132
+ "clear_parameters": True,
133
+ }
134
+ else:
135
+ return {}
136
+
137
+ def parse_model_configs(self, model_paths, model_id_with_origin_paths, fp8_models=None, offload_models=None, device="cpu"):
138
+ fp8_models = [] if fp8_models is None else fp8_models.split(",")
139
+ offload_models = [] if offload_models is None else offload_models.split(",")
140
+ model_configs = []
141
+ if model_paths is not None:
142
+ model_paths = json.loads(model_paths)
143
+ for path in model_paths:
144
+ vram_config = self.parse_vram_config(
145
+ fp8=path in fp8_models,
146
+ offload=path in offload_models,
147
+ device=device
148
+ )
149
+ model_configs.append(ModelConfig(path=path, **vram_config))
150
+ if model_id_with_origin_paths is not None:
151
+ model_id_with_origin_paths = model_id_with_origin_paths.split(",")
152
+ for model_id_with_origin_path in model_id_with_origin_paths:
153
+ vram_config = self.parse_vram_config(
154
+ fp8=model_id_with_origin_path in fp8_models,
155
+ offload=model_id_with_origin_path in offload_models,
156
+ device=device
157
+ )
158
+ config = self.parse_path_or_model_id(model_id_with_origin_path)
159
+ model_configs.append(ModelConfig(model_id=config.model_id, origin_file_pattern=config.origin_file_pattern, **vram_config))
160
+ return model_configs
161
+
162
+
163
+ def parse_path_or_model_id(self, model_id_with_origin_path, default_value=None):
164
+ if model_id_with_origin_path is None:
165
+ return default_value
166
+ elif os.path.exists(model_id_with_origin_path):
167
+ return ModelConfig(path=model_id_with_origin_path)
168
+ else:
169
+ if ":" not in model_id_with_origin_path:
170
+ raise ValueError(f"Failed to parse model config: {model_id_with_origin_path}. This is neither a valid path nor in the format of `model_id/origin_file_pattern`.")
171
+ split_id = model_id_with_origin_path.rfind(":")
172
+ model_id = model_id_with_origin_path[:split_id]
173
+ origin_file_pattern = model_id_with_origin_path[split_id + 1:]
174
+ return ModelConfig(model_id=model_id, origin_file_pattern=origin_file_pattern)
175
+
176
+
177
+ def auto_detect_lora_target_modules(
178
+ self,
179
+ model: torch.nn.Module,
180
+ search_for_linear=False,
181
+ linear_detector=lambda x: min(x.weight.shape) >= 512,
182
+ block_list_detector=lambda x: isinstance(x, torch.nn.ModuleList) and len(x) > 1,
183
+ name_prefix="",
184
+ ):
185
+ lora_target_modules = []
186
+ if search_for_linear:
187
+ for name, module in model.named_modules():
188
+ module_name = name_prefix + ["", "."][name_prefix != ""] + name
189
+ if isinstance(module, torch.nn.Linear) and linear_detector(module):
190
+ lora_target_modules.append(module_name)
191
+ else:
192
+ for name, module in model.named_children():
193
+ module_name = name_prefix + ["", "."][name_prefix != ""] + name
194
+ lora_target_modules += self.auto_detect_lora_target_modules(
195
+ module,
196
+ search_for_linear=block_list_detector(module),
197
+ linear_detector=linear_detector,
198
+ block_list_detector=block_list_detector,
199
+ name_prefix=module_name,
200
+ )
201
+ return lora_target_modules
202
+
203
+
204
+ def parse_lora_target_modules(self, model, lora_target_modules):
205
+ if lora_target_modules == "":
206
+ print("No LoRA target modules specified. The framework will automatically search for them.")
207
+ lora_target_modules = self.auto_detect_lora_target_modules(model)
208
+ print(f"LoRA will be patched at {lora_target_modules}.")
209
+ else:
210
+ lora_target_modules = lora_target_modules.split(",")
211
+ return lora_target_modules
212
+
213
+
214
+ def switch_pipe_to_training_mode(
215
+ self,
216
+ pipe,
217
+ trainable_models=None,
218
+ lora_base_model=None, lora_target_modules="", lora_rank=32, lora_checkpoint=None,
219
+ preset_lora_path=None, preset_lora_model=None,
220
+ task="sft",
221
+ ):
222
+ # Scheduler
223
+ pipe.scheduler.set_timesteps(1000, training=True)
224
+
225
+ # Freeze untrainable models
226
+ pipe.freeze_except([] if trainable_models is None else trainable_models.split(","))
227
+
228
+ # Preset LoRA
229
+ if preset_lora_path is not None:
230
+ pipe.load_lora(getattr(pipe, preset_lora_model), preset_lora_path)
231
+
232
+ # FP8
233
+ # FP8 relies on a model-specific memory management scheme.
234
+ # It is delegated to the subclass.
235
+
236
+ # Add LoRA to the base models
237
+ if lora_base_model is not None and not task.endswith(":data_process"):
238
+ if (not hasattr(pipe, lora_base_model)) or getattr(pipe, lora_base_model) is None:
239
+ print(f"No {lora_base_model} models in the pipeline. We cannot patch LoRA on the model. If this occurs during the data processing stage, it is normal.")
240
+ return
241
+ model = self.add_lora_to_model(
242
+ getattr(pipe, lora_base_model),
243
+ target_modules=self.parse_lora_target_modules(getattr(pipe, lora_base_model), lora_target_modules),
244
+ lora_rank=lora_rank,
245
+ upcast_dtype=pipe.torch_dtype,
246
+ )
247
+ if lora_checkpoint is not None:
248
+ state_dict = load_state_dict(lora_checkpoint)
249
+ state_dict = self.mapping_lora_state_dict(state_dict)
250
+ load_result = model.load_state_dict(state_dict, strict=False)
251
+ print(f"LoRA checkpoint loaded: {lora_checkpoint}, total {len(state_dict)} keys")
252
+ if len(load_result[1]) > 0:
253
+ print(f"Warning, LoRA key mismatch! Unexpected keys in LoRA checkpoint: {load_result[1]}")
254
+ setattr(pipe, lora_base_model, model)
255
+
256
+
257
+ def split_pipeline_units(
258
+ self, task, pipe,
259
+ trainable_models=None, lora_base_model=None,
260
+ # TODO: set `remove_unnecessary_params` to `True` by default
261
+ remove_unnecessary_params=False,
262
+ # TODO: move `loss_required_params` to `loss.py`
263
+ loss_required_params=("input_latents", "max_timestep_boundary", "min_timestep_boundary", "first_frame_latents", "video_latents", "audio_input_latents", "num_inference_steps"),
264
+ force_remove_params_shared=tuple(),
265
+ force_remove_params_posi=tuple(),
266
+ force_remove_params_nega=tuple(),
267
+ ):
268
+ models_require_backward = []
269
+ if trainable_models is not None:
270
+ models_require_backward += trainable_models.split(",")
271
+ if lora_base_model is not None:
272
+ models_require_backward += [lora_base_model]
273
+ if task.endswith(":data_process"):
274
+ other_units, pipe.units = pipe.split_pipeline_units(models_require_backward)
275
+ if remove_unnecessary_params:
276
+ required_params = list(loss_required_params) + [i for i in inspect.signature(self.pipe.model_fn).parameters]
277
+ for unit in other_units:
278
+ required_params.extend(unit.fetch_input_params())
279
+ required_params = sorted(list(set(required_params)))
280
+ pipe.units.append(GeneralUnit_RemoveCache(required_params, force_remove_params_shared, force_remove_params_posi, force_remove_params_nega))
281
+ elif task.endswith(":train"):
282
+ pipe.units, _ = pipe.split_pipeline_units(models_require_backward)
283
+ return pipe
284
+
285
+ def parse_extra_inputs(self, data, extra_inputs, inputs_shared):
286
+ controlnet_keys_map = (
287
+ ("blockwise_controlnet_", "blockwise_controlnet_inputs",),
288
+ ("controlnet_", "controlnet_inputs"),
289
+ )
290
+ controlnet_inputs = {}
291
+ for extra_input in extra_inputs:
292
+ for prefix, name in controlnet_keys_map:
293
+ if extra_input.startswith(prefix):
294
+ if name not in controlnet_inputs:
295
+ controlnet_inputs[name] = {}
296
+ controlnet_inputs[name][extra_input.replace(prefix, "")] = data[extra_input]
297
+ break
298
+ else:
299
+ inputs_shared[extra_input] = data[extra_input]
300
+ for name, params in controlnet_inputs.items():
301
+ inputs_shared[name] = [ControlNetInput(**params)]
302
+ return inputs_shared
diffsynth/models/anima_dit.py ADDED
@@ -0,0 +1,1307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # original code from: comfy/ldm/cosmos/predict2.py
2
+
3
+ import torch
4
+ from torch import nn
5
+ from einops import rearrange, repeat
6
+ from einops.layers.torch import Rearrange
7
+ import logging
8
+ from typing import Callable, Optional, Tuple, List
9
+ import math
10
+ from torchvision import transforms
11
+ from ..core.attention import attention_forward
12
+ from ..core.gradient import gradient_checkpoint_forward
13
+
14
+
15
+ class VideoPositionEmb(nn.Module):
16
+ def forward(self, x_B_T_H_W_C: torch.Tensor, fps=Optional[torch.Tensor], device=None, dtype=None) -> torch.Tensor:
17
+ """
18
+ It delegates the embedding generation to generate_embeddings function.
19
+ """
20
+ B_T_H_W_C = x_B_T_H_W_C.shape
21
+ embeddings = self.generate_embeddings(B_T_H_W_C, fps=fps, device=device, dtype=dtype)
22
+
23
+ return embeddings
24
+
25
+ def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor], device=None):
26
+ raise NotImplementedError
27
+
28
+
29
+ def normalize(x: torch.Tensor, dim: Optional[List[int]] = None, eps: float = 0) -> torch.Tensor:
30
+ """
31
+ Normalizes the input tensor along specified dimensions such that the average square norm of elements is adjusted.
32
+
33
+ Args:
34
+ x (torch.Tensor): The input tensor to normalize.
35
+ dim (list, optional): The dimensions over which to normalize. If None, normalizes over all dimensions except the first.
36
+ eps (float, optional): A small constant to ensure numerical stability during division.
37
+
38
+ Returns:
39
+ torch.Tensor: The normalized tensor.
40
+ """
41
+ if dim is None:
42
+ dim = list(range(1, x.ndim))
43
+ norm = torch.linalg.vector_norm(x, dim=dim, keepdim=True, dtype=torch.float32)
44
+ norm = torch.add(eps, norm, alpha=math.sqrt(norm.numel() / x.numel()))
45
+ return x / norm.to(x.dtype)
46
+
47
+
48
+ class LearnablePosEmbAxis(VideoPositionEmb):
49
+ def __init__(
50
+ self,
51
+ *, # enforce keyword arguments
52
+ interpolation: str,
53
+ model_channels: int,
54
+ len_h: int,
55
+ len_w: int,
56
+ len_t: int,
57
+ device=None,
58
+ dtype=None,
59
+ **kwargs,
60
+ ):
61
+ """
62
+ Args:
63
+ interpolation (str): we curretly only support "crop", ideally when we need extrapolation capacity, we should adjust frequency or other more advanced methods. they are not implemented yet.
64
+ """
65
+ del kwargs # unused
66
+ super().__init__()
67
+ self.interpolation = interpolation
68
+ assert self.interpolation in ["crop"], f"Unknown interpolation method {self.interpolation}"
69
+
70
+ self.pos_emb_h = nn.Parameter(torch.empty(len_h, model_channels, device=device, dtype=dtype))
71
+ self.pos_emb_w = nn.Parameter(torch.empty(len_w, model_channels, device=device, dtype=dtype))
72
+ self.pos_emb_t = nn.Parameter(torch.empty(len_t, model_channels, device=device, dtype=dtype))
73
+
74
+ def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor], device=None, dtype=None) -> torch.Tensor:
75
+ B, T, H, W, _ = B_T_H_W_C
76
+ if self.interpolation == "crop":
77
+ emb_h_H = self.pos_emb_h[:H].to(device=device, dtype=dtype)
78
+ emb_w_W = self.pos_emb_w[:W].to(device=device, dtype=dtype)
79
+ emb_t_T = self.pos_emb_t[:T].to(device=device, dtype=dtype)
80
+ emb = (
81
+ repeat(emb_t_T, "t d-> b t h w d", b=B, h=H, w=W)
82
+ + repeat(emb_h_H, "h d-> b t h w d", b=B, t=T, w=W)
83
+ + repeat(emb_w_W, "w d-> b t h w d", b=B, t=T, h=H)
84
+ )
85
+ assert list(emb.shape)[:4] == [B, T, H, W], f"bad shape: {list(emb.shape)[:4]} != {B, T, H, W}"
86
+ else:
87
+ raise ValueError(f"Unknown interpolation method {self.interpolation}")
88
+
89
+ return normalize(emb, dim=-1, eps=1e-6)
90
+
91
+
92
+ class VideoRopePosition3DEmb(VideoPositionEmb):
93
+ def __init__(
94
+ self,
95
+ *, # enforce keyword arguments
96
+ head_dim: int,
97
+ len_h: int,
98
+ len_w: int,
99
+ len_t: int,
100
+ base_fps: int = 24,
101
+ h_extrapolation_ratio: float = 1.0,
102
+ w_extrapolation_ratio: float = 1.0,
103
+ t_extrapolation_ratio: float = 1.0,
104
+ enable_fps_modulation: bool = True,
105
+ device=None,
106
+ **kwargs, # used for compatibility with other positional embeddings; unused in this class
107
+ ):
108
+ del kwargs
109
+ super().__init__()
110
+ self.base_fps = base_fps
111
+ self.max_h = len_h
112
+ self.max_w = len_w
113
+ self.enable_fps_modulation = enable_fps_modulation
114
+
115
+ dim = head_dim
116
+ dim_h = dim // 6 * 2
117
+ dim_w = dim_h
118
+ dim_t = dim - 2 * dim_h
119
+ assert dim == dim_h + dim_w + dim_t, f"bad dim: {dim} != {dim_h} + {dim_w} + {dim_t}"
120
+ self.register_buffer(
121
+ "dim_spatial_range",
122
+ torch.arange(0, dim_h, 2, device=device)[: (dim_h // 2)].float() / dim_h,
123
+ persistent=False,
124
+ )
125
+ self.register_buffer(
126
+ "dim_temporal_range",
127
+ torch.arange(0, dim_t, 2, device=device)[: (dim_t // 2)].float() / dim_t,
128
+ persistent=False,
129
+ )
130
+
131
+ self.h_ntk_factor = h_extrapolation_ratio ** (dim_h / (dim_h - 2))
132
+ self.w_ntk_factor = w_extrapolation_ratio ** (dim_w / (dim_w - 2))
133
+ self.t_ntk_factor = t_extrapolation_ratio ** (dim_t / (dim_t - 2))
134
+
135
+ def generate_embeddings(
136
+ self,
137
+ B_T_H_W_C: torch.Size,
138
+ fps: Optional[torch.Tensor] = None,
139
+ h_ntk_factor: Optional[float] = None,
140
+ w_ntk_factor: Optional[float] = None,
141
+ t_ntk_factor: Optional[float] = None,
142
+ device=None,
143
+ dtype=None,
144
+ ):
145
+ """
146
+ Generate embeddings for the given input size.
147
+
148
+ Args:
149
+ B_T_H_W_C (torch.Size): Input tensor size (Batch, Time, Height, Width, Channels).
150
+ fps (Optional[torch.Tensor], optional): Frames per second. Defaults to None.
151
+ h_ntk_factor (Optional[float], optional): Height NTK factor. If None, uses self.h_ntk_factor.
152
+ w_ntk_factor (Optional[float], optional): Width NTK factor. If None, uses self.w_ntk_factor.
153
+ t_ntk_factor (Optional[float], optional): Time NTK factor. If None, uses self.t_ntk_factor.
154
+
155
+ Returns:
156
+ Not specified in the original code snippet.
157
+ """
158
+ h_ntk_factor = h_ntk_factor if h_ntk_factor is not None else self.h_ntk_factor
159
+ w_ntk_factor = w_ntk_factor if w_ntk_factor is not None else self.w_ntk_factor
160
+ t_ntk_factor = t_ntk_factor if t_ntk_factor is not None else self.t_ntk_factor
161
+
162
+ h_theta = 10000.0 * h_ntk_factor
163
+ w_theta = 10000.0 * w_ntk_factor
164
+ t_theta = 10000.0 * t_ntk_factor
165
+
166
+ h_spatial_freqs = 1.0 / (h_theta**self.dim_spatial_range.to(device=device))
167
+ w_spatial_freqs = 1.0 / (w_theta**self.dim_spatial_range.to(device=device))
168
+ temporal_freqs = 1.0 / (t_theta**self.dim_temporal_range.to(device=device))
169
+
170
+ B, T, H, W, _ = B_T_H_W_C
171
+ seq = torch.arange(max(H, W, T), dtype=torch.float, device=device)
172
+ uniform_fps = (fps is None) or isinstance(fps, (int, float)) or (fps.min() == fps.max())
173
+ assert (
174
+ uniform_fps or B == 1 or T == 1
175
+ ), "For video batch, batch size should be 1 for non-uniform fps. For image batch, T should be 1"
176
+ half_emb_h = torch.outer(seq[:H].to(device=device), h_spatial_freqs)
177
+ half_emb_w = torch.outer(seq[:W].to(device=device), w_spatial_freqs)
178
+
179
+ # apply sequence scaling in temporal dimension
180
+ if fps is None or self.enable_fps_modulation is False: # image case
181
+ half_emb_t = torch.outer(seq[:T].to(device=device), temporal_freqs)
182
+ else:
183
+ half_emb_t = torch.outer(seq[:T].to(device=device) / fps * self.base_fps, temporal_freqs)
184
+
185
+ half_emb_h = torch.stack([torch.cos(half_emb_h), -torch.sin(half_emb_h), torch.sin(half_emb_h), torch.cos(half_emb_h)], dim=-1)
186
+ half_emb_w = torch.stack([torch.cos(half_emb_w), -torch.sin(half_emb_w), torch.sin(half_emb_w), torch.cos(half_emb_w)], dim=-1)
187
+ half_emb_t = torch.stack([torch.cos(half_emb_t), -torch.sin(half_emb_t), torch.sin(half_emb_t), torch.cos(half_emb_t)], dim=-1)
188
+
189
+ em_T_H_W_D = torch.cat(
190
+ [
191
+ repeat(half_emb_t, "t d x -> t h w d x", h=H, w=W),
192
+ repeat(half_emb_h, "h d x -> t h w d x", t=T, w=W),
193
+ repeat(half_emb_w, "w d x -> t h w d x", t=T, h=H),
194
+ ]
195
+ , dim=-2,
196
+ )
197
+
198
+ return rearrange(em_T_H_W_D, "t h w d (i j) -> (t h w) d i j", i=2, j=2).float()
199
+
200
+
201
+ def apply_rotary_pos_emb(
202
+ t: torch.Tensor,
203
+ freqs: torch.Tensor,
204
+ ) -> torch.Tensor:
205
+ t_ = t.reshape(*t.shape[:-1], 2, -1).movedim(-2, -1).unsqueeze(-2).float()
206
+ t_out = freqs[..., 0] * t_[..., 0] + freqs[..., 1] * t_[..., 1]
207
+ t_out = t_out.movedim(-1, -2).reshape(*t.shape).type_as(t)
208
+ return t_out
209
+
210
+
211
+ # ---------------------- Feed Forward Network -----------------------
212
+ class GPT2FeedForward(nn.Module):
213
+ def __init__(self, d_model: int, d_ff: int, device=None, dtype=None, operations=None) -> None:
214
+ super().__init__()
215
+ self.activation = nn.GELU()
216
+ self.layer1 = operations.Linear(d_model, d_ff, bias=False, device=device, dtype=dtype)
217
+ self.layer2 = operations.Linear(d_ff, d_model, bias=False, device=device, dtype=dtype)
218
+
219
+ self._layer_id = None
220
+ self._dim = d_model
221
+ self._hidden_dim = d_ff
222
+
223
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
224
+ x = self.layer1(x)
225
+
226
+ x = self.activation(x)
227
+ x = self.layer2(x)
228
+ return x
229
+
230
+
231
+ def torch_attention_op(q_B_S_H_D: torch.Tensor, k_B_S_H_D: torch.Tensor, v_B_S_H_D: torch.Tensor, transformer_options: Optional[dict] = {}) -> torch.Tensor:
232
+ """Computes multi-head attention using PyTorch's native implementation.
233
+
234
+ This function provides a PyTorch backend alternative to Transformer Engine's attention operation.
235
+ It rearranges the input tensors to match PyTorch's expected format, computes scaled dot-product
236
+ attention, and rearranges the output back to the original format.
237
+
238
+ The input tensor names use the following dimension conventions:
239
+
240
+ - B: batch size
241
+ - S: sequence length
242
+ - H: number of attention heads
243
+ - D: head dimension
244
+
245
+ Args:
246
+ q_B_S_H_D: Query tensor with shape (batch, seq_len, n_heads, head_dim)
247
+ k_B_S_H_D: Key tensor with shape (batch, seq_len, n_heads, head_dim)
248
+ v_B_S_H_D: Value tensor with shape (batch, seq_len, n_heads, head_dim)
249
+
250
+ Returns:
251
+ Attention output tensor with shape (batch, seq_len, n_heads * head_dim)
252
+ """
253
+ in_q_shape = q_B_S_H_D.shape
254
+ in_k_shape = k_B_S_H_D.shape
255
+ q_B_H_S_D = rearrange(q_B_S_H_D, "b ... h k -> b h ... k").view(in_q_shape[0], in_q_shape[-2], -1, in_q_shape[-1])
256
+ k_B_H_S_D = rearrange(k_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1])
257
+ v_B_H_S_D = rearrange(v_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1])
258
+ return attention_forward(q_B_H_S_D, k_B_H_S_D, v_B_H_S_D, out_pattern="b s (n d)")
259
+
260
+
261
+ class Attention(nn.Module):
262
+ """
263
+ A flexible attention module supporting both self-attention and cross-attention mechanisms.
264
+
265
+ This module implements a multi-head attention layer that can operate in either self-attention
266
+ or cross-attention mode. The mode is determined by whether a context dimension is provided.
267
+ The implementation uses scaled dot-product attention and supports optional bias terms and
268
+ dropout regularization.
269
+
270
+ Args:
271
+ query_dim (int): The dimensionality of the query vectors.
272
+ context_dim (int, optional): The dimensionality of the context (key/value) vectors.
273
+ If None, the module operates in self-attention mode using query_dim. Default: None
274
+ n_heads (int, optional): Number of attention heads for multi-head attention. Default: 8
275
+ head_dim (int, optional): The dimension of each attention head. Default: 64
276
+ dropout (float, optional): Dropout probability applied to the output. Default: 0.0
277
+ qkv_format (str, optional): Format specification for QKV tensors. Default: "bshd"
278
+ backend (str, optional): Backend to use for the attention operation. Default: "transformer_engine"
279
+
280
+ Examples:
281
+ >>> # Self-attention with 512 dimensions and 8 heads
282
+ >>> self_attn = Attention(query_dim=512)
283
+ >>> x = torch.randn(32, 16, 512) # (batch_size, seq_len, dim)
284
+ >>> out = self_attn(x) # (32, 16, 512)
285
+
286
+ >>> # Cross-attention
287
+ >>> cross_attn = Attention(query_dim=512, context_dim=256)
288
+ >>> query = torch.randn(32, 16, 512)
289
+ >>> context = torch.randn(32, 8, 256)
290
+ >>> out = cross_attn(query, context) # (32, 16, 512)
291
+ """
292
+
293
+ def __init__(
294
+ self,
295
+ query_dim: int,
296
+ context_dim: Optional[int] = None,
297
+ n_heads: int = 8,
298
+ head_dim: int = 64,
299
+ dropout: float = 0.0,
300
+ device=None,
301
+ dtype=None,
302
+ operations=None,
303
+ ) -> None:
304
+ super().__init__()
305
+ logging.debug(
306
+ f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
307
+ f"{n_heads} heads with a dimension of {head_dim}."
308
+ )
309
+ self.is_selfattn = context_dim is None # self attention
310
+
311
+ context_dim = query_dim if context_dim is None else context_dim
312
+ inner_dim = head_dim * n_heads
313
+
314
+ self.n_heads = n_heads
315
+ self.head_dim = head_dim
316
+ self.query_dim = query_dim
317
+ self.context_dim = context_dim
318
+
319
+ self.q_proj = operations.Linear(query_dim, inner_dim, bias=False, device=device, dtype=dtype)
320
+ self.q_norm = operations.RMSNorm(self.head_dim, eps=1e-6, device=device, dtype=dtype)
321
+
322
+ self.k_proj = operations.Linear(context_dim, inner_dim, bias=False, device=device, dtype=dtype)
323
+ self.k_norm = operations.RMSNorm(self.head_dim, eps=1e-6, device=device, dtype=dtype)
324
+
325
+ self.v_proj = operations.Linear(context_dim, inner_dim, bias=False, device=device, dtype=dtype)
326
+ self.v_norm = nn.Identity()
327
+
328
+ self.output_proj = operations.Linear(inner_dim, query_dim, bias=False, device=device, dtype=dtype)
329
+ self.output_dropout = nn.Dropout(dropout) if dropout > 1e-4 else nn.Identity()
330
+
331
+ self.attn_op = torch_attention_op
332
+
333
+ self._query_dim = query_dim
334
+ self._context_dim = context_dim
335
+ self._inner_dim = inner_dim
336
+
337
+ def compute_qkv(
338
+ self,
339
+ x: torch.Tensor,
340
+ context: Optional[torch.Tensor] = None,
341
+ rope_emb: Optional[torch.Tensor] = None,
342
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
343
+ q = self.q_proj(x)
344
+ context = x if context is None else context
345
+ k = self.k_proj(context)
346
+ v = self.v_proj(context)
347
+ q, k, v = map(
348
+ lambda t: rearrange(t, "b ... (h d) -> b ... h d", h=self.n_heads, d=self.head_dim),
349
+ (q, k, v),
350
+ )
351
+
352
+ def apply_norm_and_rotary_pos_emb(
353
+ q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, rope_emb: Optional[torch.Tensor]
354
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
355
+ q = self.q_norm(q)
356
+ k = self.k_norm(k)
357
+ v = self.v_norm(v)
358
+ if self.is_selfattn and rope_emb is not None: # only apply to self-attention!
359
+ q = apply_rotary_pos_emb(q, rope_emb)
360
+ k = apply_rotary_pos_emb(k, rope_emb)
361
+ return q, k, v
362
+
363
+ q, k, v = apply_norm_and_rotary_pos_emb(q, k, v, rope_emb)
364
+
365
+ return q, k, v
366
+
367
+ def compute_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, transformer_options: Optional[dict] = {}) -> torch.Tensor:
368
+ result = self.attn_op(q, k, v, transformer_options=transformer_options) # [B, S, H, D]
369
+ return self.output_dropout(self.output_proj(result))
370
+
371
+ def forward(
372
+ self,
373
+ x: torch.Tensor,
374
+ context: Optional[torch.Tensor] = None,
375
+ rope_emb: Optional[torch.Tensor] = None,
376
+ transformer_options: Optional[dict] = {},
377
+ ) -> torch.Tensor:
378
+ """
379
+ Args:
380
+ x (Tensor): The query tensor of shape [B, Mq, K]
381
+ context (Optional[Tensor]): The key tensor of shape [B, Mk, K] or use x as context [self attention] if None
382
+ """
383
+ q, k, v = self.compute_qkv(x, context, rope_emb=rope_emb)
384
+ return self.compute_attention(q, k, v, transformer_options=transformer_options)
385
+
386
+
387
+ class Timesteps(nn.Module):
388
+ def __init__(self, num_channels: int):
389
+ super().__init__()
390
+ self.num_channels = num_channels
391
+
392
+ def forward(self, timesteps_B_T: torch.Tensor) -> torch.Tensor:
393
+ assert timesteps_B_T.ndim == 2, f"Expected 2D input, got {timesteps_B_T.ndim}"
394
+ timesteps = timesteps_B_T.flatten().float()
395
+ half_dim = self.num_channels // 2
396
+ exponent = -math.log(10000) * torch.arange(half_dim, dtype=torch.float32, device=timesteps.device)
397
+ exponent = exponent / (half_dim - 0.0)
398
+
399
+ emb = torch.exp(exponent)
400
+ emb = timesteps[:, None].float() * emb[None, :]
401
+
402
+ sin_emb = torch.sin(emb)
403
+ cos_emb = torch.cos(emb)
404
+ emb = torch.cat([cos_emb, sin_emb], dim=-1)
405
+
406
+ return rearrange(emb, "(b t) d -> b t d", b=timesteps_B_T.shape[0], t=timesteps_B_T.shape[1])
407
+
408
+
409
+ class TimestepEmbedding(nn.Module):
410
+ def __init__(self, in_features: int, out_features: int, use_adaln_lora: bool = False, device=None, dtype=None, operations=None):
411
+ super().__init__()
412
+ logging.debug(
413
+ f"Using AdaLN LoRA Flag: {use_adaln_lora}. We enable bias if no AdaLN LoRA for backward compatibility."
414
+ )
415
+ self.in_dim = in_features
416
+ self.out_dim = out_features
417
+ self.linear_1 = operations.Linear(in_features, out_features, bias=not use_adaln_lora, device=device, dtype=dtype)
418
+ self.activation = nn.SiLU()
419
+ self.use_adaln_lora = use_adaln_lora
420
+ if use_adaln_lora:
421
+ self.linear_2 = operations.Linear(out_features, 3 * out_features, bias=False, device=device, dtype=dtype)
422
+ else:
423
+ self.linear_2 = operations.Linear(out_features, out_features, bias=False, device=device, dtype=dtype)
424
+
425
+ def forward(self, sample: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
426
+ emb = self.linear_1(sample)
427
+ emb = self.activation(emb)
428
+ emb = self.linear_2(emb)
429
+
430
+ if self.use_adaln_lora:
431
+ adaln_lora_B_T_3D = emb
432
+ emb_B_T_D = sample
433
+ else:
434
+ adaln_lora_B_T_3D = None
435
+ emb_B_T_D = emb
436
+
437
+ return emb_B_T_D, adaln_lora_B_T_3D
438
+
439
+
440
+ class PatchEmbed(nn.Module):
441
+ """
442
+ PatchEmbed is a module for embedding patches from an input tensor by applying either 3D or 2D convolutional layers,
443
+ depending on the . This module can process inputs with temporal (video) and spatial (image) dimensions,
444
+ making it suitable for video and image processing tasks. It supports dividing the input into patches
445
+ and embedding each patch into a vector of size `out_channels`.
446
+
447
+ Parameters:
448
+ - spatial_patch_size (int): The size of each spatial patch.
449
+ - temporal_patch_size (int): The size of each temporal patch.
450
+ - in_channels (int): Number of input channels. Default: 3.
451
+ - out_channels (int): The dimension of the embedding vector for each patch. Default: 768.
452
+ - bias (bool): If True, adds a learnable bias to the output of the convolutional layers. Default: True.
453
+ """
454
+
455
+ def __init__(
456
+ self,
457
+ spatial_patch_size: int,
458
+ temporal_patch_size: int,
459
+ in_channels: int = 3,
460
+ out_channels: int = 768,
461
+ device=None, dtype=None, operations=None
462
+ ):
463
+ super().__init__()
464
+ self.spatial_patch_size = spatial_patch_size
465
+ self.temporal_patch_size = temporal_patch_size
466
+
467
+ self.proj = nn.Sequential(
468
+ Rearrange(
469
+ "b c (t r) (h m) (w n) -> b t h w (c r m n)",
470
+ r=temporal_patch_size,
471
+ m=spatial_patch_size,
472
+ n=spatial_patch_size,
473
+ ),
474
+ operations.Linear(
475
+ in_channels * spatial_patch_size * spatial_patch_size * temporal_patch_size, out_channels, bias=False, device=device, dtype=dtype
476
+ ),
477
+ )
478
+ self.dim = in_channels * spatial_patch_size * spatial_patch_size * temporal_patch_size
479
+
480
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
481
+ """
482
+ Forward pass of the PatchEmbed module.
483
+
484
+ Parameters:
485
+ - x (torch.Tensor): The input tensor of shape (B, C, T, H, W) where
486
+ B is the batch size,
487
+ C is the number of channels,
488
+ T is the temporal dimension,
489
+ H is the height, and
490
+ W is the width of the input.
491
+
492
+ Returns:
493
+ - torch.Tensor: The embedded patches as a tensor, with shape b t h w c.
494
+ """
495
+ assert x.dim() == 5
496
+ _, _, T, H, W = x.shape
497
+ assert (
498
+ H % self.spatial_patch_size == 0 and W % self.spatial_patch_size == 0
499
+ ), f"H,W {(H, W)} should be divisible by spatial_patch_size {self.spatial_patch_size}"
500
+ assert T % self.temporal_patch_size == 0
501
+ x = self.proj(x)
502
+ return x
503
+
504
+
505
+ class FinalLayer(nn.Module):
506
+ """
507
+ The final layer of video DiT.
508
+ """
509
+
510
+ def __init__(
511
+ self,
512
+ hidden_size: int,
513
+ spatial_patch_size: int,
514
+ temporal_patch_size: int,
515
+ out_channels: int,
516
+ use_adaln_lora: bool = False,
517
+ adaln_lora_dim: int = 256,
518
+ device=None, dtype=None, operations=None
519
+ ):
520
+ super().__init__()
521
+ self.layer_norm = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
522
+ self.linear = operations.Linear(
523
+ hidden_size, spatial_patch_size * spatial_patch_size * temporal_patch_size * out_channels, bias=False, device=device, dtype=dtype
524
+ )
525
+ self.hidden_size = hidden_size
526
+ self.n_adaln_chunks = 2
527
+ self.use_adaln_lora = use_adaln_lora
528
+ self.adaln_lora_dim = adaln_lora_dim
529
+ if use_adaln_lora:
530
+ self.adaln_modulation = nn.Sequential(
531
+ nn.SiLU(),
532
+ operations.Linear(hidden_size, adaln_lora_dim, bias=False, device=device, dtype=dtype),
533
+ operations.Linear(adaln_lora_dim, self.n_adaln_chunks * hidden_size, bias=False, device=device, dtype=dtype),
534
+ )
535
+ else:
536
+ self.adaln_modulation = nn.Sequential(
537
+ nn.SiLU(), operations.Linear(hidden_size, self.n_adaln_chunks * hidden_size, bias=False, device=device, dtype=dtype)
538
+ )
539
+
540
+ def forward(
541
+ self,
542
+ x_B_T_H_W_D: torch.Tensor,
543
+ emb_B_T_D: torch.Tensor,
544
+ adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
545
+ ):
546
+ if self.use_adaln_lora:
547
+ assert adaln_lora_B_T_3D is not None
548
+ shift_B_T_D, scale_B_T_D = (
549
+ self.adaln_modulation(emb_B_T_D) + adaln_lora_B_T_3D[:, :, : 2 * self.hidden_size]
550
+ ).chunk(2, dim=-1)
551
+ else:
552
+ shift_B_T_D, scale_B_T_D = self.adaln_modulation(emb_B_T_D).chunk(2, dim=-1)
553
+
554
+ shift_B_T_1_1_D, scale_B_T_1_1_D = rearrange(shift_B_T_D, "b t d -> b t 1 1 d"), rearrange(
555
+ scale_B_T_D, "b t d -> b t 1 1 d"
556
+ )
557
+
558
+ def _fn(
559
+ _x_B_T_H_W_D: torch.Tensor,
560
+ _norm_layer: nn.Module,
561
+ _scale_B_T_1_1_D: torch.Tensor,
562
+ _shift_B_T_1_1_D: torch.Tensor,
563
+ ) -> torch.Tensor:
564
+ return _norm_layer(_x_B_T_H_W_D) * (1 + _scale_B_T_1_1_D) + _shift_B_T_1_1_D
565
+
566
+ x_B_T_H_W_D = _fn(x_B_T_H_W_D, self.layer_norm, scale_B_T_1_1_D, shift_B_T_1_1_D)
567
+ x_B_T_H_W_O = self.linear(x_B_T_H_W_D)
568
+ return x_B_T_H_W_O
569
+
570
+
571
+ class Block(nn.Module):
572
+ """
573
+ A transformer block that combines self-attention, cross-attention and MLP layers with AdaLN modulation.
574
+ Each component (self-attention, cross-attention, MLP) has its own layer normalization and AdaLN modulation.
575
+
576
+ Parameters:
577
+ x_dim (int): Dimension of input features
578
+ context_dim (int): Dimension of context features for cross-attention
579
+ num_heads (int): Number of attention heads
580
+ mlp_ratio (float): Multiplier for MLP hidden dimension. Default: 4.0
581
+ use_adaln_lora (bool): Whether to use AdaLN-LoRA modulation. Default: False
582
+ adaln_lora_dim (int): Hidden dimension for AdaLN-LoRA layers. Default: 256
583
+
584
+ The block applies the following sequence:
585
+ 1. Self-attention with AdaLN modulation
586
+ 2. Cross-attention with AdaLN modulation
587
+ 3. MLP with AdaLN modulation
588
+
589
+ Each component uses skip connections and layer normalization.
590
+ """
591
+
592
+ def __init__(
593
+ self,
594
+ x_dim: int,
595
+ context_dim: int,
596
+ num_heads: int,
597
+ mlp_ratio: float = 4.0,
598
+ use_adaln_lora: bool = False,
599
+ adaln_lora_dim: int = 256,
600
+ device=None,
601
+ dtype=None,
602
+ operations=None,
603
+ ):
604
+ super().__init__()
605
+ self.x_dim = x_dim
606
+ self.layer_norm_self_attn = operations.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype)
607
+ self.self_attn = Attention(x_dim, None, num_heads, x_dim // num_heads, device=device, dtype=dtype, operations=operations)
608
+
609
+ self.layer_norm_cross_attn = operations.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype)
610
+ self.cross_attn = Attention(
611
+ x_dim, context_dim, num_heads, x_dim // num_heads, device=device, dtype=dtype, operations=operations
612
+ )
613
+
614
+ self.layer_norm_mlp = operations.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype)
615
+ self.mlp = GPT2FeedForward(x_dim, int(x_dim * mlp_ratio), device=device, dtype=dtype, operations=operations)
616
+
617
+ self.use_adaln_lora = use_adaln_lora
618
+ if self.use_adaln_lora:
619
+ self.adaln_modulation_self_attn = nn.Sequential(
620
+ nn.SiLU(),
621
+ operations.Linear(x_dim, adaln_lora_dim, bias=False, device=device, dtype=dtype),
622
+ operations.Linear(adaln_lora_dim, 3 * x_dim, bias=False, device=device, dtype=dtype),
623
+ )
624
+ self.adaln_modulation_cross_attn = nn.Sequential(
625
+ nn.SiLU(),
626
+ operations.Linear(x_dim, adaln_lora_dim, bias=False, device=device, dtype=dtype),
627
+ operations.Linear(adaln_lora_dim, 3 * x_dim, bias=False, device=device, dtype=dtype),
628
+ )
629
+ self.adaln_modulation_mlp = nn.Sequential(
630
+ nn.SiLU(),
631
+ operations.Linear(x_dim, adaln_lora_dim, bias=False, device=device, dtype=dtype),
632
+ operations.Linear(adaln_lora_dim, 3 * x_dim, bias=False, device=device, dtype=dtype),
633
+ )
634
+ else:
635
+ self.adaln_modulation_self_attn = nn.Sequential(nn.SiLU(), operations.Linear(x_dim, 3 * x_dim, bias=False, device=device, dtype=dtype))
636
+ self.adaln_modulation_cross_attn = nn.Sequential(nn.SiLU(), operations.Linear(x_dim, 3 * x_dim, bias=False, device=device, dtype=dtype))
637
+ self.adaln_modulation_mlp = nn.Sequential(nn.SiLU(), operations.Linear(x_dim, 3 * x_dim, bias=False, device=device, dtype=dtype))
638
+
639
+ def forward(
640
+ self,
641
+ x_B_T_H_W_D: torch.Tensor,
642
+ emb_B_T_D: torch.Tensor,
643
+ crossattn_emb: torch.Tensor,
644
+ rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
645
+ adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
646
+ extra_per_block_pos_emb: Optional[torch.Tensor] = None,
647
+ transformer_options: Optional[dict] = {},
648
+ ) -> torch.Tensor:
649
+ residual_dtype = x_B_T_H_W_D.dtype
650
+ compute_dtype = emb_B_T_D.dtype
651
+ if extra_per_block_pos_emb is not None:
652
+ x_B_T_H_W_D = x_B_T_H_W_D + extra_per_block_pos_emb
653
+
654
+ if self.use_adaln_lora:
655
+ shift_self_attn_B_T_D, scale_self_attn_B_T_D, gate_self_attn_B_T_D = (
656
+ self.adaln_modulation_self_attn(emb_B_T_D) + adaln_lora_B_T_3D
657
+ ).chunk(3, dim=-1)
658
+ shift_cross_attn_B_T_D, scale_cross_attn_B_T_D, gate_cross_attn_B_T_D = (
659
+ self.adaln_modulation_cross_attn(emb_B_T_D) + adaln_lora_B_T_3D
660
+ ).chunk(3, dim=-1)
661
+ shift_mlp_B_T_D, scale_mlp_B_T_D, gate_mlp_B_T_D = (
662
+ self.adaln_modulation_mlp(emb_B_T_D) + adaln_lora_B_T_3D
663
+ ).chunk(3, dim=-1)
664
+ else:
665
+ shift_self_attn_B_T_D, scale_self_attn_B_T_D, gate_self_attn_B_T_D = self.adaln_modulation_self_attn(
666
+ emb_B_T_D
667
+ ).chunk(3, dim=-1)
668
+ shift_cross_attn_B_T_D, scale_cross_attn_B_T_D, gate_cross_attn_B_T_D = self.adaln_modulation_cross_attn(
669
+ emb_B_T_D
670
+ ).chunk(3, dim=-1)
671
+ shift_mlp_B_T_D, scale_mlp_B_T_D, gate_mlp_B_T_D = self.adaln_modulation_mlp(emb_B_T_D).chunk(3, dim=-1)
672
+
673
+ # Reshape tensors from (B, T, D) to (B, T, 1, 1, D) for broadcasting
674
+ shift_self_attn_B_T_1_1_D = rearrange(shift_self_attn_B_T_D, "b t d -> b t 1 1 d")
675
+ scale_self_attn_B_T_1_1_D = rearrange(scale_self_attn_B_T_D, "b t d -> b t 1 1 d")
676
+ gate_self_attn_B_T_1_1_D = rearrange(gate_self_attn_B_T_D, "b t d -> b t 1 1 d")
677
+
678
+ shift_cross_attn_B_T_1_1_D = rearrange(shift_cross_attn_B_T_D, "b t d -> b t 1 1 d")
679
+ scale_cross_attn_B_T_1_1_D = rearrange(scale_cross_attn_B_T_D, "b t d -> b t 1 1 d")
680
+ gate_cross_attn_B_T_1_1_D = rearrange(gate_cross_attn_B_T_D, "b t d -> b t 1 1 d")
681
+
682
+ shift_mlp_B_T_1_1_D = rearrange(shift_mlp_B_T_D, "b t d -> b t 1 1 d")
683
+ scale_mlp_B_T_1_1_D = rearrange(scale_mlp_B_T_D, "b t d -> b t 1 1 d")
684
+ gate_mlp_B_T_1_1_D = rearrange(gate_mlp_B_T_D, "b t d -> b t 1 1 d")
685
+
686
+ B, T, H, W, D = x_B_T_H_W_D.shape
687
+
688
+ def _fn(_x_B_T_H_W_D, _norm_layer, _scale_B_T_1_1_D, _shift_B_T_1_1_D):
689
+ return _norm_layer(_x_B_T_H_W_D) * (1 + _scale_B_T_1_1_D) + _shift_B_T_1_1_D
690
+
691
+ normalized_x_B_T_H_W_D = _fn(
692
+ x_B_T_H_W_D,
693
+ self.layer_norm_self_attn,
694
+ scale_self_attn_B_T_1_1_D,
695
+ shift_self_attn_B_T_1_1_D,
696
+ )
697
+ result_B_T_H_W_D = rearrange(
698
+ self.self_attn(
699
+ # normalized_x_B_T_HW_D,
700
+ rearrange(normalized_x_B_T_H_W_D.to(compute_dtype), "b t h w d -> b (t h w) d"),
701
+ None,
702
+ rope_emb=rope_emb_L_1_1_D,
703
+ transformer_options=transformer_options,
704
+ ),
705
+ "b (t h w) d -> b t h w d",
706
+ t=T,
707
+ h=H,
708
+ w=W,
709
+ )
710
+ x_B_T_H_W_D = x_B_T_H_W_D + gate_self_attn_B_T_1_1_D.to(residual_dtype) * result_B_T_H_W_D.to(residual_dtype)
711
+
712
+ def _x_fn(
713
+ _x_B_T_H_W_D: torch.Tensor,
714
+ layer_norm_cross_attn: Callable,
715
+ _scale_cross_attn_B_T_1_1_D: torch.Tensor,
716
+ _shift_cross_attn_B_T_1_1_D: torch.Tensor,
717
+ transformer_options: Optional[dict] = {},
718
+ ) -> torch.Tensor:
719
+ _normalized_x_B_T_H_W_D = _fn(
720
+ _x_B_T_H_W_D, layer_norm_cross_attn, _scale_cross_attn_B_T_1_1_D, _shift_cross_attn_B_T_1_1_D
721
+ )
722
+ _result_B_T_H_W_D = rearrange(
723
+ self.cross_attn(
724
+ rearrange(_normalized_x_B_T_H_W_D.to(compute_dtype), "b t h w d -> b (t h w) d"),
725
+ crossattn_emb,
726
+ rope_emb=rope_emb_L_1_1_D,
727
+ transformer_options=transformer_options,
728
+ ),
729
+ "b (t h w) d -> b t h w d",
730
+ t=T,
731
+ h=H,
732
+ w=W,
733
+ )
734
+ return _result_B_T_H_W_D
735
+
736
+ result_B_T_H_W_D = _x_fn(
737
+ x_B_T_H_W_D,
738
+ self.layer_norm_cross_attn,
739
+ scale_cross_attn_B_T_1_1_D,
740
+ shift_cross_attn_B_T_1_1_D,
741
+ transformer_options=transformer_options,
742
+ )
743
+ x_B_T_H_W_D = result_B_T_H_W_D.to(residual_dtype) * gate_cross_attn_B_T_1_1_D.to(residual_dtype) + x_B_T_H_W_D
744
+
745
+ normalized_x_B_T_H_W_D = _fn(
746
+ x_B_T_H_W_D,
747
+ self.layer_norm_mlp,
748
+ scale_mlp_B_T_1_1_D,
749
+ shift_mlp_B_T_1_1_D,
750
+ )
751
+ result_B_T_H_W_D = self.mlp(normalized_x_B_T_H_W_D.to(compute_dtype))
752
+ x_B_T_H_W_D = x_B_T_H_W_D + gate_mlp_B_T_1_1_D.to(residual_dtype) * result_B_T_H_W_D.to(residual_dtype)
753
+ return x_B_T_H_W_D
754
+
755
+
756
+ class MiniTrainDIT(nn.Module):
757
+ """
758
+ A clean impl of DIT that can load and reproduce the training results of the original DIT model in~(cosmos 1)
759
+ A general implementation of adaln-modulated VIT-like~(DiT) transformer for video processing.
760
+
761
+ Args:
762
+ max_img_h (int): Maximum height of the input images.
763
+ max_img_w (int): Maximum width of the input images.
764
+ max_frames (int): Maximum number of frames in the video sequence.
765
+ in_channels (int): Number of input channels (e.g., RGB channels for color images).
766
+ out_channels (int): Number of output channels.
767
+ patch_spatial (tuple): Spatial resolution of patches for input processing.
768
+ patch_temporal (int): Temporal resolution of patches for input processing.
769
+ concat_padding_mask (bool): If True, includes a mask channel in the input to handle padding.
770
+ model_channels (int): Base number of channels used throughout the model.
771
+ num_blocks (int): Number of transformer blocks.
772
+ num_heads (int): Number of heads in the multi-head attention layers.
773
+ mlp_ratio (float): Expansion ratio for MLP blocks.
774
+ crossattn_emb_channels (int): Number of embedding channels for cross-attention.
775
+ pos_emb_cls (str): Type of positional embeddings.
776
+ pos_emb_learnable (bool): Whether positional embeddings are learnable.
777
+ pos_emb_interpolation (str): Method for interpolating positional embeddings.
778
+ min_fps (int): Minimum frames per second.
779
+ max_fps (int): Maximum frames per second.
780
+ use_adaln_lora (bool): Whether to use AdaLN-LoRA.
781
+ adaln_lora_dim (int): Dimension for AdaLN-LoRA.
782
+ rope_h_extrapolation_ratio (float): Height extrapolation ratio for RoPE.
783
+ rope_w_extrapolation_ratio (float): Width extrapolation ratio for RoPE.
784
+ rope_t_extrapolation_ratio (float): Temporal extrapolation ratio for RoPE.
785
+ extra_per_block_abs_pos_emb (bool): Whether to use extra per-block absolute positional embeddings.
786
+ extra_h_extrapolation_ratio (float): Height extrapolation ratio for extra embeddings.
787
+ extra_w_extrapolation_ratio (float): Width extrapolation ratio for extra embeddings.
788
+ extra_t_extrapolation_ratio (float): Temporal extrapolation ratio for extra embeddings.
789
+ """
790
+
791
+ def __init__(
792
+ self,
793
+ max_img_h: int,
794
+ max_img_w: int,
795
+ max_frames: int,
796
+ in_channels: int,
797
+ out_channels: int,
798
+ patch_spatial: int, # tuple,
799
+ patch_temporal: int,
800
+ concat_padding_mask: bool = True,
801
+ # attention settings
802
+ model_channels: int = 768,
803
+ num_blocks: int = 10,
804
+ num_heads: int = 16,
805
+ mlp_ratio: float = 4.0,
806
+ # cross attention settings
807
+ crossattn_emb_channels: int = 1024,
808
+ # positional embedding settings
809
+ pos_emb_cls: str = "sincos",
810
+ pos_emb_learnable: bool = False,
811
+ pos_emb_interpolation: str = "crop",
812
+ min_fps: int = 1,
813
+ max_fps: int = 30,
814
+ use_adaln_lora: bool = False,
815
+ adaln_lora_dim: int = 256,
816
+ rope_h_extrapolation_ratio: float = 1.0,
817
+ rope_w_extrapolation_ratio: float = 1.0,
818
+ rope_t_extrapolation_ratio: float = 1.0,
819
+ extra_per_block_abs_pos_emb: bool = False,
820
+ extra_h_extrapolation_ratio: float = 1.0,
821
+ extra_w_extrapolation_ratio: float = 1.0,
822
+ extra_t_extrapolation_ratio: float = 1.0,
823
+ rope_enable_fps_modulation: bool = True,
824
+ image_model=None,
825
+ device=None,
826
+ dtype=None,
827
+ operations=None,
828
+ ) -> None:
829
+ super().__init__()
830
+ self.dtype = dtype
831
+ self.max_img_h = max_img_h
832
+ self.max_img_w = max_img_w
833
+ self.max_frames = max_frames
834
+ self.in_channels = in_channels
835
+ self.out_channels = out_channels
836
+ self.patch_spatial = patch_spatial
837
+ self.patch_temporal = patch_temporal
838
+ self.num_heads = num_heads
839
+ self.num_blocks = num_blocks
840
+ self.model_channels = model_channels
841
+ self.concat_padding_mask = concat_padding_mask
842
+ # positional embedding settings
843
+ self.pos_emb_cls = pos_emb_cls
844
+ self.pos_emb_learnable = pos_emb_learnable
845
+ self.pos_emb_interpolation = pos_emb_interpolation
846
+ self.min_fps = min_fps
847
+ self.max_fps = max_fps
848
+ self.rope_h_extrapolation_ratio = rope_h_extrapolation_ratio
849
+ self.rope_w_extrapolation_ratio = rope_w_extrapolation_ratio
850
+ self.rope_t_extrapolation_ratio = rope_t_extrapolation_ratio
851
+ self.extra_per_block_abs_pos_emb = extra_per_block_abs_pos_emb
852
+ self.extra_h_extrapolation_ratio = extra_h_extrapolation_ratio
853
+ self.extra_w_extrapolation_ratio = extra_w_extrapolation_ratio
854
+ self.extra_t_extrapolation_ratio = extra_t_extrapolation_ratio
855
+ self.rope_enable_fps_modulation = rope_enable_fps_modulation
856
+
857
+ self.build_pos_embed(device=device, dtype=dtype)
858
+ self.use_adaln_lora = use_adaln_lora
859
+ self.adaln_lora_dim = adaln_lora_dim
860
+ self.t_embedder = nn.Sequential(
861
+ Timesteps(model_channels),
862
+ TimestepEmbedding(model_channels, model_channels, use_adaln_lora=use_adaln_lora, device=device, dtype=dtype, operations=operations,),
863
+ )
864
+
865
+ in_channels = in_channels + 1 if concat_padding_mask else in_channels
866
+ self.x_embedder = PatchEmbed(
867
+ spatial_patch_size=patch_spatial,
868
+ temporal_patch_size=patch_temporal,
869
+ in_channels=in_channels,
870
+ out_channels=model_channels,
871
+ device=device, dtype=dtype, operations=operations,
872
+ )
873
+
874
+ self.blocks = nn.ModuleList(
875
+ [
876
+ Block(
877
+ x_dim=model_channels,
878
+ context_dim=crossattn_emb_channels,
879
+ num_heads=num_heads,
880
+ mlp_ratio=mlp_ratio,
881
+ use_adaln_lora=use_adaln_lora,
882
+ adaln_lora_dim=adaln_lora_dim,
883
+ device=device, dtype=dtype, operations=operations,
884
+ )
885
+ for _ in range(num_blocks)
886
+ ]
887
+ )
888
+
889
+ self.final_layer = FinalLayer(
890
+ hidden_size=self.model_channels,
891
+ spatial_patch_size=self.patch_spatial,
892
+ temporal_patch_size=self.patch_temporal,
893
+ out_channels=self.out_channels,
894
+ use_adaln_lora=self.use_adaln_lora,
895
+ adaln_lora_dim=self.adaln_lora_dim,
896
+ device=device, dtype=dtype, operations=operations,
897
+ )
898
+
899
+ self.t_embedding_norm = operations.RMSNorm(model_channels, eps=1e-6, device=device, dtype=dtype)
900
+
901
+ def build_pos_embed(self, device=None, dtype=None) -> None:
902
+ if self.pos_emb_cls == "rope3d":
903
+ cls_type = VideoRopePosition3DEmb
904
+ else:
905
+ raise ValueError(f"Unknown pos_emb_cls {self.pos_emb_cls}")
906
+
907
+ logging.debug(f"Building positional embedding with {self.pos_emb_cls} class, impl {cls_type}")
908
+ kwargs = dict(
909
+ model_channels=self.model_channels,
910
+ len_h=self.max_img_h // self.patch_spatial,
911
+ len_w=self.max_img_w // self.patch_spatial,
912
+ len_t=self.max_frames // self.patch_temporal,
913
+ max_fps=self.max_fps,
914
+ min_fps=self.min_fps,
915
+ is_learnable=self.pos_emb_learnable,
916
+ interpolation=self.pos_emb_interpolation,
917
+ head_dim=self.model_channels // self.num_heads,
918
+ h_extrapolation_ratio=self.rope_h_extrapolation_ratio,
919
+ w_extrapolation_ratio=self.rope_w_extrapolation_ratio,
920
+ t_extrapolation_ratio=self.rope_t_extrapolation_ratio,
921
+ enable_fps_modulation=self.rope_enable_fps_modulation,
922
+ device=device,
923
+ )
924
+ self.pos_embedder = cls_type(
925
+ **kwargs, # type: ignore
926
+ )
927
+
928
+ if self.extra_per_block_abs_pos_emb:
929
+ kwargs["h_extrapolation_ratio"] = self.extra_h_extrapolation_ratio
930
+ kwargs["w_extrapolation_ratio"] = self.extra_w_extrapolation_ratio
931
+ kwargs["t_extrapolation_ratio"] = self.extra_t_extrapolation_ratio
932
+ kwargs["device"] = device
933
+ kwargs["dtype"] = dtype
934
+ self.extra_pos_embedder = LearnablePosEmbAxis(
935
+ **kwargs, # type: ignore
936
+ )
937
+
938
+ def prepare_embedded_sequence(
939
+ self,
940
+ x_B_C_T_H_W: torch.Tensor,
941
+ fps: Optional[torch.Tensor] = None,
942
+ padding_mask: Optional[torch.Tensor] = None,
943
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
944
+ """
945
+ Prepares an embedded sequence tensor by applying positional embeddings and handling padding masks.
946
+
947
+ Args:
948
+ x_B_C_T_H_W (torch.Tensor): video
949
+ fps (Optional[torch.Tensor]): Frames per second tensor to be used for positional embedding when required.
950
+ If None, a default value (`self.base_fps`) will be used.
951
+ padding_mask (Optional[torch.Tensor]): current it is not used
952
+
953
+ Returns:
954
+ Tuple[torch.Tensor, Optional[torch.Tensor]]:
955
+ - A tensor of shape (B, T, H, W, D) with the embedded sequence.
956
+ - An optional positional embedding tensor, returned only if the positional embedding class
957
+ (`self.pos_emb_cls`) includes 'rope'. Otherwise, None.
958
+
959
+ Notes:
960
+ - If `self.concat_padding_mask` is True, a padding mask channel is concatenated to the input tensor.
961
+ - The method of applying positional embeddings depends on the value of `self.pos_emb_cls`.
962
+ - If 'rope' is in `self.pos_emb_cls` (case insensitive), the positional embeddings are generated using
963
+ the `self.pos_embedder` with the shape [T, H, W].
964
+ - If "fps_aware" is in `self.pos_emb_cls`, the positional embeddings are generated using the
965
+ `self.pos_embedder` with the fps tensor.
966
+ - Otherwise, the positional embeddings are generated without considering fps.
967
+ """
968
+ if self.concat_padding_mask:
969
+ if padding_mask is None:
970
+ padding_mask = torch.zeros(x_B_C_T_H_W.shape[0], 1, x_B_C_T_H_W.shape[3], x_B_C_T_H_W.shape[4], dtype=x_B_C_T_H_W.dtype, device=x_B_C_T_H_W.device)
971
+ else:
972
+ padding_mask = transforms.functional.resize(
973
+ padding_mask, list(x_B_C_T_H_W.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST
974
+ )
975
+ x_B_C_T_H_W = torch.cat(
976
+ [x_B_C_T_H_W, padding_mask.unsqueeze(1).repeat(1, 1, x_B_C_T_H_W.shape[2], 1, 1)], dim=1
977
+ )
978
+ x_B_T_H_W_D = self.x_embedder(x_B_C_T_H_W)
979
+
980
+ if self.extra_per_block_abs_pos_emb:
981
+ extra_pos_emb = self.extra_pos_embedder(x_B_T_H_W_D, fps=fps, device=x_B_C_T_H_W.device, dtype=x_B_C_T_H_W.dtype)
982
+ else:
983
+ extra_pos_emb = None
984
+
985
+ if "rope" in self.pos_emb_cls.lower():
986
+ return x_B_T_H_W_D, self.pos_embedder(x_B_T_H_W_D, fps=fps, device=x_B_C_T_H_W.device), extra_pos_emb
987
+ x_B_T_H_W_D = x_B_T_H_W_D + self.pos_embedder(x_B_T_H_W_D, device=x_B_C_T_H_W.device) # [B, T, H, W, D]
988
+
989
+ return x_B_T_H_W_D, None, extra_pos_emb
990
+
991
+ def unpatchify(self, x_B_T_H_W_M: torch.Tensor) -> torch.Tensor:
992
+ x_B_C_Tt_Hp_Wp = rearrange(
993
+ x_B_T_H_W_M,
994
+ "B T H W (p1 p2 t C) -> B C (T t) (H p1) (W p2)",
995
+ p1=self.patch_spatial,
996
+ p2=self.patch_spatial,
997
+ t=self.patch_temporal,
998
+ )
999
+ return x_B_C_Tt_Hp_Wp
1000
+
1001
+ def pad_to_patch_size(self, img, patch_size=(2, 2), padding_mode="circular"):
1002
+ if padding_mode == "circular" and (torch.jit.is_tracing() or torch.jit.is_scripting()):
1003
+ padding_mode = "reflect"
1004
+
1005
+ pad = ()
1006
+ for i in range(img.ndim - 2):
1007
+ pad = (0, (patch_size[i] - img.shape[i + 2] % patch_size[i]) % patch_size[i]) + pad
1008
+
1009
+ return torch.nn.functional.pad(img, pad, mode=padding_mode)
1010
+
1011
+ def forward(
1012
+ self,
1013
+ x: torch.Tensor,
1014
+ timesteps: torch.Tensor,
1015
+ context: torch.Tensor,
1016
+ fps: Optional[torch.Tensor] = None,
1017
+ padding_mask: Optional[torch.Tensor] = None,
1018
+ use_gradient_checkpointing=False,
1019
+ use_gradient_checkpointing_offload=False,
1020
+ **kwargs,
1021
+ ):
1022
+ orig_shape = list(x.shape)
1023
+ x = self.pad_to_patch_size(x, (self.patch_temporal, self.patch_spatial, self.patch_spatial))
1024
+ x_B_C_T_H_W = x
1025
+ timesteps_B_T = timesteps
1026
+ crossattn_emb = context
1027
+ """
1028
+ Args:
1029
+ x: (B, C, T, H, W) tensor of spatial-temp inputs
1030
+ timesteps: (B, ) tensor of timesteps
1031
+ crossattn_emb: (B, N, D) tensor of cross-attention embeddings
1032
+ """
1033
+ x_B_T_H_W_D, rope_emb_L_1_1_D, extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = self.prepare_embedded_sequence(
1034
+ x_B_C_T_H_W,
1035
+ fps=fps,
1036
+ padding_mask=padding_mask,
1037
+ )
1038
+
1039
+ if timesteps_B_T.ndim == 1:
1040
+ timesteps_B_T = timesteps_B_T.unsqueeze(1)
1041
+ t_embedding_B_T_D, adaln_lora_B_T_3D = self.t_embedder[1](self.t_embedder[0](timesteps_B_T).to(x_B_T_H_W_D.dtype))
1042
+ t_embedding_B_T_D = self.t_embedding_norm(t_embedding_B_T_D)
1043
+
1044
+ # for logging purpose
1045
+ affline_scale_log_info = {}
1046
+ affline_scale_log_info["t_embedding_B_T_D"] = t_embedding_B_T_D.detach()
1047
+ self.affline_scale_log_info = affline_scale_log_info
1048
+ self.affline_emb = t_embedding_B_T_D
1049
+ self.crossattn_emb = crossattn_emb
1050
+
1051
+ if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None:
1052
+ assert (
1053
+ x_B_T_H_W_D.shape == extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape
1054
+ ), f"{x_B_T_H_W_D.shape} != {extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape}"
1055
+
1056
+ block_kwargs = {
1057
+ "rope_emb_L_1_1_D": rope_emb_L_1_1_D.unsqueeze(1).unsqueeze(0),
1058
+ "adaln_lora_B_T_3D": adaln_lora_B_T_3D,
1059
+ "extra_per_block_pos_emb": extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D,
1060
+ "transformer_options": kwargs.get("transformer_options", {}),
1061
+ }
1062
+
1063
+ # The residual stream for this model has large values. To make fp16 compute_dtype work, we keep the residual stream
1064
+ # in fp32, but run attention and MLP modules in fp16.
1065
+ # An alternate method that clamps fp16 values "works" in the sense that it makes coherent images, but there is noticeable
1066
+ # quality degradation and visual artifacts.
1067
+ if x_B_T_H_W_D.dtype == torch.float16:
1068
+ x_B_T_H_W_D = x_B_T_H_W_D.float()
1069
+
1070
+ for block in self.blocks:
1071
+ x_B_T_H_W_D = gradient_checkpoint_forward(
1072
+ block,
1073
+ use_gradient_checkpointing=use_gradient_checkpointing,
1074
+ use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
1075
+ x_B_T_H_W_D=x_B_T_H_W_D,
1076
+ emb_B_T_D=t_embedding_B_T_D,
1077
+ crossattn_emb=crossattn_emb,
1078
+ **block_kwargs,
1079
+ )
1080
+
1081
+ x_B_T_H_W_O = self.final_layer(x_B_T_H_W_D.to(crossattn_emb.dtype), t_embedding_B_T_D, adaln_lora_B_T_3D=adaln_lora_B_T_3D)
1082
+ x_B_C_Tt_Hp_Wp = self.unpatchify(x_B_T_H_W_O)[:, :, :orig_shape[-3], :orig_shape[-2], :orig_shape[-1]]
1083
+ return x_B_C_Tt_Hp_Wp
1084
+
1085
+
1086
+ def rotate_half(x):
1087
+ x1 = x[..., : x.shape[-1] // 2]
1088
+ x2 = x[..., x.shape[-1] // 2 :]
1089
+ return torch.cat((-x2, x1), dim=-1)
1090
+
1091
+
1092
+ def apply_rotary_pos_emb2(x, cos, sin, unsqueeze_dim=1):
1093
+ cos = cos.unsqueeze(unsqueeze_dim)
1094
+ sin = sin.unsqueeze(unsqueeze_dim)
1095
+ x_embed = (x * cos) + (rotate_half(x) * sin)
1096
+ return x_embed
1097
+
1098
+
1099
+ class RotaryEmbedding(nn.Module):
1100
+ def __init__(self, head_dim):
1101
+ super().__init__()
1102
+ self.rope_theta = 10000
1103
+ inv_freq = 1.0 / (self.rope_theta ** (torch.arange(0, head_dim, 2, dtype=torch.int64).to(dtype=torch.float) / head_dim))
1104
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
1105
+
1106
+ @torch.no_grad()
1107
+ def forward(self, x, position_ids):
1108
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
1109
+ position_ids_expanded = position_ids[:, None, :].float()
1110
+
1111
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
1112
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
1113
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
1114
+ emb = torch.cat((freqs, freqs), dim=-1)
1115
+ cos = emb.cos()
1116
+ sin = emb.sin()
1117
+
1118
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
1119
+
1120
+
1121
+ class LLMAdapterAttention(nn.Module):
1122
+ def __init__(self, query_dim, context_dim, n_heads, head_dim, device=None, dtype=None, operations=None):
1123
+ super().__init__()
1124
+
1125
+ inner_dim = head_dim * n_heads
1126
+ self.n_heads = n_heads
1127
+ self.head_dim = head_dim
1128
+ self.query_dim = query_dim
1129
+ self.context_dim = context_dim
1130
+
1131
+ self.q_proj = operations.Linear(query_dim, inner_dim, bias=False, device=device, dtype=dtype)
1132
+ self.q_norm = operations.RMSNorm(self.head_dim, eps=1e-6, device=device, dtype=dtype)
1133
+
1134
+ self.k_proj = operations.Linear(context_dim, inner_dim, bias=False, device=device, dtype=dtype)
1135
+ self.k_norm = operations.RMSNorm(self.head_dim, eps=1e-6, device=device, dtype=dtype)
1136
+
1137
+ self.v_proj = operations.Linear(context_dim, inner_dim, bias=False, device=device, dtype=dtype)
1138
+
1139
+ self.o_proj = operations.Linear(inner_dim, query_dim, bias=False, device=device, dtype=dtype)
1140
+
1141
+ def forward(self, x, mask=None, context=None, position_embeddings=None, position_embeddings_context=None):
1142
+ context = x if context is None else context
1143
+ input_shape = x.shape[:-1]
1144
+ q_shape = (*input_shape, self.n_heads, self.head_dim)
1145
+ context_shape = context.shape[:-1]
1146
+ kv_shape = (*context_shape, self.n_heads, self.head_dim)
1147
+
1148
+ query_states = self.q_norm(self.q_proj(x).view(q_shape)).transpose(1, 2)
1149
+ key_states = self.k_norm(self.k_proj(context).view(kv_shape)).transpose(1, 2)
1150
+ value_states = self.v_proj(context).view(kv_shape).transpose(1, 2)
1151
+
1152
+ if position_embeddings is not None:
1153
+ assert position_embeddings_context is not None
1154
+ cos, sin = position_embeddings
1155
+ query_states = apply_rotary_pos_emb2(query_states, cos, sin)
1156
+ cos, sin = position_embeddings_context
1157
+ key_states = apply_rotary_pos_emb2(key_states, cos, sin)
1158
+
1159
+ attn_output = torch.nn.functional.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask=mask)
1160
+
1161
+ attn_output = attn_output.transpose(1, 2).reshape(*input_shape, -1).contiguous()
1162
+ attn_output = self.o_proj(attn_output)
1163
+ return attn_output
1164
+
1165
+ def init_weights(self):
1166
+ torch.nn.init.zeros_(self.o_proj.weight)
1167
+
1168
+
1169
+ class LLMAdapterTransformerBlock(nn.Module):
1170
+ def __init__(self, source_dim, model_dim, num_heads=16, mlp_ratio=4.0, use_self_attn=False, layer_norm=False, device=None, dtype=None, operations=None):
1171
+ super().__init__()
1172
+ self.use_self_attn = use_self_attn
1173
+
1174
+ if self.use_self_attn:
1175
+ self.norm_self_attn = operations.LayerNorm(model_dim, device=device, dtype=dtype) if layer_norm else operations.RMSNorm(model_dim, eps=1e-6, device=device, dtype=dtype)
1176
+ self.self_attn = LLMAdapterAttention(
1177
+ query_dim=model_dim,
1178
+ context_dim=model_dim,
1179
+ n_heads=num_heads,
1180
+ head_dim=model_dim//num_heads,
1181
+ device=device,
1182
+ dtype=dtype,
1183
+ operations=operations,
1184
+ )
1185
+
1186
+ self.norm_cross_attn = operations.LayerNorm(model_dim, device=device, dtype=dtype) if layer_norm else operations.RMSNorm(model_dim, eps=1e-6, device=device, dtype=dtype)
1187
+ self.cross_attn = LLMAdapterAttention(
1188
+ query_dim=model_dim,
1189
+ context_dim=source_dim,
1190
+ n_heads=num_heads,
1191
+ head_dim=model_dim//num_heads,
1192
+ device=device,
1193
+ dtype=dtype,
1194
+ operations=operations,
1195
+ )
1196
+
1197
+ self.norm_mlp = operations.LayerNorm(model_dim, device=device, dtype=dtype) if layer_norm else operations.RMSNorm(model_dim, eps=1e-6, device=device, dtype=dtype)
1198
+ self.mlp = nn.Sequential(
1199
+ operations.Linear(model_dim, int(model_dim * mlp_ratio), device=device, dtype=dtype),
1200
+ nn.GELU(),
1201
+ operations.Linear(int(model_dim * mlp_ratio), model_dim, device=device, dtype=dtype)
1202
+ )
1203
+
1204
+ def forward(self, x, context, target_attention_mask=None, source_attention_mask=None, position_embeddings=None, position_embeddings_context=None):
1205
+ if self.use_self_attn:
1206
+ normed = self.norm_self_attn(x)
1207
+ attn_out = self.self_attn(normed, mask=target_attention_mask, position_embeddings=position_embeddings, position_embeddings_context=position_embeddings)
1208
+ x = x + attn_out
1209
+
1210
+ normed = self.norm_cross_attn(x)
1211
+ attn_out = self.cross_attn(normed, mask=source_attention_mask, context=context, position_embeddings=position_embeddings, position_embeddings_context=position_embeddings_context)
1212
+ x = x + attn_out
1213
+
1214
+ x = x + self.mlp(self.norm_mlp(x))
1215
+ return x
1216
+
1217
+ def init_weights(self):
1218
+ torch.nn.init.zeros_(self.mlp[2].weight)
1219
+ self.cross_attn.init_weights()
1220
+
1221
+
1222
+ class LLMAdapter(nn.Module):
1223
+ def __init__(
1224
+ self,
1225
+ source_dim=1024,
1226
+ target_dim=1024,
1227
+ model_dim=1024,
1228
+ num_layers=6,
1229
+ num_heads=16,
1230
+ use_self_attn=True,
1231
+ layer_norm=False,
1232
+ device=None,
1233
+ dtype=None,
1234
+ operations=None,
1235
+ ):
1236
+ super().__init__()
1237
+
1238
+ self.embed = operations.Embedding(32128, target_dim, device=device, dtype=dtype)
1239
+ if model_dim != target_dim:
1240
+ self.in_proj = operations.Linear(target_dim, model_dim, device=device, dtype=dtype)
1241
+ else:
1242
+ self.in_proj = nn.Identity()
1243
+ self.rotary_emb = RotaryEmbedding(model_dim//num_heads)
1244
+ self.blocks = nn.ModuleList([
1245
+ LLMAdapterTransformerBlock(source_dim, model_dim, num_heads=num_heads, use_self_attn=use_self_attn, layer_norm=layer_norm, device=device, dtype=dtype, operations=operations) for _ in range(num_layers)
1246
+ ])
1247
+ self.out_proj = operations.Linear(model_dim, target_dim, device=device, dtype=dtype)
1248
+ self.norm = operations.RMSNorm(target_dim, eps=1e-6, device=device, dtype=dtype)
1249
+
1250
+ def forward(self, source_hidden_states, target_input_ids, target_attention_mask=None, source_attention_mask=None):
1251
+ if target_attention_mask is not None:
1252
+ target_attention_mask = target_attention_mask.to(torch.bool)
1253
+ if target_attention_mask.ndim == 2:
1254
+ target_attention_mask = target_attention_mask.unsqueeze(1).unsqueeze(1)
1255
+
1256
+ if source_attention_mask is not None:
1257
+ source_attention_mask = source_attention_mask.to(torch.bool)
1258
+ if source_attention_mask.ndim == 2:
1259
+ source_attention_mask = source_attention_mask.unsqueeze(1).unsqueeze(1)
1260
+
1261
+ context = source_hidden_states
1262
+ x = self.in_proj(self.embed(target_input_ids).to(context.dtype))
1263
+ position_ids = torch.arange(x.shape[1], device=x.device).unsqueeze(0)
1264
+ position_ids_context = torch.arange(context.shape[1], device=x.device).unsqueeze(0)
1265
+ position_embeddings = self.rotary_emb(x, position_ids)
1266
+ position_embeddings_context = self.rotary_emb(x, position_ids_context)
1267
+ for block in self.blocks:
1268
+ x = block(x, context, target_attention_mask=target_attention_mask, source_attention_mask=source_attention_mask, position_embeddings=position_embeddings, position_embeddings_context=position_embeddings_context)
1269
+ return self.norm(self.out_proj(x))
1270
+
1271
+
1272
+ class AnimaDiT(MiniTrainDIT):
1273
+
1274
+ _repeated_blocks = ["Block"]
1275
+
1276
+ def __init__(self):
1277
+ kwargs = {'image_model': 'anima', 'max_img_h': 240, 'max_img_w': 240, 'max_frames': 128, 'in_channels': 16, 'out_channels': 16, 'patch_spatial': 2, 'patch_temporal': 1, 'model_channels': 2048, 'concat_padding_mask': True, 'crossattn_emb_channels': 1024, 'pos_emb_cls': 'rope3d', 'pos_emb_learnable': True, 'pos_emb_interpolation': 'crop', 'min_fps': 1, 'max_fps': 30, 'use_adaln_lora': True, 'adaln_lora_dim': 256, 'num_blocks': 28, 'num_heads': 16, 'extra_per_block_abs_pos_emb': False, 'rope_h_extrapolation_ratio': 4.0, 'rope_w_extrapolation_ratio': 4.0, 'rope_t_extrapolation_ratio': 1.0, 'extra_h_extrapolation_ratio': 1.0, 'extra_w_extrapolation_ratio': 1.0, 'extra_t_extrapolation_ratio': 1.0, 'rope_enable_fps_modulation': False, 'dtype': torch.bfloat16, 'device': None, 'operations': torch.nn}
1278
+ super().__init__(**kwargs)
1279
+ self.llm_adapter = LLMAdapter(device=kwargs.get("device"), dtype=kwargs.get("dtype"), operations=kwargs.get("operations"))
1280
+
1281
+ def preprocess_text_embeds(self, text_embeds, text_ids, t5xxl_weights=None):
1282
+ if text_ids is not None:
1283
+ out = self.llm_adapter(text_embeds, text_ids)
1284
+ if t5xxl_weights is not None:
1285
+ out = out * t5xxl_weights
1286
+
1287
+ if out.shape[1] < 512:
1288
+ out = torch.nn.functional.pad(out, (0, 0, 0, 512 - out.shape[1]))
1289
+ return out
1290
+ else:
1291
+ return text_embeds
1292
+
1293
+ def forward(
1294
+ self,
1295
+ x, timesteps, context,
1296
+ use_gradient_checkpointing=False,
1297
+ use_gradient_checkpointing_offload=False,
1298
+ **kwargs
1299
+ ):
1300
+ t5xxl_ids = kwargs.pop("t5xxl_ids", None)
1301
+ if t5xxl_ids is not None:
1302
+ context = self.preprocess_text_embeds(context, t5xxl_ids, t5xxl_weights=kwargs.pop("t5xxl_weights", None))
1303
+ return super().forward(
1304
+ x, timesteps, context,
1305
+ use_gradient_checkpointing=use_gradient_checkpointing, use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
1306
+ **kwargs
1307
+ )
diffsynth/models/dinov3_image_encoder.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import DINOv3ViTModel, DINOv3ViTImageProcessorFast
2
+ from transformers.models.dinov3_vit.modeling_dinov3_vit import DINOv3ViTConfig
3
+ import torch
4
+
5
+ from ..core.device.npu_compatible_device import get_device_type
6
+
7
+
8
+ class DINOv3ImageEncoder(DINOv3ViTModel):
9
+ def __init__(self):
10
+ config = DINOv3ViTConfig(
11
+ architectures = [
12
+ "DINOv3ViTModel"
13
+ ],
14
+ attention_dropout = 0.0,
15
+ drop_path_rate = 0.0,
16
+ dtype = "float32",
17
+ hidden_act = "silu",
18
+ hidden_size = 4096,
19
+ image_size = 224,
20
+ initializer_range = 0.02,
21
+ intermediate_size = 8192,
22
+ key_bias = False,
23
+ layer_norm_eps = 1e-05,
24
+ layerscale_value = 1.0,
25
+ mlp_bias = True,
26
+ model_type = "dinov3_vit",
27
+ num_attention_heads = 32,
28
+ num_channels = 3,
29
+ num_hidden_layers = 40,
30
+ num_register_tokens = 4,
31
+ patch_size = 16,
32
+ pos_embed_jitter = None,
33
+ pos_embed_rescale = 2.0,
34
+ pos_embed_shift = None,
35
+ proj_bias = True,
36
+ query_bias = False,
37
+ rope_theta = 100.0,
38
+ transformers_version = "4.56.1",
39
+ use_gated_mlp = True,
40
+ value_bias = False
41
+ )
42
+ super().__init__(config)
43
+ self.processor = DINOv3ViTImageProcessorFast(
44
+ crop_size = None,
45
+ data_format = "channels_first",
46
+ default_to_square = True,
47
+ device = None,
48
+ disable_grouping = None,
49
+ do_center_crop = None,
50
+ do_convert_rgb = None,
51
+ do_normalize = True,
52
+ do_rescale = True,
53
+ do_resize = True,
54
+ image_mean = [
55
+ 0.485,
56
+ 0.456,
57
+ 0.406
58
+ ],
59
+ image_processor_type = "DINOv3ViTImageProcessorFast",
60
+ image_std = [
61
+ 0.229,
62
+ 0.224,
63
+ 0.225
64
+ ],
65
+ input_data_format = None,
66
+ resample = 2,
67
+ rescale_factor = 0.00392156862745098,
68
+ return_tensors = None,
69
+ size = {
70
+ "height": 224,
71
+ "width": 224
72
+ }
73
+ )
74
+
75
+ def forward(self, image, torch_dtype=torch.bfloat16, device=get_device_type()):
76
+ inputs = self.processor(images=image, return_tensors="pt")
77
+ pixel_values = inputs["pixel_values"].to(dtype=torch_dtype, device=device)
78
+ bool_masked_pos = None
79
+ head_mask = None
80
+
81
+ pixel_values = pixel_values.to(torch_dtype)
82
+ hidden_states = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos)
83
+ position_embeddings = self.rope_embeddings(pixel_values)
84
+
85
+ for i, layer_module in enumerate(self.layer):
86
+ layer_head_mask = head_mask[i] if head_mask is not None else None
87
+ hidden_states = layer_module(
88
+ hidden_states,
89
+ attention_mask=layer_head_mask,
90
+ position_embeddings=position_embeddings,
91
+ )
92
+
93
+ sequence_output = self.norm(hidden_states)
94
+ pooled_output = sequence_output[:, 0, :]
95
+
96
+ return pooled_output
diffsynth/models/flux2_dit.py ADDED
@@ -0,0 +1,1053 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ from typing import Any, Dict, List, Optional, Tuple, Union
3
+
4
+ import torch, math
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ import numpy as np
8
+ from ..core.attention import attention_forward
9
+ from ..core.gradient import gradient_checkpoint_forward
10
+
11
+
12
+ def get_timestep_embedding(
13
+ timesteps: torch.Tensor,
14
+ embedding_dim: int,
15
+ flip_sin_to_cos: bool = False,
16
+ downscale_freq_shift: float = 1,
17
+ scale: float = 1,
18
+ max_period: int = 10000,
19
+ ) -> torch.Tensor:
20
+ """
21
+ This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
22
+
23
+ Args
24
+ timesteps (torch.Tensor):
25
+ a 1-D Tensor of N indices, one per batch element. These may be fractional.
26
+ embedding_dim (int):
27
+ the dimension of the output.
28
+ flip_sin_to_cos (bool):
29
+ Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False)
30
+ downscale_freq_shift (float):
31
+ Controls the delta between frequencies between dimensions
32
+ scale (float):
33
+ Scaling factor applied to the embeddings.
34
+ max_period (int):
35
+ Controls the maximum frequency of the embeddings
36
+ Returns
37
+ torch.Tensor: an [N x dim] Tensor of positional embeddings.
38
+ """
39
+ assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
40
+
41
+ half_dim = embedding_dim // 2
42
+ exponent = -math.log(max_period) * torch.arange(
43
+ start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
44
+ )
45
+ exponent = exponent / (half_dim - downscale_freq_shift)
46
+
47
+ emb = torch.exp(exponent)
48
+ emb = timesteps[:, None].float() * emb[None, :]
49
+
50
+ # scale embeddings
51
+ emb = scale * emb
52
+
53
+ # concat sine and cosine embeddings
54
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
55
+
56
+ # flip sine and cosine embeddings
57
+ if flip_sin_to_cos:
58
+ emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
59
+
60
+ # zero pad
61
+ if embedding_dim % 2 == 1:
62
+ emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
63
+ return emb
64
+
65
+
66
+ class TimestepEmbedding(nn.Module):
67
+ def __init__(
68
+ self,
69
+ in_channels: int,
70
+ time_embed_dim: int,
71
+ act_fn: str = "silu",
72
+ out_dim: int = None,
73
+ post_act_fn: Optional[str] = None,
74
+ cond_proj_dim=None,
75
+ sample_proj_bias=True,
76
+ ):
77
+ super().__init__()
78
+
79
+ self.linear_1 = nn.Linear(in_channels, time_embed_dim, sample_proj_bias)
80
+
81
+ if cond_proj_dim is not None:
82
+ self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
83
+ else:
84
+ self.cond_proj = None
85
+
86
+ self.act = torch.nn.SiLU()
87
+
88
+ if out_dim is not None:
89
+ time_embed_dim_out = out_dim
90
+ else:
91
+ time_embed_dim_out = time_embed_dim
92
+ self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias)
93
+
94
+ if post_act_fn is None:
95
+ self.post_act = None
96
+
97
+ def forward(self, sample, condition=None):
98
+ if condition is not None:
99
+ sample = sample + self.cond_proj(condition)
100
+ sample = self.linear_1(sample)
101
+
102
+ if self.act is not None:
103
+ sample = self.act(sample)
104
+
105
+ sample = self.linear_2(sample)
106
+
107
+ if self.post_act is not None:
108
+ sample = self.post_act(sample)
109
+ return sample
110
+
111
+
112
+ class Timesteps(nn.Module):
113
+ def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1):
114
+ super().__init__()
115
+ self.num_channels = num_channels
116
+ self.flip_sin_to_cos = flip_sin_to_cos
117
+ self.downscale_freq_shift = downscale_freq_shift
118
+ self.scale = scale
119
+
120
+ def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
121
+ t_emb = get_timestep_embedding(
122
+ timesteps,
123
+ self.num_channels,
124
+ flip_sin_to_cos=self.flip_sin_to_cos,
125
+ downscale_freq_shift=self.downscale_freq_shift,
126
+ scale=self.scale,
127
+ )
128
+ return t_emb
129
+
130
+
131
+ class AdaLayerNormContinuous(nn.Module):
132
+ r"""
133
+ Adaptive normalization layer with a norm layer (layer_norm or rms_norm).
134
+
135
+ Args:
136
+ embedding_dim (`int`): Embedding dimension to use during projection.
137
+ conditioning_embedding_dim (`int`): Dimension of the input condition.
138
+ elementwise_affine (`bool`, defaults to `True`):
139
+ Boolean flag to denote if affine transformation should be applied.
140
+ eps (`float`, defaults to 1e-5): Epsilon factor.
141
+ bias (`bias`, defaults to `True`): Boolean flag to denote if bias should be use.
142
+ norm_type (`str`, defaults to `"layer_norm"`):
143
+ Normalization layer to use. Values supported: "layer_norm", "rms_norm".
144
+ """
145
+
146
+ def __init__(
147
+ self,
148
+ embedding_dim: int,
149
+ conditioning_embedding_dim: int,
150
+ # NOTE: It is a bit weird that the norm layer can be configured to have scale and shift parameters
151
+ # because the output is immediately scaled and shifted by the projected conditioning embeddings.
152
+ # Note that AdaLayerNorm does not let the norm layer have scale and shift parameters.
153
+ # However, this is how it was implemented in the original code, and it's rather likely you should
154
+ # set `elementwise_affine` to False.
155
+ elementwise_affine=True,
156
+ eps=1e-5,
157
+ bias=True,
158
+ norm_type="layer_norm",
159
+ ):
160
+ super().__init__()
161
+ self.silu = nn.SiLU()
162
+ self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias)
163
+ if norm_type == "layer_norm":
164
+ self.norm = nn.LayerNorm(embedding_dim, eps, elementwise_affine, bias)
165
+
166
+ def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor:
167
+ # convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT)
168
+ emb = self.linear(self.silu(conditioning_embedding).to(x.dtype))
169
+ scale, shift = torch.chunk(emb, 2, dim=1)
170
+ x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
171
+ return x
172
+
173
+
174
+ def get_1d_rotary_pos_embed(
175
+ dim: int,
176
+ pos: Union[np.ndarray, int],
177
+ theta: float = 10000.0,
178
+ use_real=False,
179
+ linear_factor=1.0,
180
+ ntk_factor=1.0,
181
+ repeat_interleave_real=True,
182
+ freqs_dtype=torch.float32, # torch.float32, torch.float64 (flux)
183
+ ):
184
+ """
185
+ Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
186
+
187
+ This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end
188
+ index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64
189
+ data type.
190
+
191
+ Args:
192
+ dim (`int`): Dimension of the frequency tensor.
193
+ pos (`np.ndarray` or `int`): Position indices for the frequency tensor. [S] or scalar
194
+ theta (`float`, *optional*, defaults to 10000.0):
195
+ Scaling factor for frequency computation. Defaults to 10000.0.
196
+ use_real (`bool`, *optional*):
197
+ If True, return real part and imaginary part separately. Otherwise, return complex numbers.
198
+ linear_factor (`float`, *optional*, defaults to 1.0):
199
+ Scaling factor for the context extrapolation. Defaults to 1.0.
200
+ ntk_factor (`float`, *optional*, defaults to 1.0):
201
+ Scaling factor for the NTK-Aware RoPE. Defaults to 1.0.
202
+ repeat_interleave_real (`bool`, *optional*, defaults to `True`):
203
+ If `True` and `use_real`, real part and imaginary part are each interleaved with themselves to reach `dim`.
204
+ Otherwise, they are concateanted with themselves.
205
+ freqs_dtype (`torch.float32` or `torch.float64`, *optional*, defaults to `torch.float32`):
206
+ the dtype of the frequency tensor.
207
+ Returns:
208
+ `torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2]
209
+ """
210
+ assert dim % 2 == 0
211
+
212
+ if isinstance(pos, int):
213
+ pos = torch.arange(pos)
214
+ if isinstance(pos, np.ndarray):
215
+ pos = torch.from_numpy(pos) # type: ignore # [S]
216
+
217
+ theta = theta * ntk_factor
218
+ freqs = (
219
+ 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device) / dim)) / linear_factor
220
+ ) # [D/2]
221
+ freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2]
222
+ is_npu = freqs.device.type == "npu"
223
+ if is_npu:
224
+ freqs = freqs.float()
225
+ if use_real and repeat_interleave_real:
226
+ # flux, hunyuan-dit, cogvideox
227
+ freqs_cos = freqs.cos().repeat_interleave(2, dim=1, output_size=freqs.shape[1] * 2).float() # [S, D]
228
+ freqs_sin = freqs.sin().repeat_interleave(2, dim=1, output_size=freqs.shape[1] * 2).float() # [S, D]
229
+ return freqs_cos, freqs_sin
230
+ elif use_real:
231
+ # stable audio, allegro
232
+ freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D]
233
+ freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D]
234
+ return freqs_cos, freqs_sin
235
+ else:
236
+ # lumina
237
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
238
+ return freqs_cis
239
+
240
+
241
+ def apply_rotary_emb(
242
+ x: torch.Tensor,
243
+ freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
244
+ use_real: bool = True,
245
+ use_real_unbind_dim: int = -1,
246
+ sequence_dim: int = 2,
247
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
248
+ """
249
+ Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
250
+ to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
251
+ reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
252
+ tensors contain rotary embeddings and are returned as real tensors.
253
+
254
+ Args:
255
+ x (`torch.Tensor`):
256
+ Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
257
+ freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
258
+
259
+ Returns:
260
+ Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
261
+ """
262
+ if use_real:
263
+ cos, sin = freqs_cis # [S, D]
264
+ if sequence_dim == 2:
265
+ cos = cos[None, None, :, :]
266
+ sin = sin[None, None, :, :]
267
+ elif sequence_dim == 1:
268
+ cos = cos[None, :, None, :]
269
+ sin = sin[None, :, None, :]
270
+ else:
271
+ raise ValueError(f"`sequence_dim={sequence_dim}` but should be 1 or 2.")
272
+
273
+ cos, sin = cos.to(x.device), sin.to(x.device)
274
+
275
+ if use_real_unbind_dim == -1:
276
+ # Used for flux, cogvideox, hunyuan-dit
277
+ x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, H, S, D//2]
278
+ x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
279
+ elif use_real_unbind_dim == -2:
280
+ # Used for Stable Audio, OmniGen, CogView4 and Cosmos
281
+ x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, H, S, D//2]
282
+ x_rotated = torch.cat([-x_imag, x_real], dim=-1)
283
+ else:
284
+ raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
285
+
286
+ out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
287
+
288
+ return out
289
+ else:
290
+ # used for lumina
291
+ x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
292
+ freqs_cis = freqs_cis.unsqueeze(2)
293
+ x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
294
+
295
+ return x_out.type_as(x)
296
+
297
+ def _get_projections(attn: "Flux2Attention", hidden_states, encoder_hidden_states=None):
298
+ query = attn.to_q(hidden_states)
299
+ key = attn.to_k(hidden_states)
300
+ value = attn.to_v(hidden_states)
301
+
302
+ encoder_query = encoder_key = encoder_value = None
303
+ if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None:
304
+ encoder_query = attn.add_q_proj(encoder_hidden_states)
305
+ encoder_key = attn.add_k_proj(encoder_hidden_states)
306
+ encoder_value = attn.add_v_proj(encoder_hidden_states)
307
+
308
+ return query, key, value, encoder_query, encoder_key, encoder_value
309
+
310
+
311
+ def _get_fused_projections(attn: "Flux2Attention", hidden_states, encoder_hidden_states=None):
312
+ query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1)
313
+
314
+ encoder_query = encoder_key = encoder_value = (None,)
315
+ if encoder_hidden_states is not None and hasattr(attn, "to_added_qkv"):
316
+ encoder_query, encoder_key, encoder_value = attn.to_added_qkv(encoder_hidden_states).chunk(3, dim=-1)
317
+
318
+ return query, key, value, encoder_query, encoder_key, encoder_value
319
+
320
+
321
+ def _get_qkv_projections(attn: "Flux2Attention", hidden_states, encoder_hidden_states=None):
322
+ return _get_projections(attn, hidden_states, encoder_hidden_states)
323
+
324
+
325
+ class Flux2SwiGLU(nn.Module):
326
+ """
327
+ Flux 2 uses a SwiGLU-style activation in the transformer feedforward sub-blocks, but with the linear projection
328
+ layer fused into the first linear layer of the FF sub-block. Thus, this module has no trainable parameters.
329
+ """
330
+
331
+ def __init__(self):
332
+ super().__init__()
333
+ self.gate_fn = nn.SiLU()
334
+
335
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
336
+ x1, x2 = x.chunk(2, dim=-1)
337
+ x = self.gate_fn(x1) * x2
338
+ return x
339
+
340
+
341
+ class Flux2FeedForward(nn.Module):
342
+ def __init__(
343
+ self,
344
+ dim: int,
345
+ dim_out: Optional[int] = None,
346
+ mult: float = 3.0,
347
+ inner_dim: Optional[int] = None,
348
+ bias: bool = False,
349
+ ):
350
+ super().__init__()
351
+ if inner_dim is None:
352
+ inner_dim = int(dim * mult)
353
+ dim_out = dim_out or dim
354
+
355
+ # Flux2SwiGLU will reduce the dimension by half
356
+ self.linear_in = nn.Linear(dim, inner_dim * 2, bias=bias)
357
+ self.act_fn = Flux2SwiGLU()
358
+ self.linear_out = nn.Linear(inner_dim, dim_out, bias=bias)
359
+
360
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
361
+ x = self.linear_in(x)
362
+ x = self.act_fn(x)
363
+ x = self.linear_out(x)
364
+ return x
365
+
366
+
367
+ class Flux2AttnProcessor:
368
+ _attention_backend = None
369
+ _parallel_config = None
370
+
371
+ def __init__(self):
372
+ if not hasattr(F, "scaled_dot_product_attention"):
373
+ raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.")
374
+
375
+ def __call__(
376
+ self,
377
+ attn: "Flux2Attention",
378
+ hidden_states: torch.Tensor,
379
+ encoder_hidden_states: torch.Tensor = None,
380
+ attention_mask: Optional[torch.Tensor] = None,
381
+ image_rotary_emb: Optional[torch.Tensor] = None,
382
+ ) -> torch.Tensor:
383
+ query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections(
384
+ attn, hidden_states, encoder_hidden_states
385
+ )
386
+
387
+ query = query.unflatten(-1, (attn.heads, -1))
388
+ key = key.unflatten(-1, (attn.heads, -1))
389
+ value = value.unflatten(-1, (attn.heads, -1))
390
+
391
+ query = attn.norm_q(query)
392
+ key = attn.norm_k(key)
393
+
394
+ if attn.added_kv_proj_dim is not None:
395
+ encoder_query = encoder_query.unflatten(-1, (attn.heads, -1))
396
+ encoder_key = encoder_key.unflatten(-1, (attn.heads, -1))
397
+ encoder_value = encoder_value.unflatten(-1, (attn.heads, -1))
398
+
399
+ encoder_query = attn.norm_added_q(encoder_query)
400
+ encoder_key = attn.norm_added_k(encoder_key)
401
+
402
+ query = torch.cat([encoder_query, query], dim=1)
403
+ key = torch.cat([encoder_key, key], dim=1)
404
+ value = torch.cat([encoder_value, value], dim=1)
405
+
406
+ if image_rotary_emb is not None:
407
+ query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
408
+ key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
409
+
410
+ query, key, value = query.to(hidden_states.dtype), key.to(hidden_states.dtype), value.to(hidden_states.dtype)
411
+ hidden_states = attention_forward(
412
+ query,
413
+ key,
414
+ value,
415
+ q_pattern="b s n d", k_pattern="b s n d", v_pattern="b s n d", out_pattern="b s n d",
416
+ )
417
+ hidden_states = hidden_states.flatten(2, 3)
418
+ hidden_states = hidden_states.to(query.dtype)
419
+
420
+ if encoder_hidden_states is not None:
421
+ encoder_hidden_states, hidden_states = hidden_states.split_with_sizes(
422
+ [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1
423
+ )
424
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
425
+
426
+ hidden_states = attn.to_out[0](hidden_states)
427
+ hidden_states = attn.to_out[1](hidden_states)
428
+
429
+ if encoder_hidden_states is not None:
430
+ return hidden_states, encoder_hidden_states
431
+ else:
432
+ return hidden_states
433
+
434
+
435
+ class Flux2Attention(torch.nn.Module):
436
+ _default_processor_cls = Flux2AttnProcessor
437
+ _available_processors = [Flux2AttnProcessor]
438
+
439
+ def __init__(
440
+ self,
441
+ query_dim: int,
442
+ heads: int = 8,
443
+ dim_head: int = 64,
444
+ dropout: float = 0.0,
445
+ bias: bool = False,
446
+ added_kv_proj_dim: Optional[int] = None,
447
+ added_proj_bias: Optional[bool] = True,
448
+ out_bias: bool = True,
449
+ eps: float = 1e-5,
450
+ out_dim: int = None,
451
+ elementwise_affine: bool = True,
452
+ processor=None,
453
+ ):
454
+ super().__init__()
455
+
456
+ self.head_dim = dim_head
457
+ self.inner_dim = out_dim if out_dim is not None else dim_head * heads
458
+ self.query_dim = query_dim
459
+ self.out_dim = out_dim if out_dim is not None else query_dim
460
+ self.heads = out_dim // dim_head if out_dim is not None else heads
461
+
462
+ self.use_bias = bias
463
+ self.dropout = dropout
464
+
465
+ self.added_kv_proj_dim = added_kv_proj_dim
466
+ self.added_proj_bias = added_proj_bias
467
+
468
+ self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
469
+ self.to_k = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
470
+ self.to_v = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
471
+
472
+ # QK Norm
473
+ self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
474
+ self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
475
+
476
+ self.to_out = torch.nn.ModuleList([])
477
+ self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
478
+ self.to_out.append(torch.nn.Dropout(dropout))
479
+
480
+ if added_kv_proj_dim is not None:
481
+ self.norm_added_q = torch.nn.RMSNorm(dim_head, eps=eps)
482
+ self.norm_added_k = torch.nn.RMSNorm(dim_head, eps=eps)
483
+ self.add_q_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
484
+ self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
485
+ self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
486
+ self.to_add_out = torch.nn.Linear(self.inner_dim, query_dim, bias=out_bias)
487
+
488
+ if processor is None:
489
+ processor = self._default_processor_cls()
490
+ self.processor = processor
491
+
492
+ def forward(
493
+ self,
494
+ hidden_states: torch.Tensor,
495
+ encoder_hidden_states: Optional[torch.Tensor] = None,
496
+ attention_mask: Optional[torch.Tensor] = None,
497
+ image_rotary_emb: Optional[torch.Tensor] = None,
498
+ **kwargs,
499
+ ) -> torch.Tensor:
500
+ attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
501
+ kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters}
502
+ return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs)
503
+
504
+
505
+ class Flux2ParallelSelfAttnProcessor:
506
+ _attention_backend = None
507
+ _parallel_config = None
508
+
509
+ def __init__(self):
510
+ if not hasattr(F, "scaled_dot_product_attention"):
511
+ raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.")
512
+
513
+ def __call__(
514
+ self,
515
+ attn: "Flux2ParallelSelfAttention",
516
+ hidden_states: torch.Tensor,
517
+ attention_mask: Optional[torch.Tensor] = None,
518
+ image_rotary_emb: Optional[torch.Tensor] = None,
519
+ ) -> torch.Tensor:
520
+ # Parallel in (QKV + MLP in) projection
521
+ hidden_states = attn.to_qkv_mlp_proj(hidden_states)
522
+ qkv, mlp_hidden_states = torch.split(
523
+ hidden_states, [3 * attn.inner_dim, attn.mlp_hidden_dim * attn.mlp_mult_factor], dim=-1
524
+ )
525
+
526
+ # Handle the attention logic
527
+ query, key, value = qkv.chunk(3, dim=-1)
528
+
529
+ query = query.unflatten(-1, (attn.heads, -1))
530
+ key = key.unflatten(-1, (attn.heads, -1))
531
+ value = value.unflatten(-1, (attn.heads, -1))
532
+
533
+ query = attn.norm_q(query)
534
+ key = attn.norm_k(key)
535
+
536
+ if image_rotary_emb is not None:
537
+ query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
538
+ key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
539
+
540
+ query, key, value = query.to(hidden_states.dtype), key.to(hidden_states.dtype), value.to(hidden_states.dtype)
541
+ hidden_states = attention_forward(
542
+ query,
543
+ key,
544
+ value,
545
+ q_pattern="b s n d", k_pattern="b s n d", v_pattern="b s n d", out_pattern="b s n d",
546
+ )
547
+ hidden_states = hidden_states.flatten(2, 3)
548
+ hidden_states = hidden_states.to(query.dtype)
549
+
550
+ # Handle the feedforward (FF) logic
551
+ mlp_hidden_states = attn.mlp_act_fn(mlp_hidden_states)
552
+
553
+ # Concatenate and parallel output projection
554
+ hidden_states = torch.cat([hidden_states, mlp_hidden_states], dim=-1)
555
+ hidden_states = attn.to_out(hidden_states)
556
+
557
+ return hidden_states
558
+
559
+
560
+ class Flux2ParallelSelfAttention(torch.nn.Module):
561
+ """
562
+ Flux 2 parallel self-attention for the Flux 2 single-stream transformer blocks.
563
+
564
+ This implements a parallel transformer block, where the attention QKV projections are fused to the feedforward (FF)
565
+ input projections, and the attention output projections are fused to the FF output projections. See the [ViT-22B
566
+ paper](https://arxiv.org/abs/2302.05442) for a visual depiction of this type of transformer block.
567
+ """
568
+
569
+ _default_processor_cls = Flux2ParallelSelfAttnProcessor
570
+ _available_processors = [Flux2ParallelSelfAttnProcessor]
571
+ # Does not support QKV fusion as the QKV projections are always fused
572
+ _supports_qkv_fusion = False
573
+
574
+ def __init__(
575
+ self,
576
+ query_dim: int,
577
+ heads: int = 8,
578
+ dim_head: int = 64,
579
+ dropout: float = 0.0,
580
+ bias: bool = False,
581
+ out_bias: bool = True,
582
+ eps: float = 1e-5,
583
+ out_dim: int = None,
584
+ elementwise_affine: bool = True,
585
+ mlp_ratio: float = 4.0,
586
+ mlp_mult_factor: int = 2,
587
+ processor=None,
588
+ ):
589
+ super().__init__()
590
+
591
+ self.head_dim = dim_head
592
+ self.inner_dim = out_dim if out_dim is not None else dim_head * heads
593
+ self.query_dim = query_dim
594
+ self.out_dim = out_dim if out_dim is not None else query_dim
595
+ self.heads = out_dim // dim_head if out_dim is not None else heads
596
+
597
+ self.use_bias = bias
598
+ self.dropout = dropout
599
+
600
+ self.mlp_ratio = mlp_ratio
601
+ self.mlp_hidden_dim = int(query_dim * self.mlp_ratio)
602
+ self.mlp_mult_factor = mlp_mult_factor
603
+
604
+ # Fused QKV projections + MLP input projection
605
+ self.to_qkv_mlp_proj = torch.nn.Linear(
606
+ self.query_dim, self.inner_dim * 3 + self.mlp_hidden_dim * self.mlp_mult_factor, bias=bias
607
+ )
608
+ self.mlp_act_fn = Flux2SwiGLU()
609
+
610
+ # QK Norm
611
+ self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
612
+ self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
613
+
614
+ # Fused attention output projection + MLP output projection
615
+ self.to_out = torch.nn.Linear(self.inner_dim + self.mlp_hidden_dim, self.out_dim, bias=out_bias)
616
+
617
+ if processor is None:
618
+ processor = self._default_processor_cls()
619
+ self.processor = processor
620
+
621
+ def forward(
622
+ self,
623
+ hidden_states: torch.Tensor,
624
+ attention_mask: Optional[torch.Tensor] = None,
625
+ image_rotary_emb: Optional[torch.Tensor] = None,
626
+ **kwargs,
627
+ ) -> torch.Tensor:
628
+ attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
629
+ kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters}
630
+ return self.processor(self, hidden_states, attention_mask, image_rotary_emb, **kwargs)
631
+
632
+
633
+ class Flux2SingleTransformerBlock(nn.Module):
634
+ def __init__(
635
+ self,
636
+ dim: int,
637
+ num_attention_heads: int,
638
+ attention_head_dim: int,
639
+ mlp_ratio: float = 3.0,
640
+ eps: float = 1e-6,
641
+ bias: bool = False,
642
+ ):
643
+ super().__init__()
644
+
645
+ self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
646
+
647
+ # Note that the MLP in/out linear layers are fused with the attention QKV/out projections, respectively; this
648
+ # is often called a "parallel" transformer block. See the [ViT-22B paper](https://arxiv.org/abs/2302.05442)
649
+ # for a visual depiction of this type of transformer block.
650
+ self.attn = Flux2ParallelSelfAttention(
651
+ query_dim=dim,
652
+ dim_head=attention_head_dim,
653
+ heads=num_attention_heads,
654
+ out_dim=dim,
655
+ bias=bias,
656
+ out_bias=bias,
657
+ eps=eps,
658
+ mlp_ratio=mlp_ratio,
659
+ mlp_mult_factor=2,
660
+ processor=Flux2ParallelSelfAttnProcessor(),
661
+ )
662
+
663
+ def forward(
664
+ self,
665
+ hidden_states: torch.Tensor,
666
+ encoder_hidden_states: Optional[torch.Tensor],
667
+ temb_mod_params: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
668
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
669
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
670
+ split_hidden_states: bool = False,
671
+ text_seq_len: Optional[int] = None,
672
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
673
+ # If encoder_hidden_states is None, hidden_states is assumed to have encoder_hidden_states already
674
+ # concatenated
675
+ if encoder_hidden_states is not None:
676
+ text_seq_len = encoder_hidden_states.shape[1]
677
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
678
+
679
+ mod_shift, mod_scale, mod_gate = temb_mod_params
680
+
681
+ norm_hidden_states = self.norm(hidden_states)
682
+ norm_hidden_states = (1 + mod_scale) * norm_hidden_states + mod_shift
683
+
684
+ joint_attention_kwargs = joint_attention_kwargs or {}
685
+ attn_output = self.attn(
686
+ hidden_states=norm_hidden_states,
687
+ image_rotary_emb=image_rotary_emb,
688
+ **joint_attention_kwargs,
689
+ )
690
+
691
+ hidden_states = hidden_states + mod_gate * attn_output
692
+ if hidden_states.dtype == torch.float16:
693
+ hidden_states = hidden_states.clip(-65504, 65504)
694
+
695
+ if split_hidden_states:
696
+ encoder_hidden_states, hidden_states = hidden_states[:, :text_seq_len], hidden_states[:, text_seq_len:]
697
+ return encoder_hidden_states, hidden_states
698
+ else:
699
+ return hidden_states
700
+
701
+
702
+ class Flux2TransformerBlock(nn.Module):
703
+ def __init__(
704
+ self,
705
+ dim: int,
706
+ num_attention_heads: int,
707
+ attention_head_dim: int,
708
+ mlp_ratio: float = 3.0,
709
+ eps: float = 1e-6,
710
+ bias: bool = False,
711
+ ):
712
+ super().__init__()
713
+ self.mlp_hidden_dim = int(dim * mlp_ratio)
714
+
715
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
716
+ self.norm1_context = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
717
+
718
+ self.attn = Flux2Attention(
719
+ query_dim=dim,
720
+ added_kv_proj_dim=dim,
721
+ dim_head=attention_head_dim,
722
+ heads=num_attention_heads,
723
+ out_dim=dim,
724
+ bias=bias,
725
+ added_proj_bias=bias,
726
+ out_bias=bias,
727
+ eps=eps,
728
+ processor=Flux2AttnProcessor(),
729
+ )
730
+
731
+ self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
732
+ self.ff = Flux2FeedForward(dim=dim, dim_out=dim, mult=mlp_ratio, bias=bias)
733
+
734
+ self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
735
+ self.ff_context = Flux2FeedForward(dim=dim, dim_out=dim, mult=mlp_ratio, bias=bias)
736
+
737
+ def forward(
738
+ self,
739
+ hidden_states: torch.Tensor,
740
+ encoder_hidden_states: torch.Tensor,
741
+ temb_mod_params_img: Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...],
742
+ temb_mod_params_txt: Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...],
743
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
744
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
745
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
746
+ joint_attention_kwargs = joint_attention_kwargs or {}
747
+
748
+ # Modulation parameters shape: [1, 1, self.dim]
749
+ (shift_msa, scale_msa, gate_msa), (shift_mlp, scale_mlp, gate_mlp) = temb_mod_params_img
750
+ (c_shift_msa, c_scale_msa, c_gate_msa), (c_shift_mlp, c_scale_mlp, c_gate_mlp) = temb_mod_params_txt
751
+
752
+ # Img stream
753
+ norm_hidden_states = self.norm1(hidden_states)
754
+ norm_hidden_states = (1 + scale_msa) * norm_hidden_states + shift_msa
755
+
756
+ # Conditioning txt stream
757
+ norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states)
758
+ norm_encoder_hidden_states = (1 + c_scale_msa) * norm_encoder_hidden_states + c_shift_msa
759
+
760
+ # Attention on concatenated img + txt stream
761
+ attention_outputs = self.attn(
762
+ hidden_states=norm_hidden_states,
763
+ encoder_hidden_states=norm_encoder_hidden_states,
764
+ image_rotary_emb=image_rotary_emb,
765
+ **joint_attention_kwargs,
766
+ )
767
+
768
+ attn_output, context_attn_output = attention_outputs
769
+
770
+ # Process attention outputs for the image stream (`hidden_states`).
771
+ attn_output = gate_msa * attn_output
772
+ hidden_states = hidden_states + attn_output
773
+
774
+ norm_hidden_states = self.norm2(hidden_states)
775
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
776
+
777
+ ff_output = self.ff(norm_hidden_states)
778
+ hidden_states = hidden_states + gate_mlp * ff_output
779
+
780
+ # Process attention outputs for the text stream (`encoder_hidden_states`).
781
+ context_attn_output = c_gate_msa * context_attn_output
782
+ encoder_hidden_states = encoder_hidden_states + context_attn_output
783
+
784
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
785
+ norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp) + c_shift_mlp
786
+
787
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
788
+ encoder_hidden_states = encoder_hidden_states + c_gate_mlp * context_ff_output
789
+ if encoder_hidden_states.dtype == torch.float16:
790
+ encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
791
+
792
+ return encoder_hidden_states, hidden_states
793
+
794
+
795
+ class Flux2PosEmbed(nn.Module):
796
+ # modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
797
+ def __init__(self, theta: int, axes_dim: List[int]):
798
+ super().__init__()
799
+ self.theta = theta
800
+ self.axes_dim = axes_dim
801
+
802
+ def forward(self, ids: torch.Tensor) -> torch.Tensor:
803
+ # Expected ids shape: [S, len(self.axes_dim)]
804
+ cos_out = []
805
+ sin_out = []
806
+ pos = ids.float()
807
+ is_mps = ids.device.type == "mps"
808
+ is_npu = ids.device.type == "npu"
809
+ freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64
810
+ # Unlike Flux 1, loop over len(self.axes_dim) rather than ids.shape[-1]
811
+ for i in range(len(self.axes_dim)):
812
+ cos, sin = get_1d_rotary_pos_embed(
813
+ self.axes_dim[i],
814
+ pos[..., i],
815
+ theta=self.theta,
816
+ repeat_interleave_real=True,
817
+ use_real=True,
818
+ freqs_dtype=freqs_dtype,
819
+ )
820
+ cos_out.append(cos)
821
+ sin_out.append(sin)
822
+ freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
823
+ freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
824
+ return freqs_cos, freqs_sin
825
+
826
+
827
+ class Flux2TimestepGuidanceEmbeddings(nn.Module):
828
+ def __init__(
829
+ self,
830
+ in_channels: int = 256,
831
+ embedding_dim: int = 6144,
832
+ bias: bool = False,
833
+ guidance_embeds: bool = True,
834
+ ):
835
+ super().__init__()
836
+
837
+ self.time_proj = Timesteps(num_channels=in_channels, flip_sin_to_cos=True, downscale_freq_shift=0)
838
+ self.timestep_embedder = TimestepEmbedding(
839
+ in_channels=in_channels, time_embed_dim=embedding_dim, sample_proj_bias=bias
840
+ )
841
+
842
+ if guidance_embeds:
843
+ self.guidance_embedder = TimestepEmbedding(
844
+ in_channels=in_channels, time_embed_dim=embedding_dim, sample_proj_bias=bias
845
+ )
846
+ else:
847
+ self.guidance_embedder = None
848
+
849
+ def forward(self, timestep: torch.Tensor, guidance: torch.Tensor) -> torch.Tensor:
850
+ timesteps_proj = self.time_proj(timestep)
851
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(timestep.dtype)) # (N, D)
852
+
853
+ if guidance is not None and self.guidance_embedder is not None:
854
+ guidance_proj = self.time_proj(guidance)
855
+ guidance_emb = self.guidance_embedder(guidance_proj.to(guidance.dtype)) # (N, D)
856
+ time_guidance_emb = timesteps_emb + guidance_emb
857
+ return time_guidance_emb
858
+ else:
859
+ return timesteps_emb
860
+
861
+
862
+ class Flux2Modulation(nn.Module):
863
+ def __init__(self, dim: int, mod_param_sets: int = 2, bias: bool = False):
864
+ super().__init__()
865
+ self.mod_param_sets = mod_param_sets
866
+
867
+ self.linear = nn.Linear(dim, dim * 3 * self.mod_param_sets, bias=bias)
868
+ self.act_fn = nn.SiLU()
869
+
870
+ def forward(self, temb: torch.Tensor) -> Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...]:
871
+ mod = self.act_fn(temb)
872
+ mod = self.linear(mod)
873
+
874
+ if mod.ndim == 2:
875
+ mod = mod.unsqueeze(1)
876
+ mod_params = torch.chunk(mod, 3 * self.mod_param_sets, dim=-1)
877
+ # Return tuple of 3-tuples of modulation params shift/scale/gate
878
+ return tuple(mod_params[3 * i : 3 * (i + 1)] for i in range(self.mod_param_sets))
879
+
880
+
881
+ class Flux2DiT(torch.nn.Module):
882
+
883
+ _repeated_blocks = ["Flux2TransformerBlock", "Flux2SingleTransformerBlock"]
884
+
885
+ def __init__(
886
+ self,
887
+ patch_size: int = 1,
888
+ in_channels: int = 128,
889
+ out_channels: Optional[int] = None,
890
+ num_layers: int = 8,
891
+ num_single_layers: int = 48,
892
+ attention_head_dim: int = 128,
893
+ num_attention_heads: int = 48,
894
+ joint_attention_dim: int = 15360,
895
+ timestep_guidance_channels: int = 256,
896
+ mlp_ratio: float = 3.0,
897
+ axes_dims_rope: Tuple[int, ...] = (32, 32, 32, 32),
898
+ rope_theta: int = 2000,
899
+ eps: float = 1e-6,
900
+ guidance_embeds: bool = True,
901
+ ):
902
+ super().__init__()
903
+ self.out_channels = out_channels or in_channels
904
+ self.inner_dim = num_attention_heads * attention_head_dim
905
+
906
+ # 1. Sinusoidal positional embedding for RoPE on image and text tokens
907
+ self.pos_embed = Flux2PosEmbed(theta=rope_theta, axes_dim=axes_dims_rope)
908
+
909
+ # 2. Combined timestep + guidance embedding
910
+ self.time_guidance_embed = Flux2TimestepGuidanceEmbeddings(
911
+ in_channels=timestep_guidance_channels,
912
+ embedding_dim=self.inner_dim,
913
+ bias=False,
914
+ guidance_embeds=guidance_embeds,
915
+ )
916
+
917
+ # 3. Modulation (double stream and single stream blocks share modulation parameters, resp.)
918
+ # Two sets of shift/scale/gate modulation parameters for the double stream attn and FF sub-blocks
919
+ self.double_stream_modulation_img = Flux2Modulation(self.inner_dim, mod_param_sets=2, bias=False)
920
+ self.double_stream_modulation_txt = Flux2Modulation(self.inner_dim, mod_param_sets=2, bias=False)
921
+ # Only one set of modulation parameters as the attn and FF sub-blocks are run in parallel for single stream
922
+ self.single_stream_modulation = Flux2Modulation(self.inner_dim, mod_param_sets=1, bias=False)
923
+
924
+ # 4. Input projections
925
+ self.x_embedder = nn.Linear(in_channels, self.inner_dim, bias=False)
926
+ self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim, bias=False)
927
+
928
+ # 5. Double Stream Transformer Blocks
929
+ self.transformer_blocks = nn.ModuleList(
930
+ [
931
+ Flux2TransformerBlock(
932
+ dim=self.inner_dim,
933
+ num_attention_heads=num_attention_heads,
934
+ attention_head_dim=attention_head_dim,
935
+ mlp_ratio=mlp_ratio,
936
+ eps=eps,
937
+ bias=False,
938
+ )
939
+ for _ in range(num_layers)
940
+ ]
941
+ )
942
+
943
+ # 6. Single Stream Transformer Blocks
944
+ self.single_transformer_blocks = nn.ModuleList(
945
+ [
946
+ Flux2SingleTransformerBlock(
947
+ dim=self.inner_dim,
948
+ num_attention_heads=num_attention_heads,
949
+ attention_head_dim=attention_head_dim,
950
+ mlp_ratio=mlp_ratio,
951
+ eps=eps,
952
+ bias=False,
953
+ )
954
+ for _ in range(num_single_layers)
955
+ ]
956
+ )
957
+
958
+ # 7. Output layers
959
+ self.norm_out = AdaLayerNormContinuous(
960
+ self.inner_dim, self.inner_dim, elementwise_affine=False, eps=eps, bias=False
961
+ )
962
+ self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=False)
963
+
964
+ self.gradient_checkpointing = False
965
+
966
+ def forward(
967
+ self,
968
+ hidden_states: torch.Tensor,
969
+ encoder_hidden_states: torch.Tensor = None,
970
+ timestep: torch.LongTensor = None,
971
+ img_ids: torch.Tensor = None,
972
+ txt_ids: torch.Tensor = None,
973
+ guidance: torch.Tensor = None,
974
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
975
+ use_gradient_checkpointing=False,
976
+ use_gradient_checkpointing_offload=False,
977
+ ):
978
+ # 0. Handle input arguments
979
+ if joint_attention_kwargs is not None:
980
+ joint_attention_kwargs = joint_attention_kwargs.copy()
981
+ lora_scale = joint_attention_kwargs.pop("scale", 1.0)
982
+ else:
983
+ lora_scale = 1.0
984
+
985
+ num_txt_tokens = encoder_hidden_states.shape[1]
986
+
987
+ # 1. Calculate timestep embedding and modulation parameters
988
+ timestep = timestep.to(hidden_states.dtype) * 1000
989
+
990
+ if guidance is not None:
991
+ guidance = guidance.to(hidden_states.dtype) * 1000
992
+
993
+ temb = self.time_guidance_embed(timestep, guidance)
994
+
995
+ double_stream_mod_img = self.double_stream_modulation_img(temb)
996
+ double_stream_mod_txt = self.double_stream_modulation_txt(temb)
997
+ single_stream_mod = self.single_stream_modulation(temb)[0]
998
+
999
+ # 2. Input projection for image (hidden_states) and conditioning text (encoder_hidden_states)
1000
+ hidden_states = self.x_embedder(hidden_states)
1001
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
1002
+
1003
+ # 3. Calculate RoPE embeddings from image and text tokens
1004
+ # NOTE: the below logic means that we can't support batched inference with images of different resolutions or
1005
+ # text prompts of differents lengths. Is this a use case we want to support?
1006
+ if img_ids.ndim == 3:
1007
+ img_ids = img_ids[0]
1008
+ if txt_ids.ndim == 3:
1009
+ txt_ids = txt_ids[0]
1010
+
1011
+ image_rotary_emb = self.pos_embed(img_ids)
1012
+ text_rotary_emb = self.pos_embed(txt_ids)
1013
+ concat_rotary_emb = (
1014
+ torch.cat([text_rotary_emb[0], image_rotary_emb[0]], dim=0),
1015
+ torch.cat([text_rotary_emb[1], image_rotary_emb[1]], dim=0),
1016
+ )
1017
+
1018
+ # 4. Double Stream Transformer Blocks
1019
+ for index_block, block in enumerate(self.transformer_blocks):
1020
+ encoder_hidden_states, hidden_states = gradient_checkpoint_forward(
1021
+ block,
1022
+ use_gradient_checkpointing=use_gradient_checkpointing,
1023
+ use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
1024
+ hidden_states=hidden_states,
1025
+ encoder_hidden_states=encoder_hidden_states,
1026
+ temb_mod_params_img=double_stream_mod_img,
1027
+ temb_mod_params_txt=double_stream_mod_txt,
1028
+ image_rotary_emb=concat_rotary_emb,
1029
+ joint_attention_kwargs=joint_attention_kwargs,
1030
+ )
1031
+ # Concatenate text and image streams for single-block inference
1032
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
1033
+
1034
+ # 5. Single Stream Transformer Blocks
1035
+ for index_block, block in enumerate(self.single_transformer_blocks):
1036
+ hidden_states = gradient_checkpoint_forward(
1037
+ block,
1038
+ use_gradient_checkpointing=use_gradient_checkpointing,
1039
+ use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
1040
+ hidden_states=hidden_states,
1041
+ encoder_hidden_states=None,
1042
+ temb_mod_params=single_stream_mod,
1043
+ image_rotary_emb=concat_rotary_emb,
1044
+ joint_attention_kwargs=joint_attention_kwargs,
1045
+ )
1046
+ # Remove text tokens from concatenated stream
1047
+ hidden_states = hidden_states[:, num_txt_tokens:, ...]
1048
+
1049
+ # 6. Output layers
1050
+ hidden_states = self.norm_out(hidden_states, temb)
1051
+ output = self.proj_out(hidden_states)
1052
+
1053
+ return output
diffsynth/models/flux2_text_encoder.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import Mistral3ForConditionalGeneration, Mistral3Config
2
+
3
+
4
+ class Flux2TextEncoder(Mistral3ForConditionalGeneration):
5
+ def __init__(self):
6
+ config = Mistral3Config(**{
7
+ "architectures": [
8
+ "Mistral3ForConditionalGeneration"
9
+ ],
10
+ "dtype": "bfloat16",
11
+ "image_token_index": 10,
12
+ "model_type": "mistral3",
13
+ "multimodal_projector_bias": False,
14
+ "projector_hidden_act": "gelu",
15
+ "spatial_merge_size": 2,
16
+ "text_config": {
17
+ "attention_dropout": 0.0,
18
+ "dtype": "bfloat16",
19
+ "head_dim": 128,
20
+ "hidden_act": "silu",
21
+ "hidden_size": 5120,
22
+ "initializer_range": 0.02,
23
+ "intermediate_size": 32768,
24
+ "max_position_embeddings": 131072,
25
+ "model_type": "mistral",
26
+ "num_attention_heads": 32,
27
+ "num_hidden_layers": 40,
28
+ "num_key_value_heads": 8,
29
+ "rms_norm_eps": 1e-05,
30
+ "rope_theta": 1000000000.0,
31
+ "sliding_window": None,
32
+ "use_cache": True,
33
+ "vocab_size": 131072
34
+ },
35
+ "transformers_version": "4.57.1",
36
+ "vision_config": {
37
+ "attention_dropout": 0.0,
38
+ "dtype": "bfloat16",
39
+ "head_dim": 64,
40
+ "hidden_act": "silu",
41
+ "hidden_size": 1024,
42
+ "image_size": 1540,
43
+ "initializer_range": 0.02,
44
+ "intermediate_size": 4096,
45
+ "model_type": "pixtral",
46
+ "num_attention_heads": 16,
47
+ "num_channels": 3,
48
+ "num_hidden_layers": 24,
49
+ "patch_size": 14,
50
+ "rope_theta": 10000.0
51
+ },
52
+ "vision_feature_layer": -1
53
+ })
54
+ super().__init__(config)
55
+
56
+ def forward(self, input_ids = None, pixel_values = None, attention_mask = None, position_ids = None, past_key_values = None, inputs_embeds = None, labels = None, use_cache = None, output_attentions = None, output_hidden_states = None, return_dict = None, cache_position = None, logits_to_keep = 0, image_sizes = None, **kwargs):
57
+ return super().forward(input_ids, pixel_values, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, logits_to_keep, image_sizes, **kwargs)
58
+
diffsynth/models/flux2_vae.py ADDED
The diff for this file is too large to render. See raw diff
 
diffsynth/models/flux_controlnet.py ADDED
@@ -0,0 +1,384 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from einops import rearrange, repeat
3
+ from .flux_dit import RoPEEmbedding, TimestepEmbeddings, FluxJointTransformerBlock, FluxSingleTransformerBlock, RMSNorm
4
+ # from .utils import hash_state_dict_keys, init_weights_on_device
5
+ from contextlib import contextmanager
6
+
7
+ def hash_state_dict_keys(state_dict, with_shape=True):
8
+ keys_str = convert_state_dict_keys_to_single_str(state_dict, with_shape=with_shape)
9
+ keys_str = keys_str.encode(encoding="UTF-8")
10
+ return hashlib.md5(keys_str).hexdigest()
11
+
12
+ @contextmanager
13
+ def init_weights_on_device(device = torch.device("meta"), include_buffers :bool = False):
14
+
15
+ old_register_parameter = torch.nn.Module.register_parameter
16
+ if include_buffers:
17
+ old_register_buffer = torch.nn.Module.register_buffer
18
+
19
+ def register_empty_parameter(module, name, param):
20
+ old_register_parameter(module, name, param)
21
+ if param is not None:
22
+ param_cls = type(module._parameters[name])
23
+ kwargs = module._parameters[name].__dict__
24
+ kwargs["requires_grad"] = param.requires_grad
25
+ module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs)
26
+
27
+ def register_empty_buffer(module, name, buffer, persistent=True):
28
+ old_register_buffer(module, name, buffer, persistent=persistent)
29
+ if buffer is not None:
30
+ module._buffers[name] = module._buffers[name].to(device)
31
+
32
+ def patch_tensor_constructor(fn):
33
+ def wrapper(*args, **kwargs):
34
+ kwargs["device"] = device
35
+ return fn(*args, **kwargs)
36
+
37
+ return wrapper
38
+
39
+ if include_buffers:
40
+ tensor_constructors_to_patch = {
41
+ torch_function_name: getattr(torch, torch_function_name)
42
+ for torch_function_name in ["empty", "zeros", "ones", "full"]
43
+ }
44
+ else:
45
+ tensor_constructors_to_patch = {}
46
+
47
+ try:
48
+ torch.nn.Module.register_parameter = register_empty_parameter
49
+ if include_buffers:
50
+ torch.nn.Module.register_buffer = register_empty_buffer
51
+ for torch_function_name in tensor_constructors_to_patch.keys():
52
+ setattr(torch, torch_function_name, patch_tensor_constructor(getattr(torch, torch_function_name)))
53
+ yield
54
+ finally:
55
+ torch.nn.Module.register_parameter = old_register_parameter
56
+ if include_buffers:
57
+ torch.nn.Module.register_buffer = old_register_buffer
58
+ for torch_function_name, old_torch_function in tensor_constructors_to_patch.items():
59
+ setattr(torch, torch_function_name, old_torch_function)
60
+
61
+ class FluxControlNet(torch.nn.Module):
62
+ def __init__(self, disable_guidance_embedder=False, num_joint_blocks=5, num_single_blocks=10, num_mode=0, mode_dict={}, additional_input_dim=0):
63
+ super().__init__()
64
+ self.pos_embedder = RoPEEmbedding(3072, 10000, [16, 56, 56])
65
+ self.time_embedder = TimestepEmbeddings(256, 3072)
66
+ self.guidance_embedder = None if disable_guidance_embedder else TimestepEmbeddings(256, 3072)
67
+ self.pooled_text_embedder = torch.nn.Sequential(torch.nn.Linear(768, 3072), torch.nn.SiLU(), torch.nn.Linear(3072, 3072))
68
+ self.context_embedder = torch.nn.Linear(4096, 3072)
69
+ self.x_embedder = torch.nn.Linear(64, 3072)
70
+
71
+ self.blocks = torch.nn.ModuleList([FluxJointTransformerBlock(3072, 24) for _ in range(num_joint_blocks)])
72
+ self.single_blocks = torch.nn.ModuleList([FluxSingleTransformerBlock(3072, 24) for _ in range(num_single_blocks)])
73
+
74
+ self.controlnet_blocks = torch.nn.ModuleList([torch.nn.Linear(3072, 3072) for _ in range(num_joint_blocks)])
75
+ self.controlnet_single_blocks = torch.nn.ModuleList([torch.nn.Linear(3072, 3072) for _ in range(num_single_blocks)])
76
+
77
+ self.mode_dict = mode_dict
78
+ self.controlnet_mode_embedder = torch.nn.Embedding(num_mode, 3072) if len(mode_dict) > 0 else None
79
+ self.controlnet_x_embedder = torch.nn.Linear(64 + additional_input_dim, 3072)
80
+
81
+
82
+ def prepare_image_ids(self, latents):
83
+ batch_size, _, height, width = latents.shape
84
+ latent_image_ids = torch.zeros(height // 2, width // 2, 3)
85
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
86
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
87
+
88
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
89
+
90
+ latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1)
91
+ latent_image_ids = latent_image_ids.reshape(
92
+ batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels
93
+ )
94
+ latent_image_ids = latent_image_ids.to(device=latents.device, dtype=latents.dtype)
95
+
96
+ return latent_image_ids
97
+
98
+
99
+ def patchify(self, hidden_states):
100
+ hidden_states = rearrange(hidden_states, "B C (H P) (W Q) -> B (H W) (C P Q)", P=2, Q=2)
101
+ return hidden_states
102
+
103
+
104
+ def align_res_stack_to_original_blocks(self, res_stack, num_blocks, hidden_states):
105
+ if len(res_stack) == 0:
106
+ return [torch.zeros_like(hidden_states)] * num_blocks
107
+ interval = (num_blocks + len(res_stack) - 1) // len(res_stack)
108
+ aligned_res_stack = [res_stack[block_id // interval] for block_id in range(num_blocks)]
109
+ return aligned_res_stack
110
+
111
+
112
+ def forward(
113
+ self,
114
+ hidden_states,
115
+ controlnet_conditioning,
116
+ timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids, image_ids=None,
117
+ processor_id=None,
118
+ tiled=False, tile_size=128, tile_stride=64,
119
+ **kwargs
120
+ ):
121
+ if image_ids is None:
122
+ image_ids = self.prepare_image_ids(hidden_states)
123
+
124
+ conditioning = self.time_embedder(timestep, hidden_states.dtype) + self.pooled_text_embedder(pooled_prompt_emb)
125
+ if self.guidance_embedder is not None:
126
+ guidance = guidance * 1000
127
+ conditioning = conditioning + self.guidance_embedder(guidance, hidden_states.dtype)
128
+ prompt_emb = self.context_embedder(prompt_emb)
129
+ if self.controlnet_mode_embedder is not None: # Different from FluxDiT
130
+ processor_id = torch.tensor([self.mode_dict[processor_id]], dtype=torch.int)
131
+ processor_id = repeat(processor_id, "D -> B D", B=1).to(text_ids.device)
132
+ prompt_emb = torch.concat([self.controlnet_mode_embedder(processor_id), prompt_emb], dim=1)
133
+ text_ids = torch.cat([text_ids[:, :1], text_ids], dim=1)
134
+ image_rotary_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
135
+
136
+ hidden_states = self.patchify(hidden_states)
137
+ hidden_states = self.x_embedder(hidden_states)
138
+ controlnet_conditioning = self.patchify(controlnet_conditioning) # Different from FluxDiT
139
+ hidden_states = hidden_states + self.controlnet_x_embedder(controlnet_conditioning) # Different from FluxDiT
140
+
141
+ controlnet_res_stack = []
142
+ for block, controlnet_block in zip(self.blocks, self.controlnet_blocks):
143
+ hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb)
144
+ controlnet_res_stack.append(controlnet_block(hidden_states))
145
+
146
+ controlnet_single_res_stack = []
147
+ hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
148
+ for block, controlnet_block in zip(self.single_blocks, self.controlnet_single_blocks):
149
+ hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb)
150
+ controlnet_single_res_stack.append(controlnet_block(hidden_states[:, prompt_emb.shape[1]:]))
151
+
152
+ controlnet_res_stack = self.align_res_stack_to_original_blocks(controlnet_res_stack, 19, hidden_states[:, prompt_emb.shape[1]:])
153
+ controlnet_single_res_stack = self.align_res_stack_to_original_blocks(controlnet_single_res_stack, 38, hidden_states[:, prompt_emb.shape[1]:])
154
+
155
+ return controlnet_res_stack, controlnet_single_res_stack
156
+
157
+
158
+ # @staticmethod
159
+ # def state_dict_converter():
160
+ # return FluxControlNetStateDictConverter()
161
+
162
+ def quantize(self):
163
+ def cast_to(weight, dtype=None, device=None, copy=False):
164
+ if device is None or weight.device == device:
165
+ if not copy:
166
+ if dtype is None or weight.dtype == dtype:
167
+ return weight
168
+ return weight.to(dtype=dtype, copy=copy)
169
+
170
+ r = torch.empty_like(weight, dtype=dtype, device=device)
171
+ r.copy_(weight)
172
+ return r
173
+
174
+ def cast_weight(s, input=None, dtype=None, device=None):
175
+ if input is not None:
176
+ if dtype is None:
177
+ dtype = input.dtype
178
+ if device is None:
179
+ device = input.device
180
+ weight = cast_to(s.weight, dtype, device)
181
+ return weight
182
+
183
+ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
184
+ if input is not None:
185
+ if dtype is None:
186
+ dtype = input.dtype
187
+ if bias_dtype is None:
188
+ bias_dtype = dtype
189
+ if device is None:
190
+ device = input.device
191
+ bias = None
192
+ weight = cast_to(s.weight, dtype, device)
193
+ bias = cast_to(s.bias, bias_dtype, device)
194
+ return weight, bias
195
+
196
+ class quantized_layer:
197
+ class QLinear(torch.nn.Linear):
198
+ def __init__(self, *args, **kwargs):
199
+ super().__init__(*args, **kwargs)
200
+
201
+ def forward(self,input,**kwargs):
202
+ weight,bias= cast_bias_weight(self,input)
203
+ return torch.nn.functional.linear(input,weight,bias)
204
+
205
+ class QRMSNorm(torch.nn.Module):
206
+ def __init__(self, module):
207
+ super().__init__()
208
+ self.module = module
209
+
210
+ def forward(self,hidden_states,**kwargs):
211
+ weight= cast_weight(self.module,hidden_states)
212
+ input_dtype = hidden_states.dtype
213
+ variance = hidden_states.to(torch.float32).square().mean(-1, keepdim=True)
214
+ hidden_states = hidden_states * torch.rsqrt(variance + self.module.eps)
215
+ hidden_states = hidden_states.to(input_dtype) * weight
216
+ return hidden_states
217
+
218
+ class QEmbedding(torch.nn.Embedding):
219
+ def __init__(self, *args, **kwargs):
220
+ super().__init__(*args, **kwargs)
221
+
222
+ def forward(self,input,**kwargs):
223
+ weight= cast_weight(self,input)
224
+ return torch.nn.functional.embedding(
225
+ input, weight, self.padding_idx, self.max_norm,
226
+ self.norm_type, self.scale_grad_by_freq, self.sparse)
227
+
228
+ def replace_layer(model):
229
+ for name, module in model.named_children():
230
+ if isinstance(module,quantized_layer.QRMSNorm):
231
+ continue
232
+ if isinstance(module, torch.nn.Linear):
233
+ with init_weights_on_device():
234
+ new_layer = quantized_layer.QLinear(module.in_features,module.out_features)
235
+ new_layer.weight = module.weight
236
+ if module.bias is not None:
237
+ new_layer.bias = module.bias
238
+ setattr(model, name, new_layer)
239
+ elif isinstance(module, RMSNorm):
240
+ if hasattr(module,"quantized"):
241
+ continue
242
+ module.quantized= True
243
+ new_layer = quantized_layer.QRMSNorm(module)
244
+ setattr(model, name, new_layer)
245
+ elif isinstance(module,torch.nn.Embedding):
246
+ rows, cols = module.weight.shape
247
+ new_layer = quantized_layer.QEmbedding(
248
+ num_embeddings=rows,
249
+ embedding_dim=cols,
250
+ _weight=module.weight,
251
+ # _freeze=module.freeze,
252
+ padding_idx=module.padding_idx,
253
+ max_norm=module.max_norm,
254
+ norm_type=module.norm_type,
255
+ scale_grad_by_freq=module.scale_grad_by_freq,
256
+ sparse=module.sparse)
257
+ setattr(model, name, new_layer)
258
+ else:
259
+ replace_layer(module)
260
+
261
+ replace_layer(self)
262
+
263
+
264
+
265
+ class FluxControlNetStateDictConverter:
266
+ def __init__(self):
267
+ pass
268
+
269
+ def from_diffusers(self, state_dict):
270
+ hash_value = hash_state_dict_keys(state_dict)
271
+ global_rename_dict = {
272
+ "context_embedder": "context_embedder",
273
+ "x_embedder": "x_embedder",
274
+ "time_text_embed.timestep_embedder.linear_1": "time_embedder.timestep_embedder.0",
275
+ "time_text_embed.timestep_embedder.linear_2": "time_embedder.timestep_embedder.2",
276
+ "time_text_embed.guidance_embedder.linear_1": "guidance_embedder.timestep_embedder.0",
277
+ "time_text_embed.guidance_embedder.linear_2": "guidance_embedder.timestep_embedder.2",
278
+ "time_text_embed.text_embedder.linear_1": "pooled_text_embedder.0",
279
+ "time_text_embed.text_embedder.linear_2": "pooled_text_embedder.2",
280
+ "norm_out.linear": "final_norm_out.linear",
281
+ "proj_out": "final_proj_out",
282
+ }
283
+ rename_dict = {
284
+ "proj_out": "proj_out",
285
+ "norm1.linear": "norm1_a.linear",
286
+ "norm1_context.linear": "norm1_b.linear",
287
+ "attn.to_q": "attn.a_to_q",
288
+ "attn.to_k": "attn.a_to_k",
289
+ "attn.to_v": "attn.a_to_v",
290
+ "attn.to_out.0": "attn.a_to_out",
291
+ "attn.add_q_proj": "attn.b_to_q",
292
+ "attn.add_k_proj": "attn.b_to_k",
293
+ "attn.add_v_proj": "attn.b_to_v",
294
+ "attn.to_add_out": "attn.b_to_out",
295
+ "ff.net.0.proj": "ff_a.0",
296
+ "ff.net.2": "ff_a.2",
297
+ "ff_context.net.0.proj": "ff_b.0",
298
+ "ff_context.net.2": "ff_b.2",
299
+ "attn.norm_q": "attn.norm_q_a",
300
+ "attn.norm_k": "attn.norm_k_a",
301
+ "attn.norm_added_q": "attn.norm_q_b",
302
+ "attn.norm_added_k": "attn.norm_k_b",
303
+ }
304
+ rename_dict_single = {
305
+ "attn.to_q": "a_to_q",
306
+ "attn.to_k": "a_to_k",
307
+ "attn.to_v": "a_to_v",
308
+ "attn.norm_q": "norm_q_a",
309
+ "attn.norm_k": "norm_k_a",
310
+ "norm.linear": "norm.linear",
311
+ "proj_mlp": "proj_in_besides_attn",
312
+ "proj_out": "proj_out",
313
+ }
314
+ state_dict_ = {}
315
+ for name, param in state_dict.items():
316
+ if name.endswith(".weight") or name.endswith(".bias"):
317
+ suffix = ".weight" if name.endswith(".weight") else ".bias"
318
+ prefix = name[:-len(suffix)]
319
+ if prefix in global_rename_dict:
320
+ state_dict_[global_rename_dict[prefix] + suffix] = param
321
+ elif prefix.startswith("transformer_blocks."):
322
+ names = prefix.split(".")
323
+ names[0] = "blocks"
324
+ middle = ".".join(names[2:])
325
+ if middle in rename_dict:
326
+ name_ = ".".join(names[:2] + [rename_dict[middle]] + [suffix[1:]])
327
+ state_dict_[name_] = param
328
+ elif prefix.startswith("single_transformer_blocks."):
329
+ names = prefix.split(".")
330
+ names[0] = "single_blocks"
331
+ middle = ".".join(names[2:])
332
+ if middle in rename_dict_single:
333
+ name_ = ".".join(names[:2] + [rename_dict_single[middle]] + [suffix[1:]])
334
+ state_dict_[name_] = param
335
+ else:
336
+ state_dict_[name] = param
337
+ else:
338
+ state_dict_[name] = param
339
+ for name in list(state_dict_.keys()):
340
+ if ".proj_in_besides_attn." in name:
341
+ name_ = name.replace(".proj_in_besides_attn.", ".to_qkv_mlp.")
342
+ param = torch.concat([
343
+ state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_q.")],
344
+ state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_k.")],
345
+ state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_v.")],
346
+ state_dict_[name],
347
+ ], dim=0)
348
+ state_dict_[name_] = param
349
+ state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_q."))
350
+ state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_k."))
351
+ state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_v."))
352
+ state_dict_.pop(name)
353
+ for name in list(state_dict_.keys()):
354
+ for component in ["a", "b"]:
355
+ if f".{component}_to_q." in name:
356
+ name_ = name.replace(f".{component}_to_q.", f".{component}_to_qkv.")
357
+ param = torch.concat([
358
+ state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")],
359
+ state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")],
360
+ state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")],
361
+ ], dim=0)
362
+ state_dict_[name_] = param
363
+ state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_q."))
364
+ state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_k."))
365
+ state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_v."))
366
+ if hash_value == "78d18b9101345ff695f312e7e62538c0":
367
+ extra_kwargs = {"num_mode": 10, "mode_dict": {"canny": 0, "tile": 1, "depth": 2, "blur": 3, "pose": 4, "gray": 5, "lq": 6}}
368
+ elif hash_value == "b001c89139b5f053c715fe772362dd2a":
369
+ extra_kwargs = {"num_single_blocks": 0}
370
+ elif hash_value == "52357cb26250681367488a8954c271e8":
371
+ extra_kwargs = {"num_joint_blocks": 6, "num_single_blocks": 0, "additional_input_dim": 4}
372
+ elif hash_value == "0cfd1740758423a2a854d67c136d1e8c":
373
+ extra_kwargs = {"num_joint_blocks": 4, "num_single_blocks": 1}
374
+ elif hash_value == "7f9583eb8ba86642abb9a21a4b2c9e16":
375
+ extra_kwargs = {"num_joint_blocks": 4, "num_single_blocks": 10}
376
+ elif hash_value == "43ad5aaa27dd4ee01b832ed16773fa52":
377
+ extra_kwargs = {"num_joint_blocks": 6, "num_single_blocks": 0}
378
+ else:
379
+ extra_kwargs = {}
380
+ return state_dict_, extra_kwargs
381
+
382
+
383
+ def from_civitai(self, state_dict):
384
+ return self.from_diffusers(state_dict)
diffsynth/models/flux_dit.py ADDED
@@ -0,0 +1,398 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from .general_modules import TimestepEmbeddings, AdaLayerNorm, RMSNorm
3
+ from einops import rearrange
4
+
5
+
6
+ def interact_with_ipadapter(hidden_states, q, ip_k, ip_v, scale=1.0):
7
+ batch_size, num_tokens = hidden_states.shape[0:2]
8
+ ip_hidden_states = torch.nn.functional.scaled_dot_product_attention(q, ip_k, ip_v)
9
+ ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, num_tokens, -1)
10
+ hidden_states = hidden_states + scale * ip_hidden_states
11
+ return hidden_states
12
+
13
+
14
+ class RoPEEmbedding(torch.nn.Module):
15
+ def __init__(self, dim, theta, axes_dim):
16
+ super().__init__()
17
+ self.dim = dim
18
+ self.theta = theta
19
+ self.axes_dim = axes_dim
20
+
21
+
22
+ def rope(self, pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
23
+ assert dim % 2 == 0, "The dimension must be even."
24
+
25
+ scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
26
+ omega = 1.0 / (theta**scale)
27
+
28
+ batch_size, seq_length = pos.shape
29
+ out = torch.einsum("...n,d->...nd", pos, omega)
30
+ cos_out = torch.cos(out)
31
+ sin_out = torch.sin(out)
32
+
33
+ stacked_out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1)
34
+ out = stacked_out.view(batch_size, -1, dim // 2, 2, 2)
35
+ return out.float()
36
+
37
+
38
+ def forward(self, ids):
39
+ n_axes = ids.shape[-1]
40
+ emb = torch.cat([self.rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], dim=-3)
41
+ return emb.unsqueeze(1)
42
+
43
+
44
+
45
+ class FluxJointAttention(torch.nn.Module):
46
+ def __init__(self, dim_a, dim_b, num_heads, head_dim, only_out_a=False):
47
+ super().__init__()
48
+ self.num_heads = num_heads
49
+ self.head_dim = head_dim
50
+ self.only_out_a = only_out_a
51
+
52
+ self.a_to_qkv = torch.nn.Linear(dim_a, dim_a * 3)
53
+ self.b_to_qkv = torch.nn.Linear(dim_b, dim_b * 3)
54
+
55
+ self.norm_q_a = RMSNorm(head_dim, eps=1e-6)
56
+ self.norm_k_a = RMSNorm(head_dim, eps=1e-6)
57
+ self.norm_q_b = RMSNorm(head_dim, eps=1e-6)
58
+ self.norm_k_b = RMSNorm(head_dim, eps=1e-6)
59
+
60
+ self.a_to_out = torch.nn.Linear(dim_a, dim_a)
61
+ if not only_out_a:
62
+ self.b_to_out = torch.nn.Linear(dim_b, dim_b)
63
+
64
+
65
+ def apply_rope(self, xq, xk, freqs_cis):
66
+ xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
67
+ xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
68
+ xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
69
+ xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
70
+ return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
71
+
72
+ def forward(self, hidden_states_a, hidden_states_b, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None):
73
+ batch_size = hidden_states_a.shape[0]
74
+
75
+ # Part A
76
+ qkv_a = self.a_to_qkv(hidden_states_a)
77
+ qkv_a = qkv_a.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2)
78
+ q_a, k_a, v_a = qkv_a.chunk(3, dim=1)
79
+ q_a, k_a = self.norm_q_a(q_a), self.norm_k_a(k_a)
80
+
81
+ # Part B
82
+ qkv_b = self.b_to_qkv(hidden_states_b)
83
+ qkv_b = qkv_b.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2)
84
+ q_b, k_b, v_b = qkv_b.chunk(3, dim=1)
85
+ q_b, k_b = self.norm_q_b(q_b), self.norm_k_b(k_b)
86
+
87
+ q = torch.concat([q_b, q_a], dim=2)
88
+ k = torch.concat([k_b, k_a], dim=2)
89
+ v = torch.concat([v_b, v_a], dim=2)
90
+
91
+ q, k = self.apply_rope(q, k, image_rotary_emb)
92
+
93
+ hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
94
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
95
+ hidden_states = hidden_states.to(q.dtype)
96
+ hidden_states_b, hidden_states_a = hidden_states[:, :hidden_states_b.shape[1]], hidden_states[:, hidden_states_b.shape[1]:]
97
+ if ipadapter_kwargs_list is not None:
98
+ hidden_states_a = interact_with_ipadapter(hidden_states_a, q_a, **ipadapter_kwargs_list)
99
+ hidden_states_a = self.a_to_out(hidden_states_a)
100
+ if self.only_out_a:
101
+ return hidden_states_a
102
+ else:
103
+ hidden_states_b = self.b_to_out(hidden_states_b)
104
+ return hidden_states_a, hidden_states_b
105
+
106
+
107
+
108
+ class FluxJointTransformerBlock(torch.nn.Module):
109
+ def __init__(self, dim, num_attention_heads):
110
+ super().__init__()
111
+ self.norm1_a = AdaLayerNorm(dim)
112
+ self.norm1_b = AdaLayerNorm(dim)
113
+
114
+ self.attn = FluxJointAttention(dim, dim, num_attention_heads, dim // num_attention_heads)
115
+
116
+ self.norm2_a = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
117
+ self.ff_a = torch.nn.Sequential(
118
+ torch.nn.Linear(dim, dim*4),
119
+ torch.nn.GELU(approximate="tanh"),
120
+ torch.nn.Linear(dim*4, dim)
121
+ )
122
+
123
+ self.norm2_b = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
124
+ self.ff_b = torch.nn.Sequential(
125
+ torch.nn.Linear(dim, dim*4),
126
+ torch.nn.GELU(approximate="tanh"),
127
+ torch.nn.Linear(dim*4, dim)
128
+ )
129
+
130
+
131
+ def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None):
132
+ norm_hidden_states_a, gate_msa_a, shift_mlp_a, scale_mlp_a, gate_mlp_a = self.norm1_a(hidden_states_a, emb=temb)
133
+ norm_hidden_states_b, gate_msa_b, shift_mlp_b, scale_mlp_b, gate_mlp_b = self.norm1_b(hidden_states_b, emb=temb)
134
+
135
+ # Attention
136
+ attn_output_a, attn_output_b = self.attn(norm_hidden_states_a, norm_hidden_states_b, image_rotary_emb, attn_mask, ipadapter_kwargs_list)
137
+
138
+ # Part A
139
+ hidden_states_a = hidden_states_a + gate_msa_a * attn_output_a
140
+ norm_hidden_states_a = self.norm2_a(hidden_states_a) * (1 + scale_mlp_a) + shift_mlp_a
141
+ hidden_states_a = hidden_states_a + gate_mlp_a * self.ff_a(norm_hidden_states_a)
142
+
143
+ # Part B
144
+ hidden_states_b = hidden_states_b + gate_msa_b * attn_output_b
145
+ norm_hidden_states_b = self.norm2_b(hidden_states_b) * (1 + scale_mlp_b) + shift_mlp_b
146
+ hidden_states_b = hidden_states_b + gate_mlp_b * self.ff_b(norm_hidden_states_b)
147
+
148
+ return hidden_states_a, hidden_states_b
149
+
150
+
151
+
152
+ class FluxSingleAttention(torch.nn.Module):
153
+ def __init__(self, dim_a, dim_b, num_heads, head_dim):
154
+ super().__init__()
155
+ self.num_heads = num_heads
156
+ self.head_dim = head_dim
157
+
158
+ self.a_to_qkv = torch.nn.Linear(dim_a, dim_a * 3)
159
+
160
+ self.norm_q_a = RMSNorm(head_dim, eps=1e-6)
161
+ self.norm_k_a = RMSNorm(head_dim, eps=1e-6)
162
+
163
+
164
+ def apply_rope(self, xq, xk, freqs_cis):
165
+ xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
166
+ xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
167
+ xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
168
+ xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
169
+ return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
170
+
171
+
172
+ def forward(self, hidden_states, image_rotary_emb):
173
+ batch_size = hidden_states.shape[0]
174
+
175
+ qkv_a = self.a_to_qkv(hidden_states)
176
+ qkv_a = qkv_a.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2)
177
+ q_a, k_a, v = qkv_a.chunk(3, dim=1)
178
+ q_a, k_a = self.norm_q_a(q_a), self.norm_k_a(k_a)
179
+
180
+ q, k = self.apply_rope(q_a, k_a, image_rotary_emb)
181
+
182
+ hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v)
183
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
184
+ hidden_states = hidden_states.to(q.dtype)
185
+ return hidden_states
186
+
187
+
188
+
189
+ class AdaLayerNormSingle(torch.nn.Module):
190
+ def __init__(self, dim):
191
+ super().__init__()
192
+ self.silu = torch.nn.SiLU()
193
+ self.linear = torch.nn.Linear(dim, 3 * dim, bias=True)
194
+ self.norm = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
195
+
196
+
197
+ def forward(self, x, emb):
198
+ emb = self.linear(self.silu(emb))
199
+ shift_msa, scale_msa, gate_msa = emb.chunk(3, dim=1)
200
+ x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
201
+ return x, gate_msa
202
+
203
+
204
+
205
+ class FluxSingleTransformerBlock(torch.nn.Module):
206
+ def __init__(self, dim, num_attention_heads):
207
+ super().__init__()
208
+ self.num_heads = num_attention_heads
209
+ self.head_dim = dim // num_attention_heads
210
+ self.dim = dim
211
+
212
+ self.norm = AdaLayerNormSingle(dim)
213
+ self.to_qkv_mlp = torch.nn.Linear(dim, dim * (3 + 4))
214
+ self.norm_q_a = RMSNorm(self.head_dim, eps=1e-6)
215
+ self.norm_k_a = RMSNorm(self.head_dim, eps=1e-6)
216
+
217
+ self.proj_out = torch.nn.Linear(dim * 5, dim)
218
+
219
+
220
+ def apply_rope(self, xq, xk, freqs_cis):
221
+ xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
222
+ xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
223
+ xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
224
+ xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
225
+ return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
226
+
227
+
228
+ def process_attention(self, hidden_states, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None):
229
+ batch_size = hidden_states.shape[0]
230
+
231
+ qkv = hidden_states.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2)
232
+ q, k, v = qkv.chunk(3, dim=1)
233
+ q, k = self.norm_q_a(q), self.norm_k_a(k)
234
+
235
+ q, k = self.apply_rope(q, k, image_rotary_emb)
236
+
237
+ hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
238
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
239
+ hidden_states = hidden_states.to(q.dtype)
240
+ if ipadapter_kwargs_list is not None:
241
+ hidden_states = interact_with_ipadapter(hidden_states, q, **ipadapter_kwargs_list)
242
+ return hidden_states
243
+
244
+
245
+ def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None):
246
+ residual = hidden_states_a
247
+ norm_hidden_states, gate = self.norm(hidden_states_a, emb=temb)
248
+ hidden_states_a = self.to_qkv_mlp(norm_hidden_states)
249
+ attn_output, mlp_hidden_states = hidden_states_a[:, :, :self.dim * 3], hidden_states_a[:, :, self.dim * 3:]
250
+
251
+ attn_output = self.process_attention(attn_output, image_rotary_emb, attn_mask, ipadapter_kwargs_list)
252
+ mlp_hidden_states = torch.nn.functional.gelu(mlp_hidden_states, approximate="tanh")
253
+
254
+ hidden_states_a = torch.cat([attn_output, mlp_hidden_states], dim=2)
255
+ hidden_states_a = gate.unsqueeze(1) * self.proj_out(hidden_states_a)
256
+ hidden_states_a = residual + hidden_states_a
257
+
258
+ return hidden_states_a, hidden_states_b
259
+
260
+
261
+
262
+ class AdaLayerNormContinuous(torch.nn.Module):
263
+ def __init__(self, dim):
264
+ super().__init__()
265
+ self.silu = torch.nn.SiLU()
266
+ self.linear = torch.nn.Linear(dim, dim * 2, bias=True)
267
+ self.norm = torch.nn.LayerNorm(dim, eps=1e-6, elementwise_affine=False)
268
+
269
+ def forward(self, x, conditioning):
270
+ emb = self.linear(self.silu(conditioning))
271
+ shift, scale = torch.chunk(emb, 2, dim=1)
272
+ x = self.norm(x) * (1 + scale)[:, None] + shift[:, None]
273
+ return x
274
+
275
+
276
+
277
+ class FluxDiT(torch.nn.Module):
278
+
279
+ _repeated_blocks = ["FluxJointTransformerBlock", "FluxSingleTransformerBlock"]
280
+
281
+ def __init__(self, disable_guidance_embedder=False, input_dim=64, num_blocks=19):
282
+ super().__init__()
283
+ self.pos_embedder = RoPEEmbedding(3072, 10000, [16, 56, 56])
284
+ self.time_embedder = TimestepEmbeddings(256, 3072)
285
+ self.guidance_embedder = None if disable_guidance_embedder else TimestepEmbeddings(256, 3072)
286
+ self.pooled_text_embedder = torch.nn.Sequential(torch.nn.Linear(768, 3072), torch.nn.SiLU(), torch.nn.Linear(3072, 3072))
287
+ self.context_embedder = torch.nn.Linear(4096, 3072)
288
+ self.x_embedder = torch.nn.Linear(input_dim, 3072)
289
+
290
+ self.blocks = torch.nn.ModuleList([FluxJointTransformerBlock(3072, 24) for _ in range(num_blocks)])
291
+ self.single_blocks = torch.nn.ModuleList([FluxSingleTransformerBlock(3072, 24) for _ in range(38)])
292
+
293
+ self.final_norm_out = AdaLayerNormContinuous(3072)
294
+ self.final_proj_out = torch.nn.Linear(3072, 64)
295
+
296
+ self.input_dim = input_dim
297
+
298
+
299
+ def patchify(self, hidden_states):
300
+ hidden_states = rearrange(hidden_states, "B C (H P) (W Q) -> B (H W) (C P Q)", P=2, Q=2)
301
+ return hidden_states
302
+
303
+
304
+ def unpatchify(self, hidden_states, height, width):
305
+ hidden_states = rearrange(hidden_states, "B (H W) (C P Q) -> B C (H P) (W Q)", P=2, Q=2, H=height//2, W=width//2)
306
+ return hidden_states
307
+
308
+
309
+ def prepare_image_ids(self, latents):
310
+ batch_size, _, height, width = latents.shape
311
+ latent_image_ids = torch.zeros(height // 2, width // 2, 3)
312
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
313
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
314
+
315
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
316
+
317
+ latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1)
318
+ latent_image_ids = latent_image_ids.reshape(
319
+ batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels
320
+ )
321
+ latent_image_ids = latent_image_ids.to(device=latents.device, dtype=latents.dtype)
322
+
323
+ return latent_image_ids
324
+
325
+
326
+ def construct_mask(self, entity_masks, prompt_seq_len, image_seq_len):
327
+ N = len(entity_masks)
328
+ batch_size = entity_masks[0].shape[0]
329
+ total_seq_len = N * prompt_seq_len + image_seq_len
330
+ patched_masks = [self.patchify(entity_masks[i]) for i in range(N)]
331
+ attention_mask = torch.ones((batch_size, total_seq_len, total_seq_len), dtype=torch.bool).to(device=entity_masks[0].device)
332
+
333
+ image_start = N * prompt_seq_len
334
+ image_end = N * prompt_seq_len + image_seq_len
335
+ # prompt-image mask
336
+ for i in range(N):
337
+ prompt_start = i * prompt_seq_len
338
+ prompt_end = (i + 1) * prompt_seq_len
339
+ image_mask = torch.sum(patched_masks[i], dim=-1) > 0
340
+ image_mask = image_mask.unsqueeze(1).repeat(1, prompt_seq_len, 1)
341
+ # prompt update with image
342
+ attention_mask[:, prompt_start:prompt_end, image_start:image_end] = image_mask
343
+ # image update with prompt
344
+ attention_mask[:, image_start:image_end, prompt_start:prompt_end] = image_mask.transpose(1, 2)
345
+ # prompt-prompt mask
346
+ for i in range(N):
347
+ for j in range(N):
348
+ if i != j:
349
+ prompt_start_i = i * prompt_seq_len
350
+ prompt_end_i = (i + 1) * prompt_seq_len
351
+ prompt_start_j = j * prompt_seq_len
352
+ prompt_end_j = (j + 1) * prompt_seq_len
353
+ attention_mask[:, prompt_start_i:prompt_end_i, prompt_start_j:prompt_end_j] = False
354
+
355
+ attention_mask = attention_mask.float()
356
+ attention_mask[attention_mask == 0] = float('-inf')
357
+ attention_mask[attention_mask == 1] = 0
358
+ return attention_mask
359
+
360
+
361
+ def process_entity_masks(self, hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids, repeat_dim):
362
+ max_masks = 0
363
+ attention_mask = None
364
+ prompt_embs = [prompt_emb]
365
+ if entity_masks is not None:
366
+ # entity_masks
367
+ batch_size, max_masks = entity_masks.shape[0], entity_masks.shape[1]
368
+ entity_masks = entity_masks.repeat(1, 1, repeat_dim, 1, 1)
369
+ entity_masks = [entity_masks[:, i, None].squeeze(1) for i in range(max_masks)]
370
+ # global mask
371
+ global_mask = torch.ones_like(entity_masks[0]).to(device=hidden_states.device, dtype=hidden_states.dtype)
372
+ entity_masks = entity_masks + [global_mask] # append global to last
373
+ # attention mask
374
+ attention_mask = self.construct_mask(entity_masks, prompt_emb.shape[1], hidden_states.shape[1])
375
+ attention_mask = attention_mask.to(device=hidden_states.device, dtype=hidden_states.dtype)
376
+ attention_mask = attention_mask.unsqueeze(1)
377
+ # embds: n_masks * b * seq * d
378
+ local_embs = [entity_prompt_emb[:, i, None].squeeze(1) for i in range(max_masks)]
379
+ prompt_embs = local_embs + prompt_embs # append global to last
380
+ prompt_embs = [self.context_embedder(prompt_emb) for prompt_emb in prompt_embs]
381
+ prompt_emb = torch.cat(prompt_embs, dim=1)
382
+
383
+ # positional embedding
384
+ text_ids = torch.cat([text_ids] * (max_masks + 1), dim=1)
385
+ image_rotary_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
386
+ return prompt_emb, image_rotary_emb, attention_mask
387
+
388
+
389
+ def forward(
390
+ self,
391
+ hidden_states,
392
+ timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids, image_ids=None,
393
+ tiled=False, tile_size=128, tile_stride=64, entity_prompt_emb=None, entity_masks=None,
394
+ use_gradient_checkpointing=False,
395
+ **kwargs
396
+ ):
397
+ # (Deprecated) The real forward is in `pipelines.flux_image`.
398
+ return None
diffsynth/models/flux_infiniteyou.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+
6
+ # FFN
7
+ def FeedForward(dim, mult=4):
8
+ inner_dim = int(dim * mult)
9
+ return nn.Sequential(
10
+ nn.LayerNorm(dim),
11
+ nn.Linear(dim, inner_dim, bias=False),
12
+ nn.GELU(),
13
+ nn.Linear(inner_dim, dim, bias=False),
14
+ )
15
+
16
+
17
+ def reshape_tensor(x, heads):
18
+ bs, length, width = x.shape
19
+ #(bs, length, width) --> (bs, length, n_heads, dim_per_head)
20
+ x = x.view(bs, length, heads, -1)
21
+ # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
22
+ x = x.transpose(1, 2)
23
+ # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
24
+ x = x.reshape(bs, heads, length, -1)
25
+ return x
26
+
27
+
28
+ class PerceiverAttention(nn.Module):
29
+
30
+ def __init__(self, *, dim, dim_head=64, heads=8):
31
+ super().__init__()
32
+ self.scale = dim_head**-0.5
33
+ self.dim_head = dim_head
34
+ self.heads = heads
35
+ inner_dim = dim_head * heads
36
+
37
+ self.norm1 = nn.LayerNorm(dim)
38
+ self.norm2 = nn.LayerNorm(dim)
39
+
40
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
41
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
42
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
43
+
44
+ def forward(self, x, latents):
45
+ """
46
+ Args:
47
+ x (torch.Tensor): image features
48
+ shape (b, n1, D)
49
+ latent (torch.Tensor): latent features
50
+ shape (b, n2, D)
51
+ """
52
+ x = self.norm1(x)
53
+ latents = self.norm2(latents)
54
+
55
+ b, l, _ = latents.shape
56
+
57
+ q = self.to_q(latents)
58
+ kv_input = torch.cat((x, latents), dim=-2)
59
+ k, v = self.to_kv(kv_input).chunk(2, dim=-1)
60
+
61
+ q = reshape_tensor(q, self.heads)
62
+ k = reshape_tensor(k, self.heads)
63
+ v = reshape_tensor(v, self.heads)
64
+
65
+ # attention
66
+ scale = 1 / math.sqrt(math.sqrt(self.dim_head))
67
+ weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
68
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
69
+ out = weight @ v
70
+
71
+ out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
72
+
73
+ return self.to_out(out)
74
+
75
+
76
+ class InfiniteYouImageProjector(nn.Module):
77
+
78
+ def __init__(
79
+ self,
80
+ dim=1280,
81
+ depth=4,
82
+ dim_head=64,
83
+ heads=20,
84
+ num_queries=8,
85
+ embedding_dim=512,
86
+ output_dim=4096,
87
+ ff_mult=4,
88
+ ):
89
+ super().__init__()
90
+ self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
91
+ self.proj_in = nn.Linear(embedding_dim, dim)
92
+
93
+ self.proj_out = nn.Linear(dim, output_dim)
94
+ self.norm_out = nn.LayerNorm(output_dim)
95
+
96
+ self.layers = nn.ModuleList([])
97
+ for _ in range(depth):
98
+ self.layers.append(
99
+ nn.ModuleList([
100
+ PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
101
+ FeedForward(dim=dim, mult=ff_mult),
102
+ ]))
103
+
104
+ def forward(self, x):
105
+
106
+ latents = self.latents.repeat(x.size(0), 1, 1)
107
+ latents = latents.to(dtype=x.dtype, device=x.device)
108
+
109
+ x = self.proj_in(x)
110
+
111
+ for attn, ff in self.layers:
112
+ latents = attn(x, latents) + latents
113
+ latents = ff(latents) + latents
114
+
115
+ latents = self.proj_out(latents)
116
+ return self.norm_out(latents)
117
+
118
+ @staticmethod
119
+ def state_dict_converter():
120
+ return FluxInfiniteYouImageProjectorStateDictConverter()
121
+
122
+
123
+ class FluxInfiniteYouImageProjectorStateDictConverter:
124
+
125
+ def __init__(self):
126
+ pass
127
+
128
+ def from_diffusers(self, state_dict):
129
+ return state_dict['image_proj']
diffsynth/models/flux_ipadapter.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .general_modules import RMSNorm
2
+ from transformers import SiglipVisionModel, SiglipVisionConfig
3
+ import torch
4
+
5
+
6
+ class SiglipVisionModelSO400M(SiglipVisionModel):
7
+ def __init__(self):
8
+ config = SiglipVisionConfig(
9
+ hidden_size=1152,
10
+ image_size=384,
11
+ intermediate_size=4304,
12
+ model_type="siglip_vision_model",
13
+ num_attention_heads=16,
14
+ num_hidden_layers=27,
15
+ patch_size=14,
16
+ architectures=["SiglipModel"],
17
+ initializer_factor=1.0,
18
+ torch_dtype="float32",
19
+ transformers_version="4.37.0.dev0"
20
+ )
21
+ super().__init__(config)
22
+
23
+ class MLPProjModel(torch.nn.Module):
24
+ def __init__(self, cross_attention_dim=768, id_embeddings_dim=512, num_tokens=4):
25
+ super().__init__()
26
+
27
+ self.cross_attention_dim = cross_attention_dim
28
+ self.num_tokens = num_tokens
29
+
30
+ self.proj = torch.nn.Sequential(
31
+ torch.nn.Linear(id_embeddings_dim, id_embeddings_dim*2),
32
+ torch.nn.GELU(),
33
+ torch.nn.Linear(id_embeddings_dim*2, cross_attention_dim*num_tokens),
34
+ )
35
+ self.norm = torch.nn.LayerNorm(cross_attention_dim)
36
+
37
+ def forward(self, id_embeds):
38
+ x = self.proj(id_embeds)
39
+ x = x.reshape(-1, self.num_tokens, self.cross_attention_dim)
40
+ x = self.norm(x)
41
+ return x
42
+
43
+ class IpAdapterModule(torch.nn.Module):
44
+ def __init__(self, num_attention_heads, attention_head_dim, input_dim):
45
+ super().__init__()
46
+ self.num_heads = num_attention_heads
47
+ self.head_dim = attention_head_dim
48
+ output_dim = num_attention_heads * attention_head_dim
49
+ self.to_k_ip = torch.nn.Linear(input_dim, output_dim, bias=False)
50
+ self.to_v_ip = torch.nn.Linear(input_dim, output_dim, bias=False)
51
+ self.norm_added_k = RMSNorm(attention_head_dim, eps=1e-5, elementwise_affine=False)
52
+
53
+
54
+ def forward(self, hidden_states):
55
+ batch_size = hidden_states.shape[0]
56
+ # ip_k
57
+ ip_k = self.to_k_ip(hidden_states)
58
+ ip_k = ip_k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
59
+ ip_k = self.norm_added_k(ip_k)
60
+ # ip_v
61
+ ip_v = self.to_v_ip(hidden_states)
62
+ ip_v = ip_v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
63
+ return ip_k, ip_v
64
+
65
+
66
+ class FluxIpAdapter(torch.nn.Module):
67
+ def __init__(self, num_attention_heads=24, attention_head_dim=128, cross_attention_dim=4096, num_tokens=128, num_blocks=57):
68
+ super().__init__()
69
+ self.ipadapter_modules = torch.nn.ModuleList([IpAdapterModule(num_attention_heads, attention_head_dim, cross_attention_dim) for _ in range(num_blocks)])
70
+ self.image_proj = MLPProjModel(cross_attention_dim=cross_attention_dim, id_embeddings_dim=1152, num_tokens=num_tokens)
71
+ self.set_adapter()
72
+
73
+ def set_adapter(self):
74
+ self.call_block_id = {i:i for i in range(len(self.ipadapter_modules))}
75
+
76
+ def forward(self, hidden_states, scale=1.0):
77
+ hidden_states = self.image_proj(hidden_states)
78
+ hidden_states = hidden_states.view(1, -1, hidden_states.shape[-1])
79
+ ip_kv_dict = {}
80
+ for block_id in self.call_block_id:
81
+ ipadapter_id = self.call_block_id[block_id]
82
+ ip_k, ip_v = self.ipadapter_modules[ipadapter_id](hidden_states)
83
+ ip_kv_dict[block_id] = {
84
+ "ip_k": ip_k,
85
+ "ip_v": ip_v,
86
+ "scale": scale
87
+ }
88
+ return ip_kv_dict
89
+
90
+ @staticmethod
91
+ def state_dict_converter():
92
+ return FluxIpAdapterStateDictConverter()
93
+
94
+
95
+ class FluxIpAdapterStateDictConverter:
96
+ def __init__(self):
97
+ pass
98
+
99
+ def from_diffusers(self, state_dict):
100
+ state_dict_ = {}
101
+ for name in state_dict["ip_adapter"]:
102
+ name_ = 'ipadapter_modules.' + name
103
+ state_dict_[name_] = state_dict["ip_adapter"][name]
104
+ for name in state_dict["image_proj"]:
105
+ name_ = "image_proj." + name
106
+ state_dict_[name_] = state_dict["image_proj"][name]
107
+ return state_dict_
108
+
109
+ def from_civitai(self, state_dict):
110
+ return self.from_diffusers(state_dict)
diffsynth/models/flux_lora_encoder.py ADDED
@@ -0,0 +1,521 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from einops import rearrange
3
+
4
+
5
+ def low_version_attention(query, key, value, attn_bias=None):
6
+ scale = 1 / query.shape[-1] ** 0.5
7
+ query = query * scale
8
+ attn = torch.matmul(query, key.transpose(-2, -1))
9
+ if attn_bias is not None:
10
+ attn = attn + attn_bias
11
+ attn = attn.softmax(-1)
12
+ return attn @ value
13
+
14
+
15
+ class Attention(torch.nn.Module):
16
+
17
+ def __init__(self, q_dim, num_heads, head_dim, kv_dim=None, bias_q=False, bias_kv=False, bias_out=False):
18
+ super().__init__()
19
+ dim_inner = head_dim * num_heads
20
+ kv_dim = kv_dim if kv_dim is not None else q_dim
21
+ self.num_heads = num_heads
22
+ self.head_dim = head_dim
23
+
24
+ self.to_q = torch.nn.Linear(q_dim, dim_inner, bias=bias_q)
25
+ self.to_k = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
26
+ self.to_v = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
27
+ self.to_out = torch.nn.Linear(dim_inner, q_dim, bias=bias_out)
28
+
29
+ def interact_with_ipadapter(self, hidden_states, q, ip_k, ip_v, scale=1.0):
30
+ batch_size = q.shape[0]
31
+ ip_k = ip_k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
32
+ ip_v = ip_v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
33
+ ip_hidden_states = torch.nn.functional.scaled_dot_product_attention(q, ip_k, ip_v)
34
+ hidden_states = hidden_states + scale * ip_hidden_states
35
+ return hidden_states
36
+
37
+ def torch_forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None, ipadapter_kwargs=None, qkv_preprocessor=None):
38
+ if encoder_hidden_states is None:
39
+ encoder_hidden_states = hidden_states
40
+
41
+ batch_size = encoder_hidden_states.shape[0]
42
+
43
+ q = self.to_q(hidden_states)
44
+ k = self.to_k(encoder_hidden_states)
45
+ v = self.to_v(encoder_hidden_states)
46
+
47
+ q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
48
+ k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
49
+ v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
50
+
51
+ if qkv_preprocessor is not None:
52
+ q, k, v = qkv_preprocessor(q, k, v)
53
+
54
+ hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
55
+ if ipadapter_kwargs is not None:
56
+ hidden_states = self.interact_with_ipadapter(hidden_states, q, **ipadapter_kwargs)
57
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
58
+ hidden_states = hidden_states.to(q.dtype)
59
+
60
+ hidden_states = self.to_out(hidden_states)
61
+
62
+ return hidden_states
63
+
64
+ def xformers_forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None):
65
+ if encoder_hidden_states is None:
66
+ encoder_hidden_states = hidden_states
67
+
68
+ q = self.to_q(hidden_states)
69
+ k = self.to_k(encoder_hidden_states)
70
+ v = self.to_v(encoder_hidden_states)
71
+
72
+ q = rearrange(q, "b f (n d) -> (b n) f d", n=self.num_heads)
73
+ k = rearrange(k, "b f (n d) -> (b n) f d", n=self.num_heads)
74
+ v = rearrange(v, "b f (n d) -> (b n) f d", n=self.num_heads)
75
+
76
+ if attn_mask is not None:
77
+ hidden_states = low_version_attention(q, k, v, attn_bias=attn_mask)
78
+ else:
79
+ import xformers.ops as xops
80
+ hidden_states = xops.memory_efficient_attention(q, k, v)
81
+ hidden_states = rearrange(hidden_states, "(b n) f d -> b f (n d)", n=self.num_heads)
82
+
83
+ hidden_states = hidden_states.to(q.dtype)
84
+ hidden_states = self.to_out(hidden_states)
85
+
86
+ return hidden_states
87
+
88
+ def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None, ipadapter_kwargs=None, qkv_preprocessor=None):
89
+ return self.torch_forward(hidden_states, encoder_hidden_states=encoder_hidden_states, attn_mask=attn_mask, ipadapter_kwargs=ipadapter_kwargs, qkv_preprocessor=qkv_preprocessor)
90
+
91
+
92
+
93
+
94
+
95
+ class CLIPEncoderLayer(torch.nn.Module):
96
+ def __init__(self, embed_dim, intermediate_size, num_heads=12, head_dim=64, use_quick_gelu=True):
97
+ super().__init__()
98
+ self.attn = Attention(q_dim=embed_dim, num_heads=num_heads, head_dim=head_dim, bias_q=True, bias_kv=True, bias_out=True)
99
+ self.layer_norm1 = torch.nn.LayerNorm(embed_dim)
100
+ self.layer_norm2 = torch.nn.LayerNorm(embed_dim)
101
+ self.fc1 = torch.nn.Linear(embed_dim, intermediate_size)
102
+ self.fc2 = torch.nn.Linear(intermediate_size, embed_dim)
103
+
104
+ self.use_quick_gelu = use_quick_gelu
105
+
106
+ def quickGELU(self, x):
107
+ return x * torch.sigmoid(1.702 * x)
108
+
109
+ def forward(self, hidden_states, attn_mask=None):
110
+ residual = hidden_states
111
+
112
+ hidden_states = self.layer_norm1(hidden_states)
113
+ hidden_states = self.attn(hidden_states, attn_mask=attn_mask)
114
+ hidden_states = residual + hidden_states
115
+
116
+ residual = hidden_states
117
+ hidden_states = self.layer_norm2(hidden_states)
118
+ hidden_states = self.fc1(hidden_states)
119
+ if self.use_quick_gelu:
120
+ hidden_states = self.quickGELU(hidden_states)
121
+ else:
122
+ hidden_states = torch.nn.functional.gelu(hidden_states)
123
+ hidden_states = self.fc2(hidden_states)
124
+ hidden_states = residual + hidden_states
125
+
126
+ return hidden_states
127
+
128
+
129
+ class SDTextEncoder(torch.nn.Module):
130
+ def __init__(self, embed_dim=768, vocab_size=49408, max_position_embeddings=77, num_encoder_layers=12, encoder_intermediate_size=3072):
131
+ super().__init__()
132
+
133
+ # token_embedding
134
+ self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim)
135
+
136
+ # position_embeds (This is a fixed tensor)
137
+ self.position_embeds = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, embed_dim))
138
+
139
+ # encoders
140
+ self.encoders = torch.nn.ModuleList([CLIPEncoderLayer(embed_dim, encoder_intermediate_size) for _ in range(num_encoder_layers)])
141
+
142
+ # attn_mask
143
+ self.attn_mask = self.attention_mask(max_position_embeddings)
144
+
145
+ # final_layer_norm
146
+ self.final_layer_norm = torch.nn.LayerNorm(embed_dim)
147
+
148
+ def attention_mask(self, length):
149
+ mask = torch.empty(length, length)
150
+ mask.fill_(float("-inf"))
151
+ mask.triu_(1)
152
+ return mask
153
+
154
+ def forward(self, input_ids, clip_skip=1):
155
+ embeds = self.token_embedding(input_ids) + self.position_embeds
156
+ attn_mask = self.attn_mask.to(device=embeds.device, dtype=embeds.dtype)
157
+ for encoder_id, encoder in enumerate(self.encoders):
158
+ embeds = encoder(embeds, attn_mask=attn_mask)
159
+ if encoder_id + clip_skip == len(self.encoders):
160
+ break
161
+ embeds = self.final_layer_norm(embeds)
162
+ return embeds
163
+
164
+ @staticmethod
165
+ def state_dict_converter():
166
+ return SDTextEncoderStateDictConverter()
167
+
168
+
169
+ class SDTextEncoderStateDictConverter:
170
+ def __init__(self):
171
+ pass
172
+
173
+ def from_diffusers(self, state_dict):
174
+ rename_dict = {
175
+ "text_model.embeddings.token_embedding.weight": "token_embedding.weight",
176
+ "text_model.embeddings.position_embedding.weight": "position_embeds",
177
+ "text_model.final_layer_norm.weight": "final_layer_norm.weight",
178
+ "text_model.final_layer_norm.bias": "final_layer_norm.bias"
179
+ }
180
+ attn_rename_dict = {
181
+ "self_attn.q_proj": "attn.to_q",
182
+ "self_attn.k_proj": "attn.to_k",
183
+ "self_attn.v_proj": "attn.to_v",
184
+ "self_attn.out_proj": "attn.to_out",
185
+ "layer_norm1": "layer_norm1",
186
+ "layer_norm2": "layer_norm2",
187
+ "mlp.fc1": "fc1",
188
+ "mlp.fc2": "fc2",
189
+ }
190
+ state_dict_ = {}
191
+ for name in state_dict:
192
+ if name in rename_dict:
193
+ param = state_dict[name]
194
+ if name == "text_model.embeddings.position_embedding.weight":
195
+ param = param.reshape((1, param.shape[0], param.shape[1]))
196
+ state_dict_[rename_dict[name]] = param
197
+ elif name.startswith("text_model.encoder.layers."):
198
+ param = state_dict[name]
199
+ names = name.split(".")
200
+ layer_id, layer_type, tail = names[3], ".".join(names[4:-1]), names[-1]
201
+ name_ = ".".join(["encoders", layer_id, attn_rename_dict[layer_type], tail])
202
+ state_dict_[name_] = param
203
+ return state_dict_
204
+
205
+ def from_civitai(self, state_dict):
206
+ rename_dict = {
207
+ "cond_stage_model.transformer.text_model.embeddings.token_embedding.weight": "token_embedding.weight",
208
+ "cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm1.bias": "encoders.0.layer_norm1.bias",
209
+ "cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm1.weight": "encoders.0.layer_norm1.weight",
210
+ "cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm2.bias": "encoders.0.layer_norm2.bias",
211
+ "cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm2.weight": "encoders.0.layer_norm2.weight",
212
+ "cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc1.bias": "encoders.0.fc1.bias",
213
+ "cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc1.weight": "encoders.0.fc1.weight",
214
+ "cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc2.bias": "encoders.0.fc2.bias",
215
+ "cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc2.weight": "encoders.0.fc2.weight",
216
+ "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.k_proj.bias": "encoders.0.attn.to_k.bias",
217
+ "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.k_proj.weight": "encoders.0.attn.to_k.weight",
218
+ "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.out_proj.bias": "encoders.0.attn.to_out.bias",
219
+ "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.out_proj.weight": "encoders.0.attn.to_out.weight",
220
+ "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.q_proj.bias": "encoders.0.attn.to_q.bias",
221
+ "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.q_proj.weight": "encoders.0.attn.to_q.weight",
222
+ "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.v_proj.bias": "encoders.0.attn.to_v.bias",
223
+ "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.v_proj.weight": "encoders.0.attn.to_v.weight",
224
+ "cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm1.bias": "encoders.1.layer_norm1.bias",
225
+ "cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm1.weight": "encoders.1.layer_norm1.weight",
226
+ "cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm2.bias": "encoders.1.layer_norm2.bias",
227
+ "cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm2.weight": "encoders.1.layer_norm2.weight",
228
+ "cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc1.bias": "encoders.1.fc1.bias",
229
+ "cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc1.weight": "encoders.1.fc1.weight",
230
+ "cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc2.bias": "encoders.1.fc2.bias",
231
+ "cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc2.weight": "encoders.1.fc2.weight",
232
+ "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.k_proj.bias": "encoders.1.attn.to_k.bias",
233
+ "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.k_proj.weight": "encoders.1.attn.to_k.weight",
234
+ "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.out_proj.bias": "encoders.1.attn.to_out.bias",
235
+ "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.out_proj.weight": "encoders.1.attn.to_out.weight",
236
+ "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.q_proj.bias": "encoders.1.attn.to_q.bias",
237
+ "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.q_proj.weight": "encoders.1.attn.to_q.weight",
238
+ "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.v_proj.bias": "encoders.1.attn.to_v.bias",
239
+ "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.v_proj.weight": "encoders.1.attn.to_v.weight",
240
+ "cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm1.bias": "encoders.10.layer_norm1.bias",
241
+ "cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm1.weight": "encoders.10.layer_norm1.weight",
242
+ "cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm2.bias": "encoders.10.layer_norm2.bias",
243
+ "cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm2.weight": "encoders.10.layer_norm2.weight",
244
+ "cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc1.bias": "encoders.10.fc1.bias",
245
+ "cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc1.weight": "encoders.10.fc1.weight",
246
+ "cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc2.bias": "encoders.10.fc2.bias",
247
+ "cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc2.weight": "encoders.10.fc2.weight",
248
+ "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.k_proj.bias": "encoders.10.attn.to_k.bias",
249
+ "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.k_proj.weight": "encoders.10.attn.to_k.weight",
250
+ "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.out_proj.bias": "encoders.10.attn.to_out.bias",
251
+ "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.out_proj.weight": "encoders.10.attn.to_out.weight",
252
+ "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.q_proj.bias": "encoders.10.attn.to_q.bias",
253
+ "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.q_proj.weight": "encoders.10.attn.to_q.weight",
254
+ "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.v_proj.bias": "encoders.10.attn.to_v.bias",
255
+ "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.v_proj.weight": "encoders.10.attn.to_v.weight",
256
+ "cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm1.bias": "encoders.11.layer_norm1.bias",
257
+ "cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm1.weight": "encoders.11.layer_norm1.weight",
258
+ "cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm2.bias": "encoders.11.layer_norm2.bias",
259
+ "cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm2.weight": "encoders.11.layer_norm2.weight",
260
+ "cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc1.bias": "encoders.11.fc1.bias",
261
+ "cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc1.weight": "encoders.11.fc1.weight",
262
+ "cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc2.bias": "encoders.11.fc2.bias",
263
+ "cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc2.weight": "encoders.11.fc2.weight",
264
+ "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.k_proj.bias": "encoders.11.attn.to_k.bias",
265
+ "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.k_proj.weight": "encoders.11.attn.to_k.weight",
266
+ "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.out_proj.bias": "encoders.11.attn.to_out.bias",
267
+ "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.out_proj.weight": "encoders.11.attn.to_out.weight",
268
+ "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.q_proj.bias": "encoders.11.attn.to_q.bias",
269
+ "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.q_proj.weight": "encoders.11.attn.to_q.weight",
270
+ "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.v_proj.bias": "encoders.11.attn.to_v.bias",
271
+ "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.v_proj.weight": "encoders.11.attn.to_v.weight",
272
+ "cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm1.bias": "encoders.2.layer_norm1.bias",
273
+ "cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm1.weight": "encoders.2.layer_norm1.weight",
274
+ "cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm2.bias": "encoders.2.layer_norm2.bias",
275
+ "cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm2.weight": "encoders.2.layer_norm2.weight",
276
+ "cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc1.bias": "encoders.2.fc1.bias",
277
+ "cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc1.weight": "encoders.2.fc1.weight",
278
+ "cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc2.bias": "encoders.2.fc2.bias",
279
+ "cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc2.weight": "encoders.2.fc2.weight",
280
+ "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.k_proj.bias": "encoders.2.attn.to_k.bias",
281
+ "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.k_proj.weight": "encoders.2.attn.to_k.weight",
282
+ "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.out_proj.bias": "encoders.2.attn.to_out.bias",
283
+ "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.out_proj.weight": "encoders.2.attn.to_out.weight",
284
+ "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.q_proj.bias": "encoders.2.attn.to_q.bias",
285
+ "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.q_proj.weight": "encoders.2.attn.to_q.weight",
286
+ "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.v_proj.bias": "encoders.2.attn.to_v.bias",
287
+ "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.v_proj.weight": "encoders.2.attn.to_v.weight",
288
+ "cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm1.bias": "encoders.3.layer_norm1.bias",
289
+ "cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm1.weight": "encoders.3.layer_norm1.weight",
290
+ "cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm2.bias": "encoders.3.layer_norm2.bias",
291
+ "cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm2.weight": "encoders.3.layer_norm2.weight",
292
+ "cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc1.bias": "encoders.3.fc1.bias",
293
+ "cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc1.weight": "encoders.3.fc1.weight",
294
+ "cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc2.bias": "encoders.3.fc2.bias",
295
+ "cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc2.weight": "encoders.3.fc2.weight",
296
+ "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.k_proj.bias": "encoders.3.attn.to_k.bias",
297
+ "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.k_proj.weight": "encoders.3.attn.to_k.weight",
298
+ "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.out_proj.bias": "encoders.3.attn.to_out.bias",
299
+ "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.out_proj.weight": "encoders.3.attn.to_out.weight",
300
+ "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.q_proj.bias": "encoders.3.attn.to_q.bias",
301
+ "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.q_proj.weight": "encoders.3.attn.to_q.weight",
302
+ "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.v_proj.bias": "encoders.3.attn.to_v.bias",
303
+ "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.v_proj.weight": "encoders.3.attn.to_v.weight",
304
+ "cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm1.bias": "encoders.4.layer_norm1.bias",
305
+ "cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm1.weight": "encoders.4.layer_norm1.weight",
306
+ "cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm2.bias": "encoders.4.layer_norm2.bias",
307
+ "cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm2.weight": "encoders.4.layer_norm2.weight",
308
+ "cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc1.bias": "encoders.4.fc1.bias",
309
+ "cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc1.weight": "encoders.4.fc1.weight",
310
+ "cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc2.bias": "encoders.4.fc2.bias",
311
+ "cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc2.weight": "encoders.4.fc2.weight",
312
+ "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.k_proj.bias": "encoders.4.attn.to_k.bias",
313
+ "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.k_proj.weight": "encoders.4.attn.to_k.weight",
314
+ "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.out_proj.bias": "encoders.4.attn.to_out.bias",
315
+ "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.out_proj.weight": "encoders.4.attn.to_out.weight",
316
+ "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.q_proj.bias": "encoders.4.attn.to_q.bias",
317
+ "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.q_proj.weight": "encoders.4.attn.to_q.weight",
318
+ "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.v_proj.bias": "encoders.4.attn.to_v.bias",
319
+ "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.v_proj.weight": "encoders.4.attn.to_v.weight",
320
+ "cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm1.bias": "encoders.5.layer_norm1.bias",
321
+ "cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm1.weight": "encoders.5.layer_norm1.weight",
322
+ "cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm2.bias": "encoders.5.layer_norm2.bias",
323
+ "cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm2.weight": "encoders.5.layer_norm2.weight",
324
+ "cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc1.bias": "encoders.5.fc1.bias",
325
+ "cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc1.weight": "encoders.5.fc1.weight",
326
+ "cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc2.bias": "encoders.5.fc2.bias",
327
+ "cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc2.weight": "encoders.5.fc2.weight",
328
+ "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.k_proj.bias": "encoders.5.attn.to_k.bias",
329
+ "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.k_proj.weight": "encoders.5.attn.to_k.weight",
330
+ "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.out_proj.bias": "encoders.5.attn.to_out.bias",
331
+ "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.out_proj.weight": "encoders.5.attn.to_out.weight",
332
+ "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.q_proj.bias": "encoders.5.attn.to_q.bias",
333
+ "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.q_proj.weight": "encoders.5.attn.to_q.weight",
334
+ "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.v_proj.bias": "encoders.5.attn.to_v.bias",
335
+ "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.v_proj.weight": "encoders.5.attn.to_v.weight",
336
+ "cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm1.bias": "encoders.6.layer_norm1.bias",
337
+ "cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm1.weight": "encoders.6.layer_norm1.weight",
338
+ "cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm2.bias": "encoders.6.layer_norm2.bias",
339
+ "cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm2.weight": "encoders.6.layer_norm2.weight",
340
+ "cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc1.bias": "encoders.6.fc1.bias",
341
+ "cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc1.weight": "encoders.6.fc1.weight",
342
+ "cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc2.bias": "encoders.6.fc2.bias",
343
+ "cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc2.weight": "encoders.6.fc2.weight",
344
+ "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.k_proj.bias": "encoders.6.attn.to_k.bias",
345
+ "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.k_proj.weight": "encoders.6.attn.to_k.weight",
346
+ "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.out_proj.bias": "encoders.6.attn.to_out.bias",
347
+ "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.out_proj.weight": "encoders.6.attn.to_out.weight",
348
+ "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.q_proj.bias": "encoders.6.attn.to_q.bias",
349
+ "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.q_proj.weight": "encoders.6.attn.to_q.weight",
350
+ "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.v_proj.bias": "encoders.6.attn.to_v.bias",
351
+ "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.v_proj.weight": "encoders.6.attn.to_v.weight",
352
+ "cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm1.bias": "encoders.7.layer_norm1.bias",
353
+ "cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm1.weight": "encoders.7.layer_norm1.weight",
354
+ "cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm2.bias": "encoders.7.layer_norm2.bias",
355
+ "cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm2.weight": "encoders.7.layer_norm2.weight",
356
+ "cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc1.bias": "encoders.7.fc1.bias",
357
+ "cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc1.weight": "encoders.7.fc1.weight",
358
+ "cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc2.bias": "encoders.7.fc2.bias",
359
+ "cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc2.weight": "encoders.7.fc2.weight",
360
+ "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.k_proj.bias": "encoders.7.attn.to_k.bias",
361
+ "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.k_proj.weight": "encoders.7.attn.to_k.weight",
362
+ "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.out_proj.bias": "encoders.7.attn.to_out.bias",
363
+ "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.out_proj.weight": "encoders.7.attn.to_out.weight",
364
+ "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.q_proj.bias": "encoders.7.attn.to_q.bias",
365
+ "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.q_proj.weight": "encoders.7.attn.to_q.weight",
366
+ "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.v_proj.bias": "encoders.7.attn.to_v.bias",
367
+ "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.v_proj.weight": "encoders.7.attn.to_v.weight",
368
+ "cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm1.bias": "encoders.8.layer_norm1.bias",
369
+ "cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm1.weight": "encoders.8.layer_norm1.weight",
370
+ "cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm2.bias": "encoders.8.layer_norm2.bias",
371
+ "cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm2.weight": "encoders.8.layer_norm2.weight",
372
+ "cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc1.bias": "encoders.8.fc1.bias",
373
+ "cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc1.weight": "encoders.8.fc1.weight",
374
+ "cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc2.bias": "encoders.8.fc2.bias",
375
+ "cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc2.weight": "encoders.8.fc2.weight",
376
+ "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.k_proj.bias": "encoders.8.attn.to_k.bias",
377
+ "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.k_proj.weight": "encoders.8.attn.to_k.weight",
378
+ "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.out_proj.bias": "encoders.8.attn.to_out.bias",
379
+ "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.out_proj.weight": "encoders.8.attn.to_out.weight",
380
+ "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.q_proj.bias": "encoders.8.attn.to_q.bias",
381
+ "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.q_proj.weight": "encoders.8.attn.to_q.weight",
382
+ "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.v_proj.bias": "encoders.8.attn.to_v.bias",
383
+ "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.v_proj.weight": "encoders.8.attn.to_v.weight",
384
+ "cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm1.bias": "encoders.9.layer_norm1.bias",
385
+ "cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm1.weight": "encoders.9.layer_norm1.weight",
386
+ "cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm2.bias": "encoders.9.layer_norm2.bias",
387
+ "cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm2.weight": "encoders.9.layer_norm2.weight",
388
+ "cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc1.bias": "encoders.9.fc1.bias",
389
+ "cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc1.weight": "encoders.9.fc1.weight",
390
+ "cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc2.bias": "encoders.9.fc2.bias",
391
+ "cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc2.weight": "encoders.9.fc2.weight",
392
+ "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.k_proj.bias": "encoders.9.attn.to_k.bias",
393
+ "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.k_proj.weight": "encoders.9.attn.to_k.weight",
394
+ "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.out_proj.bias": "encoders.9.attn.to_out.bias",
395
+ "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.out_proj.weight": "encoders.9.attn.to_out.weight",
396
+ "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.q_proj.bias": "encoders.9.attn.to_q.bias",
397
+ "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.q_proj.weight": "encoders.9.attn.to_q.weight",
398
+ "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.v_proj.bias": "encoders.9.attn.to_v.bias",
399
+ "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.v_proj.weight": "encoders.9.attn.to_v.weight",
400
+ "cond_stage_model.transformer.text_model.final_layer_norm.bias": "final_layer_norm.bias",
401
+ "cond_stage_model.transformer.text_model.final_layer_norm.weight": "final_layer_norm.weight",
402
+ "cond_stage_model.transformer.text_model.embeddings.position_embedding.weight": "position_embeds"
403
+ }
404
+ state_dict_ = {}
405
+ for name in state_dict:
406
+ if name in rename_dict:
407
+ param = state_dict[name]
408
+ if name == "cond_stage_model.transformer.text_model.embeddings.position_embedding.weight":
409
+ param = param.reshape((1, param.shape[0], param.shape[1]))
410
+ state_dict_[rename_dict[name]] = param
411
+ return state_dict_
412
+
413
+
414
+
415
+ class LoRALayerBlock(torch.nn.Module):
416
+ def __init__(self, L, dim_in, dim_out):
417
+ super().__init__()
418
+ self.x = torch.nn.Parameter(torch.randn(1, L, dim_in))
419
+ self.layer_norm = torch.nn.LayerNorm(dim_out)
420
+
421
+ def forward(self, lora_A, lora_B):
422
+ x = self.x @ lora_A.T @ lora_B.T
423
+ x = self.layer_norm(x)
424
+ return x
425
+
426
+
427
+ class LoRAEmbedder(torch.nn.Module):
428
+ def __init__(self, lora_patterns=None, L=1, out_dim=2048):
429
+ super().__init__()
430
+ if lora_patterns is None:
431
+ lora_patterns = self.default_lora_patterns()
432
+
433
+ model_dict = {}
434
+ for lora_pattern in lora_patterns:
435
+ name, dim = lora_pattern["name"], lora_pattern["dim"]
436
+ model_dict[name.replace(".", "___")] = LoRALayerBlock(L, dim[0], dim[1])
437
+ self.model_dict = torch.nn.ModuleDict(model_dict)
438
+
439
+ proj_dict = {}
440
+ for lora_pattern in lora_patterns:
441
+ layer_type, dim = lora_pattern["type"], lora_pattern["dim"]
442
+ if layer_type not in proj_dict:
443
+ proj_dict[layer_type.replace(".", "___")] = torch.nn.Linear(dim[1], out_dim)
444
+ self.proj_dict = torch.nn.ModuleDict(proj_dict)
445
+
446
+ self.lora_patterns = lora_patterns
447
+
448
+
449
+ def default_lora_patterns(self):
450
+ lora_patterns = []
451
+ lora_dict = {
452
+ "attn.a_to_qkv": (3072, 9216), "attn.a_to_out": (3072, 3072), "ff_a.0": (3072, 12288), "ff_a.2": (12288, 3072), "norm1_a.linear": (3072, 18432),
453
+ "attn.b_to_qkv": (3072, 9216), "attn.b_to_out": (3072, 3072), "ff_b.0": (3072, 12288), "ff_b.2": (12288, 3072), "norm1_b.linear": (3072, 18432),
454
+ }
455
+ for i in range(19):
456
+ for suffix in lora_dict:
457
+ lora_patterns.append({
458
+ "name": f"blocks.{i}.{suffix}",
459
+ "dim": lora_dict[suffix],
460
+ "type": suffix,
461
+ })
462
+ lora_dict = {"to_qkv_mlp": (3072, 21504), "proj_out": (15360, 3072), "norm.linear": (3072, 9216)}
463
+ for i in range(38):
464
+ for suffix in lora_dict:
465
+ lora_patterns.append({
466
+ "name": f"single_blocks.{i}.{suffix}",
467
+ "dim": lora_dict[suffix],
468
+ "type": suffix,
469
+ })
470
+ return lora_patterns
471
+
472
+ def forward(self, lora):
473
+ lora_emb = []
474
+ for lora_pattern in self.lora_patterns:
475
+ name, layer_type = lora_pattern["name"], lora_pattern["type"]
476
+ lora_A = lora[name + ".lora_A.weight"]
477
+ lora_B = lora[name + ".lora_B.weight"]
478
+ lora_out = self.model_dict[name.replace(".", "___")](lora_A, lora_B)
479
+ lora_out = self.proj_dict[layer_type.replace(".", "___")](lora_out)
480
+ lora_emb.append(lora_out)
481
+ lora_emb = torch.concat(lora_emb, dim=1)
482
+ return lora_emb
483
+
484
+
485
+ class FluxLoRAEncoder(torch.nn.Module):
486
+ def __init__(self, embed_dim=4096, encoder_intermediate_size=8192, num_encoder_layers=1, num_embeds_per_lora=16, num_special_embeds=1):
487
+ super().__init__()
488
+ self.num_embeds_per_lora = num_embeds_per_lora
489
+ # embedder
490
+ self.embedder = LoRAEmbedder(L=num_embeds_per_lora, out_dim=embed_dim)
491
+
492
+ # encoders
493
+ self.encoders = torch.nn.ModuleList([CLIPEncoderLayer(embed_dim, encoder_intermediate_size, num_heads=32, head_dim=128) for _ in range(num_encoder_layers)])
494
+
495
+ # special embedding
496
+ self.special_embeds = torch.nn.Parameter(torch.randn(1, num_special_embeds, embed_dim))
497
+ self.num_special_embeds = num_special_embeds
498
+
499
+ # final layer
500
+ self.final_layer_norm = torch.nn.LayerNorm(embed_dim)
501
+ self.final_linear = torch.nn.Linear(embed_dim, embed_dim)
502
+
503
+ def forward(self, lora):
504
+ lora_embeds = self.embedder(lora)
505
+ special_embeds = self.special_embeds.to(dtype=lora_embeds.dtype, device=lora_embeds.device)
506
+ embeds = torch.concat([special_embeds, lora_embeds], dim=1)
507
+ for encoder_id, encoder in enumerate(self.encoders):
508
+ embeds = encoder(embeds)
509
+ embeds = embeds[:, :self.num_special_embeds]
510
+ embeds = self.final_layer_norm(embeds)
511
+ embeds = self.final_linear(embeds)
512
+ return embeds
513
+
514
+ @staticmethod
515
+ def state_dict_converter():
516
+ return FluxLoRAEncoderStateDictConverter()
517
+
518
+
519
+ class FluxLoRAEncoderStateDictConverter:
520
+ def from_civitai(self, state_dict):
521
+ return state_dict
diffsynth/models/flux_lora_patcher.py ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, math
2
+ from ..core.loader import load_state_dict
3
+ from typing import Union
4
+
5
+ class GeneralLoRALoader:
6
+ def __init__(self, device="cpu", torch_dtype=torch.float32):
7
+ self.device = device
8
+ self.torch_dtype = torch_dtype
9
+
10
+
11
+ def get_name_dict(self, lora_state_dict):
12
+ lora_name_dict = {}
13
+ for key in lora_state_dict:
14
+ if ".lora_B." not in key:
15
+ continue
16
+ keys = key.split(".")
17
+ if len(keys) > keys.index("lora_B") + 2:
18
+ keys.pop(keys.index("lora_B") + 1)
19
+ keys.pop(keys.index("lora_B"))
20
+ if keys[0] == "diffusion_model":
21
+ keys.pop(0)
22
+ keys.pop(-1)
23
+ target_name = ".".join(keys)
24
+ lora_name_dict[target_name] = (key, key.replace(".lora_B.", ".lora_A."))
25
+ return lora_name_dict
26
+
27
+
28
+ def load(self, model: torch.nn.Module, state_dict_lora, alpha=1.0):
29
+ updated_num = 0
30
+ lora_name_dict = self.get_name_dict(state_dict_lora)
31
+ for name, module in model.named_modules():
32
+ if name in lora_name_dict:
33
+ weight_up = state_dict_lora[lora_name_dict[name][0]].to(device=self.device, dtype=self.torch_dtype)
34
+ weight_down = state_dict_lora[lora_name_dict[name][1]].to(device=self.device, dtype=self.torch_dtype)
35
+ if len(weight_up.shape) == 4:
36
+ weight_up = weight_up.squeeze(3).squeeze(2)
37
+ weight_down = weight_down.squeeze(3).squeeze(2)
38
+ weight_lora = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
39
+ else:
40
+ weight_lora = alpha * torch.mm(weight_up, weight_down)
41
+ state_dict = module.state_dict()
42
+ state_dict["weight"] = state_dict["weight"].to(device=self.device, dtype=self.torch_dtype) + weight_lora
43
+ module.load_state_dict(state_dict)
44
+ updated_num += 1
45
+ print(f"{updated_num} tensors are updated by LoRA.")
46
+
47
+ class FluxLoRALoader(GeneralLoRALoader):
48
+ def __init__(self, device="cpu", torch_dtype=torch.float32):
49
+ super().__init__(device=device, torch_dtype=torch_dtype)
50
+
51
+ self.diffusers_rename_dict = {
52
+ "transformer.single_transformer_blocks.blockid.attn.to_k.lora_A.weight":"single_blocks.blockid.a_to_k.lora_A.default.weight",
53
+ "transformer.single_transformer_blocks.blockid.attn.to_k.lora_B.weight":"single_blocks.blockid.a_to_k.lora_B.default.weight",
54
+ "transformer.single_transformer_blocks.blockid.attn.to_q.lora_A.weight":"single_blocks.blockid.a_to_q.lora_A.default.weight",
55
+ "transformer.single_transformer_blocks.blockid.attn.to_q.lora_B.weight":"single_blocks.blockid.a_to_q.lora_B.default.weight",
56
+ "transformer.single_transformer_blocks.blockid.attn.to_v.lora_A.weight":"single_blocks.blockid.a_to_v.lora_A.default.weight",
57
+ "transformer.single_transformer_blocks.blockid.attn.to_v.lora_B.weight":"single_blocks.blockid.a_to_v.lora_B.default.weight",
58
+ "transformer.single_transformer_blocks.blockid.norm.linear.lora_A.weight":"single_blocks.blockid.norm.linear.lora_A.default.weight",
59
+ "transformer.single_transformer_blocks.blockid.norm.linear.lora_B.weight":"single_blocks.blockid.norm.linear.lora_B.default.weight",
60
+ "transformer.single_transformer_blocks.blockid.proj_mlp.lora_A.weight":"single_blocks.blockid.proj_in_besides_attn.lora_A.default.weight",
61
+ "transformer.single_transformer_blocks.blockid.proj_mlp.lora_B.weight":"single_blocks.blockid.proj_in_besides_attn.lora_B.default.weight",
62
+ "transformer.single_transformer_blocks.blockid.proj_out.lora_A.weight":"single_blocks.blockid.proj_out.lora_A.default.weight",
63
+ "transformer.single_transformer_blocks.blockid.proj_out.lora_B.weight":"single_blocks.blockid.proj_out.lora_B.default.weight",
64
+ "transformer.transformer_blocks.blockid.attn.add_k_proj.lora_A.weight":"blocks.blockid.attn.b_to_k.lora_A.default.weight",
65
+ "transformer.transformer_blocks.blockid.attn.add_k_proj.lora_B.weight":"blocks.blockid.attn.b_to_k.lora_B.default.weight",
66
+ "transformer.transformer_blocks.blockid.attn.add_q_proj.lora_A.weight":"blocks.blockid.attn.b_to_q.lora_A.default.weight",
67
+ "transformer.transformer_blocks.blockid.attn.add_q_proj.lora_B.weight":"blocks.blockid.attn.b_to_q.lora_B.default.weight",
68
+ "transformer.transformer_blocks.blockid.attn.add_v_proj.lora_A.weight":"blocks.blockid.attn.b_to_v.lora_A.default.weight",
69
+ "transformer.transformer_blocks.blockid.attn.add_v_proj.lora_B.weight":"blocks.blockid.attn.b_to_v.lora_B.default.weight",
70
+ "transformer.transformer_blocks.blockid.attn.to_add_out.lora_A.weight":"blocks.blockid.attn.b_to_out.lora_A.default.weight",
71
+ "transformer.transformer_blocks.blockid.attn.to_add_out.lora_B.weight":"blocks.blockid.attn.b_to_out.lora_B.default.weight",
72
+ "transformer.transformer_blocks.blockid.attn.to_k.lora_A.weight":"blocks.blockid.attn.a_to_k.lora_A.default.weight",
73
+ "transformer.transformer_blocks.blockid.attn.to_k.lora_B.weight":"blocks.blockid.attn.a_to_k.lora_B.default.weight",
74
+ "transformer.transformer_blocks.blockid.attn.to_out.0.lora_A.weight":"blocks.blockid.attn.a_to_out.lora_A.default.weight",
75
+ "transformer.transformer_blocks.blockid.attn.to_out.0.lora_B.weight":"blocks.blockid.attn.a_to_out.lora_B.default.weight",
76
+ "transformer.transformer_blocks.blockid.attn.to_q.lora_A.weight":"blocks.blockid.attn.a_to_q.lora_A.default.weight",
77
+ "transformer.transformer_blocks.blockid.attn.to_q.lora_B.weight":"blocks.blockid.attn.a_to_q.lora_B.default.weight",
78
+ "transformer.transformer_blocks.blockid.attn.to_v.lora_A.weight":"blocks.blockid.attn.a_to_v.lora_A.default.weight",
79
+ "transformer.transformer_blocks.blockid.attn.to_v.lora_B.weight":"blocks.blockid.attn.a_to_v.lora_B.default.weight",
80
+ "transformer.transformer_blocks.blockid.ff.net.0.proj.lora_A.weight":"blocks.blockid.ff_a.0.lora_A.default.weight",
81
+ "transformer.transformer_blocks.blockid.ff.net.0.proj.lora_B.weight":"blocks.blockid.ff_a.0.lora_B.default.weight",
82
+ "transformer.transformer_blocks.blockid.ff.net.2.lora_A.weight":"blocks.blockid.ff_a.2.lora_A.default.weight",
83
+ "transformer.transformer_blocks.blockid.ff.net.2.lora_B.weight":"blocks.blockid.ff_a.2.lora_B.default.weight",
84
+ "transformer.transformer_blocks.blockid.ff_context.net.0.proj.lora_A.weight":"blocks.blockid.ff_b.0.lora_A.default.weight",
85
+ "transformer.transformer_blocks.blockid.ff_context.net.0.proj.lora_B.weight":"blocks.blockid.ff_b.0.lora_B.default.weight",
86
+ "transformer.transformer_blocks.blockid.ff_context.net.2.lora_A.weight":"blocks.blockid.ff_b.2.lora_A.default.weight",
87
+ "transformer.transformer_blocks.blockid.ff_context.net.2.lora_B.weight":"blocks.blockid.ff_b.2.lora_B.default.weight",
88
+ "transformer.transformer_blocks.blockid.norm1.linear.lora_A.weight":"blocks.blockid.norm1_a.linear.lora_A.default.weight",
89
+ "transformer.transformer_blocks.blockid.norm1.linear.lora_B.weight":"blocks.blockid.norm1_a.linear.lora_B.default.weight",
90
+ "transformer.transformer_blocks.blockid.norm1_context.linear.lora_A.weight":"blocks.blockid.norm1_b.linear.lora_A.default.weight",
91
+ "transformer.transformer_blocks.blockid.norm1_context.linear.lora_B.weight":"blocks.blockid.norm1_b.linear.lora_B.default.weight",
92
+ }
93
+
94
+ self.civitai_rename_dict = {
95
+ "lora_unet_double_blocks_blockid_img_mod_lin.lora_down.weight": "blocks.blockid.norm1_a.linear.lora_A.default.weight",
96
+ "lora_unet_double_blocks_blockid_img_mod_lin.lora_up.weight": "blocks.blockid.norm1_a.linear.lora_B.default.weight",
97
+ "lora_unet_double_blocks_blockid_txt_mod_lin.lora_down.weight": "blocks.blockid.norm1_b.linear.lora_A.default.weight",
98
+ "lora_unet_double_blocks_blockid_txt_mod_lin.lora_up.weight": "blocks.blockid.norm1_b.linear.lora_B.default.weight",
99
+ "lora_unet_double_blocks_blockid_img_attn_qkv.lora_down.weight": "blocks.blockid.attn.a_to_qkv.lora_A.default.weight",
100
+ "lora_unet_double_blocks_blockid_img_attn_qkv.lora_up.weight": "blocks.blockid.attn.a_to_qkv.lora_B.default.weight",
101
+ "lora_unet_double_blocks_blockid_txt_attn_qkv.lora_down.weight": "blocks.blockid.attn.b_to_qkv.lora_A.default.weight",
102
+ "lora_unet_double_blocks_blockid_txt_attn_qkv.lora_up.weight": "blocks.blockid.attn.b_to_qkv.lora_B.default.weight",
103
+ "lora_unet_double_blocks_blockid_img_attn_proj.lora_down.weight": "blocks.blockid.attn.a_to_out.lora_A.default.weight",
104
+ "lora_unet_double_blocks_blockid_img_attn_proj.lora_up.weight": "blocks.blockid.attn.a_to_out.lora_B.default.weight",
105
+ "lora_unet_double_blocks_blockid_txt_attn_proj.lora_down.weight": "blocks.blockid.attn.b_to_out.lora_A.default.weight",
106
+ "lora_unet_double_blocks_blockid_txt_attn_proj.lora_up.weight": "blocks.blockid.attn.b_to_out.lora_B.default.weight",
107
+ "lora_unet_double_blocks_blockid_img_mlp_0.lora_down.weight": "blocks.blockid.ff_a.0.lora_A.default.weight",
108
+ "lora_unet_double_blocks_blockid_img_mlp_0.lora_up.weight": "blocks.blockid.ff_a.0.lora_B.default.weight",
109
+ "lora_unet_double_blocks_blockid_img_mlp_2.lora_down.weight": "blocks.blockid.ff_a.2.lora_A.default.weight",
110
+ "lora_unet_double_blocks_blockid_img_mlp_2.lora_up.weight": "blocks.blockid.ff_a.2.lora_B.default.weight",
111
+ "lora_unet_double_blocks_blockid_txt_mlp_0.lora_down.weight": "blocks.blockid.ff_b.0.lora_A.default.weight",
112
+ "lora_unet_double_blocks_blockid_txt_mlp_0.lora_up.weight": "blocks.blockid.ff_b.0.lora_B.default.weight",
113
+ "lora_unet_double_blocks_blockid_txt_mlp_2.lora_down.weight": "blocks.blockid.ff_b.2.lora_A.default.weight",
114
+ "lora_unet_double_blocks_blockid_txt_mlp_2.lora_up.weight": "blocks.blockid.ff_b.2.lora_B.default.weight",
115
+ "lora_unet_single_blocks_blockid_modulation_lin.lora_down.weight": "single_blocks.blockid.norm.linear.lora_A.default.weight",
116
+ "lora_unet_single_blocks_blockid_modulation_lin.lora_up.weight": "single_blocks.blockid.norm.linear.lora_B.default.weight",
117
+ "lora_unet_single_blocks_blockid_linear1.lora_down.weight": "single_blocks.blockid.to_qkv_mlp.lora_A.default.weight",
118
+ "lora_unet_single_blocks_blockid_linear1.lora_up.weight": "single_blocks.blockid.to_qkv_mlp.lora_B.default.weight",
119
+ "lora_unet_single_blocks_blockid_linear2.lora_down.weight": "single_blocks.blockid.proj_out.lora_A.default.weight",
120
+ "lora_unet_single_blocks_blockid_linear2.lora_up.weight": "single_blocks.blockid.proj_out.lora_B.default.weight",
121
+ }
122
+
123
+ def load(self, model: torch.nn.Module, state_dict_lora, alpha=1.0):
124
+ super().load(model, state_dict_lora, alpha)
125
+
126
+
127
+ def convert_state_dict(self,state_dict):
128
+
129
+ def guess_block_id(name,model_resource):
130
+ if model_resource == 'civitai':
131
+ names = name.split("_")
132
+ for i in names:
133
+ if i.isdigit():
134
+ return i, name.replace(f"_{i}_", "_blockid_")
135
+ if model_resource == 'diffusers':
136
+ names = name.split(".")
137
+ for i in names:
138
+ if i.isdigit():
139
+ return i, name.replace(f"transformer_blocks.{i}.", "transformer_blocks.blockid.")
140
+ return None, None
141
+
142
+ def guess_resource(state_dict):
143
+ for k in state_dict:
144
+ if "lora_unet_" in k:
145
+ return 'civitai'
146
+ elif k.startswith("transformer."):
147
+ return 'diffusers'
148
+ else:
149
+ None
150
+
151
+ model_resource = guess_resource(state_dict)
152
+ if model_resource is None:
153
+ return state_dict
154
+
155
+ rename_dict = self.diffusers_rename_dict if model_resource == 'diffusers' else self.civitai_rename_dict
156
+ def guess_alpha(state_dict):
157
+ for name, param in state_dict.items():
158
+ if ".alpha" in name:
159
+ for suffix in [".lora_down.weight", ".lora_A.weight"]:
160
+ name_ = name.replace(".alpha", suffix)
161
+ if name_ in state_dict:
162
+ lora_alpha = param.item() / state_dict[name_].shape[0]
163
+ lora_alpha = math.sqrt(lora_alpha)
164
+ return lora_alpha
165
+
166
+ return 1
167
+
168
+ alpha = guess_alpha(state_dict)
169
+
170
+ state_dict_ = {}
171
+ for name, param in state_dict.items():
172
+ block_id, source_name = guess_block_id(name,model_resource)
173
+ if alpha != 1:
174
+ param *= alpha
175
+ if source_name in rename_dict:
176
+ target_name = rename_dict[source_name]
177
+ target_name = target_name.replace(".blockid.", f".{block_id}.")
178
+ state_dict_[target_name] = param
179
+ else:
180
+ state_dict_[name] = param
181
+
182
+ if model_resource == 'diffusers':
183
+ for name in list(state_dict_.keys()):
184
+ if "single_blocks." in name and ".a_to_q." in name:
185
+ mlp = state_dict_.get(name.replace(".a_to_q.", ".proj_in_besides_attn."), None)
186
+ if mlp is None:
187
+ dim = 4
188
+ if 'lora_A' in name:
189
+ dim = 1
190
+ mlp = torch.zeros(dim * state_dict_[name].shape[0],
191
+ *state_dict_[name].shape[1:],
192
+ dtype=state_dict_[name].dtype)
193
+ else:
194
+ state_dict_.pop(name.replace(".a_to_q.", ".proj_in_besides_attn."))
195
+ if 'lora_A' in name:
196
+ param = torch.concat([
197
+ state_dict_.pop(name),
198
+ state_dict_.pop(name.replace(".a_to_q.", ".a_to_k.")),
199
+ state_dict_.pop(name.replace(".a_to_q.", ".a_to_v.")),
200
+ mlp,
201
+ ], dim=0)
202
+ elif 'lora_B' in name:
203
+ d, r = state_dict_[name].shape
204
+ param = torch.zeros((3*d+mlp.shape[0], 3*r+mlp.shape[1]), dtype=state_dict_[name].dtype, device=state_dict_[name].device)
205
+ param[:d, :r] = state_dict_.pop(name)
206
+ param[d:2*d, r:2*r] = state_dict_.pop(name.replace(".a_to_q.", ".a_to_k."))
207
+ param[2*d:3*d, 2*r:3*r] = state_dict_.pop(name.replace(".a_to_q.", ".a_to_v."))
208
+ param[3*d:, 3*r:] = mlp
209
+ else:
210
+ param = torch.concat([
211
+ state_dict_.pop(name),
212
+ state_dict_.pop(name.replace(".a_to_q.", ".a_to_k.")),
213
+ state_dict_.pop(name.replace(".a_to_q.", ".a_to_v.")),
214
+ mlp,
215
+ ], dim=0)
216
+ name_ = name.replace(".a_to_q.", ".to_qkv_mlp.")
217
+ state_dict_[name_] = param
218
+ for name in list(state_dict_.keys()):
219
+ for component in ["a", "b"]:
220
+ if f".{component}_to_q." in name:
221
+ name_ = name.replace(f".{component}_to_q.", f".{component}_to_qkv.")
222
+ concat_dim = 0
223
+ if 'lora_A' in name:
224
+ param = torch.concat([
225
+ state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")],
226
+ state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")],
227
+ state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")],
228
+ ], dim=0)
229
+ elif 'lora_B' in name:
230
+ origin = state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")]
231
+ d, r = origin.shape
232
+ # print(d, r)
233
+ param = torch.zeros((3*d, 3*r), dtype=origin.dtype, device=origin.device)
234
+ param[:d, :r] = state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")]
235
+ param[d:2*d, r:2*r] = state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")]
236
+ param[2*d:3*d, 2*r:3*r] = state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")]
237
+ else:
238
+ param = torch.concat([
239
+ state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")],
240
+ state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")],
241
+ state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")],
242
+ ], dim=0)
243
+ state_dict_[name_] = param
244
+ state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_q."))
245
+ state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_k."))
246
+ state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_v."))
247
+ return state_dict_
248
+
249
+
250
+ class LoraMerger(torch.nn.Module):
251
+ def __init__(self, dim):
252
+ super().__init__()
253
+ self.weight_base = torch.nn.Parameter(torch.randn((dim,)))
254
+ self.weight_lora = torch.nn.Parameter(torch.randn((dim,)))
255
+ self.weight_cross = torch.nn.Parameter(torch.randn((dim,)))
256
+ self.weight_out = torch.nn.Parameter(torch.ones((dim,)))
257
+ self.bias = torch.nn.Parameter(torch.randn((dim,)))
258
+ self.activation = torch.nn.Sigmoid()
259
+ self.norm_base = torch.nn.LayerNorm(dim, eps=1e-5)
260
+ self.norm_lora = torch.nn.LayerNorm(dim, eps=1e-5)
261
+
262
+ def forward(self, base_output, lora_outputs):
263
+ norm_base_output = self.norm_base(base_output)
264
+ norm_lora_outputs = self.norm_lora(lora_outputs)
265
+ gate = self.activation(
266
+ norm_base_output * self.weight_base \
267
+ + norm_lora_outputs * self.weight_lora \
268
+ + norm_base_output * norm_lora_outputs * self.weight_cross + self.bias
269
+ )
270
+ output = base_output + (self.weight_out * gate * lora_outputs).sum(dim=0)
271
+ return output
272
+
273
+ class FluxLoraPatcher(torch.nn.Module):
274
+ def __init__(self, lora_patterns=None):
275
+ super().__init__()
276
+ if lora_patterns is None:
277
+ lora_patterns = self.default_lora_patterns()
278
+ model_dict = {}
279
+ for lora_pattern in lora_patterns:
280
+ name, dim = lora_pattern["name"], lora_pattern["dim"]
281
+ model_dict[name.replace(".", "___")] = LoraMerger(dim)
282
+ self.model_dict = torch.nn.ModuleDict(model_dict)
283
+
284
+ def default_lora_patterns(self):
285
+ lora_patterns = []
286
+ lora_dict = {
287
+ "attn.a_to_qkv": 9216, "attn.a_to_out": 3072, "ff_a.0": 12288, "ff_a.2": 3072, "norm1_a.linear": 18432,
288
+ "attn.b_to_qkv": 9216, "attn.b_to_out": 3072, "ff_b.0": 12288, "ff_b.2": 3072, "norm1_b.linear": 18432,
289
+ }
290
+ for i in range(19):
291
+ for suffix in lora_dict:
292
+ lora_patterns.append({
293
+ "name": f"blocks.{i}.{suffix}",
294
+ "dim": lora_dict[suffix]
295
+ })
296
+ lora_dict = {"to_qkv_mlp": 21504, "proj_out": 3072, "norm.linear": 9216}
297
+ for i in range(38):
298
+ for suffix in lora_dict:
299
+ lora_patterns.append({
300
+ "name": f"single_blocks.{i}.{suffix}",
301
+ "dim": lora_dict[suffix]
302
+ })
303
+ return lora_patterns
304
+
305
+ def forward(self, base_output, lora_outputs, name):
306
+ return self.model_dict[name.replace(".", "___")](base_output, lora_outputs)
diffsynth/models/flux_text_encoder_clip.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ class Attention(torch.nn.Module):
5
+
6
+ def __init__(self, q_dim, num_heads, head_dim, kv_dim=None, bias_q=False, bias_kv=False, bias_out=False):
7
+ super().__init__()
8
+ dim_inner = head_dim * num_heads
9
+ kv_dim = kv_dim if kv_dim is not None else q_dim
10
+ self.num_heads = num_heads
11
+ self.head_dim = head_dim
12
+
13
+ self.to_q = torch.nn.Linear(q_dim, dim_inner, bias=bias_q)
14
+ self.to_k = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
15
+ self.to_v = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
16
+ self.to_out = torch.nn.Linear(dim_inner, q_dim, bias=bias_out)
17
+
18
+ def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None):
19
+ if encoder_hidden_states is None:
20
+ encoder_hidden_states = hidden_states
21
+
22
+ batch_size = encoder_hidden_states.shape[0]
23
+
24
+ q = self.to_q(hidden_states)
25
+ k = self.to_k(encoder_hidden_states)
26
+ v = self.to_v(encoder_hidden_states)
27
+
28
+ q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
29
+ k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
30
+ v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
31
+
32
+ hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
33
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
34
+ hidden_states = hidden_states.to(q.dtype)
35
+
36
+ hidden_states = self.to_out(hidden_states)
37
+
38
+ return hidden_states
39
+
40
+
41
+ class CLIPEncoderLayer(torch.nn.Module):
42
+ def __init__(self, embed_dim, intermediate_size, num_heads=12, head_dim=64, use_quick_gelu=True):
43
+ super().__init__()
44
+ self.attn = Attention(q_dim=embed_dim, num_heads=num_heads, head_dim=head_dim, bias_q=True, bias_kv=True, bias_out=True)
45
+ self.layer_norm1 = torch.nn.LayerNorm(embed_dim)
46
+ self.layer_norm2 = torch.nn.LayerNorm(embed_dim)
47
+ self.fc1 = torch.nn.Linear(embed_dim, intermediate_size)
48
+ self.fc2 = torch.nn.Linear(intermediate_size, embed_dim)
49
+
50
+ self.use_quick_gelu = use_quick_gelu
51
+
52
+ def quickGELU(self, x):
53
+ return x * torch.sigmoid(1.702 * x)
54
+
55
+ def forward(self, hidden_states, attn_mask=None):
56
+ residual = hidden_states
57
+
58
+ hidden_states = self.layer_norm1(hidden_states)
59
+ hidden_states = self.attn(hidden_states, attn_mask=attn_mask)
60
+ hidden_states = residual + hidden_states
61
+
62
+ residual = hidden_states
63
+ hidden_states = self.layer_norm2(hidden_states)
64
+ hidden_states = self.fc1(hidden_states)
65
+ if self.use_quick_gelu:
66
+ hidden_states = self.quickGELU(hidden_states)
67
+ else:
68
+ hidden_states = torch.nn.functional.gelu(hidden_states)
69
+ hidden_states = self.fc2(hidden_states)
70
+ hidden_states = residual + hidden_states
71
+
72
+ return hidden_states
73
+
74
+
75
+ class FluxTextEncoderClip(torch.nn.Module):
76
+ def __init__(self, embed_dim=768, vocab_size=49408, max_position_embeddings=77, num_encoder_layers=12, encoder_intermediate_size=3072):
77
+ super().__init__()
78
+
79
+ # token_embedding
80
+ self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim)
81
+
82
+ # position_embeds (This is a fixed tensor)
83
+ self.position_embeds = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, embed_dim))
84
+
85
+ # encoders
86
+ self.encoders = torch.nn.ModuleList([CLIPEncoderLayer(embed_dim, encoder_intermediate_size) for _ in range(num_encoder_layers)])
87
+
88
+ # attn_mask
89
+ self.attn_mask = self.attention_mask(max_position_embeddings)
90
+
91
+ # final_layer_norm
92
+ self.final_layer_norm = torch.nn.LayerNorm(embed_dim)
93
+
94
+ def attention_mask(self, length):
95
+ mask = torch.empty(length, length)
96
+ mask.fill_(float("-inf"))
97
+ mask.triu_(1)
98
+ return mask
99
+
100
+ def forward(self, input_ids, clip_skip=2, extra_mask=None):
101
+ embeds = self.token_embedding(input_ids)
102
+ embeds = embeds + self.position_embeds.to(dtype=embeds.dtype, device=input_ids.device)
103
+ attn_mask = self.attn_mask.to(device=embeds.device, dtype=embeds.dtype)
104
+ if extra_mask is not None:
105
+ attn_mask[:, extra_mask[0]==0] = float("-inf")
106
+ for encoder_id, encoder in enumerate(self.encoders):
107
+ embeds = encoder(embeds, attn_mask=attn_mask)
108
+ if encoder_id + clip_skip == len(self.encoders):
109
+ hidden_states = embeds
110
+ embeds = self.final_layer_norm(embeds)
111
+ pooled_embeds = embeds[torch.arange(embeds.shape[0]), input_ids.to(dtype=torch.int).argmax(dim=-1)]
112
+ return pooled_embeds, hidden_states
diffsynth/models/flux_text_encoder_t5.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import T5EncoderModel, T5Config
3
+
4
+
5
+ class FluxTextEncoderT5(T5EncoderModel):
6
+ def __init__(self):
7
+ config = T5Config(**{
8
+ "architectures": [
9
+ "T5EncoderModel"
10
+ ],
11
+ "classifier_dropout": 0.0,
12
+ "d_ff": 10240,
13
+ "d_kv": 64,
14
+ "d_model": 4096,
15
+ "decoder_start_token_id": 0,
16
+ "dense_act_fn": "gelu_new",
17
+ "dropout_rate": 0.1,
18
+ "dtype": "bfloat16",
19
+ "eos_token_id": 1,
20
+ "feed_forward_proj": "gated-gelu",
21
+ "initializer_factor": 1.0,
22
+ "is_encoder_decoder": True,
23
+ "is_gated_act": True,
24
+ "layer_norm_epsilon": 1e-06,
25
+ "model_type": "t5",
26
+ "num_decoder_layers": 24,
27
+ "num_heads": 64,
28
+ "num_layers": 24,
29
+ "output_past": True,
30
+ "pad_token_id": 0,
31
+ "relative_attention_max_distance": 128,
32
+ "relative_attention_num_buckets": 32,
33
+ "tie_word_embeddings": False,
34
+ "transformers_version": "4.57.1",
35
+ "use_cache": True,
36
+ "vocab_size": 32128
37
+ })
38
+ super().__init__(config)
39
+
40
+ def forward(self, input_ids):
41
+ outputs = super().forward(input_ids=input_ids)
42
+ prompt_emb = outputs.last_hidden_state
43
+ return prompt_emb
diffsynth/models/flux_vae.py ADDED
@@ -0,0 +1,451 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from einops import rearrange, repeat
3
+
4
+
5
+ class TileWorker:
6
+ def __init__(self):
7
+ pass
8
+
9
+
10
+ def mask(self, height, width, border_width):
11
+ # Create a mask with shape (height, width).
12
+ # The centre area is filled with 1, and the border line is filled with values in range (0, 1].
13
+ x = torch.arange(height).repeat(width, 1).T
14
+ y = torch.arange(width).repeat(height, 1)
15
+ mask = torch.stack([x + 1, height - x, y + 1, width - y]).min(dim=0).values
16
+ mask = (mask / border_width).clip(0, 1)
17
+ return mask
18
+
19
+
20
+ def tile(self, model_input, tile_size, tile_stride, tile_device, tile_dtype):
21
+ # Convert a tensor (b, c, h, w) to (b, c, tile_size, tile_size, tile_num)
22
+ batch_size, channel, _, _ = model_input.shape
23
+ model_input = model_input.to(device=tile_device, dtype=tile_dtype)
24
+ unfold_operator = torch.nn.Unfold(
25
+ kernel_size=(tile_size, tile_size),
26
+ stride=(tile_stride, tile_stride)
27
+ )
28
+ model_input = unfold_operator(model_input)
29
+ model_input = model_input.view((batch_size, channel, tile_size, tile_size, -1))
30
+
31
+ return model_input
32
+
33
+
34
+ def tiled_inference(self, forward_fn, model_input, tile_batch_size, inference_device, inference_dtype, tile_device, tile_dtype):
35
+ # Call y=forward_fn(x) for each tile
36
+ tile_num = model_input.shape[-1]
37
+ model_output_stack = []
38
+
39
+ for tile_id in range(0, tile_num, tile_batch_size):
40
+
41
+ # process input
42
+ tile_id_ = min(tile_id + tile_batch_size, tile_num)
43
+ x = model_input[:, :, :, :, tile_id: tile_id_]
44
+ x = x.to(device=inference_device, dtype=inference_dtype)
45
+ x = rearrange(x, "b c h w n -> (n b) c h w")
46
+
47
+ # process output
48
+ y = forward_fn(x)
49
+ y = rearrange(y, "(n b) c h w -> b c h w n", n=tile_id_-tile_id)
50
+ y = y.to(device=tile_device, dtype=tile_dtype)
51
+ model_output_stack.append(y)
52
+
53
+ model_output = torch.concat(model_output_stack, dim=-1)
54
+ return model_output
55
+
56
+
57
+ def io_scale(self, model_output, tile_size):
58
+ # Determine the size modification happened in forward_fn
59
+ # We only consider the same scale on height and width.
60
+ io_scale = model_output.shape[2] / tile_size
61
+ return io_scale
62
+
63
+
64
+ def untile(self, model_output, height, width, tile_size, tile_stride, border_width, tile_device, tile_dtype):
65
+ # The reversed function of tile
66
+ mask = self.mask(tile_size, tile_size, border_width)
67
+ mask = mask.to(device=tile_device, dtype=tile_dtype)
68
+ mask = rearrange(mask, "h w -> 1 1 h w 1")
69
+ model_output = model_output * mask
70
+
71
+ fold_operator = torch.nn.Fold(
72
+ output_size=(height, width),
73
+ kernel_size=(tile_size, tile_size),
74
+ stride=(tile_stride, tile_stride)
75
+ )
76
+ mask = repeat(mask[0, 0, :, :, 0], "h w -> 1 (h w) n", n=model_output.shape[-1])
77
+ model_output = rearrange(model_output, "b c h w n -> b (c h w) n")
78
+ model_output = fold_operator(model_output) / fold_operator(mask)
79
+
80
+ return model_output
81
+
82
+
83
+ def tiled_forward(self, forward_fn, model_input, tile_size, tile_stride, tile_batch_size=1, tile_device="cpu", tile_dtype=torch.float32, border_width=None):
84
+ # Prepare
85
+ inference_device, inference_dtype = model_input.device, model_input.dtype
86
+ height, width = model_input.shape[2], model_input.shape[3]
87
+ border_width = int(tile_stride*0.5) if border_width is None else border_width
88
+
89
+ # tile
90
+ model_input = self.tile(model_input, tile_size, tile_stride, tile_device, tile_dtype)
91
+
92
+ # inference
93
+ model_output = self.tiled_inference(forward_fn, model_input, tile_batch_size, inference_device, inference_dtype, tile_device, tile_dtype)
94
+
95
+ # resize
96
+ io_scale = self.io_scale(model_output, tile_size)
97
+ height, width = int(height*io_scale), int(width*io_scale)
98
+ tile_size, tile_stride = int(tile_size*io_scale), int(tile_stride*io_scale)
99
+ border_width = int(border_width*io_scale)
100
+
101
+ # untile
102
+ model_output = self.untile(model_output, height, width, tile_size, tile_stride, border_width, tile_device, tile_dtype)
103
+
104
+ # Done!
105
+ model_output = model_output.to(device=inference_device, dtype=inference_dtype)
106
+ return model_output
107
+
108
+
109
+ class ConvAttention(torch.nn.Module):
110
+
111
+ def __init__(self, q_dim, num_heads, head_dim, kv_dim=None, bias_q=False, bias_kv=False, bias_out=False):
112
+ super().__init__()
113
+ dim_inner = head_dim * num_heads
114
+ kv_dim = kv_dim if kv_dim is not None else q_dim
115
+ self.num_heads = num_heads
116
+ self.head_dim = head_dim
117
+
118
+ self.to_q = torch.nn.Conv2d(q_dim, dim_inner, kernel_size=(1, 1), bias=bias_q)
119
+ self.to_k = torch.nn.Conv2d(kv_dim, dim_inner, kernel_size=(1, 1), bias=bias_kv)
120
+ self.to_v = torch.nn.Conv2d(kv_dim, dim_inner, kernel_size=(1, 1), bias=bias_kv)
121
+ self.to_out = torch.nn.Conv2d(dim_inner, q_dim, kernel_size=(1, 1), bias=bias_out)
122
+
123
+ def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None):
124
+ if encoder_hidden_states is None:
125
+ encoder_hidden_states = hidden_states
126
+
127
+ batch_size = encoder_hidden_states.shape[0]
128
+
129
+ conv_input = rearrange(hidden_states, "B L C -> B C L 1")
130
+ q = self.to_q(conv_input)
131
+ q = rearrange(q[:, :, :, 0], "B C L -> B L C")
132
+ conv_input = rearrange(encoder_hidden_states, "B L C -> B C L 1")
133
+ k = self.to_k(conv_input)
134
+ v = self.to_v(conv_input)
135
+ k = rearrange(k[:, :, :, 0], "B C L -> B L C")
136
+ v = rearrange(v[:, :, :, 0], "B C L -> B L C")
137
+
138
+ q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
139
+ k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
140
+ v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
141
+
142
+ hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
143
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
144
+ hidden_states = hidden_states.to(q.dtype)
145
+
146
+ conv_input = rearrange(hidden_states, "B L C -> B C L 1")
147
+ hidden_states = self.to_out(conv_input)
148
+ hidden_states = rearrange(hidden_states[:, :, :, 0], "B C L -> B L C")
149
+
150
+ return hidden_states
151
+
152
+
153
+ class Attention(torch.nn.Module):
154
+
155
+ def __init__(self, q_dim, num_heads, head_dim, kv_dim=None, bias_q=False, bias_kv=False, bias_out=False):
156
+ super().__init__()
157
+ dim_inner = head_dim * num_heads
158
+ kv_dim = kv_dim if kv_dim is not None else q_dim
159
+ self.num_heads = num_heads
160
+ self.head_dim = head_dim
161
+
162
+ self.to_q = torch.nn.Linear(q_dim, dim_inner, bias=bias_q)
163
+ self.to_k = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
164
+ self.to_v = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
165
+ self.to_out = torch.nn.Linear(dim_inner, q_dim, bias=bias_out)
166
+
167
+ def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None):
168
+ if encoder_hidden_states is None:
169
+ encoder_hidden_states = hidden_states
170
+
171
+ batch_size = encoder_hidden_states.shape[0]
172
+
173
+ q = self.to_q(hidden_states)
174
+ k = self.to_k(encoder_hidden_states)
175
+ v = self.to_v(encoder_hidden_states)
176
+
177
+ q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
178
+ k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
179
+ v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
180
+
181
+ hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
182
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
183
+ hidden_states = hidden_states.to(q.dtype)
184
+
185
+ hidden_states = self.to_out(hidden_states)
186
+
187
+ return hidden_states
188
+
189
+
190
+ class VAEAttentionBlock(torch.nn.Module):
191
+
192
+ def __init__(self, num_attention_heads, attention_head_dim, in_channels, num_layers=1, norm_num_groups=32, eps=1e-5, use_conv_attention=True):
193
+ super().__init__()
194
+ inner_dim = num_attention_heads * attention_head_dim
195
+
196
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=eps, affine=True)
197
+
198
+ if use_conv_attention:
199
+ self.transformer_blocks = torch.nn.ModuleList([
200
+ ConvAttention(
201
+ inner_dim,
202
+ num_attention_heads,
203
+ attention_head_dim,
204
+ bias_q=True,
205
+ bias_kv=True,
206
+ bias_out=True
207
+ )
208
+ for d in range(num_layers)
209
+ ])
210
+ else:
211
+ self.transformer_blocks = torch.nn.ModuleList([
212
+ Attention(
213
+ inner_dim,
214
+ num_attention_heads,
215
+ attention_head_dim,
216
+ bias_q=True,
217
+ bias_kv=True,
218
+ bias_out=True
219
+ )
220
+ for d in range(num_layers)
221
+ ])
222
+
223
+ def forward(self, hidden_states, time_emb, text_emb, res_stack):
224
+ batch, _, height, width = hidden_states.shape
225
+ residual = hidden_states
226
+
227
+ hidden_states = self.norm(hidden_states)
228
+ inner_dim = hidden_states.shape[1]
229
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
230
+
231
+ for block in self.transformer_blocks:
232
+ hidden_states = block(hidden_states)
233
+
234
+ hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
235
+ hidden_states = hidden_states + residual
236
+
237
+ return hidden_states, time_emb, text_emb, res_stack
238
+
239
+
240
+ class ResnetBlock(torch.nn.Module):
241
+ def __init__(self, in_channels, out_channels, temb_channels=None, groups=32, eps=1e-5):
242
+ super().__init__()
243
+ self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
244
+ self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
245
+ if temb_channels is not None:
246
+ self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels)
247
+ self.norm2 = torch.nn.GroupNorm(num_groups=groups, num_channels=out_channels, eps=eps, affine=True)
248
+ self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
249
+ self.nonlinearity = torch.nn.SiLU()
250
+ self.conv_shortcut = None
251
+ if in_channels != out_channels:
252
+ self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=True)
253
+
254
+ def forward(self, hidden_states, time_emb, text_emb, res_stack, **kwargs):
255
+ x = hidden_states
256
+ x = self.norm1(x)
257
+ x = self.nonlinearity(x)
258
+ x = self.conv1(x)
259
+ if time_emb is not None:
260
+ emb = self.nonlinearity(time_emb)
261
+ emb = self.time_emb_proj(emb)[:, :, None, None]
262
+ x = x + emb
263
+ x = self.norm2(x)
264
+ x = self.nonlinearity(x)
265
+ x = self.conv2(x)
266
+ if self.conv_shortcut is not None:
267
+ hidden_states = self.conv_shortcut(hidden_states)
268
+ hidden_states = hidden_states + x
269
+ return hidden_states, time_emb, text_emb, res_stack
270
+
271
+
272
+ class UpSampler(torch.nn.Module):
273
+ def __init__(self, channels):
274
+ super().__init__()
275
+ self.conv = torch.nn.Conv2d(channels, channels, 3, padding=1)
276
+
277
+ def forward(self, hidden_states, time_emb, text_emb, res_stack, **kwargs):
278
+ hidden_states = torch.nn.functional.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
279
+ hidden_states = self.conv(hidden_states)
280
+ return hidden_states, time_emb, text_emb, res_stack
281
+
282
+
283
+ class DownSampler(torch.nn.Module):
284
+ def __init__(self, channels, padding=1, extra_padding=False):
285
+ super().__init__()
286
+ self.conv = torch.nn.Conv2d(channels, channels, 3, stride=2, padding=padding)
287
+ self.extra_padding = extra_padding
288
+
289
+ def forward(self, hidden_states, time_emb, text_emb, res_stack, **kwargs):
290
+ if self.extra_padding:
291
+ hidden_states = torch.nn.functional.pad(hidden_states, (0, 1, 0, 1), mode="constant", value=0)
292
+ hidden_states = self.conv(hidden_states)
293
+ return hidden_states, time_emb, text_emb, res_stack
294
+
295
+
296
+ class FluxVAEDecoder(torch.nn.Module):
297
+ def __init__(self, use_conv_attention=True):
298
+ super().__init__()
299
+ self.scaling_factor = 0.3611
300
+ self.shift_factor = 0.1159
301
+ self.conv_in = torch.nn.Conv2d(16, 512, kernel_size=3, padding=1) # Different from SD 1.x
302
+
303
+ self.blocks = torch.nn.ModuleList([
304
+ # UNetMidBlock2D
305
+ ResnetBlock(512, 512, eps=1e-6),
306
+ VAEAttentionBlock(1, 512, 512, 1, eps=1e-6, use_conv_attention=use_conv_attention),
307
+ ResnetBlock(512, 512, eps=1e-6),
308
+ # UpDecoderBlock2D
309
+ ResnetBlock(512, 512, eps=1e-6),
310
+ ResnetBlock(512, 512, eps=1e-6),
311
+ ResnetBlock(512, 512, eps=1e-6),
312
+ UpSampler(512),
313
+ # UpDecoderBlock2D
314
+ ResnetBlock(512, 512, eps=1e-6),
315
+ ResnetBlock(512, 512, eps=1e-6),
316
+ ResnetBlock(512, 512, eps=1e-6),
317
+ UpSampler(512),
318
+ # UpDecoderBlock2D
319
+ ResnetBlock(512, 256, eps=1e-6),
320
+ ResnetBlock(256, 256, eps=1e-6),
321
+ ResnetBlock(256, 256, eps=1e-6),
322
+ UpSampler(256),
323
+ # UpDecoderBlock2D
324
+ ResnetBlock(256, 128, eps=1e-6),
325
+ ResnetBlock(128, 128, eps=1e-6),
326
+ ResnetBlock(128, 128, eps=1e-6),
327
+ ])
328
+
329
+ self.conv_norm_out = torch.nn.GroupNorm(num_channels=128, num_groups=32, eps=1e-6)
330
+ self.conv_act = torch.nn.SiLU()
331
+ self.conv_out = torch.nn.Conv2d(128, 3, kernel_size=3, padding=1)
332
+
333
+ def tiled_forward(self, sample, tile_size=64, tile_stride=32):
334
+ hidden_states = TileWorker().tiled_forward(
335
+ lambda x: self.forward(x),
336
+ sample,
337
+ tile_size,
338
+ tile_stride,
339
+ tile_device=sample.device,
340
+ tile_dtype=sample.dtype
341
+ )
342
+ return hidden_states
343
+
344
+ def forward(self, sample, tiled=False, tile_size=64, tile_stride=32, **kwargs):
345
+ # For VAE Decoder, we do not need to apply the tiler on each layer.
346
+ if tiled:
347
+ return self.tiled_forward(sample, tile_size=tile_size, tile_stride=tile_stride)
348
+
349
+ # 1. pre-process
350
+ hidden_states = sample / self.scaling_factor + self.shift_factor
351
+ hidden_states = self.conv_in(hidden_states)
352
+ time_emb = None
353
+ text_emb = None
354
+ res_stack = None
355
+
356
+ # 2. blocks
357
+ for i, block in enumerate(self.blocks):
358
+ hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
359
+
360
+ # 3. output
361
+ hidden_states = self.conv_norm_out(hidden_states)
362
+ hidden_states = self.conv_act(hidden_states)
363
+ hidden_states = self.conv_out(hidden_states)
364
+
365
+ return hidden_states
366
+
367
+
368
+ class FluxVAEEncoder(torch.nn.Module):
369
+ def __init__(self, use_conv_attention=True):
370
+ super().__init__()
371
+ self.scaling_factor = 0.3611
372
+ self.shift_factor = 0.1159
373
+ self.conv_in = torch.nn.Conv2d(3, 128, kernel_size=3, padding=1)
374
+
375
+ self.blocks = torch.nn.ModuleList([
376
+ # DownEncoderBlock2D
377
+ ResnetBlock(128, 128, eps=1e-6),
378
+ ResnetBlock(128, 128, eps=1e-6),
379
+ DownSampler(128, padding=0, extra_padding=True),
380
+ # DownEncoderBlock2D
381
+ ResnetBlock(128, 256, eps=1e-6),
382
+ ResnetBlock(256, 256, eps=1e-6),
383
+ DownSampler(256, padding=0, extra_padding=True),
384
+ # DownEncoderBlock2D
385
+ ResnetBlock(256, 512, eps=1e-6),
386
+ ResnetBlock(512, 512, eps=1e-6),
387
+ DownSampler(512, padding=0, extra_padding=True),
388
+ # DownEncoderBlock2D
389
+ ResnetBlock(512, 512, eps=1e-6),
390
+ ResnetBlock(512, 512, eps=1e-6),
391
+ # UNetMidBlock2D
392
+ ResnetBlock(512, 512, eps=1e-6),
393
+ VAEAttentionBlock(1, 512, 512, 1, eps=1e-6, use_conv_attention=use_conv_attention),
394
+ ResnetBlock(512, 512, eps=1e-6),
395
+ ])
396
+
397
+ self.conv_norm_out = torch.nn.GroupNorm(num_channels=512, num_groups=32, eps=1e-6)
398
+ self.conv_act = torch.nn.SiLU()
399
+ self.conv_out = torch.nn.Conv2d(512, 32, kernel_size=3, padding=1)
400
+
401
+ def tiled_forward(self, sample, tile_size=64, tile_stride=32):
402
+ hidden_states = TileWorker().tiled_forward(
403
+ lambda x: self.forward(x),
404
+ sample,
405
+ tile_size,
406
+ tile_stride,
407
+ tile_device=sample.device,
408
+ tile_dtype=sample.dtype
409
+ )
410
+ return hidden_states
411
+
412
+ def forward(self, sample, tiled=False, tile_size=64, tile_stride=32, **kwargs):
413
+ # For VAE Decoder, we do not need to apply the tiler on each layer.
414
+ if tiled:
415
+ return self.tiled_forward(sample, tile_size=tile_size, tile_stride=tile_stride)
416
+
417
+ # 1. pre-process
418
+ hidden_states = self.conv_in(sample)
419
+ time_emb = None
420
+ text_emb = None
421
+ res_stack = None
422
+
423
+ # 2. blocks
424
+ for i, block in enumerate(self.blocks):
425
+ hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
426
+
427
+ # 3. output
428
+ hidden_states = self.conv_norm_out(hidden_states)
429
+ hidden_states = self.conv_act(hidden_states)
430
+ hidden_states = self.conv_out(hidden_states)
431
+ hidden_states = hidden_states[:, :16]
432
+ hidden_states = (hidden_states - self.shift_factor) * self.scaling_factor
433
+
434
+ return hidden_states
435
+
436
+ def encode_video(self, sample, batch_size=8):
437
+ B = sample.shape[0]
438
+ hidden_states = []
439
+
440
+ for i in range(0, sample.shape[2], batch_size):
441
+
442
+ j = min(i + batch_size, sample.shape[2])
443
+ sample_batch = rearrange(sample[:,:,i:j], "B C T H W -> (B T) C H W")
444
+
445
+ hidden_states_batch = self(sample_batch)
446
+ hidden_states_batch = rearrange(hidden_states_batch, "(B T) C H W -> B C T H W", B=B)
447
+
448
+ hidden_states.append(hidden_states_batch)
449
+
450
+ hidden_states = torch.concat(hidden_states, dim=2)
451
+ return hidden_states
diffsynth/models/flux_value_control.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from .general_modules import TemporalTimesteps
3
+
4
+
5
+ class MultiValueEncoder(torch.nn.Module):
6
+ def __init__(self, encoders=()):
7
+ super().__init__()
8
+ if not isinstance(encoders, list):
9
+ encoders = [encoders]
10
+ self.encoders = torch.nn.ModuleList(encoders)
11
+
12
+ def __call__(self, values, dtype):
13
+ emb = []
14
+ for encoder, value in zip(self.encoders, values):
15
+ if value is not None:
16
+ value = value.unsqueeze(0)
17
+ emb.append(encoder(value, dtype))
18
+ emb = torch.concat(emb, dim=0)
19
+ return emb
20
+
21
+
22
+ class SingleValueEncoder(torch.nn.Module):
23
+ def __init__(self, dim_in=256, dim_out=4096, prefer_len=32, computation_device=None):
24
+ super().__init__()
25
+ self.prefer_len = prefer_len
26
+ self.prefer_proj = TemporalTimesteps(num_channels=dim_in, flip_sin_to_cos=True, downscale_freq_shift=0, computation_device=computation_device)
27
+ self.prefer_value_embedder = torch.nn.Sequential(
28
+ torch.nn.Linear(dim_in, dim_out), torch.nn.SiLU(), torch.nn.Linear(dim_out, dim_out)
29
+ )
30
+ self.positional_embedding = torch.nn.Parameter(
31
+ torch.randn(self.prefer_len, dim_out)
32
+ )
33
+
34
+ def forward(self, value, dtype):
35
+ value = value * 1000
36
+ emb = self.prefer_proj(value).to(dtype)
37
+ emb = self.prefer_value_embedder(emb).squeeze(0)
38
+ base_embeddings = emb.expand(self.prefer_len, -1)
39
+ positional_embedding = self.positional_embedding.to(dtype=base_embeddings.dtype, device=base_embeddings.device)
40
+ learned_embeddings = base_embeddings + positional_embedding
41
+ return learned_embeddings
42
+
43
+ @staticmethod
44
+ def state_dict_converter():
45
+ return SingleValueEncoderStateDictConverter()
46
+
47
+
48
+ class SingleValueEncoderStateDictConverter:
49
+ def __init__(self):
50
+ pass
51
+
52
+ def from_diffusers(self, state_dict):
53
+ return state_dict
54
+
55
+ def from_civitai(self, state_dict):
56
+ return state_dict
diffsynth/models/general_modules.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, math
2
+
3
+
4
+ def get_timestep_embedding(
5
+ timesteps: torch.Tensor,
6
+ embedding_dim: int,
7
+ flip_sin_to_cos: bool = False,
8
+ downscale_freq_shift: float = 1,
9
+ scale: float = 1,
10
+ max_period: int = 10000,
11
+ computation_device = None,
12
+ align_dtype_to_timestep = False,
13
+ ):
14
+ assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
15
+
16
+ half_dim = embedding_dim // 2
17
+ exponent = -math.log(max_period) * torch.arange(
18
+ start=0, end=half_dim, dtype=torch.float32, device=timesteps.device if computation_device is None else computation_device
19
+ )
20
+ exponent = exponent / (half_dim - downscale_freq_shift)
21
+
22
+ emb = torch.exp(exponent)
23
+ if align_dtype_to_timestep:
24
+ emb = emb.to(timesteps.dtype)
25
+ emb = timesteps[:, None].float() * emb[None, :]
26
+
27
+ # scale embeddings
28
+ emb = scale * emb
29
+
30
+ # concat sine and cosine embeddings
31
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
32
+
33
+ # flip sine and cosine embeddings
34
+ if flip_sin_to_cos:
35
+ emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
36
+
37
+ # zero pad
38
+ if embedding_dim % 2 == 1:
39
+ emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
40
+ return emb
41
+
42
+
43
+ class TemporalTimesteps(torch.nn.Module):
44
+ def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, computation_device = None, scale=1, align_dtype_to_timestep=False):
45
+ super().__init__()
46
+ self.num_channels = num_channels
47
+ self.flip_sin_to_cos = flip_sin_to_cos
48
+ self.downscale_freq_shift = downscale_freq_shift
49
+ self.computation_device = computation_device
50
+ self.scale = scale
51
+ self.align_dtype_to_timestep = align_dtype_to_timestep
52
+
53
+ def forward(self, timesteps):
54
+ t_emb = get_timestep_embedding(
55
+ timesteps,
56
+ self.num_channels,
57
+ flip_sin_to_cos=self.flip_sin_to_cos,
58
+ downscale_freq_shift=self.downscale_freq_shift,
59
+ computation_device=self.computation_device,
60
+ scale=self.scale,
61
+ align_dtype_to_timestep=self.align_dtype_to_timestep,
62
+ )
63
+ return t_emb
64
+
65
+
66
+ class DiffusersCompatibleTimestepProj(torch.nn.Module):
67
+ def __init__(self, dim_in, dim_out):
68
+ super().__init__()
69
+ self.linear_1 = torch.nn.Linear(dim_in, dim_out)
70
+ self.act = torch.nn.SiLU()
71
+ self.linear_2 = torch.nn.Linear(dim_out, dim_out)
72
+
73
+ def forward(self, x):
74
+ x = self.linear_1(x)
75
+ x = self.act(x)
76
+ x = self.linear_2(x)
77
+ return x
78
+
79
+
80
+ class TimestepEmbeddings(torch.nn.Module):
81
+ def __init__(self, dim_in, dim_out, computation_device=None, diffusers_compatible_format=False, scale=1, align_dtype_to_timestep=False, use_additional_t_cond=False):
82
+ super().__init__()
83
+ self.time_proj = TemporalTimesteps(num_channels=dim_in, flip_sin_to_cos=True, downscale_freq_shift=0, computation_device=computation_device, scale=scale, align_dtype_to_timestep=align_dtype_to_timestep)
84
+ if diffusers_compatible_format:
85
+ self.timestep_embedder = DiffusersCompatibleTimestepProj(dim_in, dim_out)
86
+ else:
87
+ self.timestep_embedder = torch.nn.Sequential(
88
+ torch.nn.Linear(dim_in, dim_out), torch.nn.SiLU(), torch.nn.Linear(dim_out, dim_out)
89
+ )
90
+ self.use_additional_t_cond = use_additional_t_cond
91
+ if use_additional_t_cond:
92
+ self.addition_t_embedding = torch.nn.Embedding(2, dim_out)
93
+
94
+ def forward(self, timestep, dtype, addition_t_cond=None):
95
+ time_emb = self.time_proj(timestep).to(dtype)
96
+ time_emb = self.timestep_embedder(time_emb)
97
+ if addition_t_cond is not None:
98
+ addition_t_emb = self.addition_t_embedding(addition_t_cond)
99
+ addition_t_emb = addition_t_emb.to(dtype=dtype)
100
+ time_emb = time_emb + addition_t_emb
101
+ return time_emb
102
+
103
+
104
+ class RMSNorm(torch.nn.Module):
105
+ def __init__(self, dim, eps, elementwise_affine=True):
106
+ super().__init__()
107
+ self.eps = eps
108
+ if elementwise_affine:
109
+ self.weight = torch.nn.Parameter(torch.ones((dim,)))
110
+ else:
111
+ self.weight = None
112
+
113
+ def forward(self, hidden_states):
114
+ input_dtype = hidden_states.dtype
115
+ variance = hidden_states.to(torch.float32).square().mean(-1, keepdim=True)
116
+ hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
117
+ hidden_states = hidden_states.to(input_dtype)
118
+ if self.weight is not None:
119
+ hidden_states = hidden_states * self.weight
120
+ return hidden_states
121
+
122
+
123
+ class AdaLayerNorm(torch.nn.Module):
124
+ def __init__(self, dim, single=False, dual=False):
125
+ super().__init__()
126
+ self.single = single
127
+ self.dual = dual
128
+ self.linear = torch.nn.Linear(dim, dim * [[6, 2][single], 9][dual])
129
+ self.norm = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
130
+
131
+ def forward(self, x, emb):
132
+ emb = self.linear(torch.nn.functional.silu(emb))
133
+ if self.single:
134
+ scale, shift = emb.unsqueeze(1).chunk(2, dim=2)
135
+ x = self.norm(x) * (1 + scale) + shift
136
+ return x
137
+ elif self.dual:
138
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp, shift_msa2, scale_msa2, gate_msa2 = emb.unsqueeze(1).chunk(9, dim=2)
139
+ norm_x = self.norm(x)
140
+ x = norm_x * (1 + scale_msa) + shift_msa
141
+ norm_x2 = norm_x * (1 + scale_msa2) + shift_msa2
142
+ return x, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_x2, gate_msa2
143
+ else:
144
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.unsqueeze(1).chunk(6, dim=2)
145
+ x = self.norm(x) * (1 + scale_msa) + shift_msa
146
+ return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
diffsynth/models/longcat_video_dit.py ADDED
@@ -0,0 +1,902 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Tuple
2
+
3
+ import math
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.amp as amp
7
+
8
+ import numpy as np
9
+ import torch.nn.functional as F
10
+ from einops import rearrange, repeat
11
+ from .wan_video_dit import flash_attention
12
+ from ..core.device.npu_compatible_device import get_device_type
13
+ from ..core.gradient import gradient_checkpoint_forward
14
+
15
+
16
+ class RMSNorm_FP32(torch.nn.Module):
17
+ def __init__(self, dim: int, eps: float):
18
+ super().__init__()
19
+ self.eps = eps
20
+ self.weight = nn.Parameter(torch.ones(dim))
21
+
22
+ def _norm(self, x):
23
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
24
+
25
+ def forward(self, x):
26
+ output = self._norm(x.float()).type_as(x)
27
+ return output * self.weight
28
+
29
+
30
+ def broadcat(tensors, dim=-1):
31
+ num_tensors = len(tensors)
32
+ shape_lens = set(list(map(lambda t: len(t.shape), tensors)))
33
+ assert len(shape_lens) == 1, "tensors must all have the same number of dimensions"
34
+ shape_len = list(shape_lens)[0]
35
+ dim = (dim + shape_len) if dim < 0 else dim
36
+ dims = list(zip(*map(lambda t: list(t.shape), tensors)))
37
+ expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
38
+ assert all(
39
+ [*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]
40
+ ), "invalid dimensions for broadcastable concatentation"
41
+ max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims))
42
+ expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims))
43
+ expanded_dims.insert(dim, (dim, dims[dim]))
44
+ expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims)))
45
+ tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes)))
46
+ return torch.cat(tensors, dim=dim)
47
+
48
+
49
+ def rotate_half(x):
50
+ x = rearrange(x, "... (d r) -> ... d r", r=2)
51
+ x1, x2 = x.unbind(dim=-1)
52
+ x = torch.stack((-x2, x1), dim=-1)
53
+ return rearrange(x, "... d r -> ... (d r)")
54
+
55
+
56
+ class RotaryPositionalEmbedding(nn.Module):
57
+
58
+ def __init__(self,
59
+ head_dim,
60
+ cp_split_hw=None
61
+ ):
62
+ """Rotary positional embedding for 3D
63
+ Reference : https://blog.eleuther.ai/rotary-embeddings/
64
+ Paper: https://arxiv.org/pdf/2104.09864.pdf
65
+ Args:
66
+ dim: Dimension of embedding
67
+ base: Base value for exponential
68
+ """
69
+ super().__init__()
70
+ self.head_dim = head_dim
71
+ assert self.head_dim % 8 == 0, 'Dim must be a multiply of 8 for 3D RoPE.'
72
+ self.cp_split_hw = cp_split_hw
73
+ # We take the assumption that the longest side of grid will not larger than 512, i.e, 512 * 8 = 4098 input pixels
74
+ self.base = 10000
75
+ self.freqs_dict = {}
76
+
77
+ def register_grid_size(self, grid_size):
78
+ if grid_size not in self.freqs_dict:
79
+ self.freqs_dict.update({
80
+ grid_size: self.precompute_freqs_cis_3d(grid_size)
81
+ })
82
+
83
+ def precompute_freqs_cis_3d(self, grid_size):
84
+ num_frames, height, width = grid_size
85
+ dim_t = self.head_dim - 4 * (self.head_dim // 6)
86
+ dim_h = 2 * (self.head_dim // 6)
87
+ dim_w = 2 * (self.head_dim // 6)
88
+ freqs_t = 1.0 / (self.base ** (torch.arange(0, dim_t, 2)[: (dim_t // 2)].float() / dim_t))
89
+ freqs_h = 1.0 / (self.base ** (torch.arange(0, dim_h, 2)[: (dim_h // 2)].float() / dim_h))
90
+ freqs_w = 1.0 / (self.base ** (torch.arange(0, dim_w, 2)[: (dim_w // 2)].float() / dim_w))
91
+ grid_t = np.linspace(0, num_frames, num_frames, endpoint=False, dtype=np.float32)
92
+ grid_h = np.linspace(0, height, height, endpoint=False, dtype=np.float32)
93
+ grid_w = np.linspace(0, width, width, endpoint=False, dtype=np.float32)
94
+ grid_t = torch.from_numpy(grid_t).float()
95
+ grid_h = torch.from_numpy(grid_h).float()
96
+ grid_w = torch.from_numpy(grid_w).float()
97
+ freqs_t = torch.einsum("..., f -> ... f", grid_t, freqs_t)
98
+ freqs_h = torch.einsum("..., f -> ... f", grid_h, freqs_h)
99
+ freqs_w = torch.einsum("..., f -> ... f", grid_w, freqs_w)
100
+ freqs_t = repeat(freqs_t, "... n -> ... (n r)", r=2)
101
+ freqs_h = repeat(freqs_h, "... n -> ... (n r)", r=2)
102
+ freqs_w = repeat(freqs_w, "... n -> ... (n r)", r=2)
103
+ freqs = broadcat((freqs_t[:, None, None, :], freqs_h[None, :, None, :], freqs_w[None, None, :, :]), dim=-1)
104
+ # (T H W D)
105
+ freqs = rearrange(freqs, "T H W D -> (T H W) D")
106
+ # if self.cp_split_hw[0] * self.cp_split_hw[1] > 1:
107
+ # with torch.no_grad():
108
+ # freqs = rearrange(freqs, "(T H W) D -> T H W D", T=num_frames, H=height, W=width)
109
+ # freqs = context_parallel_util.split_cp_2d(freqs, seq_dim_hw=(1, 2), split_hw=self.cp_split_hw)
110
+ # freqs = rearrange(freqs, "T H W D -> (T H W) D")
111
+
112
+ return freqs
113
+
114
+ def forward(self, q, k, grid_size):
115
+ """3D RoPE.
116
+
117
+ Args:
118
+ query: [B, head, seq, head_dim]
119
+ key: [B, head, seq, head_dim]
120
+ Returns:
121
+ query and key with the same shape as input.
122
+ """
123
+
124
+ if grid_size not in self.freqs_dict:
125
+ self.register_grid_size(grid_size)
126
+
127
+ freqs_cis = self.freqs_dict[grid_size].to(q.device)
128
+ q_, k_ = q.float(), k.float()
129
+ freqs_cis = freqs_cis.float().to(q.device)
130
+ cos, sin = freqs_cis.cos(), freqs_cis.sin()
131
+ cos, sin = rearrange(cos, 'n d -> 1 1 n d'), rearrange(sin, 'n d -> 1 1 n d')
132
+ q_ = (q_ * cos) + (rotate_half(q_) * sin)
133
+ k_ = (k_ * cos) + (rotate_half(k_) * sin)
134
+
135
+ return q_.type_as(q), k_.type_as(k)
136
+
137
+
138
+ class Attention(nn.Module):
139
+ def __init__(
140
+ self,
141
+ dim: int,
142
+ num_heads: int,
143
+ enable_flashattn3: bool = False,
144
+ enable_flashattn2: bool = False,
145
+ enable_xformers: bool = False,
146
+ enable_bsa: bool = False,
147
+ bsa_params: dict = None,
148
+ cp_split_hw: Optional[List[int]] = None
149
+ ) -> None:
150
+ super().__init__()
151
+ assert dim % num_heads == 0, "dim should be divisible by num_heads"
152
+ self.dim = dim
153
+ self.num_heads = num_heads
154
+ self.head_dim = dim // num_heads
155
+ self.scale = self.head_dim**-0.5
156
+ self.enable_flashattn3 = enable_flashattn3
157
+ self.enable_flashattn2 = enable_flashattn2
158
+ self.enable_xformers = enable_xformers
159
+ self.enable_bsa = enable_bsa
160
+ self.bsa_params = bsa_params
161
+ self.cp_split_hw = cp_split_hw
162
+
163
+ self.qkv = nn.Linear(dim, dim * 3, bias=True)
164
+ self.q_norm = RMSNorm_FP32(self.head_dim, eps=1e-6)
165
+ self.k_norm = RMSNorm_FP32(self.head_dim, eps=1e-6)
166
+ self.proj = nn.Linear(dim, dim)
167
+
168
+ self.rope_3d = RotaryPositionalEmbedding(
169
+ self.head_dim,
170
+ cp_split_hw=cp_split_hw
171
+ )
172
+
173
+ def _process_attn(self, q, k, v, shape):
174
+ q = rearrange(q, "B H S D -> B S (H D)")
175
+ k = rearrange(k, "B H S D -> B S (H D)")
176
+ v = rearrange(v, "B H S D -> B S (H D)")
177
+ x = flash_attention(q, k, v, num_heads=self.num_heads)
178
+ x = rearrange(x, "B S (H D) -> B H S D", H=self.num_heads)
179
+ return x
180
+
181
+ def forward(self, x: torch.Tensor, shape=None, num_cond_latents=None, return_kv=False) -> torch.Tensor:
182
+ """
183
+ """
184
+ B, N, C = x.shape
185
+ qkv = self.qkv(x)
186
+
187
+ qkv_shape = (B, N, 3, self.num_heads, self.head_dim)
188
+ qkv = qkv.view(qkv_shape).permute((2, 0, 3, 1, 4)) # [3, B, H, N, D]
189
+ q, k, v = qkv.unbind(0)
190
+ q, k = self.q_norm(q), self.k_norm(k)
191
+
192
+ if return_kv:
193
+ k_cache, v_cache = k.clone(), v.clone()
194
+
195
+ q, k = self.rope_3d(q, k, shape)
196
+
197
+ # cond mode
198
+ if num_cond_latents is not None and num_cond_latents > 0:
199
+ num_cond_latents_thw = num_cond_latents * (N // shape[0])
200
+ # process the condition tokens
201
+ q_cond = q[:, :, :num_cond_latents_thw].contiguous()
202
+ k_cond = k[:, :, :num_cond_latents_thw].contiguous()
203
+ v_cond = v[:, :, :num_cond_latents_thw].contiguous()
204
+ x_cond = self._process_attn(q_cond, k_cond, v_cond, shape)
205
+ # process the noise tokens
206
+ q_noise = q[:, :, num_cond_latents_thw:].contiguous()
207
+ x_noise = self._process_attn(q_noise, k, v, shape)
208
+ # merge x_cond and x_noise
209
+ x = torch.cat([x_cond, x_noise], dim=2).contiguous()
210
+ else:
211
+ x = self._process_attn(q, k, v, shape)
212
+
213
+ x_output_shape = (B, N, C)
214
+ x = x.transpose(1, 2) # [B, H, N, D] --> [B, N, H, D]
215
+ x = x.reshape(x_output_shape) # [B, N, H, D] --> [B, N, C]
216
+ x = self.proj(x)
217
+
218
+ if return_kv:
219
+ return x, (k_cache, v_cache)
220
+ else:
221
+ return x
222
+
223
+ def forward_with_kv_cache(self, x: torch.Tensor, shape=None, num_cond_latents=None, kv_cache=None) -> torch.Tensor:
224
+ """
225
+ """
226
+ B, N, C = x.shape
227
+ qkv = self.qkv(x)
228
+
229
+ qkv_shape = (B, N, 3, self.num_heads, self.head_dim)
230
+ qkv = qkv.view(qkv_shape).permute((2, 0, 3, 1, 4)) # [3, B, H, N, D]
231
+ q, k, v = qkv.unbind(0)
232
+ q, k = self.q_norm(q), self.k_norm(k)
233
+
234
+ T, H, W = shape
235
+ k_cache, v_cache = kv_cache
236
+ assert k_cache.shape[0] == v_cache.shape[0] and k_cache.shape[0] in [1, B]
237
+ if k_cache.shape[0] == 1:
238
+ k_cache = k_cache.repeat(B, 1, 1, 1)
239
+ v_cache = v_cache.repeat(B, 1, 1, 1)
240
+
241
+ if num_cond_latents is not None and num_cond_latents > 0:
242
+ k_full = torch.cat([k_cache, k], dim=2).contiguous()
243
+ v_full = torch.cat([v_cache, v], dim=2).contiguous()
244
+ q_padding = torch.cat([torch.empty_like(k_cache), q], dim=2).contiguous()
245
+ q_padding, k_full = self.rope_3d(q_padding, k_full, (T + num_cond_latents, H, W))
246
+ q = q_padding[:, :, -N:].contiguous()
247
+
248
+ x = self._process_attn(q, k_full, v_full, shape)
249
+
250
+ x_output_shape = (B, N, C)
251
+ x = x.transpose(1, 2) # [B, H, N, D] --> [B, N, H, D]
252
+ x = x.reshape(x_output_shape) # [B, N, H, D] --> [B, N, C]
253
+ x = self.proj(x)
254
+
255
+ return x
256
+
257
+
258
+ class MultiHeadCrossAttention(nn.Module):
259
+ def __init__(
260
+ self,
261
+ dim,
262
+ num_heads,
263
+ enable_flashattn3=False,
264
+ enable_flashattn2=False,
265
+ enable_xformers=False,
266
+ ):
267
+ super(MultiHeadCrossAttention, self).__init__()
268
+ assert dim % num_heads == 0, "d_model must be divisible by num_heads"
269
+
270
+ self.dim = dim
271
+ self.num_heads = num_heads
272
+ self.head_dim = dim // num_heads
273
+
274
+ self.q_linear = nn.Linear(dim, dim)
275
+ self.kv_linear = nn.Linear(dim, dim * 2)
276
+ self.proj = nn.Linear(dim, dim)
277
+
278
+ self.q_norm = RMSNorm_FP32(self.head_dim, eps=1e-6)
279
+ self.k_norm = RMSNorm_FP32(self.head_dim, eps=1e-6)
280
+
281
+ self.enable_flashattn3 = enable_flashattn3
282
+ self.enable_flashattn2 = enable_flashattn2
283
+ self.enable_xformers = enable_xformers
284
+
285
+ def _process_cross_attn(self, x, cond, kv_seqlen):
286
+ B, N, C = x.shape
287
+ assert C == self.dim and cond.shape[2] == self.dim
288
+
289
+ q = self.q_linear(x).view(1, -1, self.num_heads, self.head_dim)
290
+ kv = self.kv_linear(cond).view(1, -1, 2, self.num_heads, self.head_dim)
291
+ k, v = kv.unbind(2)
292
+
293
+ q, k = self.q_norm(q), self.k_norm(k)
294
+
295
+ q = rearrange(q, "B S H D -> B S (H D)")
296
+ k = rearrange(k, "B S H D -> B S (H D)")
297
+ v = rearrange(v, "B S H D -> B S (H D)")
298
+ x = flash_attention(q, k, v, num_heads=self.num_heads)
299
+
300
+ x = x.view(B, -1, C)
301
+ x = self.proj(x)
302
+ return x
303
+
304
+ def forward(self, x, cond, kv_seqlen, num_cond_latents=None, shape=None):
305
+ """
306
+ x: [B, N, C]
307
+ cond: [B, M, C]
308
+ """
309
+ if num_cond_latents is None or num_cond_latents == 0:
310
+ return self._process_cross_attn(x, cond, kv_seqlen)
311
+ else:
312
+ B, N, C = x.shape
313
+ if num_cond_latents is not None and num_cond_latents > 0:
314
+ assert shape is not None, "SHOULD pass in the shape"
315
+ num_cond_latents_thw = num_cond_latents * (N // shape[0])
316
+ x_noise = x[:, num_cond_latents_thw:] # [B, N_noise, C]
317
+ output_noise = self._process_cross_attn(x_noise, cond, kv_seqlen) # [B, N_noise, C]
318
+ output = torch.cat([
319
+ torch.zeros((B, num_cond_latents_thw, C), dtype=output_noise.dtype, device=output_noise.device),
320
+ output_noise
321
+ ], dim=1).contiguous()
322
+ else:
323
+ raise NotImplementedError
324
+
325
+ return output
326
+
327
+
328
+ class LayerNorm_FP32(nn.LayerNorm):
329
+ def __init__(self, dim, eps, elementwise_affine):
330
+ super().__init__(dim, eps=eps, elementwise_affine=elementwise_affine)
331
+
332
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
333
+ origin_dtype = inputs.dtype
334
+ out = F.layer_norm(
335
+ inputs.float(),
336
+ self.normalized_shape,
337
+ None if self.weight is None else self.weight.float(),
338
+ None if self.bias is None else self.bias.float() ,
339
+ self.eps
340
+ ).to(origin_dtype)
341
+ return out
342
+
343
+
344
+ def modulate_fp32(norm_func, x, shift, scale):
345
+ # Suppose x is (B, N, D), shift is (B, -1, D), scale is (B, -1, D)
346
+ # ensure the modulation params be fp32
347
+ assert shift.dtype == torch.float32, scale.dtype == torch.float32
348
+ dtype = x.dtype
349
+ x = norm_func(x.to(torch.float32))
350
+ x = x * (scale + 1) + shift
351
+ x = x.to(dtype)
352
+ return x
353
+
354
+
355
+ class FinalLayer_FP32(nn.Module):
356
+ """
357
+ The final layer of DiT.
358
+ """
359
+
360
+ def __init__(self, hidden_size, num_patch, out_channels, adaln_tembed_dim):
361
+ super().__init__()
362
+ self.hidden_size = hidden_size
363
+ self.num_patch = num_patch
364
+ self.out_channels = out_channels
365
+ self.adaln_tembed_dim = adaln_tembed_dim
366
+
367
+ self.norm_final = LayerNorm_FP32(hidden_size, elementwise_affine=False, eps=1e-6)
368
+ self.linear = nn.Linear(hidden_size, num_patch * out_channels, bias=True)
369
+ self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(adaln_tembed_dim, 2 * hidden_size, bias=True))
370
+
371
+ def forward(self, x, t, latent_shape):
372
+ # timestep shape: [B, T, C]
373
+ assert t.dtype == torch.float32
374
+ B, N, C = x.shape
375
+ T, _, _ = latent_shape
376
+
377
+ with amp.autocast(get_device_type(), dtype=torch.float32):
378
+ shift, scale = self.adaLN_modulation(t).unsqueeze(2).chunk(2, dim=-1) # [B, T, 1, C]
379
+ x = modulate_fp32(self.norm_final, x.view(B, T, -1, C), shift, scale).view(B, N, C)
380
+ x = self.linear(x)
381
+ return x
382
+
383
+
384
+ class FeedForwardSwiGLU(nn.Module):
385
+ def __init__(
386
+ self,
387
+ dim: int,
388
+ hidden_dim: int,
389
+ multiple_of: int = 256,
390
+ ffn_dim_multiplier: Optional[float] = None,
391
+ ):
392
+ super().__init__()
393
+ hidden_dim = int(2 * hidden_dim / 3)
394
+ # custom dim factor multiplier
395
+ if ffn_dim_multiplier is not None:
396
+ hidden_dim = int(ffn_dim_multiplier * hidden_dim)
397
+ hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
398
+
399
+ self.dim = dim
400
+ self.hidden_dim = hidden_dim
401
+ self.w1 = nn.Linear(dim, hidden_dim, bias=False)
402
+ self.w2 = nn.Linear(hidden_dim, dim, bias=False)
403
+ self.w3 = nn.Linear(dim, hidden_dim, bias=False)
404
+
405
+ def forward(self, x):
406
+ return self.w2(F.silu(self.w1(x)) * self.w3(x))
407
+
408
+
409
+ class TimestepEmbedder(nn.Module):
410
+ """
411
+ Embeds scalar timesteps into vector representations.
412
+ """
413
+
414
+ def __init__(self, t_embed_dim, frequency_embedding_size=256):
415
+ super().__init__()
416
+ self.t_embed_dim = t_embed_dim
417
+ self.frequency_embedding_size = frequency_embedding_size
418
+ self.mlp = nn.Sequential(
419
+ nn.Linear(frequency_embedding_size, t_embed_dim, bias=True),
420
+ nn.SiLU(),
421
+ nn.Linear(t_embed_dim, t_embed_dim, bias=True),
422
+ )
423
+
424
+ @staticmethod
425
+ def timestep_embedding(t, dim, max_period=10000):
426
+ """
427
+ Create sinusoidal timestep embeddings.
428
+ :param t: a 1-D Tensor of N indices, one per batch element.
429
+ These may be fractional.
430
+ :param dim: the dimension of the output.
431
+ :param max_period: controls the minimum frequency of the embeddings.
432
+ :return: an (N, D) Tensor of positional embeddings.
433
+ """
434
+ half = dim // 2
435
+ freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half)
436
+ freqs = freqs.to(device=t.device)
437
+ args = t[:, None].float() * freqs[None]
438
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
439
+ if dim % 2:
440
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
441
+ return embedding
442
+
443
+ def forward(self, t, dtype):
444
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
445
+ if t_freq.dtype != dtype:
446
+ t_freq = t_freq.to(dtype)
447
+ t_emb = self.mlp(t_freq)
448
+ return t_emb
449
+
450
+
451
+ class CaptionEmbedder(nn.Module):
452
+ """
453
+ Embeds class labels into vector representations.
454
+ """
455
+
456
+ def __init__(self, in_channels, hidden_size):
457
+ super().__init__()
458
+ self.in_channels = in_channels
459
+ self.hidden_size = hidden_size
460
+ self.y_proj = nn.Sequential(
461
+ nn.Linear(in_channels, hidden_size, bias=True),
462
+ nn.GELU(approximate="tanh"),
463
+ nn.Linear(hidden_size, hidden_size, bias=True),
464
+ )
465
+
466
+ def forward(self, caption):
467
+ B, _, N, C = caption.shape
468
+ caption = self.y_proj(caption)
469
+ return caption
470
+
471
+
472
+ class PatchEmbed3D(nn.Module):
473
+ """Video to Patch Embedding.
474
+
475
+ Args:
476
+ patch_size (int): Patch token size. Default: (2,4,4).
477
+ in_chans (int): Number of input video channels. Default: 3.
478
+ embed_dim (int): Number of linear projection output channels. Default: 96.
479
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
480
+ """
481
+
482
+ def __init__(
483
+ self,
484
+ patch_size=(2, 4, 4),
485
+ in_chans=3,
486
+ embed_dim=96,
487
+ norm_layer=None,
488
+ flatten=True,
489
+ ):
490
+ super().__init__()
491
+ self.patch_size = patch_size
492
+ self.flatten = flatten
493
+
494
+ self.in_chans = in_chans
495
+ self.embed_dim = embed_dim
496
+
497
+ self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
498
+ if norm_layer is not None:
499
+ self.norm = norm_layer(embed_dim)
500
+ else:
501
+ self.norm = None
502
+
503
+ def forward(self, x):
504
+ """Forward function."""
505
+ # padding
506
+ _, _, D, H, W = x.size()
507
+ if W % self.patch_size[2] != 0:
508
+ x = F.pad(x, (0, self.patch_size[2] - W % self.patch_size[2]))
509
+ if H % self.patch_size[1] != 0:
510
+ x = F.pad(x, (0, 0, 0, self.patch_size[1] - H % self.patch_size[1]))
511
+ if D % self.patch_size[0] != 0:
512
+ x = F.pad(x, (0, 0, 0, 0, 0, self.patch_size[0] - D % self.patch_size[0]))
513
+
514
+ B, C, T, H, W = x.shape
515
+ x = self.proj(x) # (B C T H W)
516
+ if self.norm is not None:
517
+ D, Wh, Ww = x.size(2), x.size(3), x.size(4)
518
+ x = x.flatten(2).transpose(1, 2)
519
+ x = self.norm(x)
520
+ x = x.transpose(1, 2).view(-1, self.embed_dim, D, Wh, Ww)
521
+ if self.flatten:
522
+ x = x.flatten(2).transpose(1, 2) # BCTHW -> BNC
523
+ return x
524
+
525
+
526
+ class LongCatSingleStreamBlock(nn.Module):
527
+ def __init__(
528
+ self,
529
+ hidden_size: int,
530
+ num_heads: int,
531
+ mlp_ratio: int,
532
+ adaln_tembed_dim: int,
533
+ enable_flashattn3: bool = False,
534
+ enable_flashattn2: bool = False,
535
+ enable_xformers: bool = False,
536
+ enable_bsa: bool = False,
537
+ bsa_params=None,
538
+ cp_split_hw=None
539
+ ):
540
+ super().__init__()
541
+
542
+ self.hidden_size = hidden_size
543
+
544
+ # scale and gate modulation
545
+ self.adaLN_modulation = nn.Sequential(
546
+ nn.SiLU(),
547
+ nn.Linear(adaln_tembed_dim, 6 * hidden_size, bias=True)
548
+ )
549
+
550
+ self.mod_norm_attn = LayerNorm_FP32(hidden_size, eps=1e-6, elementwise_affine=False)
551
+ self.mod_norm_ffn = LayerNorm_FP32(hidden_size, eps=1e-6, elementwise_affine=False)
552
+ self.pre_crs_attn_norm = LayerNorm_FP32(hidden_size, eps=1e-6, elementwise_affine=True)
553
+
554
+ self.attn = Attention(
555
+ dim=hidden_size,
556
+ num_heads=num_heads,
557
+ enable_flashattn3=enable_flashattn3,
558
+ enable_flashattn2=enable_flashattn2,
559
+ enable_xformers=enable_xformers,
560
+ enable_bsa=enable_bsa,
561
+ bsa_params=bsa_params,
562
+ cp_split_hw=cp_split_hw
563
+ )
564
+ self.cross_attn = MultiHeadCrossAttention(
565
+ dim=hidden_size,
566
+ num_heads=num_heads,
567
+ enable_flashattn3=enable_flashattn3,
568
+ enable_flashattn2=enable_flashattn2,
569
+ enable_xformers=enable_xformers,
570
+ )
571
+ self.ffn = FeedForwardSwiGLU(dim=hidden_size, hidden_dim=int(hidden_size * mlp_ratio))
572
+
573
+ def forward(self, x, y, t, y_seqlen, latent_shape, num_cond_latents=None, return_kv=False, kv_cache=None, skip_crs_attn=False):
574
+ """
575
+ x: [B, N, C]
576
+ y: [1, N_valid_tokens, C]
577
+ t: [B, T, C_t]
578
+ y_seqlen: [B]; type of a list
579
+ latent_shape: latent shape of a single item
580
+ """
581
+ x_dtype = x.dtype
582
+
583
+ B, N, C = x.shape
584
+ T, _, _ = latent_shape # S != T*H*W in case of CP split on H*W.
585
+
586
+ # compute modulation params in fp32
587
+ with amp.autocast(device_type=get_device_type(), dtype=torch.float32):
588
+ shift_msa, scale_msa, gate_msa, \
589
+ shift_mlp, scale_mlp, gate_mlp = \
590
+ self.adaLN_modulation(t).unsqueeze(2).chunk(6, dim=-1) # [B, T, 1, C]
591
+
592
+ # self attn with modulation
593
+ x_m = modulate_fp32(self.mod_norm_attn, x.view(B, T, -1, C), shift_msa, scale_msa).view(B, N, C)
594
+
595
+ if kv_cache is not None:
596
+ kv_cache = (kv_cache[0].to(x.device), kv_cache[1].to(x.device))
597
+ attn_outputs = self.attn.forward_with_kv_cache(x_m, shape=latent_shape, num_cond_latents=num_cond_latents, kv_cache=kv_cache)
598
+ else:
599
+ attn_outputs = self.attn(x_m, shape=latent_shape, num_cond_latents=num_cond_latents, return_kv=return_kv)
600
+
601
+ if return_kv:
602
+ x_s, kv_cache = attn_outputs
603
+ else:
604
+ x_s = attn_outputs
605
+
606
+ with amp.autocast(device_type=get_device_type(), dtype=torch.float32):
607
+ x = x + (gate_msa * x_s.view(B, -1, N//T, C)).view(B, -1, C) # [B, N, C]
608
+ x = x.to(x_dtype)
609
+
610
+ # cross attn
611
+ if not skip_crs_attn:
612
+ if kv_cache is not None:
613
+ num_cond_latents = None
614
+ x = x + self.cross_attn(self.pre_crs_attn_norm(x), y, y_seqlen, num_cond_latents=num_cond_latents, shape=latent_shape)
615
+
616
+ # ffn with modulation
617
+ x_m = modulate_fp32(self.mod_norm_ffn, x.view(B, -1, N//T, C), shift_mlp, scale_mlp).view(B, -1, C)
618
+ x_s = self.ffn(x_m)
619
+ with amp.autocast(device_type=get_device_type(), dtype=torch.float32):
620
+ x = x + (gate_mlp * x_s.view(B, -1, N//T, C)).view(B, -1, C) # [B, N, C]
621
+ x = x.to(x_dtype)
622
+
623
+ if return_kv:
624
+ return x, kv_cache
625
+ else:
626
+ return x
627
+
628
+
629
+ class LongCatVideoTransformer3DModel(torch.nn.Module):
630
+ def __init__(
631
+ self,
632
+ in_channels: int = 16,
633
+ out_channels: int = 16,
634
+ hidden_size: int = 4096,
635
+ depth: int = 48,
636
+ num_heads: int = 32,
637
+ caption_channels: int = 4096,
638
+ mlp_ratio: int = 4,
639
+ adaln_tembed_dim: int = 512,
640
+ frequency_embedding_size: int = 256,
641
+ # default params
642
+ patch_size: Tuple[int] = (1, 2, 2),
643
+ # attention config
644
+ enable_flashattn3: bool = False,
645
+ enable_flashattn2: bool = True,
646
+ enable_xformers: bool = False,
647
+ enable_bsa: bool = False,
648
+ bsa_params: dict = {'sparsity': 0.9375, 'chunk_3d_shape_q': [4, 4, 4], 'chunk_3d_shape_k': [4, 4, 4]},
649
+ cp_split_hw: Optional[List[int]] = [1, 1],
650
+ text_tokens_zero_pad: bool = True,
651
+ ) -> None:
652
+ super().__init__()
653
+
654
+ self.patch_size = patch_size
655
+ self.in_channels = in_channels
656
+ self.out_channels = out_channels
657
+ self.cp_split_hw = cp_split_hw
658
+
659
+ self.x_embedder = PatchEmbed3D(patch_size, in_channels, hidden_size)
660
+ self.t_embedder = TimestepEmbedder(t_embed_dim=adaln_tembed_dim, frequency_embedding_size=frequency_embedding_size)
661
+ self.y_embedder = CaptionEmbedder(
662
+ in_channels=caption_channels,
663
+ hidden_size=hidden_size,
664
+ )
665
+
666
+ self.blocks = nn.ModuleList(
667
+ [
668
+ LongCatSingleStreamBlock(
669
+ hidden_size=hidden_size,
670
+ num_heads=num_heads,
671
+ mlp_ratio=mlp_ratio,
672
+ adaln_tembed_dim=adaln_tembed_dim,
673
+ enable_flashattn3=enable_flashattn3,
674
+ enable_flashattn2=enable_flashattn2,
675
+ enable_xformers=enable_xformers,
676
+ enable_bsa=enable_bsa,
677
+ bsa_params=bsa_params,
678
+ cp_split_hw=cp_split_hw
679
+ )
680
+ for i in range(depth)
681
+ ]
682
+ )
683
+
684
+ self.final_layer = FinalLayer_FP32(
685
+ hidden_size,
686
+ np.prod(self.patch_size),
687
+ out_channels,
688
+ adaln_tembed_dim,
689
+ )
690
+
691
+ self.gradient_checkpointing = False
692
+ self.text_tokens_zero_pad = text_tokens_zero_pad
693
+
694
+ self.lora_dict = {}
695
+ self.active_loras = []
696
+
697
+ def enable_loras(self, lora_key_list=[]):
698
+ self.disable_all_loras()
699
+
700
+ module_loras = {} # {module_name: [lora1, lora2, ...]}
701
+ model_device = next(self.parameters()).device
702
+ model_dtype = next(self.parameters()).dtype
703
+
704
+ for lora_key in lora_key_list:
705
+ if lora_key in self.lora_dict:
706
+ for lora in self.lora_dict[lora_key].loras:
707
+ lora.to(model_device, dtype=model_dtype, non_blocking=True)
708
+ module_name = lora.lora_name.replace("lora___lorahyphen___", "").replace("___lorahyphen___", ".")
709
+ if module_name not in module_loras:
710
+ module_loras[module_name] = []
711
+ module_loras[module_name].append(lora)
712
+ self.active_loras.append(lora_key)
713
+
714
+ for module_name, loras in module_loras.items():
715
+ module = self._get_module_by_name(module_name)
716
+ if not hasattr(module, 'org_forward'):
717
+ module.org_forward = module.forward
718
+ module.forward = self._create_multi_lora_forward(module, loras)
719
+
720
+ def _create_multi_lora_forward(self, module, loras):
721
+ def multi_lora_forward(x, *args, **kwargs):
722
+ weight_dtype = x.dtype
723
+ org_output = module.org_forward(x, *args, **kwargs)
724
+
725
+ total_lora_output = 0
726
+ for lora in loras:
727
+ if lora.use_lora:
728
+ lx = lora.lora_down(x.to(lora.lora_down.weight.dtype))
729
+ lx = lora.lora_up(lx)
730
+ lora_output = lx.to(weight_dtype) * lora.multiplier * lora.alpha_scale
731
+ total_lora_output += lora_output
732
+
733
+ return org_output + total_lora_output
734
+
735
+ return multi_lora_forward
736
+
737
+ def _get_module_by_name(self, module_name):
738
+ try:
739
+ module = self
740
+ for part in module_name.split('.'):
741
+ module = getattr(module, part)
742
+ return module
743
+ except AttributeError as e:
744
+ raise ValueError(f"Cannot find module: {module_name}, error: {e}")
745
+
746
+ def disable_all_loras(self):
747
+ for name, module in self.named_modules():
748
+ if hasattr(module, 'org_forward'):
749
+ module.forward = module.org_forward
750
+ delattr(module, 'org_forward')
751
+
752
+ for lora_key, lora_network in self.lora_dict.items():
753
+ for lora in lora_network.loras:
754
+ lora.to("cpu")
755
+
756
+ self.active_loras.clear()
757
+
758
+ def enable_bsa(self,):
759
+ for block in self.blocks:
760
+ block.attn.enable_bsa = True
761
+
762
+ def disable_bsa(self,):
763
+ for block in self.blocks:
764
+ block.attn.enable_bsa = False
765
+
766
+ def forward(
767
+ self,
768
+ hidden_states,
769
+ timestep,
770
+ encoder_hidden_states,
771
+ encoder_attention_mask=None,
772
+ num_cond_latents=0,
773
+ return_kv=False,
774
+ kv_cache_dict={},
775
+ skip_crs_attn=False,
776
+ offload_kv_cache=False,
777
+ use_gradient_checkpointing=False,
778
+ use_gradient_checkpointing_offload=False,
779
+ ):
780
+
781
+ B, _, T, H, W = hidden_states.shape
782
+
783
+ N_t = T // self.patch_size[0]
784
+ N_h = H // self.patch_size[1]
785
+ N_w = W // self.patch_size[2]
786
+
787
+ assert self.patch_size[0]==1, "Currently, 3D x_embedder should not compress the temporal dimension."
788
+
789
+ # expand the shape of timestep from [B] to [B, T]
790
+ if len(timestep.shape) == 1:
791
+ timestep = timestep.unsqueeze(1).expand(-1, N_t).clone() # [B, T]
792
+ timestep[:, :num_cond_latents] = 0
793
+
794
+ dtype = hidden_states.dtype
795
+ hidden_states = hidden_states.to(dtype)
796
+ timestep = timestep.to(dtype)
797
+ encoder_hidden_states = encoder_hidden_states.to(dtype)
798
+
799
+ hidden_states = self.x_embedder(hidden_states) # [B, N, C]
800
+
801
+ with amp.autocast(device_type=get_device_type(), dtype=torch.float32):
802
+ t = self.t_embedder(timestep.float().flatten(), dtype=torch.float32).reshape(B, N_t, -1) # [B, T, C_t]
803
+
804
+ encoder_hidden_states = self.y_embedder(encoder_hidden_states) # [B, 1, N_token, C]
805
+
806
+ if self.text_tokens_zero_pad and encoder_attention_mask is not None:
807
+ encoder_hidden_states = encoder_hidden_states * encoder_attention_mask[:, None, :, None]
808
+ encoder_attention_mask = (encoder_attention_mask * 0 + 1).to(encoder_attention_mask.dtype)
809
+
810
+ if encoder_attention_mask is not None:
811
+ encoder_attention_mask = encoder_attention_mask.squeeze(1).squeeze(1)
812
+ encoder_hidden_states = encoder_hidden_states.squeeze(1).masked_select(encoder_attention_mask.unsqueeze(-1) != 0).view(1, -1, hidden_states.shape[-1]) # [1, N_valid_tokens, C]
813
+ y_seqlens = encoder_attention_mask.sum(dim=1).tolist() # [B]
814
+ else:
815
+ y_seqlens = [encoder_hidden_states.shape[2]] * encoder_hidden_states.shape[0]
816
+ encoder_hidden_states = encoder_hidden_states.squeeze(1).view(1, -1, hidden_states.shape[-1])
817
+
818
+ # if self.cp_split_hw[0] * self.cp_split_hw[1] > 1:
819
+ # hidden_states = rearrange(hidden_states, "B (T H W) C -> B T H W C", T=N_t, H=N_h, W=N_w)
820
+ # hidden_states = context_parallel_util.split_cp_2d(hidden_states, seq_dim_hw=(2, 3), split_hw=self.cp_split_hw)
821
+ # hidden_states = rearrange(hidden_states, "B T H W C -> B (T H W) C")
822
+
823
+ # blocks
824
+ kv_cache_dict_ret = {}
825
+ for i, block in enumerate(self.blocks):
826
+ block_outputs = gradient_checkpoint_forward(
827
+ block,
828
+ use_gradient_checkpointing=use_gradient_checkpointing,
829
+ use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
830
+ x=hidden_states,
831
+ y=encoder_hidden_states,
832
+ t=t,
833
+ y_seqlen=y_seqlens,
834
+ latent_shape=(N_t, N_h, N_w),
835
+ num_cond_latents=num_cond_latents,
836
+ return_kv=return_kv,
837
+ kv_cache=kv_cache_dict.get(i, None),
838
+ skip_crs_attn=skip_crs_attn,
839
+ )
840
+
841
+ if return_kv:
842
+ hidden_states, kv_cache = block_outputs
843
+ if offload_kv_cache:
844
+ kv_cache_dict_ret[i] = (kv_cache[0].cpu(), kv_cache[1].cpu())
845
+ else:
846
+ kv_cache_dict_ret[i] = (kv_cache[0].contiguous(), kv_cache[1].contiguous())
847
+ else:
848
+ hidden_states = block_outputs
849
+
850
+ hidden_states = self.final_layer(hidden_states, t, (N_t, N_h, N_w)) # [B, N, C=T_p*H_p*W_p*C_out]
851
+
852
+ # if self.cp_split_hw[0] * self.cp_split_hw[1] > 1:
853
+ # hidden_states = context_parallel_util.gather_cp_2d(hidden_states, shape=(N_t, N_h, N_w), split_hw=self.cp_split_hw)
854
+
855
+ hidden_states = self.unpatchify(hidden_states, N_t, N_h, N_w) # [B, C_out, H, W]
856
+
857
+ # cast to float32 for better accuracy
858
+ hidden_states = hidden_states.to(torch.float32)
859
+
860
+ if return_kv:
861
+ return hidden_states, kv_cache_dict_ret
862
+ else:
863
+ return hidden_states
864
+
865
+
866
+ def unpatchify(self, x, N_t, N_h, N_w):
867
+ """
868
+ Args:
869
+ x (torch.Tensor): of shape [B, N, C]
870
+
871
+ Return:
872
+ x (torch.Tensor): of shape [B, C_out, T, H, W]
873
+ """
874
+ T_p, H_p, W_p = self.patch_size
875
+ x = rearrange(
876
+ x,
877
+ "B (N_t N_h N_w) (T_p H_p W_p C_out) -> B C_out (N_t T_p) (N_h H_p) (N_w W_p)",
878
+ N_t=N_t,
879
+ N_h=N_h,
880
+ N_w=N_w,
881
+ T_p=T_p,
882
+ H_p=H_p,
883
+ W_p=W_p,
884
+ C_out=self.out_channels,
885
+ )
886
+ return x
887
+
888
+ @staticmethod
889
+ def state_dict_converter():
890
+ return LongCatVideoTransformer3DModelDictConverter()
891
+
892
+
893
+ class LongCatVideoTransformer3DModelDictConverter:
894
+ def __init__(self):
895
+ pass
896
+
897
+ def from_diffusers(self, state_dict):
898
+ return state_dict
899
+
900
+ def from_civitai(self, state_dict):
901
+ return state_dict
902
+
diffsynth/models/ltx2_audio_vae.py ADDED
@@ -0,0 +1,1872 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Set, Tuple, Optional, List
2
+ from enum import Enum
3
+ import math
4
+ import einops
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ import torchaudio
9
+ from .ltx2_common import VideoLatentShape, AudioLatentShape, Patchifier, NormType, build_normalization_layer
10
+
11
+
12
+ class AudioProcessor(nn.Module):
13
+ """Converts audio waveforms to log-mel spectrograms with optional resampling."""
14
+
15
+ def __init__(
16
+ self,
17
+ sample_rate: int = 16000,
18
+ mel_bins: int = 64,
19
+ mel_hop_length: int = 160,
20
+ n_fft: int = 1024,
21
+ ) -> None:
22
+ super().__init__()
23
+ self.sample_rate = sample_rate
24
+ self.mel_transform = torchaudio.transforms.MelSpectrogram(
25
+ sample_rate=sample_rate,
26
+ n_fft=n_fft,
27
+ win_length=n_fft,
28
+ hop_length=mel_hop_length,
29
+ f_min=0.0,
30
+ f_max=sample_rate / 2.0,
31
+ n_mels=mel_bins,
32
+ window_fn=torch.hann_window,
33
+ center=True,
34
+ pad_mode="reflect",
35
+ power=1.0,
36
+ mel_scale="slaney",
37
+ norm="slaney",
38
+ )
39
+
40
+ def resample_waveform(
41
+ self,
42
+ waveform: torch.Tensor,
43
+ source_rate: int,
44
+ target_rate: int,
45
+ ) -> torch.Tensor:
46
+ """Resample waveform to target sample rate if needed."""
47
+ if source_rate == target_rate:
48
+ return waveform
49
+ resampled = torchaudio.functional.resample(waveform, source_rate, target_rate)
50
+ return resampled.to(device=waveform.device, dtype=waveform.dtype)
51
+
52
+ def waveform_to_mel(
53
+ self,
54
+ waveform: torch.Tensor,
55
+ waveform_sample_rate: int,
56
+ ) -> torch.Tensor:
57
+ """Convert waveform to log-mel spectrogram [batch, channels, time, n_mels]."""
58
+ waveform = self.resample_waveform(waveform, waveform_sample_rate, self.sample_rate)
59
+
60
+ mel = self.mel_transform(waveform)
61
+ mel = torch.log(torch.clamp(mel, min=1e-5))
62
+
63
+ mel = mel.to(device=waveform.device, dtype=waveform.dtype)
64
+ return mel.permute(0, 1, 3, 2).contiguous()
65
+
66
+
67
+ class AudioPatchifier(Patchifier):
68
+ def __init__(
69
+ self,
70
+ patch_size: int,
71
+ sample_rate: int = 16000,
72
+ hop_length: int = 160,
73
+ audio_latent_downsample_factor: int = 4,
74
+ is_causal: bool = True,
75
+ shift: int = 0,
76
+ ):
77
+ """
78
+ Patchifier tailored for spectrogram/audio latents.
79
+ Args:
80
+ patch_size: Number of mel bins combined into a single patch. This
81
+ controls the resolution along the frequency axis.
82
+ sample_rate: Original waveform sampling rate. Used to map latent
83
+ indices back to seconds so downstream consumers can align audio
84
+ and video cues.
85
+ hop_length: Window hop length used for the spectrogram. Determines
86
+ how many real-time samples separate two consecutive latent frames.
87
+ audio_latent_downsample_factor: Ratio between spectrogram frames and
88
+ latent frames; compensates for additional downsampling inside the
89
+ VAE encoder.
90
+ is_causal: When True, timing is shifted to account for causal
91
+ receptive fields so timestamps do not peek into the future.
92
+ shift: Integer offset applied to the latent indices. Enables
93
+ constructing overlapping windows from the same latent sequence.
94
+ """
95
+ self.hop_length = hop_length
96
+ self.sample_rate = sample_rate
97
+ self.audio_latent_downsample_factor = audio_latent_downsample_factor
98
+ self.is_causal = is_causal
99
+ self.shift = shift
100
+ self._patch_size = (1, patch_size, patch_size)
101
+
102
+ @property
103
+ def patch_size(self) -> Tuple[int, int, int]:
104
+ return self._patch_size
105
+
106
+ def get_token_count(self, tgt_shape: AudioLatentShape) -> int:
107
+ return tgt_shape.frames
108
+
109
+ def _get_audio_latent_time_in_sec(
110
+ self,
111
+ start_latent: int,
112
+ end_latent: int,
113
+ dtype: torch.dtype,
114
+ device: Optional[torch.device] = None,
115
+ ) -> torch.Tensor:
116
+ """
117
+ Converts latent indices into real-time seconds while honoring causal
118
+ offsets and the configured hop length.
119
+ Args:
120
+ start_latent: Inclusive start index inside the latent sequence. This
121
+ sets the first timestamp returned.
122
+ end_latent: Exclusive end index. Determines how many timestamps get
123
+ generated.
124
+ dtype: Floating-point dtype used for the returned tensor, allowing
125
+ callers to control precision.
126
+ device: Target device for the timestamp tensor. When omitted the
127
+ computation occurs on CPU to avoid surprising GPU allocations.
128
+ """
129
+ if device is None:
130
+ device = torch.device("cpu")
131
+
132
+ audio_latent_frame = torch.arange(start_latent, end_latent, dtype=dtype, device=device)
133
+
134
+ audio_mel_frame = audio_latent_frame * self.audio_latent_downsample_factor
135
+
136
+ if self.is_causal:
137
+ # Frame offset for causal alignment.
138
+ # The "+1" ensures the timestamp corresponds to the first sample that is fully available.
139
+ causal_offset = 1
140
+ audio_mel_frame = (audio_mel_frame + causal_offset - self.audio_latent_downsample_factor).clip(min=0)
141
+
142
+ return audio_mel_frame * self.hop_length / self.sample_rate
143
+
144
+ def _compute_audio_timings(
145
+ self,
146
+ batch_size: int,
147
+ num_steps: int,
148
+ device: Optional[torch.device] = None,
149
+ ) -> torch.Tensor:
150
+ """
151
+ Builds a `(B, 1, T, 2)` tensor containing timestamps for each latent frame.
152
+ This helper method underpins `get_patch_grid_bounds` for the audio patchifier.
153
+ Args:
154
+ batch_size: Number of sequences to broadcast the timings over.
155
+ num_steps: Number of latent frames (time steps) to convert into timestamps.
156
+ device: Device on which the resulting tensor should reside.
157
+ """
158
+ resolved_device = device
159
+ if resolved_device is None:
160
+ resolved_device = torch.device("cpu")
161
+
162
+ start_timings = self._get_audio_latent_time_in_sec(
163
+ self.shift,
164
+ num_steps + self.shift,
165
+ torch.float32,
166
+ resolved_device,
167
+ )
168
+ start_timings = start_timings.unsqueeze(0).expand(batch_size, -1).unsqueeze(1)
169
+
170
+ end_timings = self._get_audio_latent_time_in_sec(
171
+ self.shift + 1,
172
+ num_steps + self.shift + 1,
173
+ torch.float32,
174
+ resolved_device,
175
+ )
176
+ end_timings = end_timings.unsqueeze(0).expand(batch_size, -1).unsqueeze(1)
177
+
178
+ return torch.stack([start_timings, end_timings], dim=-1)
179
+
180
+ def patchify(
181
+ self,
182
+ audio_latents: torch.Tensor,
183
+ ) -> torch.Tensor:
184
+ """
185
+ Flattens the audio latent tensor along time. Use `get_patch_grid_bounds`
186
+ to derive timestamps for each latent frame based on the configured hop
187
+ length and downsampling.
188
+ Args:
189
+ audio_latents: Latent tensor to patchify.
190
+ Returns:
191
+ Flattened patch tokens tensor. Use `get_patch_grid_bounds` to compute the
192
+ corresponding timing metadata when needed.
193
+ """
194
+ audio_latents = einops.rearrange(
195
+ audio_latents,
196
+ "b c t f -> b t (c f)",
197
+ )
198
+
199
+ return audio_latents
200
+
201
+ def unpatchify(
202
+ self,
203
+ audio_latents: torch.Tensor,
204
+ output_shape: AudioLatentShape,
205
+ ) -> torch.Tensor:
206
+ """
207
+ Restores the `(B, C, T, F)` spectrogram tensor from flattened patches.
208
+ Use `get_patch_grid_bounds` to recompute the timestamps that describe each
209
+ frame's position in real time.
210
+ Args:
211
+ audio_latents: Latent tensor to unpatchify.
212
+ output_shape: Shape of the unpatched output tensor.
213
+ Returns:
214
+ Unpatched latent tensor. Use `get_patch_grid_bounds` to compute the timing
215
+ metadata associated with the restored latents.
216
+ """
217
+ # audio_latents shape: (batch, time, freq * channels)
218
+ audio_latents = einops.rearrange(
219
+ audio_latents,
220
+ "b t (c f) -> b c t f",
221
+ c=output_shape.channels,
222
+ f=output_shape.mel_bins,
223
+ )
224
+
225
+ return audio_latents
226
+
227
+ def unpatchify_audio(
228
+ self,
229
+ audio_latents: torch.Tensor,
230
+ channels: int,
231
+ mel_bins: int
232
+ ) -> torch.Tensor:
233
+ audio_latents = einops.rearrange(
234
+ audio_latents,
235
+ "b t (c f) -> b c t f",
236
+ c=channels,
237
+ f=mel_bins,
238
+ )
239
+ return audio_latents
240
+
241
+ def get_patch_grid_bounds(
242
+ self,
243
+ output_shape: AudioLatentShape | VideoLatentShape,
244
+ device: Optional[torch.device] = None,
245
+ ) -> torch.Tensor:
246
+ """
247
+ Return the temporal bounds `[inclusive start, exclusive end)` for every
248
+ patch emitted by `patchify`. For audio this corresponds to timestamps in
249
+ seconds aligned with the original spectrogram grid.
250
+ The returned tensor has shape `[batch_size, 1, time_steps, 2]`, where:
251
+ - axis 1 (size 1) represents the temporal dimension
252
+ - axis 3 (size 2) stores the `[start, end)` timestamps per patch
253
+ Args:
254
+ output_shape: Audio grid specification describing the number of time steps.
255
+ device: Target device for the returned tensor.
256
+ """
257
+ if not isinstance(output_shape, AudioLatentShape):
258
+ raise ValueError("AudioPatchifier expects AudioLatentShape when computing coordinates")
259
+
260
+ return self._compute_audio_timings(output_shape.batch, output_shape.frames, device)
261
+
262
+
263
+ class AttentionType(Enum):
264
+ """Enum for specifying the attention mechanism type."""
265
+
266
+ VANILLA = "vanilla"
267
+ LINEAR = "linear"
268
+ NONE = "none"
269
+
270
+
271
+ class AttnBlock(torch.nn.Module):
272
+ def __init__(
273
+ self,
274
+ in_channels: int,
275
+ norm_type: NormType = NormType.GROUP,
276
+ ) -> None:
277
+ super().__init__()
278
+ self.in_channels = in_channels
279
+
280
+ self.norm = build_normalization_layer(in_channels, normtype=norm_type)
281
+ self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
282
+ self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
283
+ self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
284
+ self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
285
+
286
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
287
+ h_ = x
288
+ h_ = self.norm(h_)
289
+ q = self.q(h_)
290
+ k = self.k(h_)
291
+ v = self.v(h_)
292
+
293
+ # compute attention
294
+ b, c, h, w = q.shape
295
+ q = q.reshape(b, c, h * w).contiguous()
296
+ q = q.permute(0, 2, 1).contiguous() # b,hw,c
297
+ k = k.reshape(b, c, h * w).contiguous() # b,c,hw
298
+ w_ = torch.bmm(q, k).contiguous() # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
299
+ w_ = w_ * (int(c) ** (-0.5))
300
+ w_ = torch.nn.functional.softmax(w_, dim=2)
301
+
302
+ # attend to values
303
+ v = v.reshape(b, c, h * w).contiguous()
304
+ w_ = w_.permute(0, 2, 1).contiguous() # b,hw,hw (first hw of k, second of q)
305
+ h_ = torch.bmm(v, w_).contiguous() # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
306
+ h_ = h_.reshape(b, c, h, w).contiguous()
307
+
308
+ h_ = self.proj_out(h_)
309
+
310
+ return x + h_
311
+
312
+
313
+ def make_attn(
314
+ in_channels: int,
315
+ attn_type: AttentionType = AttentionType.VANILLA,
316
+ norm_type: NormType = NormType.GROUP,
317
+ ) -> torch.nn.Module:
318
+ match attn_type:
319
+ case AttentionType.VANILLA:
320
+ return AttnBlock(in_channels, norm_type=norm_type)
321
+ case AttentionType.NONE:
322
+ return torch.nn.Identity()
323
+ case AttentionType.LINEAR:
324
+ raise NotImplementedError(f"Attention type {attn_type.value} is not supported yet.")
325
+ case _:
326
+ raise ValueError(f"Unknown attention type: {attn_type}")
327
+
328
+
329
+ class CausalityAxis(Enum):
330
+ """Enum for specifying the causality axis in causal convolutions."""
331
+
332
+ NONE = None
333
+ WIDTH = "width"
334
+ HEIGHT = "height"
335
+ WIDTH_COMPATIBILITY = "width-compatibility"
336
+
337
+
338
+ class CausalConv2d(torch.nn.Module):
339
+ """
340
+ A causal 2D convolution.
341
+ This layer ensures that the output at time `t` only depends on inputs
342
+ at time `t` and earlier. It achieves this by applying asymmetric padding
343
+ to the time dimension (width) before the convolution.
344
+ """
345
+
346
+ def __init__(
347
+ self,
348
+ in_channels: int,
349
+ out_channels: int,
350
+ kernel_size: int | tuple[int, int],
351
+ stride: int = 1,
352
+ dilation: int | tuple[int, int] = 1,
353
+ groups: int = 1,
354
+ bias: bool = True,
355
+ causality_axis: CausalityAxis = CausalityAxis.HEIGHT,
356
+ ) -> None:
357
+ super().__init__()
358
+
359
+ self.causality_axis = causality_axis
360
+
361
+ # Ensure kernel_size and dilation are tuples
362
+ kernel_size = torch.nn.modules.utils._pair(kernel_size)
363
+ dilation = torch.nn.modules.utils._pair(dilation)
364
+
365
+ # Calculate padding dimensions
366
+ pad_h = (kernel_size[0] - 1) * dilation[0]
367
+ pad_w = (kernel_size[1] - 1) * dilation[1]
368
+
369
+ # The padding tuple for F.pad is (pad_left, pad_right, pad_top, pad_bottom)
370
+ match self.causality_axis:
371
+ case CausalityAxis.NONE:
372
+ self.padding = (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2)
373
+ case CausalityAxis.WIDTH | CausalityAxis.WIDTH_COMPATIBILITY:
374
+ self.padding = (pad_w, 0, pad_h // 2, pad_h - pad_h // 2)
375
+ case CausalityAxis.HEIGHT:
376
+ self.padding = (pad_w // 2, pad_w - pad_w // 2, pad_h, 0)
377
+ case _:
378
+ raise ValueError(f"Invalid causality_axis: {causality_axis}")
379
+
380
+ # The internal convolution layer uses no padding, as we handle it manually
381
+ self.conv = torch.nn.Conv2d(
382
+ in_channels,
383
+ out_channels,
384
+ kernel_size,
385
+ stride=stride,
386
+ padding=0,
387
+ dilation=dilation,
388
+ groups=groups,
389
+ bias=bias,
390
+ )
391
+
392
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
393
+ # Apply causal padding before convolution
394
+ x = F.pad(x, self.padding)
395
+ return self.conv(x)
396
+
397
+
398
+ def make_conv2d(
399
+ in_channels: int,
400
+ out_channels: int,
401
+ kernel_size: int | tuple[int, int],
402
+ stride: int = 1,
403
+ padding: tuple[int, int, int, int] | None = None,
404
+ dilation: int = 1,
405
+ groups: int = 1,
406
+ bias: bool = True,
407
+ causality_axis: CausalityAxis | None = None,
408
+ ) -> torch.nn.Module:
409
+ """
410
+ Create a 2D convolution layer that can be either causal or non-causal.
411
+ Args:
412
+ in_channels: Number of input channels
413
+ out_channels: Number of output channels
414
+ kernel_size: Size of the convolution kernel
415
+ stride: Convolution stride
416
+ padding: Padding (if None, will be calculated based on causal flag)
417
+ dilation: Dilation rate
418
+ groups: Number of groups for grouped convolution
419
+ bias: Whether to use bias
420
+ causality_axis: Dimension along which to apply causality.
421
+ Returns:
422
+ Either a regular Conv2d or CausalConv2d layer
423
+ """
424
+ if causality_axis is not None:
425
+ # For causal convolution, padding is handled internally by CausalConv2d
426
+ return CausalConv2d(in_channels, out_channels, kernel_size, stride, dilation, groups, bias, causality_axis)
427
+ else:
428
+ # For non-causal convolution, use symmetric padding if not specified
429
+ if padding is None:
430
+ padding = kernel_size // 2 if isinstance(kernel_size, int) else tuple(k // 2 for k in kernel_size)
431
+
432
+ return torch.nn.Conv2d(
433
+ in_channels,
434
+ out_channels,
435
+ kernel_size,
436
+ stride,
437
+ padding,
438
+ dilation,
439
+ groups,
440
+ bias,
441
+ )
442
+
443
+
444
+
445
+ LRELU_SLOPE = 0.1
446
+
447
+
448
+ class ResBlock1(torch.nn.Module):
449
+ def __init__(self, channels: int, kernel_size: int = 3, dilation: Tuple[int, int, int] = (1, 3, 5)):
450
+ super(ResBlock1, self).__init__()
451
+ self.convs1 = torch.nn.ModuleList(
452
+ [
453
+ torch.nn.Conv1d(
454
+ channels,
455
+ channels,
456
+ kernel_size,
457
+ 1,
458
+ dilation=dilation[0],
459
+ padding="same",
460
+ ),
461
+ torch.nn.Conv1d(
462
+ channels,
463
+ channels,
464
+ kernel_size,
465
+ 1,
466
+ dilation=dilation[1],
467
+ padding="same",
468
+ ),
469
+ torch.nn.Conv1d(
470
+ channels,
471
+ channels,
472
+ kernel_size,
473
+ 1,
474
+ dilation=dilation[2],
475
+ padding="same",
476
+ ),
477
+ ]
478
+ )
479
+
480
+ self.convs2 = torch.nn.ModuleList(
481
+ [
482
+ torch.nn.Conv1d(
483
+ channels,
484
+ channels,
485
+ kernel_size,
486
+ 1,
487
+ dilation=1,
488
+ padding="same",
489
+ ),
490
+ torch.nn.Conv1d(
491
+ channels,
492
+ channels,
493
+ kernel_size,
494
+ 1,
495
+ dilation=1,
496
+ padding="same",
497
+ ),
498
+ torch.nn.Conv1d(
499
+ channels,
500
+ channels,
501
+ kernel_size,
502
+ 1,
503
+ dilation=1,
504
+ padding="same",
505
+ ),
506
+ ]
507
+ )
508
+
509
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
510
+ for conv1, conv2 in zip(self.convs1, self.convs2, strict=True):
511
+ xt = torch.nn.functional.leaky_relu(x, LRELU_SLOPE)
512
+ xt = conv1(xt)
513
+ xt = torch.nn.functional.leaky_relu(xt, LRELU_SLOPE)
514
+ xt = conv2(xt)
515
+ x = xt + x
516
+ return x
517
+
518
+
519
+ class ResBlock2(torch.nn.Module):
520
+ def __init__(self, channels: int, kernel_size: int = 3, dilation: Tuple[int, int] = (1, 3)):
521
+ super(ResBlock2, self).__init__()
522
+ self.convs = torch.nn.ModuleList(
523
+ [
524
+ torch.nn.Conv1d(
525
+ channels,
526
+ channels,
527
+ kernel_size,
528
+ 1,
529
+ dilation=dilation[0],
530
+ padding="same",
531
+ ),
532
+ torch.nn.Conv1d(
533
+ channels,
534
+ channels,
535
+ kernel_size,
536
+ 1,
537
+ dilation=dilation[1],
538
+ padding="same",
539
+ ),
540
+ ]
541
+ )
542
+
543
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
544
+ for conv in self.convs:
545
+ xt = torch.nn.functional.leaky_relu(x, LRELU_SLOPE)
546
+ xt = conv(xt)
547
+ x = xt + x
548
+ return x
549
+
550
+
551
+ class ResnetBlock(torch.nn.Module):
552
+ def __init__(
553
+ self,
554
+ *,
555
+ in_channels: int,
556
+ out_channels: int | None = None,
557
+ conv_shortcut: bool = False,
558
+ dropout: float = 0.0,
559
+ temb_channels: int = 512,
560
+ norm_type: NormType = NormType.GROUP,
561
+ causality_axis: CausalityAxis = CausalityAxis.HEIGHT,
562
+ ) -> None:
563
+ super().__init__()
564
+ self.causality_axis = causality_axis
565
+
566
+ if self.causality_axis != CausalityAxis.NONE and norm_type == NormType.GROUP:
567
+ raise ValueError("Causal ResnetBlock with GroupNorm is not supported.")
568
+ self.in_channels = in_channels
569
+ out_channels = in_channels if out_channels is None else out_channels
570
+ self.out_channels = out_channels
571
+ self.use_conv_shortcut = conv_shortcut
572
+
573
+ self.norm1 = build_normalization_layer(in_channels, normtype=norm_type)
574
+ self.non_linearity = torch.nn.SiLU()
575
+ self.conv1 = make_conv2d(in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis)
576
+ if temb_channels > 0:
577
+ self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
578
+ self.norm2 = build_normalization_layer(out_channels, normtype=norm_type)
579
+ self.dropout = torch.nn.Dropout(dropout)
580
+ self.conv2 = make_conv2d(out_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis)
581
+ if self.in_channels != self.out_channels:
582
+ if self.use_conv_shortcut:
583
+ self.conv_shortcut = make_conv2d(
584
+ in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis
585
+ )
586
+ else:
587
+ self.nin_shortcut = make_conv2d(
588
+ in_channels, out_channels, kernel_size=1, stride=1, causality_axis=causality_axis
589
+ )
590
+
591
+ def forward(
592
+ self,
593
+ x: torch.Tensor,
594
+ temb: torch.Tensor | None = None,
595
+ ) -> torch.Tensor:
596
+ h = x
597
+ h = self.norm1(h)
598
+ h = self.non_linearity(h)
599
+ h = self.conv1(h)
600
+
601
+ if temb is not None:
602
+ h = h + self.temb_proj(self.non_linearity(temb))[:, :, None, None]
603
+
604
+ h = self.norm2(h)
605
+ h = self.non_linearity(h)
606
+ h = self.dropout(h)
607
+ h = self.conv2(h)
608
+
609
+ if self.in_channels != self.out_channels:
610
+ x = self.conv_shortcut(x) if self.use_conv_shortcut else self.nin_shortcut(x)
611
+
612
+ return x + h
613
+
614
+
615
+ class Downsample(torch.nn.Module):
616
+ """
617
+ A downsampling layer that can use either a strided convolution
618
+ or average pooling. Supports standard and causal padding for the
619
+ convolutional mode.
620
+ """
621
+
622
+ def __init__(
623
+ self,
624
+ in_channels: int,
625
+ with_conv: bool,
626
+ causality_axis: CausalityAxis = CausalityAxis.WIDTH,
627
+ ) -> None:
628
+ super().__init__()
629
+ self.with_conv = with_conv
630
+ self.causality_axis = causality_axis
631
+
632
+ if self.causality_axis != CausalityAxis.NONE and not self.with_conv:
633
+ raise ValueError("causality is only supported when `with_conv=True`.")
634
+
635
+ if self.with_conv:
636
+ # Do time downsampling here
637
+ # no asymmetric padding in torch conv, must do it ourselves
638
+ self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
639
+
640
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
641
+ if self.with_conv:
642
+ # Padding tuple is in the order: (left, right, top, bottom).
643
+ match self.causality_axis:
644
+ case CausalityAxis.NONE:
645
+ pad = (0, 1, 0, 1)
646
+ case CausalityAxis.WIDTH:
647
+ pad = (2, 0, 0, 1)
648
+ case CausalityAxis.HEIGHT:
649
+ pad = (0, 1, 2, 0)
650
+ case CausalityAxis.WIDTH_COMPATIBILITY:
651
+ pad = (1, 0, 0, 1)
652
+ case _:
653
+ raise ValueError(f"Invalid causality_axis: {self.causality_axis}")
654
+
655
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
656
+ x = self.conv(x)
657
+ else:
658
+ # This branch is only taken if with_conv=False, which implies causality_axis is NONE.
659
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
660
+
661
+ return x
662
+
663
+
664
+ def build_downsampling_path( # noqa: PLR0913
665
+ *,
666
+ ch: int,
667
+ ch_mult: Tuple[int, ...],
668
+ num_resolutions: int,
669
+ num_res_blocks: int,
670
+ resolution: int,
671
+ temb_channels: int,
672
+ dropout: float,
673
+ norm_type: NormType,
674
+ causality_axis: CausalityAxis,
675
+ attn_type: AttentionType,
676
+ attn_resolutions: Set[int],
677
+ resamp_with_conv: bool,
678
+ ) -> tuple[torch.nn.ModuleList, int]:
679
+ """Build the downsampling path with residual blocks, attention, and downsampling layers."""
680
+ down_modules = torch.nn.ModuleList()
681
+ curr_res = resolution
682
+ in_ch_mult = (1, *tuple(ch_mult))
683
+ block_in = ch
684
+
685
+ for i_level in range(num_resolutions):
686
+ block = torch.nn.ModuleList()
687
+ attn = torch.nn.ModuleList()
688
+ block_in = ch * in_ch_mult[i_level]
689
+ block_out = ch * ch_mult[i_level]
690
+
691
+ for _ in range(num_res_blocks):
692
+ block.append(
693
+ ResnetBlock(
694
+ in_channels=block_in,
695
+ out_channels=block_out,
696
+ temb_channels=temb_channels,
697
+ dropout=dropout,
698
+ norm_type=norm_type,
699
+ causality_axis=causality_axis,
700
+ )
701
+ )
702
+ block_in = block_out
703
+ if curr_res in attn_resolutions:
704
+ attn.append(make_attn(block_in, attn_type=attn_type, norm_type=norm_type))
705
+
706
+ down = torch.nn.Module()
707
+ down.block = block
708
+ down.attn = attn
709
+ if i_level != num_resolutions - 1:
710
+ down.downsample = Downsample(block_in, resamp_with_conv, causality_axis=causality_axis)
711
+ curr_res = curr_res // 2
712
+ down_modules.append(down)
713
+
714
+ return down_modules, block_in
715
+
716
+
717
+ class Upsample(torch.nn.Module):
718
+ def __init__(
719
+ self,
720
+ in_channels: int,
721
+ with_conv: bool,
722
+ causality_axis: CausalityAxis = CausalityAxis.HEIGHT,
723
+ ) -> None:
724
+ super().__init__()
725
+ self.with_conv = with_conv
726
+ self.causality_axis = causality_axis
727
+ if self.with_conv:
728
+ self.conv = make_conv2d(in_channels, in_channels, kernel_size=3, stride=1, causality_axis=causality_axis)
729
+
730
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
731
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
732
+ if self.with_conv:
733
+ x = self.conv(x)
734
+ # Drop FIRST element in the causal axis to undo encoder's padding, while keeping the length 1 + 2 * n.
735
+ # For example, if the input is [0, 1, 2], after interpolation, the output is [0, 0, 1, 1, 2, 2].
736
+ # The causal convolution will pad the first element as [-, -, 0, 0, 1, 1, 2, 2],
737
+ # So the output elements rely on the following windows:
738
+ # 0: [-,-,0]
739
+ # 1: [-,0,0]
740
+ # 2: [0,0,1]
741
+ # 3: [0,1,1]
742
+ # 4: [1,1,2]
743
+ # 5: [1,2,2]
744
+ # Notice that the first and second elements in the output rely only on the first element in the input,
745
+ # while all other elements rely on two elements in the input.
746
+ # So we can drop the first element to undo the padding (rather than the last element).
747
+ # This is a no-op for non-causal convolutions.
748
+ match self.causality_axis:
749
+ case CausalityAxis.NONE:
750
+ pass # x remains unchanged
751
+ case CausalityAxis.HEIGHT:
752
+ x = x[:, :, 1:, :]
753
+ case CausalityAxis.WIDTH:
754
+ x = x[:, :, :, 1:]
755
+ case CausalityAxis.WIDTH_COMPATIBILITY:
756
+ pass # x remains unchanged
757
+ case _:
758
+ raise ValueError(f"Invalid causality_axis: {self.causality_axis}")
759
+
760
+ return x
761
+
762
+
763
+ def build_upsampling_path( # noqa: PLR0913
764
+ *,
765
+ ch: int,
766
+ ch_mult: Tuple[int, ...],
767
+ num_resolutions: int,
768
+ num_res_blocks: int,
769
+ resolution: int,
770
+ temb_channels: int,
771
+ dropout: float,
772
+ norm_type: NormType,
773
+ causality_axis: CausalityAxis,
774
+ attn_type: AttentionType,
775
+ attn_resolutions: Set[int],
776
+ resamp_with_conv: bool,
777
+ initial_block_channels: int,
778
+ ) -> tuple[torch.nn.ModuleList, int]:
779
+ """Build the upsampling path with residual blocks, attention, and upsampling layers."""
780
+ up_modules = torch.nn.ModuleList()
781
+ block_in = initial_block_channels
782
+ curr_res = resolution // (2 ** (num_resolutions - 1))
783
+
784
+ for level in reversed(range(num_resolutions)):
785
+ stage = torch.nn.Module()
786
+ stage.block = torch.nn.ModuleList()
787
+ stage.attn = torch.nn.ModuleList()
788
+ block_out = ch * ch_mult[level]
789
+
790
+ for _ in range(num_res_blocks + 1):
791
+ stage.block.append(
792
+ ResnetBlock(
793
+ in_channels=block_in,
794
+ out_channels=block_out,
795
+ temb_channels=temb_channels,
796
+ dropout=dropout,
797
+ norm_type=norm_type,
798
+ causality_axis=causality_axis,
799
+ )
800
+ )
801
+ block_in = block_out
802
+ if curr_res in attn_resolutions:
803
+ stage.attn.append(make_attn(block_in, attn_type=attn_type, norm_type=norm_type))
804
+
805
+ if level != 0:
806
+ stage.upsample = Upsample(block_in, resamp_with_conv, causality_axis=causality_axis)
807
+ curr_res *= 2
808
+
809
+ up_modules.insert(0, stage)
810
+
811
+ return up_modules, block_in
812
+
813
+
814
+ class PerChannelStatistics(nn.Module):
815
+ """
816
+ Per-channel statistics for normalizing and denormalizing the latent representation.
817
+ This statics is computed over the entire dataset and stored in model's checkpoint under AudioVAE state_dict.
818
+ """
819
+
820
+ def __init__(self, latent_channels: int = 128) -> None:
821
+ super().__init__()
822
+ self.register_buffer("std-of-means", torch.empty(latent_channels))
823
+ self.register_buffer("mean-of-means", torch.empty(latent_channels))
824
+
825
+ def un_normalize(self, x: torch.Tensor) -> torch.Tensor:
826
+ return (x * self.get_buffer("std-of-means").to(x)) + self.get_buffer("mean-of-means").to(x)
827
+
828
+ def normalize(self, x: torch.Tensor) -> torch.Tensor:
829
+ return (x - self.get_buffer("mean-of-means").to(x)) / self.get_buffer("std-of-means").to(x)
830
+
831
+
832
+ LATENT_DOWNSAMPLE_FACTOR = 4
833
+
834
+
835
+ def build_mid_block(
836
+ channels: int,
837
+ temb_channels: int,
838
+ dropout: float,
839
+ norm_type: NormType,
840
+ causality_axis: CausalityAxis,
841
+ attn_type: AttentionType,
842
+ add_attention: bool,
843
+ ) -> torch.nn.Module:
844
+ """Build the middle block with two ResNet blocks and optional attention."""
845
+ mid = torch.nn.Module()
846
+ mid.block_1 = ResnetBlock(
847
+ in_channels=channels,
848
+ out_channels=channels,
849
+ temb_channels=temb_channels,
850
+ dropout=dropout,
851
+ norm_type=norm_type,
852
+ causality_axis=causality_axis,
853
+ )
854
+ mid.attn_1 = make_attn(channels, attn_type=attn_type, norm_type=norm_type) if add_attention else torch.nn.Identity()
855
+ mid.block_2 = ResnetBlock(
856
+ in_channels=channels,
857
+ out_channels=channels,
858
+ temb_channels=temb_channels,
859
+ dropout=dropout,
860
+ norm_type=norm_type,
861
+ causality_axis=causality_axis,
862
+ )
863
+ return mid
864
+
865
+
866
+ def run_mid_block(mid: torch.nn.Module, features: torch.Tensor) -> torch.Tensor:
867
+ """Run features through the middle block."""
868
+ features = mid.block_1(features, temb=None)
869
+ features = mid.attn_1(features)
870
+ return mid.block_2(features, temb=None)
871
+
872
+
873
+ class LTX2AudioEncoder(torch.nn.Module):
874
+ """
875
+ Encoder that compresses audio spectrograms into latent representations.
876
+ The encoder uses a series of downsampling blocks with residual connections,
877
+ attention mechanisms, and configurable causal convolutions.
878
+ """
879
+
880
+ def __init__( # noqa: PLR0913
881
+ self,
882
+ *,
883
+ ch: int = 128,
884
+ ch_mult: Tuple[int, ...] = (1, 2, 4),
885
+ num_res_blocks: int = 2,
886
+ attn_resolutions: Set[int] = set(),
887
+ dropout: float = 0.0,
888
+ resamp_with_conv: bool = True,
889
+ in_channels: int = 2,
890
+ resolution: int = 256,
891
+ z_channels: int = 8,
892
+ double_z: bool = True,
893
+ attn_type: AttentionType = AttentionType.VANILLA,
894
+ mid_block_add_attention: bool = False,
895
+ norm_type: NormType = NormType.PIXEL,
896
+ causality_axis: CausalityAxis = CausalityAxis.HEIGHT,
897
+ sample_rate: int = 16000,
898
+ mel_hop_length: int = 160,
899
+ n_fft: int = 1024,
900
+ is_causal: bool = True,
901
+ mel_bins: int = 64,
902
+ **_ignore_kwargs,
903
+ ) -> None:
904
+ """
905
+ Initialize the Encoder.
906
+ Args:
907
+ Arguments are configuration parameters, loaded from the audio VAE checkpoint config
908
+ (audio_vae.model.params.ddconfig):
909
+ ch: Base number of feature channels used in the first convolution layer.
910
+ ch_mult: Multiplicative factors for the number of channels at each resolution level.
911
+ num_res_blocks: Number of residual blocks to use at each resolution level.
912
+ attn_resolutions: Spatial resolutions (e.g., in time/frequency) at which to apply attention.
913
+ resolution: Input spatial resolution of the spectrogram (height, width).
914
+ z_channels: Number of channels in the latent representation.
915
+ norm_type: Normalization layer type to use within the network (e.g., group, batch).
916
+ causality_axis: Axis along which convolutions should be causal (e.g., time axis).
917
+ sample_rate: Audio sample rate in Hz for the input signals.
918
+ mel_hop_length: Hop length used when computing the mel spectrogram.
919
+ n_fft: FFT size used to compute the spectrogram.
920
+ mel_bins: Number of mel-frequency bins in the input spectrogram.
921
+ in_channels: Number of channels in the input spectrogram tensor.
922
+ double_z: If True, predict both mean and log-variance (doubling latent channels).
923
+ is_causal: If True, use causal convolutions suitable for streaming setups.
924
+ dropout: Dropout probability used in residual and mid blocks.
925
+ attn_type: Type of attention mechanism to use in attention blocks.
926
+ resamp_with_conv: If True, perform resolution changes using strided convolutions.
927
+ mid_block_add_attention: If True, add an attention block in the mid-level of the encoder.
928
+ """
929
+ super().__init__()
930
+
931
+ self.per_channel_statistics = PerChannelStatistics(latent_channels=ch)
932
+ self.sample_rate = sample_rate
933
+ self.mel_hop_length = mel_hop_length
934
+ self.n_fft = n_fft
935
+ self.is_causal = is_causal
936
+ self.mel_bins = mel_bins
937
+
938
+ self.patchifier = AudioPatchifier(
939
+ patch_size=1,
940
+ audio_latent_downsample_factor=LATENT_DOWNSAMPLE_FACTOR,
941
+ sample_rate=sample_rate,
942
+ hop_length=mel_hop_length,
943
+ is_causal=is_causal,
944
+ )
945
+
946
+ self.ch = ch
947
+ self.temb_ch = 0
948
+ self.num_resolutions = len(ch_mult)
949
+ self.num_res_blocks = num_res_blocks
950
+ self.resolution = resolution
951
+ self.in_channels = in_channels
952
+ self.z_channels = z_channels
953
+ self.double_z = double_z
954
+ self.norm_type = norm_type
955
+ self.causality_axis = causality_axis
956
+ self.attn_type = attn_type
957
+
958
+ # downsampling
959
+ self.conv_in = make_conv2d(
960
+ in_channels,
961
+ self.ch,
962
+ kernel_size=3,
963
+ stride=1,
964
+ causality_axis=self.causality_axis,
965
+ )
966
+
967
+ self.non_linearity = torch.nn.SiLU()
968
+
969
+ self.down, block_in = build_downsampling_path(
970
+ ch=ch,
971
+ ch_mult=ch_mult,
972
+ num_resolutions=self.num_resolutions,
973
+ num_res_blocks=num_res_blocks,
974
+ resolution=resolution,
975
+ temb_channels=self.temb_ch,
976
+ dropout=dropout,
977
+ norm_type=self.norm_type,
978
+ causality_axis=self.causality_axis,
979
+ attn_type=self.attn_type,
980
+ attn_resolutions=attn_resolutions,
981
+ resamp_with_conv=resamp_with_conv,
982
+ )
983
+
984
+ self.mid = build_mid_block(
985
+ channels=block_in,
986
+ temb_channels=self.temb_ch,
987
+ dropout=dropout,
988
+ norm_type=self.norm_type,
989
+ causality_axis=self.causality_axis,
990
+ attn_type=self.attn_type,
991
+ add_attention=mid_block_add_attention,
992
+ )
993
+
994
+ self.norm_out = build_normalization_layer(block_in, normtype=self.norm_type)
995
+ self.conv_out = make_conv2d(
996
+ block_in,
997
+ 2 * z_channels if double_z else z_channels,
998
+ kernel_size=3,
999
+ stride=1,
1000
+ causality_axis=self.causality_axis,
1001
+ )
1002
+
1003
+ def forward(self, spectrogram: torch.Tensor) -> torch.Tensor:
1004
+ """
1005
+ Encode audio spectrogram into latent representations.
1006
+ Args:
1007
+ spectrogram: Input spectrogram of shape (batch, channels, time, frequency)
1008
+ Returns:
1009
+ Encoded latent representation of shape (batch, channels, frames, mel_bins)
1010
+ """
1011
+ h = self.conv_in(spectrogram)
1012
+ h = self._run_downsampling_path(h)
1013
+ h = run_mid_block(self.mid, h)
1014
+ h = self._finalize_output(h)
1015
+
1016
+ return self._normalize_latents(h)
1017
+
1018
+ def _run_downsampling_path(self, h: torch.Tensor) -> torch.Tensor:
1019
+ for level in range(self.num_resolutions):
1020
+ stage = self.down[level]
1021
+ for block_idx in range(self.num_res_blocks):
1022
+ h = stage.block[block_idx](h, temb=None)
1023
+ if stage.attn:
1024
+ h = stage.attn[block_idx](h)
1025
+
1026
+ if level != self.num_resolutions - 1:
1027
+ h = stage.downsample(h)
1028
+
1029
+ return h
1030
+
1031
+ def _finalize_output(self, h: torch.Tensor) -> torch.Tensor:
1032
+ h = self.norm_out(h)
1033
+ h = self.non_linearity(h)
1034
+ return self.conv_out(h)
1035
+
1036
+ def _normalize_latents(self, latent_output: torch.Tensor) -> torch.Tensor:
1037
+ """
1038
+ Normalize encoder latents using per-channel statistics.
1039
+ When the encoder is configured with ``double_z=True``, the final
1040
+ convolution produces twice the number of latent channels, typically
1041
+ interpreted as two concatenated tensors along the channel dimension
1042
+ (e.g., mean and variance or other auxiliary parameters).
1043
+ This method intentionally uses only the first half of the channels
1044
+ (the "mean" component) as input to the patchifier and normalization
1045
+ logic. The remaining channels are left unchanged by this method and
1046
+ are expected to be consumed elsewhere in the VAE pipeline.
1047
+ If ``double_z=False``, the encoder output already contains only the
1048
+ mean latents and the chunking operation simply returns that tensor.
1049
+ """
1050
+ means = torch.chunk(latent_output, 2, dim=1)[0]
1051
+ latent_shape = AudioLatentShape(
1052
+ batch=means.shape[0],
1053
+ channels=means.shape[1],
1054
+ frames=means.shape[2],
1055
+ mel_bins=means.shape[3],
1056
+ )
1057
+ latent_patched = self.patchifier.patchify(means)
1058
+ latent_normalized = self.per_channel_statistics.normalize(latent_patched)
1059
+ return self.patchifier.unpatchify(latent_normalized, latent_shape)
1060
+
1061
+
1062
+ class LTX2AudioDecoder(torch.nn.Module):
1063
+ """
1064
+ Symmetric decoder that reconstructs audio spectrograms from latent features.
1065
+ The decoder mirrors the encoder structure with configurable channel multipliers,
1066
+ attention resolutions, and causal convolutions.
1067
+ """
1068
+
1069
+ def __init__( # noqa: PLR0913
1070
+ self,
1071
+ *,
1072
+ ch: int = 128,
1073
+ out_ch: int = 2,
1074
+ ch_mult: Tuple[int, ...] = (1, 2, 4),
1075
+ num_res_blocks: int = 2,
1076
+ attn_resolutions: Set[int] = set(),
1077
+ resolution: int=256,
1078
+ z_channels: int=8,
1079
+ norm_type: NormType = NormType.PIXEL,
1080
+ causality_axis: CausalityAxis = CausalityAxis.HEIGHT,
1081
+ dropout: float = 0.0,
1082
+ mid_block_add_attention: bool = False,
1083
+ sample_rate: int = 16000,
1084
+ mel_hop_length: int = 160,
1085
+ is_causal: bool = True,
1086
+ mel_bins: int | None = 64,
1087
+ ) -> None:
1088
+ """
1089
+ Initialize the Decoder.
1090
+ Args:
1091
+ Arguments are configuration parameters, loaded from the audio VAE checkpoint config
1092
+ (audio_vae.model.params.ddconfig):
1093
+ - ch, out_ch, ch_mult, num_res_blocks, attn_resolutions
1094
+ - resolution, z_channels
1095
+ - norm_type, causality_axis
1096
+ """
1097
+ super().__init__()
1098
+
1099
+ # Internal behavioural defaults that are not driven by the checkpoint.
1100
+ resamp_with_conv = True
1101
+ attn_type = AttentionType.VANILLA
1102
+
1103
+ # Per-channel statistics for denormalizing latents
1104
+ self.per_channel_statistics = PerChannelStatistics(latent_channels=ch)
1105
+ self.sample_rate = sample_rate
1106
+ self.mel_hop_length = mel_hop_length
1107
+ self.is_causal = is_causal
1108
+ self.mel_bins = mel_bins
1109
+ self.patchifier = AudioPatchifier(
1110
+ patch_size=1,
1111
+ audio_latent_downsample_factor=LATENT_DOWNSAMPLE_FACTOR,
1112
+ sample_rate=sample_rate,
1113
+ hop_length=mel_hop_length,
1114
+ is_causal=is_causal,
1115
+ )
1116
+
1117
+ self.ch = ch
1118
+ self.temb_ch = 0
1119
+ self.num_resolutions = len(ch_mult)
1120
+ self.num_res_blocks = num_res_blocks
1121
+ self.resolution = resolution
1122
+ self.out_ch = out_ch
1123
+ self.give_pre_end = False
1124
+ self.tanh_out = False
1125
+ self.norm_type = norm_type
1126
+ self.z_channels = z_channels
1127
+ self.channel_multipliers = ch_mult
1128
+ self.attn_resolutions = attn_resolutions
1129
+ self.causality_axis = causality_axis
1130
+ self.attn_type = attn_type
1131
+
1132
+ base_block_channels = ch * self.channel_multipliers[-1]
1133
+ base_resolution = resolution // (2 ** (self.num_resolutions - 1))
1134
+ self.z_shape = (1, z_channels, base_resolution, base_resolution)
1135
+
1136
+ self.conv_in = make_conv2d(
1137
+ z_channels, base_block_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis
1138
+ )
1139
+ self.non_linearity = torch.nn.SiLU()
1140
+ self.mid = build_mid_block(
1141
+ channels=base_block_channels,
1142
+ temb_channels=self.temb_ch,
1143
+ dropout=dropout,
1144
+ norm_type=self.norm_type,
1145
+ causality_axis=self.causality_axis,
1146
+ attn_type=self.attn_type,
1147
+ add_attention=mid_block_add_attention,
1148
+ )
1149
+ self.up, final_block_channels = build_upsampling_path(
1150
+ ch=ch,
1151
+ ch_mult=ch_mult,
1152
+ num_resolutions=self.num_resolutions,
1153
+ num_res_blocks=num_res_blocks,
1154
+ resolution=resolution,
1155
+ temb_channels=self.temb_ch,
1156
+ dropout=dropout,
1157
+ norm_type=self.norm_type,
1158
+ causality_axis=self.causality_axis,
1159
+ attn_type=self.attn_type,
1160
+ attn_resolutions=attn_resolutions,
1161
+ resamp_with_conv=resamp_with_conv,
1162
+ initial_block_channels=base_block_channels,
1163
+ )
1164
+
1165
+ self.norm_out = build_normalization_layer(final_block_channels, normtype=self.norm_type)
1166
+ self.conv_out = make_conv2d(
1167
+ final_block_channels, out_ch, kernel_size=3, stride=1, causality_axis=self.causality_axis
1168
+ )
1169
+
1170
+ def forward(self, sample: torch.Tensor) -> torch.Tensor:
1171
+ """
1172
+ Decode latent features back to audio spectrograms.
1173
+ Args:
1174
+ sample: Encoded latent representation of shape (batch, channels, frames, mel_bins)
1175
+ Returns:
1176
+ Reconstructed audio spectrogram of shape (batch, channels, time, frequency)
1177
+ """
1178
+ sample, target_shape = self._denormalize_latents(sample)
1179
+
1180
+ h = self.conv_in(sample)
1181
+ h = run_mid_block(self.mid, h)
1182
+ h = self._run_upsampling_path(h)
1183
+ h = self._finalize_output(h)
1184
+
1185
+ return self._adjust_output_shape(h, target_shape)
1186
+
1187
+ def _denormalize_latents(self, sample: torch.Tensor) -> tuple[torch.Tensor, AudioLatentShape]:
1188
+ latent_shape = AudioLatentShape(
1189
+ batch=sample.shape[0],
1190
+ channels=sample.shape[1],
1191
+ frames=sample.shape[2],
1192
+ mel_bins=sample.shape[3],
1193
+ )
1194
+
1195
+ sample_patched = self.patchifier.patchify(sample)
1196
+ sample_denormalized = self.per_channel_statistics.un_normalize(sample_patched)
1197
+ sample = self.patchifier.unpatchify(sample_denormalized, latent_shape)
1198
+
1199
+ target_frames = latent_shape.frames * LATENT_DOWNSAMPLE_FACTOR
1200
+ if self.causality_axis != CausalityAxis.NONE:
1201
+ target_frames = max(target_frames - (LATENT_DOWNSAMPLE_FACTOR - 1), 1)
1202
+
1203
+ target_shape = AudioLatentShape(
1204
+ batch=latent_shape.batch,
1205
+ channels=self.out_ch,
1206
+ frames=target_frames,
1207
+ mel_bins=self.mel_bins if self.mel_bins is not None else latent_shape.mel_bins,
1208
+ )
1209
+
1210
+ return sample, target_shape
1211
+
1212
+ def _adjust_output_shape(
1213
+ self,
1214
+ decoded_output: torch.Tensor,
1215
+ target_shape: AudioLatentShape,
1216
+ ) -> torch.Tensor:
1217
+ """
1218
+ Adjust output shape to match target dimensions for variable-length audio.
1219
+ This function handles the common case where decoded audio spectrograms need to be
1220
+ resized to match a specific target shape.
1221
+ Args:
1222
+ decoded_output: Tensor of shape (batch, channels, time, frequency)
1223
+ target_shape: AudioLatentShape describing (batch, channels, time, mel bins)
1224
+ Returns:
1225
+ Tensor adjusted to match target_shape exactly
1226
+ """
1227
+ # Current output shape: (batch, channels, time, frequency)
1228
+ _, _, current_time, current_freq = decoded_output.shape
1229
+ target_channels = target_shape.channels
1230
+ target_time = target_shape.frames
1231
+ target_freq = target_shape.mel_bins
1232
+
1233
+ # Step 1: Crop first to avoid exceeding target dimensions
1234
+ decoded_output = decoded_output[
1235
+ :, :target_channels, : min(current_time, target_time), : min(current_freq, target_freq)
1236
+ ]
1237
+
1238
+ # Step 2: Calculate padding needed for time and frequency dimensions
1239
+ time_padding_needed = target_time - decoded_output.shape[2]
1240
+ freq_padding_needed = target_freq - decoded_output.shape[3]
1241
+
1242
+ # Step 3: Apply padding if needed
1243
+ if time_padding_needed > 0 or freq_padding_needed > 0:
1244
+ # PyTorch padding format: (pad_left, pad_right, pad_top, pad_bottom)
1245
+ # For audio: pad_left/right = frequency, pad_top/bottom = time
1246
+ padding = (
1247
+ 0,
1248
+ max(freq_padding_needed, 0), # frequency padding (left, right)
1249
+ 0,
1250
+ max(time_padding_needed, 0), # time padding (top, bottom)
1251
+ )
1252
+ decoded_output = F.pad(decoded_output, padding)
1253
+
1254
+ # Step 4: Final safety crop to ensure exact target shape
1255
+ decoded_output = decoded_output[:, :target_channels, :target_time, :target_freq]
1256
+
1257
+ return decoded_output
1258
+
1259
+ def _run_upsampling_path(self, h: torch.Tensor) -> torch.Tensor:
1260
+ for level in reversed(range(self.num_resolutions)):
1261
+ stage = self.up[level]
1262
+ for block_idx, block in enumerate(stage.block):
1263
+ h = block(h, temb=None)
1264
+ if stage.attn:
1265
+ h = stage.attn[block_idx](h)
1266
+
1267
+ if level != 0 and hasattr(stage, "upsample"):
1268
+ h = stage.upsample(h)
1269
+
1270
+ return h
1271
+
1272
+ def _finalize_output(self, h: torch.Tensor) -> torch.Tensor:
1273
+ if self.give_pre_end:
1274
+ return h
1275
+
1276
+ h = self.norm_out(h)
1277
+ h = self.non_linearity(h)
1278
+ h = self.conv_out(h)
1279
+ return torch.tanh(h) if self.tanh_out else h
1280
+
1281
+
1282
+ def get_padding(kernel_size: int, dilation: int = 1) -> int:
1283
+ return int((kernel_size * dilation - dilation) / 2)
1284
+
1285
+
1286
+ # ---------------------------------------------------------------------------
1287
+ # Anti-aliased resampling helpers (kaiser-sinc filters) for BigVGAN v2
1288
+ # Adopted from https://github.com/NVIDIA/BigVGAN
1289
+ # ---------------------------------------------------------------------------
1290
+
1291
+
1292
+ def _sinc(x: torch.Tensor) -> torch.Tensor:
1293
+ return torch.where(
1294
+ x == 0,
1295
+ torch.tensor(1.0, device=x.device, dtype=x.dtype),
1296
+ torch.sin(math.pi * x) / math.pi / x,
1297
+ )
1298
+
1299
+
1300
+ def kaiser_sinc_filter1d(cutoff: float, half_width: float, kernel_size: int) -> torch.Tensor:
1301
+ even = kernel_size % 2 == 0
1302
+ half_size = kernel_size // 2
1303
+ delta_f = 4 * half_width
1304
+ amplitude = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
1305
+ if amplitude > 50.0:
1306
+ beta = 0.1102 * (amplitude - 8.7)
1307
+ elif amplitude >= 21.0:
1308
+ beta = 0.5842 * (amplitude - 21) ** 0.4 + 0.07886 * (amplitude - 21.0)
1309
+ else:
1310
+ beta = 0.0
1311
+ window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
1312
+ time = torch.arange(-half_size, half_size) + 0.5 if even else torch.arange(kernel_size) - half_size
1313
+ if cutoff == 0:
1314
+ filter_ = torch.zeros_like(time)
1315
+ else:
1316
+ filter_ = 2 * cutoff * window * _sinc(2 * cutoff * time)
1317
+ filter_ /= filter_.sum()
1318
+ return filter_.view(1, 1, kernel_size)
1319
+
1320
+
1321
+ class LowPassFilter1d(nn.Module):
1322
+ def __init__(
1323
+ self,
1324
+ cutoff: float = 0.5,
1325
+ half_width: float = 0.6,
1326
+ stride: int = 1,
1327
+ padding: bool = True,
1328
+ padding_mode: str = "replicate",
1329
+ kernel_size: int = 12,
1330
+ ) -> None:
1331
+ super().__init__()
1332
+ if cutoff < -0.0:
1333
+ raise ValueError("Minimum cutoff must be larger than zero.")
1334
+ if cutoff > 0.5:
1335
+ raise ValueError("A cutoff above 0.5 does not make sense.")
1336
+ self.kernel_size = kernel_size
1337
+ self.even = kernel_size % 2 == 0
1338
+ self.pad_left = kernel_size // 2 - int(self.even)
1339
+ self.pad_right = kernel_size // 2
1340
+ self.stride = stride
1341
+ self.padding = padding
1342
+ self.padding_mode = padding_mode
1343
+ self.register_buffer("filter", kaiser_sinc_filter1d(cutoff, half_width, kernel_size))
1344
+
1345
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
1346
+ _, n_channels, _ = x.shape
1347
+ if self.padding:
1348
+ x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode)
1349
+ return F.conv1d(x, self.filter.expand(n_channels, -1, -1), stride=self.stride, groups=n_channels)
1350
+
1351
+
1352
+ class UpSample1d(nn.Module):
1353
+ def __init__(
1354
+ self,
1355
+ ratio: int = 2,
1356
+ kernel_size: int | None = None,
1357
+ persistent: bool = True,
1358
+ window_type: str = "kaiser",
1359
+ ) -> None:
1360
+ super().__init__()
1361
+ self.ratio = ratio
1362
+ self.stride = ratio
1363
+
1364
+ if window_type == "hann":
1365
+ # Hann-windowed sinc filter equivalent to torchaudio.functional.resample
1366
+ rolloff = 0.99
1367
+ lowpass_filter_width = 6
1368
+ width = math.ceil(lowpass_filter_width / rolloff)
1369
+ self.kernel_size = 2 * width * ratio + 1
1370
+ self.pad = width
1371
+ self.pad_left = 2 * width * ratio
1372
+ self.pad_right = self.kernel_size - ratio
1373
+ time_axis = (torch.arange(self.kernel_size) / ratio - width) * rolloff
1374
+ time_clamped = time_axis.clamp(-lowpass_filter_width, lowpass_filter_width)
1375
+ window = torch.cos(time_clamped * math.pi / lowpass_filter_width / 2) ** 2
1376
+ sinc_filter = (torch.sinc(time_axis) * window * rolloff / ratio).view(1, 1, -1)
1377
+ else:
1378
+ # Kaiser-windowed sinc filter (BigVGAN default).
1379
+ self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
1380
+ self.pad = self.kernel_size // ratio - 1
1381
+ self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
1382
+ self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
1383
+ sinc_filter = kaiser_sinc_filter1d(
1384
+ cutoff=0.5 / ratio,
1385
+ half_width=0.6 / ratio,
1386
+ kernel_size=self.kernel_size,
1387
+ )
1388
+
1389
+ self.register_buffer("filter", sinc_filter, persistent=persistent)
1390
+
1391
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
1392
+ _, n_channels, _ = x.shape
1393
+ x = F.pad(x, (self.pad, self.pad), mode="replicate")
1394
+ filt = self.filter.to(dtype=x.dtype, device=x.device).expand(n_channels, -1, -1)
1395
+ x = self.ratio * F.conv_transpose1d(x, filt, stride=self.stride, groups=n_channels)
1396
+ return x[..., self.pad_left : -self.pad_right]
1397
+
1398
+
1399
+ class DownSample1d(nn.Module):
1400
+ def __init__(self, ratio: int = 2, kernel_size: int | None = None) -> None:
1401
+ super().__init__()
1402
+ self.ratio = ratio
1403
+ self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
1404
+ self.lowpass = LowPassFilter1d(
1405
+ cutoff=0.5 / ratio,
1406
+ half_width=0.6 / ratio,
1407
+ stride=ratio,
1408
+ kernel_size=self.kernel_size,
1409
+ )
1410
+
1411
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
1412
+ return self.lowpass(x)
1413
+
1414
+
1415
+ class Activation1d(nn.Module):
1416
+ def __init__(
1417
+ self,
1418
+ activation: nn.Module,
1419
+ up_ratio: int = 2,
1420
+ down_ratio: int = 2,
1421
+ up_kernel_size: int = 12,
1422
+ down_kernel_size: int = 12,
1423
+ ) -> None:
1424
+ super().__init__()
1425
+ self.act = activation
1426
+ self.upsample = UpSample1d(up_ratio, up_kernel_size)
1427
+ self.downsample = DownSample1d(down_ratio, down_kernel_size)
1428
+
1429
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
1430
+ x = self.upsample(x)
1431
+ x = self.act(x)
1432
+ return self.downsample(x)
1433
+
1434
+
1435
+ class Snake(nn.Module):
1436
+ def __init__(
1437
+ self,
1438
+ in_features: int,
1439
+ alpha: float = 1.0,
1440
+ alpha_trainable: bool = True,
1441
+ alpha_logscale: bool = True,
1442
+ ) -> None:
1443
+ super().__init__()
1444
+ self.alpha_logscale = alpha_logscale
1445
+ self.alpha = nn.Parameter(torch.zeros(in_features) if alpha_logscale else torch.ones(in_features) * alpha)
1446
+ self.alpha.requires_grad = alpha_trainable
1447
+ self.eps = 1e-9
1448
+
1449
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
1450
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1)
1451
+ if self.alpha_logscale:
1452
+ alpha = torch.exp(alpha)
1453
+ return x + (1.0 / (alpha + self.eps)) * torch.sin(x * alpha).pow(2)
1454
+
1455
+
1456
+ class SnakeBeta(nn.Module):
1457
+ def __init__(
1458
+ self,
1459
+ in_features: int,
1460
+ alpha: float = 1.0,
1461
+ alpha_trainable: bool = True,
1462
+ alpha_logscale: bool = True,
1463
+ ) -> None:
1464
+ super().__init__()
1465
+ self.alpha_logscale = alpha_logscale
1466
+ self.alpha = nn.Parameter(torch.zeros(in_features) if alpha_logscale else torch.ones(in_features) * alpha)
1467
+ self.alpha.requires_grad = alpha_trainable
1468
+ self.beta = nn.Parameter(torch.zeros(in_features) if alpha_logscale else torch.ones(in_features) * alpha)
1469
+ self.beta.requires_grad = alpha_trainable
1470
+ self.eps = 1e-9
1471
+
1472
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
1473
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1)
1474
+ beta = self.beta.unsqueeze(0).unsqueeze(-1)
1475
+ if self.alpha_logscale:
1476
+ alpha = torch.exp(alpha)
1477
+ beta = torch.exp(beta)
1478
+ return x + (1.0 / (beta + self.eps)) * torch.sin(x * alpha).pow(2)
1479
+
1480
+
1481
+ class AMPBlock1(nn.Module):
1482
+ def __init__(
1483
+ self,
1484
+ channels: int,
1485
+ kernel_size: int = 3,
1486
+ dilation: tuple[int, int, int] = (1, 3, 5),
1487
+ activation: str = "snake",
1488
+ ) -> None:
1489
+ super().__init__()
1490
+ act_cls = SnakeBeta if activation == "snakebeta" else Snake
1491
+ self.convs1 = nn.ModuleList(
1492
+ [
1493
+ nn.Conv1d(
1494
+ channels,
1495
+ channels,
1496
+ kernel_size,
1497
+ 1,
1498
+ dilation=dilation[0],
1499
+ padding=get_padding(kernel_size, dilation[0]),
1500
+ ),
1501
+ nn.Conv1d(
1502
+ channels,
1503
+ channels,
1504
+ kernel_size,
1505
+ 1,
1506
+ dilation=dilation[1],
1507
+ padding=get_padding(kernel_size, dilation[1]),
1508
+ ),
1509
+ nn.Conv1d(
1510
+ channels,
1511
+ channels,
1512
+ kernel_size,
1513
+ 1,
1514
+ dilation=dilation[2],
1515
+ padding=get_padding(kernel_size, dilation[2]),
1516
+ ),
1517
+ ]
1518
+ )
1519
+
1520
+ self.convs2 = nn.ModuleList(
1521
+ [
1522
+ nn.Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1)),
1523
+ nn.Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1)),
1524
+ nn.Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1)),
1525
+ ]
1526
+ )
1527
+
1528
+ self.acts1 = nn.ModuleList([Activation1d(act_cls(channels)) for _ in range(len(self.convs1))])
1529
+ self.acts2 = nn.ModuleList([Activation1d(act_cls(channels)) for _ in range(len(self.convs2))])
1530
+
1531
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
1532
+ for c1, c2, a1, a2 in zip(self.convs1, self.convs2, self.acts1, self.acts2, strict=True):
1533
+ xt = a1(x)
1534
+ xt = c1(xt)
1535
+ xt = a2(xt)
1536
+ xt = c2(xt)
1537
+ x = x + xt
1538
+ return x
1539
+
1540
+
1541
+ class LTX2Vocoder(torch.nn.Module):
1542
+ """
1543
+ LTX2Vocoder model for synthesizing audio from Mel spectrograms.
1544
+ Args:
1545
+ resblock_kernel_sizes: List of kernel sizes for the residual blocks.
1546
+ This value is read from the checkpoint at `config.vocoder.resblock_kernel_sizes`.
1547
+ upsample_rates: List of upsampling rates.
1548
+ This value is read from the checkpoint at `config.vocoder.upsample_rates`.
1549
+ upsample_kernel_sizes: List of kernel sizes for the upsampling layers.
1550
+ This value is read from the checkpoint at `config.vocoder.upsample_kernel_sizes`.
1551
+ resblock_dilation_sizes: List of dilation sizes for the residual blocks.
1552
+ This value is read from the checkpoint at `config.vocoder.resblock_dilation_sizes`.
1553
+ upsample_initial_channel: Initial number of channels for the upsampling layers.
1554
+ This value is read from the checkpoint at `config.vocoder.upsample_initial_channel`.
1555
+ resblock: Type of residual block to use ("1", "2", or "AMP1").
1556
+ This value is read from the checkpoint at `config.vocoder.resblock`.
1557
+ output_sampling_rate: Waveform sample rate.
1558
+ This value is read from the checkpoint at `config.vocoder.output_sampling_rate`.
1559
+ activation: Activation type for BigVGAN v2 ("snake" or "snakebeta"). Only used when resblock="AMP1".
1560
+ use_tanh_at_final: Apply tanh at the output (when apply_final_activation=True).
1561
+ apply_final_activation: Whether to apply the final tanh/clamp activation.
1562
+ use_bias_at_final: Whether to use bias in the final conv layer.
1563
+ """
1564
+
1565
+ def __init__( # noqa: PLR0913
1566
+ self,
1567
+ resblock_kernel_sizes: List[int] | None = [3, 7, 11],
1568
+ upsample_rates: List[int] | None = [6, 5, 2, 2, 2],
1569
+ upsample_kernel_sizes: List[int] | None = [16, 15, 8, 4, 4],
1570
+ resblock_dilation_sizes: List[List[int]] | None = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
1571
+ upsample_initial_channel: int = 1024,
1572
+ resblock: str = "1",
1573
+ output_sampling_rate: int = 24000,
1574
+ activation: str = "snake",
1575
+ use_tanh_at_final: bool = True,
1576
+ apply_final_activation: bool = True,
1577
+ use_bias_at_final: bool = True,
1578
+ ) -> None:
1579
+ super().__init__()
1580
+
1581
+ # Mutable default values are not supported as default arguments.
1582
+ if resblock_kernel_sizes is None:
1583
+ resblock_kernel_sizes = [3, 7, 11]
1584
+ if upsample_rates is None:
1585
+ upsample_rates = [6, 5, 2, 2, 2]
1586
+ if upsample_kernel_sizes is None:
1587
+ upsample_kernel_sizes = [16, 15, 8, 4, 4]
1588
+ if resblock_dilation_sizes is None:
1589
+ resblock_dilation_sizes = [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
1590
+
1591
+ self.output_sampling_rate = output_sampling_rate
1592
+ self.num_kernels = len(resblock_kernel_sizes)
1593
+ self.num_upsamples = len(upsample_rates)
1594
+ self.use_tanh_at_final = use_tanh_at_final
1595
+ self.apply_final_activation = apply_final_activation
1596
+ self.is_amp = resblock == "AMP1"
1597
+
1598
+ # All production checkpoints are stereo: 128 input channels (2 stereo channels x 64 mel
1599
+ # bins each), 2 output channels.
1600
+ self.conv_pre = nn.Conv1d(
1601
+ in_channels=128,
1602
+ out_channels=upsample_initial_channel,
1603
+ kernel_size=7,
1604
+ stride=1,
1605
+ padding=3,
1606
+ )
1607
+ resblock_cls = ResBlock1 if resblock == "1" else AMPBlock1
1608
+
1609
+ self.ups = nn.ModuleList(
1610
+ nn.ConvTranspose1d(
1611
+ upsample_initial_channel // (2**i),
1612
+ upsample_initial_channel // (2 ** (i + 1)),
1613
+ kernel_size,
1614
+ stride,
1615
+ padding=(kernel_size - stride) // 2,
1616
+ )
1617
+ for i, (stride, kernel_size) in enumerate(zip(upsample_rates, upsample_kernel_sizes, strict=True))
1618
+ )
1619
+
1620
+ final_channels = upsample_initial_channel // (2 ** len(upsample_rates))
1621
+ self.resblocks = nn.ModuleList()
1622
+
1623
+ for i in range(len(upsample_rates)):
1624
+ ch = upsample_initial_channel // (2 ** (i + 1))
1625
+ for kernel_size, dilations in zip(resblock_kernel_sizes, resblock_dilation_sizes, strict=True):
1626
+ if self.is_amp:
1627
+ self.resblocks.append(resblock_cls(ch, kernel_size, dilations, activation=activation))
1628
+ else:
1629
+ self.resblocks.append(resblock_cls(ch, kernel_size, dilations))
1630
+
1631
+ if self.is_amp:
1632
+ self.act_post: nn.Module = Activation1d(SnakeBeta(final_channels))
1633
+ else:
1634
+ self.act_post = nn.LeakyReLU()
1635
+
1636
+ # All production checkpoints are stereo: this final conv maps `final_channels` to 2 output channels (stereo).
1637
+ self.conv_post = nn.Conv1d(
1638
+ in_channels=final_channels,
1639
+ out_channels=2,
1640
+ kernel_size=7,
1641
+ stride=1,
1642
+ padding=3,
1643
+ bias=use_bias_at_final,
1644
+ )
1645
+
1646
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
1647
+ """
1648
+ Forward pass of the vocoder.
1649
+ Args:
1650
+ x: Input Mel spectrogram tensor. Can be either:
1651
+ - 3D: (batch_size, time, mel_bins) for mono
1652
+ - 4D: (batch_size, 2, time, mel_bins) for stereo
1653
+ Returns:
1654
+ Audio waveform tensor of shape (batch_size, out_channels, audio_length)
1655
+ """
1656
+ x = x.transpose(2, 3) # (batch, channels, time, mel_bins) -> (batch, channels, mel_bins, time)
1657
+
1658
+ if x.dim() == 4: # stereo
1659
+ assert x.shape[1] == 2, "Input must have 2 channels for stereo"
1660
+ x = einops.rearrange(x, "b s c t -> b (s c) t")
1661
+
1662
+ x = self.conv_pre(x)
1663
+
1664
+ for i in range(self.num_upsamples):
1665
+ if not self.is_amp:
1666
+ x = F.leaky_relu(x, LRELU_SLOPE)
1667
+ x = self.ups[i](x)
1668
+ start = i * self.num_kernels
1669
+ end = start + self.num_kernels
1670
+
1671
+ # Evaluate all resblocks with the same input tensor so they can run
1672
+ # independently (and thus in parallel on accelerator hardware) before
1673
+ # aggregating their outputs via mean.
1674
+ block_outputs = torch.stack(
1675
+ [self.resblocks[idx](x) for idx in range(start, end)],
1676
+ dim=0,
1677
+ )
1678
+ x = block_outputs.mean(dim=0)
1679
+
1680
+ x = self.act_post(x)
1681
+ x = self.conv_post(x)
1682
+
1683
+ if self.apply_final_activation:
1684
+ x = torch.tanh(x) if self.use_tanh_at_final else torch.clamp(x, -1, 1)
1685
+
1686
+ return x
1687
+
1688
+
1689
+ class _STFTFn(nn.Module):
1690
+ """Implements STFT as a convolution with precomputed DFT x Hann-window bases.
1691
+ The DFT basis rows (real and imaginary parts interleaved) multiplied by the causal
1692
+ Hann window are stored as buffers and loaded from the checkpoint. Using the exact
1693
+ bfloat16 bases from training ensures the mel values fed to the BWE generator are
1694
+ bit-identical to what it was trained on.
1695
+ """
1696
+
1697
+ def __init__(self, filter_length: int, hop_length: int, win_length: int) -> None:
1698
+ super().__init__()
1699
+ self.hop_length = hop_length
1700
+ self.win_length = win_length
1701
+ n_freqs = filter_length // 2 + 1
1702
+ self.register_buffer("forward_basis", torch.zeros(n_freqs * 2, 1, filter_length))
1703
+ self.register_buffer("inverse_basis", torch.zeros(n_freqs * 2, 1, filter_length))
1704
+
1705
+ def forward(self, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
1706
+ """Compute magnitude and phase spectrogram from a batch of waveforms.
1707
+ Applies causal (left-only) padding of win_length - hop_length samples so that
1708
+ each output frame depends only on past and present input — no lookahead.
1709
+ Args:
1710
+ y: Waveform tensor of shape (B, T).
1711
+ Returns:
1712
+ magnitude: Linear amplitude spectrogram, shape (B, n_freqs, T_frames).
1713
+ phase: Phase spectrogram in radians, shape (B, n_freqs, T_frames).
1714
+ """
1715
+ if y.dim() == 2:
1716
+ y = y.unsqueeze(1) # (B, 1, T)
1717
+ left_pad = max(0, self.win_length - self.hop_length) # causal: left-only
1718
+ y = F.pad(y, (left_pad, 0))
1719
+ spec = F.conv1d(y, self.forward_basis, stride=self.hop_length, padding=0)
1720
+ n_freqs = spec.shape[1] // 2
1721
+ real, imag = spec[:, :n_freqs], spec[:, n_freqs:]
1722
+ magnitude = torch.sqrt(real**2 + imag**2)
1723
+ phase = torch.atan2(imag.float(), real.float()).to(real.dtype)
1724
+ return magnitude, phase
1725
+
1726
+
1727
+ class MelSTFT(nn.Module):
1728
+ """Causal log-mel spectrogram module whose buffers are loaded from the checkpoint.
1729
+ Computes a log-mel spectrogram by running the causal STFT (_STFTFn) on the input
1730
+ waveform and projecting the linear magnitude spectrum onto the mel filterbank.
1731
+ The module's state dict layout matches the 'mel_stft.*' keys stored in the checkpoint
1732
+ (mel_basis, stft_fn.forward_basis, stft_fn.inverse_basis).
1733
+ """
1734
+
1735
+ def __init__(
1736
+ self,
1737
+ filter_length: int,
1738
+ hop_length: int,
1739
+ win_length: int,
1740
+ n_mel_channels: int,
1741
+ ) -> None:
1742
+ super().__init__()
1743
+ self.stft_fn = _STFTFn(filter_length, hop_length, win_length)
1744
+
1745
+ # Initialized to zeros; load_state_dict overwrites with the checkpoint's
1746
+ # exact bfloat16 filterbank (vocoder.mel_stft.mel_basis, shape [n_mels, n_freqs]).
1747
+ n_freqs = filter_length // 2 + 1
1748
+ self.register_buffer("mel_basis", torch.zeros(n_mel_channels, n_freqs))
1749
+
1750
+ def mel_spectrogram(self, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
1751
+ """Compute log-mel spectrogram and auxiliary spectral quantities.
1752
+ Args:
1753
+ y: Waveform tensor of shape (B, T).
1754
+ Returns:
1755
+ log_mel: Log-compressed mel spectrogram, shape (B, n_mel_channels, T_frames).
1756
+ magnitude: Linear amplitude spectrogram, shape (B, n_freqs, T_frames).
1757
+ phase: Phase spectrogram in radians, shape (B, n_freqs, T_frames).
1758
+ energy: Per-frame energy (L2 norm over frequency), shape (B, T_frames).
1759
+ """
1760
+ magnitude, phase = self.stft_fn(y)
1761
+ energy = torch.norm(magnitude, dim=1)
1762
+ mel = torch.matmul(self.mel_basis.to(magnitude.dtype), magnitude)
1763
+ log_mel = torch.log(torch.clamp(mel, min=1e-5))
1764
+ return log_mel, magnitude, phase, energy
1765
+
1766
+
1767
+ class LTX2VocoderWithBWE(nn.Module):
1768
+ """LTX2Vocoder with bandwidth extension (BWE) upsampling.
1769
+ Chains a mel-to-wav vocoder with a BWE module that upsamples the output
1770
+ to a higher sample rate. The BWE computes a mel spectrogram from the
1771
+ vocoder output, runs it through a second generator to predict a residual,
1772
+ and adds it to a sinc-resampled skip connection.
1773
+ """
1774
+
1775
+ def __init__(
1776
+ self,
1777
+ input_sampling_rate: int = 16000,
1778
+ output_sampling_rate: int = 48000,
1779
+ hop_length: int = 80,
1780
+ ) -> None:
1781
+ super().__init__()
1782
+ self.vocoder = LTX2Vocoder(
1783
+ resblock_kernel_sizes=[3, 7, 11],
1784
+ upsample_rates=[5, 2, 2, 2, 2, 2],
1785
+ upsample_kernel_sizes=[11, 4, 4, 4, 4, 4],
1786
+ resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
1787
+ upsample_initial_channel=1536,
1788
+ resblock="AMP1",
1789
+ activation="snakebeta",
1790
+ use_tanh_at_final=False,
1791
+ apply_final_activation=True,
1792
+ use_bias_at_final=False,
1793
+ output_sampling_rate=input_sampling_rate,
1794
+ )
1795
+ self.bwe_generator = LTX2Vocoder(
1796
+ resblock_kernel_sizes=[3, 7, 11],
1797
+ upsample_rates=[6, 5, 2, 2, 2],
1798
+ upsample_kernel_sizes=[12, 11, 4, 4, 4],
1799
+ resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
1800
+ upsample_initial_channel=512,
1801
+ resblock="AMP1",
1802
+ activation="snakebeta",
1803
+ use_tanh_at_final=False,
1804
+ apply_final_activation=False,
1805
+ use_bias_at_final=False,
1806
+ output_sampling_rate=output_sampling_rate,
1807
+ )
1808
+
1809
+ self.mel_stft = MelSTFT(
1810
+ filter_length=512,
1811
+ hop_length=hop_length,
1812
+ win_length=512,
1813
+ n_mel_channels=64,
1814
+ )
1815
+ self.input_sampling_rate = input_sampling_rate
1816
+ self.output_sampling_rate = output_sampling_rate
1817
+ self.hop_length = hop_length
1818
+ # Compute the resampler on CPU so the sinc filter is materialized even when
1819
+ # the model is constructed on meta device (SingleGPUModelBuilder pattern).
1820
+ # The filter is not stored in the checkpoint (persistent=False).
1821
+ with torch.device("cpu"):
1822
+ self.resampler = UpSample1d(
1823
+ ratio=output_sampling_rate // input_sampling_rate, persistent=False, window_type="hann"
1824
+ )
1825
+
1826
+ @property
1827
+ def conv_pre(self) -> nn.Conv1d:
1828
+ return self.vocoder.conv_pre
1829
+
1830
+ @property
1831
+ def conv_post(self) -> nn.Conv1d:
1832
+ return self.vocoder.conv_post
1833
+
1834
+ def _compute_mel(self, audio: torch.Tensor) -> torch.Tensor:
1835
+ """Compute log-mel spectrogram from waveform using causal STFT bases.
1836
+ Args:
1837
+ audio: Waveform tensor of shape (B, C, T).
1838
+ Returns:
1839
+ mel: Log-mel spectrogram of shape (B, C, n_mels, T_frames).
1840
+ """
1841
+ batch, n_channels, _ = audio.shape
1842
+ flat = audio.reshape(batch * n_channels, -1) # (B*C, T)
1843
+ mel, _, _, _ = self.mel_stft.mel_spectrogram(flat) # (B*C, n_mels, T_frames)
1844
+ return mel.reshape(batch, n_channels, mel.shape[1], mel.shape[2]) # (B, C, n_mels, T_frames)
1845
+
1846
+ def forward(self, mel_spec: torch.Tensor) -> torch.Tensor:
1847
+ """Run the full vocoder + BWE forward pass.
1848
+ Args:
1849
+ mel_spec: Mel spectrogram of shape (B, 2, T, mel_bins) for stereo
1850
+ or (B, T, mel_bins) for mono. Same format as LTX2Vocoder.forward.
1851
+ Returns:
1852
+ Waveform tensor of shape (B, out_channels, T_out) clipped to [-1, 1].
1853
+ """
1854
+ x = self.vocoder(mel_spec)
1855
+ _, _, length_low_rate = x.shape
1856
+ output_length = length_low_rate * self.output_sampling_rate // self.input_sampling_rate
1857
+
1858
+ # Pad to multiple of hop_length for exact mel frame count
1859
+ remainder = length_low_rate % self.hop_length
1860
+ if remainder != 0:
1861
+ x = F.pad(x, (0, self.hop_length - remainder))
1862
+
1863
+ # Compute mel spectrogram from vocoder output: (B, C, n_mels, T_frames)
1864
+ mel = self._compute_mel(x)
1865
+
1866
+ # LTX2Vocoder.forward expects (B, C, T, mel_bins) — transpose before calling bwe_generator
1867
+ mel_for_bwe = mel.transpose(2, 3) # (B, C, T_frames, mel_bins)
1868
+ residual = self.bwe_generator(mel_for_bwe)
1869
+ skip = self.resampler(x)
1870
+ assert residual.shape == skip.shape, f"residual {residual.shape} != skip {skip.shape}"
1871
+
1872
+ return torch.clamp(residual + skip, -1, 1)[..., :output_length]
diffsynth/models/ltx2_common.py ADDED
@@ -0,0 +1,388 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import NamedTuple, Protocol, Tuple
3
+ import torch
4
+ from torch import nn
5
+ from enum import Enum
6
+
7
+
8
+ class VideoPixelShape(NamedTuple):
9
+ """
10
+ Shape of the tensor representing the video pixel array. Assumes BGR channel format.
11
+ """
12
+
13
+ batch: int
14
+ frames: int
15
+ height: int
16
+ width: int
17
+ fps: float
18
+
19
+
20
+ class SpatioTemporalScaleFactors(NamedTuple):
21
+ """
22
+ Describes the spatiotemporal downscaling between decoded video space and
23
+ the corresponding VAE latent grid.
24
+ """
25
+
26
+ time: int
27
+ width: int
28
+ height: int
29
+
30
+ @classmethod
31
+ def default(cls) -> "SpatioTemporalScaleFactors":
32
+ return cls(time=8, width=32, height=32)
33
+
34
+
35
+ VIDEO_SCALE_FACTORS = SpatioTemporalScaleFactors.default()
36
+
37
+
38
+ class VideoLatentShape(NamedTuple):
39
+ """
40
+ Shape of the tensor representing video in VAE latent space.
41
+ The latent representation is a 5D tensor with dimensions ordered as
42
+ (batch, channels, frames, height, width). Spatial and temporal dimensions
43
+ are downscaled relative to pixel space according to the VAE's scale factors.
44
+ """
45
+
46
+ batch: int
47
+ channels: int
48
+ frames: int
49
+ height: int
50
+ width: int
51
+
52
+ def to_torch_shape(self) -> torch.Size:
53
+ return torch.Size([self.batch, self.channels, self.frames, self.height, self.width])
54
+
55
+ @staticmethod
56
+ def from_torch_shape(shape: torch.Size) -> "VideoLatentShape":
57
+ return VideoLatentShape(
58
+ batch=shape[0],
59
+ channels=shape[1],
60
+ frames=shape[2],
61
+ height=shape[3],
62
+ width=shape[4],
63
+ )
64
+
65
+ def mask_shape(self) -> "VideoLatentShape":
66
+ return self._replace(channels=1)
67
+
68
+ @staticmethod
69
+ def from_pixel_shape(
70
+ shape: VideoPixelShape,
71
+ latent_channels: int = 128,
72
+ scale_factors: SpatioTemporalScaleFactors = VIDEO_SCALE_FACTORS,
73
+ ) -> "VideoLatentShape":
74
+ frames = (shape.frames - 1) // scale_factors[0] + 1
75
+ height = shape.height // scale_factors[1]
76
+ width = shape.width // scale_factors[2]
77
+
78
+ return VideoLatentShape(
79
+ batch=shape.batch,
80
+ channels=latent_channels,
81
+ frames=frames,
82
+ height=height,
83
+ width=width,
84
+ )
85
+
86
+ def upscale(self, scale_factors: SpatioTemporalScaleFactors = VIDEO_SCALE_FACTORS) -> "VideoLatentShape":
87
+ return self._replace(
88
+ channels=3,
89
+ frames=(self.frames - 1) * scale_factors.time + 1,
90
+ height=self.height * scale_factors.height,
91
+ width=self.width * scale_factors.width,
92
+ )
93
+
94
+
95
+ class AudioLatentShape(NamedTuple):
96
+ """
97
+ Shape of audio in VAE latent space: (batch, channels, frames, mel_bins).
98
+ mel_bins is the number of frequency bins from the mel-spectrogram encoding.
99
+ """
100
+
101
+ batch: int
102
+ channels: int
103
+ frames: int
104
+ mel_bins: int
105
+
106
+ def to_torch_shape(self) -> torch.Size:
107
+ return torch.Size([self.batch, self.channels, self.frames, self.mel_bins])
108
+
109
+ def mask_shape(self) -> "AudioLatentShape":
110
+ return self._replace(channels=1, mel_bins=1)
111
+
112
+ @staticmethod
113
+ def from_torch_shape(shape: torch.Size) -> "AudioLatentShape":
114
+ return AudioLatentShape(
115
+ batch=shape[0],
116
+ channels=shape[1],
117
+ frames=shape[2],
118
+ mel_bins=shape[3],
119
+ )
120
+
121
+ @staticmethod
122
+ def from_duration(
123
+ batch: int,
124
+ duration: float,
125
+ channels: int = 8,
126
+ mel_bins: int = 16,
127
+ sample_rate: int = 16000,
128
+ hop_length: int = 160,
129
+ audio_latent_downsample_factor: int = 4,
130
+ ) -> "AudioLatentShape":
131
+ latents_per_second = float(sample_rate) / float(hop_length) / float(audio_latent_downsample_factor)
132
+
133
+ return AudioLatentShape(
134
+ batch=batch,
135
+ channels=channels,
136
+ frames=round(duration * latents_per_second),
137
+ mel_bins=mel_bins,
138
+ )
139
+
140
+ @staticmethod
141
+ def from_video_pixel_shape(
142
+ shape: VideoPixelShape,
143
+ channels: int = 8,
144
+ mel_bins: int = 16,
145
+ sample_rate: int = 16000,
146
+ hop_length: int = 160,
147
+ audio_latent_downsample_factor: int = 4,
148
+ ) -> "AudioLatentShape":
149
+ return AudioLatentShape.from_duration(
150
+ batch=shape.batch,
151
+ duration=float(shape.frames) / float(shape.fps),
152
+ channels=channels,
153
+ mel_bins=mel_bins,
154
+ sample_rate=sample_rate,
155
+ hop_length=hop_length,
156
+ audio_latent_downsample_factor=audio_latent_downsample_factor,
157
+ )
158
+
159
+
160
+ @dataclass(frozen=True)
161
+ class LatentState:
162
+ """
163
+ State of latents during the diffusion denoising process.
164
+ Attributes:
165
+ latent: The current noisy latent tensor being denoised.
166
+ denoise_mask: Mask encoding the denoising strength for each token (1 = full denoising, 0 = no denoising).
167
+ positions: Positional indices for each latent element, used for positional embeddings.
168
+ clean_latent: Initial state of the latent before denoising, may include conditioning latents.
169
+ """
170
+
171
+ latent: torch.Tensor
172
+ denoise_mask: torch.Tensor
173
+ positions: torch.Tensor
174
+ clean_latent: torch.Tensor
175
+
176
+ def clone(self) -> "LatentState":
177
+ return LatentState(
178
+ latent=self.latent.clone(),
179
+ denoise_mask=self.denoise_mask.clone(),
180
+ positions=self.positions.clone(),
181
+ clean_latent=self.clean_latent.clone(),
182
+ )
183
+
184
+
185
+ class NormType(Enum):
186
+ """Normalization layer types: GROUP (GroupNorm) or PIXEL (per-location RMS norm)."""
187
+
188
+ GROUP = "group"
189
+ PIXEL = "pixel"
190
+
191
+
192
+ class PixelNorm(nn.Module):
193
+ """
194
+ Per-pixel (per-location) RMS normalization layer.
195
+ For each element along the chosen dimension, this layer normalizes the tensor
196
+ by the root-mean-square of its values across that dimension:
197
+ y = x / sqrt(mean(x^2, dim=dim, keepdim=True) + eps)
198
+ """
199
+
200
+ def __init__(self, dim: int = 1, eps: float = 1e-8) -> None:
201
+ """
202
+ Args:
203
+ dim: Dimension along which to compute the RMS (typically channels).
204
+ eps: Small constant added for numerical stability.
205
+ """
206
+ super().__init__()
207
+ self.dim = dim
208
+ self.eps = eps
209
+
210
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
211
+ """
212
+ Apply RMS normalization along the configured dimension.
213
+ """
214
+ # Compute mean of squared values along `dim`, keep dimensions for broadcasting.
215
+ mean_sq = torch.mean(x**2, dim=self.dim, keepdim=True)
216
+ # Normalize by the root-mean-square (RMS).
217
+ rms = torch.sqrt(mean_sq + self.eps)
218
+ return x / rms
219
+
220
+
221
+ def build_normalization_layer(
222
+ in_channels: int, *, num_groups: int = 32, normtype: NormType = NormType.GROUP
223
+ ) -> nn.Module:
224
+ """
225
+ Create a normalization layer based on the normalization type.
226
+ Args:
227
+ in_channels: Number of input channels
228
+ num_groups: Number of groups for group normalization
229
+ normtype: Type of normalization: "group" or "pixel"
230
+ Returns:
231
+ A normalization layer
232
+ """
233
+ if normtype == NormType.GROUP:
234
+ return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
235
+ if normtype == NormType.PIXEL:
236
+ return PixelNorm(dim=1, eps=1e-6)
237
+ raise ValueError(f"Invalid normalization type: {normtype}")
238
+
239
+
240
+ def rms_norm(x: torch.Tensor, weight: torch.Tensor | None = None, eps: float = 1e-6) -> torch.Tensor:
241
+ """Root-mean-square (RMS) normalize `x` over its last dimension.
242
+ Thin wrapper around `torch.nn.functional.rms_norm` that infers the normalized
243
+ shape and forwards `weight` and `eps`.
244
+ """
245
+ return torch.nn.functional.rms_norm(x, (x.shape[-1],), weight=weight, eps=eps)
246
+
247
+
248
+ @dataclass(frozen=True)
249
+ class Modality:
250
+ """
251
+ Input data for a single modality (video or audio) in the transformer.
252
+ Bundles the latent tokens, timestep embeddings, positional information,
253
+ and text conditioning context for processing by the diffusion transformer.
254
+ Attributes:
255
+ latent: Patchified latent tokens, shape ``(B, T, D)`` where *B* is
256
+ the batch size, *T* is the total number of tokens (noisy +
257
+ conditioning), and *D* is the input dimension.
258
+ timesteps: Per-token timestep embeddings, shape ``(B, T)``.
259
+ positions: Positional coordinates, shape ``(B, 3, T)`` for video
260
+ (time, height, width) or ``(B, 1, T)`` for audio.
261
+ context: Text conditioning embeddings from the prompt encoder.
262
+ enabled: Whether this modality is active in the current forward pass.
263
+ context_mask: Optional mask for the text context tokens.
264
+ attention_mask: Optional 2-D self-attention mask, shape ``(B, T, T)``.
265
+ Values in ``[0, 1]`` where ``1`` = full attention and ``0`` = no
266
+ attention. ``None`` means unrestricted (full) attention between
267
+ all tokens. Built incrementally by conditioning items; see
268
+ :class:`~ltx_core.conditioning.types.attention_strength_wrapper.ConditioningItemAttentionStrengthWrapper`.
269
+ """
270
+
271
+ latent: (
272
+ torch.Tensor
273
+ ) # Shape: (B, T, D) where B is the batch size, T is the number of tokens, and D is input dimension
274
+ sigma: torch.Tensor # Shape: (B,). Current sigma value, used for cross-attention timestep calculation.
275
+ timesteps: torch.Tensor # Shape: (B, T) where T is the number of timesteps
276
+ positions: (
277
+ torch.Tensor
278
+ ) # Shape: (B, 3, T) for video, where 3 is the number of dimensions and T is the number of tokens
279
+ context: torch.Tensor
280
+ enabled: bool = True
281
+ context_mask: torch.Tensor | None = None
282
+ attention_mask: torch.Tensor | None = None
283
+
284
+
285
+ def to_denoised(
286
+ sample: torch.Tensor,
287
+ velocity: torch.Tensor,
288
+ sigma: float | torch.Tensor,
289
+ calc_dtype: torch.dtype = torch.float32,
290
+ ) -> torch.Tensor:
291
+ """
292
+ Convert the sample and its denoising velocity to denoised sample.
293
+ Returns:
294
+ Denoised sample
295
+ """
296
+ if isinstance(sigma, torch.Tensor):
297
+ sigma = sigma.to(calc_dtype)
298
+ return (sample.to(calc_dtype) - velocity.to(calc_dtype) * sigma).to(sample.dtype)
299
+
300
+
301
+
302
+ class Patchifier(Protocol):
303
+ """
304
+ Protocol for patchifiers that convert latent tensors into patches and assemble them back.
305
+ """
306
+
307
+ def patchify(
308
+ self,
309
+ latents: torch.Tensor,
310
+ ) -> torch.Tensor:
311
+ ...
312
+ """
313
+ Convert latent tensors into flattened patch tokens.
314
+ Args:
315
+ latents: Latent tensor to patchify.
316
+ Returns:
317
+ Flattened patch tokens tensor.
318
+ """
319
+
320
+ def unpatchify(
321
+ self,
322
+ latents: torch.Tensor,
323
+ output_shape: AudioLatentShape | VideoLatentShape,
324
+ ) -> torch.Tensor:
325
+ """
326
+ Converts latent tensors between spatio-temporal formats and flattened sequence representations.
327
+ Args:
328
+ latents: Patch tokens that must be rearranged back into the latent grid constructed by `patchify`.
329
+ output_shape: Shape of the output tensor. Note that output_shape is either AudioLatentShape or
330
+ VideoLatentShape.
331
+ Returns:
332
+ Dense latent tensor restored from the flattened representation.
333
+ """
334
+
335
+ @property
336
+ def patch_size(self) -> Tuple[int, int, int]:
337
+ ...
338
+ """
339
+ Returns the patch size as a tuple of (temporal, height, width) dimensions
340
+ """
341
+
342
+ def get_patch_grid_bounds(
343
+ self,
344
+ output_shape: AudioLatentShape | VideoLatentShape,
345
+ device: torch.device | None = None,
346
+ ) -> torch.Tensor:
347
+ ...
348
+ """
349
+ Compute metadata describing where each latent patch resides within the
350
+ grid specified by `output_shape`.
351
+ Args:
352
+ output_shape: Target grid layout for the patches.
353
+ device: Target device for the returned tensor.
354
+ Returns:
355
+ Tensor containing patch coordinate metadata such as spatial or temporal intervals.
356
+ """
357
+
358
+
359
+ def get_pixel_coords(
360
+ latent_coords: torch.Tensor,
361
+ scale_factors: SpatioTemporalScaleFactors,
362
+ causal_fix: bool = False,
363
+ ) -> torch.Tensor:
364
+ """
365
+ Map latent-space `[start, end)` coordinates to their pixel-space equivalents by scaling
366
+ each axis (frame/time, height, width) with the corresponding VAE downsampling factors.
367
+ Optionally compensate for causal encoding that keeps the first frame at unit temporal scale.
368
+ Args:
369
+ latent_coords: Tensor of latent bounds shaped `(batch, 3, num_patches, 2)`.
370
+ scale_factors: SpatioTemporalScaleFactors tuple `(temporal, height, width)` with integer scale factors applied
371
+ per axis.
372
+ causal_fix: When True, rewrites the temporal axis of the first frame so causal VAEs
373
+ that treat frame zero differently still yield non-negative timestamps.
374
+ """
375
+ # Broadcast the VAE scale factors so they align with the `(batch, axis, patch, bound)` layout.
376
+ broadcast_shape = [1] * latent_coords.ndim
377
+ broadcast_shape[1] = -1 # axis dimension corresponds to (frame/time, height, width)
378
+ scale_tensor = torch.tensor(scale_factors, device=latent_coords.device).view(*broadcast_shape)
379
+
380
+ # Apply per-axis scaling to convert latent bounds into pixel-space coordinates.
381
+ pixel_coords = latent_coords * scale_tensor
382
+
383
+ if causal_fix:
384
+ # VAE temporal stride for the very first frame is 1 instead of `scale_factors[0]`.
385
+ # Shift and clamp to keep the first-frame timestamps causal and non-negative.
386
+ pixel_coords[:, 0, ...] = (pixel_coords[:, 0, ...] + 1 - scale_factors[0]).clamp(min=0)
387
+
388
+ return pixel_coords