svjack commited on
Commit
99b022c
·
verified ·
1 Parent(s): 369f429

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +6 -0
  2. .gitignore +4 -0
  3. .ipynb_checkpoints/README-checkpoint.md +154 -0
  4. 20250204-104122_1234.mp4 +3 -0
  5. 20250204-111149_1234.mp4 +3 -0
  6. 20250204-114357_1234.mp4 +3 -0
  7. README.md +154 -0
  8. Star_im_lora_dir/Star_single_im_lora-000030.safetensors +3 -0
  9. Star_im_lora_dir/Star_single_im_lora-000031.safetensors +3 -0
  10. Star_im_lora_dir/Star_single_im_lora-000032.safetensors +3 -0
  11. Star_im_lora_dir/Star_single_im_lora-000033.safetensors +3 -0
  12. Star_im_lora_dir/Star_single_im_lora-000034.safetensors +3 -0
  13. Star_im_lora_dir/Star_single_im_lora-000035.safetensors +3 -0
  14. Star_im_lora_dir/Star_single_im_lora-000036.safetensors +3 -0
  15. Star_im_lora_dir/Star_single_im_lora-000037.safetensors +3 -0
  16. Star_im_lora_dir/Star_single_im_lora-000038.safetensors +3 -0
  17. Star_im_lora_dir/Star_single_im_lora-000039.safetensors +3 -0
  18. Star_im_lora_dir/Star_single_im_lora-000040.safetensors +3 -0
  19. Star_im_lora_dir/Star_single_im_lora-000041.safetensors +3 -0
  20. Star_im_lora_dir/Star_single_im_lora-000042.safetensors +3 -0
  21. Star_im_lora_dir/Star_single_im_lora-000043.safetensors +3 -0
  22. Star_im_lora_dir/Star_single_im_lora-000044.safetensors +3 -0
  23. Star_im_lora_dir/Star_single_im_lora-000045.safetensors +3 -0
  24. Star_im_lora_dir/Star_single_im_lora-000046.safetensors +3 -0
  25. Star_im_lora_dir/Star_single_im_lora-000047.safetensors +3 -0
  26. Star_im_lora_dir/Star_single_im_lora-000048.safetensors +3 -0
  27. Star_im_lora_dir/Star_single_im_lora-000049.safetensors +3 -0
  28. Star_im_lora_dir/Star_single_im_lora.safetensors +3 -0
  29. cache_latents.py +245 -0
  30. cache_text_encoder_outputs.py +135 -0
  31. convert_lora.py +129 -0
  32. dataset/__init__.py +0 -0
  33. dataset/config_utils.py +359 -0
  34. dataset/dataset_config.md +293 -0
  35. dataset/image_video_dataset.py +1255 -0
  36. hunyuan_model/__init__.py +0 -0
  37. hunyuan_model/activation_layers.py +23 -0
  38. hunyuan_model/attention.py +230 -0
  39. hunyuan_model/autoencoder_kl_causal_3d.py +609 -0
  40. hunyuan_model/embed_layers.py +132 -0
  41. hunyuan_model/helpers.py +40 -0
  42. hunyuan_model/mlp_layers.py +118 -0
  43. hunyuan_model/models.py +997 -0
  44. hunyuan_model/modulate_layers.py +76 -0
  45. hunyuan_model/norm_layers.py +79 -0
  46. hunyuan_model/pipeline_hunyuan_video.py +1100 -0
  47. hunyuan_model/posemb_layers.py +310 -0
  48. hunyuan_model/text_encoder.py +438 -0
  49. hunyuan_model/token_refiner.py +236 -0
  50. hunyuan_model/vae.py +442 -0
.gitattributes CHANGED
@@ -33,3 +33,9 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ 20250131-122504_1234.mp4 filter=lfs diff=lfs merge=lfs -text
37
+ 20250131-125418_1234.mp4 filter=lfs diff=lfs merge=lfs -text
38
+ 20250131-130555_1234.mp4 filter=lfs diff=lfs merge=lfs -text
39
+ 20250204-104122_1234.mp4 filter=lfs diff=lfs merge=lfs -text
40
+ 20250204-111149_1234.mp4 filter=lfs diff=lfs merge=lfs -text
41
+ 20250204-114357_1234.mp4 filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ __pycache__/
2
+ .venv
3
+ venv/
4
+ logs/
.ipynb_checkpoints/README-checkpoint.md ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Prince Star (Kim Hyesung) HunyuanVideo LoRA
2
+
3
+ This repository contains the necessary setup and scripts to generate videos using the HunyuanVideo model with a LoRA (Low-Rank Adaptation) fine-tuned for Kim Hyesung. Below are the instructions to install dependencies, download models, and run the demo.
4
+
5
+ ---
6
+
7
+ ## Installation
8
+
9
+ ### Step 1: Install System Dependencies
10
+ Run the following command to install required system packages:
11
+ ```bash
12
+ sudo apt-get update && sudo apt-get install git-lfs ffmpeg cbm
13
+ ```
14
+
15
+ ### Step 2: Clone the Repository
16
+ Clone the repository and navigate to the project directory:
17
+ ```bash
18
+ git clone https://huggingface.co/svjack/Prince_Star_HunyuanVideo_lora
19
+ cd Prince_Star_HunyuanVideo_lora
20
+ ```
21
+
22
+ ### Step 3: Install Python Dependencies
23
+ Install the required Python packages:
24
+ ```bash
25
+ conda create -n py310 python=3.10
26
+ conda activate py310
27
+ pip install ipykernel
28
+ python -m ipykernel install --user --name py310 --display-name "py310"
29
+
30
+ pip install -r requirements.txt
31
+ pip install ascii-magic matplotlib tensorboard huggingface_hub
32
+ pip install moviepy==1.0.3
33
+ pip install sageattention==1.0.6
34
+
35
+ pip install torch==2.5.0 torchvision
36
+ ```
37
+
38
+ ---
39
+
40
+ ## Download Models
41
+
42
+ ### Step 1: Download HunyuanVideo Model
43
+ Download the HunyuanVideo model and place it in the `ckpts` directory:
44
+ ```bash
45
+ huggingface-cli download tencent/HunyuanVideo --local-dir ./ckpts
46
+ ```
47
+
48
+ ### Step 2: Download LLaVA Model
49
+ Download the LLaVA model and preprocess it:
50
+ ```bash
51
+ cd ckpts
52
+ huggingface-cli download xtuner/llava-llama-3-8b-v1_1-transformers --local-dir ./llava-llama-3-8b-v1_1-transformers
53
+ wget https://raw.githubusercontent.com/Tencent/HunyuanVideo/refs/heads/main/hyvideo/utils/preprocess_text_encoder_tokenizer_utils.py
54
+ python preprocess_text_encoder_tokenizer_utils.py --input_dir llava-llama-3-8b-v1_1-transformers --output_dir text_encoder
55
+ ```
56
+
57
+ ### Step 3: Download CLIP Model
58
+ Download the CLIP model for the text encoder:
59
+ ```bash
60
+ huggingface-cli download openai/clip-vit-large-patch14 --local-dir ./text_encoder_2
61
+ ```
62
+
63
+ ---
64
+
65
+ ## Demo
66
+
67
+ ### Generate Video 1: Kim Hyesung Sun
68
+ Run the following command to generate a video of Prince Kim Hyesung:
69
+ ```bash
70
+ python hv_generate_video.py \
71
+ --fp8 \
72
+ --video_size 544 960 \
73
+ --video_length 60 \
74
+ --infer_steps 30 \
75
+ --prompt "fantastic artwork of Kim Hyesung. warm sunset in a rural village. the interior of a futuristic spaceship in the background." \
76
+ --save_path . \
77
+ --output_type both \
78
+ --dit ckpts/hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt \
79
+ --attn_mode sdpa \
80
+ --vae ckpts/hunyuan-video-t2v-720p/vae/pytorch_model.pt \
81
+ --vae_chunk_size 32 \
82
+ --vae_spatial_tile_sample_min_size 128 \
83
+ --text_encoder1 ckpts/text_encoder \
84
+ --text_encoder2 ckpts/text_encoder_2 \
85
+ --seed 1234 \
86
+ --lora_multiplier 1.0 \
87
+ --lora_weight Star_im_lora_dir/Star_single_im_lora-000040.safetensors
88
+ ```
89
+
90
+
91
+ <video controls autoplay src="https://huggingface.co/svjack/Prince_Star_HunyuanVideo_lora/resolve/main/20250204-104122_1234.mp4"></video>
92
+
93
+
94
+ ### Generate Video 2: Kim Hyesung Sea
95
+ Run the following command to generate a video of Prince Kim Hyesung:
96
+ ```bash
97
+ python hv_generate_video.py \
98
+ --fp8 \
99
+ --video_size 544 960 \
100
+ --video_length 60 \
101
+ --infer_steps 30 \
102
+ --prompt "surrealist painting of Kim Hyesung. underwater glow, deep sea. a peaceful zen garden with koi pond in the background." \
103
+ --save_path . \
104
+ --output_type both \
105
+ --dit ckpts/hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt \
106
+ --attn_mode sdpa \
107
+ --vae ckpts/hunyuan-video-t2v-720p/vae/pytorch_model.pt \
108
+ --vae_chunk_size 32 \
109
+ --vae_spatial_tile_sample_min_size 128 \
110
+ --text_encoder1 ckpts/text_encoder \
111
+ --text_encoder2 ckpts/text_encoder_2 \
112
+ --seed 1234 \
113
+ --lora_multiplier 1.0 \
114
+ --lora_weight Star_im_lora_dir/Star_single_im_lora-000040.safetensors
115
+ ```
116
+
117
+
118
+ <video controls autoplay src="https://huggingface.co/svjack/Prince_Star_HunyuanVideo_lora/resolve/main/20250204-111149_1234.mp4"></video>
119
+
120
+ ### Generate Video 1: Kim Hyesung Class
121
+ Run the following command to generate a video of Prince Kim Hyesung:
122
+ ```bash
123
+ python hv_generate_video.py \
124
+ --fp8 \
125
+ --video_size 544 960 \
126
+ --video_length 60 \
127
+ --infer_steps 30 \
128
+ --prompt "Kim Hyesung, a young person with straight, dark hair, wearing a white school uniform. They are seated in a classroom with other students, all dressed in white uniforms. The background includes a wooden door and blurred figures of other students, suggesting a school setting. The lighting is soft, and the image has a slightly grainy texture, adding to the realistic and candid feel." \
129
+ --save_path . \
130
+ --output_type both \
131
+ --dit ckpts/hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt \
132
+ --attn_mode sdpa \
133
+ --vae ckpts/hunyuan-video-t2v-720p/vae/pytorch_model.pt \
134
+ --vae_chunk_size 32 \
135
+ --vae_spatial_tile_sample_min_size 128 \
136
+ --text_encoder1 ckpts/text_encoder \
137
+ --text_encoder2 ckpts/text_encoder_2 \
138
+ --seed 1234 \
139
+ --lora_multiplier 1.0 \
140
+ --lora_weight Star_im_lora_dir/Star_single_im_lora-000040.safetensors
141
+ ```
142
+
143
+
144
+ <video controls autoplay src="https://huggingface.co/svjack/Prince_Star_HunyuanVideo_lora/resolve/main/20250204-114357_1234.mp4"></video>
145
+
146
+
147
+ ---
148
+
149
+ ## Notes
150
+ - Ensure you have sufficient GPU resources for video generation.
151
+ - Adjust the `--video_size`, `--video_length`, and `--infer_steps` parameters as needed for different output qualities and lengths.
152
+ - The `--prompt` parameter can be modified to generate videos with different scenes or actions.
153
+
154
+ ---
20250204-104122_1234.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:122cf6f9fd13478d006c95fdfa6caafb7a6f138b2b09a814f471cb7e52044224
3
+ size 1109401
20250204-111149_1234.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cc68ab4b39c8ec500fbe1aab1da3c4853967c1da29f2e771fef0d6d6287efd44
3
+ size 1148331
20250204-114357_1234.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dde5e550ad8af64ca53b2fd1bb488b5f9344f0ad01cd0bb7c7da8459bc144a8d
3
+ size 1058260
README.md ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Prince Star (Kim Hyesung) HunyuanVideo LoRA
2
+
3
+ This repository contains the necessary setup and scripts to generate videos using the HunyuanVideo model with a LoRA (Low-Rank Adaptation) fine-tuned for Kim Hyesung. Below are the instructions to install dependencies, download models, and run the demo.
4
+
5
+ ---
6
+
7
+ ## Installation
8
+
9
+ ### Step 1: Install System Dependencies
10
+ Run the following command to install required system packages:
11
+ ```bash
12
+ sudo apt-get update && sudo apt-get install git-lfs ffmpeg cbm
13
+ ```
14
+
15
+ ### Step 2: Clone the Repository
16
+ Clone the repository and navigate to the project directory:
17
+ ```bash
18
+ git clone https://huggingface.co/svjack/Prince_Star_HunyuanVideo_lora
19
+ cd Prince_Star_HunyuanVideo_lora
20
+ ```
21
+
22
+ ### Step 3: Install Python Dependencies
23
+ Install the required Python packages:
24
+ ```bash
25
+ conda create -n py310 python=3.10
26
+ conda activate py310
27
+ pip install ipykernel
28
+ python -m ipykernel install --user --name py310 --display-name "py310"
29
+
30
+ pip install -r requirements.txt
31
+ pip install ascii-magic matplotlib tensorboard huggingface_hub
32
+ pip install moviepy==1.0.3
33
+ pip install sageattention==1.0.6
34
+
35
+ pip install torch==2.5.0 torchvision
36
+ ```
37
+
38
+ ---
39
+
40
+ ## Download Models
41
+
42
+ ### Step 1: Download HunyuanVideo Model
43
+ Download the HunyuanVideo model and place it in the `ckpts` directory:
44
+ ```bash
45
+ huggingface-cli download tencent/HunyuanVideo --local-dir ./ckpts
46
+ ```
47
+
48
+ ### Step 2: Download LLaVA Model
49
+ Download the LLaVA model and preprocess it:
50
+ ```bash
51
+ cd ckpts
52
+ huggingface-cli download xtuner/llava-llama-3-8b-v1_1-transformers --local-dir ./llava-llama-3-8b-v1_1-transformers
53
+ wget https://raw.githubusercontent.com/Tencent/HunyuanVideo/refs/heads/main/hyvideo/utils/preprocess_text_encoder_tokenizer_utils.py
54
+ python preprocess_text_encoder_tokenizer_utils.py --input_dir llava-llama-3-8b-v1_1-transformers --output_dir text_encoder
55
+ ```
56
+
57
+ ### Step 3: Download CLIP Model
58
+ Download the CLIP model for the text encoder:
59
+ ```bash
60
+ huggingface-cli download openai/clip-vit-large-patch14 --local-dir ./text_encoder_2
61
+ ```
62
+
63
+ ---
64
+
65
+ ## Demo
66
+
67
+ ### Generate Video 1: Kim Hyesung Sun
68
+ Run the following command to generate a video of Prince Kim Hyesung:
69
+ ```bash
70
+ python hv_generate_video.py \
71
+ --fp8 \
72
+ --video_size 544 960 \
73
+ --video_length 60 \
74
+ --infer_steps 30 \
75
+ --prompt "fantastic artwork of Kim Hyesung. warm sunset in a rural village. the interior of a futuristic spaceship in the background." \
76
+ --save_path . \
77
+ --output_type both \
78
+ --dit ckpts/hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt \
79
+ --attn_mode sdpa \
80
+ --vae ckpts/hunyuan-video-t2v-720p/vae/pytorch_model.pt \
81
+ --vae_chunk_size 32 \
82
+ --vae_spatial_tile_sample_min_size 128 \
83
+ --text_encoder1 ckpts/text_encoder \
84
+ --text_encoder2 ckpts/text_encoder_2 \
85
+ --seed 1234 \
86
+ --lora_multiplier 1.0 \
87
+ --lora_weight Star_im_lora_dir/Star_single_im_lora-000040.safetensors
88
+ ```
89
+
90
+
91
+ <video controls autoplay src="https://huggingface.co/svjack/Prince_Star_HunyuanVideo_lora/resolve/main/20250204-104122_1234.mp4"></video>
92
+
93
+
94
+ ### Generate Video 2: Kim Hyesung Sea
95
+ Run the following command to generate a video of Prince Kim Hyesung:
96
+ ```bash
97
+ python hv_generate_video.py \
98
+ --fp8 \
99
+ --video_size 544 960 \
100
+ --video_length 60 \
101
+ --infer_steps 30 \
102
+ --prompt "surrealist painting of Kim Hyesung. underwater glow, deep sea. a peaceful zen garden with koi pond in the background." \
103
+ --save_path . \
104
+ --output_type both \
105
+ --dit ckpts/hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt \
106
+ --attn_mode sdpa \
107
+ --vae ckpts/hunyuan-video-t2v-720p/vae/pytorch_model.pt \
108
+ --vae_chunk_size 32 \
109
+ --vae_spatial_tile_sample_min_size 128 \
110
+ --text_encoder1 ckpts/text_encoder \
111
+ --text_encoder2 ckpts/text_encoder_2 \
112
+ --seed 1234 \
113
+ --lora_multiplier 1.0 \
114
+ --lora_weight Star_im_lora_dir/Star_single_im_lora-000040.safetensors
115
+ ```
116
+
117
+
118
+ <video controls autoplay src="https://huggingface.co/svjack/Prince_Star_HunyuanVideo_lora/resolve/main/20250204-111149_1234.mp4"></video>
119
+
120
+ ### Generate Video 1: Kim Hyesung Class
121
+ Run the following command to generate a video of Prince Kim Hyesung:
122
+ ```bash
123
+ python hv_generate_video.py \
124
+ --fp8 \
125
+ --video_size 544 960 \
126
+ --video_length 60 \
127
+ --infer_steps 30 \
128
+ --prompt "Kim Hyesung, a young person with straight, dark hair, wearing a white school uniform. They are seated in a classroom with other students, all dressed in white uniforms. The background includes a wooden door and blurred figures of other students, suggesting a school setting. The lighting is soft, and the image has a slightly grainy texture, adding to the realistic and candid feel." \
129
+ --save_path . \
130
+ --output_type both \
131
+ --dit ckpts/hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt \
132
+ --attn_mode sdpa \
133
+ --vae ckpts/hunyuan-video-t2v-720p/vae/pytorch_model.pt \
134
+ --vae_chunk_size 32 \
135
+ --vae_spatial_tile_sample_min_size 128 \
136
+ --text_encoder1 ckpts/text_encoder \
137
+ --text_encoder2 ckpts/text_encoder_2 \
138
+ --seed 1234 \
139
+ --lora_multiplier 1.0 \
140
+ --lora_weight Star_im_lora_dir/Star_single_im_lora-000040.safetensors
141
+ ```
142
+
143
+
144
+ <video controls autoplay src="https://huggingface.co/svjack/Prince_Star_HunyuanVideo_lora/resolve/main/20250204-114357_1234.mp4"></video>
145
+
146
+
147
+ ---
148
+
149
+ ## Notes
150
+ - Ensure you have sufficient GPU resources for video generation.
151
+ - Adjust the `--video_size`, `--video_length`, and `--infer_steps` parameters as needed for different output qualities and lengths.
152
+ - The `--prompt` parameter can be modified to generate videos with different scenes or actions.
153
+
154
+ ---
Star_im_lora_dir/Star_single_im_lora-000030.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7a24b6c96e93260558ca1c77eab6ecee4f350b641624f96b957336fc7d25fb3d
3
+ size 322557552
Star_im_lora_dir/Star_single_im_lora-000031.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:34224c0a987eb6919f244d4092dad34865b2c98e8a40eea2db4764ef580b72f6
3
+ size 322557552
Star_im_lora_dir/Star_single_im_lora-000032.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ee0603ba5fb45a9e0b2ba726bd3e868c681d72e83ac509d764ba99ebd635db2d
3
+ size 322557552
Star_im_lora_dir/Star_single_im_lora-000033.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3da0c7ff2cbf98a27818e079b94ae21390068489da23b29fba6a045f2e55a1c6
3
+ size 322557552
Star_im_lora_dir/Star_single_im_lora-000034.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:66a46a68a5385bbe018d3b3fe307f21e544e5b9546a4057ca591c6df8314e25e
3
+ size 322557552
Star_im_lora_dir/Star_single_im_lora-000035.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f23bdd6d9438cd0260df5ad17244a98e0e615fa507ebf092ee799443150b69f9
3
+ size 322557552
Star_im_lora_dir/Star_single_im_lora-000036.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bfd3c266c8de8ba2886b73439e78b3288ac0a8b18bf52d07417c30797ca151cd
3
+ size 322557552
Star_im_lora_dir/Star_single_im_lora-000037.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bc71502683b12f2532b71c9099733eb31ec2bb2eef6bf185a6b5c032df76dcd6
3
+ size 322557552
Star_im_lora_dir/Star_single_im_lora-000038.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d565b636caffdd453b99ceec541e67dff1fe803814dcb77a76ef84c11f5fac42
3
+ size 322557552
Star_im_lora_dir/Star_single_im_lora-000039.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:44d9dfb1e994b153be2d406eb93aa4578719f67e2b730d80a6764ed21e72bff1
3
+ size 322557552
Star_im_lora_dir/Star_single_im_lora-000040.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:582f0f6882a73d7b4673e223c2276f76ac129b5685ec4517c4d7dbe948b622b4
3
+ size 322557552
Star_im_lora_dir/Star_single_im_lora-000041.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e43c3b84a50c1c3deef78150aa29d4f0203d377d331ccaf0d15e5ae097688bab
3
+ size 322557552
Star_im_lora_dir/Star_single_im_lora-000042.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5b2be76884c7d20c0943838406ce030f04197f21baa4e622ab97400a4d8e37fe
3
+ size 322557552
Star_im_lora_dir/Star_single_im_lora-000043.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4feaee78e3b65dec3465b7131c1a2adeef5b0ff8090199dbe4fb5459a74016a3
3
+ size 322557552
Star_im_lora_dir/Star_single_im_lora-000044.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:53a2171f59b70af7820d2cb032b35f78cb1f93193497a2b6a1c03853a85103ed
3
+ size 322557552
Star_im_lora_dir/Star_single_im_lora-000045.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7a4ec1d41c6938541f05fb65b5c614afea0613157bc084fb4ccd6ac0a96975d2
3
+ size 322557552
Star_im_lora_dir/Star_single_im_lora-000046.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:67c4bb9694a00c0bbaae59e9e615aa6fd6dc556a153b90c1766ec87efed0cd23
3
+ size 322557552
Star_im_lora_dir/Star_single_im_lora-000047.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2ee33e2660d6fc184b9e94daf46140ff4a686fb23dcfc150e98b18c04f51641e
3
+ size 322557552
Star_im_lora_dir/Star_single_im_lora-000048.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:82755bc9dddeed5932690c6949e61ab26f5edf2ef358cc1caad6335d1b49bcdd
3
+ size 322557552
Star_im_lora_dir/Star_single_im_lora-000049.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:99def0d3d2a584e6c382a6cf7fe737716b96dfebbd2f82e722ef7cc2ad751b57
3
+ size 322557552
Star_im_lora_dir/Star_single_im_lora.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ec8e65f91b49ae338ec7b8d14d0273bac63437027534cf5ad2ebbee0ac1f6b02
3
+ size 322557552
cache_latents.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ from typing import Optional, Union
4
+
5
+ import numpy as np
6
+ import torch
7
+ from tqdm import tqdm
8
+
9
+ from dataset import config_utils
10
+ from dataset.config_utils import BlueprintGenerator, ConfigSanitizer
11
+ from PIL import Image
12
+
13
+ import logging
14
+
15
+ from dataset.image_video_dataset import BaseDataset, ItemInfo, save_latent_cache
16
+ from hunyuan_model.vae import load_vae
17
+ from hunyuan_model.autoencoder_kl_causal_3d import AutoencoderKLCausal3D
18
+ from utils.model_utils import str_to_dtype
19
+
20
+ logger = logging.getLogger(__name__)
21
+ logging.basicConfig(level=logging.INFO)
22
+
23
+
24
+ def show_image(image: Union[list[Union[Image.Image, np.ndarray], Union[Image.Image, np.ndarray]]]) -> int:
25
+ import cv2
26
+
27
+ imgs = (
28
+ [image]
29
+ if (isinstance(image, np.ndarray) and len(image.shape) == 3) or isinstance(image, Image.Image)
30
+ else [image[0], image[-1]]
31
+ )
32
+ if len(imgs) > 1:
33
+ print(f"Number of images: {len(image)}")
34
+ for i, img in enumerate(imgs):
35
+ if len(imgs) > 1:
36
+ print(f"{'First' if i == 0 else 'Last'} image: {img.shape}")
37
+ else:
38
+ print(f"Image: {img.shape}")
39
+ cv2_img = np.array(img) if isinstance(img, Image.Image) else img
40
+ cv2_img = cv2.cvtColor(cv2_img, cv2.COLOR_RGB2BGR)
41
+ cv2.imshow("image", cv2_img)
42
+ k = cv2.waitKey(0)
43
+ cv2.destroyAllWindows()
44
+ if k == ord("q") or k == ord("d"):
45
+ return k
46
+ return k
47
+
48
+
49
+ def show_console(
50
+ image: Union[list[Union[Image.Image, np.ndarray], Union[Image.Image, np.ndarray]]],
51
+ width: int,
52
+ back: str,
53
+ interactive: bool = False,
54
+ ) -> int:
55
+ from ascii_magic import from_pillow_image, Back
56
+
57
+ back = None
58
+ if back is not None:
59
+ back = getattr(Back, back.upper())
60
+
61
+ k = None
62
+ imgs = (
63
+ [image]
64
+ if (isinstance(image, np.ndarray) and len(image.shape) == 3) or isinstance(image, Image.Image)
65
+ else [image[0], image[-1]]
66
+ )
67
+ if len(imgs) > 1:
68
+ print(f"Number of images: {len(image)}")
69
+ for i, img in enumerate(imgs):
70
+ if len(imgs) > 1:
71
+ print(f"{'First' if i == 0 else 'Last'} image: {img.shape}")
72
+ else:
73
+ print(f"Image: {img.shape}")
74
+ pil_img = img if isinstance(img, Image.Image) else Image.fromarray(img)
75
+ ascii_img = from_pillow_image(pil_img)
76
+ ascii_img.to_terminal(columns=width, back=back)
77
+
78
+ if interactive:
79
+ k = input("Press q to quit, d to next dataset, other key to next: ")
80
+ if k == "q" or k == "d":
81
+ return ord(k)
82
+
83
+ if not interactive:
84
+ return ord(" ")
85
+ return ord(k) if k else ord(" ")
86
+
87
+
88
+ def show_datasets(
89
+ datasets: list[BaseDataset], debug_mode: str, console_width: int, console_back: str, console_num_images: Optional[int]
90
+ ):
91
+ print(f"d: next dataset, q: quit")
92
+
93
+ num_workers = max(1, os.cpu_count() - 1)
94
+ for i, dataset in enumerate(datasets):
95
+ print(f"Dataset [{i}]")
96
+ batch_index = 0
97
+ num_images_to_show = console_num_images
98
+ k = None
99
+ for key, batch in dataset.retrieve_latent_cache_batches(num_workers):
100
+ print(f"bucket resolution: {key}, count: {len(batch)}")
101
+ for j, item_info in enumerate(batch):
102
+ item_info: ItemInfo
103
+ print(f"{batch_index}-{j}: {item_info}")
104
+ if debug_mode == "image":
105
+ k = show_image(item_info.content)
106
+ elif debug_mode == "console":
107
+ k = show_console(item_info.content, console_width, console_back, console_num_images is None)
108
+ if num_images_to_show is not None:
109
+ num_images_to_show -= 1
110
+ if num_images_to_show == 0:
111
+ k = ord("d") # next dataset
112
+
113
+ if k == ord("q"):
114
+ return
115
+ elif k == ord("d"):
116
+ break
117
+ if k == ord("d"):
118
+ break
119
+ batch_index += 1
120
+
121
+
122
+ def encode_and_save_batch(vae: AutoencoderKLCausal3D, batch: list[ItemInfo]):
123
+ contents = torch.stack([torch.from_numpy(item.content) for item in batch])
124
+ if len(contents.shape) == 4:
125
+ contents = contents.unsqueeze(1) # B, H, W, C -> B, F, H, W, C
126
+
127
+ contents = contents.permute(0, 4, 1, 2, 3).contiguous() # B, C, F, H, W
128
+ contents = contents.to(vae.device, dtype=vae.dtype)
129
+ contents = contents / 127.5 - 1.0 # normalize to [-1, 1]
130
+
131
+ # print(f"encode batch: {contents.shape}")
132
+ with torch.no_grad():
133
+ latent = vae.encode(contents).latent_dist.sample()
134
+ latent = latent * vae.config.scaling_factor
135
+
136
+ # # debug: decode and save
137
+ # with torch.no_grad():
138
+ # latent_to_decode = latent / vae.config.scaling_factor
139
+ # images = vae.decode(latent_to_decode, return_dict=False)[0]
140
+ # images = (images / 2 + 0.5).clamp(0, 1)
141
+ # images = images.cpu().float().numpy()
142
+ # images = (images * 255).astype(np.uint8)
143
+ # images = images.transpose(0, 2, 3, 4, 1) # B, C, F, H, W -> B, F, H, W, C
144
+ # for b in range(images.shape[0]):
145
+ # for f in range(images.shape[1]):
146
+ # fln = os.path.splitext(os.path.basename(batch[b].item_key))[0]
147
+ # img = Image.fromarray(images[b, f])
148
+ # img.save(f"./logs/decode_{fln}_{b}_{f:03d}.jpg")
149
+
150
+ for item, l in zip(batch, latent):
151
+ # print(f"save latent cache: {item.latent_cache_path}, latent shape: {l.shape}")
152
+ save_latent_cache(item, l)
153
+
154
+
155
+ def main(args):
156
+ device = args.device if args.device is not None else "cuda" if torch.cuda.is_available() else "cpu"
157
+ device = torch.device(device)
158
+
159
+ # Load dataset config
160
+ blueprint_generator = BlueprintGenerator(ConfigSanitizer())
161
+ logger.info(f"Load dataset config from {args.dataset_config}")
162
+ user_config = config_utils.load_user_config(args.dataset_config)
163
+ blueprint = blueprint_generator.generate(user_config, args)
164
+ train_dataset_group = config_utils.generate_dataset_group_by_blueprint(blueprint.dataset_group)
165
+
166
+ datasets = train_dataset_group.datasets
167
+
168
+ if args.debug_mode is not None:
169
+ show_datasets(datasets, args.debug_mode, args.console_width, args.console_back, args.console_num_images)
170
+ return
171
+
172
+ assert args.vae is not None, "vae checkpoint is required"
173
+
174
+ # Load VAE model: HunyuanVideo VAE model is float16
175
+ vae_dtype = torch.float16 if args.vae_dtype is None else str_to_dtype(args.vae_dtype)
176
+ vae, _, s_ratio, t_ratio = load_vae(vae_dtype=vae_dtype, device=device, vae_path=args.vae)
177
+ vae.eval()
178
+ print(f"Loaded VAE: {vae.config}, dtype: {vae.dtype}")
179
+
180
+ if args.vae_chunk_size is not None:
181
+ vae.set_chunk_size_for_causal_conv_3d(args.vae_chunk_size)
182
+ logger.info(f"Set chunk_size to {args.vae_chunk_size} for CausalConv3d in VAE")
183
+ if args.vae_spatial_tile_sample_min_size is not None:
184
+ vae.enable_spatial_tiling(True)
185
+ vae.tile_sample_min_size = args.vae_spatial_tile_sample_min_size
186
+ vae.tile_latent_min_size = args.vae_spatial_tile_sample_min_size // 8
187
+ elif args.vae_tiling:
188
+ vae.enable_spatial_tiling(True)
189
+
190
+ # Encode images
191
+ num_workers = args.num_workers if args.num_workers is not None else max(1, os.cpu_count() - 1)
192
+ for i, dataset in enumerate(datasets):
193
+ print(f"Encoding dataset [{i}]")
194
+ for _, batch in tqdm(dataset.retrieve_latent_cache_batches(num_workers)):
195
+ if args.skip_existing:
196
+ filtered_batch = [item for item in batch if not os.path.exists(item.latent_cache_path)]
197
+ if len(filtered_batch) == 0:
198
+ continue
199
+ batch = filtered_batch
200
+
201
+ bs = args.batch_size if args.batch_size is not None else len(batch)
202
+ for i in range(0, len(batch), bs):
203
+ encode_and_save_batch(vae, batch[i : i + bs])
204
+
205
+
206
+ def setup_parser():
207
+ parser = argparse.ArgumentParser()
208
+
209
+ parser.add_argument("--dataset_config", type=str, required=True, help="path to dataset config .toml file")
210
+ parser.add_argument("--vae", type=str, required=False, default=None, help="path to vae checkpoint")
211
+ parser.add_argument("--vae_dtype", type=str, default=None, help="data type for VAE, default is float16")
212
+ parser.add_argument(
213
+ "--vae_tiling",
214
+ action="store_true",
215
+ help="enable spatial tiling for VAE, default is False. If vae_spatial_tile_sample_min_size is set, this is automatically enabled",
216
+ )
217
+ parser.add_argument("--vae_chunk_size", type=int, default=None, help="chunk size for CausalConv3d in VAE")
218
+ parser.add_argument(
219
+ "--vae_spatial_tile_sample_min_size", type=int, default=None, help="spatial tile sample min size for VAE, default 256"
220
+ )
221
+ parser.add_argument("--device", type=str, default=None, help="device to use, default is cuda if available")
222
+ parser.add_argument(
223
+ "--batch_size", type=int, default=None, help="batch size, override dataset config if dataset batch size > this"
224
+ )
225
+ parser.add_argument("--num_workers", type=int, default=None, help="number of workers for dataset. default is cpu count-1")
226
+ parser.add_argument("--skip_existing", action="store_true", help="skip existing cache files")
227
+ parser.add_argument("--debug_mode", type=str, default=None, choices=["image", "console"], help="debug mode")
228
+ parser.add_argument("--console_width", type=int, default=80, help="debug mode: console width")
229
+ parser.add_argument(
230
+ "--console_back", type=str, default=None, help="debug mode: console background color, one of ascii_magic.Back"
231
+ )
232
+ parser.add_argument(
233
+ "--console_num_images",
234
+ type=int,
235
+ default=None,
236
+ help="debug mode: not interactive, number of images to show for each dataset",
237
+ )
238
+ return parser
239
+
240
+
241
+ if __name__ == "__main__":
242
+ parser = setup_parser()
243
+
244
+ args = parser.parse_args()
245
+ main(args)
cache_text_encoder_outputs.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ from typing import Optional, Union
4
+
5
+ import numpy as np
6
+ import torch
7
+ from tqdm import tqdm
8
+
9
+ from dataset import config_utils
10
+ from dataset.config_utils import BlueprintGenerator, ConfigSanitizer
11
+ import accelerate
12
+
13
+ from dataset.image_video_dataset import ItemInfo, save_text_encoder_output_cache
14
+ from hunyuan_model import text_encoder as text_encoder_module
15
+ from hunyuan_model.text_encoder import TextEncoder
16
+
17
+ import logging
18
+
19
+ from utils.model_utils import str_to_dtype
20
+
21
+ logger = logging.getLogger(__name__)
22
+ logging.basicConfig(level=logging.INFO)
23
+
24
+
25
+ def encode_prompt(text_encoder: TextEncoder, prompt: Union[str, list[str]]):
26
+ data_type = "video" # video only, image is not supported
27
+ text_inputs = text_encoder.text2tokens(prompt, data_type=data_type)
28
+
29
+ with torch.no_grad():
30
+ prompt_outputs = text_encoder.encode(text_inputs, data_type=data_type)
31
+
32
+ return prompt_outputs.hidden_state, prompt_outputs.attention_mask
33
+
34
+
35
+ def encode_and_save_batch(
36
+ text_encoder: TextEncoder, batch: list[ItemInfo], is_llm: bool, accelerator: Optional[accelerate.Accelerator]
37
+ ):
38
+ prompts = [item.caption for item in batch]
39
+ # print(prompts)
40
+
41
+ # encode prompt
42
+ if accelerator is not None:
43
+ with accelerator.autocast():
44
+ prompt_embeds, prompt_mask = encode_prompt(text_encoder, prompts)
45
+ else:
46
+ prompt_embeds, prompt_mask = encode_prompt(text_encoder, prompts)
47
+
48
+ # # convert to fp16 if needed
49
+ # if prompt_embeds.dtype == torch.float32 and text_encoder.dtype != torch.float32:
50
+ # prompt_embeds = prompt_embeds.to(text_encoder.dtype)
51
+
52
+ # save prompt cache
53
+ for item, embed, mask in zip(batch, prompt_embeds, prompt_mask):
54
+ save_text_encoder_output_cache(item, embed, mask, is_llm)
55
+
56
+
57
+ def main(args):
58
+ device = args.device if args.device is not None else "cuda" if torch.cuda.is_available() else "cpu"
59
+ device = torch.device(device)
60
+
61
+ # Load dataset config
62
+ blueprint_generator = BlueprintGenerator(ConfigSanitizer())
63
+ logger.info(f"Load dataset config from {args.dataset_config}")
64
+ user_config = config_utils.load_user_config(args.dataset_config)
65
+ blueprint = blueprint_generator.generate(user_config, args)
66
+ train_dataset_group = config_utils.generate_dataset_group_by_blueprint(blueprint.dataset_group)
67
+
68
+ datasets = train_dataset_group.datasets
69
+
70
+ # define accelerator for fp8 inference
71
+ accelerator = None
72
+ if args.fp8_llm:
73
+ accelerator = accelerate.Accelerator(mixed_precision="fp16")
74
+
75
+ # define encode function
76
+ num_workers = args.num_workers if args.num_workers is not None else max(1, os.cpu_count() - 1)
77
+
78
+ def encode_for_text_encoder(text_encoder: TextEncoder, is_llm: bool):
79
+ for i, dataset in enumerate(datasets):
80
+ print(f"Encoding dataset [{i}]")
81
+ for batch in tqdm(dataset.retrieve_text_encoder_output_cache_batches(num_workers)):
82
+ if args.skip_existing:
83
+ filtered_batch = [item for item in batch if not os.path.exists(item.text_encoder_output_cache_path)]
84
+ if len(filtered_batch) == 0:
85
+ continue
86
+ batch = filtered_batch
87
+
88
+ bs = args.batch_size if args.batch_size is not None else len(batch)
89
+ for i in range(0, len(batch), bs):
90
+ encode_and_save_batch(text_encoder, batch[i : i + bs], is_llm, accelerator)
91
+
92
+ # Load Text Encoder 1
93
+ text_encoder_dtype = torch.float16 if args.text_encoder_dtype is None else str_to_dtype(args.text_encoder_dtype)
94
+ logger.info(f"loading text encoder 1: {args.text_encoder1}")
95
+ text_encoder_1 = text_encoder_module.load_text_encoder_1(args.text_encoder1, device, args.fp8_llm, text_encoder_dtype)
96
+ text_encoder_1.to(device=device)
97
+
98
+ # Encode with Text Encoder 1
99
+ logger.info("Encoding with Text Encoder 1")
100
+ encode_for_text_encoder(text_encoder_1, is_llm=True)
101
+ del text_encoder_1
102
+
103
+ # Load Text Encoder 2
104
+ logger.info(f"loading text encoder 2: {args.text_encoder2}")
105
+ text_encoder_2 = text_encoder_module.load_text_encoder_2(args.text_encoder2, device, text_encoder_dtype)
106
+ text_encoder_2.to(device=device)
107
+
108
+ # Encode with Text Encoder 2
109
+ logger.info("Encoding with Text Encoder 2")
110
+ encode_for_text_encoder(text_encoder_2, is_llm=False)
111
+ del text_encoder_2
112
+
113
+
114
+ def setup_parser():
115
+ parser = argparse.ArgumentParser()
116
+
117
+ parser.add_argument("--dataset_config", type=str, required=True, help="path to dataset config .toml file")
118
+ parser.add_argument("--text_encoder1", type=str, required=True, help="Text Encoder 1 directory")
119
+ parser.add_argument("--text_encoder2", type=str, required=True, help="Text Encoder 2 directory")
120
+ parser.add_argument("--device", type=str, default=None, help="device to use, default is cuda if available")
121
+ parser.add_argument("--text_encoder_dtype", type=str, default=None, help="data type for Text Encoder, default is float16")
122
+ parser.add_argument("--fp8_llm", action="store_true", help="use fp8 for Text Encoder 1 (LLM)")
123
+ parser.add_argument(
124
+ "--batch_size", type=int, default=None, help="batch size, override dataset config if dataset batch size > this"
125
+ )
126
+ parser.add_argument("--num_workers", type=int, default=None, help="number of workers for dataset. default is cpu count-1")
127
+ parser.add_argument("--skip_existing", action="store_true", help="skip existing cache files")
128
+ return parser
129
+
130
+
131
+ if __name__ == "__main__":
132
+ parser = setup_parser()
133
+
134
+ args = parser.parse_args()
135
+ main(args)
convert_lora.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ import torch
4
+ from safetensors.torch import load_file, save_file
5
+ from safetensors import safe_open
6
+ from utils import model_utils
7
+
8
+ import logging
9
+
10
+
11
+ logger = logging.getLogger(__name__)
12
+ logging.basicConfig(level=logging.INFO)
13
+
14
+
15
+ def convert_from_diffusers(prefix, weights_sd):
16
+ # convert from diffusers(?) to default LoRA
17
+ # Diffusers format: {"diffusion_model.module.name.lora_A.weight": weight, "diffusion_model.module.name.lora_B.weight": weight, ...}
18
+ # default LoRA format: {"prefix_module_name.lora_down.weight": weight, "prefix_module_name.lora_up.weight": weight, ...}
19
+ # note: Diffusers has no alpha, so alpha is set to rank
20
+ new_weights_sd = {}
21
+ lora_dims = {}
22
+ for key, weight in weights_sd.items():
23
+ diffusers_prefix, key_body = key.split(".", 1)
24
+ if diffusers_prefix != "diffusion_model":
25
+ logger.warning(f"unexpected key: {key} in diffusers format")
26
+ continue
27
+
28
+ new_key = f"{prefix}{key_body}".replace(".", "_").replace("_lora_A_", ".lora_down.").replace("_lora_B_", ".lora_up.")
29
+ new_weights_sd[new_key] = weight
30
+
31
+ lora_name = new_key.split(".")[0] # before first dot
32
+ if lora_name not in lora_dims and "lora_down" in new_key:
33
+ lora_dims[lora_name] = weight.shape[0]
34
+
35
+ # add alpha with rank
36
+ for lora_name, dim in lora_dims.items():
37
+ new_weights_sd[f"{lora_name}.alpha"] = torch.tensor(dim)
38
+
39
+ return new_weights_sd
40
+
41
+
42
+ def convert_to_diffusers(prefix, weights_sd):
43
+ # convert from default LoRA to diffusers
44
+
45
+ # get alphas
46
+ lora_alphas = {}
47
+ for key, weight in weights_sd.items():
48
+ if key.startswith(prefix):
49
+ lora_name = key.split(".", 1)[0] # before first dot
50
+ if lora_name not in lora_alphas and "alpha" in key:
51
+ lora_alphas[lora_name] = weight
52
+
53
+ new_weights_sd = {}
54
+ for key, weight in weights_sd.items():
55
+ if key.startswith(prefix):
56
+ if "alpha" in key:
57
+ continue
58
+
59
+ lora_name = key.split(".", 1)[0] # before first dot
60
+
61
+ # HunyuanVideo lora name to module name: ugly but works
62
+ module_name = lora_name[len(prefix) :] # remove "lora_unet_"
63
+ module_name = module_name.replace("_", ".") # replace "_" with "."
64
+ module_name = module_name.replace("double.blocks.", "double_blocks.") # fix double blocks
65
+ module_name = module_name.replace("single.blocks.", "single_blocks.") # fix single blocks
66
+ module_name = module_name.replace("img.", "img_") # fix img
67
+ module_name = module_name.replace("txt.", "txt_") # fix txt
68
+ module_name = module_name.replace("attn.", "attn_") # fix attn
69
+
70
+ diffusers_prefix = "diffusion_model"
71
+ if "lora_down" in key:
72
+ new_key = f"{diffusers_prefix}.{module_name}.lora_A.weight"
73
+ dim = weight.shape[0]
74
+ elif "lora_up" in key:
75
+ new_key = f"{diffusers_prefix}.{module_name}.lora_B.weight"
76
+ dim = weight.shape[1]
77
+ else:
78
+ logger.warning(f"unexpected key: {key} in default LoRA format")
79
+ continue
80
+
81
+ # scale weight by alpha
82
+ if lora_name in lora_alphas:
83
+ # we scale both down and up, so scale is sqrt
84
+ scale = lora_alphas[lora_name] / dim
85
+ scale = scale.sqrt()
86
+ weight = weight * scale
87
+ else:
88
+ logger.warning(f"missing alpha for {lora_name}")
89
+
90
+ new_weights_sd[new_key] = weight
91
+
92
+ return new_weights_sd
93
+
94
+
95
+ def convert(input_file, output_file, target_format):
96
+ logger.info(f"loading {input_file}")
97
+ weights_sd = load_file(input_file)
98
+ with safe_open(input_file, framework="pt") as f:
99
+ metadata = f.metadata()
100
+
101
+ logger.info(f"converting to {target_format}")
102
+ prefix = "lora_unet_"
103
+ if target_format == "default":
104
+ new_weights_sd = convert_from_diffusers(prefix, weights_sd)
105
+ metadata = metadata or {}
106
+ model_utils.precalculate_safetensors_hashes(new_weights_sd, metadata)
107
+ elif target_format == "other":
108
+ new_weights_sd = convert_to_diffusers(prefix, weights_sd)
109
+ else:
110
+ raise ValueError(f"unknown target format: {target_format}")
111
+
112
+ logger.info(f"saving to {output_file}")
113
+ save_file(new_weights_sd, output_file, metadata=metadata)
114
+
115
+ logger.info("done")
116
+
117
+
118
+ def parse_args():
119
+ parser = argparse.ArgumentParser(description="Convert LoRA weights between default and other formats")
120
+ parser.add_argument("--input", type=str, required=True, help="input model file")
121
+ parser.add_argument("--output", type=str, required=True, help="output model file")
122
+ parser.add_argument("--target", type=str, required=True, choices=["other", "default"], help="target format")
123
+ args = parser.parse_args()
124
+ return args
125
+
126
+
127
+ if __name__ == "__main__":
128
+ args = parse_args()
129
+ convert(args.input, args.output, args.target)
dataset/__init__.py ADDED
File without changes
dataset/config_utils.py ADDED
@@ -0,0 +1,359 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from dataclasses import (
3
+ asdict,
4
+ dataclass,
5
+ )
6
+ import functools
7
+ import random
8
+ from textwrap import dedent, indent
9
+ import json
10
+ from pathlib import Path
11
+
12
+ # from toolz import curry
13
+ from typing import Dict, List, Optional, Sequence, Tuple, Union
14
+
15
+ import toml
16
+ import voluptuous
17
+ from voluptuous import Any, ExactSequence, MultipleInvalid, Object, Schema
18
+
19
+ from .image_video_dataset import DatasetGroup, ImageDataset, VideoDataset
20
+
21
+ import logging
22
+
23
+ logger = logging.getLogger(__name__)
24
+ logging.basicConfig(level=logging.INFO)
25
+
26
+
27
+ @dataclass
28
+ class BaseDatasetParams:
29
+ resolution: Tuple[int, int] = (960, 544)
30
+ enable_bucket: bool = False
31
+ bucket_no_upscale: bool = False
32
+ caption_extension: Optional[str] = None
33
+ batch_size: int = 1
34
+ cache_directory: Optional[str] = None
35
+ debug_dataset: bool = False
36
+
37
+
38
+ @dataclass
39
+ class ImageDatasetParams(BaseDatasetParams):
40
+ image_directory: Optional[str] = None
41
+ image_jsonl_file: Optional[str] = None
42
+
43
+
44
+ @dataclass
45
+ class VideoDatasetParams(BaseDatasetParams):
46
+ video_directory: Optional[str] = None
47
+ video_jsonl_file: Optional[str] = None
48
+ target_frames: Sequence[int] = (1,)
49
+ frame_extraction: Optional[str] = "head"
50
+ frame_stride: Optional[int] = 1
51
+ frame_sample: Optional[int] = 1
52
+
53
+
54
+ @dataclass
55
+ class DatasetBlueprint:
56
+ is_image_dataset: bool
57
+ params: Union[ImageDatasetParams, VideoDatasetParams]
58
+
59
+
60
+ @dataclass
61
+ class DatasetGroupBlueprint:
62
+ datasets: Sequence[DatasetBlueprint]
63
+
64
+
65
+ @dataclass
66
+ class Blueprint:
67
+ dataset_group: DatasetGroupBlueprint
68
+
69
+
70
+ class ConfigSanitizer:
71
+ # @curry
72
+ @staticmethod
73
+ def __validate_and_convert_twodim(klass, value: Sequence) -> Tuple:
74
+ Schema(ExactSequence([klass, klass]))(value)
75
+ return tuple(value)
76
+
77
+ # @curry
78
+ @staticmethod
79
+ def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence]) -> Tuple:
80
+ Schema(Any(klass, ExactSequence([klass, klass])))(value)
81
+ try:
82
+ Schema(klass)(value)
83
+ return (value, value)
84
+ except:
85
+ return ConfigSanitizer.__validate_and_convert_twodim(klass, value)
86
+
87
+ # datasets schema
88
+ DATASET_ASCENDABLE_SCHEMA = {
89
+ "caption_extension": str,
90
+ "batch_size": int,
91
+ "resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int),
92
+ "enable_bucket": bool,
93
+ "bucket_no_upscale": bool,
94
+ }
95
+ IMAGE_DATASET_DISTINCT_SCHEMA = {
96
+ "image_directory": str,
97
+ "image_jsonl_file": str,
98
+ "cache_directory": str,
99
+ }
100
+ VIDEO_DATASET_DISTINCT_SCHEMA = {
101
+ "video_directory": str,
102
+ "video_jsonl_file": str,
103
+ "target_frames": [int],
104
+ "frame_extraction": str,
105
+ "frame_stride": int,
106
+ "frame_sample": int,
107
+ "cache_directory": str,
108
+ }
109
+
110
+ # options handled by argparse but not handled by user config
111
+ ARGPARSE_SPECIFIC_SCHEMA = {
112
+ "debug_dataset": bool,
113
+ }
114
+
115
+ def __init__(self) -> None:
116
+ self.image_dataset_schema = self.__merge_dict(
117
+ self.DATASET_ASCENDABLE_SCHEMA,
118
+ self.IMAGE_DATASET_DISTINCT_SCHEMA,
119
+ )
120
+ self.video_dataset_schema = self.__merge_dict(
121
+ self.DATASET_ASCENDABLE_SCHEMA,
122
+ self.VIDEO_DATASET_DISTINCT_SCHEMA,
123
+ )
124
+
125
+ def validate_flex_dataset(dataset_config: dict):
126
+ if "target_frames" in dataset_config:
127
+ return Schema(self.video_dataset_schema)(dataset_config)
128
+ else:
129
+ return Schema(self.image_dataset_schema)(dataset_config)
130
+
131
+ self.dataset_schema = validate_flex_dataset
132
+
133
+ self.general_schema = self.__merge_dict(
134
+ self.DATASET_ASCENDABLE_SCHEMA,
135
+ )
136
+ self.user_config_validator = Schema(
137
+ {
138
+ "general": self.general_schema,
139
+ "datasets": [self.dataset_schema],
140
+ }
141
+ )
142
+ self.argparse_schema = self.__merge_dict(
143
+ self.ARGPARSE_SPECIFIC_SCHEMA,
144
+ )
145
+ self.argparse_config_validator = Schema(Object(self.argparse_schema), extra=voluptuous.ALLOW_EXTRA)
146
+
147
+ def sanitize_user_config(self, user_config: dict) -> dict:
148
+ try:
149
+ return self.user_config_validator(user_config)
150
+ except MultipleInvalid:
151
+ # TODO: clarify the error message
152
+ logger.error("Invalid user config / ユーザ設定の形式が正しくないようです")
153
+ raise
154
+
155
+ # NOTE: In nature, argument parser result is not needed to be sanitize
156
+ # However this will help us to detect program bug
157
+ def sanitize_argparse_namespace(self, argparse_namespace: argparse.Namespace) -> argparse.Namespace:
158
+ try:
159
+ return self.argparse_config_validator(argparse_namespace)
160
+ except MultipleInvalid:
161
+ # XXX: this should be a bug
162
+ logger.error(
163
+ "Invalid cmdline parsed arguments. This should be a bug. / コマンドラインのパース結果が正しくないようです。プログラムのバグの可能性が高いです。"
164
+ )
165
+ raise
166
+
167
+ # NOTE: value would be overwritten by latter dict if there is already the same key
168
+ @staticmethod
169
+ def __merge_dict(*dict_list: dict) -> dict:
170
+ merged = {}
171
+ for schema in dict_list:
172
+ # merged |= schema
173
+ for k, v in schema.items():
174
+ merged[k] = v
175
+ return merged
176
+
177
+
178
+ class BlueprintGenerator:
179
+ BLUEPRINT_PARAM_NAME_TO_CONFIG_OPTNAME = {}
180
+
181
+ def __init__(self, sanitizer: ConfigSanitizer):
182
+ self.sanitizer = sanitizer
183
+
184
+ # runtime_params is for parameters which is only configurable on runtime, such as tokenizer
185
+ def generate(self, user_config: dict, argparse_namespace: argparse.Namespace, **runtime_params) -> Blueprint:
186
+ sanitized_user_config = self.sanitizer.sanitize_user_config(user_config)
187
+ sanitized_argparse_namespace = self.sanitizer.sanitize_argparse_namespace(argparse_namespace)
188
+
189
+ argparse_config = {k: v for k, v in vars(sanitized_argparse_namespace).items() if v is not None}
190
+ general_config = sanitized_user_config.get("general", {})
191
+
192
+ dataset_blueprints = []
193
+ for dataset_config in sanitized_user_config.get("datasets", []):
194
+ is_image_dataset = "target_frames" not in dataset_config
195
+ if is_image_dataset:
196
+ dataset_params_klass = ImageDatasetParams
197
+ else:
198
+ dataset_params_klass = VideoDatasetParams
199
+
200
+ params = self.generate_params_by_fallbacks(
201
+ dataset_params_klass, [dataset_config, general_config, argparse_config, runtime_params]
202
+ )
203
+ dataset_blueprints.append(DatasetBlueprint(is_image_dataset, params))
204
+
205
+ dataset_group_blueprint = DatasetGroupBlueprint(dataset_blueprints)
206
+
207
+ return Blueprint(dataset_group_blueprint)
208
+
209
+ @staticmethod
210
+ def generate_params_by_fallbacks(param_klass, fallbacks: Sequence[dict]):
211
+ name_map = BlueprintGenerator.BLUEPRINT_PARAM_NAME_TO_CONFIG_OPTNAME
212
+ search_value = BlueprintGenerator.search_value
213
+ default_params = asdict(param_klass())
214
+ param_names = default_params.keys()
215
+
216
+ params = {name: search_value(name_map.get(name, name), fallbacks, default_params.get(name)) for name in param_names}
217
+
218
+ return param_klass(**params)
219
+
220
+ @staticmethod
221
+ def search_value(key: str, fallbacks: Sequence[dict], default_value=None):
222
+ for cand in fallbacks:
223
+ value = cand.get(key)
224
+ if value is not None:
225
+ return value
226
+
227
+ return default_value
228
+
229
+
230
+ # if training is True, it will return a dataset group for training, otherwise for caching
231
+ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlueprint, training: bool = False) -> DatasetGroup:
232
+ datasets: List[Union[ImageDataset, VideoDataset]] = []
233
+
234
+ for dataset_blueprint in dataset_group_blueprint.datasets:
235
+ if dataset_blueprint.is_image_dataset:
236
+ dataset_klass = ImageDataset
237
+ else:
238
+ dataset_klass = VideoDataset
239
+
240
+ dataset = dataset_klass(**asdict(dataset_blueprint.params))
241
+ datasets.append(dataset)
242
+
243
+ # print info
244
+ info = ""
245
+ for i, dataset in enumerate(datasets):
246
+ is_image_dataset = isinstance(dataset, ImageDataset)
247
+ info += dedent(
248
+ f"""\
249
+ [Dataset {i}]
250
+ is_image_dataset: {is_image_dataset}
251
+ resolution: {dataset.resolution}
252
+ batch_size: {dataset.batch_size}
253
+ caption_extension: "{dataset.caption_extension}"
254
+ enable_bucket: {dataset.enable_bucket}
255
+ bucket_no_upscale: {dataset.bucket_no_upscale}
256
+ cache_directory: "{dataset.cache_directory}"
257
+ debug_dataset: {dataset.debug_dataset}
258
+ """
259
+ )
260
+
261
+ if is_image_dataset:
262
+ info += indent(
263
+ dedent(
264
+ f"""\
265
+ image_directory: "{dataset.image_directory}"
266
+ image_jsonl_file: "{dataset.image_jsonl_file}"
267
+ \n"""
268
+ ),
269
+ " ",
270
+ )
271
+ else:
272
+ info += indent(
273
+ dedent(
274
+ f"""\
275
+ video_directory: "{dataset.video_directory}"
276
+ video_jsonl_file: "{dataset.video_jsonl_file}"
277
+ target_frames: {dataset.target_frames}
278
+ frame_extraction: {dataset.frame_extraction}
279
+ frame_stride: {dataset.frame_stride}
280
+ frame_sample: {dataset.frame_sample}
281
+ \n"""
282
+ ),
283
+ " ",
284
+ )
285
+ logger.info(f"{info}")
286
+
287
+ # make buckets first because it determines the length of dataset
288
+ # and set the same seed for all datasets
289
+ seed = random.randint(0, 2**31) # actual seed is seed + epoch_no
290
+ for i, dataset in enumerate(datasets):
291
+ # logger.info(f"[Dataset {i}]")
292
+ dataset.set_seed(seed)
293
+ if training:
294
+ dataset.prepare_for_training()
295
+
296
+ return DatasetGroup(datasets)
297
+
298
+
299
+ def load_user_config(file: str) -> dict:
300
+ file: Path = Path(file)
301
+ if not file.is_file():
302
+ raise ValueError(f"file not found / ファイルが見つかりません: {file}")
303
+
304
+ if file.name.lower().endswith(".json"):
305
+ try:
306
+ with open(file, "r") as f:
307
+ config = json.load(f)
308
+ except Exception:
309
+ logger.error(
310
+ f"Error on parsing JSON config file. Please check the format. / JSON 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}"
311
+ )
312
+ raise
313
+ elif file.name.lower().endswith(".toml"):
314
+ try:
315
+ config = toml.load(file)
316
+ except Exception:
317
+ logger.error(
318
+ f"Error on parsing TOML config file. Please check the format. / TOML 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}"
319
+ )
320
+ raise
321
+ else:
322
+ raise ValueError(f"not supported config file format / 対応していない設定ファイルの形式です: {file}")
323
+
324
+ return config
325
+
326
+
327
+ # for config test
328
+ if __name__ == "__main__":
329
+ parser = argparse.ArgumentParser()
330
+ parser.add_argument("dataset_config")
331
+ config_args, remain = parser.parse_known_args()
332
+
333
+ parser = argparse.ArgumentParser()
334
+ parser.add_argument("--debug_dataset", action="store_true")
335
+ argparse_namespace = parser.parse_args(remain)
336
+
337
+ logger.info("[argparse_namespace]")
338
+ logger.info(f"{vars(argparse_namespace)}")
339
+
340
+ user_config = load_user_config(config_args.dataset_config)
341
+
342
+ logger.info("")
343
+ logger.info("[user_config]")
344
+ logger.info(f"{user_config}")
345
+
346
+ sanitizer = ConfigSanitizer()
347
+ sanitized_user_config = sanitizer.sanitize_user_config(user_config)
348
+
349
+ logger.info("")
350
+ logger.info("[sanitized_user_config]")
351
+ logger.info(f"{sanitized_user_config}")
352
+
353
+ blueprint = BlueprintGenerator(sanitizer).generate(user_config, argparse_namespace)
354
+
355
+ logger.info("")
356
+ logger.info("[blueprint]")
357
+ logger.info(f"{blueprint}")
358
+
359
+ dataset_group = generate_dataset_group_by_blueprint(blueprint.dataset_group)
dataset/dataset_config.md ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Dataset Configuration
2
+
3
+ Please create a TOML file for dataset configuration.
4
+
5
+ Image and video datasets are supported. The configuration file can include multiple datasets, either image or video datasets, with caption text files or metadata JSONL files.
6
+
7
+ ### Sample for Image Dataset with Caption Text Files
8
+
9
+ ```toml
10
+ # resolution, caption_extension, batch_size, enable_bucket, bucket_no_upscale must be set in either general or datasets
11
+
12
+ # general configurations
13
+ [general]
14
+ resolution = [960, 544]
15
+ caption_extension = ".txt"
16
+ batch_size = 1
17
+ enable_bucket = true
18
+ bucket_no_upscale = false
19
+
20
+ [[datasets]]
21
+ image_directory = "/path/to/image_dir"
22
+
23
+ # other datasets can be added here. each dataset can have different configurations
24
+ ```
25
+
26
+ ### Sample for Image Dataset with Metadata JSONL File
27
+
28
+ ```toml
29
+ # resolution, batch_size, enable_bucket, bucket_no_upscale must be set in either general or datasets
30
+ # caption_extension is not required for metadata jsonl file
31
+ # cache_directory is required for each dataset with metadata jsonl file
32
+
33
+ # general configurations
34
+ [general]
35
+ resolution = [960, 544]
36
+ batch_size = 1
37
+ enable_bucket = true
38
+ bucket_no_upscale = false
39
+
40
+ [[datasets]]
41
+ image_jsonl_file = "/path/to/metadata.jsonl"
42
+ cache_directory = "/path/to/cache_directory"
43
+
44
+ # other datasets can be added here. each dataset can have different configurations
45
+ ```
46
+
47
+ JSONL file format for metadata:
48
+
49
+ ```json
50
+ {"image_path": "/path/to/image1.jpg", "caption": "A caption for image1"}
51
+ {"image_path": "/path/to/image2.jpg", "caption": "A caption for image2"}
52
+ ```
53
+
54
+ ### Sample for Video Dataset with Caption Text Files
55
+
56
+ ```toml
57
+ # resolution, caption_extension, target_frames, frame_extraction, frame_stride, frame_sample, batch_size, enable_bucket, bucket_no_upscale must be set in either general or datasets
58
+
59
+ # general configurations
60
+ [general]
61
+ resolution = [960, 544]
62
+ caption_extension = ".txt"
63
+ batch_size = 1
64
+ enable_bucket = true
65
+ bucket_no_upscale = false
66
+
67
+ [[datasets]]
68
+ video_directory = "/path/to/video_dir"
69
+ target_frames = [1, 25, 45]
70
+ frame_extraction = "head"
71
+
72
+ # other datasets can be added here. each dataset can have different configurations
73
+ ```
74
+
75
+ ### Sample for Video Dataset with Metadata JSONL File
76
+
77
+ ```toml
78
+ # resolution, target_frames, frame_extraction, frame_stride, frame_sample, batch_size, enable_bucket, bucket_no_upscale must be set in either general or datasets
79
+ # caption_extension is not required for metadata jsonl file
80
+ # cache_directory is required for each dataset with metadata jsonl file
81
+
82
+ # general configurations
83
+ [general]
84
+ resolution = [960, 544]
85
+ batch_size = 1
86
+ enable_bucket = true
87
+ bucket_no_upscale = false
88
+
89
+ [[datasets]]
90
+ video_jsonl_file = "/path/to/metadata.jsonl"
91
+ target_frames = [1, 25, 45]
92
+ frame_extraction = "head"
93
+ cache_directory = "/path/to/cache_directory"
94
+
95
+ # same metadata jsonl file can be used for multiple datasets
96
+ [[datasets]]
97
+ video_jsonl_file = "/path/to/metadata.jsonl"
98
+ target_frames = [1]
99
+ frame_stride = 10
100
+ cache_directory = "/path/to/cache_directory"
101
+
102
+ # other datasets can be added here. each dataset can have different configurations
103
+ ```
104
+
105
+ JSONL file format for metadata:
106
+
107
+ ```json
108
+ {"video_path": "/path/to/video1.mp4", "caption": "A caption for video1"}
109
+ {"video_path": "/path/to/video2.mp4", "caption": "A caption for video2"}
110
+ ```
111
+
112
+ ### fame_extraction Options
113
+
114
+ - `head`: Extract the first N frames from the video.
115
+ - `chunk`: Extract frames by splitting the video into chunks of N frames.
116
+ - `slide`: Extract frames from the video with a stride of `frame_stride`.
117
+ - `uniform`: Extract `frame_sample` samples uniformly from the video.
118
+
119
+ For example, consider a video with 40 frames. The following diagrams illustrate each extraction:
120
+
121
+ ```
122
+ Original Video, 40 frames: x = frame, o = no frame
123
+ oooooooooooooooooooooooooooooooooooooooo
124
+
125
+ head, target_frames = [1, 13, 25] -> extract head frames:
126
+ xooooooooooooooooooooooooooooooooooooooo
127
+ xxxxxxxxxxxxxooooooooooooooooooooooooooo
128
+ xxxxxxxxxxxxxxxxxxxxxxxxxooooooooooooooo
129
+
130
+ chunk, target_frames = [13, 25] -> extract frames by splitting into chunks, into 13 and 25 frames:
131
+ xxxxxxxxxxxxxooooooooooooooooooooooooooo
132
+ oooooooooooooxxxxxxxxxxxxxoooooooooooooo
133
+ ooooooooooooooooooooooooooxxxxxxxxxxxxxo
134
+ xxxxxxxxxxxxxxxxxxxxxxxxxooooooooooooooo
135
+
136
+ NOTE: Please do not include 1 in target_frames if you are using the frame_extraction "chunk". It will make the all frames to be extracted.
137
+
138
+ slide, target_frames = [1, 13, 25], frame_stride = 10 -> extract N frames with a stride of 10:
139
+ xooooooooooooooooooooooooooooooooooooooo
140
+ ooooooooooxooooooooooooooooooooooooooooo
141
+ ooooooooooooooooooooxooooooooooooooooooo
142
+ ooooooooooooooooooooooooooooooxooooooooo
143
+ xxxxxxxxxxxxxooooooooooooooooooooooooooo
144
+ ooooooooooxxxxxxxxxxxxxooooooooooooooooo
145
+ ooooooooooooooooooooxxxxxxxxxxxxxooooooo
146
+ xxxxxxxxxxxxxxxxxxxxxxxxxooooooooooooooo
147
+ ooooooooooxxxxxxxxxxxxxxxxxxxxxxxxxooooo
148
+
149
+ uniform, target_frames =[1, 13, 25], frame_sample = 4 -> extract `frame_sample` samples uniformly, N frames each:
150
+ xooooooooooooooooooooooooooooooooooooooo
151
+ oooooooooooooxoooooooooooooooooooooooooo
152
+ oooooooooooooooooooooooooxoooooooooooooo
153
+ ooooooooooooooooooooooooooooooooooooooox
154
+ xxxxxxxxxxxxxooooooooooooooooooooooooooo
155
+ oooooooooxxxxxxxxxxxxxoooooooooooooooooo
156
+ ooooooooooooooooooxxxxxxxxxxxxxooooooooo
157
+ oooooooooooooooooooooooooooxxxxxxxxxxxxx
158
+ xxxxxxxxxxxxxxxxxxxxxxxxxooooooooooooooo
159
+ oooooxxxxxxxxxxxxxxxxxxxxxxxxxoooooooooo
160
+ ooooooooooxxxxxxxxxxxxxxxxxxxxxxxxxooooo
161
+ oooooooooooooooxxxxxxxxxxxxxxxxxxxxxxxxx
162
+ ```
163
+
164
+ ## Specifications
165
+
166
+ ```toml
167
+ # general configurations
168
+ [general]
169
+ resolution = [960, 544] # optional, [W, H], default is None. This is the default resolution for all datasets
170
+ caption_extension = ".txt" # optional, default is None. This is the default caption extension for all datasets
171
+ batch_size = 1 # optional, default is 1. This is the default batch size for all datasets
172
+ enable_bucket = true # optional, default is false. Enable bucketing for datasets
173
+ bucket_no_upscale = false # optional, default is false. Disable upscaling for bucketing. Ignored if enable_bucket is false
174
+
175
+ ### Image Dataset
176
+
177
+ # sample image dataset with caption text files
178
+ [[datasets]]
179
+ image_directory = "/path/to/image_dir"
180
+ caption_extension = ".txt" # required for caption text files, if general caption extension is not set
181
+ resolution = [960, 544] # required if general resolution is not set
182
+ batch_size = 4 # optional, overwrite the default batch size
183
+ enable_bucket = false # optional, overwrite the default bucketing setting
184
+ bucket_no_upscale = true # optional, overwrite the default bucketing setting
185
+ cache_directory = "/path/to/cache_directory" # optional, default is None to use the same directory as the image directory. NOTE: caching is always enabled
186
+
187
+ # sample image dataset with metadata **jsonl** file
188
+ [[datasets]]
189
+ image_jsonl_file = "/path/to/metadata.jsonl" # includes pairs of image files and captions
190
+ resolution = [960, 544] # required if general resolution is not set
191
+ cache_directory = "/path/to/cache_directory" # required for metadata jsonl file
192
+ # caption_extension is not required for metadata jsonl file
193
+ # batch_size, enable_bucket, bucket_no_upscale are also available for metadata jsonl file
194
+
195
+ ### Video Dataset
196
+
197
+ # sample video dataset with caption text files
198
+ [[datasets]]
199
+ video_directory = "/path/to/video_dir"
200
+ caption_extension = ".txt" # required for caption text files, if general caption extension is not set
201
+ resolution = [960, 544] # required if general resolution is not set
202
+
203
+ target_frames = [1, 25, 79] # required for video dataset. list of video lengths to extract frames. each element must be N*4+1 (N=0,1,2,...)
204
+
205
+ # NOTE: Please do not include 1 in target_frames if you are using the frame_extraction "chunk". It will make the all frames to be extracted.
206
+
207
+ frame_extraction = "head" # optional, "head" or "chunk", "slide", "uniform". Default is "head"
208
+ frame_stride = 1 # optional, default is 1, available for "slide" frame extraction
209
+ frame_sample = 4 # optional, default is 1 (same as "head"), available for "uniform" frame extraction
210
+ # batch_size, enable_bucket, bucket_no_upscale, cache_directory are also available for video dataset
211
+
212
+ # sample video dataset with metadata jsonl file
213
+ [[datasets]]
214
+ video_jsonl_file = "/path/to/metadata.jsonl" # includes pairs of video files and captions
215
+
216
+ target_frames = [1, 79]
217
+
218
+ cache_directory = "/path/to/cache_directory" # required for metadata jsonl file
219
+ # frame_extraction, frame_stride, frame_sample are also available for metadata jsonl file
220
+ ```
221
+
222
+ <!--
223
+ # sample image dataset with lance
224
+ [[datasets]]
225
+ image_lance_dataset = "/path/to/lance_dataset"
226
+ resolution = [960, 544] # required if general resolution is not set
227
+ # batch_size, enable_bucket, bucket_no_upscale, cache_directory are also available for lance dataset
228
+ -->
229
+
230
+ The metadata with .json file will be supported in the near future.
231
+
232
+
233
+
234
+ <!--
235
+
236
+ ```toml
237
+ # general configurations
238
+ [general]
239
+ resolution = [960, 544] # optional, [W, H], default is None. This is the default resolution for all datasets
240
+ caption_extension = ".txt" # optional, default is None. This is the default caption extension for all datasets
241
+ batch_size = 1 # optional, default is 1. This is the default batch size for all datasets
242
+ enable_bucket = true # optional, default is false. Enable bucketing for datasets
243
+ bucket_no_upscale = false # optional, default is false. Disable upscaling for bucketing. Ignored if enable_bucket is false
244
+
245
+ # sample image dataset with caption text files
246
+ [[datasets]]
247
+ image_directory = "/path/to/image_dir"
248
+ caption_extension = ".txt" # required for caption text files, if general caption extension is not set
249
+ resolution = [960, 544] # required if general resolution is not set
250
+ batch_size = 4 # optional, overwrite the default batch size
251
+ enable_bucket = false # optional, overwrite the default bucketing setting
252
+ bucket_no_upscale = true # optional, overwrite the default bucketing setting
253
+ cache_directory = "/path/to/cache_directory" # optional, default is None to use the same directory as the image directory. NOTE: caching is always enabled
254
+
255
+ # sample image dataset with metadata **jsonl** file
256
+ [[datasets]]
257
+ image_jsonl_file = "/path/to/metadata.jsonl" # includes pairs of image files and captions
258
+ resolution = [960, 544] # required if general resolution is not set
259
+ cache_directory = "/path/to/cache_directory" # required for metadata jsonl file
260
+ # caption_extension is not required for metadata jsonl file
261
+ # batch_size, enable_bucket, bucket_no_upscale are also available for metadata jsonl file
262
+
263
+ # sample video dataset with caption text files
264
+ [[datasets]]
265
+ video_directory = "/path/to/video_dir"
266
+ caption_extension = ".txt" # required for caption text files, if general caption extension is not set
267
+ resolution = [960, 544] # required if general resolution is not set
268
+ target_frames = [1, 25, 79] # required for video dataset. list of video lengths to extract frames. each element must be N*4+1 (N=0,1,2,...)
269
+ frame_extraction = "head" # optional, "head" or "chunk", "slide", "uniform". Default is "head"
270
+ frame_stride = 1 # optional, default is 1, available for "slide" frame extraction
271
+ frame_sample = 4 # optional, default is 1 (same as "head"), available for "uniform" frame extraction
272
+ # batch_size, enable_bucket, bucket_no_upscale, cache_directory are also available for video dataset
273
+
274
+ # sample video dataset with metadata jsonl file
275
+ [[datasets]]
276
+ video_jsonl_file = "/path/to/metadata.jsonl" # includes pairs of video files and captions
277
+ target_frames = [1, 79]
278
+ cache_directory = "/path/to/cache_directory" # required for metadata jsonl file
279
+ # frame_extraction, frame_stride, frame_sample are also available for metadata jsonl file
280
+ ```
281
+
282
+ # sample image dataset with lance
283
+ [[datasets]]
284
+ image_lance_dataset = "/path/to/lance_dataset"
285
+ resolution = [960, 544] # required if general resolution is not set
286
+ # batch_size, enable_bucket, bucket_no_upscale, cache_directory are also available for lance dataset
287
+
288
+ The metadata with .json file will be supported in the near future.
289
+
290
+
291
+
292
+
293
+ -->
dataset/image_video_dataset.py ADDED
@@ -0,0 +1,1255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from concurrent.futures import ThreadPoolExecutor
2
+ import glob
3
+ import json
4
+ import math
5
+ import os
6
+ import random
7
+ import time
8
+ from typing import Optional, Sequence, Tuple, Union
9
+
10
+ import numpy as np
11
+ import torch
12
+ from safetensors.torch import save_file, load_file
13
+ from safetensors import safe_open
14
+ from PIL import Image
15
+ import cv2
16
+ import av
17
+
18
+ from utils import safetensors_utils
19
+ from utils.model_utils import dtype_to_str
20
+
21
+ import logging
22
+
23
+ logger = logging.getLogger(__name__)
24
+ logging.basicConfig(level=logging.INFO)
25
+
26
+
27
+ IMAGE_EXTENSIONS = [".png", ".jpg", ".jpeg", ".webp", ".bmp", ".PNG", ".JPG", ".JPEG", ".WEBP", ".BMP"]
28
+
29
+ try:
30
+ import pillow_avif
31
+
32
+ IMAGE_EXTENSIONS.extend([".avif", ".AVIF"])
33
+ except:
34
+ pass
35
+
36
+ # JPEG-XL on Linux
37
+ try:
38
+ from jxlpy import JXLImagePlugin
39
+
40
+ IMAGE_EXTENSIONS.extend([".jxl", ".JXL"])
41
+ except:
42
+ pass
43
+
44
+ # JPEG-XL on Windows
45
+ try:
46
+ import pillow_jxl
47
+
48
+ IMAGE_EXTENSIONS.extend([".jxl", ".JXL"])
49
+ except:
50
+ pass
51
+
52
+ VIDEO_EXTENSIONS = [".mp4", ".avi", ".mov", ".webm", ".MP4", ".AVI", ".MOV", ".WEBM"] # some of them are not tested
53
+
54
+ ARCHITECTURE_HUNYUAN_VIDEO = "hv"
55
+
56
+
57
+ def glob_images(directory, base="*"):
58
+ img_paths = []
59
+ for ext in IMAGE_EXTENSIONS:
60
+ if base == "*":
61
+ img_paths.extend(glob.glob(os.path.join(glob.escape(directory), base + ext)))
62
+ else:
63
+ img_paths.extend(glob.glob(glob.escape(os.path.join(directory, base + ext))))
64
+ img_paths = list(set(img_paths)) # remove duplicates
65
+ img_paths.sort()
66
+ return img_paths
67
+
68
+
69
+ def glob_videos(directory, base="*"):
70
+ video_paths = []
71
+ for ext in VIDEO_EXTENSIONS:
72
+ if base == "*":
73
+ video_paths.extend(glob.glob(os.path.join(glob.escape(directory), base + ext)))
74
+ else:
75
+ video_paths.extend(glob.glob(glob.escape(os.path.join(directory, base + ext))))
76
+ video_paths = list(set(video_paths)) # remove duplicates
77
+ video_paths.sort()
78
+ return video_paths
79
+
80
+
81
+ def divisible_by(num: int, divisor: int) -> int:
82
+ return num - num % divisor
83
+
84
+
85
+ def resize_image_to_bucket(image: Union[Image.Image, np.ndarray], bucket_reso: tuple[int, int]) -> np.ndarray:
86
+ """
87
+ Resize the image to the bucket resolution.
88
+ """
89
+ is_pil_image = isinstance(image, Image.Image)
90
+ if is_pil_image:
91
+ image_width, image_height = image.size
92
+ else:
93
+ image_height, image_width = image.shape[:2]
94
+
95
+ if bucket_reso == (image_width, image_height):
96
+ return np.array(image) if is_pil_image else image
97
+
98
+ bucket_width, bucket_height = bucket_reso
99
+ if bucket_width == image_width or bucket_height == image_height:
100
+ image = np.array(image) if is_pil_image else image
101
+ else:
102
+ # resize the image to the bucket resolution to match the short side
103
+ scale_width = bucket_width / image_width
104
+ scale_height = bucket_height / image_height
105
+ scale = max(scale_width, scale_height)
106
+ image_width = int(image_width * scale + 0.5)
107
+ image_height = int(image_height * scale + 0.5)
108
+
109
+ if scale > 1:
110
+ image = Image.fromarray(image) if not is_pil_image else image
111
+ image = image.resize((image_width, image_height), Image.LANCZOS)
112
+ image = np.array(image)
113
+ else:
114
+ image = np.array(image) if is_pil_image else image
115
+ image = cv2.resize(image, (image_width, image_height), interpolation=cv2.INTER_AREA)
116
+
117
+ # crop the image to the bucket resolution
118
+ crop_left = (image_width - bucket_width) // 2
119
+ crop_top = (image_height - bucket_height) // 2
120
+ image = image[crop_top : crop_top + bucket_height, crop_left : crop_left + bucket_width]
121
+ return image
122
+
123
+
124
+ class ItemInfo:
125
+ def __init__(
126
+ self,
127
+ item_key: str,
128
+ caption: str,
129
+ original_size: tuple[int, int],
130
+ bucket_size: Optional[Union[tuple[int, int], tuple[int, int, int]]] = None,
131
+ frame_count: Optional[int] = None,
132
+ content: Optional[np.ndarray] = None,
133
+ latent_cache_path: Optional[str] = None,
134
+ ) -> None:
135
+ self.item_key = item_key
136
+ self.caption = caption
137
+ self.original_size = original_size
138
+ self.bucket_size = bucket_size
139
+ self.frame_count = frame_count
140
+ self.content = content
141
+ self.latent_cache_path = latent_cache_path
142
+ self.text_encoder_output_cache_path: Optional[str] = None
143
+
144
+ def __str__(self) -> str:
145
+ return (
146
+ f"ItemInfo(item_key={self.item_key}, caption={self.caption}, "
147
+ + f"original_size={self.original_size}, bucket_size={self.bucket_size}, "
148
+ + f"frame_count={self.frame_count}, latent_cache_path={self.latent_cache_path})"
149
+ )
150
+
151
+
152
+ def save_latent_cache(item_info: ItemInfo, latent: torch.Tensor):
153
+ assert latent.dim() == 4, "latent should be 4D tensor (frame, channel, height, width)"
154
+ metadata = {
155
+ "architecture": "hunyuan_video",
156
+ "width": f"{item_info.original_size[0]}",
157
+ "height": f"{item_info.original_size[1]}",
158
+ "format_version": "1.0.0",
159
+ }
160
+ if item_info.frame_count is not None:
161
+ metadata["frame_count"] = f"{item_info.frame_count}"
162
+
163
+ _, F, H, W = latent.shape
164
+ dtype_str = dtype_to_str(latent.dtype)
165
+ sd = {f"latents_{F}x{H}x{W}_{dtype_str}": latent.detach().cpu()}
166
+
167
+ latent_dir = os.path.dirname(item_info.latent_cache_path)
168
+ os.makedirs(latent_dir, exist_ok=True)
169
+
170
+ save_file(sd, item_info.latent_cache_path, metadata=metadata)
171
+
172
+
173
+ def save_text_encoder_output_cache(item_info: ItemInfo, embed: torch.Tensor, mask: Optional[torch.Tensor], is_llm: bool):
174
+ assert (
175
+ embed.dim() == 1 or embed.dim() == 2
176
+ ), f"embed should be 2D tensor (feature, hidden_size) or (hidden_size,), got {embed.shape}"
177
+ assert mask is None or mask.dim() == 1, f"mask should be 1D tensor (feature), got {mask.shape}"
178
+ metadata = {
179
+ "architecture": "hunyuan_video",
180
+ "caption1": item_info.caption,
181
+ "format_version": "1.0.0",
182
+ }
183
+
184
+ sd = {}
185
+ if os.path.exists(item_info.text_encoder_output_cache_path):
186
+ # load existing cache and update metadata
187
+ with safetensors_utils.MemoryEfficientSafeOpen(item_info.text_encoder_output_cache_path) as f:
188
+ existing_metadata = f.metadata()
189
+ for key in f.keys():
190
+ sd[key] = f.get_tensor(key)
191
+
192
+ assert existing_metadata["architecture"] == metadata["architecture"], "architecture mismatch"
193
+ if existing_metadata["caption1"] != metadata["caption1"]:
194
+ logger.warning(f"caption mismatch: existing={existing_metadata['caption1']}, new={metadata['caption1']}, overwrite")
195
+ # TODO verify format_version
196
+
197
+ existing_metadata.pop("caption1", None)
198
+ existing_metadata.pop("format_version", None)
199
+ metadata.update(existing_metadata) # copy existing metadata
200
+ else:
201
+ text_encoder_output_dir = os.path.dirname(item_info.text_encoder_output_cache_path)
202
+ os.makedirs(text_encoder_output_dir, exist_ok=True)
203
+
204
+ dtype_str = dtype_to_str(embed.dtype)
205
+ text_encoder_type = "llm" if is_llm else "clipL"
206
+ sd[f"{text_encoder_type}_{dtype_str}"] = embed.detach().cpu()
207
+ if mask is not None:
208
+ sd[f"{text_encoder_type}_mask"] = mask.detach().cpu()
209
+
210
+ safetensors_utils.mem_eff_save_file(sd, item_info.text_encoder_output_cache_path, metadata=metadata)
211
+
212
+
213
+ class BucketSelector:
214
+ RESOLUTION_STEPS_HUNYUAN = 16
215
+
216
+ def __init__(self, resolution: Tuple[int, int], enable_bucket: bool = True, no_upscale: bool = False):
217
+ self.resolution = resolution
218
+ self.bucket_area = resolution[0] * resolution[1]
219
+ self.reso_steps = BucketSelector.RESOLUTION_STEPS_HUNYUAN
220
+
221
+ if not enable_bucket:
222
+ # only define one bucket
223
+ self.bucket_resolutions = [resolution]
224
+ self.no_upscale = False
225
+ else:
226
+ # prepare bucket resolution
227
+ self.no_upscale = no_upscale
228
+ sqrt_size = int(math.sqrt(self.bucket_area))
229
+ min_size = divisible_by(sqrt_size // 2, self.reso_steps)
230
+ self.bucket_resolutions = []
231
+ for w in range(min_size, sqrt_size + self.reso_steps, self.reso_steps):
232
+ h = divisible_by(self.bucket_area // w, self.reso_steps)
233
+ self.bucket_resolutions.append((w, h))
234
+ self.bucket_resolutions.append((h, w))
235
+
236
+ self.bucket_resolutions = list(set(self.bucket_resolutions))
237
+ self.bucket_resolutions.sort()
238
+
239
+ # calculate aspect ratio to find the nearest resolution
240
+ self.aspect_ratios = np.array([w / h for w, h in self.bucket_resolutions])
241
+
242
+ def get_bucket_resolution(self, image_size: tuple[int, int]) -> tuple[int, int]:
243
+ """
244
+ return the bucket resolution for the given image size, (width, height)
245
+ """
246
+ area = image_size[0] * image_size[1]
247
+ if self.no_upscale and area <= self.bucket_area:
248
+ w, h = image_size
249
+ w = divisible_by(w, self.reso_steps)
250
+ h = divisible_by(h, self.reso_steps)
251
+ return w, h
252
+
253
+ aspect_ratio = image_size[0] / image_size[1]
254
+ ar_errors = self.aspect_ratios - aspect_ratio
255
+ bucket_id = np.abs(ar_errors).argmin()
256
+ return self.bucket_resolutions[bucket_id]
257
+
258
+
259
+ def load_video(
260
+ video_path: str,
261
+ start_frame: Optional[int] = None,
262
+ end_frame: Optional[int] = None,
263
+ bucket_selector: Optional[BucketSelector] = None,
264
+ ) -> list[np.ndarray]:
265
+ container = av.open(video_path)
266
+ video = []
267
+ bucket_reso = None
268
+ for i, frame in enumerate(container.decode(video=0)):
269
+ if start_frame is not None and i < start_frame:
270
+ continue
271
+ if end_frame is not None and i >= end_frame:
272
+ break
273
+ frame = frame.to_image()
274
+
275
+ if bucket_selector is not None and bucket_reso is None:
276
+ bucket_reso = bucket_selector.get_bucket_resolution(frame.size)
277
+
278
+ if bucket_reso is not None:
279
+ frame = resize_image_to_bucket(frame, bucket_reso)
280
+ else:
281
+ frame = np.array(frame)
282
+
283
+ video.append(frame)
284
+ container.close()
285
+ return video
286
+
287
+
288
+ class BucketBatchManager:
289
+
290
+ def __init__(self, bucketed_item_info: dict[tuple[int, int], list[ItemInfo]], batch_size: int):
291
+ self.batch_size = batch_size
292
+ self.buckets = bucketed_item_info
293
+ self.bucket_resos = list(self.buckets.keys())
294
+ self.bucket_resos.sort()
295
+
296
+ self.bucket_batch_indices = []
297
+ for bucket_reso in self.bucket_resos:
298
+ bucket = self.buckets[bucket_reso]
299
+ num_batches = math.ceil(len(bucket) / self.batch_size)
300
+ for i in range(num_batches):
301
+ self.bucket_batch_indices.append((bucket_reso, i))
302
+
303
+ self.shuffle()
304
+
305
+ def show_bucket_info(self):
306
+ for bucket_reso in self.bucket_resos:
307
+ bucket = self.buckets[bucket_reso]
308
+ logger.info(f"bucket: {bucket_reso}, count: {len(bucket)}")
309
+
310
+ logger.info(f"total batches: {len(self)}")
311
+
312
+ def shuffle(self):
313
+ for bucket in self.buckets.values():
314
+ random.shuffle(bucket)
315
+ random.shuffle(self.bucket_batch_indices)
316
+
317
+ def __len__(self):
318
+ return len(self.bucket_batch_indices)
319
+
320
+ def __getitem__(self, idx):
321
+ bucket_reso, batch_idx = self.bucket_batch_indices[idx]
322
+ bucket = self.buckets[bucket_reso]
323
+ start = batch_idx * self.batch_size
324
+ end = min(start + self.batch_size, len(bucket))
325
+
326
+ latents = []
327
+ llm_embeds = []
328
+ llm_masks = []
329
+ clip_l_embeds = []
330
+ for item_info in bucket[start:end]:
331
+ sd = load_file(item_info.latent_cache_path)
332
+ latent = None
333
+ for key in sd.keys():
334
+ if key.startswith("latents_"):
335
+ latent = sd[key]
336
+ break
337
+ latents.append(latent)
338
+
339
+ sd = load_file(item_info.text_encoder_output_cache_path)
340
+ llm_embed = llm_mask = clip_l_embed = None
341
+ for key in sd.keys():
342
+ if key.startswith("llm_mask"):
343
+ llm_mask = sd[key]
344
+ elif key.startswith("llm_"):
345
+ llm_embed = sd[key]
346
+ elif key.startswith("clipL_mask"):
347
+ pass
348
+ elif key.startswith("clipL_"):
349
+ clip_l_embed = sd[key]
350
+ llm_embeds.append(llm_embed)
351
+ llm_masks.append(llm_mask)
352
+ clip_l_embeds.append(clip_l_embed)
353
+
354
+ latents = torch.stack(latents)
355
+ llm_embeds = torch.stack(llm_embeds)
356
+ llm_masks = torch.stack(llm_masks)
357
+ clip_l_embeds = torch.stack(clip_l_embeds)
358
+
359
+ return latents, llm_embeds, llm_masks, clip_l_embeds
360
+
361
+
362
+ class ContentDatasource:
363
+ def __init__(self):
364
+ self.caption_only = False
365
+
366
+ def set_caption_only(self, caption_only: bool):
367
+ self.caption_only = caption_only
368
+
369
+ def is_indexable(self):
370
+ return False
371
+
372
+ def get_caption(self, idx: int) -> tuple[str, str]:
373
+ """
374
+ Returns caption. May not be called if is_indexable() returns False.
375
+ """
376
+ raise NotImplementedError
377
+
378
+ def __len__(self):
379
+ raise NotImplementedError
380
+
381
+ def __iter__(self):
382
+ raise NotImplementedError
383
+
384
+ def __next__(self):
385
+ raise NotImplementedError
386
+
387
+
388
+ class ImageDatasource(ContentDatasource):
389
+ def __init__(self):
390
+ super().__init__()
391
+
392
+ def get_image_data(self, idx: int) -> tuple[str, Image.Image, str]:
393
+ """
394
+ Returns image data as a tuple of image path, image, and caption for the given index.
395
+ Key must be unique and valid as a file name.
396
+ May not be called if is_indexable() returns False.
397
+ """
398
+ raise NotImplementedError
399
+
400
+
401
+ class ImageDirectoryDatasource(ImageDatasource):
402
+ def __init__(self, image_directory: str, caption_extension: Optional[str] = None):
403
+ super().__init__()
404
+ self.image_directory = image_directory
405
+ self.caption_extension = caption_extension
406
+ self.current_idx = 0
407
+
408
+ # glob images
409
+ logger.info(f"glob images in {self.image_directory}")
410
+ self.image_paths = glob_images(self.image_directory)
411
+ logger.info(f"found {len(self.image_paths)} images")
412
+
413
+ def is_indexable(self):
414
+ return True
415
+
416
+ def __len__(self):
417
+ return len(self.image_paths)
418
+
419
+ def get_image_data(self, idx: int) -> tuple[str, Image.Image, str]:
420
+ image_path = self.image_paths[idx]
421
+ image = Image.open(image_path).convert("RGB")
422
+
423
+ _, caption = self.get_caption(idx)
424
+
425
+ return image_path, image, caption
426
+
427
+ def get_caption(self, idx: int) -> tuple[str, str]:
428
+ image_path = self.image_paths[idx]
429
+ caption_path = os.path.splitext(image_path)[0] + self.caption_extension if self.caption_extension else ""
430
+ with open(caption_path, "r", encoding="utf-8") as f:
431
+ caption = f.read().strip()
432
+ return image_path, caption
433
+
434
+ def __iter__(self):
435
+ self.current_idx = 0
436
+ return self
437
+
438
+ def __next__(self) -> callable:
439
+ """
440
+ Returns a fetcher function that returns image data.
441
+ """
442
+ if self.current_idx >= len(self.image_paths):
443
+ raise StopIteration
444
+
445
+ if self.caption_only:
446
+
447
+ def create_caption_fetcher(index):
448
+ return lambda: self.get_caption(index)
449
+
450
+ fetcher = create_caption_fetcher(self.current_idx)
451
+ else:
452
+
453
+ def create_image_fetcher(index):
454
+ return lambda: self.get_image_data(index)
455
+
456
+ fetcher = create_image_fetcher(self.current_idx)
457
+
458
+ self.current_idx += 1
459
+ return fetcher
460
+
461
+
462
+ class ImageJsonlDatasource(ImageDatasource):
463
+ def __init__(self, image_jsonl_file: str):
464
+ super().__init__()
465
+ self.image_jsonl_file = image_jsonl_file
466
+ self.current_idx = 0
467
+
468
+ # load jsonl
469
+ logger.info(f"load image jsonl from {self.image_jsonl_file}")
470
+ self.data = []
471
+ with open(self.image_jsonl_file, "r", encoding="utf-8") as f:
472
+ for line in f:
473
+ data = json.loads(line)
474
+ self.data.append(data)
475
+ logger.info(f"loaded {len(self.data)} images")
476
+
477
+ def is_indexable(self):
478
+ return True
479
+
480
+ def __len__(self):
481
+ return len(self.data)
482
+
483
+ def get_image_data(self, idx: int) -> tuple[str, Image.Image, str]:
484
+ data = self.data[idx]
485
+ image_path = data["image_path"]
486
+ image = Image.open(image_path).convert("RGB")
487
+
488
+ caption = data["caption"]
489
+
490
+ return image_path, image, caption
491
+
492
+ def get_caption(self, idx: int) -> tuple[str, str]:
493
+ data = self.data[idx]
494
+ image_path = data["image_path"]
495
+ caption = data["caption"]
496
+ return image_path, caption
497
+
498
+ def __iter__(self):
499
+ self.current_idx = 0
500
+ return self
501
+
502
+ def __next__(self) -> callable:
503
+ if self.current_idx >= len(self.data):
504
+ raise StopIteration
505
+
506
+ if self.caption_only:
507
+
508
+ def create_caption_fetcher(index):
509
+ return lambda: self.get_caption(index)
510
+
511
+ fetcher = create_caption_fetcher(self.current_idx)
512
+
513
+ else:
514
+
515
+ def create_fetcher(index):
516
+ return lambda: self.get_image_data(index)
517
+
518
+ fetcher = create_fetcher(self.current_idx)
519
+
520
+ self.current_idx += 1
521
+ return fetcher
522
+
523
+
524
+ class VideoDatasource(ContentDatasource):
525
+ def __init__(self):
526
+ super().__init__()
527
+
528
+ # None means all frames
529
+ self.start_frame = None
530
+ self.end_frame = None
531
+
532
+ self.bucket_selector = None
533
+
534
+ def __len__(self):
535
+ raise NotImplementedError
536
+
537
+ def get_video_data_from_path(
538
+ self,
539
+ video_path: str,
540
+ start_frame: Optional[int] = None,
541
+ end_frame: Optional[int] = None,
542
+ bucket_selector: Optional[BucketSelector] = None,
543
+ ) -> tuple[str, list[Image.Image], str]:
544
+ # this method can resize the video if bucket_selector is given to reduce the memory usage
545
+
546
+ start_frame = start_frame if start_frame is not None else self.start_frame
547
+ end_frame = end_frame if end_frame is not None else self.end_frame
548
+ bucket_selector = bucket_selector if bucket_selector is not None else self.bucket_selector
549
+
550
+ video = load_video(video_path, start_frame, end_frame, bucket_selector)
551
+ return video
552
+
553
+ def set_start_and_end_frame(self, start_frame: Optional[int], end_frame: Optional[int]):
554
+ self.start_frame = start_frame
555
+ self.end_frame = end_frame
556
+
557
+ def set_bucket_selector(self, bucket_selector: BucketSelector):
558
+ self.bucket_selector = bucket_selector
559
+
560
+ def __iter__(self):
561
+ raise NotImplementedError
562
+
563
+ def __next__(self):
564
+ raise NotImplementedError
565
+
566
+
567
+ class VideoDirectoryDatasource(VideoDatasource):
568
+ def __init__(self, video_directory: str, caption_extension: Optional[str] = None):
569
+ super().__init__()
570
+ self.video_directory = video_directory
571
+ self.caption_extension = caption_extension
572
+ self.current_idx = 0
573
+
574
+ # glob images
575
+ logger.info(f"glob images in {self.video_directory}")
576
+ self.video_paths = glob_videos(self.video_directory)
577
+ logger.info(f"found {len(self.video_paths)} videos")
578
+
579
+ def is_indexable(self):
580
+ return True
581
+
582
+ def __len__(self):
583
+ return len(self.video_paths)
584
+
585
+ def get_video_data(
586
+ self,
587
+ idx: int,
588
+ start_frame: Optional[int] = None,
589
+ end_frame: Optional[int] = None,
590
+ bucket_selector: Optional[BucketSelector] = None,
591
+ ) -> tuple[str, list[Image.Image], str]:
592
+ video_path = self.video_paths[idx]
593
+ video = self.get_video_data_from_path(video_path, start_frame, end_frame, bucket_selector)
594
+
595
+ _, caption = self.get_caption(idx)
596
+
597
+ return video_path, video, caption
598
+
599
+ def get_caption(self, idx: int) -> tuple[str, str]:
600
+ video_path = self.video_paths[idx]
601
+ caption_path = os.path.splitext(video_path)[0] + self.caption_extension if self.caption_extension else ""
602
+ with open(caption_path, "r", encoding="utf-8") as f:
603
+ caption = f.read().strip()
604
+ return video_path, caption
605
+
606
+ def __iter__(self):
607
+ self.current_idx = 0
608
+ return self
609
+
610
+ def __next__(self):
611
+ if self.current_idx >= len(self.video_paths):
612
+ raise StopIteration
613
+
614
+ if self.caption_only:
615
+
616
+ def create_caption_fetcher(index):
617
+ return lambda: self.get_caption(index)
618
+
619
+ fetcher = create_caption_fetcher(self.current_idx)
620
+
621
+ else:
622
+
623
+ def create_fetcher(index):
624
+ return lambda: self.get_video_data(index)
625
+
626
+ fetcher = create_fetcher(self.current_idx)
627
+
628
+ self.current_idx += 1
629
+ return fetcher
630
+
631
+
632
+ class VideoJsonlDatasource(VideoDatasource):
633
+ def __init__(self, video_jsonl_file: str):
634
+ super().__init__()
635
+ self.video_jsonl_file = video_jsonl_file
636
+ self.current_idx = 0
637
+
638
+ # load jsonl
639
+ logger.info(f"load video jsonl from {self.video_jsonl_file}")
640
+ self.data = []
641
+ with open(self.video_jsonl_file, "r", encoding="utf-8") as f:
642
+ for line in f:
643
+ data = json.loads(line)
644
+ self.data.append(data)
645
+ logger.info(f"loaded {len(self.data)} videos")
646
+
647
+ def is_indexable(self):
648
+ return True
649
+
650
+ def __len__(self):
651
+ return len(self.data)
652
+
653
+ def get_video_data(
654
+ self,
655
+ idx: int,
656
+ start_frame: Optional[int] = None,
657
+ end_frame: Optional[int] = None,
658
+ bucket_selector: Optional[BucketSelector] = None,
659
+ ) -> tuple[str, list[Image.Image], str]:
660
+ data = self.data[idx]
661
+ video_path = data["video_path"]
662
+ video = self.get_video_data_from_path(video_path, start_frame, end_frame, bucket_selector)
663
+
664
+ caption = data["caption"]
665
+
666
+ return video_path, video, caption
667
+
668
+ def get_caption(self, idx: int) -> tuple[str, str]:
669
+ data = self.data[idx]
670
+ video_path = data["video_path"]
671
+ caption = data["caption"]
672
+ return video_path, caption
673
+
674
+ def __iter__(self):
675
+ self.current_idx = 0
676
+ return self
677
+
678
+ def __next__(self):
679
+ if self.current_idx >= len(self.data):
680
+ raise StopIteration
681
+
682
+ if self.caption_only:
683
+
684
+ def create_caption_fetcher(index):
685
+ return lambda: self.get_caption(index)
686
+
687
+ fetcher = create_caption_fetcher(self.current_idx)
688
+
689
+ else:
690
+
691
+ def create_fetcher(index):
692
+ return lambda: self.get_video_data(index)
693
+
694
+ fetcher = create_fetcher(self.current_idx)
695
+
696
+ self.current_idx += 1
697
+ return fetcher
698
+
699
+
700
+ class BaseDataset(torch.utils.data.Dataset):
701
+ def __init__(
702
+ self,
703
+ resolution: Tuple[int, int] = (960, 544),
704
+ caption_extension: Optional[str] = None,
705
+ batch_size: int = 1,
706
+ enable_bucket: bool = False,
707
+ bucket_no_upscale: bool = False,
708
+ cache_directory: Optional[str] = None,
709
+ debug_dataset: bool = False,
710
+ ):
711
+ self.resolution = resolution
712
+ self.caption_extension = caption_extension
713
+ self.batch_size = batch_size
714
+ self.enable_bucket = enable_bucket
715
+ self.bucket_no_upscale = bucket_no_upscale
716
+ self.cache_directory = cache_directory
717
+ self.debug_dataset = debug_dataset
718
+ self.seed = None
719
+ self.current_epoch = 0
720
+
721
+ if not self.enable_bucket:
722
+ self.bucket_no_upscale = False
723
+
724
+ def get_metadata(self) -> dict:
725
+ metadata = {
726
+ "resolution": self.resolution,
727
+ "caption_extension": self.caption_extension,
728
+ "batch_size_per_device": self.batch_size,
729
+ "enable_bucket": bool(self.enable_bucket),
730
+ "bucket_no_upscale": bool(self.bucket_no_upscale),
731
+ }
732
+ return metadata
733
+
734
+ def get_latent_cache_path(self, item_info: ItemInfo) -> str:
735
+ w, h = item_info.original_size
736
+ basename = os.path.splitext(os.path.basename(item_info.item_key))[0]
737
+ assert self.cache_directory is not None, "cache_directory is required / cache_directoryは必須です"
738
+ return os.path.join(self.cache_directory, f"{basename}_{w:04d}x{h:04d}_{ARCHITECTURE_HUNYUAN_VIDEO}.safetensors")
739
+
740
+ def get_text_encoder_output_cache_path(self, item_info: ItemInfo) -> str:
741
+ basename = os.path.splitext(os.path.basename(item_info.item_key))[0]
742
+ assert self.cache_directory is not None, "cache_directory is required / cache_directoryは必須です"
743
+ return os.path.join(self.cache_directory, f"{basename}_{ARCHITECTURE_HUNYUAN_VIDEO}_te.safetensors")
744
+
745
+ def retrieve_latent_cache_batches(self, num_workers: int):
746
+ raise NotImplementedError
747
+
748
+ def retrieve_text_encoder_output_cache_batches(self, num_workers: int):
749
+ raise NotImplementedError
750
+
751
+ def prepare_for_training(self):
752
+ pass
753
+
754
+ def set_seed(self, seed: int):
755
+ self.seed = seed
756
+
757
+ def set_current_epoch(self, epoch):
758
+ if not self.current_epoch == epoch: # shuffle buckets when epoch is incremented
759
+ if epoch > self.current_epoch:
760
+ logger.info("epoch is incremented. current_epoch: {}, epoch: {}".format(self.current_epoch, epoch))
761
+ num_epochs = epoch - self.current_epoch
762
+ for _ in range(num_epochs):
763
+ self.current_epoch += 1
764
+ self.shuffle_buckets()
765
+ # self.current_epoch seem to be set to 0 again in the next epoch. it may be caused by skipped_dataloader?
766
+ else:
767
+ logger.warning("epoch is not incremented. current_epoch: {}, epoch: {}".format(self.current_epoch, epoch))
768
+ self.current_epoch = epoch
769
+
770
+ def set_current_step(self, step):
771
+ self.current_step = step
772
+
773
+ def set_max_train_steps(self, max_train_steps):
774
+ self.max_train_steps = max_train_steps
775
+
776
+ def shuffle_buckets(self):
777
+ raise NotImplementedError
778
+
779
+ def __len__(self):
780
+ return NotImplementedError
781
+
782
+ def __getitem__(self, idx):
783
+ raise NotImplementedError
784
+
785
+ def _default_retrieve_text_encoder_output_cache_batches(self, datasource: ContentDatasource, batch_size: int, num_workers: int):
786
+ datasource.set_caption_only(True)
787
+ executor = ThreadPoolExecutor(max_workers=num_workers)
788
+
789
+ data: list[ItemInfo] = []
790
+ futures = []
791
+
792
+ def aggregate_future(consume_all: bool = False):
793
+ while len(futures) >= num_workers or (consume_all and len(futures) > 0):
794
+ completed_futures = [future for future in futures if future.done()]
795
+ if len(completed_futures) == 0:
796
+ if len(futures) >= num_workers or consume_all: # to avoid adding too many futures
797
+ time.sleep(0.1)
798
+ continue
799
+ else:
800
+ break # submit batch if possible
801
+
802
+ for future in completed_futures:
803
+ item_key, caption = future.result()
804
+ item_info = ItemInfo(item_key, caption, (0, 0), (0, 0))
805
+ item_info.text_encoder_output_cache_path = self.get_text_encoder_output_cache_path(item_info)
806
+ data.append(item_info)
807
+
808
+ futures.remove(future)
809
+
810
+ def submit_batch(flush: bool = False):
811
+ nonlocal data
812
+ if len(data) >= batch_size or (len(data) > 0 and flush):
813
+ batch = data[0:batch_size]
814
+ if len(data) > batch_size:
815
+ data = data[batch_size:]
816
+ else:
817
+ data = []
818
+ return batch
819
+ return None
820
+
821
+ for fetch_op in datasource:
822
+ future = executor.submit(fetch_op)
823
+ futures.append(future)
824
+ aggregate_future()
825
+ while True:
826
+ batch = submit_batch()
827
+ if batch is None:
828
+ break
829
+ yield batch
830
+
831
+ aggregate_future(consume_all=True)
832
+ while True:
833
+ batch = submit_batch(flush=True)
834
+ if batch is None:
835
+ break
836
+ yield batch
837
+
838
+ executor.shutdown()
839
+
840
+
841
+ class ImageDataset(BaseDataset):
842
+ def __init__(
843
+ self,
844
+ resolution: Tuple[int, int],
845
+ caption_extension: Optional[str],
846
+ batch_size: int,
847
+ enable_bucket: bool,
848
+ bucket_no_upscale: bool,
849
+ image_directory: Optional[str] = None,
850
+ image_jsonl_file: Optional[str] = None,
851
+ cache_directory: Optional[str] = None,
852
+ debug_dataset: bool = False,
853
+ ):
854
+ super(ImageDataset, self).__init__(
855
+ resolution, caption_extension, batch_size, enable_bucket, bucket_no_upscale, cache_directory, debug_dataset
856
+ )
857
+ self.image_directory = image_directory
858
+ self.image_jsonl_file = image_jsonl_file
859
+ if image_directory is not None:
860
+ self.datasource = ImageDirectoryDatasource(image_directory, caption_extension)
861
+ elif image_jsonl_file is not None:
862
+ self.datasource = ImageJsonlDatasource(image_jsonl_file)
863
+ else:
864
+ raise ValueError("image_directory or image_jsonl_file must be specified")
865
+
866
+ if self.cache_directory is None:
867
+ self.cache_directory = self.image_directory
868
+
869
+ self.batch_manager = None
870
+ self.num_train_items = 0
871
+
872
+ def get_metadata(self):
873
+ metadata = super().get_metadata()
874
+ if self.image_directory is not None:
875
+ metadata["image_directory"] = os.path.basename(self.image_directory)
876
+ if self.image_jsonl_file is not None:
877
+ metadata["image_jsonl_file"] = os.path.basename(self.image_jsonl_file)
878
+ return metadata
879
+
880
+ def get_total_image_count(self):
881
+ return len(self.datasource) if self.datasource.is_indexable() else None
882
+
883
+ def retrieve_latent_cache_batches(self, num_workers: int):
884
+ buckset_selector = BucketSelector(self.resolution, self.enable_bucket, self.bucket_no_upscale)
885
+ executor = ThreadPoolExecutor(max_workers=num_workers)
886
+
887
+ batches: dict[tuple[int, int], list[ItemInfo]] = {} # (width, height) -> [ItemInfo]
888
+ futures = []
889
+
890
+ def aggregate_future(consume_all: bool = False):
891
+ while len(futures) >= num_workers or (consume_all and len(futures) > 0):
892
+ completed_futures = [future for future in futures if future.done()]
893
+ if len(completed_futures) == 0:
894
+ if len(futures) >= num_workers or consume_all: # to avoid adding too many futures
895
+ time.sleep(0.1)
896
+ continue
897
+ else:
898
+ break # submit batch if possible
899
+
900
+ for future in completed_futures:
901
+ original_size, item_key, image, caption = future.result()
902
+ bucket_height, bucket_width = image.shape[:2]
903
+ bucket_reso = (bucket_width, bucket_height)
904
+
905
+ item_info = ItemInfo(item_key, caption, original_size, bucket_reso, content=image)
906
+ item_info.latent_cache_path = self.get_latent_cache_path(item_info)
907
+
908
+ if bucket_reso not in batches:
909
+ batches[bucket_reso] = []
910
+ batches[bucket_reso].append(item_info)
911
+
912
+ futures.remove(future)
913
+
914
+ def submit_batch(flush: bool = False):
915
+ for key in batches:
916
+ if len(batches[key]) >= self.batch_size or flush:
917
+ batch = batches[key][0 : self.batch_size]
918
+ if len(batches[key]) > self.batch_size:
919
+ batches[key] = batches[key][self.batch_size :]
920
+ else:
921
+ del batches[key]
922
+ return key, batch
923
+ return None, None
924
+
925
+ for fetch_op in self.datasource:
926
+
927
+ def fetch_and_resize(op: callable) -> tuple[tuple[int, int], str, Image.Image, str]:
928
+ image_key, image, caption = op()
929
+ image: Image.Image
930
+ image_size = image.size
931
+
932
+ bucket_reso = buckset_selector.get_bucket_resolution(image_size)
933
+ image = resize_image_to_bucket(image, bucket_reso)
934
+ return image_size, image_key, image, caption
935
+
936
+ future = executor.submit(fetch_and_resize, fetch_op)
937
+ futures.append(future)
938
+ aggregate_future()
939
+ while True:
940
+ key, batch = submit_batch()
941
+ if key is None:
942
+ break
943
+ yield key, batch
944
+
945
+ aggregate_future(consume_all=True)
946
+ while True:
947
+ key, batch = submit_batch(flush=True)
948
+ if key is None:
949
+ break
950
+ yield key, batch
951
+
952
+ executor.shutdown()
953
+
954
+ def retrieve_text_encoder_output_cache_batches(self, num_workers: int):
955
+ return self._default_retrieve_text_encoder_output_cache_batches(self.datasource, self.batch_size, num_workers)
956
+
957
+ def prepare_for_training(self):
958
+ bucket_selector = BucketSelector(self.resolution, self.enable_bucket, self.bucket_no_upscale)
959
+
960
+ # glob cache files
961
+ latent_cache_files = glob.glob(os.path.join(self.cache_directory, f"*_{ARCHITECTURE_HUNYUAN_VIDEO}.safetensors"))
962
+
963
+ # assign cache files to item info
964
+ bucketed_item_info: dict[tuple[int, int], list[ItemInfo]] = {} # (width, height) -> [ItemInfo]
965
+ for cache_file in latent_cache_files:
966
+ tokens = os.path.basename(cache_file).split("_")
967
+
968
+ image_size = tokens[-2] # 0000x0000
969
+ image_width, image_height = map(int, image_size.split("x"))
970
+ image_size = (image_width, image_height)
971
+
972
+ item_key = "_".join(tokens[:-2])
973
+ text_encoder_output_cache_file = os.path.join(
974
+ self.cache_directory, f"{item_key}_{ARCHITECTURE_HUNYUAN_VIDEO}_te.safetensors"
975
+ )
976
+ if not os.path.exists(text_encoder_output_cache_file):
977
+ logger.warning(f"Text encoder output cache file not found: {text_encoder_output_cache_file}")
978
+ continue
979
+
980
+ bucket_reso = bucket_selector.get_bucket_resolution(image_size)
981
+ item_info = ItemInfo(item_key, "", image_size, bucket_reso, latent_cache_path=cache_file)
982
+ item_info.text_encoder_output_cache_path = text_encoder_output_cache_file
983
+
984
+ bucket = bucketed_item_info.get(bucket_reso, [])
985
+ bucket.append(item_info)
986
+ bucketed_item_info[bucket_reso] = bucket
987
+
988
+ # prepare batch manager
989
+ self.batch_manager = BucketBatchManager(bucketed_item_info, self.batch_size)
990
+ self.batch_manager.show_bucket_info()
991
+
992
+ self.num_train_items = sum([len(bucket) for bucket in bucketed_item_info.values()])
993
+
994
+ def shuffle_buckets(self):
995
+ # set random seed for this epoch
996
+ random.seed(self.seed + self.current_epoch)
997
+ self.batch_manager.shuffle()
998
+
999
+ def __len__(self):
1000
+ if self.batch_manager is None:
1001
+ return 100 # dummy value
1002
+ return len(self.batch_manager)
1003
+
1004
+ def __getitem__(self, idx):
1005
+ return self.batch_manager[idx]
1006
+
1007
+
1008
+ class VideoDataset(BaseDataset):
1009
+ def __init__(
1010
+ self,
1011
+ resolution: Tuple[int, int],
1012
+ caption_extension: Optional[str],
1013
+ batch_size: int,
1014
+ enable_bucket: bool,
1015
+ bucket_no_upscale: bool,
1016
+ frame_extraction: Optional[str] = "head",
1017
+ frame_stride: Optional[int] = 1,
1018
+ frame_sample: Optional[int] = 1,
1019
+ target_frames: Optional[list[int]] = None,
1020
+ video_directory: Optional[str] = None,
1021
+ video_jsonl_file: Optional[str] = None,
1022
+ cache_directory: Optional[str] = None,
1023
+ debug_dataset: bool = False,
1024
+ ):
1025
+ super(VideoDataset, self).__init__(
1026
+ resolution, caption_extension, batch_size, enable_bucket, bucket_no_upscale, cache_directory, debug_dataset
1027
+ )
1028
+ self.video_directory = video_directory
1029
+ self.video_jsonl_file = video_jsonl_file
1030
+ self.target_frames = target_frames
1031
+ self.frame_extraction = frame_extraction
1032
+ self.frame_stride = frame_stride
1033
+ self.frame_sample = frame_sample
1034
+
1035
+ if video_directory is not None:
1036
+ self.datasource = VideoDirectoryDatasource(video_directory, caption_extension)
1037
+ elif video_jsonl_file is not None:
1038
+ self.datasource = VideoJsonlDatasource(video_jsonl_file)
1039
+
1040
+ if self.frame_extraction == "uniform" and self.frame_sample == 1:
1041
+ self.frame_extraction = "head"
1042
+ logger.warning("frame_sample is set to 1 for frame_extraction=uniform. frame_extraction is changed to head.")
1043
+ if self.frame_extraction == "head":
1044
+ # head extraction. we can limit the number of frames to be extracted
1045
+ self.datasource.set_start_and_end_frame(0, max(self.target_frames))
1046
+
1047
+ if self.cache_directory is None:
1048
+ self.cache_directory = self.video_directory
1049
+
1050
+ self.batch_manager = None
1051
+ self.num_train_items = 0
1052
+
1053
+ def get_metadata(self):
1054
+ metadata = super().get_metadata()
1055
+ if self.video_directory is not None:
1056
+ metadata["video_directory"] = os.path.basename(self.video_directory)
1057
+ if self.video_jsonl_file is not None:
1058
+ metadata["video_jsonl_file"] = os.path.basename(self.video_jsonl_file)
1059
+ metadata["frame_extraction"] = self.frame_extraction
1060
+ metadata["frame_stride"] = self.frame_stride
1061
+ metadata["frame_sample"] = self.frame_sample
1062
+ metadata["target_frames"] = self.target_frames
1063
+ return metadata
1064
+
1065
+ def retrieve_latent_cache_batches(self, num_workers: int):
1066
+ buckset_selector = BucketSelector(self.resolution)
1067
+ self.datasource.set_bucket_selector(buckset_selector)
1068
+
1069
+ executor = ThreadPoolExecutor(max_workers=num_workers)
1070
+
1071
+ # key: (width, height, frame_count), value: [ItemInfo]
1072
+ batches: dict[tuple[int, int, int], list[ItemInfo]] = {}
1073
+ futures = []
1074
+
1075
+ def aggregate_future(consume_all: bool = False):
1076
+ while len(futures) >= num_workers or (consume_all and len(futures) > 0):
1077
+ completed_futures = [future for future in futures if future.done()]
1078
+ if len(completed_futures) == 0:
1079
+ if len(futures) >= num_workers or consume_all: # to avoid adding too many futures
1080
+ time.sleep(0.1)
1081
+ continue
1082
+ else:
1083
+ break # submit batch if possible
1084
+
1085
+ for future in completed_futures:
1086
+ original_frame_size, video_key, video, caption = future.result()
1087
+
1088
+ frame_count = len(video)
1089
+ video = np.stack(video, axis=0)
1090
+ height, width = video.shape[1:3]
1091
+ bucket_reso = (width, height) # already resized
1092
+
1093
+ crop_pos_and_frames = []
1094
+ if self.frame_extraction == "head":
1095
+ for target_frame in self.target_frames:
1096
+ if frame_count >= target_frame:
1097
+ crop_pos_and_frames.append((0, target_frame))
1098
+ elif self.frame_extraction == "chunk":
1099
+ # split by target_frames
1100
+ for target_frame in self.target_frames:
1101
+ for i in range(0, frame_count, target_frame):
1102
+ if i + target_frame <= frame_count:
1103
+ crop_pos_and_frames.append((i, target_frame))
1104
+ elif self.frame_extraction == "slide":
1105
+ # slide window
1106
+ for target_frame in self.target_frames:
1107
+ if frame_count >= target_frame:
1108
+ for i in range(0, frame_count - target_frame + 1, self.frame_stride):
1109
+ crop_pos_and_frames.append((i, target_frame))
1110
+ elif self.frame_extraction == "uniform":
1111
+ # select N frames uniformly
1112
+ for target_frame in self.target_frames:
1113
+ if frame_count >= target_frame:
1114
+ frame_indices = np.linspace(0, frame_count - target_frame, self.frame_sample, dtype=int)
1115
+ for i in frame_indices:
1116
+ crop_pos_and_frames.append((i, target_frame))
1117
+ else:
1118
+ raise ValueError(f"frame_extraction {self.frame_extraction} is not supported")
1119
+
1120
+ for crop_pos, target_frame in crop_pos_and_frames:
1121
+ cropped_video = video[crop_pos : crop_pos + target_frame]
1122
+ body, ext = os.path.splitext(video_key)
1123
+ item_key = f"{body}_{crop_pos:05d}-{target_frame:03d}{ext}"
1124
+ batch_key = (*bucket_reso, target_frame) # bucket_reso with frame_count
1125
+
1126
+ item_info = ItemInfo(
1127
+ item_key, caption, original_frame_size, batch_key, frame_count=target_frame, content=cropped_video
1128
+ )
1129
+ item_info.latent_cache_path = self.get_latent_cache_path(item_info)
1130
+
1131
+ batch = batches.get(batch_key, [])
1132
+ batch.append(item_info)
1133
+ batches[batch_key] = batch
1134
+
1135
+ futures.remove(future)
1136
+
1137
+ def submit_batch(flush: bool = False):
1138
+ for key in batches:
1139
+ if len(batches[key]) >= self.batch_size or flush:
1140
+ batch = batches[key][0 : self.batch_size]
1141
+ if len(batches[key]) > self.batch_size:
1142
+ batches[key] = batches[key][self.batch_size :]
1143
+ else:
1144
+ del batches[key]
1145
+ return key, batch
1146
+ return None, None
1147
+
1148
+ for operator in self.datasource:
1149
+
1150
+ def fetch_and_resize(op: callable) -> tuple[tuple[int, int], str, list[np.ndarray], str]:
1151
+ video_key, video, caption = op()
1152
+ video: list[np.ndarray]
1153
+ frame_size = (video[0].shape[1], video[0].shape[0])
1154
+
1155
+ # resize if necessary
1156
+ bucket_reso = buckset_selector.get_bucket_resolution(frame_size)
1157
+ video = [resize_image_to_bucket(frame, bucket_reso) for frame in video]
1158
+
1159
+ return frame_size, video_key, video, caption
1160
+
1161
+ future = executor.submit(fetch_and_resize, operator)
1162
+ futures.append(future)
1163
+ aggregate_future()
1164
+ while True:
1165
+ key, batch = submit_batch()
1166
+ if key is None:
1167
+ break
1168
+ yield key, batch
1169
+
1170
+ aggregate_future(consume_all=True)
1171
+ while True:
1172
+ key, batch = submit_batch(flush=True)
1173
+ if key is None:
1174
+ break
1175
+ yield key, batch
1176
+
1177
+ executor.shutdown()
1178
+
1179
+ def retrieve_text_encoder_output_cache_batches(self, num_workers: int):
1180
+ return self._default_retrieve_text_encoder_output_cache_batches(self.datasource, self.batch_size, num_workers)
1181
+
1182
+ def prepare_for_training(self):
1183
+ bucket_selector = BucketSelector(self.resolution, self.enable_bucket, self.bucket_no_upscale)
1184
+
1185
+ # glob cache files
1186
+ latent_cache_files = glob.glob(os.path.join(self.cache_directory, f"*_{ARCHITECTURE_HUNYUAN_VIDEO}.safetensors"))
1187
+
1188
+ # assign cache files to item info
1189
+ bucketed_item_info: dict[tuple[int, int, int], list[ItemInfo]] = {} # (width, height, frame_count) -> [ItemInfo]
1190
+ for cache_file in latent_cache_files:
1191
+ tokens = os.path.basename(cache_file).split("_")
1192
+
1193
+ image_size = tokens[-2] # 0000x0000
1194
+ image_width, image_height = map(int, image_size.split("x"))
1195
+ image_size = (image_width, image_height)
1196
+
1197
+ frame_pos, frame_count = tokens[-3].split("-")
1198
+ frame_pos, frame_count = int(frame_pos), int(frame_count)
1199
+
1200
+ item_key = "_".join(tokens[:-3])
1201
+ text_encoder_output_cache_file = os.path.join(
1202
+ self.cache_directory, f"{item_key}_{ARCHITECTURE_HUNYUAN_VIDEO}_te.safetensors"
1203
+ )
1204
+ if not os.path.exists(text_encoder_output_cache_file):
1205
+ logger.warning(f"Text encoder output cache file not found: {text_encoder_output_cache_file}")
1206
+ continue
1207
+
1208
+ bucket_reso = bucket_selector.get_bucket_resolution(image_size)
1209
+ bucket_reso = (*bucket_reso, frame_count)
1210
+ item_info = ItemInfo(item_key, "", image_size, bucket_reso, frame_count=frame_count, latent_cache_path=cache_file)
1211
+ item_info.text_encoder_output_cache_path = text_encoder_output_cache_file
1212
+
1213
+ bucket = bucketed_item_info.get(bucket_reso, [])
1214
+ bucket.append(item_info)
1215
+ bucketed_item_info[bucket_reso] = bucket
1216
+
1217
+ # prepare batch manager
1218
+ self.batch_manager = BucketBatchManager(bucketed_item_info, self.batch_size)
1219
+ self.batch_manager.show_bucket_info()
1220
+
1221
+ self.num_train_items = sum([len(bucket) for bucket in bucketed_item_info.values()])
1222
+
1223
+ def shuffle_buckets(self):
1224
+ # set random seed for this epoch
1225
+ random.seed(self.seed + self.current_epoch)
1226
+ self.batch_manager.shuffle()
1227
+
1228
+ def __len__(self):
1229
+ if self.batch_manager is None:
1230
+ return 100 # dummy value
1231
+ return len(self.batch_manager)
1232
+
1233
+ def __getitem__(self, idx):
1234
+ return self.batch_manager[idx]
1235
+
1236
+
1237
+ class DatasetGroup(torch.utils.data.ConcatDataset):
1238
+ def __init__(self, datasets: Sequence[Union[ImageDataset, VideoDataset]]):
1239
+ super().__init__(datasets)
1240
+ self.datasets: list[Union[ImageDataset, VideoDataset]] = datasets
1241
+ self.num_train_items = 0
1242
+ for dataset in self.datasets:
1243
+ self.num_train_items += dataset.num_train_items
1244
+
1245
+ def set_current_epoch(self, epoch):
1246
+ for dataset in self.datasets:
1247
+ dataset.set_current_epoch(epoch)
1248
+
1249
+ def set_current_step(self, step):
1250
+ for dataset in self.datasets:
1251
+ dataset.set_current_step(step)
1252
+
1253
+ def set_max_train_steps(self, max_train_steps):
1254
+ for dataset in self.datasets:
1255
+ dataset.set_max_train_steps(max_train_steps)
hunyuan_model/__init__.py ADDED
File without changes
hunyuan_model/activation_layers.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+
4
+ def get_activation_layer(act_type):
5
+ """get activation layer
6
+
7
+ Args:
8
+ act_type (str): the activation type
9
+
10
+ Returns:
11
+ torch.nn.functional: the activation layer
12
+ """
13
+ if act_type == "gelu":
14
+ return lambda: nn.GELU()
15
+ elif act_type == "gelu_tanh":
16
+ # Approximate `tanh` requires torch >= 1.13
17
+ return lambda: nn.GELU(approximate="tanh")
18
+ elif act_type == "relu":
19
+ return nn.ReLU
20
+ elif act_type == "silu":
21
+ return nn.SiLU
22
+ else:
23
+ raise ValueError(f"Unknown activation type: {act_type}")
hunyuan_model/attention.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib.metadata
2
+ import math
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ try:
9
+ import flash_attn
10
+ from flash_attn.flash_attn_interface import _flash_attn_forward
11
+ from flash_attn.flash_attn_interface import flash_attn_varlen_func
12
+ except ImportError:
13
+ flash_attn = None
14
+ flash_attn_varlen_func = None
15
+ _flash_attn_forward = None
16
+
17
+ try:
18
+ print(f"Trying to import sageattention")
19
+ from sageattention import sageattn_varlen
20
+
21
+ print("Successfully imported sageattention")
22
+ except ImportError:
23
+ print(f"Failed to import flash_attn and sageattention")
24
+ sageattn_varlen = None
25
+
26
+ MEMORY_LAYOUT = {
27
+ "flash": (
28
+ lambda x: x.view(x.shape[0] * x.shape[1], *x.shape[2:]),
29
+ lambda x: x,
30
+ ),
31
+ "sageattn": (
32
+ lambda x: x.view(x.shape[0] * x.shape[1], *x.shape[2:]),
33
+ lambda x: x,
34
+ ),
35
+ "torch": (
36
+ lambda x: x.transpose(1, 2),
37
+ lambda x: x.transpose(1, 2),
38
+ ),
39
+ "vanilla": (
40
+ lambda x: x.transpose(1, 2),
41
+ lambda x: x.transpose(1, 2),
42
+ ),
43
+ }
44
+
45
+
46
+ def get_cu_seqlens(text_mask, img_len):
47
+ """Calculate cu_seqlens_q, cu_seqlens_kv using text_mask and img_len
48
+
49
+ Args:
50
+ text_mask (torch.Tensor): the mask of text
51
+ img_len (int): the length of image
52
+
53
+ Returns:
54
+ torch.Tensor: the calculated cu_seqlens for flash attention
55
+ """
56
+ batch_size = text_mask.shape[0]
57
+ text_len = text_mask.sum(dim=1)
58
+ max_len = text_mask.shape[1] + img_len
59
+
60
+ cu_seqlens = torch.zeros([2 * batch_size + 1], dtype=torch.int32, device="cuda")
61
+
62
+ for i in range(batch_size):
63
+ s = text_len[i] + img_len
64
+ s1 = i * max_len + s
65
+ s2 = (i + 1) * max_len
66
+ cu_seqlens[2 * i + 1] = s1
67
+ cu_seqlens[2 * i + 2] = s2
68
+
69
+ return cu_seqlens
70
+
71
+
72
+ def attention(
73
+ q_or_qkv_list,
74
+ k=None,
75
+ v=None,
76
+ mode="flash",
77
+ drop_rate=0,
78
+ attn_mask=None,
79
+ causal=False,
80
+ cu_seqlens_q=None,
81
+ cu_seqlens_kv=None,
82
+ max_seqlen_q=None,
83
+ max_seqlen_kv=None,
84
+ batch_size=1,
85
+ ):
86
+ """
87
+ Perform QKV self attention.
88
+
89
+ Args:
90
+ q (torch.Tensor): Query tensor with shape [b, s, a, d], where a is the number of heads.
91
+ k (torch.Tensor): Key tensor with shape [b, s1, a, d]
92
+ v (torch.Tensor): Value tensor with shape [b, s1, a, d]
93
+ mode (str): Attention mode. Choose from 'self_flash', 'cross_flash', 'torch', and 'vanilla'.
94
+ drop_rate (float): Dropout rate in attention map. (default: 0)
95
+ attn_mask (torch.Tensor): Attention mask with shape [b, s1] (cross_attn), or [b, a, s, s1] (torch or vanilla).
96
+ (default: None)
97
+ causal (bool): Whether to use causal attention. (default: False)
98
+ cu_seqlens_q (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
99
+ used to index into q.
100
+ cu_seqlens_kv (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
101
+ used to index into kv.
102
+ max_seqlen_q (int): The maximum sequence length in the batch of q.
103
+ max_seqlen_kv (int): The maximum sequence length in the batch of k and v.
104
+
105
+ Returns:
106
+ torch.Tensor: Output tensor after self attention with shape [b, s, ad]
107
+ """
108
+ q, k, v = q_or_qkv_list if type(q_or_qkv_list) == list else (q_or_qkv_list, k, v)
109
+ pre_attn_layout, post_attn_layout = MEMORY_LAYOUT[mode]
110
+ q = pre_attn_layout(q)
111
+ k = pre_attn_layout(k)
112
+ v = pre_attn_layout(v)
113
+
114
+ if mode == "torch":
115
+ if attn_mask is not None and attn_mask.dtype != torch.bool:
116
+ attn_mask = attn_mask.to(q.dtype)
117
+ x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal)
118
+ if type(q_or_qkv_list) == list:
119
+ q_or_qkv_list.clear()
120
+ del q, k, v
121
+ del attn_mask
122
+ elif mode == "flash":
123
+ x = flash_attn_varlen_func(
124
+ q,
125
+ k,
126
+ v,
127
+ cu_seqlens_q,
128
+ cu_seqlens_kv,
129
+ max_seqlen_q,
130
+ max_seqlen_kv,
131
+ )
132
+ if type(q_or_qkv_list) == list:
133
+ q_or_qkv_list.clear()
134
+ del q, k, v
135
+ # x with shape [(bxs), a, d]
136
+ x = x.view(batch_size, max_seqlen_q, x.shape[-2], x.shape[-1]) # reshape x to [b, s, a, d]
137
+ elif mode == "sageattn":
138
+ x = sageattn_varlen(
139
+ q,
140
+ k,
141
+ v,
142
+ cu_seqlens_q,
143
+ cu_seqlens_kv,
144
+ max_seqlen_q,
145
+ max_seqlen_kv,
146
+ )
147
+ if type(q_or_qkv_list) == list:
148
+ q_or_qkv_list.clear()
149
+ del q, k, v
150
+ # x with shape [(bxs), a, d]
151
+ x = x.view(batch_size, max_seqlen_q, x.shape[-2], x.shape[-1]) # reshape x to [b, s, a, d]
152
+ elif mode == "vanilla":
153
+ scale_factor = 1 / math.sqrt(q.size(-1))
154
+
155
+ b, a, s, _ = q.shape
156
+ s1 = k.size(2)
157
+ attn_bias = torch.zeros(b, a, s, s1, dtype=q.dtype, device=q.device)
158
+ if causal:
159
+ # Only applied to self attention
160
+ assert attn_mask is None, "Causal mask and attn_mask cannot be used together"
161
+ temp_mask = torch.ones(b, a, s, s, dtype=torch.bool, device=q.device).tril(diagonal=0)
162
+ attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
163
+ attn_bias.to(q.dtype)
164
+
165
+ if attn_mask is not None:
166
+ if attn_mask.dtype == torch.bool:
167
+ attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
168
+ else:
169
+ attn_bias += attn_mask
170
+
171
+ # TODO: Maybe force q and k to be float32 to avoid numerical overflow
172
+ attn = (q @ k.transpose(-2, -1)) * scale_factor
173
+ attn += attn_bias
174
+ attn = attn.softmax(dim=-1)
175
+ attn = torch.dropout(attn, p=drop_rate, train=True)
176
+ x = attn @ v
177
+ else:
178
+ raise NotImplementedError(f"Unsupported attention mode: {mode}")
179
+
180
+ x = post_attn_layout(x)
181
+ b, s, a, d = x.shape
182
+ out = x.reshape(b, s, -1)
183
+ return out
184
+
185
+
186
+ def parallel_attention(hybrid_seq_parallel_attn, q, k, v, img_q_len, img_kv_len, cu_seqlens_q, cu_seqlens_kv):
187
+ attn1 = hybrid_seq_parallel_attn(
188
+ None,
189
+ q[:, :img_q_len, :, :],
190
+ k[:, :img_kv_len, :, :],
191
+ v[:, :img_kv_len, :, :],
192
+ dropout_p=0.0,
193
+ causal=False,
194
+ joint_tensor_query=q[:, img_q_len : cu_seqlens_q[1]],
195
+ joint_tensor_key=k[:, img_kv_len : cu_seqlens_kv[1]],
196
+ joint_tensor_value=v[:, img_kv_len : cu_seqlens_kv[1]],
197
+ joint_strategy="rear",
198
+ )
199
+ if flash_attn.__version__ >= "2.7.0":
200
+ attn2, *_ = _flash_attn_forward(
201
+ q[:, cu_seqlens_q[1] :],
202
+ k[:, cu_seqlens_kv[1] :],
203
+ v[:, cu_seqlens_kv[1] :],
204
+ dropout_p=0.0,
205
+ softmax_scale=q.shape[-1] ** (-0.5),
206
+ causal=False,
207
+ window_size_left=-1,
208
+ window_size_right=-1,
209
+ softcap=0.0,
210
+ alibi_slopes=None,
211
+ return_softmax=False,
212
+ )
213
+ else:
214
+ attn2, *_ = _flash_attn_forward(
215
+ q[:, cu_seqlens_q[1] :],
216
+ k[:, cu_seqlens_kv[1] :],
217
+ v[:, cu_seqlens_kv[1] :],
218
+ dropout_p=0.0,
219
+ softmax_scale=q.shape[-1] ** (-0.5),
220
+ causal=False,
221
+ window_size=(-1, -1),
222
+ softcap=0.0,
223
+ alibi_slopes=None,
224
+ return_softmax=False,
225
+ )
226
+ attn = torch.cat([attn1, attn2], dim=1)
227
+ b, s, a, d = attn.shape
228
+ attn = attn.reshape(b, s, -1)
229
+
230
+ return attn
hunyuan_model/autoencoder_kl_causal_3d.py ADDED
@@ -0,0 +1,609 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ #
16
+ # Modified from diffusers==0.29.2
17
+ #
18
+ # ==============================================================================
19
+ from typing import Dict, Optional, Tuple, Union
20
+ from dataclasses import dataclass
21
+
22
+ import torch
23
+ import torch.nn as nn
24
+
25
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
26
+
27
+ try:
28
+ # This diffusers is modified and packed in the mirror.
29
+ from diffusers.loaders import FromOriginalVAEMixin
30
+ except ImportError:
31
+ # Use this to be compatible with the original diffusers.
32
+ from diffusers.loaders.single_file_model import FromOriginalModelMixin as FromOriginalVAEMixin
33
+ from diffusers.utils.accelerate_utils import apply_forward_hook
34
+ from diffusers.models.attention_processor import (
35
+ ADDED_KV_ATTENTION_PROCESSORS,
36
+ CROSS_ATTENTION_PROCESSORS,
37
+ Attention,
38
+ AttentionProcessor,
39
+ AttnAddedKVProcessor,
40
+ AttnProcessor,
41
+ )
42
+ from diffusers.models.modeling_outputs import AutoencoderKLOutput
43
+ from diffusers.models.modeling_utils import ModelMixin
44
+ from .vae import DecoderCausal3D, BaseOutput, DecoderOutput, DiagonalGaussianDistribution, EncoderCausal3D
45
+
46
+
47
+ @dataclass
48
+ class DecoderOutput2(BaseOutput):
49
+ sample: torch.FloatTensor
50
+ posterior: Optional[DiagonalGaussianDistribution] = None
51
+
52
+
53
+ class AutoencoderKLCausal3D(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
54
+ r"""
55
+ A VAE model with KL loss for encoding images/videos into latents and decoding latent representations into images/videos.
56
+
57
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
58
+ for all models (such as downloading or saving).
59
+ """
60
+
61
+ _supports_gradient_checkpointing = True
62
+
63
+ @register_to_config
64
+ def __init__(
65
+ self,
66
+ in_channels: int = 3,
67
+ out_channels: int = 3,
68
+ down_block_types: Tuple[str] = ("DownEncoderBlockCausal3D",),
69
+ up_block_types: Tuple[str] = ("UpDecoderBlockCausal3D",),
70
+ block_out_channels: Tuple[int] = (64,),
71
+ layers_per_block: int = 1,
72
+ act_fn: str = "silu",
73
+ latent_channels: int = 4,
74
+ norm_num_groups: int = 32,
75
+ sample_size: int = 32,
76
+ sample_tsize: int = 64,
77
+ scaling_factor: float = 0.18215,
78
+ force_upcast: float = True,
79
+ spatial_compression_ratio: int = 8,
80
+ time_compression_ratio: int = 4,
81
+ mid_block_add_attention: bool = True,
82
+ ):
83
+ super().__init__()
84
+
85
+ self.time_compression_ratio = time_compression_ratio
86
+
87
+ self.encoder = EncoderCausal3D(
88
+ in_channels=in_channels,
89
+ out_channels=latent_channels,
90
+ down_block_types=down_block_types,
91
+ block_out_channels=block_out_channels,
92
+ layers_per_block=layers_per_block,
93
+ act_fn=act_fn,
94
+ norm_num_groups=norm_num_groups,
95
+ double_z=True,
96
+ time_compression_ratio=time_compression_ratio,
97
+ spatial_compression_ratio=spatial_compression_ratio,
98
+ mid_block_add_attention=mid_block_add_attention,
99
+ )
100
+
101
+ self.decoder = DecoderCausal3D(
102
+ in_channels=latent_channels,
103
+ out_channels=out_channels,
104
+ up_block_types=up_block_types,
105
+ block_out_channels=block_out_channels,
106
+ layers_per_block=layers_per_block,
107
+ norm_num_groups=norm_num_groups,
108
+ act_fn=act_fn,
109
+ time_compression_ratio=time_compression_ratio,
110
+ spatial_compression_ratio=spatial_compression_ratio,
111
+ mid_block_add_attention=mid_block_add_attention,
112
+ )
113
+
114
+ self.quant_conv = nn.Conv3d(2 * latent_channels, 2 * latent_channels, kernel_size=1)
115
+ self.post_quant_conv = nn.Conv3d(latent_channels, latent_channels, kernel_size=1)
116
+
117
+ self.use_slicing = False
118
+ self.use_spatial_tiling = False
119
+ self.use_temporal_tiling = False
120
+
121
+ # only relevant if vae tiling is enabled
122
+ self.tile_sample_min_tsize = sample_tsize
123
+ self.tile_latent_min_tsize = sample_tsize // time_compression_ratio
124
+
125
+ self.tile_sample_min_size = self.config.sample_size
126
+ sample_size = self.config.sample_size[0] if isinstance(self.config.sample_size, (list, tuple)) else self.config.sample_size
127
+ self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
128
+ self.tile_overlap_factor = 0.25
129
+
130
+ def _set_gradient_checkpointing(self, module, value=False):
131
+ if isinstance(module, (EncoderCausal3D, DecoderCausal3D)):
132
+ module.gradient_checkpointing = value
133
+
134
+ def enable_temporal_tiling(self, use_tiling: bool = True):
135
+ self.use_temporal_tiling = use_tiling
136
+
137
+ def disable_temporal_tiling(self):
138
+ self.enable_temporal_tiling(False)
139
+
140
+ def enable_spatial_tiling(self, use_tiling: bool = True):
141
+ self.use_spatial_tiling = use_tiling
142
+
143
+ def disable_spatial_tiling(self):
144
+ self.enable_spatial_tiling(False)
145
+
146
+ def enable_tiling(self, use_tiling: bool = True):
147
+ r"""
148
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
149
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
150
+ processing larger videos.
151
+ """
152
+ self.enable_spatial_tiling(use_tiling)
153
+ self.enable_temporal_tiling(use_tiling)
154
+
155
+ def disable_tiling(self):
156
+ r"""
157
+ Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
158
+ decoding in one step.
159
+ """
160
+ self.disable_spatial_tiling()
161
+ self.disable_temporal_tiling()
162
+
163
+ def enable_slicing(self):
164
+ r"""
165
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
166
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
167
+ """
168
+ self.use_slicing = True
169
+
170
+ def disable_slicing(self):
171
+ r"""
172
+ Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
173
+ decoding in one step.
174
+ """
175
+ self.use_slicing = False
176
+
177
+ def set_chunk_size_for_causal_conv_3d(self, chunk_size: int):
178
+ # set chunk_size to CausalConv3d recursively
179
+ def set_chunk_size(module):
180
+ if hasattr(module, "chunk_size"):
181
+ module.chunk_size = chunk_size
182
+
183
+ self.apply(set_chunk_size)
184
+
185
+ @property
186
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
187
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
188
+ r"""
189
+ Returns:
190
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
191
+ indexed by its weight name.
192
+ """
193
+ # set recursively
194
+ processors = {}
195
+
196
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
197
+ if hasattr(module, "get_processor"):
198
+ processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
199
+
200
+ for sub_name, child in module.named_children():
201
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
202
+
203
+ return processors
204
+
205
+ for name, module in self.named_children():
206
+ fn_recursive_add_processors(name, module, processors)
207
+
208
+ return processors
209
+
210
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
211
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False):
212
+ r"""
213
+ Sets the attention processor to use to compute attention.
214
+
215
+ Parameters:
216
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
217
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
218
+ for **all** `Attention` layers.
219
+
220
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
221
+ processor. This is strongly recommended when setting trainable attention processors.
222
+
223
+ """
224
+ count = len(self.attn_processors.keys())
225
+
226
+ if isinstance(processor, dict) and len(processor) != count:
227
+ raise ValueError(
228
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
229
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
230
+ )
231
+
232
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
233
+ if hasattr(module, "set_processor"):
234
+ if not isinstance(processor, dict):
235
+ module.set_processor(processor, _remove_lora=_remove_lora)
236
+ else:
237
+ module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
238
+
239
+ for sub_name, child in module.named_children():
240
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
241
+
242
+ for name, module in self.named_children():
243
+ fn_recursive_attn_processor(name, module, processor)
244
+
245
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
246
+ def set_default_attn_processor(self):
247
+ """
248
+ Disables custom attention processors and sets the default attention implementation.
249
+ """
250
+ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
251
+ processor = AttnAddedKVProcessor()
252
+ elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
253
+ processor = AttnProcessor()
254
+ else:
255
+ raise ValueError(
256
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
257
+ )
258
+
259
+ self.set_attn_processor(processor, _remove_lora=True)
260
+
261
+ @apply_forward_hook
262
+ def encode(
263
+ self, x: torch.FloatTensor, return_dict: bool = True
264
+ ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
265
+ """
266
+ Encode a batch of images/videos into latents.
267
+
268
+ Args:
269
+ x (`torch.FloatTensor`): Input batch of images/videos.
270
+ return_dict (`bool`, *optional*, defaults to `True`):
271
+ Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
272
+
273
+ Returns:
274
+ The latent representations of the encoded images/videos. If `return_dict` is True, a
275
+ [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
276
+ """
277
+ assert len(x.shape) == 5, "The input tensor should have 5 dimensions."
278
+
279
+ if self.use_temporal_tiling and x.shape[2] > self.tile_sample_min_tsize:
280
+ return self.temporal_tiled_encode(x, return_dict=return_dict)
281
+
282
+ if self.use_spatial_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
283
+ return self.spatial_tiled_encode(x, return_dict=return_dict)
284
+
285
+ if self.use_slicing and x.shape[0] > 1:
286
+ encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
287
+ h = torch.cat(encoded_slices)
288
+ else:
289
+ h = self.encoder(x)
290
+
291
+ moments = self.quant_conv(h)
292
+ posterior = DiagonalGaussianDistribution(moments)
293
+
294
+ if not return_dict:
295
+ return (posterior,)
296
+
297
+ return AutoencoderKLOutput(latent_dist=posterior)
298
+
299
+ def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
300
+ assert len(z.shape) == 5, "The input tensor should have 5 dimensions."
301
+
302
+ if self.use_temporal_tiling and z.shape[2] > self.tile_latent_min_tsize:
303
+ return self.temporal_tiled_decode(z, return_dict=return_dict)
304
+
305
+ if self.use_spatial_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size):
306
+ return self.spatial_tiled_decode(z, return_dict=return_dict)
307
+
308
+ z = self.post_quant_conv(z)
309
+ dec = self.decoder(z)
310
+
311
+ if not return_dict:
312
+ return (dec,)
313
+
314
+ return DecoderOutput(sample=dec)
315
+
316
+ @apply_forward_hook
317
+ def decode(self, z: torch.FloatTensor, return_dict: bool = True, generator=None) -> Union[DecoderOutput, torch.FloatTensor]:
318
+ """
319
+ Decode a batch of images/videos.
320
+
321
+ Args:
322
+ z (`torch.FloatTensor`): Input batch of latent vectors.
323
+ return_dict (`bool`, *optional*, defaults to `True`):
324
+ Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
325
+
326
+ Returns:
327
+ [`~models.vae.DecoderOutput`] or `tuple`:
328
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
329
+ returned.
330
+
331
+ """
332
+ if self.use_slicing and z.shape[0] > 1:
333
+ decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
334
+ decoded = torch.cat(decoded_slices)
335
+ else:
336
+ decoded = self._decode(z).sample
337
+
338
+ if not return_dict:
339
+ return (decoded,)
340
+
341
+ return DecoderOutput(sample=decoded)
342
+
343
+ def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
344
+ blend_extent = min(a.shape[-2], b.shape[-2], blend_extent)
345
+ for y in range(blend_extent):
346
+ b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (y / blend_extent)
347
+ return b
348
+
349
+ def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
350
+ blend_extent = min(a.shape[-1], b.shape[-1], blend_extent)
351
+ for x in range(blend_extent):
352
+ b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (x / blend_extent)
353
+ return b
354
+
355
+ def blend_t(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
356
+ blend_extent = min(a.shape[-3], b.shape[-3], blend_extent)
357
+ for x in range(blend_extent):
358
+ b[:, :, x, :, :] = a[:, :, -blend_extent + x, :, :] * (1 - x / blend_extent) + b[:, :, x, :, :] * (x / blend_extent)
359
+ return b
360
+
361
+ def spatial_tiled_encode(
362
+ self, x: torch.FloatTensor, return_dict: bool = True, return_moments: bool = False
363
+ ) -> AutoencoderKLOutput:
364
+ r"""Encode a batch of images/videos using a tiled encoder.
365
+
366
+ When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
367
+ steps. This is useful to keep memory use constant regardless of image/videos size. The end result of tiled encoding is
368
+ different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
369
+ tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
370
+ output, but they should be much less noticeable.
371
+
372
+ Args:
373
+ x (`torch.FloatTensor`): Input batch of images/videos.
374
+ return_dict (`bool`, *optional*, defaults to `True`):
375
+ Whether or not to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
376
+
377
+ Returns:
378
+ [`~models.autoencoder_kl.AutoencoderKLOutput`] or `tuple`:
379
+ If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain
380
+ `tuple` is returned.
381
+ """
382
+ overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
383
+ blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
384
+ row_limit = self.tile_latent_min_size - blend_extent
385
+
386
+ # Split video into tiles and encode them separately.
387
+ rows = []
388
+ for i in range(0, x.shape[-2], overlap_size):
389
+ row = []
390
+ for j in range(0, x.shape[-1], overlap_size):
391
+ tile = x[:, :, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
392
+ tile = self.encoder(tile)
393
+ tile = self.quant_conv(tile)
394
+ row.append(tile)
395
+ rows.append(row)
396
+ result_rows = []
397
+ for i, row in enumerate(rows):
398
+ result_row = []
399
+ for j, tile in enumerate(row):
400
+ # blend the above tile and the left tile
401
+ # to the current tile and add the current tile to the result row
402
+ if i > 0:
403
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
404
+ if j > 0:
405
+ tile = self.blend_h(row[j - 1], tile, blend_extent)
406
+ result_row.append(tile[:, :, :, :row_limit, :row_limit])
407
+ result_rows.append(torch.cat(result_row, dim=-1))
408
+
409
+ moments = torch.cat(result_rows, dim=-2)
410
+ if return_moments:
411
+ return moments
412
+
413
+ posterior = DiagonalGaussianDistribution(moments)
414
+ if not return_dict:
415
+ return (posterior,)
416
+
417
+ return AutoencoderKLOutput(latent_dist=posterior)
418
+
419
+ def spatial_tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
420
+ r"""
421
+ Decode a batch of images/videos using a tiled decoder.
422
+
423
+ Args:
424
+ z (`torch.FloatTensor`): Input batch of latent vectors.
425
+ return_dict (`bool`, *optional*, defaults to `True`):
426
+ Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
427
+
428
+ Returns:
429
+ [`~models.vae.DecoderOutput`] or `tuple`:
430
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
431
+ returned.
432
+ """
433
+ overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
434
+ blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
435
+ row_limit = self.tile_sample_min_size - blend_extent
436
+
437
+ # Split z into overlapping tiles and decode them separately.
438
+ # The tiles have an overlap to avoid seams between tiles.
439
+ rows = []
440
+ for i in range(0, z.shape[-2], overlap_size):
441
+ row = []
442
+ for j in range(0, z.shape[-1], overlap_size):
443
+ tile = z[:, :, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size]
444
+ tile = self.post_quant_conv(tile)
445
+ decoded = self.decoder(tile)
446
+ row.append(decoded)
447
+ rows.append(row)
448
+ result_rows = []
449
+ for i, row in enumerate(rows):
450
+ result_row = []
451
+ for j, tile in enumerate(row):
452
+ # blend the above tile and the left tile
453
+ # to the current tile and add the current tile to the result row
454
+ if i > 0:
455
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
456
+ if j > 0:
457
+ tile = self.blend_h(row[j - 1], tile, blend_extent)
458
+ result_row.append(tile[:, :, :, :row_limit, :row_limit])
459
+ result_rows.append(torch.cat(result_row, dim=-1))
460
+
461
+ dec = torch.cat(result_rows, dim=-2)
462
+ if not return_dict:
463
+ return (dec,)
464
+
465
+ return DecoderOutput(sample=dec)
466
+
467
+ def temporal_tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
468
+
469
+ B, C, T, H, W = x.shape
470
+ overlap_size = int(self.tile_sample_min_tsize * (1 - self.tile_overlap_factor))
471
+ blend_extent = int(self.tile_latent_min_tsize * self.tile_overlap_factor)
472
+ t_limit = self.tile_latent_min_tsize - blend_extent
473
+
474
+ # Split the video into tiles and encode them separately.
475
+ row = []
476
+ for i in range(0, T, overlap_size):
477
+ tile = x[:, :, i : i + self.tile_sample_min_tsize + 1, :, :]
478
+ if self.use_spatial_tiling and (
479
+ tile.shape[-1] > self.tile_sample_min_size or tile.shape[-2] > self.tile_sample_min_size
480
+ ):
481
+ tile = self.spatial_tiled_encode(tile, return_moments=True)
482
+ else:
483
+ tile = self.encoder(tile)
484
+ tile = self.quant_conv(tile)
485
+ if i > 0:
486
+ tile = tile[:, :, 1:, :, :]
487
+ row.append(tile)
488
+ result_row = []
489
+ for i, tile in enumerate(row):
490
+ if i > 0:
491
+ tile = self.blend_t(row[i - 1], tile, blend_extent)
492
+ result_row.append(tile[:, :, :t_limit, :, :])
493
+ else:
494
+ result_row.append(tile[:, :, : t_limit + 1, :, :])
495
+
496
+ moments = torch.cat(result_row, dim=2)
497
+ posterior = DiagonalGaussianDistribution(moments)
498
+
499
+ if not return_dict:
500
+ return (posterior,)
501
+
502
+ return AutoencoderKLOutput(latent_dist=posterior)
503
+
504
+ def temporal_tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
505
+ # Split z into overlapping tiles and decode them separately.
506
+
507
+ B, C, T, H, W = z.shape
508
+ overlap_size = int(self.tile_latent_min_tsize * (1 - self.tile_overlap_factor))
509
+ blend_extent = int(self.tile_sample_min_tsize * self.tile_overlap_factor)
510
+ t_limit = self.tile_sample_min_tsize - blend_extent
511
+
512
+ row = []
513
+ for i in range(0, T, overlap_size):
514
+ tile = z[:, :, i : i + self.tile_latent_min_tsize + 1, :, :]
515
+ if self.use_spatial_tiling and (
516
+ tile.shape[-1] > self.tile_latent_min_size or tile.shape[-2] > self.tile_latent_min_size
517
+ ):
518
+ decoded = self.spatial_tiled_decode(tile, return_dict=True).sample
519
+ else:
520
+ tile = self.post_quant_conv(tile)
521
+ decoded = self.decoder(tile)
522
+ if i > 0:
523
+ decoded = decoded[:, :, 1:, :, :]
524
+ row.append(decoded)
525
+ result_row = []
526
+ for i, tile in enumerate(row):
527
+ if i > 0:
528
+ tile = self.blend_t(row[i - 1], tile, blend_extent)
529
+ result_row.append(tile[:, :, :t_limit, :, :])
530
+ else:
531
+ result_row.append(tile[:, :, : t_limit + 1, :, :])
532
+
533
+ dec = torch.cat(result_row, dim=2)
534
+ if not return_dict:
535
+ return (dec,)
536
+
537
+ return DecoderOutput(sample=dec)
538
+
539
+ def forward(
540
+ self,
541
+ sample: torch.FloatTensor,
542
+ sample_posterior: bool = False,
543
+ return_dict: bool = True,
544
+ return_posterior: bool = False,
545
+ generator: Optional[torch.Generator] = None,
546
+ ) -> Union[DecoderOutput2, torch.FloatTensor]:
547
+ r"""
548
+ Args:
549
+ sample (`torch.FloatTensor`): Input sample.
550
+ sample_posterior (`bool`, *optional*, defaults to `False`):
551
+ Whether to sample from the posterior.
552
+ return_dict (`bool`, *optional*, defaults to `True`):
553
+ Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
554
+ """
555
+ x = sample
556
+ posterior = self.encode(x).latent_dist
557
+ if sample_posterior:
558
+ z = posterior.sample(generator=generator)
559
+ else:
560
+ z = posterior.mode()
561
+ dec = self.decode(z).sample
562
+
563
+ if not return_dict:
564
+ if return_posterior:
565
+ return (dec, posterior)
566
+ else:
567
+ return (dec,)
568
+ if return_posterior:
569
+ return DecoderOutput2(sample=dec, posterior=posterior)
570
+ else:
571
+ return DecoderOutput2(sample=dec)
572
+
573
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
574
+ def fuse_qkv_projections(self):
575
+ """
576
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
577
+ key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
578
+
579
+ <Tip warning={true}>
580
+
581
+ This API is 🧪 experimental.
582
+
583
+ </Tip>
584
+ """
585
+ self.original_attn_processors = None
586
+
587
+ for _, attn_processor in self.attn_processors.items():
588
+ if "Added" in str(attn_processor.__class__.__name__):
589
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
590
+
591
+ self.original_attn_processors = self.attn_processors
592
+
593
+ for module in self.modules():
594
+ if isinstance(module, Attention):
595
+ module.fuse_projections(fuse=True)
596
+
597
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
598
+ def unfuse_qkv_projections(self):
599
+ """Disables the fused QKV projection if enabled.
600
+
601
+ <Tip warning={true}>
602
+
603
+ This API is 🧪 experimental.
604
+
605
+ </Tip>
606
+
607
+ """
608
+ if self.original_attn_processors is not None:
609
+ self.set_attn_processor(self.original_attn_processors)
hunyuan_model/embed_layers.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ import math
3
+ import torch
4
+ import torch.nn as nn
5
+ from einops import rearrange, repeat
6
+
7
+ from .helpers import to_2tuple
8
+
9
+ class PatchEmbed(nn.Module):
10
+ """2D Image to Patch Embedding
11
+
12
+ Image to Patch Embedding using Conv2d
13
+
14
+ A convolution based approach to patchifying a 2D image w/ embedding projection.
15
+
16
+ Based on the impl in https://github.com/google-research/vision_transformer
17
+
18
+ Hacked together by / Copyright 2020 Ross Wightman
19
+
20
+ Remove the _assert function in forward function to be compatible with multi-resolution images.
21
+ """
22
+
23
+ def __init__(
24
+ self,
25
+ patch_size=16,
26
+ in_chans=3,
27
+ embed_dim=768,
28
+ norm_layer=None,
29
+ flatten=True,
30
+ bias=True,
31
+ dtype=None,
32
+ device=None,
33
+ ):
34
+ factory_kwargs = {"dtype": dtype, "device": device}
35
+ super().__init__()
36
+ patch_size = to_2tuple(patch_size)
37
+ self.patch_size = patch_size
38
+ self.flatten = flatten
39
+
40
+ self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias, **factory_kwargs)
41
+ nn.init.xavier_uniform_(self.proj.weight.view(self.proj.weight.size(0), -1))
42
+ if bias:
43
+ nn.init.zeros_(self.proj.bias)
44
+
45
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
46
+
47
+ def forward(self, x):
48
+ x = self.proj(x)
49
+ if self.flatten:
50
+ x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
51
+ x = self.norm(x)
52
+ return x
53
+
54
+
55
+ class TextProjection(nn.Module):
56
+ """
57
+ Projects text embeddings. Also handles dropout for classifier-free guidance.
58
+
59
+ Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
60
+ """
61
+
62
+ def __init__(self, in_channels, hidden_size, act_layer, dtype=None, device=None):
63
+ factory_kwargs = {"dtype": dtype, "device": device}
64
+ super().__init__()
65
+ self.linear_1 = nn.Linear(in_features=in_channels, out_features=hidden_size, bias=True, **factory_kwargs)
66
+ self.act_1 = act_layer()
67
+ self.linear_2 = nn.Linear(in_features=hidden_size, out_features=hidden_size, bias=True, **factory_kwargs)
68
+
69
+ def forward(self, caption):
70
+ hidden_states = self.linear_1(caption)
71
+ hidden_states = self.act_1(hidden_states)
72
+ hidden_states = self.linear_2(hidden_states)
73
+ return hidden_states
74
+
75
+
76
+ def timestep_embedding(t, dim, max_period=10000):
77
+ """
78
+ Create sinusoidal timestep embeddings.
79
+
80
+ Args:
81
+ t (torch.Tensor): a 1-D Tensor of N indices, one per batch element. These may be fractional.
82
+ dim (int): the dimension of the output.
83
+ max_period (int): controls the minimum frequency of the embeddings.
84
+
85
+ Returns:
86
+ embedding (torch.Tensor): An (N, D) Tensor of positional embeddings.
87
+
88
+ .. ref_link: https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
89
+ """
90
+ half = dim // 2
91
+ freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(device=t.device)
92
+ args = t[:, None].float() * freqs[None]
93
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
94
+ if dim % 2:
95
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
96
+ return embedding
97
+
98
+
99
+ class TimestepEmbedder(nn.Module):
100
+ """
101
+ Embeds scalar timesteps into vector representations.
102
+ """
103
+
104
+ def __init__(
105
+ self,
106
+ hidden_size,
107
+ act_layer,
108
+ frequency_embedding_size=256,
109
+ max_period=10000,
110
+ out_size=None,
111
+ dtype=None,
112
+ device=None,
113
+ ):
114
+ factory_kwargs = {"dtype": dtype, "device": device}
115
+ super().__init__()
116
+ self.frequency_embedding_size = frequency_embedding_size
117
+ self.max_period = max_period
118
+ if out_size is None:
119
+ out_size = hidden_size
120
+
121
+ self.mlp = nn.Sequential(
122
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True, **factory_kwargs),
123
+ act_layer(),
124
+ nn.Linear(hidden_size, out_size, bias=True, **factory_kwargs),
125
+ )
126
+ nn.init.normal_(self.mlp[0].weight, std=0.02)
127
+ nn.init.normal_(self.mlp[2].weight, std=0.02)
128
+
129
+ def forward(self, t):
130
+ t_freq = timestep_embedding(t, self.frequency_embedding_size, self.max_period).type(self.mlp[0].weight.dtype)
131
+ t_emb = self.mlp(t_freq)
132
+ return t_emb
hunyuan_model/helpers.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections.abc
2
+
3
+ from itertools import repeat
4
+
5
+
6
+ def _ntuple(n):
7
+ def parse(x):
8
+ if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
9
+ x = tuple(x)
10
+ if len(x) == 1:
11
+ x = tuple(repeat(x[0], n))
12
+ return x
13
+ return tuple(repeat(x, n))
14
+ return parse
15
+
16
+
17
+ to_1tuple = _ntuple(1)
18
+ to_2tuple = _ntuple(2)
19
+ to_3tuple = _ntuple(3)
20
+ to_4tuple = _ntuple(4)
21
+
22
+
23
+ def as_tuple(x):
24
+ if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
25
+ return tuple(x)
26
+ if x is None or isinstance(x, (int, float, str)):
27
+ return (x,)
28
+ else:
29
+ raise ValueError(f"Unknown type {type(x)}")
30
+
31
+
32
+ def as_list_of_2tuple(x):
33
+ x = as_tuple(x)
34
+ if len(x) == 1:
35
+ x = (x[0], x[0])
36
+ assert len(x) % 2 == 0, f"Expect even length, got {len(x)}."
37
+ lst = []
38
+ for i in range(0, len(x), 2):
39
+ lst.append((x[i], x[i + 1]))
40
+ return lst
hunyuan_model/mlp_layers.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from timm library:
2
+ # https://github.com/huggingface/pytorch-image-models/blob/648aaa41233ba83eb38faf5ba9d415d574823241/timm/layers/mlp.py#L13
3
+
4
+ from functools import partial
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+ from .modulate_layers import modulate
10
+ from .helpers import to_2tuple
11
+
12
+
13
+ class MLP(nn.Module):
14
+ """MLP as used in Vision Transformer, MLP-Mixer and related networks"""
15
+
16
+ def __init__(
17
+ self,
18
+ in_channels,
19
+ hidden_channels=None,
20
+ out_features=None,
21
+ act_layer=nn.GELU,
22
+ norm_layer=None,
23
+ bias=True,
24
+ drop=0.0,
25
+ use_conv=False,
26
+ device=None,
27
+ dtype=None,
28
+ ):
29
+ factory_kwargs = {"device": device, "dtype": dtype}
30
+ super().__init__()
31
+ out_features = out_features or in_channels
32
+ hidden_channels = hidden_channels or in_channels
33
+ bias = to_2tuple(bias)
34
+ drop_probs = to_2tuple(drop)
35
+ linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
36
+
37
+ self.fc1 = linear_layer(
38
+ in_channels, hidden_channels, bias=bias[0], **factory_kwargs
39
+ )
40
+ self.act = act_layer()
41
+ self.drop1 = nn.Dropout(drop_probs[0])
42
+ self.norm = (
43
+ norm_layer(hidden_channels, **factory_kwargs)
44
+ if norm_layer is not None
45
+ else nn.Identity()
46
+ )
47
+ self.fc2 = linear_layer(
48
+ hidden_channels, out_features, bias=bias[1], **factory_kwargs
49
+ )
50
+ self.drop2 = nn.Dropout(drop_probs[1])
51
+
52
+ def forward(self, x):
53
+ x = self.fc1(x)
54
+ x = self.act(x)
55
+ x = self.drop1(x)
56
+ x = self.norm(x)
57
+ x = self.fc2(x)
58
+ x = self.drop2(x)
59
+ return x
60
+
61
+
62
+ #
63
+ class MLPEmbedder(nn.Module):
64
+ """copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py"""
65
+ def __init__(self, in_dim: int, hidden_dim: int, device=None, dtype=None):
66
+ factory_kwargs = {"device": device, "dtype": dtype}
67
+ super().__init__()
68
+ self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True, **factory_kwargs)
69
+ self.silu = nn.SiLU()
70
+ self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True, **factory_kwargs)
71
+
72
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
73
+ return self.out_layer(self.silu(self.in_layer(x)))
74
+
75
+
76
+ class FinalLayer(nn.Module):
77
+ """The final layer of DiT."""
78
+
79
+ def __init__(
80
+ self, hidden_size, patch_size, out_channels, act_layer, device=None, dtype=None
81
+ ):
82
+ factory_kwargs = {"device": device, "dtype": dtype}
83
+ super().__init__()
84
+
85
+ # Just use LayerNorm for the final layer
86
+ self.norm_final = nn.LayerNorm(
87
+ hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs
88
+ )
89
+ if isinstance(patch_size, int):
90
+ self.linear = nn.Linear(
91
+ hidden_size,
92
+ patch_size * patch_size * out_channels,
93
+ bias=True,
94
+ **factory_kwargs
95
+ )
96
+ else:
97
+ self.linear = nn.Linear(
98
+ hidden_size,
99
+ patch_size[0] * patch_size[1] * patch_size[2] * out_channels,
100
+ bias=True,
101
+ )
102
+ nn.init.zeros_(self.linear.weight)
103
+ nn.init.zeros_(self.linear.bias)
104
+
105
+ # Here we don't distinguish between the modulate types. Just use the simple one.
106
+ self.adaLN_modulation = nn.Sequential(
107
+ act_layer(),
108
+ nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs),
109
+ )
110
+ # Zero-initialize the modulation
111
+ nn.init.zeros_(self.adaLN_modulation[1].weight)
112
+ nn.init.zeros_(self.adaLN_modulation[1].bias)
113
+
114
+ def forward(self, x, c):
115
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
116
+ x = modulate(self.norm_final(x), shift=shift, scale=scale)
117
+ x = self.linear(x)
118
+ return x
hunyuan_model/models.py ADDED
@@ -0,0 +1,997 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Any, List, Tuple, Optional, Union, Dict
3
+ import accelerate
4
+ from einops import rearrange
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from torch.utils.checkpoint import checkpoint
9
+
10
+ from .activation_layers import get_activation_layer
11
+ from .norm_layers import get_norm_layer
12
+ from .embed_layers import TimestepEmbedder, PatchEmbed, TextProjection
13
+ from .attention import attention, parallel_attention, get_cu_seqlens
14
+ from .posemb_layers import apply_rotary_emb
15
+ from .mlp_layers import MLP, MLPEmbedder, FinalLayer
16
+ from .modulate_layers import ModulateDiT, modulate, apply_gate
17
+ from .token_refiner import SingleTokenRefiner
18
+ from modules.custom_offloading_utils import ModelOffloader, synchronize_device, clean_memory_on_device
19
+ from hunyuan_model.posemb_layers import get_nd_rotary_pos_embed
20
+
21
+ from utils.safetensors_utils import MemoryEfficientSafeOpen
22
+
23
+
24
+ class MMDoubleStreamBlock(nn.Module):
25
+ """
26
+ A multimodal dit block with seperate modulation for
27
+ text and image/video, see more details (SD3): https://arxiv.org/abs/2403.03206
28
+ (Flux.1): https://github.com/black-forest-labs/flux
29
+ """
30
+
31
+ def __init__(
32
+ self,
33
+ hidden_size: int,
34
+ heads_num: int,
35
+ mlp_width_ratio: float,
36
+ mlp_act_type: str = "gelu_tanh",
37
+ qk_norm: bool = True,
38
+ qk_norm_type: str = "rms",
39
+ qkv_bias: bool = False,
40
+ dtype: Optional[torch.dtype] = None,
41
+ device: Optional[torch.device] = None,
42
+ attn_mode: str = "flash",
43
+ ):
44
+ factory_kwargs = {"device": device, "dtype": dtype}
45
+ super().__init__()
46
+ self.attn_mode = attn_mode
47
+
48
+ self.deterministic = False
49
+ self.heads_num = heads_num
50
+ head_dim = hidden_size // heads_num
51
+ mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
52
+
53
+ self.img_mod = ModulateDiT(
54
+ hidden_size,
55
+ factor=6,
56
+ act_layer=get_activation_layer("silu"),
57
+ **factory_kwargs,
58
+ )
59
+ self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
60
+
61
+ self.img_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs)
62
+ qk_norm_layer = get_norm_layer(qk_norm_type)
63
+ self.img_attn_q_norm = (
64
+ qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
65
+ )
66
+ self.img_attn_k_norm = (
67
+ qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
68
+ )
69
+ self.img_attn_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs)
70
+
71
+ self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
72
+ self.img_mlp = MLP(
73
+ hidden_size,
74
+ mlp_hidden_dim,
75
+ act_layer=get_activation_layer(mlp_act_type),
76
+ bias=True,
77
+ **factory_kwargs,
78
+ )
79
+
80
+ self.txt_mod = ModulateDiT(
81
+ hidden_size,
82
+ factor=6,
83
+ act_layer=get_activation_layer("silu"),
84
+ **factory_kwargs,
85
+ )
86
+ self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
87
+
88
+ self.txt_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs)
89
+ self.txt_attn_q_norm = (
90
+ qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
91
+ )
92
+ self.txt_attn_k_norm = (
93
+ qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
94
+ )
95
+ self.txt_attn_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs)
96
+
97
+ self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
98
+ self.txt_mlp = MLP(
99
+ hidden_size,
100
+ mlp_hidden_dim,
101
+ act_layer=get_activation_layer(mlp_act_type),
102
+ bias=True,
103
+ **factory_kwargs,
104
+ )
105
+ self.hybrid_seq_parallel_attn = None
106
+
107
+ self.gradient_checkpointing = False
108
+
109
+ def enable_deterministic(self):
110
+ self.deterministic = True
111
+
112
+ def disable_deterministic(self):
113
+ self.deterministic = False
114
+
115
+ def enable_gradient_checkpointing(self):
116
+ self.gradient_checkpointing = True
117
+
118
+ def _forward(
119
+ self,
120
+ img: torch.Tensor,
121
+ txt: torch.Tensor,
122
+ vec: torch.Tensor,
123
+ attn_mask: Optional[torch.Tensor] = None,
124
+ cu_seqlens_q: Optional[torch.Tensor] = None,
125
+ cu_seqlens_kv: Optional[torch.Tensor] = None,
126
+ max_seqlen_q: Optional[int] = None,
127
+ max_seqlen_kv: Optional[int] = None,
128
+ freqs_cis: tuple = None,
129
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
130
+ (img_mod1_shift, img_mod1_scale, img_mod1_gate, img_mod2_shift, img_mod2_scale, img_mod2_gate) = self.img_mod(vec).chunk(
131
+ 6, dim=-1
132
+ )
133
+ (txt_mod1_shift, txt_mod1_scale, txt_mod1_gate, txt_mod2_shift, txt_mod2_scale, txt_mod2_gate) = self.txt_mod(vec).chunk(
134
+ 6, dim=-1
135
+ )
136
+
137
+ # Prepare image for attention.
138
+ img_modulated = self.img_norm1(img)
139
+ img_modulated = modulate(img_modulated, shift=img_mod1_shift, scale=img_mod1_scale)
140
+ img_qkv = self.img_attn_qkv(img_modulated)
141
+ img_modulated = None
142
+ img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
143
+ img_qkv = None
144
+ # Apply QK-Norm if needed
145
+ img_q = self.img_attn_q_norm(img_q).to(img_v)
146
+ img_k = self.img_attn_k_norm(img_k).to(img_v)
147
+
148
+ # Apply RoPE if needed.
149
+ if freqs_cis is not None:
150
+ img_q_shape = img_q.shape
151
+ img_k_shape = img_k.shape
152
+ img_q, img_k = apply_rotary_emb(img_q, img_k, freqs_cis, head_first=False)
153
+ assert (
154
+ img_q.shape == img_q_shape and img_k.shape == img_k_shape
155
+ ), f"img_kk: {img_q.shape}, img_q: {img_q_shape}, img_kk: {img_k.shape}, img_k: {img_k_shape}"
156
+ # img_q, img_k = img_qq, img_kk
157
+
158
+ # Prepare txt for attention.
159
+ txt_modulated = self.txt_norm1(txt)
160
+ txt_modulated = modulate(txt_modulated, shift=txt_mod1_shift, scale=txt_mod1_scale)
161
+ txt_qkv = self.txt_attn_qkv(txt_modulated)
162
+ txt_modulated = None
163
+ txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
164
+ txt_qkv = None
165
+ # Apply QK-Norm if needed.
166
+ txt_q = self.txt_attn_q_norm(txt_q).to(txt_v)
167
+ txt_k = self.txt_attn_k_norm(txt_k).to(txt_v)
168
+
169
+ # Run actual attention.
170
+ img_q_len = img_q.shape[1]
171
+ img_kv_len = img_k.shape[1]
172
+ batch_size = img_k.shape[0]
173
+ q = torch.cat((img_q, txt_q), dim=1)
174
+ img_q = txt_q = None
175
+ k = torch.cat((img_k, txt_k), dim=1)
176
+ img_k = txt_k = None
177
+ v = torch.cat((img_v, txt_v), dim=1)
178
+ img_v = txt_v = None
179
+
180
+ assert (
181
+ cu_seqlens_q.shape[0] == 2 * img.shape[0] + 1
182
+ ), f"cu_seqlens_q.shape:{cu_seqlens_q.shape}, img.shape[0]:{img.shape[0]}"
183
+
184
+ # attention computation start
185
+ if not self.hybrid_seq_parallel_attn:
186
+ l = [q, k, v]
187
+ q = k = v = None
188
+ attn = attention(
189
+ l,
190
+ mode=self.attn_mode,
191
+ attn_mask=attn_mask,
192
+ cu_seqlens_q=cu_seqlens_q,
193
+ cu_seqlens_kv=cu_seqlens_kv,
194
+ max_seqlen_q=max_seqlen_q,
195
+ max_seqlen_kv=max_seqlen_kv,
196
+ batch_size=batch_size,
197
+ )
198
+ else:
199
+ attn = parallel_attention(
200
+ self.hybrid_seq_parallel_attn,
201
+ q,
202
+ k,
203
+ v,
204
+ img_q_len=img_q_len,
205
+ img_kv_len=img_kv_len,
206
+ cu_seqlens_q=cu_seqlens_q,
207
+ cu_seqlens_kv=cu_seqlens_kv,
208
+ )
209
+
210
+ # attention computation end
211
+
212
+ img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1] :]
213
+ attn = None
214
+
215
+ # Calculate the img bloks.
216
+ img = img + apply_gate(self.img_attn_proj(img_attn), gate=img_mod1_gate)
217
+ img_attn = None
218
+ img = img + apply_gate(
219
+ self.img_mlp(modulate(self.img_norm2(img), shift=img_mod2_shift, scale=img_mod2_scale)),
220
+ gate=img_mod2_gate,
221
+ )
222
+
223
+ # Calculate the txt bloks.
224
+ txt = txt + apply_gate(self.txt_attn_proj(txt_attn), gate=txt_mod1_gate)
225
+ txt_attn = None
226
+ txt = txt + apply_gate(
227
+ self.txt_mlp(modulate(self.txt_norm2(txt), shift=txt_mod2_shift, scale=txt_mod2_scale)),
228
+ gate=txt_mod2_gate,
229
+ )
230
+
231
+ return img, txt
232
+
233
+ # def forward(
234
+ # self,
235
+ # img: torch.Tensor,
236
+ # txt: torch.Tensor,
237
+ # vec: torch.Tensor,
238
+ # attn_mask: Optional[torch.Tensor] = None,
239
+ # cu_seqlens_q: Optional[torch.Tensor] = None,
240
+ # cu_seqlens_kv: Optional[torch.Tensor] = None,
241
+ # max_seqlen_q: Optional[int] = None,
242
+ # max_seqlen_kv: Optional[int] = None,
243
+ # freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None,
244
+ # ) -> Tuple[torch.Tensor, torch.Tensor]:
245
+ def forward(self, *args, **kwargs):
246
+ if self.training and self.gradient_checkpointing:
247
+ return checkpoint(self._forward, *args, use_reentrant=False, **kwargs)
248
+ else:
249
+ return self._forward(*args, **kwargs)
250
+
251
+
252
+ class MMSingleStreamBlock(nn.Module):
253
+ """
254
+ A DiT block with parallel linear layers as described in
255
+ https://arxiv.org/abs/2302.05442 and adapted modulation interface.
256
+ Also refer to (SD3): https://arxiv.org/abs/2403.03206
257
+ (Flux.1): https://github.com/black-forest-labs/flux
258
+ """
259
+
260
+ def __init__(
261
+ self,
262
+ hidden_size: int,
263
+ heads_num: int,
264
+ mlp_width_ratio: float = 4.0,
265
+ mlp_act_type: str = "gelu_tanh",
266
+ qk_norm: bool = True,
267
+ qk_norm_type: str = "rms",
268
+ qk_scale: float = None,
269
+ dtype: Optional[torch.dtype] = None,
270
+ device: Optional[torch.device] = None,
271
+ attn_mode: str = "flash",
272
+ ):
273
+ factory_kwargs = {"device": device, "dtype": dtype}
274
+ super().__init__()
275
+ self.attn_mode = attn_mode
276
+
277
+ self.deterministic = False
278
+ self.hidden_size = hidden_size
279
+ self.heads_num = heads_num
280
+ head_dim = hidden_size // heads_num
281
+ mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
282
+ self.mlp_hidden_dim = mlp_hidden_dim
283
+ self.scale = qk_scale or head_dim**-0.5
284
+
285
+ # qkv and mlp_in
286
+ self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + mlp_hidden_dim, **factory_kwargs)
287
+ # proj and mlp_out
288
+ self.linear2 = nn.Linear(hidden_size + mlp_hidden_dim, hidden_size, **factory_kwargs)
289
+
290
+ qk_norm_layer = get_norm_layer(qk_norm_type)
291
+ self.q_norm = qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
292
+ self.k_norm = qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
293
+
294
+ self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
295
+
296
+ self.mlp_act = get_activation_layer(mlp_act_type)()
297
+ self.modulation = ModulateDiT(hidden_size, factor=3, act_layer=get_activation_layer("silu"), **factory_kwargs)
298
+ self.hybrid_seq_parallel_attn = None
299
+
300
+ self.gradient_checkpointing = False
301
+
302
+ def enable_deterministic(self):
303
+ self.deterministic = True
304
+
305
+ def disable_deterministic(self):
306
+ self.deterministic = False
307
+
308
+ def enable_gradient_checkpointing(self):
309
+ self.gradient_checkpointing = True
310
+
311
+ def _forward(
312
+ self,
313
+ x: torch.Tensor,
314
+ vec: torch.Tensor,
315
+ txt_len: int,
316
+ attn_mask: Optional[torch.Tensor] = None,
317
+ cu_seqlens_q: Optional[torch.Tensor] = None,
318
+ cu_seqlens_kv: Optional[torch.Tensor] = None,
319
+ max_seqlen_q: Optional[int] = None,
320
+ max_seqlen_kv: Optional[int] = None,
321
+ freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None,
322
+ ) -> torch.Tensor:
323
+ mod_shift, mod_scale, mod_gate = self.modulation(vec).chunk(3, dim=-1)
324
+ x_mod = modulate(self.pre_norm(x), shift=mod_shift, scale=mod_scale)
325
+ qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
326
+ x_mod = None
327
+ # mlp = mlp.to("cpu", non_blocking=True)
328
+ # clean_memory_on_device(x.device)
329
+
330
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
331
+ qkv = None
332
+
333
+ # Apply QK-Norm if needed.
334
+ q = self.q_norm(q).to(v)
335
+ k = self.k_norm(k).to(v)
336
+
337
+ # Apply RoPE if needed.
338
+ if freqs_cis is not None:
339
+ img_q, txt_q = q[:, :-txt_len, :, :], q[:, -txt_len:, :, :]
340
+ img_k, txt_k = k[:, :-txt_len, :, :], k[:, -txt_len:, :, :]
341
+ q = k = None
342
+ img_q_shape = img_q.shape
343
+ img_k_shape = img_k.shape
344
+ img_q, img_k = apply_rotary_emb(img_q, img_k, freqs_cis, head_first=False)
345
+ assert (
346
+ img_q.shape == img_q_shape and img_k_shape == img_k.shape
347
+ ), f"img_kk: {img_q.shape}, img_q: {img_q.shape}, img_kk: {img_k.shape}, img_k: {img_k.shape}"
348
+ # img_q, img_k = img_qq, img_kk
349
+ # del img_qq, img_kk
350
+ q = torch.cat((img_q, txt_q), dim=1)
351
+ k = torch.cat((img_k, txt_k), dim=1)
352
+ del img_q, txt_q, img_k, txt_k
353
+
354
+ # Compute attention.
355
+ assert cu_seqlens_q.shape[0] == 2 * x.shape[0] + 1, f"cu_seqlens_q.shape:{cu_seqlens_q.shape}, x.shape[0]:{x.shape[0]}"
356
+
357
+ # attention computation start
358
+ if not self.hybrid_seq_parallel_attn:
359
+ l = [q, k, v]
360
+ q = k = v = None
361
+ attn = attention(
362
+ l,
363
+ mode=self.attn_mode,
364
+ attn_mask=attn_mask,
365
+ cu_seqlens_q=cu_seqlens_q,
366
+ cu_seqlens_kv=cu_seqlens_kv,
367
+ max_seqlen_q=max_seqlen_q,
368
+ max_seqlen_kv=max_seqlen_kv,
369
+ batch_size=x.shape[0],
370
+ )
371
+ else:
372
+ attn = parallel_attention(
373
+ self.hybrid_seq_parallel_attn,
374
+ q,
375
+ k,
376
+ v,
377
+ img_q_len=img_q.shape[1],
378
+ img_kv_len=img_k.shape[1],
379
+ cu_seqlens_q=cu_seqlens_q,
380
+ cu_seqlens_kv=cu_seqlens_kv,
381
+ )
382
+ # attention computation end
383
+
384
+ # Compute activation in mlp stream, cat again and run second linear layer.
385
+ # mlp = mlp.to(x.device)
386
+ mlp = self.mlp_act(mlp)
387
+ attn_mlp = torch.cat((attn, mlp), 2)
388
+ attn = None
389
+ mlp = None
390
+ output = self.linear2(attn_mlp)
391
+ attn_mlp = None
392
+ return x + apply_gate(output, gate=mod_gate)
393
+
394
+ # def forward(
395
+ # self,
396
+ # x: torch.Tensor,
397
+ # vec: torch.Tensor,
398
+ # txt_len: int,
399
+ # attn_mask: Optional[torch.Tensor] = None,
400
+ # cu_seqlens_q: Optional[torch.Tensor] = None,
401
+ # cu_seqlens_kv: Optional[torch.Tensor] = None,
402
+ # max_seqlen_q: Optional[int] = None,
403
+ # max_seqlen_kv: Optional[int] = None,
404
+ # freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None,
405
+ # ) -> torch.Tensor:
406
+ def forward(self, *args, **kwargs):
407
+ if self.training and self.gradient_checkpointing:
408
+ return checkpoint(self._forward, *args, use_reentrant=False, **kwargs)
409
+ else:
410
+ return self._forward(*args, **kwargs)
411
+
412
+
413
+ class HYVideoDiffusionTransformer(nn.Module): # ModelMixin, ConfigMixin):
414
+ """
415
+ HunyuanVideo Transformer backbone
416
+
417
+ Inherited from ModelMixin and ConfigMixin for compatibility with diffusers' sampler StableDiffusionPipeline.
418
+
419
+ Reference:
420
+ [1] Flux.1: https://github.com/black-forest-labs/flux
421
+ [2] MMDiT: http://arxiv.org/abs/2403.03206
422
+
423
+ Parameters
424
+ ----------
425
+ args: argparse.Namespace
426
+ The arguments parsed by argparse.
427
+ patch_size: list
428
+ The size of the patch.
429
+ in_channels: int
430
+ The number of input channels.
431
+ out_channels: int
432
+ The number of output channels.
433
+ hidden_size: int
434
+ The hidden size of the transformer backbone.
435
+ heads_num: int
436
+ The number of attention heads.
437
+ mlp_width_ratio: float
438
+ The ratio of the hidden size of the MLP in the transformer block.
439
+ mlp_act_type: str
440
+ The activation function of the MLP in the transformer block.
441
+ depth_double_blocks: int
442
+ The number of transformer blocks in the double blocks.
443
+ depth_single_blocks: int
444
+ The number of transformer blocks in the single blocks.
445
+ rope_dim_list: list
446
+ The dimension of the rotary embedding for t, h, w.
447
+ qkv_bias: bool
448
+ Whether to use bias in the qkv linear layer.
449
+ qk_norm: bool
450
+ Whether to use qk norm.
451
+ qk_norm_type: str
452
+ The type of qk norm.
453
+ guidance_embed: bool
454
+ Whether to use guidance embedding for distillation.
455
+ text_projection: str
456
+ The type of the text projection, default is single_refiner.
457
+ use_attention_mask: bool
458
+ Whether to use attention mask for text encoder.
459
+ dtype: torch.dtype
460
+ The dtype of the model.
461
+ device: torch.device
462
+ The device of the model.
463
+ attn_mode: str
464
+ The mode of the attention, default is flash.
465
+ """
466
+
467
+ # @register_to_config
468
+ def __init__(
469
+ self,
470
+ text_states_dim: int,
471
+ text_states_dim_2: int,
472
+ patch_size: list = [1, 2, 2],
473
+ in_channels: int = 4, # Should be VAE.config.latent_channels.
474
+ out_channels: int = None,
475
+ hidden_size: int = 3072,
476
+ heads_num: int = 24,
477
+ mlp_width_ratio: float = 4.0,
478
+ mlp_act_type: str = "gelu_tanh",
479
+ mm_double_blocks_depth: int = 20,
480
+ mm_single_blocks_depth: int = 40,
481
+ rope_dim_list: List[int] = [16, 56, 56],
482
+ qkv_bias: bool = True,
483
+ qk_norm: bool = True,
484
+ qk_norm_type: str = "rms",
485
+ guidance_embed: bool = False, # For modulation.
486
+ text_projection: str = "single_refiner",
487
+ use_attention_mask: bool = True,
488
+ dtype: Optional[torch.dtype] = None,
489
+ device: Optional[torch.device] = None,
490
+ attn_mode: str = "flash",
491
+ ):
492
+ factory_kwargs = {"device": device, "dtype": dtype}
493
+ super().__init__()
494
+
495
+ self.patch_size = patch_size
496
+ self.in_channels = in_channels
497
+ self.out_channels = in_channels if out_channels is None else out_channels
498
+ self.unpatchify_channels = self.out_channels
499
+ self.guidance_embed = guidance_embed
500
+ self.rope_dim_list = rope_dim_list
501
+
502
+ # Text projection. Default to linear projection.
503
+ # Alternative: TokenRefiner. See more details (LI-DiT): http://arxiv.org/abs/2406.11831
504
+ self.use_attention_mask = use_attention_mask
505
+ self.text_projection = text_projection
506
+
507
+ self.text_states_dim = text_states_dim
508
+ self.text_states_dim_2 = text_states_dim_2
509
+
510
+ if hidden_size % heads_num != 0:
511
+ raise ValueError(f"Hidden size {hidden_size} must be divisible by heads_num {heads_num}")
512
+ pe_dim = hidden_size // heads_num
513
+ if sum(rope_dim_list) != pe_dim:
514
+ raise ValueError(f"Got {rope_dim_list} but expected positional dim {pe_dim}")
515
+ self.hidden_size = hidden_size
516
+ self.heads_num = heads_num
517
+
518
+ self.attn_mode = attn_mode
519
+
520
+ # image projection
521
+ self.img_in = PatchEmbed(self.patch_size, self.in_channels, self.hidden_size, **factory_kwargs)
522
+
523
+ # text projection
524
+ if self.text_projection == "linear":
525
+ self.txt_in = TextProjection(
526
+ self.text_states_dim,
527
+ self.hidden_size,
528
+ get_activation_layer("silu"),
529
+ **factory_kwargs,
530
+ )
531
+ elif self.text_projection == "single_refiner":
532
+ self.txt_in = SingleTokenRefiner(self.text_states_dim, hidden_size, heads_num, depth=2, **factory_kwargs)
533
+ else:
534
+ raise NotImplementedError(f"Unsupported text_projection: {self.text_projection}")
535
+
536
+ # time modulation
537
+ self.time_in = TimestepEmbedder(self.hidden_size, get_activation_layer("silu"), **factory_kwargs)
538
+
539
+ # text modulation
540
+ self.vector_in = MLPEmbedder(self.text_states_dim_2, self.hidden_size, **factory_kwargs)
541
+
542
+ # guidance modulation
543
+ self.guidance_in = (
544
+ TimestepEmbedder(self.hidden_size, get_activation_layer("silu"), **factory_kwargs) if guidance_embed else None
545
+ )
546
+
547
+ # double blocks
548
+ self.double_blocks = nn.ModuleList(
549
+ [
550
+ MMDoubleStreamBlock(
551
+ self.hidden_size,
552
+ self.heads_num,
553
+ mlp_width_ratio=mlp_width_ratio,
554
+ mlp_act_type=mlp_act_type,
555
+ qk_norm=qk_norm,
556
+ qk_norm_type=qk_norm_type,
557
+ qkv_bias=qkv_bias,
558
+ attn_mode=attn_mode,
559
+ **factory_kwargs,
560
+ )
561
+ for _ in range(mm_double_blocks_depth)
562
+ ]
563
+ )
564
+
565
+ # single blocks
566
+ self.single_blocks = nn.ModuleList(
567
+ [
568
+ MMSingleStreamBlock(
569
+ self.hidden_size,
570
+ self.heads_num,
571
+ mlp_width_ratio=mlp_width_ratio,
572
+ mlp_act_type=mlp_act_type,
573
+ qk_norm=qk_norm,
574
+ qk_norm_type=qk_norm_type,
575
+ attn_mode=attn_mode,
576
+ **factory_kwargs,
577
+ )
578
+ for _ in range(mm_single_blocks_depth)
579
+ ]
580
+ )
581
+
582
+ self.final_layer = FinalLayer(
583
+ self.hidden_size,
584
+ self.patch_size,
585
+ self.out_channels,
586
+ get_activation_layer("silu"),
587
+ **factory_kwargs,
588
+ )
589
+
590
+ self.gradient_checkpointing = False
591
+ self.blocks_to_swap = None
592
+ self.offloader_double = None
593
+ self.offloader_single = None
594
+ self._enable_img_in_txt_in_offloading = False
595
+
596
+ @property
597
+ def device(self):
598
+ return next(self.parameters()).device
599
+
600
+ @property
601
+ def dtype(self):
602
+ return next(self.parameters()).dtype
603
+
604
+ def enable_gradient_checkpointing(self):
605
+ self.gradient_checkpointing = True
606
+
607
+ self.txt_in.enable_gradient_checkpointing()
608
+
609
+ for block in self.double_blocks + self.single_blocks:
610
+ block.enable_gradient_checkpointing()
611
+
612
+ print(f"HYVideoDiffusionTransformer: Gradient checkpointing enabled.")
613
+
614
+ def enable_img_in_txt_in_offloading(self):
615
+ self._enable_img_in_txt_in_offloading = True
616
+
617
+ def enable_block_swap(self, num_blocks: int, device: torch.device, supports_backward: bool):
618
+ self.blocks_to_swap = num_blocks
619
+ self.num_double_blocks = len(self.double_blocks)
620
+ self.num_single_blocks = len(self.single_blocks)
621
+ double_blocks_to_swap = num_blocks // 2
622
+ single_blocks_to_swap = (num_blocks - double_blocks_to_swap) * 2 + 1
623
+
624
+ assert double_blocks_to_swap <= self.num_double_blocks - 1 and single_blocks_to_swap <= self.num_single_blocks - 1, (
625
+ f"Cannot swap more than {self.num_double_blocks - 1} double blocks and {self.num_single_blocks - 1} single blocks. "
626
+ f"Requested {double_blocks_to_swap} double blocks and {single_blocks_to_swap} single blocks."
627
+ )
628
+
629
+ self.offloader_double = ModelOffloader(
630
+ "double", self.double_blocks, self.num_double_blocks, double_blocks_to_swap, supports_backward, device # , debug=True
631
+ )
632
+ self.offloader_single = ModelOffloader(
633
+ "single", self.single_blocks, self.num_single_blocks, single_blocks_to_swap, supports_backward, device # , debug=True
634
+ )
635
+ print(
636
+ f"HYVideoDiffusionTransformer: Block swap enabled. Swapping {num_blocks} blocks, double blocks: {double_blocks_to_swap}, single blocks: {single_blocks_to_swap}."
637
+ )
638
+
639
+ def move_to_device_except_swap_blocks(self, device: torch.device):
640
+ # assume model is on cpu. do not move blocks to device to reduce temporary memory usage
641
+ if self.blocks_to_swap:
642
+ save_double_blocks = self.double_blocks
643
+ save_single_blocks = self.single_blocks
644
+ self.double_blocks = None
645
+ self.single_blocks = None
646
+
647
+ self.to(device)
648
+
649
+ if self.blocks_to_swap:
650
+ self.double_blocks = save_double_blocks
651
+ self.single_blocks = save_single_blocks
652
+
653
+ def prepare_block_swap_before_forward(self):
654
+ if self.blocks_to_swap is None or self.blocks_to_swap == 0:
655
+ return
656
+ self.offloader_double.prepare_block_devices_before_forward(self.double_blocks)
657
+ self.offloader_single.prepare_block_devices_before_forward(self.single_blocks)
658
+
659
+ def enable_deterministic(self):
660
+ for block in self.double_blocks:
661
+ block.enable_deterministic()
662
+ for block in self.single_blocks:
663
+ block.enable_deterministic()
664
+
665
+ def disable_deterministic(self):
666
+ for block in self.double_blocks:
667
+ block.disable_deterministic()
668
+ for block in self.single_blocks:
669
+ block.disable_deterministic()
670
+
671
+ def forward(
672
+ self,
673
+ x: torch.Tensor,
674
+ t: torch.Tensor, # Should be in range(0, 1000).
675
+ text_states: torch.Tensor = None,
676
+ text_mask: torch.Tensor = None, # Now we don't use it.
677
+ text_states_2: Optional[torch.Tensor] = None, # Text embedding for modulation.
678
+ freqs_cos: Optional[torch.Tensor] = None,
679
+ freqs_sin: Optional[torch.Tensor] = None,
680
+ guidance: torch.Tensor = None, # Guidance for modulation, should be cfg_scale x 1000.
681
+ return_dict: bool = True,
682
+ ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
683
+ out = {}
684
+ img = x
685
+ txt = text_states
686
+ _, _, ot, oh, ow = x.shape
687
+ tt, th, tw = (
688
+ ot // self.patch_size[0],
689
+ oh // self.patch_size[1],
690
+ ow // self.patch_size[2],
691
+ )
692
+
693
+ # Prepare modulation vectors.
694
+ vec = self.time_in(t)
695
+
696
+ # text modulation
697
+ vec = vec + self.vector_in(text_states_2)
698
+
699
+ # guidance modulation
700
+ if self.guidance_embed:
701
+ if guidance is None:
702
+ raise ValueError("Didn't get guidance strength for guidance distilled model.")
703
+
704
+ # our timestep_embedding is merged into guidance_in(TimestepEmbedder)
705
+ vec = vec + self.guidance_in(guidance)
706
+
707
+ # Embed image and text.
708
+ if self._enable_img_in_txt_in_offloading:
709
+ self.img_in.to(x.device, non_blocking=True)
710
+ self.txt_in.to(x.device, non_blocking=True)
711
+ synchronize_device(x.device)
712
+
713
+ img = self.img_in(img)
714
+ if self.text_projection == "linear":
715
+ txt = self.txt_in(txt)
716
+ elif self.text_projection == "single_refiner":
717
+ txt = self.txt_in(txt, t, text_mask if self.use_attention_mask else None)
718
+ else:
719
+ raise NotImplementedError(f"Unsupported text_projection: {self.text_projection}")
720
+
721
+ if self._enable_img_in_txt_in_offloading:
722
+ self.img_in.to(torch.device("cpu"), non_blocking=True)
723
+ self.txt_in.to(torch.device("cpu"), non_blocking=True)
724
+ synchronize_device(x.device)
725
+ clean_memory_on_device(x.device)
726
+
727
+ txt_seq_len = txt.shape[1]
728
+ img_seq_len = img.shape[1]
729
+
730
+ # Compute cu_squlens and max_seqlen for flash attention
731
+ cu_seqlens_q = get_cu_seqlens(text_mask, img_seq_len)
732
+ cu_seqlens_kv = cu_seqlens_q
733
+ max_seqlen_q = img_seq_len + txt_seq_len
734
+ max_seqlen_kv = max_seqlen_q
735
+
736
+ attn_mask = None
737
+ if self.attn_mode == "torch":
738
+ # initialize attention mask: bool tensor for sdpa, (b, 1, n, n)
739
+ bs = img.shape[0]
740
+ attn_mask = torch.zeros((bs, 1, max_seqlen_q, max_seqlen_q), dtype=torch.bool, device=text_mask.device)
741
+
742
+ # calculate text length and total length
743
+ text_len = text_mask.sum(dim=1) # (bs, )
744
+ total_len = img_seq_len + text_len # (bs, )
745
+
746
+ # set attention mask
747
+ for i in range(bs):
748
+ attn_mask[i, :, : total_len[i], : total_len[i]] = True
749
+
750
+ freqs_cis = (freqs_cos, freqs_sin) if freqs_cos is not None else None
751
+ # --------------------- Pass through DiT blocks ------------------------
752
+ for block_idx, block in enumerate(self.double_blocks):
753
+ double_block_args = [
754
+ img,
755
+ txt,
756
+ vec,
757
+ attn_mask,
758
+ cu_seqlens_q,
759
+ cu_seqlens_kv,
760
+ max_seqlen_q,
761
+ max_seqlen_kv,
762
+ freqs_cis,
763
+ ]
764
+
765
+ if self.blocks_to_swap:
766
+ self.offloader_double.wait_for_block(block_idx)
767
+
768
+ img, txt = block(*double_block_args)
769
+
770
+ if self.blocks_to_swap:
771
+ self.offloader_double.submit_move_blocks_forward(self.double_blocks, block_idx)
772
+
773
+ # Merge txt and img to pass through single stream blocks.
774
+ x = torch.cat((img, txt), 1)
775
+ if self.blocks_to_swap:
776
+ # delete img, txt to reduce memory usage
777
+ del img, txt
778
+ clean_memory_on_device(x.device)
779
+
780
+ if len(self.single_blocks) > 0:
781
+ for block_idx, block in enumerate(self.single_blocks):
782
+ single_block_args = [
783
+ x,
784
+ vec,
785
+ txt_seq_len,
786
+ attn_mask,
787
+ cu_seqlens_q,
788
+ cu_seqlens_kv,
789
+ max_seqlen_q,
790
+ max_seqlen_kv,
791
+ (freqs_cos, freqs_sin),
792
+ ]
793
+ if self.blocks_to_swap:
794
+ self.offloader_single.wait_for_block(block_idx)
795
+
796
+ x = block(*single_block_args)
797
+
798
+ if self.blocks_to_swap:
799
+ self.offloader_single.submit_move_blocks_forward(self.single_blocks, block_idx)
800
+
801
+ img = x[:, :img_seq_len, ...]
802
+ x = None
803
+
804
+ # ---------------------------- Final layer ------------------------------
805
+ img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
806
+
807
+ img = self.unpatchify(img, tt, th, tw)
808
+ if return_dict:
809
+ out["x"] = img
810
+ return out
811
+ return img
812
+
813
+ def unpatchify(self, x, t, h, w):
814
+ """
815
+ x: (N, T, patch_size**2 * C)
816
+ imgs: (N, H, W, C)
817
+ """
818
+ c = self.unpatchify_channels
819
+ pt, ph, pw = self.patch_size
820
+ assert t * h * w == x.shape[1]
821
+
822
+ x = x.reshape(shape=(x.shape[0], t, h, w, c, pt, ph, pw))
823
+ x = torch.einsum("nthwcopq->nctohpwq", x)
824
+ imgs = x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw))
825
+
826
+ return imgs
827
+
828
+ def params_count(self):
829
+ counts = {
830
+ "double": sum(
831
+ [
832
+ sum(p.numel() for p in block.img_attn_qkv.parameters())
833
+ + sum(p.numel() for p in block.img_attn_proj.parameters())
834
+ + sum(p.numel() for p in block.img_mlp.parameters())
835
+ + sum(p.numel() for p in block.txt_attn_qkv.parameters())
836
+ + sum(p.numel() for p in block.txt_attn_proj.parameters())
837
+ + sum(p.numel() for p in block.txt_mlp.parameters())
838
+ for block in self.double_blocks
839
+ ]
840
+ ),
841
+ "single": sum(
842
+ [
843
+ sum(p.numel() for p in block.linear1.parameters()) + sum(p.numel() for p in block.linear2.parameters())
844
+ for block in self.single_blocks
845
+ ]
846
+ ),
847
+ "total": sum(p.numel() for p in self.parameters()),
848
+ }
849
+ counts["attn+mlp"] = counts["double"] + counts["single"]
850
+ return counts
851
+
852
+
853
+ #################################################################################
854
+ # HunyuanVideo Configs #
855
+ #################################################################################
856
+
857
+ HUNYUAN_VIDEO_CONFIG = {
858
+ "HYVideo-T/2": {
859
+ "mm_double_blocks_depth": 20,
860
+ "mm_single_blocks_depth": 40,
861
+ "rope_dim_list": [16, 56, 56],
862
+ "hidden_size": 3072,
863
+ "heads_num": 24,
864
+ "mlp_width_ratio": 4,
865
+ },
866
+ "HYVideo-T/2-cfgdistill": {
867
+ "mm_double_blocks_depth": 20,
868
+ "mm_single_blocks_depth": 40,
869
+ "rope_dim_list": [16, 56, 56],
870
+ "hidden_size": 3072,
871
+ "heads_num": 24,
872
+ "mlp_width_ratio": 4,
873
+ "guidance_embed": True,
874
+ },
875
+ }
876
+
877
+
878
+ def load_dit_model(text_states_dim, text_states_dim_2, in_channels, out_channels, factor_kwargs):
879
+ """load hunyuan video model
880
+
881
+ NOTE: Only support HYVideo-T/2-cfgdistill now.
882
+
883
+ Args:
884
+ text_state_dim (int): text state dimension
885
+ text_state_dim_2 (int): text state dimension 2
886
+ in_channels (int): input channels number
887
+ out_channels (int): output channels number
888
+ factor_kwargs (dict): factor kwargs
889
+
890
+ Returns:
891
+ model (nn.Module): The hunyuan video model
892
+ """
893
+ # if args.model in HUNYUAN_VIDEO_CONFIG.keys():
894
+ model = HYVideoDiffusionTransformer(
895
+ text_states_dim=text_states_dim,
896
+ text_states_dim_2=text_states_dim_2,
897
+ in_channels=in_channels,
898
+ out_channels=out_channels,
899
+ **HUNYUAN_VIDEO_CONFIG["HYVideo-T/2-cfgdistill"],
900
+ **factor_kwargs,
901
+ )
902
+ return model
903
+ # else:
904
+ # raise NotImplementedError()
905
+
906
+
907
+ def load_state_dict(model, model_path):
908
+ state_dict = torch.load(model_path, map_location=lambda storage, loc: storage, weights_only=True)
909
+
910
+ load_key = "module"
911
+ if load_key in state_dict:
912
+ state_dict = state_dict[load_key]
913
+ else:
914
+ raise KeyError(
915
+ f"Missing key: `{load_key}` in the checkpoint: {model_path}. The keys in the checkpoint "
916
+ f"are: {list(state_dict.keys())}."
917
+ )
918
+ model.load_state_dict(state_dict, strict=True, assign=True)
919
+ return model
920
+
921
+
922
+ def load_transformer(dit_path, attn_mode, device, dtype) -> HYVideoDiffusionTransformer:
923
+ # =========================== Build main model ===========================
924
+ factor_kwargs = {"device": device, "dtype": dtype, "attn_mode": attn_mode}
925
+ latent_channels = 16
926
+ in_channels = latent_channels
927
+ out_channels = latent_channels
928
+
929
+ with accelerate.init_empty_weights():
930
+ transformer = load_dit_model(
931
+ text_states_dim=4096,
932
+ text_states_dim_2=768,
933
+ in_channels=in_channels,
934
+ out_channels=out_channels,
935
+ factor_kwargs=factor_kwargs,
936
+ )
937
+
938
+ if os.path.splitext(dit_path)[-1] == ".safetensors":
939
+ # loading safetensors: may be already fp8
940
+ with MemoryEfficientSafeOpen(dit_path) as f:
941
+ state_dict = {}
942
+ for k in f.keys():
943
+ tensor = f.get_tensor(k)
944
+ tensor = tensor.to(device=device, dtype=dtype)
945
+ # TODO support comfy model
946
+ # if k.startswith("model.model."):
947
+ # k = convert_comfy_model_key(k)
948
+ state_dict[k] = tensor
949
+ transformer.load_state_dict(state_dict, strict=True, assign=True)
950
+ else:
951
+ transformer = load_state_dict(transformer, dit_path)
952
+
953
+ return transformer
954
+
955
+
956
+ def get_rotary_pos_embed_by_shape(model, latents_size):
957
+ target_ndim = 3
958
+ ndim = 5 - 2
959
+
960
+ if isinstance(model.patch_size, int):
961
+ assert all(s % model.patch_size == 0 for s in latents_size), (
962
+ f"Latent size(last {ndim} dimensions) should be divisible by patch size({model.patch_size}), "
963
+ f"but got {latents_size}."
964
+ )
965
+ rope_sizes = [s // model.patch_size for s in latents_size]
966
+ elif isinstance(model.patch_size, list):
967
+ assert all(s % model.patch_size[idx] == 0 for idx, s in enumerate(latents_size)), (
968
+ f"Latent size(last {ndim} dimensions) should be divisible by patch size({model.patch_size}), "
969
+ f"but got {latents_size}."
970
+ )
971
+ rope_sizes = [s // model.patch_size[idx] for idx, s in enumerate(latents_size)]
972
+
973
+ if len(rope_sizes) != target_ndim:
974
+ rope_sizes = [1] * (target_ndim - len(rope_sizes)) + rope_sizes # time axis
975
+ head_dim = model.hidden_size // model.heads_num
976
+ rope_dim_list = model.rope_dim_list
977
+ if rope_dim_list is None:
978
+ rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
979
+ assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer"
980
+
981
+ rope_theta = 256
982
+ freqs_cos, freqs_sin = get_nd_rotary_pos_embed(
983
+ rope_dim_list, rope_sizes, theta=rope_theta, use_real=True, theta_rescale_factor=1
984
+ )
985
+ return freqs_cos, freqs_sin
986
+
987
+
988
+ def get_rotary_pos_embed(vae_name, model, video_length, height, width):
989
+ # 884
990
+ if "884" in vae_name:
991
+ latents_size = [(video_length - 1) // 4 + 1, height // 8, width // 8]
992
+ elif "888" in vae_name:
993
+ latents_size = [(video_length - 1) // 8 + 1, height // 8, width // 8]
994
+ else:
995
+ latents_size = [video_length, height // 8, width // 8]
996
+
997
+ return get_rotary_pos_embed_by_shape(model, latents_size)
hunyuan_model/modulate_layers.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+
7
+ class ModulateDiT(nn.Module):
8
+ """Modulation layer for DiT."""
9
+ def __init__(
10
+ self,
11
+ hidden_size: int,
12
+ factor: int,
13
+ act_layer: Callable,
14
+ dtype=None,
15
+ device=None,
16
+ ):
17
+ factory_kwargs = {"dtype": dtype, "device": device}
18
+ super().__init__()
19
+ self.act = act_layer()
20
+ self.linear = nn.Linear(
21
+ hidden_size, factor * hidden_size, bias=True, **factory_kwargs
22
+ )
23
+ # Zero-initialize the modulation
24
+ nn.init.zeros_(self.linear.weight)
25
+ nn.init.zeros_(self.linear.bias)
26
+
27
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
28
+ return self.linear(self.act(x))
29
+
30
+
31
+ def modulate(x, shift=None, scale=None):
32
+ """modulate by shift and scale
33
+
34
+ Args:
35
+ x (torch.Tensor): input tensor.
36
+ shift (torch.Tensor, optional): shift tensor. Defaults to None.
37
+ scale (torch.Tensor, optional): scale tensor. Defaults to None.
38
+
39
+ Returns:
40
+ torch.Tensor: the output tensor after modulate.
41
+ """
42
+ if scale is None and shift is None:
43
+ return x
44
+ elif shift is None:
45
+ return x * (1 + scale.unsqueeze(1))
46
+ elif scale is None:
47
+ return x + shift.unsqueeze(1)
48
+ else:
49
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
50
+
51
+
52
+ def apply_gate(x, gate=None, tanh=False):
53
+ """AI is creating summary for apply_gate
54
+
55
+ Args:
56
+ x (torch.Tensor): input tensor.
57
+ gate (torch.Tensor, optional): gate tensor. Defaults to None.
58
+ tanh (bool, optional): whether to use tanh function. Defaults to False.
59
+
60
+ Returns:
61
+ torch.Tensor: the output tensor after apply gate.
62
+ """
63
+ if gate is None:
64
+ return x
65
+ if tanh:
66
+ return x * gate.unsqueeze(1).tanh()
67
+ else:
68
+ return x * gate.unsqueeze(1)
69
+
70
+
71
+ def ckpt_wrapper(module):
72
+ def ckpt_forward(*inputs):
73
+ outputs = module(*inputs)
74
+ return outputs
75
+
76
+ return ckpt_forward
hunyuan_model/norm_layers.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class RMSNorm(nn.Module):
6
+ def __init__(
7
+ self,
8
+ dim: int,
9
+ elementwise_affine=True,
10
+ eps: float = 1e-6,
11
+ device=None,
12
+ dtype=None,
13
+ ):
14
+ """
15
+ Initialize the RMSNorm normalization layer.
16
+
17
+ Args:
18
+ dim (int): The dimension of the input tensor.
19
+ eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
20
+
21
+ Attributes:
22
+ eps (float): A small value added to the denominator for numerical stability.
23
+ weight (nn.Parameter): Learnable scaling parameter.
24
+
25
+ """
26
+ factory_kwargs = {"device": device, "dtype": dtype}
27
+ super().__init__()
28
+ self.eps = eps
29
+ if elementwise_affine:
30
+ self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs))
31
+
32
+ def _norm(self, x):
33
+ """
34
+ Apply the RMSNorm normalization to the input tensor.
35
+
36
+ Args:
37
+ x (torch.Tensor): The input tensor.
38
+
39
+ Returns:
40
+ torch.Tensor: The normalized tensor.
41
+
42
+ """
43
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
44
+
45
+ def forward(self, x):
46
+ """
47
+ Forward pass through the RMSNorm layer.
48
+
49
+ Args:
50
+ x (torch.Tensor): The input tensor.
51
+
52
+ Returns:
53
+ torch.Tensor: The output tensor after applying RMSNorm.
54
+
55
+ """
56
+ output = self._norm(x.float()).type_as(x)
57
+ if hasattr(self, "weight"):
58
+ # output = output * self.weight
59
+ # support fp8
60
+ output = output * self.weight.to(output.dtype)
61
+ return output
62
+
63
+
64
+ def get_norm_layer(norm_layer):
65
+ """
66
+ Get the normalization layer.
67
+
68
+ Args:
69
+ norm_layer (str): The type of normalization layer.
70
+
71
+ Returns:
72
+ norm_layer (nn.Module): The normalization layer.
73
+ """
74
+ if norm_layer == "layer":
75
+ return nn.LayerNorm
76
+ elif norm_layer == "rms":
77
+ return RMSNorm
78
+ else:
79
+ raise NotImplementedError(f"Norm layer {norm_layer} is not implemented")
hunyuan_model/pipeline_hunyuan_video.py ADDED
@@ -0,0 +1,1100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ #
16
+ # Modified from diffusers==0.29.2
17
+ #
18
+ # ==============================================================================
19
+ import inspect
20
+ from typing import Any, Callable, Dict, List, Optional, Union, Tuple
21
+ import torch
22
+ import torch.distributed as dist
23
+ import numpy as np
24
+ from dataclasses import dataclass
25
+ from packaging import version
26
+
27
+ from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
28
+ from diffusers.configuration_utils import FrozenDict
29
+ from diffusers.image_processor import VaeImageProcessor
30
+ from diffusers.loaders import LoraLoaderMixin, TextualInversionLoaderMixin
31
+ from diffusers.models import AutoencoderKL
32
+ from diffusers.models.lora import adjust_lora_scale_text_encoder
33
+ from diffusers.schedulers import KarrasDiffusionSchedulers
34
+ from diffusers.utils import (
35
+ USE_PEFT_BACKEND,
36
+ deprecate,
37
+ logging,
38
+ replace_example_docstring,
39
+ scale_lora_layers,
40
+ unscale_lora_layers,
41
+ )
42
+ from diffusers.utils.torch_utils import randn_tensor
43
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
44
+ from diffusers.utils import BaseOutput
45
+
46
+ from ...constants import PRECISION_TO_TYPE
47
+ from ...vae.autoencoder_kl_causal_3d import AutoencoderKLCausal3D
48
+ from ...text_encoder import TextEncoder
49
+ from ...modules import HYVideoDiffusionTransformer
50
+
51
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
52
+
53
+ EXAMPLE_DOC_STRING = """"""
54
+
55
+
56
+ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
57
+ """
58
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
59
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
60
+ """
61
+ std_text = noise_pred_text.std(
62
+ dim=list(range(1, noise_pred_text.ndim)), keepdim=True
63
+ )
64
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
65
+ # rescale the results from guidance (fixes overexposure)
66
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
67
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
68
+ noise_cfg = (
69
+ guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
70
+ )
71
+ return noise_cfg
72
+
73
+
74
+ def retrieve_timesteps(
75
+ scheduler,
76
+ num_inference_steps: Optional[int] = None,
77
+ device: Optional[Union[str, torch.device]] = None,
78
+ timesteps: Optional[List[int]] = None,
79
+ sigmas: Optional[List[float]] = None,
80
+ **kwargs,
81
+ ):
82
+ """
83
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
84
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
85
+
86
+ Args:
87
+ scheduler (`SchedulerMixin`):
88
+ The scheduler to get timesteps from.
89
+ num_inference_steps (`int`):
90
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
91
+ must be `None`.
92
+ device (`str` or `torch.device`, *optional*):
93
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
94
+ timesteps (`List[int]`, *optional*):
95
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
96
+ `num_inference_steps` and `sigmas` must be `None`.
97
+ sigmas (`List[float]`, *optional*):
98
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
99
+ `num_inference_steps` and `timesteps` must be `None`.
100
+
101
+ Returns:
102
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
103
+ second element is the number of inference steps.
104
+ """
105
+ if timesteps is not None and sigmas is not None:
106
+ raise ValueError(
107
+ "Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values"
108
+ )
109
+ if timesteps is not None:
110
+ accepts_timesteps = "timesteps" in set(
111
+ inspect.signature(scheduler.set_timesteps).parameters.keys()
112
+ )
113
+ if not accepts_timesteps:
114
+ raise ValueError(
115
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
116
+ f" timestep schedules. Please check whether you are using the correct scheduler."
117
+ )
118
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
119
+ timesteps = scheduler.timesteps
120
+ num_inference_steps = len(timesteps)
121
+ elif sigmas is not None:
122
+ accept_sigmas = "sigmas" in set(
123
+ inspect.signature(scheduler.set_timesteps).parameters.keys()
124
+ )
125
+ if not accept_sigmas:
126
+ raise ValueError(
127
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
128
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
129
+ )
130
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
131
+ timesteps = scheduler.timesteps
132
+ num_inference_steps = len(timesteps)
133
+ else:
134
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
135
+ timesteps = scheduler.timesteps
136
+ return timesteps, num_inference_steps
137
+
138
+
139
+ @dataclass
140
+ class HunyuanVideoPipelineOutput(BaseOutput):
141
+ videos: Union[torch.Tensor, np.ndarray]
142
+
143
+
144
+ class HunyuanVideoPipeline(DiffusionPipeline):
145
+ r"""
146
+ Pipeline for text-to-video generation using HunyuanVideo.
147
+
148
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
149
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
150
+
151
+ Args:
152
+ vae ([`AutoencoderKL`]):
153
+ Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
154
+ text_encoder ([`TextEncoder`]):
155
+ Frozen text-encoder.
156
+ text_encoder_2 ([`TextEncoder`]):
157
+ Frozen text-encoder_2.
158
+ transformer ([`HYVideoDiffusionTransformer`]):
159
+ A `HYVideoDiffusionTransformer` to denoise the encoded video latents.
160
+ scheduler ([`SchedulerMixin`]):
161
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents.
162
+ """
163
+
164
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
165
+ _optional_components = ["text_encoder_2"]
166
+ _exclude_from_cpu_offload = ["transformer"]
167
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
168
+
169
+ def __init__(
170
+ self,
171
+ vae: AutoencoderKL,
172
+ text_encoder: TextEncoder,
173
+ transformer: HYVideoDiffusionTransformer,
174
+ scheduler: KarrasDiffusionSchedulers,
175
+ text_encoder_2: Optional[TextEncoder] = None,
176
+ progress_bar_config: Dict[str, Any] = None,
177
+ args=None,
178
+ ):
179
+ super().__init__()
180
+
181
+ # ==========================================================================================
182
+ if progress_bar_config is None:
183
+ progress_bar_config = {}
184
+ if not hasattr(self, "_progress_bar_config"):
185
+ self._progress_bar_config = {}
186
+ self._progress_bar_config.update(progress_bar_config)
187
+
188
+ self.args = args
189
+ # ==========================================================================================
190
+
191
+ if (
192
+ hasattr(scheduler.config, "steps_offset")
193
+ and scheduler.config.steps_offset != 1
194
+ ):
195
+ deprecation_message = (
196
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
197
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
198
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
199
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
200
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
201
+ " file"
202
+ )
203
+ deprecate(
204
+ "steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False
205
+ )
206
+ new_config = dict(scheduler.config)
207
+ new_config["steps_offset"] = 1
208
+ scheduler._internal_dict = FrozenDict(new_config)
209
+
210
+ if (
211
+ hasattr(scheduler.config, "clip_sample")
212
+ and scheduler.config.clip_sample is True
213
+ ):
214
+ deprecation_message = (
215
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
216
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
217
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
218
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
219
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
220
+ )
221
+ deprecate(
222
+ "clip_sample not set", "1.0.0", deprecation_message, standard_warn=False
223
+ )
224
+ new_config = dict(scheduler.config)
225
+ new_config["clip_sample"] = False
226
+ scheduler._internal_dict = FrozenDict(new_config)
227
+
228
+ self.register_modules(
229
+ vae=vae,
230
+ text_encoder=text_encoder,
231
+ transformer=transformer,
232
+ scheduler=scheduler,
233
+ text_encoder_2=text_encoder_2,
234
+ )
235
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
236
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
237
+
238
+ def encode_prompt(
239
+ self,
240
+ prompt,
241
+ device,
242
+ num_videos_per_prompt,
243
+ do_classifier_free_guidance,
244
+ negative_prompt=None,
245
+ prompt_embeds: Optional[torch.Tensor] = None,
246
+ attention_mask: Optional[torch.Tensor] = None,
247
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
248
+ negative_attention_mask: Optional[torch.Tensor] = None,
249
+ lora_scale: Optional[float] = None,
250
+ clip_skip: Optional[int] = None,
251
+ text_encoder: Optional[TextEncoder] = None,
252
+ data_type: Optional[str] = "image",
253
+ ):
254
+ r"""
255
+ Encodes the prompt into text encoder hidden states.
256
+
257
+ Args:
258
+ prompt (`str` or `List[str]`, *optional*):
259
+ prompt to be encoded
260
+ device: (`torch.device`):
261
+ torch device
262
+ num_videos_per_prompt (`int`):
263
+ number of videos that should be generated per prompt
264
+ do_classifier_free_guidance (`bool`):
265
+ whether to use classifier free guidance or not
266
+ negative_prompt (`str` or `List[str]`, *optional*):
267
+ The prompt or prompts not to guide the video generation. If not defined, one has to pass
268
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
269
+ less than `1`).
270
+ prompt_embeds (`torch.Tensor`, *optional*):
271
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
272
+ provided, text embeddings will be generated from `prompt` input argument.
273
+ attention_mask (`torch.Tensor`, *optional*):
274
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
275
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
276
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
277
+ argument.
278
+ negative_attention_mask (`torch.Tensor`, *optional*):
279
+ lora_scale (`float`, *optional*):
280
+ A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
281
+ clip_skip (`int`, *optional*):
282
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
283
+ the output of the pre-final layer will be used for computing the prompt embeddings.
284
+ text_encoder (TextEncoder, *optional*):
285
+ data_type (`str`, *optional*):
286
+ """
287
+ if text_encoder is None:
288
+ text_encoder = self.text_encoder
289
+
290
+ # set lora scale so that monkey patched LoRA
291
+ # function of text encoder can correctly access it
292
+ if lora_scale is not None and isinstance(self, LoraLoaderMixin):
293
+ self._lora_scale = lora_scale
294
+
295
+ # dynamically adjust the LoRA scale
296
+ if not USE_PEFT_BACKEND:
297
+ adjust_lora_scale_text_encoder(text_encoder.model, lora_scale)
298
+ else:
299
+ scale_lora_layers(text_encoder.model, lora_scale)
300
+
301
+ if prompt is not None and isinstance(prompt, str):
302
+ batch_size = 1
303
+ elif prompt is not None and isinstance(prompt, list):
304
+ batch_size = len(prompt)
305
+ else:
306
+ batch_size = prompt_embeds.shape[0]
307
+
308
+ if prompt_embeds is None:
309
+ # textual inversion: process multi-vector tokens if necessary
310
+ if isinstance(self, TextualInversionLoaderMixin):
311
+ prompt = self.maybe_convert_prompt(prompt, text_encoder.tokenizer)
312
+
313
+ text_inputs = text_encoder.text2tokens(prompt, data_type=data_type)
314
+
315
+ if clip_skip is None:
316
+ prompt_outputs = text_encoder.encode(
317
+ text_inputs, data_type=data_type, device=device
318
+ )
319
+ prompt_embeds = prompt_outputs.hidden_state
320
+ else:
321
+ prompt_outputs = text_encoder.encode(
322
+ text_inputs,
323
+ output_hidden_states=True,
324
+ data_type=data_type,
325
+ device=device,
326
+ )
327
+ # Access the `hidden_states` first, that contains a tuple of
328
+ # all the hidden states from the encoder layers. Then index into
329
+ # the tuple to access the hidden states from the desired layer.
330
+ prompt_embeds = prompt_outputs.hidden_states_list[-(clip_skip + 1)]
331
+ # We also need to apply the final LayerNorm here to not mess with the
332
+ # representations. The `last_hidden_states` that we typically use for
333
+ # obtaining the final prompt representations passes through the LayerNorm
334
+ # layer.
335
+ prompt_embeds = text_encoder.model.text_model.final_layer_norm(
336
+ prompt_embeds
337
+ )
338
+
339
+ attention_mask = prompt_outputs.attention_mask
340
+ if attention_mask is not None:
341
+ attention_mask = attention_mask.to(device)
342
+ bs_embed, seq_len = attention_mask.shape
343
+ attention_mask = attention_mask.repeat(1, num_videos_per_prompt)
344
+ attention_mask = attention_mask.view(
345
+ bs_embed * num_videos_per_prompt, seq_len
346
+ )
347
+
348
+ if text_encoder is not None:
349
+ prompt_embeds_dtype = text_encoder.dtype
350
+ elif self.transformer is not None:
351
+ prompt_embeds_dtype = self.transformer.dtype
352
+ else:
353
+ prompt_embeds_dtype = prompt_embeds.dtype
354
+
355
+ prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
356
+
357
+ if prompt_embeds.ndim == 2:
358
+ bs_embed, _ = prompt_embeds.shape
359
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
360
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt)
361
+ prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, -1)
362
+ else:
363
+ bs_embed, seq_len, _ = prompt_embeds.shape
364
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
365
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
366
+ prompt_embeds = prompt_embeds.view(
367
+ bs_embed * num_videos_per_prompt, seq_len, -1
368
+ )
369
+
370
+ # get unconditional embeddings for classifier free guidance
371
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
372
+ uncond_tokens: List[str]
373
+ if negative_prompt is None:
374
+ uncond_tokens = [""] * batch_size
375
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
376
+ raise TypeError(
377
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
378
+ f" {type(prompt)}."
379
+ )
380
+ elif isinstance(negative_prompt, str):
381
+ uncond_tokens = [negative_prompt]
382
+ elif batch_size != len(negative_prompt):
383
+ raise ValueError(
384
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
385
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
386
+ " the batch size of `prompt`."
387
+ )
388
+ else:
389
+ uncond_tokens = negative_prompt
390
+
391
+ # textual inversion: process multi-vector tokens if necessary
392
+ if isinstance(self, TextualInversionLoaderMixin):
393
+ uncond_tokens = self.maybe_convert_prompt(
394
+ uncond_tokens, text_encoder.tokenizer
395
+ )
396
+
397
+ # max_length = prompt_embeds.shape[1]
398
+ uncond_input = text_encoder.text2tokens(uncond_tokens, data_type=data_type)
399
+
400
+ negative_prompt_outputs = text_encoder.encode(
401
+ uncond_input, data_type=data_type, device=device
402
+ )
403
+ negative_prompt_embeds = negative_prompt_outputs.hidden_state
404
+
405
+ negative_attention_mask = negative_prompt_outputs.attention_mask
406
+ if negative_attention_mask is not None:
407
+ negative_attention_mask = negative_attention_mask.to(device)
408
+ _, seq_len = negative_attention_mask.shape
409
+ negative_attention_mask = negative_attention_mask.repeat(
410
+ 1, num_videos_per_prompt
411
+ )
412
+ negative_attention_mask = negative_attention_mask.view(
413
+ batch_size * num_videos_per_prompt, seq_len
414
+ )
415
+
416
+ if do_classifier_free_guidance:
417
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
418
+ seq_len = negative_prompt_embeds.shape[1]
419
+
420
+ negative_prompt_embeds = negative_prompt_embeds.to(
421
+ dtype=prompt_embeds_dtype, device=device
422
+ )
423
+
424
+ if negative_prompt_embeds.ndim == 2:
425
+ negative_prompt_embeds = negative_prompt_embeds.repeat(
426
+ 1, num_videos_per_prompt
427
+ )
428
+ negative_prompt_embeds = negative_prompt_embeds.view(
429
+ batch_size * num_videos_per_prompt, -1
430
+ )
431
+ else:
432
+ negative_prompt_embeds = negative_prompt_embeds.repeat(
433
+ 1, num_videos_per_prompt, 1
434
+ )
435
+ negative_prompt_embeds = negative_prompt_embeds.view(
436
+ batch_size * num_videos_per_prompt, seq_len, -1
437
+ )
438
+
439
+ if text_encoder is not None:
440
+ if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
441
+ # Retrieve the original scale by scaling back the LoRA layers
442
+ unscale_lora_layers(text_encoder.model, lora_scale)
443
+
444
+ return (
445
+ prompt_embeds,
446
+ negative_prompt_embeds,
447
+ attention_mask,
448
+ negative_attention_mask,
449
+ )
450
+
451
+ def decode_latents(self, latents, enable_tiling=True):
452
+ deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
453
+ deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
454
+
455
+ latents = 1 / self.vae.config.scaling_factor * latents
456
+ if enable_tiling:
457
+ self.vae.enable_tiling()
458
+ image = self.vae.decode(latents, return_dict=False)[0]
459
+ else:
460
+ image = self.vae.decode(latents, return_dict=False)[0]
461
+ image = (image / 2 + 0.5).clamp(0, 1)
462
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
463
+ if image.ndim == 4:
464
+ image = image.cpu().permute(0, 2, 3, 1).float()
465
+ else:
466
+ image = image.cpu().float()
467
+ return image
468
+
469
+ def prepare_extra_func_kwargs(self, func, kwargs):
470
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
471
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
472
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
473
+ # and should be between [0, 1]
474
+ extra_step_kwargs = {}
475
+
476
+ for k, v in kwargs.items():
477
+ accepts = k in set(inspect.signature(func).parameters.keys())
478
+ if accepts:
479
+ extra_step_kwargs[k] = v
480
+ return extra_step_kwargs
481
+
482
+ def check_inputs(
483
+ self,
484
+ prompt,
485
+ height,
486
+ width,
487
+ video_length,
488
+ callback_steps,
489
+ negative_prompt=None,
490
+ prompt_embeds=None,
491
+ negative_prompt_embeds=None,
492
+ callback_on_step_end_tensor_inputs=None,
493
+ vae_ver="88-4c-sd",
494
+ ):
495
+ if height % 8 != 0 or width % 8 != 0:
496
+ raise ValueError(
497
+ f"`height` and `width` have to be divisible by 8 but are {height} and {width}."
498
+ )
499
+
500
+ if video_length is not None:
501
+ if "884" in vae_ver:
502
+ if video_length != 1 and (video_length - 1) % 4 != 0:
503
+ raise ValueError(
504
+ f"`video_length` has to be 1 or a multiple of 4 but is {video_length}."
505
+ )
506
+ elif "888" in vae_ver:
507
+ if video_length != 1 and (video_length - 1) % 8 != 0:
508
+ raise ValueError(
509
+ f"`video_length` has to be 1 or a multiple of 8 but is {video_length}."
510
+ )
511
+
512
+ if callback_steps is not None and (
513
+ not isinstance(callback_steps, int) or callback_steps <= 0
514
+ ):
515
+ raise ValueError(
516
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
517
+ f" {type(callback_steps)}."
518
+ )
519
+ if callback_on_step_end_tensor_inputs is not None and not all(
520
+ k in self._callback_tensor_inputs
521
+ for k in callback_on_step_end_tensor_inputs
522
+ ):
523
+ raise ValueError(
524
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
525
+ )
526
+
527
+ if prompt is not None and prompt_embeds is not None:
528
+ raise ValueError(
529
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
530
+ " only forward one of the two."
531
+ )
532
+ elif prompt is None and prompt_embeds is None:
533
+ raise ValueError(
534
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
535
+ )
536
+ elif prompt is not None and (
537
+ not isinstance(prompt, str) and not isinstance(prompt, list)
538
+ ):
539
+ raise ValueError(
540
+ f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
541
+ )
542
+
543
+ if negative_prompt is not None and negative_prompt_embeds is not None:
544
+ raise ValueError(
545
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
546
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
547
+ )
548
+
549
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
550
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
551
+ raise ValueError(
552
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
553
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
554
+ f" {negative_prompt_embeds.shape}."
555
+ )
556
+
557
+
558
+ def prepare_latents(
559
+ self,
560
+ batch_size,
561
+ num_channels_latents,
562
+ height,
563
+ width,
564
+ video_length,
565
+ dtype,
566
+ device,
567
+ generator,
568
+ latents=None,
569
+ ):
570
+ shape = (
571
+ batch_size,
572
+ num_channels_latents,
573
+ video_length,
574
+ int(height) // self.vae_scale_factor,
575
+ int(width) // self.vae_scale_factor,
576
+ )
577
+ if isinstance(generator, list) and len(generator) != batch_size:
578
+ raise ValueError(
579
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
580
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
581
+ )
582
+
583
+ if latents is None:
584
+ latents = randn_tensor(
585
+ shape, generator=generator, device=device, dtype=dtype
586
+ )
587
+ else:
588
+ latents = latents.to(device)
589
+
590
+ # Check existence to make it compatible with FlowMatchEulerDiscreteScheduler
591
+ if hasattr(self.scheduler, "init_noise_sigma"):
592
+ # scale the initial noise by the standard deviation required by the scheduler
593
+ latents = latents * self.scheduler.init_noise_sigma
594
+ return latents
595
+
596
+ # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
597
+ def get_guidance_scale_embedding(
598
+ self,
599
+ w: torch.Tensor,
600
+ embedding_dim: int = 512,
601
+ dtype: torch.dtype = torch.float32,
602
+ ) -> torch.Tensor:
603
+ """
604
+ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
605
+
606
+ Args:
607
+ w (`torch.Tensor`):
608
+ Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings.
609
+ embedding_dim (`int`, *optional*, defaults to 512):
610
+ Dimension of the embeddings to generate.
611
+ dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
612
+ Data type of the generated embeddings.
613
+
614
+ Returns:
615
+ `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`.
616
+ """
617
+ assert len(w.shape) == 1
618
+ w = w * 1000.0
619
+
620
+ half_dim = embedding_dim // 2
621
+ emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
622
+ emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
623
+ emb = w.to(dtype)[:, None] * emb[None, :]
624
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
625
+ if embedding_dim % 2 == 1: # zero pad
626
+ emb = torch.nn.functional.pad(emb, (0, 1))
627
+ assert emb.shape == (w.shape[0], embedding_dim)
628
+ return emb
629
+
630
+ @property
631
+ def guidance_scale(self):
632
+ return self._guidance_scale
633
+
634
+ @property
635
+ def guidance_rescale(self):
636
+ return self._guidance_rescale
637
+
638
+ @property
639
+ def clip_skip(self):
640
+ return self._clip_skip
641
+
642
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
643
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
644
+ # corresponds to doing no classifier free guidance.
645
+ @property
646
+ def do_classifier_free_guidance(self):
647
+ # return self._guidance_scale > 1 and self.transformer.config.time_cond_proj_dim is None
648
+ return self._guidance_scale > 1
649
+
650
+ @property
651
+ def cross_attention_kwargs(self):
652
+ return self._cross_attention_kwargs
653
+
654
+ @property
655
+ def num_timesteps(self):
656
+ return self._num_timesteps
657
+
658
+ @property
659
+ def interrupt(self):
660
+ return self._interrupt
661
+
662
+ @torch.no_grad()
663
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
664
+ def __call__(
665
+ self,
666
+ prompt: Union[str, List[str]],
667
+ height: int,
668
+ width: int,
669
+ video_length: int,
670
+ data_type: str = "video",
671
+ num_inference_steps: int = 50,
672
+ timesteps: List[int] = None,
673
+ sigmas: List[float] = None,
674
+ guidance_scale: float = 7.5,
675
+ negative_prompt: Optional[Union[str, List[str]]] = None,
676
+ num_videos_per_prompt: Optional[int] = 1,
677
+ eta: float = 0.0,
678
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
679
+ latents: Optional[torch.Tensor] = None,
680
+ prompt_embeds: Optional[torch.Tensor] = None,
681
+ attention_mask: Optional[torch.Tensor] = None,
682
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
683
+ negative_attention_mask: Optional[torch.Tensor] = None,
684
+ output_type: Optional[str] = "pil",
685
+ return_dict: bool = True,
686
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
687
+ guidance_rescale: float = 0.0,
688
+ clip_skip: Optional[int] = None,
689
+ callback_on_step_end: Optional[
690
+ Union[
691
+ Callable[[int, int, Dict], None],
692
+ PipelineCallback,
693
+ MultiPipelineCallbacks,
694
+ ]
695
+ ] = None,
696
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
697
+ freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None,
698
+ vae_ver: str = "88-4c-sd",
699
+ enable_tiling: bool = False,
700
+ n_tokens: Optional[int] = None,
701
+ embedded_guidance_scale: Optional[float] = None,
702
+ **kwargs,
703
+ ):
704
+ r"""
705
+ The call function to the pipeline for generation.
706
+
707
+ Args:
708
+ prompt (`str` or `List[str]`):
709
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
710
+ height (`int`):
711
+ The height in pixels of the generated image.
712
+ width (`int`):
713
+ The width in pixels of the generated image.
714
+ video_length (`int`):
715
+ The number of frames in the generated video.
716
+ num_inference_steps (`int`, *optional*, defaults to 50):
717
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
718
+ expense of slower inference.
719
+ timesteps (`List[int]`, *optional*):
720
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
721
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
722
+ passed will be used. Must be in descending order.
723
+ sigmas (`List[float]`, *optional*):
724
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
725
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
726
+ will be used.
727
+ guidance_scale (`float`, *optional*, defaults to 7.5):
728
+ A higher guidance scale value encourages the model to generate images closely linked to the text
729
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
730
+ negative_prompt (`str` or `List[str]`, *optional*):
731
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
732
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
733
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
734
+ The number of images to generate per prompt.
735
+ eta (`float`, *optional*, defaults to 0.0):
736
+ Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
737
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
738
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
739
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
740
+ generation deterministic.
741
+ latents (`torch.Tensor`, *optional*):
742
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
743
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
744
+ tensor is generated by sampling using the supplied random `generator`.
745
+ prompt_embeds (`torch.Tensor`, *optional*):
746
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
747
+ provided, text embeddings are generated from the `prompt` input argument.
748
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
749
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
750
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
751
+
752
+ output_type (`str`, *optional*, defaults to `"pil"`):
753
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
754
+ return_dict (`bool`, *optional*, defaults to `True`):
755
+ Whether or not to return a [`HunyuanVideoPipelineOutput`] instead of a
756
+ plain tuple.
757
+ cross_attention_kwargs (`dict`, *optional*):
758
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
759
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
760
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
761
+ Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are
762
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when
763
+ using zero terminal SNR.
764
+ clip_skip (`int`, *optional*):
765
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
766
+ the output of the pre-final layer will be used for computing the prompt embeddings.
767
+ callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
768
+ A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
769
+ each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
770
+ DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
771
+ list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
772
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
773
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
774
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
775
+ `._callback_tensor_inputs` attribute of your pipeline class.
776
+
777
+ Examples:
778
+
779
+ Returns:
780
+ [`~HunyuanVideoPipelineOutput`] or `tuple`:
781
+ If `return_dict` is `True`, [`HunyuanVideoPipelineOutput`] is returned,
782
+ otherwise a `tuple` is returned where the first element is a list with the generated images and the
783
+ second element is a list of `bool`s indicating whether the corresponding generated image contains
784
+ "not-safe-for-work" (nsfw) content.
785
+ """
786
+ callback = kwargs.pop("callback", None)
787
+ callback_steps = kwargs.pop("callback_steps", None)
788
+
789
+ if callback is not None:
790
+ deprecate(
791
+ "callback",
792
+ "1.0.0",
793
+ "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
794
+ )
795
+ if callback_steps is not None:
796
+ deprecate(
797
+ "callback_steps",
798
+ "1.0.0",
799
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
800
+ )
801
+
802
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
803
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
804
+
805
+ # 0. Default height and width to unet
806
+ # height = height or self.transformer.config.sample_size * self.vae_scale_factor
807
+ # width = width or self.transformer.config.sample_size * self.vae_scale_factor
808
+ # to deal with lora scaling and other possible forward hooks
809
+
810
+ # 1. Check inputs. Raise error if not correct
811
+ self.check_inputs(
812
+ prompt,
813
+ height,
814
+ width,
815
+ video_length,
816
+ callback_steps,
817
+ negative_prompt,
818
+ prompt_embeds,
819
+ negative_prompt_embeds,
820
+ callback_on_step_end_tensor_inputs,
821
+ vae_ver=vae_ver,
822
+ )
823
+
824
+ self._guidance_scale = guidance_scale
825
+ self._guidance_rescale = guidance_rescale
826
+ self._clip_skip = clip_skip
827
+ self._cross_attention_kwargs = cross_attention_kwargs
828
+ self._interrupt = False
829
+
830
+ # 2. Define call parameters
831
+ if prompt is not None and isinstance(prompt, str):
832
+ batch_size = 1
833
+ elif prompt is not None and isinstance(prompt, list):
834
+ batch_size = len(prompt)
835
+ else:
836
+ batch_size = prompt_embeds.shape[0]
837
+
838
+ device = torch.device(f"cuda:{dist.get_rank()}") if dist.is_initialized() else self._execution_device
839
+
840
+ # 3. Encode input prompt
841
+ lora_scale = (
842
+ self.cross_attention_kwargs.get("scale", None)
843
+ if self.cross_attention_kwargs is not None
844
+ else None
845
+ )
846
+
847
+ (
848
+ prompt_embeds,
849
+ negative_prompt_embeds,
850
+ prompt_mask,
851
+ negative_prompt_mask,
852
+ ) = self.encode_prompt(
853
+ prompt,
854
+ device,
855
+ num_videos_per_prompt,
856
+ self.do_classifier_free_guidance,
857
+ negative_prompt,
858
+ prompt_embeds=prompt_embeds,
859
+ attention_mask=attention_mask,
860
+ negative_prompt_embeds=negative_prompt_embeds,
861
+ negative_attention_mask=negative_attention_mask,
862
+ lora_scale=lora_scale,
863
+ clip_skip=self.clip_skip,
864
+ data_type=data_type,
865
+ )
866
+ if self.text_encoder_2 is not None:
867
+ (
868
+ prompt_embeds_2,
869
+ negative_prompt_embeds_2,
870
+ prompt_mask_2,
871
+ negative_prompt_mask_2,
872
+ ) = self.encode_prompt(
873
+ prompt,
874
+ device,
875
+ num_videos_per_prompt,
876
+ self.do_classifier_free_guidance,
877
+ negative_prompt,
878
+ prompt_embeds=None,
879
+ attention_mask=None,
880
+ negative_prompt_embeds=None,
881
+ negative_attention_mask=None,
882
+ lora_scale=lora_scale,
883
+ clip_skip=self.clip_skip,
884
+ text_encoder=self.text_encoder_2,
885
+ data_type=data_type,
886
+ )
887
+ else:
888
+ prompt_embeds_2 = None
889
+ negative_prompt_embeds_2 = None
890
+ prompt_mask_2 = None
891
+ negative_prompt_mask_2 = None
892
+
893
+ # For classifier free guidance, we need to do two forward passes.
894
+ # Here we concatenate the unconditional and text embeddings into a single batch
895
+ # to avoid doing two forward passes
896
+ if self.do_classifier_free_guidance:
897
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
898
+ if prompt_mask is not None:
899
+ prompt_mask = torch.cat([negative_prompt_mask, prompt_mask])
900
+ if prompt_embeds_2 is not None:
901
+ prompt_embeds_2 = torch.cat([negative_prompt_embeds_2, prompt_embeds_2])
902
+ if prompt_mask_2 is not None:
903
+ prompt_mask_2 = torch.cat([negative_prompt_mask_2, prompt_mask_2])
904
+
905
+
906
+ # 4. Prepare timesteps
907
+ extra_set_timesteps_kwargs = self.prepare_extra_func_kwargs(
908
+ self.scheduler.set_timesteps, {"n_tokens": n_tokens}
909
+ )
910
+ timesteps, num_inference_steps = retrieve_timesteps(
911
+ self.scheduler,
912
+ num_inference_steps,
913
+ device,
914
+ timesteps,
915
+ sigmas,
916
+ **extra_set_timesteps_kwargs,
917
+ )
918
+
919
+ if "884" in vae_ver:
920
+ video_length = (video_length - 1) // 4 + 1
921
+ elif "888" in vae_ver:
922
+ video_length = (video_length - 1) // 8 + 1
923
+ else:
924
+ video_length = video_length
925
+
926
+ # 5. Prepare latent variables
927
+ num_channels_latents = self.transformer.config.in_channels
928
+ latents = self.prepare_latents(
929
+ batch_size * num_videos_per_prompt,
930
+ num_channels_latents,
931
+ height,
932
+ width,
933
+ video_length,
934
+ prompt_embeds.dtype,
935
+ device,
936
+ generator,
937
+ latents,
938
+ )
939
+
940
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
941
+ extra_step_kwargs = self.prepare_extra_func_kwargs(
942
+ self.scheduler.step,
943
+ {"generator": generator, "eta": eta},
944
+ )
945
+
946
+ target_dtype = PRECISION_TO_TYPE[self.args.precision]
947
+ autocast_enabled = (
948
+ target_dtype != torch.float32
949
+ ) and not self.args.disable_autocast
950
+ vae_dtype = PRECISION_TO_TYPE[self.args.vae_precision]
951
+ vae_autocast_enabled = (
952
+ vae_dtype != torch.float32
953
+ ) and not self.args.disable_autocast
954
+
955
+ # 7. Denoising loop
956
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
957
+ self._num_timesteps = len(timesteps)
958
+
959
+ # if is_progress_bar:
960
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
961
+ for i, t in enumerate(timesteps):
962
+ if self.interrupt:
963
+ continue
964
+
965
+ # expand the latents if we are doing classifier free guidance
966
+ latent_model_input = (
967
+ torch.cat([latents] * 2)
968
+ if self.do_classifier_free_guidance
969
+ else latents
970
+ )
971
+ latent_model_input = self.scheduler.scale_model_input(
972
+ latent_model_input, t
973
+ )
974
+
975
+ t_expand = t.repeat(latent_model_input.shape[0])
976
+ guidance_expand = (
977
+ torch.tensor(
978
+ [embedded_guidance_scale] * latent_model_input.shape[0],
979
+ dtype=torch.float32,
980
+ device=device,
981
+ ).to(target_dtype)
982
+ * 1000.0
983
+ if embedded_guidance_scale is not None
984
+ else None
985
+ )
986
+
987
+ # predict the noise residual
988
+ with torch.autocast(
989
+ device_type="cuda", dtype=target_dtype, enabled=autocast_enabled
990
+ ):
991
+ noise_pred = self.transformer( # For an input image (129, 192, 336) (1, 256, 256)
992
+ latent_model_input, # [2, 16, 33, 24, 42]
993
+ t_expand, # [2]
994
+ text_states=prompt_embeds, # [2, 256, 4096]
995
+ text_mask=prompt_mask, # [2, 256]
996
+ text_states_2=prompt_embeds_2, # [2, 768]
997
+ freqs_cos=freqs_cis[0], # [seqlen, head_dim]
998
+ freqs_sin=freqs_cis[1], # [seqlen, head_dim]
999
+ guidance=guidance_expand,
1000
+ return_dict=True,
1001
+ )[
1002
+ "x"
1003
+ ]
1004
+
1005
+ # perform guidance
1006
+ if self.do_classifier_free_guidance:
1007
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1008
+ noise_pred = noise_pred_uncond + self.guidance_scale * (
1009
+ noise_pred_text - noise_pred_uncond
1010
+ )
1011
+
1012
+ if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
1013
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
1014
+ noise_pred = rescale_noise_cfg(
1015
+ noise_pred,
1016
+ noise_pred_text,
1017
+ guidance_rescale=self.guidance_rescale,
1018
+ )
1019
+
1020
+ # compute the previous noisy sample x_t -> x_t-1
1021
+ latents = self.scheduler.step(
1022
+ noise_pred, t, latents, **extra_step_kwargs, return_dict=False
1023
+ )[0]
1024
+
1025
+ if callback_on_step_end is not None:
1026
+ callback_kwargs = {}
1027
+ for k in callback_on_step_end_tensor_inputs:
1028
+ callback_kwargs[k] = locals()[k]
1029
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1030
+
1031
+ latents = callback_outputs.pop("latents", latents)
1032
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1033
+ negative_prompt_embeds = callback_outputs.pop(
1034
+ "negative_prompt_embeds", negative_prompt_embeds
1035
+ )
1036
+
1037
+ # call the callback, if provided
1038
+ if i == len(timesteps) - 1 or (
1039
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
1040
+ ):
1041
+ if progress_bar is not None:
1042
+ progress_bar.update()
1043
+ if callback is not None and i % callback_steps == 0:
1044
+ step_idx = i // getattr(self.scheduler, "order", 1)
1045
+ callback(step_idx, t, latents)
1046
+
1047
+ if not output_type == "latent":
1048
+ expand_temporal_dim = False
1049
+ if len(latents.shape) == 4:
1050
+ if isinstance(self.vae, AutoencoderKLCausal3D):
1051
+ latents = latents.unsqueeze(2)
1052
+ expand_temporal_dim = True
1053
+ elif len(latents.shape) == 5:
1054
+ pass
1055
+ else:
1056
+ raise ValueError(
1057
+ f"Only support latents with shape (b, c, h, w) or (b, c, f, h, w), but got {latents.shape}."
1058
+ )
1059
+
1060
+ if (
1061
+ hasattr(self.vae.config, "shift_factor")
1062
+ and self.vae.config.shift_factor
1063
+ ):
1064
+ latents = (
1065
+ latents / self.vae.config.scaling_factor
1066
+ + self.vae.config.shift_factor
1067
+ )
1068
+ else:
1069
+ latents = latents / self.vae.config.scaling_factor
1070
+
1071
+ with torch.autocast(
1072
+ device_type="cuda", dtype=vae_dtype, enabled=vae_autocast_enabled
1073
+ ):
1074
+ if enable_tiling:
1075
+ self.vae.enable_tiling()
1076
+ image = self.vae.decode(
1077
+ latents, return_dict=False, generator=generator
1078
+ )[0]
1079
+ else:
1080
+ image = self.vae.decode(
1081
+ latents, return_dict=False, generator=generator
1082
+ )[0]
1083
+
1084
+ if expand_temporal_dim or image.shape[2] == 1:
1085
+ image = image.squeeze(2)
1086
+
1087
+ else:
1088
+ image = latents
1089
+
1090
+ image = (image / 2 + 0.5).clamp(0, 1)
1091
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
1092
+ image = image.cpu().float()
1093
+
1094
+ # Offload all models
1095
+ self.maybe_free_model_hooks()
1096
+
1097
+ if not return_dict:
1098
+ return image
1099
+
1100
+ return HunyuanVideoPipelineOutput(videos=image)
hunyuan_model/posemb_layers.py ADDED
@@ -0,0 +1,310 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Union, Tuple, List
3
+
4
+
5
+ def _to_tuple(x, dim=2):
6
+ if isinstance(x, int):
7
+ return (x,) * dim
8
+ elif len(x) == dim:
9
+ return x
10
+ else:
11
+ raise ValueError(f"Expected length {dim} or int, but got {x}")
12
+
13
+
14
+ def get_meshgrid_nd(start, *args, dim=2):
15
+ """
16
+ Get n-D meshgrid with start, stop and num.
17
+
18
+ Args:
19
+ start (int or tuple): If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop,
20
+ step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num. For n-dim, start/stop/num
21
+ should be int or n-tuple. If n-tuple is provided, the meshgrid will be stacked following the dim order in
22
+ n-tuples.
23
+ *args: See above.
24
+ dim (int): Dimension of the meshgrid. Defaults to 2.
25
+
26
+ Returns:
27
+ grid (np.ndarray): [dim, ...]
28
+ """
29
+ if len(args) == 0:
30
+ # start is grid_size
31
+ num = _to_tuple(start, dim=dim)
32
+ start = (0,) * dim
33
+ stop = num
34
+ elif len(args) == 1:
35
+ # start is start, args[0] is stop, step is 1
36
+ start = _to_tuple(start, dim=dim)
37
+ stop = _to_tuple(args[0], dim=dim)
38
+ num = [stop[i] - start[i] for i in range(dim)]
39
+ elif len(args) == 2:
40
+ # start is start, args[0] is stop, args[1] is num
41
+ start = _to_tuple(start, dim=dim) # Left-Top eg: 12,0
42
+ stop = _to_tuple(args[0], dim=dim) # Right-Bottom eg: 20,32
43
+ num = _to_tuple(args[1], dim=dim) # Target Size eg: 32,124
44
+ else:
45
+ raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}")
46
+
47
+ # PyTorch implement of np.linspace(start[i], stop[i], num[i], endpoint=False)
48
+ axis_grid = []
49
+ for i in range(dim):
50
+ a, b, n = start[i], stop[i], num[i]
51
+ g = torch.linspace(a, b, n + 1, dtype=torch.float32)[:n]
52
+ axis_grid.append(g)
53
+ grid = torch.meshgrid(*axis_grid, indexing="ij") # dim x [W, H, D]
54
+ grid = torch.stack(grid, dim=0) # [dim, W, H, D]
55
+
56
+ return grid
57
+
58
+
59
+ #################################################################################
60
+ # Rotary Positional Embedding Functions #
61
+ #################################################################################
62
+ # https://github.com/meta-llama/llama/blob/be327c427cc5e89cc1d3ab3d3fec4484df771245/llama/model.py#L80
63
+
64
+
65
+ def reshape_for_broadcast(
66
+ freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
67
+ x: torch.Tensor,
68
+ head_first=False,
69
+ ):
70
+ """
71
+ Reshape frequency tensor for broadcasting it with another tensor.
72
+
73
+ This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
74
+ for the purpose of broadcasting the frequency tensor during element-wise operations.
75
+
76
+ Notes:
77
+ When using FlashMHAModified, head_first should be False.
78
+ When using Attention, head_first should be True.
79
+
80
+ Args:
81
+ freqs_cis (Union[torch.Tensor, Tuple[torch.Tensor]]): Frequency tensor to be reshaped.
82
+ x (torch.Tensor): Target tensor for broadcasting compatibility.
83
+ head_first (bool): head dimension first (except batch dim) or not.
84
+
85
+ Returns:
86
+ torch.Tensor: Reshaped frequency tensor.
87
+
88
+ Raises:
89
+ AssertionError: If the frequency tensor doesn't match the expected shape.
90
+ AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions.
91
+ """
92
+ ndim = x.ndim
93
+ assert 0 <= 1 < ndim
94
+
95
+ if isinstance(freqs_cis, tuple):
96
+ # freqs_cis: (cos, sin) in real space
97
+ if head_first:
98
+ assert freqs_cis[0].shape == (
99
+ x.shape[-2],
100
+ x.shape[-1],
101
+ ), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}"
102
+ shape = [
103
+ d if i == ndim - 2 or i == ndim - 1 else 1
104
+ for i, d in enumerate(x.shape)
105
+ ]
106
+ else:
107
+ assert freqs_cis[0].shape == (
108
+ x.shape[1],
109
+ x.shape[-1],
110
+ ), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}"
111
+ shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
112
+ return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape)
113
+ else:
114
+ # freqs_cis: values in complex space
115
+ if head_first:
116
+ assert freqs_cis.shape == (
117
+ x.shape[-2],
118
+ x.shape[-1],
119
+ ), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}"
120
+ shape = [
121
+ d if i == ndim - 2 or i == ndim - 1 else 1
122
+ for i, d in enumerate(x.shape)
123
+ ]
124
+ else:
125
+ assert freqs_cis.shape == (
126
+ x.shape[1],
127
+ x.shape[-1],
128
+ ), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}"
129
+ shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
130
+ return freqs_cis.view(*shape)
131
+
132
+
133
+ def rotate_half(x):
134
+ x_real, x_imag = (
135
+ x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1)
136
+ ) # [B, S, H, D//2]
137
+ return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
138
+
139
+
140
+ def apply_rotary_emb(
141
+ xq: torch.Tensor,
142
+ xk: torch.Tensor,
143
+ freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
144
+ head_first: bool = False,
145
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
146
+ """
147
+ Apply rotary embeddings to input tensors using the given frequency tensor.
148
+
149
+ This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided
150
+ frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor
151
+ is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are
152
+ returned as real tensors.
153
+
154
+ Args:
155
+ xq (torch.Tensor): Query tensor to apply rotary embeddings. [B, S, H, D]
156
+ xk (torch.Tensor): Key tensor to apply rotary embeddings. [B, S, H, D]
157
+ freqs_cis (torch.Tensor or tuple): Precomputed frequency tensor for complex exponential.
158
+ head_first (bool): head dimension first (except batch dim) or not.
159
+
160
+ Returns:
161
+ Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
162
+
163
+ """
164
+ xk_out = None
165
+ if isinstance(freqs_cis, tuple):
166
+ cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first) # [S, D]
167
+ cos, sin = cos.to(xq.device), sin.to(xq.device)
168
+ # real * cos - imag * sin
169
+ # imag * cos + real * sin
170
+ xq_out = (xq.float() * cos + rotate_half(xq.float()) * sin).type_as(xq)
171
+ xk_out = (xk.float() * cos + rotate_half(xk.float()) * sin).type_as(xk)
172
+ else:
173
+ # view_as_complex will pack [..., D/2, 2](real) to [..., D/2](complex)
174
+ xq_ = torch.view_as_complex(
175
+ xq.float().reshape(*xq.shape[:-1], -1, 2)
176
+ ) # [B, S, H, D//2]
177
+ freqs_cis = reshape_for_broadcast(freqs_cis, xq_, head_first).to(
178
+ xq.device
179
+ ) # [S, D//2] --> [1, S, 1, D//2]
180
+ # (real, imag) * (cos, sin) = (real * cos - imag * sin, imag * cos + real * sin)
181
+ # view_as_real will expand [..., D/2](complex) to [..., D/2, 2](real)
182
+ xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3).type_as(xq)
183
+ xk_ = torch.view_as_complex(
184
+ xk.float().reshape(*xk.shape[:-1], -1, 2)
185
+ ) # [B, S, H, D//2]
186
+ xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3).type_as(xk)
187
+
188
+ return xq_out, xk_out
189
+
190
+
191
+ def get_nd_rotary_pos_embed(
192
+ rope_dim_list,
193
+ start,
194
+ *args,
195
+ theta=10000.0,
196
+ use_real=False,
197
+ theta_rescale_factor: Union[float, List[float]] = 1.0,
198
+ interpolation_factor: Union[float, List[float]] = 1.0,
199
+ ):
200
+ """
201
+ This is a n-d version of precompute_freqs_cis, which is a RoPE for tokens with n-d structure.
202
+
203
+ Args:
204
+ rope_dim_list (list of int): Dimension of each rope. len(rope_dim_list) should equal to n.
205
+ sum(rope_dim_list) should equal to head_dim of attention layer.
206
+ start (int | tuple of int | list of int): If len(args) == 0, start is num; If len(args) == 1, start is start,
207
+ args[0] is stop, step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num.
208
+ *args: See above.
209
+ theta (float): Scaling factor for frequency computation. Defaults to 10000.0.
210
+ use_real (bool): If True, return real part and imaginary part separately. Otherwise, return complex numbers.
211
+ Some libraries such as TensorRT does not support complex64 data type. So it is useful to provide a real
212
+ part and an imaginary part separately.
213
+ theta_rescale_factor (float): Rescale factor for theta. Defaults to 1.0.
214
+
215
+ Returns:
216
+ pos_embed (torch.Tensor): [HW, D/2]
217
+ """
218
+
219
+ grid = get_meshgrid_nd(
220
+ start, *args, dim=len(rope_dim_list)
221
+ ) # [3, W, H, D] / [2, W, H]
222
+
223
+ if isinstance(theta_rescale_factor, int) or isinstance(theta_rescale_factor, float):
224
+ theta_rescale_factor = [theta_rescale_factor] * len(rope_dim_list)
225
+ elif isinstance(theta_rescale_factor, list) and len(theta_rescale_factor) == 1:
226
+ theta_rescale_factor = [theta_rescale_factor[0]] * len(rope_dim_list)
227
+ assert len(theta_rescale_factor) == len(
228
+ rope_dim_list
229
+ ), "len(theta_rescale_factor) should equal to len(rope_dim_list)"
230
+
231
+ if isinstance(interpolation_factor, int) or isinstance(interpolation_factor, float):
232
+ interpolation_factor = [interpolation_factor] * len(rope_dim_list)
233
+ elif isinstance(interpolation_factor, list) and len(interpolation_factor) == 1:
234
+ interpolation_factor = [interpolation_factor[0]] * len(rope_dim_list)
235
+ assert len(interpolation_factor) == len(
236
+ rope_dim_list
237
+ ), "len(interpolation_factor) should equal to len(rope_dim_list)"
238
+
239
+ # use 1/ndim of dimensions to encode grid_axis
240
+ embs = []
241
+ for i in range(len(rope_dim_list)):
242
+ emb = get_1d_rotary_pos_embed(
243
+ rope_dim_list[i],
244
+ grid[i].reshape(-1),
245
+ theta,
246
+ use_real=use_real,
247
+ theta_rescale_factor=theta_rescale_factor[i],
248
+ interpolation_factor=interpolation_factor[i],
249
+ ) # 2 x [WHD, rope_dim_list[i]]
250
+ embs.append(emb)
251
+
252
+ if use_real:
253
+ cos = torch.cat([emb[0] for emb in embs], dim=1) # (WHD, D/2)
254
+ sin = torch.cat([emb[1] for emb in embs], dim=1) # (WHD, D/2)
255
+ return cos, sin
256
+ else:
257
+ emb = torch.cat(embs, dim=1) # (WHD, D/2)
258
+ return emb
259
+
260
+
261
+ def get_1d_rotary_pos_embed(
262
+ dim: int,
263
+ pos: Union[torch.FloatTensor, int],
264
+ theta: float = 10000.0,
265
+ use_real: bool = False,
266
+ theta_rescale_factor: float = 1.0,
267
+ interpolation_factor: float = 1.0,
268
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
269
+ """
270
+ Precompute the frequency tensor for complex exponential (cis) with given dimensions.
271
+ (Note: `cis` means `cos + i * sin`, where i is the imaginary unit.)
272
+
273
+ This function calculates a frequency tensor with complex exponential using the given dimension 'dim'
274
+ and the end index 'end'. The 'theta' parameter scales the frequencies.
275
+ The returned tensor contains complex values in complex64 data type.
276
+
277
+ Args:
278
+ dim (int): Dimension of the frequency tensor.
279
+ pos (int or torch.FloatTensor): Position indices for the frequency tensor. [S] or scalar
280
+ theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
281
+ use_real (bool, optional): If True, return real part and imaginary part separately.
282
+ Otherwise, return complex numbers.
283
+ theta_rescale_factor (float, optional): Rescale factor for theta. Defaults to 1.0.
284
+
285
+ Returns:
286
+ freqs_cis: Precomputed frequency tensor with complex exponential. [S, D/2]
287
+ freqs_cos, freqs_sin: Precomputed frequency tensor with real and imaginary parts separately. [S, D]
288
+ """
289
+ if isinstance(pos, int):
290
+ pos = torch.arange(pos).float()
291
+
292
+ # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
293
+ # has some connection to NTK literature
294
+ if theta_rescale_factor != 1.0:
295
+ theta *= theta_rescale_factor ** (dim / (dim - 2))
296
+
297
+ freqs = 1.0 / (
298
+ theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
299
+ ) # [D/2]
300
+ # assert interpolation_factor == 1.0, f"interpolation_factor: {interpolation_factor}"
301
+ freqs = torch.outer(pos * interpolation_factor, freqs) # [S, D/2]
302
+ if use_real:
303
+ freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D]
304
+ freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D]
305
+ return freqs_cos, freqs_sin
306
+ else:
307
+ freqs_cis = torch.polar(
308
+ torch.ones_like(freqs), freqs
309
+ ) # complex64 # [S, D/2]
310
+ return freqs_cis
hunyuan_model/text_encoder.py ADDED
@@ -0,0 +1,438 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Optional, Tuple, Union
3
+ from copy import deepcopy
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from transformers import CLIPTextModel, CLIPTokenizer, AutoTokenizer, AutoModel
8
+ from transformers.utils import ModelOutput
9
+ from transformers.models.llama import LlamaModel
10
+
11
+ import logging
12
+
13
+ logger = logging.getLogger(__name__)
14
+ logging.basicConfig(level=logging.INFO)
15
+
16
+
17
+ # When using decoder-only models, we must provide a prompt template to instruct the text encoder
18
+ # on how to generate the text.
19
+ # --------------------------------------------------------------------
20
+ PROMPT_TEMPLATE_ENCODE = (
21
+ "<|start_header_id|>system<|end_header_id|>\n\nDescribe the image by detailing the color, shape, size, texture, "
22
+ "quantity, text, spatial relationships of the objects and background:<|eot_id|>"
23
+ "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
24
+ )
25
+ PROMPT_TEMPLATE_ENCODE_VIDEO = (
26
+ "<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: "
27
+ "1. The main content and theme of the video."
28
+ "2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects."
29
+ "3. Actions, events, behaviors temporal relationships, physical movement changes of the objects."
30
+ "4. background environment, light, style and atmosphere."
31
+ "5. camera angles, movements, and transitions used in the video:<|eot_id|>"
32
+ "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
33
+ )
34
+
35
+ NEGATIVE_PROMPT = "Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion"
36
+
37
+ PROMPT_TEMPLATE = {
38
+ "dit-llm-encode": {
39
+ "template": PROMPT_TEMPLATE_ENCODE,
40
+ "crop_start": 36,
41
+ },
42
+ "dit-llm-encode-video": {
43
+ "template": PROMPT_TEMPLATE_ENCODE_VIDEO,
44
+ "crop_start": 95,
45
+ },
46
+ }
47
+
48
+
49
+ def use_default(value, default):
50
+ return value if value is not None else default
51
+
52
+
53
+ def load_text_encoder(
54
+ text_encoder_type: str,
55
+ text_encoder_path: str,
56
+ text_encoder_dtype: Optional[Union[str, torch.dtype]] = None,
57
+ ):
58
+ logger.info(f"Loading text encoder model ({text_encoder_type}) from: {text_encoder_path}")
59
+
60
+ # reduce peak memory usage by specifying the dtype of the model
61
+ dtype = text_encoder_dtype
62
+ if text_encoder_type == "clipL":
63
+ text_encoder = CLIPTextModel.from_pretrained(text_encoder_path, torch_dtype=dtype)
64
+ text_encoder.final_layer_norm = text_encoder.text_model.final_layer_norm
65
+ elif text_encoder_type == "llm":
66
+ text_encoder = AutoModel.from_pretrained(text_encoder_path, low_cpu_mem_usage=True, torch_dtype=dtype)
67
+ text_encoder.final_layer_norm = text_encoder.norm
68
+ else:
69
+ raise ValueError(f"Unsupported text encoder type: {text_encoder_type}")
70
+ # from_pretrained will ensure that the model is in eval mode.
71
+
72
+ if dtype is not None:
73
+ text_encoder = text_encoder.to(dtype=dtype)
74
+
75
+ text_encoder.requires_grad_(False)
76
+
77
+ logger.info(f"Text encoder to dtype: {text_encoder.dtype}")
78
+ return text_encoder, text_encoder_path
79
+
80
+
81
+ def load_tokenizer(tokenizer_type, tokenizer_path=None, padding_side="right"):
82
+ logger.info(f"Loading tokenizer ({tokenizer_type}) from: {tokenizer_path}")
83
+
84
+ if tokenizer_type == "clipL":
85
+ tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path, max_length=77)
86
+ elif tokenizer_type == "llm":
87
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, padding_side=padding_side)
88
+ else:
89
+ raise ValueError(f"Unsupported tokenizer type: {tokenizer_type}")
90
+
91
+ return tokenizer, tokenizer_path
92
+
93
+
94
+ @dataclass
95
+ class TextEncoderModelOutput(ModelOutput):
96
+ """
97
+ Base class for model's outputs that also contains a pooling of the last hidden states.
98
+
99
+ Args:
100
+ hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
101
+ Sequence of hidden-states at the output of the last layer of the model.
102
+ attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
103
+ Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
104
+ hidden_states_list (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed):
105
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
106
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
107
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
108
+ text_outputs (`list`, *optional*, returned when `return_texts=True` is passed):
109
+ List of decoded texts.
110
+ """
111
+
112
+ hidden_state: torch.FloatTensor = None
113
+ attention_mask: Optional[torch.LongTensor] = None
114
+ hidden_states_list: Optional[Tuple[torch.FloatTensor, ...]] = None
115
+ text_outputs: Optional[list] = None
116
+
117
+
118
+ class TextEncoder(nn.Module):
119
+ def __init__(
120
+ self,
121
+ text_encoder_type: str,
122
+ max_length: int,
123
+ text_encoder_dtype: Optional[Union[str, torch.dtype]] = None,
124
+ text_encoder_path: Optional[str] = None,
125
+ tokenizer_type: Optional[str] = None,
126
+ tokenizer_path: Optional[str] = None,
127
+ output_key: Optional[str] = None,
128
+ use_attention_mask: bool = True,
129
+ input_max_length: Optional[int] = None,
130
+ prompt_template: Optional[dict] = None,
131
+ prompt_template_video: Optional[dict] = None,
132
+ hidden_state_skip_layer: Optional[int] = None,
133
+ apply_final_norm: bool = False,
134
+ reproduce: bool = False,
135
+ ):
136
+ super().__init__()
137
+ self.text_encoder_type = text_encoder_type
138
+ self.max_length = max_length
139
+ # self.precision = text_encoder_precision
140
+ self.model_path = text_encoder_path
141
+ self.tokenizer_type = tokenizer_type if tokenizer_type is not None else text_encoder_type
142
+ self.tokenizer_path = tokenizer_path if tokenizer_path is not None else text_encoder_path
143
+ self.use_attention_mask = use_attention_mask
144
+ if prompt_template_video is not None:
145
+ assert use_attention_mask is True, "Attention mask is True required when training videos."
146
+ self.input_max_length = input_max_length if input_max_length is not None else max_length
147
+ self.prompt_template = prompt_template
148
+ self.prompt_template_video = prompt_template_video
149
+ self.hidden_state_skip_layer = hidden_state_skip_layer
150
+ self.apply_final_norm = apply_final_norm
151
+ self.reproduce = reproduce
152
+
153
+ self.use_template = self.prompt_template is not None
154
+ if self.use_template:
155
+ assert (
156
+ isinstance(self.prompt_template, dict) and "template" in self.prompt_template
157
+ ), f"`prompt_template` must be a dictionary with a key 'template', got {self.prompt_template}"
158
+ assert "{}" in str(self.prompt_template["template"]), (
159
+ "`prompt_template['template']` must contain a placeholder `{}` for the input text, "
160
+ f"got {self.prompt_template['template']}"
161
+ )
162
+
163
+ self.use_video_template = self.prompt_template_video is not None
164
+ if self.use_video_template:
165
+ if self.prompt_template_video is not None:
166
+ assert (
167
+ isinstance(self.prompt_template_video, dict) and "template" in self.prompt_template_video
168
+ ), f"`prompt_template_video` must be a dictionary with a key 'template', got {self.prompt_template_video}"
169
+ assert "{}" in str(self.prompt_template_video["template"]), (
170
+ "`prompt_template_video['template']` must contain a placeholder `{}` for the input text, "
171
+ f"got {self.prompt_template_video['template']}"
172
+ )
173
+
174
+ if "t5" in text_encoder_type:
175
+ self.output_key = output_key or "last_hidden_state"
176
+ elif "clip" in text_encoder_type:
177
+ self.output_key = output_key or "pooler_output"
178
+ elif "llm" in text_encoder_type or "glm" in text_encoder_type:
179
+ self.output_key = output_key or "last_hidden_state"
180
+ else:
181
+ raise ValueError(f"Unsupported text encoder type: {text_encoder_type}")
182
+
183
+ self.model, self.model_path = load_text_encoder(
184
+ text_encoder_type=self.text_encoder_type, text_encoder_path=self.model_path, text_encoder_dtype=text_encoder_dtype
185
+ )
186
+ self.dtype = self.model.dtype
187
+
188
+ self.tokenizer, self.tokenizer_path = load_tokenizer(
189
+ tokenizer_type=self.tokenizer_type, tokenizer_path=self.tokenizer_path, padding_side="right"
190
+ )
191
+
192
+ def __repr__(self):
193
+ return f"{self.text_encoder_type} ({self.precision} - {self.model_path})"
194
+
195
+ @property
196
+ def device(self):
197
+ return self.model.device
198
+
199
+ @staticmethod
200
+ def apply_text_to_template(text, template, prevent_empty_text=True):
201
+ """
202
+ Apply text to template.
203
+
204
+ Args:
205
+ text (str): Input text.
206
+ template (str or list): Template string or list of chat conversation.
207
+ prevent_empty_text (bool): If Ture, we will prevent the user text from being empty
208
+ by adding a space. Defaults to True.
209
+ """
210
+ if isinstance(template, str):
211
+ # Will send string to tokenizer. Used for llm
212
+ return template.format(text)
213
+ else:
214
+ raise TypeError(f"Unsupported template type: {type(template)}")
215
+
216
+ def text2tokens(self, text, data_type="image"):
217
+ """
218
+ Tokenize the input text.
219
+
220
+ Args:
221
+ text (str or list): Input text.
222
+ """
223
+ tokenize_input_type = "str"
224
+ if self.use_template:
225
+ if data_type == "image":
226
+ prompt_template = self.prompt_template["template"]
227
+ elif data_type == "video":
228
+ prompt_template = self.prompt_template_video["template"]
229
+ else:
230
+ raise ValueError(f"Unsupported data type: {data_type}")
231
+ if isinstance(text, (list, tuple)):
232
+ text = [self.apply_text_to_template(one_text, prompt_template) for one_text in text]
233
+ if isinstance(text[0], list):
234
+ tokenize_input_type = "list"
235
+ elif isinstance(text, str):
236
+ text = self.apply_text_to_template(text, prompt_template)
237
+ if isinstance(text, list):
238
+ tokenize_input_type = "list"
239
+ else:
240
+ raise TypeError(f"Unsupported text type: {type(text)}")
241
+
242
+ kwargs = dict(
243
+ truncation=True,
244
+ max_length=self.max_length,
245
+ padding="max_length",
246
+ return_tensors="pt",
247
+ )
248
+ if tokenize_input_type == "str":
249
+ return self.tokenizer(
250
+ text,
251
+ return_length=False,
252
+ return_overflowing_tokens=False,
253
+ return_attention_mask=True,
254
+ **kwargs,
255
+ )
256
+ elif tokenize_input_type == "list":
257
+ return self.tokenizer.apply_chat_template(
258
+ text,
259
+ add_generation_prompt=True,
260
+ tokenize=True,
261
+ return_dict=True,
262
+ **kwargs,
263
+ )
264
+ else:
265
+ raise ValueError(f"Unsupported tokenize_input_type: {tokenize_input_type}")
266
+
267
+ def encode(
268
+ self,
269
+ batch_encoding,
270
+ use_attention_mask=None,
271
+ output_hidden_states=False,
272
+ do_sample=None,
273
+ hidden_state_skip_layer=None,
274
+ return_texts=False,
275
+ data_type="image",
276
+ device=None,
277
+ ):
278
+ """
279
+ Args:
280
+ batch_encoding (dict): Batch encoding from tokenizer.
281
+ use_attention_mask (bool): Whether to use attention mask. If None, use self.use_attention_mask.
282
+ Defaults to None.
283
+ output_hidden_states (bool): Whether to output hidden states. If False, return the value of
284
+ self.output_key. If True, return the entire output. If set self.hidden_state_skip_layer,
285
+ output_hidden_states will be set True. Defaults to False.
286
+ do_sample (bool): Whether to sample from the model. Used for Decoder-Only LLMs. Defaults to None.
287
+ When self.produce is False, do_sample is set to True by default.
288
+ hidden_state_skip_layer (int): Number of hidden states to hidden_state_skip_layer. 0 means the last layer.
289
+ If None, self.output_key will be used. Defaults to None.
290
+ return_texts (bool): Whether to return the decoded texts. Defaults to False.
291
+ """
292
+ device = self.model.device if device is None else device
293
+ use_attention_mask = use_default(use_attention_mask, self.use_attention_mask)
294
+ hidden_state_skip_layer = use_default(hidden_state_skip_layer, self.hidden_state_skip_layer)
295
+ do_sample = use_default(do_sample, not self.reproduce)
296
+ attention_mask = batch_encoding["attention_mask"].to(device) if use_attention_mask else None
297
+ outputs = self.model(
298
+ input_ids=batch_encoding["input_ids"].to(device),
299
+ attention_mask=attention_mask,
300
+ output_hidden_states=output_hidden_states or hidden_state_skip_layer is not None,
301
+ )
302
+ if hidden_state_skip_layer is not None:
303
+ last_hidden_state = outputs.hidden_states[-(hidden_state_skip_layer + 1)]
304
+ # Real last hidden state already has layer norm applied. So here we only apply it
305
+ # for intermediate layers.
306
+ if hidden_state_skip_layer > 0 and self.apply_final_norm:
307
+ last_hidden_state = self.model.final_layer_norm(last_hidden_state)
308
+ else:
309
+ last_hidden_state = outputs[self.output_key]
310
+
311
+ # Remove hidden states of instruction tokens, only keep prompt tokens.
312
+ if self.use_template:
313
+ if data_type == "image":
314
+ crop_start = self.prompt_template.get("crop_start", -1)
315
+ elif data_type == "video":
316
+ crop_start = self.prompt_template_video.get("crop_start", -1)
317
+ else:
318
+ raise ValueError(f"Unsupported data type: {data_type}")
319
+ if crop_start > 0:
320
+ last_hidden_state = last_hidden_state[:, crop_start:]
321
+ attention_mask = attention_mask[:, crop_start:] if use_attention_mask else None
322
+
323
+ if output_hidden_states:
324
+ return TextEncoderModelOutput(last_hidden_state, attention_mask, outputs.hidden_states)
325
+ return TextEncoderModelOutput(last_hidden_state, attention_mask)
326
+
327
+ def forward(
328
+ self,
329
+ text,
330
+ use_attention_mask=None,
331
+ output_hidden_states=False,
332
+ do_sample=False,
333
+ hidden_state_skip_layer=None,
334
+ return_texts=False,
335
+ ):
336
+ batch_encoding = self.text2tokens(text)
337
+ return self.encode(
338
+ batch_encoding,
339
+ use_attention_mask=use_attention_mask,
340
+ output_hidden_states=output_hidden_states,
341
+ do_sample=do_sample,
342
+ hidden_state_skip_layer=hidden_state_skip_layer,
343
+ return_texts=return_texts,
344
+ )
345
+
346
+
347
+ # region HunyanVideo architecture
348
+
349
+
350
+ def load_text_encoder_1(
351
+ text_encoder_dir: str, device: torch.device, fp8_llm: bool, dtype: Optional[Union[str, torch.dtype]] = None
352
+ ) -> TextEncoder:
353
+ text_encoder_dtype = dtype or torch.float16
354
+ text_encoder_type = "llm"
355
+ text_len = 256
356
+ hidden_state_skip_layer = 2
357
+ apply_final_norm = False
358
+ reproduce = False
359
+
360
+ prompt_template = "dit-llm-encode"
361
+ prompt_template = PROMPT_TEMPLATE[prompt_template]
362
+ prompt_template_video = "dit-llm-encode-video"
363
+ prompt_template_video = PROMPT_TEMPLATE[prompt_template_video]
364
+
365
+ crop_start = prompt_template_video["crop_start"] # .get("crop_start", 0)
366
+ max_length = text_len + crop_start
367
+
368
+ text_encoder_1 = TextEncoder(
369
+ text_encoder_type=text_encoder_type,
370
+ max_length=max_length,
371
+ text_encoder_dtype=text_encoder_dtype,
372
+ text_encoder_path=text_encoder_dir,
373
+ tokenizer_type=text_encoder_type,
374
+ prompt_template=prompt_template,
375
+ prompt_template_video=prompt_template_video,
376
+ hidden_state_skip_layer=hidden_state_skip_layer,
377
+ apply_final_norm=apply_final_norm,
378
+ reproduce=reproduce,
379
+ )
380
+ text_encoder_1.eval()
381
+
382
+ if fp8_llm:
383
+ org_dtype = text_encoder_1.dtype
384
+ logger.info(f"Moving and casting text encoder to {device} and torch.float8_e4m3fn")
385
+ text_encoder_1.to(device=device, dtype=torch.float8_e4m3fn)
386
+
387
+ # prepare LLM for fp8
388
+ def prepare_fp8(llama_model: LlamaModel, target_dtype):
389
+ def forward_hook(module):
390
+ def forward(hidden_states):
391
+ input_dtype = hidden_states.dtype
392
+ hidden_states = hidden_states.to(torch.float32)
393
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
394
+ hidden_states = hidden_states * torch.rsqrt(variance + module.variance_epsilon)
395
+ return module.weight.to(input_dtype) * hidden_states.to(input_dtype)
396
+
397
+ return forward
398
+
399
+ for module in llama_model.modules():
400
+ if module.__class__.__name__ in ["Embedding"]:
401
+ # print("set", module.__class__.__name__, "to", target_dtype)
402
+ module.to(target_dtype)
403
+ if module.__class__.__name__ in ["LlamaRMSNorm"]:
404
+ # print("set", module.__class__.__name__, "hooks")
405
+ module.forward = forward_hook(module)
406
+
407
+ prepare_fp8(text_encoder_1.model, org_dtype)
408
+ else:
409
+ text_encoder_1.to(device=device)
410
+
411
+ return text_encoder_1
412
+
413
+
414
+ def load_text_encoder_2(
415
+ text_encoder_dir: str, device: torch.device, dtype: Optional[Union[str, torch.dtype]] = None
416
+ ) -> TextEncoder:
417
+ text_encoder_dtype = dtype or torch.float16
418
+ reproduce = False
419
+
420
+ text_encoder_2_type = "clipL"
421
+ text_len_2 = 77
422
+
423
+ text_encoder_2 = TextEncoder(
424
+ text_encoder_type=text_encoder_2_type,
425
+ max_length=text_len_2,
426
+ text_encoder_dtype=text_encoder_dtype,
427
+ text_encoder_path=text_encoder_dir,
428
+ tokenizer_type=text_encoder_2_type,
429
+ reproduce=reproduce,
430
+ )
431
+ text_encoder_2.eval()
432
+
433
+ text_encoder_2.to(device=device)
434
+
435
+ return text_encoder_2
436
+
437
+
438
+ # endregion
hunyuan_model/token_refiner.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ from einops import rearrange
4
+ import torch
5
+ import torch.nn as nn
6
+ from torch.utils.checkpoint import checkpoint
7
+
8
+ from .activation_layers import get_activation_layer
9
+ from .attention import attention
10
+ from .norm_layers import get_norm_layer
11
+ from .embed_layers import TimestepEmbedder, TextProjection
12
+ from .mlp_layers import MLP
13
+ from .modulate_layers import modulate, apply_gate
14
+
15
+
16
+ class IndividualTokenRefinerBlock(nn.Module):
17
+ def __init__(
18
+ self,
19
+ hidden_size,
20
+ heads_num,
21
+ mlp_width_ratio: str = 4.0,
22
+ mlp_drop_rate: float = 0.0,
23
+ act_type: str = "silu",
24
+ qk_norm: bool = False,
25
+ qk_norm_type: str = "layer",
26
+ qkv_bias: bool = True,
27
+ dtype: Optional[torch.dtype] = None,
28
+ device: Optional[torch.device] = None,
29
+ ):
30
+ factory_kwargs = {"device": device, "dtype": dtype}
31
+ super().__init__()
32
+ self.heads_num = heads_num
33
+ head_dim = hidden_size // heads_num
34
+ mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
35
+
36
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs)
37
+ self.self_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs)
38
+ qk_norm_layer = get_norm_layer(qk_norm_type)
39
+ self.self_attn_q_norm = (
40
+ qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
41
+ )
42
+ self.self_attn_k_norm = (
43
+ qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
44
+ )
45
+ self.self_attn_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs)
46
+
47
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs)
48
+ act_layer = get_activation_layer(act_type)
49
+ self.mlp = MLP(
50
+ in_channels=hidden_size,
51
+ hidden_channels=mlp_hidden_dim,
52
+ act_layer=act_layer,
53
+ drop=mlp_drop_rate,
54
+ **factory_kwargs,
55
+ )
56
+
57
+ self.adaLN_modulation = nn.Sequential(
58
+ act_layer(),
59
+ nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs),
60
+ )
61
+ # Zero-initialize the modulation
62
+ nn.init.zeros_(self.adaLN_modulation[1].weight)
63
+ nn.init.zeros_(self.adaLN_modulation[1].bias)
64
+
65
+ self.gradient_checkpointing = False
66
+
67
+ def enable_gradient_checkpointing(self):
68
+ self.gradient_checkpointing = True
69
+
70
+ def _forward(
71
+ self,
72
+ x: torch.Tensor,
73
+ c: torch.Tensor, # timestep_aware_representations + context_aware_representations
74
+ attn_mask: torch.Tensor = None,
75
+ ):
76
+ gate_msa, gate_mlp = self.adaLN_modulation(c).chunk(2, dim=1)
77
+
78
+ norm_x = self.norm1(x)
79
+ qkv = self.self_attn_qkv(norm_x)
80
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
81
+ # Apply QK-Norm if needed
82
+ q = self.self_attn_q_norm(q).to(v)
83
+ k = self.self_attn_k_norm(k).to(v)
84
+
85
+ # Self-Attention
86
+ attn = attention(q, k, v, mode="torch", attn_mask=attn_mask)
87
+
88
+ x = x + apply_gate(self.self_attn_proj(attn), gate_msa)
89
+
90
+ # FFN Layer
91
+ x = x + apply_gate(self.mlp(self.norm2(x)), gate_mlp)
92
+
93
+ return x
94
+
95
+ def forward(self, *args, **kwargs):
96
+ if self.training and self.gradient_checkpointing:
97
+ return checkpoint(self._forward, *args, use_reentrant=False, **kwargs)
98
+ else:
99
+ return self._forward(*args, **kwargs)
100
+
101
+
102
+
103
+ class IndividualTokenRefiner(nn.Module):
104
+ def __init__(
105
+ self,
106
+ hidden_size,
107
+ heads_num,
108
+ depth,
109
+ mlp_width_ratio: float = 4.0,
110
+ mlp_drop_rate: float = 0.0,
111
+ act_type: str = "silu",
112
+ qk_norm: bool = False,
113
+ qk_norm_type: str = "layer",
114
+ qkv_bias: bool = True,
115
+ dtype: Optional[torch.dtype] = None,
116
+ device: Optional[torch.device] = None,
117
+ ):
118
+ factory_kwargs = {"device": device, "dtype": dtype}
119
+ super().__init__()
120
+ self.blocks = nn.ModuleList(
121
+ [
122
+ IndividualTokenRefinerBlock(
123
+ hidden_size=hidden_size,
124
+ heads_num=heads_num,
125
+ mlp_width_ratio=mlp_width_ratio,
126
+ mlp_drop_rate=mlp_drop_rate,
127
+ act_type=act_type,
128
+ qk_norm=qk_norm,
129
+ qk_norm_type=qk_norm_type,
130
+ qkv_bias=qkv_bias,
131
+ **factory_kwargs,
132
+ )
133
+ for _ in range(depth)
134
+ ]
135
+ )
136
+
137
+ def enable_gradient_checkpointing(self):
138
+ for block in self.blocks:
139
+ block.enable_gradient_checkpointing()
140
+
141
+ def forward(
142
+ self,
143
+ x: torch.Tensor,
144
+ c: torch.LongTensor,
145
+ mask: Optional[torch.Tensor] = None,
146
+ ):
147
+ self_attn_mask = None
148
+ if mask is not None:
149
+ batch_size = mask.shape[0]
150
+ seq_len = mask.shape[1]
151
+ mask = mask.to(x.device)
152
+ # batch_size x 1 x seq_len x seq_len
153
+ self_attn_mask_1 = mask.view(batch_size, 1, 1, seq_len).repeat(1, 1, seq_len, 1)
154
+ # batch_size x 1 x seq_len x seq_len
155
+ self_attn_mask_2 = self_attn_mask_1.transpose(2, 3)
156
+ # batch_size x 1 x seq_len x seq_len, 1 for broadcasting of heads_num
157
+ self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool()
158
+ # avoids self-attention weight being NaN for padding tokens
159
+ self_attn_mask[:, :, :, 0] = True
160
+
161
+ for block in self.blocks:
162
+ x = block(x, c, self_attn_mask)
163
+ return x
164
+
165
+
166
+ class SingleTokenRefiner(nn.Module):
167
+ """
168
+ A single token refiner block for llm text embedding refine.
169
+ """
170
+
171
+ def __init__(
172
+ self,
173
+ in_channels,
174
+ hidden_size,
175
+ heads_num,
176
+ depth,
177
+ mlp_width_ratio: float = 4.0,
178
+ mlp_drop_rate: float = 0.0,
179
+ act_type: str = "silu",
180
+ qk_norm: bool = False,
181
+ qk_norm_type: str = "layer",
182
+ qkv_bias: bool = True,
183
+ attn_mode: str = "torch",
184
+ dtype: Optional[torch.dtype] = None,
185
+ device: Optional[torch.device] = None,
186
+ ):
187
+ factory_kwargs = {"device": device, "dtype": dtype}
188
+ super().__init__()
189
+ self.attn_mode = attn_mode
190
+ assert self.attn_mode == "torch", "Only support 'torch' mode for token refiner."
191
+
192
+ self.input_embedder = nn.Linear(in_channels, hidden_size, bias=True, **factory_kwargs)
193
+
194
+ act_layer = get_activation_layer(act_type)
195
+ # Build timestep embedding layer
196
+ self.t_embedder = TimestepEmbedder(hidden_size, act_layer, **factory_kwargs)
197
+ # Build context embedding layer
198
+ self.c_embedder = TextProjection(in_channels, hidden_size, act_layer, **factory_kwargs)
199
+
200
+ self.individual_token_refiner = IndividualTokenRefiner(
201
+ hidden_size=hidden_size,
202
+ heads_num=heads_num,
203
+ depth=depth,
204
+ mlp_width_ratio=mlp_width_ratio,
205
+ mlp_drop_rate=mlp_drop_rate,
206
+ act_type=act_type,
207
+ qk_norm=qk_norm,
208
+ qk_norm_type=qk_norm_type,
209
+ qkv_bias=qkv_bias,
210
+ **factory_kwargs,
211
+ )
212
+
213
+ def enable_gradient_checkpointing(self):
214
+ self.individual_token_refiner.enable_gradient_checkpointing()
215
+
216
+ def forward(
217
+ self,
218
+ x: torch.Tensor,
219
+ t: torch.LongTensor,
220
+ mask: Optional[torch.LongTensor] = None,
221
+ ):
222
+ timestep_aware_representations = self.t_embedder(t)
223
+
224
+ if mask is None:
225
+ context_aware_representations = x.mean(dim=1)
226
+ else:
227
+ mask_float = mask.float().unsqueeze(-1) # [b, s1, 1]
228
+ context_aware_representations = (x * mask_float).sum(dim=1) / mask_float.sum(dim=1)
229
+ context_aware_representations = self.c_embedder(context_aware_representations)
230
+ c = timestep_aware_representations + context_aware_representations
231
+
232
+ x = self.input_embedder(x)
233
+
234
+ x = self.individual_token_refiner(x, c, mask)
235
+
236
+ return x
hunyuan_model/vae.py ADDED
@@ -0,0 +1,442 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ import json
3
+ from typing import Optional, Tuple, Union
4
+ from pathlib import Path
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+ from diffusers.utils import BaseOutput, is_torch_version
11
+ from diffusers.utils.torch_utils import randn_tensor
12
+ from diffusers.models.attention_processor import SpatialNorm
13
+ from modules.unet_causal_3d_blocks import CausalConv3d, UNetMidBlockCausal3D, get_down_block3d, get_up_block3d
14
+
15
+ import logging
16
+
17
+ logger = logging.getLogger(__name__)
18
+ logging.basicConfig(level=logging.INFO)
19
+
20
+
21
+ SCALING_FACTOR = 0.476986
22
+ VAE_VER = "884-16c-hy"
23
+
24
+
25
+ def load_vae(
26
+ vae_type: str = "884-16c-hy",
27
+ vae_dtype: Optional[Union[str, torch.dtype]] = None,
28
+ sample_size: tuple = None,
29
+ vae_path: str = None,
30
+ device=None,
31
+ ):
32
+ """the fucntion to load the 3D VAE model
33
+
34
+ Args:
35
+ vae_type (str): the type of the 3D VAE model. Defaults to "884-16c-hy".
36
+ vae_precision (str, optional): the precision to load vae. Defaults to None.
37
+ sample_size (tuple, optional): the tiling size. Defaults to None.
38
+ vae_path (str, optional): the path to vae. Defaults to None.
39
+ logger (_type_, optional): logger. Defaults to None.
40
+ device (_type_, optional): device to load vae. Defaults to None.
41
+ """
42
+ if vae_path is None:
43
+ vae_path = VAE_PATH[vae_type]
44
+
45
+ logger.info(f"Loading 3D VAE model ({vae_type}) from: {vae_path}")
46
+
47
+ # use fixed config for Hunyuan's VAE
48
+ CONFIG_JSON = """{
49
+ "_class_name": "AutoencoderKLCausal3D",
50
+ "_diffusers_version": "0.4.2",
51
+ "act_fn": "silu",
52
+ "block_out_channels": [
53
+ 128,
54
+ 256,
55
+ 512,
56
+ 512
57
+ ],
58
+ "down_block_types": [
59
+ "DownEncoderBlockCausal3D",
60
+ "DownEncoderBlockCausal3D",
61
+ "DownEncoderBlockCausal3D",
62
+ "DownEncoderBlockCausal3D"
63
+ ],
64
+ "in_channels": 3,
65
+ "latent_channels": 16,
66
+ "layers_per_block": 2,
67
+ "norm_num_groups": 32,
68
+ "out_channels": 3,
69
+ "sample_size": 256,
70
+ "sample_tsize": 64,
71
+ "up_block_types": [
72
+ "UpDecoderBlockCausal3D",
73
+ "UpDecoderBlockCausal3D",
74
+ "UpDecoderBlockCausal3D",
75
+ "UpDecoderBlockCausal3D"
76
+ ],
77
+ "scaling_factor": 0.476986,
78
+ "time_compression_ratio": 4,
79
+ "mid_block_add_attention": true
80
+ }"""
81
+
82
+ # config = AutoencoderKLCausal3D.load_config(vae_path)
83
+ config = json.loads(CONFIG_JSON)
84
+
85
+ # import here to avoid circular import
86
+ from .autoencoder_kl_causal_3d import AutoencoderKLCausal3D
87
+
88
+ if sample_size:
89
+ vae = AutoencoderKLCausal3D.from_config(config, sample_size=sample_size)
90
+ else:
91
+ vae = AutoencoderKLCausal3D.from_config(config)
92
+
93
+ # vae_ckpt = Path(vae_path) / "pytorch_model.pt"
94
+ # assert vae_ckpt.exists(), f"VAE checkpoint not found: {vae_ckpt}"
95
+
96
+ ckpt = torch.load(vae_path, map_location=vae.device, weights_only=True)
97
+ if "state_dict" in ckpt:
98
+ ckpt = ckpt["state_dict"]
99
+ if any(k.startswith("vae.") for k in ckpt.keys()):
100
+ ckpt = {k.replace("vae.", ""): v for k, v in ckpt.items() if k.startswith("vae.")}
101
+ vae.load_state_dict(ckpt)
102
+
103
+ spatial_compression_ratio = vae.config.spatial_compression_ratio
104
+ time_compression_ratio = vae.config.time_compression_ratio
105
+
106
+ if vae_dtype is not None:
107
+ vae = vae.to(vae_dtype)
108
+
109
+ vae.requires_grad_(False)
110
+
111
+ logger.info(f"VAE to dtype: {vae.dtype}")
112
+
113
+ if device is not None:
114
+ vae = vae.to(device)
115
+
116
+ vae.eval()
117
+
118
+ return vae, vae_path, spatial_compression_ratio, time_compression_ratio
119
+
120
+
121
+ @dataclass
122
+ class DecoderOutput(BaseOutput):
123
+ r"""
124
+ Output of decoding method.
125
+
126
+ Args:
127
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
128
+ The decoded output sample from the last layer of the model.
129
+ """
130
+
131
+ sample: torch.FloatTensor
132
+
133
+
134
+ class EncoderCausal3D(nn.Module):
135
+ r"""
136
+ The `EncoderCausal3D` layer of a variational autoencoder that encodes its input into a latent representation.
137
+ """
138
+
139
+ def __init__(
140
+ self,
141
+ in_channels: int = 3,
142
+ out_channels: int = 3,
143
+ down_block_types: Tuple[str, ...] = ("DownEncoderBlockCausal3D",),
144
+ block_out_channels: Tuple[int, ...] = (64,),
145
+ layers_per_block: int = 2,
146
+ norm_num_groups: int = 32,
147
+ act_fn: str = "silu",
148
+ double_z: bool = True,
149
+ mid_block_add_attention=True,
150
+ time_compression_ratio: int = 4,
151
+ spatial_compression_ratio: int = 8,
152
+ ):
153
+ super().__init__()
154
+ self.layers_per_block = layers_per_block
155
+
156
+ self.conv_in = CausalConv3d(in_channels, block_out_channels[0], kernel_size=3, stride=1)
157
+ self.mid_block = None
158
+ self.down_blocks = nn.ModuleList([])
159
+
160
+ # down
161
+ output_channel = block_out_channels[0]
162
+ for i, down_block_type in enumerate(down_block_types):
163
+ input_channel = output_channel
164
+ output_channel = block_out_channels[i]
165
+ is_final_block = i == len(block_out_channels) - 1
166
+ num_spatial_downsample_layers = int(np.log2(spatial_compression_ratio))
167
+ num_time_downsample_layers = int(np.log2(time_compression_ratio))
168
+
169
+ if time_compression_ratio == 4:
170
+ add_spatial_downsample = bool(i < num_spatial_downsample_layers)
171
+ add_time_downsample = bool(i >= (len(block_out_channels) - 1 - num_time_downsample_layers) and not is_final_block)
172
+ else:
173
+ raise ValueError(f"Unsupported time_compression_ratio: {time_compression_ratio}.")
174
+
175
+ downsample_stride_HW = (2, 2) if add_spatial_downsample else (1, 1)
176
+ downsample_stride_T = (2,) if add_time_downsample else (1,)
177
+ downsample_stride = tuple(downsample_stride_T + downsample_stride_HW)
178
+ down_block = get_down_block3d(
179
+ down_block_type,
180
+ num_layers=self.layers_per_block,
181
+ in_channels=input_channel,
182
+ out_channels=output_channel,
183
+ add_downsample=bool(add_spatial_downsample or add_time_downsample),
184
+ downsample_stride=downsample_stride,
185
+ resnet_eps=1e-6,
186
+ downsample_padding=0,
187
+ resnet_act_fn=act_fn,
188
+ resnet_groups=norm_num_groups,
189
+ attention_head_dim=output_channel,
190
+ temb_channels=None,
191
+ )
192
+ self.down_blocks.append(down_block)
193
+
194
+ # mid
195
+ self.mid_block = UNetMidBlockCausal3D(
196
+ in_channels=block_out_channels[-1],
197
+ resnet_eps=1e-6,
198
+ resnet_act_fn=act_fn,
199
+ output_scale_factor=1,
200
+ resnet_time_scale_shift="default",
201
+ attention_head_dim=block_out_channels[-1],
202
+ resnet_groups=norm_num_groups,
203
+ temb_channels=None,
204
+ add_attention=mid_block_add_attention,
205
+ )
206
+
207
+ # out
208
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
209
+ self.conv_act = nn.SiLU()
210
+
211
+ conv_out_channels = 2 * out_channels if double_z else out_channels
212
+ self.conv_out = CausalConv3d(block_out_channels[-1], conv_out_channels, kernel_size=3)
213
+
214
+ def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor:
215
+ r"""The forward method of the `EncoderCausal3D` class."""
216
+ assert len(sample.shape) == 5, "The input tensor should have 5 dimensions"
217
+
218
+ sample = self.conv_in(sample)
219
+
220
+ # down
221
+ for down_block in self.down_blocks:
222
+ sample = down_block(sample)
223
+
224
+ # middle
225
+ sample = self.mid_block(sample)
226
+
227
+ # post-process
228
+ sample = self.conv_norm_out(sample)
229
+ sample = self.conv_act(sample)
230
+ sample = self.conv_out(sample)
231
+
232
+ return sample
233
+
234
+
235
+ class DecoderCausal3D(nn.Module):
236
+ r"""
237
+ The `DecoderCausal3D` layer of a variational autoencoder that decodes its latent representation into an output sample.
238
+ """
239
+
240
+ def __init__(
241
+ self,
242
+ in_channels: int = 3,
243
+ out_channels: int = 3,
244
+ up_block_types: Tuple[str, ...] = ("UpDecoderBlockCausal3D",),
245
+ block_out_channels: Tuple[int, ...] = (64,),
246
+ layers_per_block: int = 2,
247
+ norm_num_groups: int = 32,
248
+ act_fn: str = "silu",
249
+ norm_type: str = "group", # group, spatial
250
+ mid_block_add_attention=True,
251
+ time_compression_ratio: int = 4,
252
+ spatial_compression_ratio: int = 8,
253
+ ):
254
+ super().__init__()
255
+ self.layers_per_block = layers_per_block
256
+
257
+ self.conv_in = CausalConv3d(in_channels, block_out_channels[-1], kernel_size=3, stride=1)
258
+ self.mid_block = None
259
+ self.up_blocks = nn.ModuleList([])
260
+
261
+ temb_channels = in_channels if norm_type == "spatial" else None
262
+
263
+ # mid
264
+ self.mid_block = UNetMidBlockCausal3D(
265
+ in_channels=block_out_channels[-1],
266
+ resnet_eps=1e-6,
267
+ resnet_act_fn=act_fn,
268
+ output_scale_factor=1,
269
+ resnet_time_scale_shift="default" if norm_type == "group" else norm_type,
270
+ attention_head_dim=block_out_channels[-1],
271
+ resnet_groups=norm_num_groups,
272
+ temb_channels=temb_channels,
273
+ add_attention=mid_block_add_attention,
274
+ )
275
+
276
+ # up
277
+ reversed_block_out_channels = list(reversed(block_out_channels))
278
+ output_channel = reversed_block_out_channels[0]
279
+ for i, up_block_type in enumerate(up_block_types):
280
+ prev_output_channel = output_channel
281
+ output_channel = reversed_block_out_channels[i]
282
+ is_final_block = i == len(block_out_channels) - 1
283
+ num_spatial_upsample_layers = int(np.log2(spatial_compression_ratio))
284
+ num_time_upsample_layers = int(np.log2(time_compression_ratio))
285
+
286
+ if time_compression_ratio == 4:
287
+ add_spatial_upsample = bool(i < num_spatial_upsample_layers)
288
+ add_time_upsample = bool(i >= len(block_out_channels) - 1 - num_time_upsample_layers and not is_final_block)
289
+ else:
290
+ raise ValueError(f"Unsupported time_compression_ratio: {time_compression_ratio}.")
291
+
292
+ upsample_scale_factor_HW = (2, 2) if add_spatial_upsample else (1, 1)
293
+ upsample_scale_factor_T = (2,) if add_time_upsample else (1,)
294
+ upsample_scale_factor = tuple(upsample_scale_factor_T + upsample_scale_factor_HW)
295
+ up_block = get_up_block3d(
296
+ up_block_type,
297
+ num_layers=self.layers_per_block + 1,
298
+ in_channels=prev_output_channel,
299
+ out_channels=output_channel,
300
+ prev_output_channel=None,
301
+ add_upsample=bool(add_spatial_upsample or add_time_upsample),
302
+ upsample_scale_factor=upsample_scale_factor,
303
+ resnet_eps=1e-6,
304
+ resnet_act_fn=act_fn,
305
+ resnet_groups=norm_num_groups,
306
+ attention_head_dim=output_channel,
307
+ temb_channels=temb_channels,
308
+ resnet_time_scale_shift=norm_type,
309
+ )
310
+ self.up_blocks.append(up_block)
311
+ prev_output_channel = output_channel
312
+
313
+ # out
314
+ if norm_type == "spatial":
315
+ self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels)
316
+ else:
317
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
318
+ self.conv_act = nn.SiLU()
319
+ self.conv_out = CausalConv3d(block_out_channels[0], out_channels, kernel_size=3)
320
+
321
+ self.gradient_checkpointing = False
322
+
323
+ def forward(
324
+ self,
325
+ sample: torch.FloatTensor,
326
+ latent_embeds: Optional[torch.FloatTensor] = None,
327
+ ) -> torch.FloatTensor:
328
+ r"""The forward method of the `DecoderCausal3D` class."""
329
+ assert len(sample.shape) == 5, "The input tensor should have 5 dimensions."
330
+
331
+ sample = self.conv_in(sample)
332
+
333
+ upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
334
+ if self.training and self.gradient_checkpointing:
335
+
336
+ def create_custom_forward(module):
337
+ def custom_forward(*inputs):
338
+ return module(*inputs)
339
+
340
+ return custom_forward
341
+
342
+ if is_torch_version(">=", "1.11.0"):
343
+ # middle
344
+ sample = torch.utils.checkpoint.checkpoint(
345
+ create_custom_forward(self.mid_block),
346
+ sample,
347
+ latent_embeds,
348
+ use_reentrant=False,
349
+ )
350
+ sample = sample.to(upscale_dtype)
351
+
352
+ # up
353
+ for up_block in self.up_blocks:
354
+ sample = torch.utils.checkpoint.checkpoint(
355
+ create_custom_forward(up_block),
356
+ sample,
357
+ latent_embeds,
358
+ use_reentrant=False,
359
+ )
360
+ else:
361
+ # middle
362
+ sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample, latent_embeds)
363
+ sample = sample.to(upscale_dtype)
364
+
365
+ # up
366
+ for up_block in self.up_blocks:
367
+ sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample, latent_embeds)
368
+ else:
369
+ # middle
370
+ sample = self.mid_block(sample, latent_embeds)
371
+ sample = sample.to(upscale_dtype)
372
+
373
+ # up
374
+ for up_block in self.up_blocks:
375
+ sample = up_block(sample, latent_embeds)
376
+
377
+ # post-process
378
+ if latent_embeds is None:
379
+ sample = self.conv_norm_out(sample)
380
+ else:
381
+ sample = self.conv_norm_out(sample, latent_embeds)
382
+ sample = self.conv_act(sample)
383
+ sample = self.conv_out(sample)
384
+
385
+ return sample
386
+
387
+
388
+ class DiagonalGaussianDistribution(object):
389
+ def __init__(self, parameters: torch.Tensor, deterministic: bool = False):
390
+ if parameters.ndim == 3:
391
+ dim = 2 # (B, L, C)
392
+ elif parameters.ndim == 5 or parameters.ndim == 4:
393
+ dim = 1 # (B, C, T, H ,W) / (B, C, H, W)
394
+ else:
395
+ raise NotImplementedError
396
+ self.parameters = parameters
397
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=dim)
398
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
399
+ self.deterministic = deterministic
400
+ self.std = torch.exp(0.5 * self.logvar)
401
+ self.var = torch.exp(self.logvar)
402
+ if self.deterministic:
403
+ self.var = self.std = torch.zeros_like(self.mean, device=self.parameters.device, dtype=self.parameters.dtype)
404
+
405
+ def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor:
406
+ # make sure sample is on the same device as the parameters and has same dtype
407
+ sample = randn_tensor(
408
+ self.mean.shape,
409
+ generator=generator,
410
+ device=self.parameters.device,
411
+ dtype=self.parameters.dtype,
412
+ )
413
+ x = self.mean + self.std * sample
414
+ return x
415
+
416
+ def kl(self, other: "DiagonalGaussianDistribution" = None) -> torch.Tensor:
417
+ if self.deterministic:
418
+ return torch.Tensor([0.0])
419
+ else:
420
+ reduce_dim = list(range(1, self.mean.ndim))
421
+ if other is None:
422
+ return 0.5 * torch.sum(
423
+ torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
424
+ dim=reduce_dim,
425
+ )
426
+ else:
427
+ return 0.5 * torch.sum(
428
+ torch.pow(self.mean - other.mean, 2) / other.var + self.var / other.var - 1.0 - self.logvar + other.logvar,
429
+ dim=reduce_dim,
430
+ )
431
+
432
+ def nll(self, sample: torch.Tensor, dims: Tuple[int, ...] = [1, 2, 3]) -> torch.Tensor:
433
+ if self.deterministic:
434
+ return torch.Tensor([0.0])
435
+ logtwopi = np.log(2.0 * np.pi)
436
+ return 0.5 * torch.sum(
437
+ logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
438
+ dim=dims,
439
+ )
440
+
441
+ def mode(self) -> torch.Tensor:
442
+ return self.mean