svjack commited on
Commit
bce3e7c
·
verified ·
1 Parent(s): a0899f4

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +4 -0
  2. cache_latents.py +245 -0
  3. cache_text_encoder_outputs.py +135 -0
  4. convert_lora.py +129 -0
  5. dataset/__init__.py +0 -0
  6. dataset/config_utils.py +359 -0
  7. dataset/dataset_config.md +293 -0
  8. dataset/image_video_dataset.py +1255 -0
  9. hunyuan_model/__init__.py +0 -0
  10. hunyuan_model/activation_layers.py +23 -0
  11. hunyuan_model/attention.py +230 -0
  12. hunyuan_model/autoencoder_kl_causal_3d.py +609 -0
  13. hunyuan_model/embed_layers.py +132 -0
  14. hunyuan_model/helpers.py +40 -0
  15. hunyuan_model/mlp_layers.py +118 -0
  16. hunyuan_model/models.py +997 -0
  17. hunyuan_model/modulate_layers.py +76 -0
  18. hunyuan_model/norm_layers.py +79 -0
  19. hunyuan_model/pipeline_hunyuan_video.py +1100 -0
  20. hunyuan_model/posemb_layers.py +310 -0
  21. hunyuan_model/text_encoder.py +438 -0
  22. hunyuan_model/token_refiner.py +236 -0
  23. hunyuan_model/vae.py +442 -0
  24. hv_generate_video.py +563 -0
  25. hv_train_network.py +2129 -0
  26. modules/__init__.py +0 -0
  27. modules/custom_offloading_utils.py +262 -0
  28. modules/scheduling_flow_match_discrete.py +257 -0
  29. modules/unet_causal_3d_blocks.py +818 -0
  30. networks/__init__.py +0 -0
  31. networks/lora.py +828 -0
  32. requirements.txt +18 -0
  33. utils/__init__.py +0 -0
  34. utils/huggingface_utils.py +89 -0
  35. utils/model_utils.py +151 -0
  36. utils/safetensors_utils.py +191 -0
  37. utils/sai_model_spec.py +263 -0
  38. utils/train_utils.py +177 -0
  39. zhongli_ningguang_couple_im_lora_dir/zhongli_ningguang_couple_im_lora-000001.safetensors +3 -0
  40. zhongli_ningguang_couple_im_lora_dir/zhongli_ningguang_couple_im_lora-000002.safetensors +3 -0
  41. zhongli_ningguang_couple_im_lora_dir/zhongli_ningguang_couple_im_lora-000003.safetensors +3 -0
  42. zhongli_ningguang_couple_im_lora_dir/zhongli_ningguang_couple_im_lora-000004.safetensors +3 -0
  43. zhongli_ningguang_couple_im_lora_dir/zhongli_ningguang_couple_im_lora-000005.safetensors +3 -0
  44. zhongli_ningguang_couple_im_lora_dir/zhongli_ningguang_couple_im_lora-000006.safetensors +3 -0
  45. zhongli_ningguang_couple_im_lora_dir/zhongli_ningguang_couple_im_lora-000007.safetensors +3 -0
  46. zhongli_ningguang_couple_im_lora_dir/zhongli_ningguang_couple_im_lora-000008.safetensors +3 -0
  47. zhongli_ningguang_couple_im_lora_dir/zhongli_ningguang_couple_im_lora-000009.safetensors +3 -0
  48. zhongli_ningguang_couple_im_lora_dir/zhongli_ningguang_couple_im_lora-000010.safetensors +3 -0
  49. zhongli_ningguang_couple_im_lora_dir/zhongli_ningguang_couple_im_lora-000011.safetensors +3 -0
  50. zhongli_ningguang_couple_im_lora_dir/zhongli_ningguang_couple_im_lora-000012.safetensors +3 -0
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ __pycache__/
2
+ .venv
3
+ venv/
4
+ logs/
cache_latents.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ from typing import Optional, Union
4
+
5
+ import numpy as np
6
+ import torch
7
+ from tqdm import tqdm
8
+
9
+ from dataset import config_utils
10
+ from dataset.config_utils import BlueprintGenerator, ConfigSanitizer
11
+ from PIL import Image
12
+
13
+ import logging
14
+
15
+ from dataset.image_video_dataset import BaseDataset, ItemInfo, save_latent_cache
16
+ from hunyuan_model.vae import load_vae
17
+ from hunyuan_model.autoencoder_kl_causal_3d import AutoencoderKLCausal3D
18
+ from utils.model_utils import str_to_dtype
19
+
20
+ logger = logging.getLogger(__name__)
21
+ logging.basicConfig(level=logging.INFO)
22
+
23
+
24
+ def show_image(image: Union[list[Union[Image.Image, np.ndarray], Union[Image.Image, np.ndarray]]]) -> int:
25
+ import cv2
26
+
27
+ imgs = (
28
+ [image]
29
+ if (isinstance(image, np.ndarray) and len(image.shape) == 3) or isinstance(image, Image.Image)
30
+ else [image[0], image[-1]]
31
+ )
32
+ if len(imgs) > 1:
33
+ print(f"Number of images: {len(image)}")
34
+ for i, img in enumerate(imgs):
35
+ if len(imgs) > 1:
36
+ print(f"{'First' if i == 0 else 'Last'} image: {img.shape}")
37
+ else:
38
+ print(f"Image: {img.shape}")
39
+ cv2_img = np.array(img) if isinstance(img, Image.Image) else img
40
+ cv2_img = cv2.cvtColor(cv2_img, cv2.COLOR_RGB2BGR)
41
+ cv2.imshow("image", cv2_img)
42
+ k = cv2.waitKey(0)
43
+ cv2.destroyAllWindows()
44
+ if k == ord("q") or k == ord("d"):
45
+ return k
46
+ return k
47
+
48
+
49
+ def show_console(
50
+ image: Union[list[Union[Image.Image, np.ndarray], Union[Image.Image, np.ndarray]]],
51
+ width: int,
52
+ back: str,
53
+ interactive: bool = False,
54
+ ) -> int:
55
+ from ascii_magic import from_pillow_image, Back
56
+
57
+ back = None
58
+ if back is not None:
59
+ back = getattr(Back, back.upper())
60
+
61
+ k = None
62
+ imgs = (
63
+ [image]
64
+ if (isinstance(image, np.ndarray) and len(image.shape) == 3) or isinstance(image, Image.Image)
65
+ else [image[0], image[-1]]
66
+ )
67
+ if len(imgs) > 1:
68
+ print(f"Number of images: {len(image)}")
69
+ for i, img in enumerate(imgs):
70
+ if len(imgs) > 1:
71
+ print(f"{'First' if i == 0 else 'Last'} image: {img.shape}")
72
+ else:
73
+ print(f"Image: {img.shape}")
74
+ pil_img = img if isinstance(img, Image.Image) else Image.fromarray(img)
75
+ ascii_img = from_pillow_image(pil_img)
76
+ ascii_img.to_terminal(columns=width, back=back)
77
+
78
+ if interactive:
79
+ k = input("Press q to quit, d to next dataset, other key to next: ")
80
+ if k == "q" or k == "d":
81
+ return ord(k)
82
+
83
+ if not interactive:
84
+ return ord(" ")
85
+ return ord(k) if k else ord(" ")
86
+
87
+
88
+ def show_datasets(
89
+ datasets: list[BaseDataset], debug_mode: str, console_width: int, console_back: str, console_num_images: Optional[int]
90
+ ):
91
+ print(f"d: next dataset, q: quit")
92
+
93
+ num_workers = max(1, os.cpu_count() - 1)
94
+ for i, dataset in enumerate(datasets):
95
+ print(f"Dataset [{i}]")
96
+ batch_index = 0
97
+ num_images_to_show = console_num_images
98
+ k = None
99
+ for key, batch in dataset.retrieve_latent_cache_batches(num_workers):
100
+ print(f"bucket resolution: {key}, count: {len(batch)}")
101
+ for j, item_info in enumerate(batch):
102
+ item_info: ItemInfo
103
+ print(f"{batch_index}-{j}: {item_info}")
104
+ if debug_mode == "image":
105
+ k = show_image(item_info.content)
106
+ elif debug_mode == "console":
107
+ k = show_console(item_info.content, console_width, console_back, console_num_images is None)
108
+ if num_images_to_show is not None:
109
+ num_images_to_show -= 1
110
+ if num_images_to_show == 0:
111
+ k = ord("d") # next dataset
112
+
113
+ if k == ord("q"):
114
+ return
115
+ elif k == ord("d"):
116
+ break
117
+ if k == ord("d"):
118
+ break
119
+ batch_index += 1
120
+
121
+
122
+ def encode_and_save_batch(vae: AutoencoderKLCausal3D, batch: list[ItemInfo]):
123
+ contents = torch.stack([torch.from_numpy(item.content) for item in batch])
124
+ if len(contents.shape) == 4:
125
+ contents = contents.unsqueeze(1) # B, H, W, C -> B, F, H, W, C
126
+
127
+ contents = contents.permute(0, 4, 1, 2, 3).contiguous() # B, C, F, H, W
128
+ contents = contents.to(vae.device, dtype=vae.dtype)
129
+ contents = contents / 127.5 - 1.0 # normalize to [-1, 1]
130
+
131
+ # print(f"encode batch: {contents.shape}")
132
+ with torch.no_grad():
133
+ latent = vae.encode(contents).latent_dist.sample()
134
+ latent = latent * vae.config.scaling_factor
135
+
136
+ # # debug: decode and save
137
+ # with torch.no_grad():
138
+ # latent_to_decode = latent / vae.config.scaling_factor
139
+ # images = vae.decode(latent_to_decode, return_dict=False)[0]
140
+ # images = (images / 2 + 0.5).clamp(0, 1)
141
+ # images = images.cpu().float().numpy()
142
+ # images = (images * 255).astype(np.uint8)
143
+ # images = images.transpose(0, 2, 3, 4, 1) # B, C, F, H, W -> B, F, H, W, C
144
+ # for b in range(images.shape[0]):
145
+ # for f in range(images.shape[1]):
146
+ # fln = os.path.splitext(os.path.basename(batch[b].item_key))[0]
147
+ # img = Image.fromarray(images[b, f])
148
+ # img.save(f"./logs/decode_{fln}_{b}_{f:03d}.jpg")
149
+
150
+ for item, l in zip(batch, latent):
151
+ # print(f"save latent cache: {item.latent_cache_path}, latent shape: {l.shape}")
152
+ save_latent_cache(item, l)
153
+
154
+
155
+ def main(args):
156
+ device = args.device if args.device is not None else "cuda" if torch.cuda.is_available() else "cpu"
157
+ device = torch.device(device)
158
+
159
+ # Load dataset config
160
+ blueprint_generator = BlueprintGenerator(ConfigSanitizer())
161
+ logger.info(f"Load dataset config from {args.dataset_config}")
162
+ user_config = config_utils.load_user_config(args.dataset_config)
163
+ blueprint = blueprint_generator.generate(user_config, args)
164
+ train_dataset_group = config_utils.generate_dataset_group_by_blueprint(blueprint.dataset_group)
165
+
166
+ datasets = train_dataset_group.datasets
167
+
168
+ if args.debug_mode is not None:
169
+ show_datasets(datasets, args.debug_mode, args.console_width, args.console_back, args.console_num_images)
170
+ return
171
+
172
+ assert args.vae is not None, "vae checkpoint is required"
173
+
174
+ # Load VAE model: HunyuanVideo VAE model is float16
175
+ vae_dtype = torch.float16 if args.vae_dtype is None else str_to_dtype(args.vae_dtype)
176
+ vae, _, s_ratio, t_ratio = load_vae(vae_dtype=vae_dtype, device=device, vae_path=args.vae)
177
+ vae.eval()
178
+ print(f"Loaded VAE: {vae.config}, dtype: {vae.dtype}")
179
+
180
+ if args.vae_chunk_size is not None:
181
+ vae.set_chunk_size_for_causal_conv_3d(args.vae_chunk_size)
182
+ logger.info(f"Set chunk_size to {args.vae_chunk_size} for CausalConv3d in VAE")
183
+ if args.vae_spatial_tile_sample_min_size is not None:
184
+ vae.enable_spatial_tiling(True)
185
+ vae.tile_sample_min_size = args.vae_spatial_tile_sample_min_size
186
+ vae.tile_latent_min_size = args.vae_spatial_tile_sample_min_size // 8
187
+ elif args.vae_tiling:
188
+ vae.enable_spatial_tiling(True)
189
+
190
+ # Encode images
191
+ num_workers = args.num_workers if args.num_workers is not None else max(1, os.cpu_count() - 1)
192
+ for i, dataset in enumerate(datasets):
193
+ print(f"Encoding dataset [{i}]")
194
+ for _, batch in tqdm(dataset.retrieve_latent_cache_batches(num_workers)):
195
+ if args.skip_existing:
196
+ filtered_batch = [item for item in batch if not os.path.exists(item.latent_cache_path)]
197
+ if len(filtered_batch) == 0:
198
+ continue
199
+ batch = filtered_batch
200
+
201
+ bs = args.batch_size if args.batch_size is not None else len(batch)
202
+ for i in range(0, len(batch), bs):
203
+ encode_and_save_batch(vae, batch[i : i + bs])
204
+
205
+
206
+ def setup_parser():
207
+ parser = argparse.ArgumentParser()
208
+
209
+ parser.add_argument("--dataset_config", type=str, required=True, help="path to dataset config .toml file")
210
+ parser.add_argument("--vae", type=str, required=False, default=None, help="path to vae checkpoint")
211
+ parser.add_argument("--vae_dtype", type=str, default=None, help="data type for VAE, default is float16")
212
+ parser.add_argument(
213
+ "--vae_tiling",
214
+ action="store_true",
215
+ help="enable spatial tiling for VAE, default is False. If vae_spatial_tile_sample_min_size is set, this is automatically enabled",
216
+ )
217
+ parser.add_argument("--vae_chunk_size", type=int, default=None, help="chunk size for CausalConv3d in VAE")
218
+ parser.add_argument(
219
+ "--vae_spatial_tile_sample_min_size", type=int, default=None, help="spatial tile sample min size for VAE, default 256"
220
+ )
221
+ parser.add_argument("--device", type=str, default=None, help="device to use, default is cuda if available")
222
+ parser.add_argument(
223
+ "--batch_size", type=int, default=None, help="batch size, override dataset config if dataset batch size > this"
224
+ )
225
+ parser.add_argument("--num_workers", type=int, default=None, help="number of workers for dataset. default is cpu count-1")
226
+ parser.add_argument("--skip_existing", action="store_true", help="skip existing cache files")
227
+ parser.add_argument("--debug_mode", type=str, default=None, choices=["image", "console"], help="debug mode")
228
+ parser.add_argument("--console_width", type=int, default=80, help="debug mode: console width")
229
+ parser.add_argument(
230
+ "--console_back", type=str, default=None, help="debug mode: console background color, one of ascii_magic.Back"
231
+ )
232
+ parser.add_argument(
233
+ "--console_num_images",
234
+ type=int,
235
+ default=None,
236
+ help="debug mode: not interactive, number of images to show for each dataset",
237
+ )
238
+ return parser
239
+
240
+
241
+ if __name__ == "__main__":
242
+ parser = setup_parser()
243
+
244
+ args = parser.parse_args()
245
+ main(args)
cache_text_encoder_outputs.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ from typing import Optional, Union
4
+
5
+ import numpy as np
6
+ import torch
7
+ from tqdm import tqdm
8
+
9
+ from dataset import config_utils
10
+ from dataset.config_utils import BlueprintGenerator, ConfigSanitizer
11
+ import accelerate
12
+
13
+ from dataset.image_video_dataset import ItemInfo, save_text_encoder_output_cache
14
+ from hunyuan_model import text_encoder as text_encoder_module
15
+ from hunyuan_model.text_encoder import TextEncoder
16
+
17
+ import logging
18
+
19
+ from utils.model_utils import str_to_dtype
20
+
21
+ logger = logging.getLogger(__name__)
22
+ logging.basicConfig(level=logging.INFO)
23
+
24
+
25
+ def encode_prompt(text_encoder: TextEncoder, prompt: Union[str, list[str]]):
26
+ data_type = "video" # video only, image is not supported
27
+ text_inputs = text_encoder.text2tokens(prompt, data_type=data_type)
28
+
29
+ with torch.no_grad():
30
+ prompt_outputs = text_encoder.encode(text_inputs, data_type=data_type)
31
+
32
+ return prompt_outputs.hidden_state, prompt_outputs.attention_mask
33
+
34
+
35
+ def encode_and_save_batch(
36
+ text_encoder: TextEncoder, batch: list[ItemInfo], is_llm: bool, accelerator: Optional[accelerate.Accelerator]
37
+ ):
38
+ prompts = [item.caption for item in batch]
39
+ # print(prompts)
40
+
41
+ # encode prompt
42
+ if accelerator is not None:
43
+ with accelerator.autocast():
44
+ prompt_embeds, prompt_mask = encode_prompt(text_encoder, prompts)
45
+ else:
46
+ prompt_embeds, prompt_mask = encode_prompt(text_encoder, prompts)
47
+
48
+ # # convert to fp16 if needed
49
+ # if prompt_embeds.dtype == torch.float32 and text_encoder.dtype != torch.float32:
50
+ # prompt_embeds = prompt_embeds.to(text_encoder.dtype)
51
+
52
+ # save prompt cache
53
+ for item, embed, mask in zip(batch, prompt_embeds, prompt_mask):
54
+ save_text_encoder_output_cache(item, embed, mask, is_llm)
55
+
56
+
57
+ def main(args):
58
+ device = args.device if args.device is not None else "cuda" if torch.cuda.is_available() else "cpu"
59
+ device = torch.device(device)
60
+
61
+ # Load dataset config
62
+ blueprint_generator = BlueprintGenerator(ConfigSanitizer())
63
+ logger.info(f"Load dataset config from {args.dataset_config}")
64
+ user_config = config_utils.load_user_config(args.dataset_config)
65
+ blueprint = blueprint_generator.generate(user_config, args)
66
+ train_dataset_group = config_utils.generate_dataset_group_by_blueprint(blueprint.dataset_group)
67
+
68
+ datasets = train_dataset_group.datasets
69
+
70
+ # define accelerator for fp8 inference
71
+ accelerator = None
72
+ if args.fp8_llm:
73
+ accelerator = accelerate.Accelerator(mixed_precision="fp16")
74
+
75
+ # define encode function
76
+ num_workers = args.num_workers if args.num_workers is not None else max(1, os.cpu_count() - 1)
77
+
78
+ def encode_for_text_encoder(text_encoder: TextEncoder, is_llm: bool):
79
+ for i, dataset in enumerate(datasets):
80
+ print(f"Encoding dataset [{i}]")
81
+ for batch in tqdm(dataset.retrieve_text_encoder_output_cache_batches(num_workers)):
82
+ if args.skip_existing:
83
+ filtered_batch = [item for item in batch if not os.path.exists(item.text_encoder_output_cache_path)]
84
+ if len(filtered_batch) == 0:
85
+ continue
86
+ batch = filtered_batch
87
+
88
+ bs = args.batch_size if args.batch_size is not None else len(batch)
89
+ for i in range(0, len(batch), bs):
90
+ encode_and_save_batch(text_encoder, batch[i : i + bs], is_llm, accelerator)
91
+
92
+ # Load Text Encoder 1
93
+ text_encoder_dtype = torch.float16 if args.text_encoder_dtype is None else str_to_dtype(args.text_encoder_dtype)
94
+ logger.info(f"loading text encoder 1: {args.text_encoder1}")
95
+ text_encoder_1 = text_encoder_module.load_text_encoder_1(args.text_encoder1, device, args.fp8_llm, text_encoder_dtype)
96
+ text_encoder_1.to(device=device)
97
+
98
+ # Encode with Text Encoder 1
99
+ logger.info("Encoding with Text Encoder 1")
100
+ encode_for_text_encoder(text_encoder_1, is_llm=True)
101
+ del text_encoder_1
102
+
103
+ # Load Text Encoder 2
104
+ logger.info(f"loading text encoder 2: {args.text_encoder2}")
105
+ text_encoder_2 = text_encoder_module.load_text_encoder_2(args.text_encoder2, device, text_encoder_dtype)
106
+ text_encoder_2.to(device=device)
107
+
108
+ # Encode with Text Encoder 2
109
+ logger.info("Encoding with Text Encoder 2")
110
+ encode_for_text_encoder(text_encoder_2, is_llm=False)
111
+ del text_encoder_2
112
+
113
+
114
+ def setup_parser():
115
+ parser = argparse.ArgumentParser()
116
+
117
+ parser.add_argument("--dataset_config", type=str, required=True, help="path to dataset config .toml file")
118
+ parser.add_argument("--text_encoder1", type=str, required=True, help="Text Encoder 1 directory")
119
+ parser.add_argument("--text_encoder2", type=str, required=True, help="Text Encoder 2 directory")
120
+ parser.add_argument("--device", type=str, default=None, help="device to use, default is cuda if available")
121
+ parser.add_argument("--text_encoder_dtype", type=str, default=None, help="data type for Text Encoder, default is float16")
122
+ parser.add_argument("--fp8_llm", action="store_true", help="use fp8 for Text Encoder 1 (LLM)")
123
+ parser.add_argument(
124
+ "--batch_size", type=int, default=None, help="batch size, override dataset config if dataset batch size > this"
125
+ )
126
+ parser.add_argument("--num_workers", type=int, default=None, help="number of workers for dataset. default is cpu count-1")
127
+ parser.add_argument("--skip_existing", action="store_true", help="skip existing cache files")
128
+ return parser
129
+
130
+
131
+ if __name__ == "__main__":
132
+ parser = setup_parser()
133
+
134
+ args = parser.parse_args()
135
+ main(args)
convert_lora.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ import torch
4
+ from safetensors.torch import load_file, save_file
5
+ from safetensors import safe_open
6
+ from utils import model_utils
7
+
8
+ import logging
9
+
10
+
11
+ logger = logging.getLogger(__name__)
12
+ logging.basicConfig(level=logging.INFO)
13
+
14
+
15
+ def convert_from_diffusers(prefix, weights_sd):
16
+ # convert from diffusers(?) to default LoRA
17
+ # Diffusers format: {"diffusion_model.module.name.lora_A.weight": weight, "diffusion_model.module.name.lora_B.weight": weight, ...}
18
+ # default LoRA format: {"prefix_module_name.lora_down.weight": weight, "prefix_module_name.lora_up.weight": weight, ...}
19
+ # note: Diffusers has no alpha, so alpha is set to rank
20
+ new_weights_sd = {}
21
+ lora_dims = {}
22
+ for key, weight in weights_sd.items():
23
+ diffusers_prefix, key_body = key.split(".", 1)
24
+ if diffusers_prefix != "diffusion_model":
25
+ logger.warning(f"unexpected key: {key} in diffusers format")
26
+ continue
27
+
28
+ new_key = f"{prefix}{key_body}".replace(".", "_").replace("_lora_A_", ".lora_down.").replace("_lora_B_", ".lora_up.")
29
+ new_weights_sd[new_key] = weight
30
+
31
+ lora_name = new_key.split(".")[0] # before first dot
32
+ if lora_name not in lora_dims and "lora_down" in new_key:
33
+ lora_dims[lora_name] = weight.shape[0]
34
+
35
+ # add alpha with rank
36
+ for lora_name, dim in lora_dims.items():
37
+ new_weights_sd[f"{lora_name}.alpha"] = torch.tensor(dim)
38
+
39
+ return new_weights_sd
40
+
41
+
42
+ def convert_to_diffusers(prefix, weights_sd):
43
+ # convert from default LoRA to diffusers
44
+
45
+ # get alphas
46
+ lora_alphas = {}
47
+ for key, weight in weights_sd.items():
48
+ if key.startswith(prefix):
49
+ lora_name = key.split(".", 1)[0] # before first dot
50
+ if lora_name not in lora_alphas and "alpha" in key:
51
+ lora_alphas[lora_name] = weight
52
+
53
+ new_weights_sd = {}
54
+ for key, weight in weights_sd.items():
55
+ if key.startswith(prefix):
56
+ if "alpha" in key:
57
+ continue
58
+
59
+ lora_name = key.split(".", 1)[0] # before first dot
60
+
61
+ # HunyuanVideo lora name to module name: ugly but works
62
+ module_name = lora_name[len(prefix) :] # remove "lora_unet_"
63
+ module_name = module_name.replace("_", ".") # replace "_" with "."
64
+ module_name = module_name.replace("double.blocks.", "double_blocks.") # fix double blocks
65
+ module_name = module_name.replace("single.blocks.", "single_blocks.") # fix single blocks
66
+ module_name = module_name.replace("img.", "img_") # fix img
67
+ module_name = module_name.replace("txt.", "txt_") # fix txt
68
+ module_name = module_name.replace("attn.", "attn_") # fix attn
69
+
70
+ diffusers_prefix = "diffusion_model"
71
+ if "lora_down" in key:
72
+ new_key = f"{diffusers_prefix}.{module_name}.lora_A.weight"
73
+ dim = weight.shape[0]
74
+ elif "lora_up" in key:
75
+ new_key = f"{diffusers_prefix}.{module_name}.lora_B.weight"
76
+ dim = weight.shape[1]
77
+ else:
78
+ logger.warning(f"unexpected key: {key} in default LoRA format")
79
+ continue
80
+
81
+ # scale weight by alpha
82
+ if lora_name in lora_alphas:
83
+ # we scale both down and up, so scale is sqrt
84
+ scale = lora_alphas[lora_name] / dim
85
+ scale = scale.sqrt()
86
+ weight = weight * scale
87
+ else:
88
+ logger.warning(f"missing alpha for {lora_name}")
89
+
90
+ new_weights_sd[new_key] = weight
91
+
92
+ return new_weights_sd
93
+
94
+
95
+ def convert(input_file, output_file, target_format):
96
+ logger.info(f"loading {input_file}")
97
+ weights_sd = load_file(input_file)
98
+ with safe_open(input_file, framework="pt") as f:
99
+ metadata = f.metadata()
100
+
101
+ logger.info(f"converting to {target_format}")
102
+ prefix = "lora_unet_"
103
+ if target_format == "default":
104
+ new_weights_sd = convert_from_diffusers(prefix, weights_sd)
105
+ metadata = metadata or {}
106
+ model_utils.precalculate_safetensors_hashes(new_weights_sd, metadata)
107
+ elif target_format == "other":
108
+ new_weights_sd = convert_to_diffusers(prefix, weights_sd)
109
+ else:
110
+ raise ValueError(f"unknown target format: {target_format}")
111
+
112
+ logger.info(f"saving to {output_file}")
113
+ save_file(new_weights_sd, output_file, metadata=metadata)
114
+
115
+ logger.info("done")
116
+
117
+
118
+ def parse_args():
119
+ parser = argparse.ArgumentParser(description="Convert LoRA weights between default and other formats")
120
+ parser.add_argument("--input", type=str, required=True, help="input model file")
121
+ parser.add_argument("--output", type=str, required=True, help="output model file")
122
+ parser.add_argument("--target", type=str, required=True, choices=["other", "default"], help="target format")
123
+ args = parser.parse_args()
124
+ return args
125
+
126
+
127
+ if __name__ == "__main__":
128
+ args = parse_args()
129
+ convert(args.input, args.output, args.target)
dataset/__init__.py ADDED
File without changes
dataset/config_utils.py ADDED
@@ -0,0 +1,359 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from dataclasses import (
3
+ asdict,
4
+ dataclass,
5
+ )
6
+ import functools
7
+ import random
8
+ from textwrap import dedent, indent
9
+ import json
10
+ from pathlib import Path
11
+
12
+ # from toolz import curry
13
+ from typing import Dict, List, Optional, Sequence, Tuple, Union
14
+
15
+ import toml
16
+ import voluptuous
17
+ from voluptuous import Any, ExactSequence, MultipleInvalid, Object, Schema
18
+
19
+ from .image_video_dataset import DatasetGroup, ImageDataset, VideoDataset
20
+
21
+ import logging
22
+
23
+ logger = logging.getLogger(__name__)
24
+ logging.basicConfig(level=logging.INFO)
25
+
26
+
27
+ @dataclass
28
+ class BaseDatasetParams:
29
+ resolution: Tuple[int, int] = (960, 544)
30
+ enable_bucket: bool = False
31
+ bucket_no_upscale: bool = False
32
+ caption_extension: Optional[str] = None
33
+ batch_size: int = 1
34
+ cache_directory: Optional[str] = None
35
+ debug_dataset: bool = False
36
+
37
+
38
+ @dataclass
39
+ class ImageDatasetParams(BaseDatasetParams):
40
+ image_directory: Optional[str] = None
41
+ image_jsonl_file: Optional[str] = None
42
+
43
+
44
+ @dataclass
45
+ class VideoDatasetParams(BaseDatasetParams):
46
+ video_directory: Optional[str] = None
47
+ video_jsonl_file: Optional[str] = None
48
+ target_frames: Sequence[int] = (1,)
49
+ frame_extraction: Optional[str] = "head"
50
+ frame_stride: Optional[int] = 1
51
+ frame_sample: Optional[int] = 1
52
+
53
+
54
+ @dataclass
55
+ class DatasetBlueprint:
56
+ is_image_dataset: bool
57
+ params: Union[ImageDatasetParams, VideoDatasetParams]
58
+
59
+
60
+ @dataclass
61
+ class DatasetGroupBlueprint:
62
+ datasets: Sequence[DatasetBlueprint]
63
+
64
+
65
+ @dataclass
66
+ class Blueprint:
67
+ dataset_group: DatasetGroupBlueprint
68
+
69
+
70
+ class ConfigSanitizer:
71
+ # @curry
72
+ @staticmethod
73
+ def __validate_and_convert_twodim(klass, value: Sequence) -> Tuple:
74
+ Schema(ExactSequence([klass, klass]))(value)
75
+ return tuple(value)
76
+
77
+ # @curry
78
+ @staticmethod
79
+ def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence]) -> Tuple:
80
+ Schema(Any(klass, ExactSequence([klass, klass])))(value)
81
+ try:
82
+ Schema(klass)(value)
83
+ return (value, value)
84
+ except:
85
+ return ConfigSanitizer.__validate_and_convert_twodim(klass, value)
86
+
87
+ # datasets schema
88
+ DATASET_ASCENDABLE_SCHEMA = {
89
+ "caption_extension": str,
90
+ "batch_size": int,
91
+ "resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int),
92
+ "enable_bucket": bool,
93
+ "bucket_no_upscale": bool,
94
+ }
95
+ IMAGE_DATASET_DISTINCT_SCHEMA = {
96
+ "image_directory": str,
97
+ "image_jsonl_file": str,
98
+ "cache_directory": str,
99
+ }
100
+ VIDEO_DATASET_DISTINCT_SCHEMA = {
101
+ "video_directory": str,
102
+ "video_jsonl_file": str,
103
+ "target_frames": [int],
104
+ "frame_extraction": str,
105
+ "frame_stride": int,
106
+ "frame_sample": int,
107
+ "cache_directory": str,
108
+ }
109
+
110
+ # options handled by argparse but not handled by user config
111
+ ARGPARSE_SPECIFIC_SCHEMA = {
112
+ "debug_dataset": bool,
113
+ }
114
+
115
+ def __init__(self) -> None:
116
+ self.image_dataset_schema = self.__merge_dict(
117
+ self.DATASET_ASCENDABLE_SCHEMA,
118
+ self.IMAGE_DATASET_DISTINCT_SCHEMA,
119
+ )
120
+ self.video_dataset_schema = self.__merge_dict(
121
+ self.DATASET_ASCENDABLE_SCHEMA,
122
+ self.VIDEO_DATASET_DISTINCT_SCHEMA,
123
+ )
124
+
125
+ def validate_flex_dataset(dataset_config: dict):
126
+ if "target_frames" in dataset_config:
127
+ return Schema(self.video_dataset_schema)(dataset_config)
128
+ else:
129
+ return Schema(self.image_dataset_schema)(dataset_config)
130
+
131
+ self.dataset_schema = validate_flex_dataset
132
+
133
+ self.general_schema = self.__merge_dict(
134
+ self.DATASET_ASCENDABLE_SCHEMA,
135
+ )
136
+ self.user_config_validator = Schema(
137
+ {
138
+ "general": self.general_schema,
139
+ "datasets": [self.dataset_schema],
140
+ }
141
+ )
142
+ self.argparse_schema = self.__merge_dict(
143
+ self.ARGPARSE_SPECIFIC_SCHEMA,
144
+ )
145
+ self.argparse_config_validator = Schema(Object(self.argparse_schema), extra=voluptuous.ALLOW_EXTRA)
146
+
147
+ def sanitize_user_config(self, user_config: dict) -> dict:
148
+ try:
149
+ return self.user_config_validator(user_config)
150
+ except MultipleInvalid:
151
+ # TODO: clarify the error message
152
+ logger.error("Invalid user config / ユーザ設定の形式が正しくないようです")
153
+ raise
154
+
155
+ # NOTE: In nature, argument parser result is not needed to be sanitize
156
+ # However this will help us to detect program bug
157
+ def sanitize_argparse_namespace(self, argparse_namespace: argparse.Namespace) -> argparse.Namespace:
158
+ try:
159
+ return self.argparse_config_validator(argparse_namespace)
160
+ except MultipleInvalid:
161
+ # XXX: this should be a bug
162
+ logger.error(
163
+ "Invalid cmdline parsed arguments. This should be a bug. / コマンドラインのパース結果が正しくないようです。プログラムのバグの可能性が高いです。"
164
+ )
165
+ raise
166
+
167
+ # NOTE: value would be overwritten by latter dict if there is already the same key
168
+ @staticmethod
169
+ def __merge_dict(*dict_list: dict) -> dict:
170
+ merged = {}
171
+ for schema in dict_list:
172
+ # merged |= schema
173
+ for k, v in schema.items():
174
+ merged[k] = v
175
+ return merged
176
+
177
+
178
+ class BlueprintGenerator:
179
+ BLUEPRINT_PARAM_NAME_TO_CONFIG_OPTNAME = {}
180
+
181
+ def __init__(self, sanitizer: ConfigSanitizer):
182
+ self.sanitizer = sanitizer
183
+
184
+ # runtime_params is for parameters which is only configurable on runtime, such as tokenizer
185
+ def generate(self, user_config: dict, argparse_namespace: argparse.Namespace, **runtime_params) -> Blueprint:
186
+ sanitized_user_config = self.sanitizer.sanitize_user_config(user_config)
187
+ sanitized_argparse_namespace = self.sanitizer.sanitize_argparse_namespace(argparse_namespace)
188
+
189
+ argparse_config = {k: v for k, v in vars(sanitized_argparse_namespace).items() if v is not None}
190
+ general_config = sanitized_user_config.get("general", {})
191
+
192
+ dataset_blueprints = []
193
+ for dataset_config in sanitized_user_config.get("datasets", []):
194
+ is_image_dataset = "target_frames" not in dataset_config
195
+ if is_image_dataset:
196
+ dataset_params_klass = ImageDatasetParams
197
+ else:
198
+ dataset_params_klass = VideoDatasetParams
199
+
200
+ params = self.generate_params_by_fallbacks(
201
+ dataset_params_klass, [dataset_config, general_config, argparse_config, runtime_params]
202
+ )
203
+ dataset_blueprints.append(DatasetBlueprint(is_image_dataset, params))
204
+
205
+ dataset_group_blueprint = DatasetGroupBlueprint(dataset_blueprints)
206
+
207
+ return Blueprint(dataset_group_blueprint)
208
+
209
+ @staticmethod
210
+ def generate_params_by_fallbacks(param_klass, fallbacks: Sequence[dict]):
211
+ name_map = BlueprintGenerator.BLUEPRINT_PARAM_NAME_TO_CONFIG_OPTNAME
212
+ search_value = BlueprintGenerator.search_value
213
+ default_params = asdict(param_klass())
214
+ param_names = default_params.keys()
215
+
216
+ params = {name: search_value(name_map.get(name, name), fallbacks, default_params.get(name)) for name in param_names}
217
+
218
+ return param_klass(**params)
219
+
220
+ @staticmethod
221
+ def search_value(key: str, fallbacks: Sequence[dict], default_value=None):
222
+ for cand in fallbacks:
223
+ value = cand.get(key)
224
+ if value is not None:
225
+ return value
226
+
227
+ return default_value
228
+
229
+
230
+ # if training is True, it will return a dataset group for training, otherwise for caching
231
+ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlueprint, training: bool = False) -> DatasetGroup:
232
+ datasets: List[Union[ImageDataset, VideoDataset]] = []
233
+
234
+ for dataset_blueprint in dataset_group_blueprint.datasets:
235
+ if dataset_blueprint.is_image_dataset:
236
+ dataset_klass = ImageDataset
237
+ else:
238
+ dataset_klass = VideoDataset
239
+
240
+ dataset = dataset_klass(**asdict(dataset_blueprint.params))
241
+ datasets.append(dataset)
242
+
243
+ # print info
244
+ info = ""
245
+ for i, dataset in enumerate(datasets):
246
+ is_image_dataset = isinstance(dataset, ImageDataset)
247
+ info += dedent(
248
+ f"""\
249
+ [Dataset {i}]
250
+ is_image_dataset: {is_image_dataset}
251
+ resolution: {dataset.resolution}
252
+ batch_size: {dataset.batch_size}
253
+ caption_extension: "{dataset.caption_extension}"
254
+ enable_bucket: {dataset.enable_bucket}
255
+ bucket_no_upscale: {dataset.bucket_no_upscale}
256
+ cache_directory: "{dataset.cache_directory}"
257
+ debug_dataset: {dataset.debug_dataset}
258
+ """
259
+ )
260
+
261
+ if is_image_dataset:
262
+ info += indent(
263
+ dedent(
264
+ f"""\
265
+ image_directory: "{dataset.image_directory}"
266
+ image_jsonl_file: "{dataset.image_jsonl_file}"
267
+ \n"""
268
+ ),
269
+ " ",
270
+ )
271
+ else:
272
+ info += indent(
273
+ dedent(
274
+ f"""\
275
+ video_directory: "{dataset.video_directory}"
276
+ video_jsonl_file: "{dataset.video_jsonl_file}"
277
+ target_frames: {dataset.target_frames}
278
+ frame_extraction: {dataset.frame_extraction}
279
+ frame_stride: {dataset.frame_stride}
280
+ frame_sample: {dataset.frame_sample}
281
+ \n"""
282
+ ),
283
+ " ",
284
+ )
285
+ logger.info(f"{info}")
286
+
287
+ # make buckets first because it determines the length of dataset
288
+ # and set the same seed for all datasets
289
+ seed = random.randint(0, 2**31) # actual seed is seed + epoch_no
290
+ for i, dataset in enumerate(datasets):
291
+ # logger.info(f"[Dataset {i}]")
292
+ dataset.set_seed(seed)
293
+ if training:
294
+ dataset.prepare_for_training()
295
+
296
+ return DatasetGroup(datasets)
297
+
298
+
299
+ def load_user_config(file: str) -> dict:
300
+ file: Path = Path(file)
301
+ if not file.is_file():
302
+ raise ValueError(f"file not found / ファイルが見つかりません: {file}")
303
+
304
+ if file.name.lower().endswith(".json"):
305
+ try:
306
+ with open(file, "r") as f:
307
+ config = json.load(f)
308
+ except Exception:
309
+ logger.error(
310
+ f"Error on parsing JSON config file. Please check the format. / JSON 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}"
311
+ )
312
+ raise
313
+ elif file.name.lower().endswith(".toml"):
314
+ try:
315
+ config = toml.load(file)
316
+ except Exception:
317
+ logger.error(
318
+ f"Error on parsing TOML config file. Please check the format. / TOML 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}"
319
+ )
320
+ raise
321
+ else:
322
+ raise ValueError(f"not supported config file format / 対応していない設定ファイルの形式です: {file}")
323
+
324
+ return config
325
+
326
+
327
+ # for config test
328
+ if __name__ == "__main__":
329
+ parser = argparse.ArgumentParser()
330
+ parser.add_argument("dataset_config")
331
+ config_args, remain = parser.parse_known_args()
332
+
333
+ parser = argparse.ArgumentParser()
334
+ parser.add_argument("--debug_dataset", action="store_true")
335
+ argparse_namespace = parser.parse_args(remain)
336
+
337
+ logger.info("[argparse_namespace]")
338
+ logger.info(f"{vars(argparse_namespace)}")
339
+
340
+ user_config = load_user_config(config_args.dataset_config)
341
+
342
+ logger.info("")
343
+ logger.info("[user_config]")
344
+ logger.info(f"{user_config}")
345
+
346
+ sanitizer = ConfigSanitizer()
347
+ sanitized_user_config = sanitizer.sanitize_user_config(user_config)
348
+
349
+ logger.info("")
350
+ logger.info("[sanitized_user_config]")
351
+ logger.info(f"{sanitized_user_config}")
352
+
353
+ blueprint = BlueprintGenerator(sanitizer).generate(user_config, argparse_namespace)
354
+
355
+ logger.info("")
356
+ logger.info("[blueprint]")
357
+ logger.info(f"{blueprint}")
358
+
359
+ dataset_group = generate_dataset_group_by_blueprint(blueprint.dataset_group)
dataset/dataset_config.md ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Dataset Configuration
2
+
3
+ Please create a TOML file for dataset configuration.
4
+
5
+ Image and video datasets are supported. The configuration file can include multiple datasets, either image or video datasets, with caption text files or metadata JSONL files.
6
+
7
+ ### Sample for Image Dataset with Caption Text Files
8
+
9
+ ```toml
10
+ # resolution, caption_extension, batch_size, enable_bucket, bucket_no_upscale must be set in either general or datasets
11
+
12
+ # general configurations
13
+ [general]
14
+ resolution = [960, 544]
15
+ caption_extension = ".txt"
16
+ batch_size = 1
17
+ enable_bucket = true
18
+ bucket_no_upscale = false
19
+
20
+ [[datasets]]
21
+ image_directory = "/path/to/image_dir"
22
+
23
+ # other datasets can be added here. each dataset can have different configurations
24
+ ```
25
+
26
+ ### Sample for Image Dataset with Metadata JSONL File
27
+
28
+ ```toml
29
+ # resolution, batch_size, enable_bucket, bucket_no_upscale must be set in either general or datasets
30
+ # caption_extension is not required for metadata jsonl file
31
+ # cache_directory is required for each dataset with metadata jsonl file
32
+
33
+ # general configurations
34
+ [general]
35
+ resolution = [960, 544]
36
+ batch_size = 1
37
+ enable_bucket = true
38
+ bucket_no_upscale = false
39
+
40
+ [[datasets]]
41
+ image_jsonl_file = "/path/to/metadata.jsonl"
42
+ cache_directory = "/path/to/cache_directory"
43
+
44
+ # other datasets can be added here. each dataset can have different configurations
45
+ ```
46
+
47
+ JSONL file format for metadata:
48
+
49
+ ```json
50
+ {"image_path": "/path/to/image1.jpg", "caption": "A caption for image1"}
51
+ {"image_path": "/path/to/image2.jpg", "caption": "A caption for image2"}
52
+ ```
53
+
54
+ ### Sample for Video Dataset with Caption Text Files
55
+
56
+ ```toml
57
+ # resolution, caption_extension, target_frames, frame_extraction, frame_stride, frame_sample, batch_size, enable_bucket, bucket_no_upscale must be set in either general or datasets
58
+
59
+ # general configurations
60
+ [general]
61
+ resolution = [960, 544]
62
+ caption_extension = ".txt"
63
+ batch_size = 1
64
+ enable_bucket = true
65
+ bucket_no_upscale = false
66
+
67
+ [[datasets]]
68
+ video_directory = "/path/to/video_dir"
69
+ target_frames = [1, 25, 45]
70
+ frame_extraction = "head"
71
+
72
+ # other datasets can be added here. each dataset can have different configurations
73
+ ```
74
+
75
+ ### Sample for Video Dataset with Metadata JSONL File
76
+
77
+ ```toml
78
+ # resolution, target_frames, frame_extraction, frame_stride, frame_sample, batch_size, enable_bucket, bucket_no_upscale must be set in either general or datasets
79
+ # caption_extension is not required for metadata jsonl file
80
+ # cache_directory is required for each dataset with metadata jsonl file
81
+
82
+ # general configurations
83
+ [general]
84
+ resolution = [960, 544]
85
+ batch_size = 1
86
+ enable_bucket = true
87
+ bucket_no_upscale = false
88
+
89
+ [[datasets]]
90
+ video_jsonl_file = "/path/to/metadata.jsonl"
91
+ target_frames = [1, 25, 45]
92
+ frame_extraction = "head"
93
+ cache_directory = "/path/to/cache_directory"
94
+
95
+ # same metadata jsonl file can be used for multiple datasets
96
+ [[datasets]]
97
+ video_jsonl_file = "/path/to/metadata.jsonl"
98
+ target_frames = [1]
99
+ frame_stride = 10
100
+ cache_directory = "/path/to/cache_directory"
101
+
102
+ # other datasets can be added here. each dataset can have different configurations
103
+ ```
104
+
105
+ JSONL file format for metadata:
106
+
107
+ ```json
108
+ {"video_path": "/path/to/video1.mp4", "caption": "A caption for video1"}
109
+ {"video_path": "/path/to/video2.mp4", "caption": "A caption for video2"}
110
+ ```
111
+
112
+ ### fame_extraction Options
113
+
114
+ - `head`: Extract the first N frames from the video.
115
+ - `chunk`: Extract frames by splitting the video into chunks of N frames.
116
+ - `slide`: Extract frames from the video with a stride of `frame_stride`.
117
+ - `uniform`: Extract `frame_sample` samples uniformly from the video.
118
+
119
+ For example, consider a video with 40 frames. The following diagrams illustrate each extraction:
120
+
121
+ ```
122
+ Original Video, 40 frames: x = frame, o = no frame
123
+ oooooooooooooooooooooooooooooooooooooooo
124
+
125
+ head, target_frames = [1, 13, 25] -> extract head frames:
126
+ xooooooooooooooooooooooooooooooooooooooo
127
+ xxxxxxxxxxxxxooooooooooooooooooooooooooo
128
+ xxxxxxxxxxxxxxxxxxxxxxxxxooooooooooooooo
129
+
130
+ chunk, target_frames = [13, 25] -> extract frames by splitting into chunks, into 13 and 25 frames:
131
+ xxxxxxxxxxxxxooooooooooooooooooooooooooo
132
+ oooooooooooooxxxxxxxxxxxxxoooooooooooooo
133
+ ooooooooooooooooooooooooooxxxxxxxxxxxxxo
134
+ xxxxxxxxxxxxxxxxxxxxxxxxxooooooooooooooo
135
+
136
+ NOTE: Please do not include 1 in target_frames if you are using the frame_extraction "chunk". It will make the all frames to be extracted.
137
+
138
+ slide, target_frames = [1, 13, 25], frame_stride = 10 -> extract N frames with a stride of 10:
139
+ xooooooooooooooooooooooooooooooooooooooo
140
+ ooooooooooxooooooooooooooooooooooooooooo
141
+ ooooooooooooooooooooxooooooooooooooooooo
142
+ ooooooooooooooooooooooooooooooxooooooooo
143
+ xxxxxxxxxxxxxooooooooooooooooooooooooooo
144
+ ooooooooooxxxxxxxxxxxxxooooooooooooooooo
145
+ ooooooooooooooooooooxxxxxxxxxxxxxooooooo
146
+ xxxxxxxxxxxxxxxxxxxxxxxxxooooooooooooooo
147
+ ooooooooooxxxxxxxxxxxxxxxxxxxxxxxxxooooo
148
+
149
+ uniform, target_frames =[1, 13, 25], frame_sample = 4 -> extract `frame_sample` samples uniformly, N frames each:
150
+ xooooooooooooooooooooooooooooooooooooooo
151
+ oooooooooooooxoooooooooooooooooooooooooo
152
+ oooooooooooooooooooooooooxoooooooooooooo
153
+ ooooooooooooooooooooooooooooooooooooooox
154
+ xxxxxxxxxxxxxooooooooooooooooooooooooooo
155
+ oooooooooxxxxxxxxxxxxxoooooooooooooooooo
156
+ ooooooooooooooooooxxxxxxxxxxxxxooooooooo
157
+ oooooooooooooooooooooooooooxxxxxxxxxxxxx
158
+ xxxxxxxxxxxxxxxxxxxxxxxxxooooooooooooooo
159
+ oooooxxxxxxxxxxxxxxxxxxxxxxxxxoooooooooo
160
+ ooooooooooxxxxxxxxxxxxxxxxxxxxxxxxxooooo
161
+ oooooooooooooooxxxxxxxxxxxxxxxxxxxxxxxxx
162
+ ```
163
+
164
+ ## Specifications
165
+
166
+ ```toml
167
+ # general configurations
168
+ [general]
169
+ resolution = [960, 544] # optional, [W, H], default is None. This is the default resolution for all datasets
170
+ caption_extension = ".txt" # optional, default is None. This is the default caption extension for all datasets
171
+ batch_size = 1 # optional, default is 1. This is the default batch size for all datasets
172
+ enable_bucket = true # optional, default is false. Enable bucketing for datasets
173
+ bucket_no_upscale = false # optional, default is false. Disable upscaling for bucketing. Ignored if enable_bucket is false
174
+
175
+ ### Image Dataset
176
+
177
+ # sample image dataset with caption text files
178
+ [[datasets]]
179
+ image_directory = "/path/to/image_dir"
180
+ caption_extension = ".txt" # required for caption text files, if general caption extension is not set
181
+ resolution = [960, 544] # required if general resolution is not set
182
+ batch_size = 4 # optional, overwrite the default batch size
183
+ enable_bucket = false # optional, overwrite the default bucketing setting
184
+ bucket_no_upscale = true # optional, overwrite the default bucketing setting
185
+ cache_directory = "/path/to/cache_directory" # optional, default is None to use the same directory as the image directory. NOTE: caching is always enabled
186
+
187
+ # sample image dataset with metadata **jsonl** file
188
+ [[datasets]]
189
+ image_jsonl_file = "/path/to/metadata.jsonl" # includes pairs of image files and captions
190
+ resolution = [960, 544] # required if general resolution is not set
191
+ cache_directory = "/path/to/cache_directory" # required for metadata jsonl file
192
+ # caption_extension is not required for metadata jsonl file
193
+ # batch_size, enable_bucket, bucket_no_upscale are also available for metadata jsonl file
194
+
195
+ ### Video Dataset
196
+
197
+ # sample video dataset with caption text files
198
+ [[datasets]]
199
+ video_directory = "/path/to/video_dir"
200
+ caption_extension = ".txt" # required for caption text files, if general caption extension is not set
201
+ resolution = [960, 544] # required if general resolution is not set
202
+
203
+ target_frames = [1, 25, 79] # required for video dataset. list of video lengths to extract frames. each element must be N*4+1 (N=0,1,2,...)
204
+
205
+ # NOTE: Please do not include 1 in target_frames if you are using the frame_extraction "chunk". It will make the all frames to be extracted.
206
+
207
+ frame_extraction = "head" # optional, "head" or "chunk", "slide", "uniform". Default is "head"
208
+ frame_stride = 1 # optional, default is 1, available for "slide" frame extraction
209
+ frame_sample = 4 # optional, default is 1 (same as "head"), available for "uniform" frame extraction
210
+ # batch_size, enable_bucket, bucket_no_upscale, cache_directory are also available for video dataset
211
+
212
+ # sample video dataset with metadata jsonl file
213
+ [[datasets]]
214
+ video_jsonl_file = "/path/to/metadata.jsonl" # includes pairs of video files and captions
215
+
216
+ target_frames = [1, 79]
217
+
218
+ cache_directory = "/path/to/cache_directory" # required for metadata jsonl file
219
+ # frame_extraction, frame_stride, frame_sample are also available for metadata jsonl file
220
+ ```
221
+
222
+ <!--
223
+ # sample image dataset with lance
224
+ [[datasets]]
225
+ image_lance_dataset = "/path/to/lance_dataset"
226
+ resolution = [960, 544] # required if general resolution is not set
227
+ # batch_size, enable_bucket, bucket_no_upscale, cache_directory are also available for lance dataset
228
+ -->
229
+
230
+ The metadata with .json file will be supported in the near future.
231
+
232
+
233
+
234
+ <!--
235
+
236
+ ```toml
237
+ # general configurations
238
+ [general]
239
+ resolution = [960, 544] # optional, [W, H], default is None. This is the default resolution for all datasets
240
+ caption_extension = ".txt" # optional, default is None. This is the default caption extension for all datasets
241
+ batch_size = 1 # optional, default is 1. This is the default batch size for all datasets
242
+ enable_bucket = true # optional, default is false. Enable bucketing for datasets
243
+ bucket_no_upscale = false # optional, default is false. Disable upscaling for bucketing. Ignored if enable_bucket is false
244
+
245
+ # sample image dataset with caption text files
246
+ [[datasets]]
247
+ image_directory = "/path/to/image_dir"
248
+ caption_extension = ".txt" # required for caption text files, if general caption extension is not set
249
+ resolution = [960, 544] # required if general resolution is not set
250
+ batch_size = 4 # optional, overwrite the default batch size
251
+ enable_bucket = false # optional, overwrite the default bucketing setting
252
+ bucket_no_upscale = true # optional, overwrite the default bucketing setting
253
+ cache_directory = "/path/to/cache_directory" # optional, default is None to use the same directory as the image directory. NOTE: caching is always enabled
254
+
255
+ # sample image dataset with metadata **jsonl** file
256
+ [[datasets]]
257
+ image_jsonl_file = "/path/to/metadata.jsonl" # includes pairs of image files and captions
258
+ resolution = [960, 544] # required if general resolution is not set
259
+ cache_directory = "/path/to/cache_directory" # required for metadata jsonl file
260
+ # caption_extension is not required for metadata jsonl file
261
+ # batch_size, enable_bucket, bucket_no_upscale are also available for metadata jsonl file
262
+
263
+ # sample video dataset with caption text files
264
+ [[datasets]]
265
+ video_directory = "/path/to/video_dir"
266
+ caption_extension = ".txt" # required for caption text files, if general caption extension is not set
267
+ resolution = [960, 544] # required if general resolution is not set
268
+ target_frames = [1, 25, 79] # required for video dataset. list of video lengths to extract frames. each element must be N*4+1 (N=0,1,2,...)
269
+ frame_extraction = "head" # optional, "head" or "chunk", "slide", "uniform". Default is "head"
270
+ frame_stride = 1 # optional, default is 1, available for "slide" frame extraction
271
+ frame_sample = 4 # optional, default is 1 (same as "head"), available for "uniform" frame extraction
272
+ # batch_size, enable_bucket, bucket_no_upscale, cache_directory are also available for video dataset
273
+
274
+ # sample video dataset with metadata jsonl file
275
+ [[datasets]]
276
+ video_jsonl_file = "/path/to/metadata.jsonl" # includes pairs of video files and captions
277
+ target_frames = [1, 79]
278
+ cache_directory = "/path/to/cache_directory" # required for metadata jsonl file
279
+ # frame_extraction, frame_stride, frame_sample are also available for metadata jsonl file
280
+ ```
281
+
282
+ # sample image dataset with lance
283
+ [[datasets]]
284
+ image_lance_dataset = "/path/to/lance_dataset"
285
+ resolution = [960, 544] # required if general resolution is not set
286
+ # batch_size, enable_bucket, bucket_no_upscale, cache_directory are also available for lance dataset
287
+
288
+ The metadata with .json file will be supported in the near future.
289
+
290
+
291
+
292
+
293
+ -->
dataset/image_video_dataset.py ADDED
@@ -0,0 +1,1255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from concurrent.futures import ThreadPoolExecutor
2
+ import glob
3
+ import json
4
+ import math
5
+ import os
6
+ import random
7
+ import time
8
+ from typing import Optional, Sequence, Tuple, Union
9
+
10
+ import numpy as np
11
+ import torch
12
+ from safetensors.torch import save_file, load_file
13
+ from safetensors import safe_open
14
+ from PIL import Image
15
+ import cv2
16
+ import av
17
+
18
+ from utils import safetensors_utils
19
+ from utils.model_utils import dtype_to_str
20
+
21
+ import logging
22
+
23
+ logger = logging.getLogger(__name__)
24
+ logging.basicConfig(level=logging.INFO)
25
+
26
+
27
+ IMAGE_EXTENSIONS = [".png", ".jpg", ".jpeg", ".webp", ".bmp", ".PNG", ".JPG", ".JPEG", ".WEBP", ".BMP"]
28
+
29
+ try:
30
+ import pillow_avif
31
+
32
+ IMAGE_EXTENSIONS.extend([".avif", ".AVIF"])
33
+ except:
34
+ pass
35
+
36
+ # JPEG-XL on Linux
37
+ try:
38
+ from jxlpy import JXLImagePlugin
39
+
40
+ IMAGE_EXTENSIONS.extend([".jxl", ".JXL"])
41
+ except:
42
+ pass
43
+
44
+ # JPEG-XL on Windows
45
+ try:
46
+ import pillow_jxl
47
+
48
+ IMAGE_EXTENSIONS.extend([".jxl", ".JXL"])
49
+ except:
50
+ pass
51
+
52
+ VIDEO_EXTENSIONS = [".mp4", ".avi", ".mov", ".webm", ".MP4", ".AVI", ".MOV", ".WEBM"] # some of them are not tested
53
+
54
+ ARCHITECTURE_HUNYUAN_VIDEO = "hv"
55
+
56
+
57
+ def glob_images(directory, base="*"):
58
+ img_paths = []
59
+ for ext in IMAGE_EXTENSIONS:
60
+ if base == "*":
61
+ img_paths.extend(glob.glob(os.path.join(glob.escape(directory), base + ext)))
62
+ else:
63
+ img_paths.extend(glob.glob(glob.escape(os.path.join(directory, base + ext))))
64
+ img_paths = list(set(img_paths)) # remove duplicates
65
+ img_paths.sort()
66
+ return img_paths
67
+
68
+
69
+ def glob_videos(directory, base="*"):
70
+ video_paths = []
71
+ for ext in VIDEO_EXTENSIONS:
72
+ if base == "*":
73
+ video_paths.extend(glob.glob(os.path.join(glob.escape(directory), base + ext)))
74
+ else:
75
+ video_paths.extend(glob.glob(glob.escape(os.path.join(directory, base + ext))))
76
+ video_paths = list(set(video_paths)) # remove duplicates
77
+ video_paths.sort()
78
+ return video_paths
79
+
80
+
81
+ def divisible_by(num: int, divisor: int) -> int:
82
+ return num - num % divisor
83
+
84
+
85
+ def resize_image_to_bucket(image: Union[Image.Image, np.ndarray], bucket_reso: tuple[int, int]) -> np.ndarray:
86
+ """
87
+ Resize the image to the bucket resolution.
88
+ """
89
+ is_pil_image = isinstance(image, Image.Image)
90
+ if is_pil_image:
91
+ image_width, image_height = image.size
92
+ else:
93
+ image_height, image_width = image.shape[:2]
94
+
95
+ if bucket_reso == (image_width, image_height):
96
+ return np.array(image) if is_pil_image else image
97
+
98
+ bucket_width, bucket_height = bucket_reso
99
+ if bucket_width == image_width or bucket_height == image_height:
100
+ image = np.array(image) if is_pil_image else image
101
+ else:
102
+ # resize the image to the bucket resolution to match the short side
103
+ scale_width = bucket_width / image_width
104
+ scale_height = bucket_height / image_height
105
+ scale = max(scale_width, scale_height)
106
+ image_width = int(image_width * scale + 0.5)
107
+ image_height = int(image_height * scale + 0.5)
108
+
109
+ if scale > 1:
110
+ image = Image.fromarray(image) if not is_pil_image else image
111
+ image = image.resize((image_width, image_height), Image.LANCZOS)
112
+ image = np.array(image)
113
+ else:
114
+ image = np.array(image) if is_pil_image else image
115
+ image = cv2.resize(image, (image_width, image_height), interpolation=cv2.INTER_AREA)
116
+
117
+ # crop the image to the bucket resolution
118
+ crop_left = (image_width - bucket_width) // 2
119
+ crop_top = (image_height - bucket_height) // 2
120
+ image = image[crop_top : crop_top + bucket_height, crop_left : crop_left + bucket_width]
121
+ return image
122
+
123
+
124
+ class ItemInfo:
125
+ def __init__(
126
+ self,
127
+ item_key: str,
128
+ caption: str,
129
+ original_size: tuple[int, int],
130
+ bucket_size: Optional[Union[tuple[int, int], tuple[int, int, int]]] = None,
131
+ frame_count: Optional[int] = None,
132
+ content: Optional[np.ndarray] = None,
133
+ latent_cache_path: Optional[str] = None,
134
+ ) -> None:
135
+ self.item_key = item_key
136
+ self.caption = caption
137
+ self.original_size = original_size
138
+ self.bucket_size = bucket_size
139
+ self.frame_count = frame_count
140
+ self.content = content
141
+ self.latent_cache_path = latent_cache_path
142
+ self.text_encoder_output_cache_path: Optional[str] = None
143
+
144
+ def __str__(self) -> str:
145
+ return (
146
+ f"ItemInfo(item_key={self.item_key}, caption={self.caption}, "
147
+ + f"original_size={self.original_size}, bucket_size={self.bucket_size}, "
148
+ + f"frame_count={self.frame_count}, latent_cache_path={self.latent_cache_path})"
149
+ )
150
+
151
+
152
+ def save_latent_cache(item_info: ItemInfo, latent: torch.Tensor):
153
+ assert latent.dim() == 4, "latent should be 4D tensor (frame, channel, height, width)"
154
+ metadata = {
155
+ "architecture": "hunyuan_video",
156
+ "width": f"{item_info.original_size[0]}",
157
+ "height": f"{item_info.original_size[1]}",
158
+ "format_version": "1.0.0",
159
+ }
160
+ if item_info.frame_count is not None:
161
+ metadata["frame_count"] = f"{item_info.frame_count}"
162
+
163
+ _, F, H, W = latent.shape
164
+ dtype_str = dtype_to_str(latent.dtype)
165
+ sd = {f"latents_{F}x{H}x{W}_{dtype_str}": latent.detach().cpu()}
166
+
167
+ latent_dir = os.path.dirname(item_info.latent_cache_path)
168
+ os.makedirs(latent_dir, exist_ok=True)
169
+
170
+ save_file(sd, item_info.latent_cache_path, metadata=metadata)
171
+
172
+
173
+ def save_text_encoder_output_cache(item_info: ItemInfo, embed: torch.Tensor, mask: Optional[torch.Tensor], is_llm: bool):
174
+ assert (
175
+ embed.dim() == 1 or embed.dim() == 2
176
+ ), f"embed should be 2D tensor (feature, hidden_size) or (hidden_size,), got {embed.shape}"
177
+ assert mask is None or mask.dim() == 1, f"mask should be 1D tensor (feature), got {mask.shape}"
178
+ metadata = {
179
+ "architecture": "hunyuan_video",
180
+ "caption1": item_info.caption,
181
+ "format_version": "1.0.0",
182
+ }
183
+
184
+ sd = {}
185
+ if os.path.exists(item_info.text_encoder_output_cache_path):
186
+ # load existing cache and update metadata
187
+ with safetensors_utils.MemoryEfficientSafeOpen(item_info.text_encoder_output_cache_path) as f:
188
+ existing_metadata = f.metadata()
189
+ for key in f.keys():
190
+ sd[key] = f.get_tensor(key)
191
+
192
+ assert existing_metadata["architecture"] == metadata["architecture"], "architecture mismatch"
193
+ if existing_metadata["caption1"] != metadata["caption1"]:
194
+ logger.warning(f"caption mismatch: existing={existing_metadata['caption1']}, new={metadata['caption1']}, overwrite")
195
+ # TODO verify format_version
196
+
197
+ existing_metadata.pop("caption1", None)
198
+ existing_metadata.pop("format_version", None)
199
+ metadata.update(existing_metadata) # copy existing metadata
200
+ else:
201
+ text_encoder_output_dir = os.path.dirname(item_info.text_encoder_output_cache_path)
202
+ os.makedirs(text_encoder_output_dir, exist_ok=True)
203
+
204
+ dtype_str = dtype_to_str(embed.dtype)
205
+ text_encoder_type = "llm" if is_llm else "clipL"
206
+ sd[f"{text_encoder_type}_{dtype_str}"] = embed.detach().cpu()
207
+ if mask is not None:
208
+ sd[f"{text_encoder_type}_mask"] = mask.detach().cpu()
209
+
210
+ safetensors_utils.mem_eff_save_file(sd, item_info.text_encoder_output_cache_path, metadata=metadata)
211
+
212
+
213
+ class BucketSelector:
214
+ RESOLUTION_STEPS_HUNYUAN = 16
215
+
216
+ def __init__(self, resolution: Tuple[int, int], enable_bucket: bool = True, no_upscale: bool = False):
217
+ self.resolution = resolution
218
+ self.bucket_area = resolution[0] * resolution[1]
219
+ self.reso_steps = BucketSelector.RESOLUTION_STEPS_HUNYUAN
220
+
221
+ if not enable_bucket:
222
+ # only define one bucket
223
+ self.bucket_resolutions = [resolution]
224
+ self.no_upscale = False
225
+ else:
226
+ # prepare bucket resolution
227
+ self.no_upscale = no_upscale
228
+ sqrt_size = int(math.sqrt(self.bucket_area))
229
+ min_size = divisible_by(sqrt_size // 2, self.reso_steps)
230
+ self.bucket_resolutions = []
231
+ for w in range(min_size, sqrt_size + self.reso_steps, self.reso_steps):
232
+ h = divisible_by(self.bucket_area // w, self.reso_steps)
233
+ self.bucket_resolutions.append((w, h))
234
+ self.bucket_resolutions.append((h, w))
235
+
236
+ self.bucket_resolutions = list(set(self.bucket_resolutions))
237
+ self.bucket_resolutions.sort()
238
+
239
+ # calculate aspect ratio to find the nearest resolution
240
+ self.aspect_ratios = np.array([w / h for w, h in self.bucket_resolutions])
241
+
242
+ def get_bucket_resolution(self, image_size: tuple[int, int]) -> tuple[int, int]:
243
+ """
244
+ return the bucket resolution for the given image size, (width, height)
245
+ """
246
+ area = image_size[0] * image_size[1]
247
+ if self.no_upscale and area <= self.bucket_area:
248
+ w, h = image_size
249
+ w = divisible_by(w, self.reso_steps)
250
+ h = divisible_by(h, self.reso_steps)
251
+ return w, h
252
+
253
+ aspect_ratio = image_size[0] / image_size[1]
254
+ ar_errors = self.aspect_ratios - aspect_ratio
255
+ bucket_id = np.abs(ar_errors).argmin()
256
+ return self.bucket_resolutions[bucket_id]
257
+
258
+
259
+ def load_video(
260
+ video_path: str,
261
+ start_frame: Optional[int] = None,
262
+ end_frame: Optional[int] = None,
263
+ bucket_selector: Optional[BucketSelector] = None,
264
+ ) -> list[np.ndarray]:
265
+ container = av.open(video_path)
266
+ video = []
267
+ bucket_reso = None
268
+ for i, frame in enumerate(container.decode(video=0)):
269
+ if start_frame is not None and i < start_frame:
270
+ continue
271
+ if end_frame is not None and i >= end_frame:
272
+ break
273
+ frame = frame.to_image()
274
+
275
+ if bucket_selector is not None and bucket_reso is None:
276
+ bucket_reso = bucket_selector.get_bucket_resolution(frame.size)
277
+
278
+ if bucket_reso is not None:
279
+ frame = resize_image_to_bucket(frame, bucket_reso)
280
+ else:
281
+ frame = np.array(frame)
282
+
283
+ video.append(frame)
284
+ container.close()
285
+ return video
286
+
287
+
288
+ class BucketBatchManager:
289
+
290
+ def __init__(self, bucketed_item_info: dict[tuple[int, int], list[ItemInfo]], batch_size: int):
291
+ self.batch_size = batch_size
292
+ self.buckets = bucketed_item_info
293
+ self.bucket_resos = list(self.buckets.keys())
294
+ self.bucket_resos.sort()
295
+
296
+ self.bucket_batch_indices = []
297
+ for bucket_reso in self.bucket_resos:
298
+ bucket = self.buckets[bucket_reso]
299
+ num_batches = math.ceil(len(bucket) / self.batch_size)
300
+ for i in range(num_batches):
301
+ self.bucket_batch_indices.append((bucket_reso, i))
302
+
303
+ self.shuffle()
304
+
305
+ def show_bucket_info(self):
306
+ for bucket_reso in self.bucket_resos:
307
+ bucket = self.buckets[bucket_reso]
308
+ logger.info(f"bucket: {bucket_reso}, count: {len(bucket)}")
309
+
310
+ logger.info(f"total batches: {len(self)}")
311
+
312
+ def shuffle(self):
313
+ for bucket in self.buckets.values():
314
+ random.shuffle(bucket)
315
+ random.shuffle(self.bucket_batch_indices)
316
+
317
+ def __len__(self):
318
+ return len(self.bucket_batch_indices)
319
+
320
+ def __getitem__(self, idx):
321
+ bucket_reso, batch_idx = self.bucket_batch_indices[idx]
322
+ bucket = self.buckets[bucket_reso]
323
+ start = batch_idx * self.batch_size
324
+ end = min(start + self.batch_size, len(bucket))
325
+
326
+ latents = []
327
+ llm_embeds = []
328
+ llm_masks = []
329
+ clip_l_embeds = []
330
+ for item_info in bucket[start:end]:
331
+ sd = load_file(item_info.latent_cache_path)
332
+ latent = None
333
+ for key in sd.keys():
334
+ if key.startswith("latents_"):
335
+ latent = sd[key]
336
+ break
337
+ latents.append(latent)
338
+
339
+ sd = load_file(item_info.text_encoder_output_cache_path)
340
+ llm_embed = llm_mask = clip_l_embed = None
341
+ for key in sd.keys():
342
+ if key.startswith("llm_mask"):
343
+ llm_mask = sd[key]
344
+ elif key.startswith("llm_"):
345
+ llm_embed = sd[key]
346
+ elif key.startswith("clipL_mask"):
347
+ pass
348
+ elif key.startswith("clipL_"):
349
+ clip_l_embed = sd[key]
350
+ llm_embeds.append(llm_embed)
351
+ llm_masks.append(llm_mask)
352
+ clip_l_embeds.append(clip_l_embed)
353
+
354
+ latents = torch.stack(latents)
355
+ llm_embeds = torch.stack(llm_embeds)
356
+ llm_masks = torch.stack(llm_masks)
357
+ clip_l_embeds = torch.stack(clip_l_embeds)
358
+
359
+ return latents, llm_embeds, llm_masks, clip_l_embeds
360
+
361
+
362
+ class ContentDatasource:
363
+ def __init__(self):
364
+ self.caption_only = False
365
+
366
+ def set_caption_only(self, caption_only: bool):
367
+ self.caption_only = caption_only
368
+
369
+ def is_indexable(self):
370
+ return False
371
+
372
+ def get_caption(self, idx: int) -> tuple[str, str]:
373
+ """
374
+ Returns caption. May not be called if is_indexable() returns False.
375
+ """
376
+ raise NotImplementedError
377
+
378
+ def __len__(self):
379
+ raise NotImplementedError
380
+
381
+ def __iter__(self):
382
+ raise NotImplementedError
383
+
384
+ def __next__(self):
385
+ raise NotImplementedError
386
+
387
+
388
+ class ImageDatasource(ContentDatasource):
389
+ def __init__(self):
390
+ super().__init__()
391
+
392
+ def get_image_data(self, idx: int) -> tuple[str, Image.Image, str]:
393
+ """
394
+ Returns image data as a tuple of image path, image, and caption for the given index.
395
+ Key must be unique and valid as a file name.
396
+ May not be called if is_indexable() returns False.
397
+ """
398
+ raise NotImplementedError
399
+
400
+
401
+ class ImageDirectoryDatasource(ImageDatasource):
402
+ def __init__(self, image_directory: str, caption_extension: Optional[str] = None):
403
+ super().__init__()
404
+ self.image_directory = image_directory
405
+ self.caption_extension = caption_extension
406
+ self.current_idx = 0
407
+
408
+ # glob images
409
+ logger.info(f"glob images in {self.image_directory}")
410
+ self.image_paths = glob_images(self.image_directory)
411
+ logger.info(f"found {len(self.image_paths)} images")
412
+
413
+ def is_indexable(self):
414
+ return True
415
+
416
+ def __len__(self):
417
+ return len(self.image_paths)
418
+
419
+ def get_image_data(self, idx: int) -> tuple[str, Image.Image, str]:
420
+ image_path = self.image_paths[idx]
421
+ image = Image.open(image_path).convert("RGB")
422
+
423
+ _, caption = self.get_caption(idx)
424
+
425
+ return image_path, image, caption
426
+
427
+ def get_caption(self, idx: int) -> tuple[str, str]:
428
+ image_path = self.image_paths[idx]
429
+ caption_path = os.path.splitext(image_path)[0] + self.caption_extension if self.caption_extension else ""
430
+ with open(caption_path, "r", encoding="utf-8") as f:
431
+ caption = f.read().strip()
432
+ return image_path, caption
433
+
434
+ def __iter__(self):
435
+ self.current_idx = 0
436
+ return self
437
+
438
+ def __next__(self) -> callable:
439
+ """
440
+ Returns a fetcher function that returns image data.
441
+ """
442
+ if self.current_idx >= len(self.image_paths):
443
+ raise StopIteration
444
+
445
+ if self.caption_only:
446
+
447
+ def create_caption_fetcher(index):
448
+ return lambda: self.get_caption(index)
449
+
450
+ fetcher = create_caption_fetcher(self.current_idx)
451
+ else:
452
+
453
+ def create_image_fetcher(index):
454
+ return lambda: self.get_image_data(index)
455
+
456
+ fetcher = create_image_fetcher(self.current_idx)
457
+
458
+ self.current_idx += 1
459
+ return fetcher
460
+
461
+
462
+ class ImageJsonlDatasource(ImageDatasource):
463
+ def __init__(self, image_jsonl_file: str):
464
+ super().__init__()
465
+ self.image_jsonl_file = image_jsonl_file
466
+ self.current_idx = 0
467
+
468
+ # load jsonl
469
+ logger.info(f"load image jsonl from {self.image_jsonl_file}")
470
+ self.data = []
471
+ with open(self.image_jsonl_file, "r", encoding="utf-8") as f:
472
+ for line in f:
473
+ data = json.loads(line)
474
+ self.data.append(data)
475
+ logger.info(f"loaded {len(self.data)} images")
476
+
477
+ def is_indexable(self):
478
+ return True
479
+
480
+ def __len__(self):
481
+ return len(self.data)
482
+
483
+ def get_image_data(self, idx: int) -> tuple[str, Image.Image, str]:
484
+ data = self.data[idx]
485
+ image_path = data["image_path"]
486
+ image = Image.open(image_path).convert("RGB")
487
+
488
+ caption = data["caption"]
489
+
490
+ return image_path, image, caption
491
+
492
+ def get_caption(self, idx: int) -> tuple[str, str]:
493
+ data = self.data[idx]
494
+ image_path = data["image_path"]
495
+ caption = data["caption"]
496
+ return image_path, caption
497
+
498
+ def __iter__(self):
499
+ self.current_idx = 0
500
+ return self
501
+
502
+ def __next__(self) -> callable:
503
+ if self.current_idx >= len(self.data):
504
+ raise StopIteration
505
+
506
+ if self.caption_only:
507
+
508
+ def create_caption_fetcher(index):
509
+ return lambda: self.get_caption(index)
510
+
511
+ fetcher = create_caption_fetcher(self.current_idx)
512
+
513
+ else:
514
+
515
+ def create_fetcher(index):
516
+ return lambda: self.get_image_data(index)
517
+
518
+ fetcher = create_fetcher(self.current_idx)
519
+
520
+ self.current_idx += 1
521
+ return fetcher
522
+
523
+
524
+ class VideoDatasource(ContentDatasource):
525
+ def __init__(self):
526
+ super().__init__()
527
+
528
+ # None means all frames
529
+ self.start_frame = None
530
+ self.end_frame = None
531
+
532
+ self.bucket_selector = None
533
+
534
+ def __len__(self):
535
+ raise NotImplementedError
536
+
537
+ def get_video_data_from_path(
538
+ self,
539
+ video_path: str,
540
+ start_frame: Optional[int] = None,
541
+ end_frame: Optional[int] = None,
542
+ bucket_selector: Optional[BucketSelector] = None,
543
+ ) -> tuple[str, list[Image.Image], str]:
544
+ # this method can resize the video if bucket_selector is given to reduce the memory usage
545
+
546
+ start_frame = start_frame if start_frame is not None else self.start_frame
547
+ end_frame = end_frame if end_frame is not None else self.end_frame
548
+ bucket_selector = bucket_selector if bucket_selector is not None else self.bucket_selector
549
+
550
+ video = load_video(video_path, start_frame, end_frame, bucket_selector)
551
+ return video
552
+
553
+ def set_start_and_end_frame(self, start_frame: Optional[int], end_frame: Optional[int]):
554
+ self.start_frame = start_frame
555
+ self.end_frame = end_frame
556
+
557
+ def set_bucket_selector(self, bucket_selector: BucketSelector):
558
+ self.bucket_selector = bucket_selector
559
+
560
+ def __iter__(self):
561
+ raise NotImplementedError
562
+
563
+ def __next__(self):
564
+ raise NotImplementedError
565
+
566
+
567
+ class VideoDirectoryDatasource(VideoDatasource):
568
+ def __init__(self, video_directory: str, caption_extension: Optional[str] = None):
569
+ super().__init__()
570
+ self.video_directory = video_directory
571
+ self.caption_extension = caption_extension
572
+ self.current_idx = 0
573
+
574
+ # glob images
575
+ logger.info(f"glob images in {self.video_directory}")
576
+ self.video_paths = glob_videos(self.video_directory)
577
+ logger.info(f"found {len(self.video_paths)} videos")
578
+
579
+ def is_indexable(self):
580
+ return True
581
+
582
+ def __len__(self):
583
+ return len(self.video_paths)
584
+
585
+ def get_video_data(
586
+ self,
587
+ idx: int,
588
+ start_frame: Optional[int] = None,
589
+ end_frame: Optional[int] = None,
590
+ bucket_selector: Optional[BucketSelector] = None,
591
+ ) -> tuple[str, list[Image.Image], str]:
592
+ video_path = self.video_paths[idx]
593
+ video = self.get_video_data_from_path(video_path, start_frame, end_frame, bucket_selector)
594
+
595
+ _, caption = self.get_caption(idx)
596
+
597
+ return video_path, video, caption
598
+
599
+ def get_caption(self, idx: int) -> tuple[str, str]:
600
+ video_path = self.video_paths[idx]
601
+ caption_path = os.path.splitext(video_path)[0] + self.caption_extension if self.caption_extension else ""
602
+ with open(caption_path, "r", encoding="utf-8") as f:
603
+ caption = f.read().strip()
604
+ return video_path, caption
605
+
606
+ def __iter__(self):
607
+ self.current_idx = 0
608
+ return self
609
+
610
+ def __next__(self):
611
+ if self.current_idx >= len(self.video_paths):
612
+ raise StopIteration
613
+
614
+ if self.caption_only:
615
+
616
+ def create_caption_fetcher(index):
617
+ return lambda: self.get_caption(index)
618
+
619
+ fetcher = create_caption_fetcher(self.current_idx)
620
+
621
+ else:
622
+
623
+ def create_fetcher(index):
624
+ return lambda: self.get_video_data(index)
625
+
626
+ fetcher = create_fetcher(self.current_idx)
627
+
628
+ self.current_idx += 1
629
+ return fetcher
630
+
631
+
632
+ class VideoJsonlDatasource(VideoDatasource):
633
+ def __init__(self, video_jsonl_file: str):
634
+ super().__init__()
635
+ self.video_jsonl_file = video_jsonl_file
636
+ self.current_idx = 0
637
+
638
+ # load jsonl
639
+ logger.info(f"load video jsonl from {self.video_jsonl_file}")
640
+ self.data = []
641
+ with open(self.video_jsonl_file, "r", encoding="utf-8") as f:
642
+ for line in f:
643
+ data = json.loads(line)
644
+ self.data.append(data)
645
+ logger.info(f"loaded {len(self.data)} videos")
646
+
647
+ def is_indexable(self):
648
+ return True
649
+
650
+ def __len__(self):
651
+ return len(self.data)
652
+
653
+ def get_video_data(
654
+ self,
655
+ idx: int,
656
+ start_frame: Optional[int] = None,
657
+ end_frame: Optional[int] = None,
658
+ bucket_selector: Optional[BucketSelector] = None,
659
+ ) -> tuple[str, list[Image.Image], str]:
660
+ data = self.data[idx]
661
+ video_path = data["video_path"]
662
+ video = self.get_video_data_from_path(video_path, start_frame, end_frame, bucket_selector)
663
+
664
+ caption = data["caption"]
665
+
666
+ return video_path, video, caption
667
+
668
+ def get_caption(self, idx: int) -> tuple[str, str]:
669
+ data = self.data[idx]
670
+ video_path = data["video_path"]
671
+ caption = data["caption"]
672
+ return video_path, caption
673
+
674
+ def __iter__(self):
675
+ self.current_idx = 0
676
+ return self
677
+
678
+ def __next__(self):
679
+ if self.current_idx >= len(self.data):
680
+ raise StopIteration
681
+
682
+ if self.caption_only:
683
+
684
+ def create_caption_fetcher(index):
685
+ return lambda: self.get_caption(index)
686
+
687
+ fetcher = create_caption_fetcher(self.current_idx)
688
+
689
+ else:
690
+
691
+ def create_fetcher(index):
692
+ return lambda: self.get_video_data(index)
693
+
694
+ fetcher = create_fetcher(self.current_idx)
695
+
696
+ self.current_idx += 1
697
+ return fetcher
698
+
699
+
700
+ class BaseDataset(torch.utils.data.Dataset):
701
+ def __init__(
702
+ self,
703
+ resolution: Tuple[int, int] = (960, 544),
704
+ caption_extension: Optional[str] = None,
705
+ batch_size: int = 1,
706
+ enable_bucket: bool = False,
707
+ bucket_no_upscale: bool = False,
708
+ cache_directory: Optional[str] = None,
709
+ debug_dataset: bool = False,
710
+ ):
711
+ self.resolution = resolution
712
+ self.caption_extension = caption_extension
713
+ self.batch_size = batch_size
714
+ self.enable_bucket = enable_bucket
715
+ self.bucket_no_upscale = bucket_no_upscale
716
+ self.cache_directory = cache_directory
717
+ self.debug_dataset = debug_dataset
718
+ self.seed = None
719
+ self.current_epoch = 0
720
+
721
+ if not self.enable_bucket:
722
+ self.bucket_no_upscale = False
723
+
724
+ def get_metadata(self) -> dict:
725
+ metadata = {
726
+ "resolution": self.resolution,
727
+ "caption_extension": self.caption_extension,
728
+ "batch_size_per_device": self.batch_size,
729
+ "enable_bucket": bool(self.enable_bucket),
730
+ "bucket_no_upscale": bool(self.bucket_no_upscale),
731
+ }
732
+ return metadata
733
+
734
+ def get_latent_cache_path(self, item_info: ItemInfo) -> str:
735
+ w, h = item_info.original_size
736
+ basename = os.path.splitext(os.path.basename(item_info.item_key))[0]
737
+ assert self.cache_directory is not None, "cache_directory is required / cache_directoryは必須です"
738
+ return os.path.join(self.cache_directory, f"{basename}_{w:04d}x{h:04d}_{ARCHITECTURE_HUNYUAN_VIDEO}.safetensors")
739
+
740
+ def get_text_encoder_output_cache_path(self, item_info: ItemInfo) -> str:
741
+ basename = os.path.splitext(os.path.basename(item_info.item_key))[0]
742
+ assert self.cache_directory is not None, "cache_directory is required / cache_directoryは必須です"
743
+ return os.path.join(self.cache_directory, f"{basename}_{ARCHITECTURE_HUNYUAN_VIDEO}_te.safetensors")
744
+
745
+ def retrieve_latent_cache_batches(self, num_workers: int):
746
+ raise NotImplementedError
747
+
748
+ def retrieve_text_encoder_output_cache_batches(self, num_workers: int):
749
+ raise NotImplementedError
750
+
751
+ def prepare_for_training(self):
752
+ pass
753
+
754
+ def set_seed(self, seed: int):
755
+ self.seed = seed
756
+
757
+ def set_current_epoch(self, epoch):
758
+ if not self.current_epoch == epoch: # shuffle buckets when epoch is incremented
759
+ if epoch > self.current_epoch:
760
+ logger.info("epoch is incremented. current_epoch: {}, epoch: {}".format(self.current_epoch, epoch))
761
+ num_epochs = epoch - self.current_epoch
762
+ for _ in range(num_epochs):
763
+ self.current_epoch += 1
764
+ self.shuffle_buckets()
765
+ # self.current_epoch seem to be set to 0 again in the next epoch. it may be caused by skipped_dataloader?
766
+ else:
767
+ logger.warning("epoch is not incremented. current_epoch: {}, epoch: {}".format(self.current_epoch, epoch))
768
+ self.current_epoch = epoch
769
+
770
+ def set_current_step(self, step):
771
+ self.current_step = step
772
+
773
+ def set_max_train_steps(self, max_train_steps):
774
+ self.max_train_steps = max_train_steps
775
+
776
+ def shuffle_buckets(self):
777
+ raise NotImplementedError
778
+
779
+ def __len__(self):
780
+ return NotImplementedError
781
+
782
+ def __getitem__(self, idx):
783
+ raise NotImplementedError
784
+
785
+ def _default_retrieve_text_encoder_output_cache_batches(self, datasource: ContentDatasource, batch_size: int, num_workers: int):
786
+ datasource.set_caption_only(True)
787
+ executor = ThreadPoolExecutor(max_workers=num_workers)
788
+
789
+ data: list[ItemInfo] = []
790
+ futures = []
791
+
792
+ def aggregate_future(consume_all: bool = False):
793
+ while len(futures) >= num_workers or (consume_all and len(futures) > 0):
794
+ completed_futures = [future for future in futures if future.done()]
795
+ if len(completed_futures) == 0:
796
+ if len(futures) >= num_workers or consume_all: # to avoid adding too many futures
797
+ time.sleep(0.1)
798
+ continue
799
+ else:
800
+ break # submit batch if possible
801
+
802
+ for future in completed_futures:
803
+ item_key, caption = future.result()
804
+ item_info = ItemInfo(item_key, caption, (0, 0), (0, 0))
805
+ item_info.text_encoder_output_cache_path = self.get_text_encoder_output_cache_path(item_info)
806
+ data.append(item_info)
807
+
808
+ futures.remove(future)
809
+
810
+ def submit_batch(flush: bool = False):
811
+ nonlocal data
812
+ if len(data) >= batch_size or (len(data) > 0 and flush):
813
+ batch = data[0:batch_size]
814
+ if len(data) > batch_size:
815
+ data = data[batch_size:]
816
+ else:
817
+ data = []
818
+ return batch
819
+ return None
820
+
821
+ for fetch_op in datasource:
822
+ future = executor.submit(fetch_op)
823
+ futures.append(future)
824
+ aggregate_future()
825
+ while True:
826
+ batch = submit_batch()
827
+ if batch is None:
828
+ break
829
+ yield batch
830
+
831
+ aggregate_future(consume_all=True)
832
+ while True:
833
+ batch = submit_batch(flush=True)
834
+ if batch is None:
835
+ break
836
+ yield batch
837
+
838
+ executor.shutdown()
839
+
840
+
841
+ class ImageDataset(BaseDataset):
842
+ def __init__(
843
+ self,
844
+ resolution: Tuple[int, int],
845
+ caption_extension: Optional[str],
846
+ batch_size: int,
847
+ enable_bucket: bool,
848
+ bucket_no_upscale: bool,
849
+ image_directory: Optional[str] = None,
850
+ image_jsonl_file: Optional[str] = None,
851
+ cache_directory: Optional[str] = None,
852
+ debug_dataset: bool = False,
853
+ ):
854
+ super(ImageDataset, self).__init__(
855
+ resolution, caption_extension, batch_size, enable_bucket, bucket_no_upscale, cache_directory, debug_dataset
856
+ )
857
+ self.image_directory = image_directory
858
+ self.image_jsonl_file = image_jsonl_file
859
+ if image_directory is not None:
860
+ self.datasource = ImageDirectoryDatasource(image_directory, caption_extension)
861
+ elif image_jsonl_file is not None:
862
+ self.datasource = ImageJsonlDatasource(image_jsonl_file)
863
+ else:
864
+ raise ValueError("image_directory or image_jsonl_file must be specified")
865
+
866
+ if self.cache_directory is None:
867
+ self.cache_directory = self.image_directory
868
+
869
+ self.batch_manager = None
870
+ self.num_train_items = 0
871
+
872
+ def get_metadata(self):
873
+ metadata = super().get_metadata()
874
+ if self.image_directory is not None:
875
+ metadata["image_directory"] = os.path.basename(self.image_directory)
876
+ if self.image_jsonl_file is not None:
877
+ metadata["image_jsonl_file"] = os.path.basename(self.image_jsonl_file)
878
+ return metadata
879
+
880
+ def get_total_image_count(self):
881
+ return len(self.datasource) if self.datasource.is_indexable() else None
882
+
883
+ def retrieve_latent_cache_batches(self, num_workers: int):
884
+ buckset_selector = BucketSelector(self.resolution, self.enable_bucket, self.bucket_no_upscale)
885
+ executor = ThreadPoolExecutor(max_workers=num_workers)
886
+
887
+ batches: dict[tuple[int, int], list[ItemInfo]] = {} # (width, height) -> [ItemInfo]
888
+ futures = []
889
+
890
+ def aggregate_future(consume_all: bool = False):
891
+ while len(futures) >= num_workers or (consume_all and len(futures) > 0):
892
+ completed_futures = [future for future in futures if future.done()]
893
+ if len(completed_futures) == 0:
894
+ if len(futures) >= num_workers or consume_all: # to avoid adding too many futures
895
+ time.sleep(0.1)
896
+ continue
897
+ else:
898
+ break # submit batch if possible
899
+
900
+ for future in completed_futures:
901
+ original_size, item_key, image, caption = future.result()
902
+ bucket_height, bucket_width = image.shape[:2]
903
+ bucket_reso = (bucket_width, bucket_height)
904
+
905
+ item_info = ItemInfo(item_key, caption, original_size, bucket_reso, content=image)
906
+ item_info.latent_cache_path = self.get_latent_cache_path(item_info)
907
+
908
+ if bucket_reso not in batches:
909
+ batches[bucket_reso] = []
910
+ batches[bucket_reso].append(item_info)
911
+
912
+ futures.remove(future)
913
+
914
+ def submit_batch(flush: bool = False):
915
+ for key in batches:
916
+ if len(batches[key]) >= self.batch_size or flush:
917
+ batch = batches[key][0 : self.batch_size]
918
+ if len(batches[key]) > self.batch_size:
919
+ batches[key] = batches[key][self.batch_size :]
920
+ else:
921
+ del batches[key]
922
+ return key, batch
923
+ return None, None
924
+
925
+ for fetch_op in self.datasource:
926
+
927
+ def fetch_and_resize(op: callable) -> tuple[tuple[int, int], str, Image.Image, str]:
928
+ image_key, image, caption = op()
929
+ image: Image.Image
930
+ image_size = image.size
931
+
932
+ bucket_reso = buckset_selector.get_bucket_resolution(image_size)
933
+ image = resize_image_to_bucket(image, bucket_reso)
934
+ return image_size, image_key, image, caption
935
+
936
+ future = executor.submit(fetch_and_resize, fetch_op)
937
+ futures.append(future)
938
+ aggregate_future()
939
+ while True:
940
+ key, batch = submit_batch()
941
+ if key is None:
942
+ break
943
+ yield key, batch
944
+
945
+ aggregate_future(consume_all=True)
946
+ while True:
947
+ key, batch = submit_batch(flush=True)
948
+ if key is None:
949
+ break
950
+ yield key, batch
951
+
952
+ executor.shutdown()
953
+
954
+ def retrieve_text_encoder_output_cache_batches(self, num_workers: int):
955
+ return self._default_retrieve_text_encoder_output_cache_batches(self.datasource, self.batch_size, num_workers)
956
+
957
+ def prepare_for_training(self):
958
+ bucket_selector = BucketSelector(self.resolution, self.enable_bucket, self.bucket_no_upscale)
959
+
960
+ # glob cache files
961
+ latent_cache_files = glob.glob(os.path.join(self.cache_directory, f"*_{ARCHITECTURE_HUNYUAN_VIDEO}.safetensors"))
962
+
963
+ # assign cache files to item info
964
+ bucketed_item_info: dict[tuple[int, int], list[ItemInfo]] = {} # (width, height) -> [ItemInfo]
965
+ for cache_file in latent_cache_files:
966
+ tokens = os.path.basename(cache_file).split("_")
967
+
968
+ image_size = tokens[-2] # 0000x0000
969
+ image_width, image_height = map(int, image_size.split("x"))
970
+ image_size = (image_width, image_height)
971
+
972
+ item_key = "_".join(tokens[:-2])
973
+ text_encoder_output_cache_file = os.path.join(
974
+ self.cache_directory, f"{item_key}_{ARCHITECTURE_HUNYUAN_VIDEO}_te.safetensors"
975
+ )
976
+ if not os.path.exists(text_encoder_output_cache_file):
977
+ logger.warning(f"Text encoder output cache file not found: {text_encoder_output_cache_file}")
978
+ continue
979
+
980
+ bucket_reso = bucket_selector.get_bucket_resolution(image_size)
981
+ item_info = ItemInfo(item_key, "", image_size, bucket_reso, latent_cache_path=cache_file)
982
+ item_info.text_encoder_output_cache_path = text_encoder_output_cache_file
983
+
984
+ bucket = bucketed_item_info.get(bucket_reso, [])
985
+ bucket.append(item_info)
986
+ bucketed_item_info[bucket_reso] = bucket
987
+
988
+ # prepare batch manager
989
+ self.batch_manager = BucketBatchManager(bucketed_item_info, self.batch_size)
990
+ self.batch_manager.show_bucket_info()
991
+
992
+ self.num_train_items = sum([len(bucket) for bucket in bucketed_item_info.values()])
993
+
994
+ def shuffle_buckets(self):
995
+ # set random seed for this epoch
996
+ random.seed(self.seed + self.current_epoch)
997
+ self.batch_manager.shuffle()
998
+
999
+ def __len__(self):
1000
+ if self.batch_manager is None:
1001
+ return 100 # dummy value
1002
+ return len(self.batch_manager)
1003
+
1004
+ def __getitem__(self, idx):
1005
+ return self.batch_manager[idx]
1006
+
1007
+
1008
+ class VideoDataset(BaseDataset):
1009
+ def __init__(
1010
+ self,
1011
+ resolution: Tuple[int, int],
1012
+ caption_extension: Optional[str],
1013
+ batch_size: int,
1014
+ enable_bucket: bool,
1015
+ bucket_no_upscale: bool,
1016
+ frame_extraction: Optional[str] = "head",
1017
+ frame_stride: Optional[int] = 1,
1018
+ frame_sample: Optional[int] = 1,
1019
+ target_frames: Optional[list[int]] = None,
1020
+ video_directory: Optional[str] = None,
1021
+ video_jsonl_file: Optional[str] = None,
1022
+ cache_directory: Optional[str] = None,
1023
+ debug_dataset: bool = False,
1024
+ ):
1025
+ super(VideoDataset, self).__init__(
1026
+ resolution, caption_extension, batch_size, enable_bucket, bucket_no_upscale, cache_directory, debug_dataset
1027
+ )
1028
+ self.video_directory = video_directory
1029
+ self.video_jsonl_file = video_jsonl_file
1030
+ self.target_frames = target_frames
1031
+ self.frame_extraction = frame_extraction
1032
+ self.frame_stride = frame_stride
1033
+ self.frame_sample = frame_sample
1034
+
1035
+ if video_directory is not None:
1036
+ self.datasource = VideoDirectoryDatasource(video_directory, caption_extension)
1037
+ elif video_jsonl_file is not None:
1038
+ self.datasource = VideoJsonlDatasource(video_jsonl_file)
1039
+
1040
+ if self.frame_extraction == "uniform" and self.frame_sample == 1:
1041
+ self.frame_extraction = "head"
1042
+ logger.warning("frame_sample is set to 1 for frame_extraction=uniform. frame_extraction is changed to head.")
1043
+ if self.frame_extraction == "head":
1044
+ # head extraction. we can limit the number of frames to be extracted
1045
+ self.datasource.set_start_and_end_frame(0, max(self.target_frames))
1046
+
1047
+ if self.cache_directory is None:
1048
+ self.cache_directory = self.video_directory
1049
+
1050
+ self.batch_manager = None
1051
+ self.num_train_items = 0
1052
+
1053
+ def get_metadata(self):
1054
+ metadata = super().get_metadata()
1055
+ if self.video_directory is not None:
1056
+ metadata["video_directory"] = os.path.basename(self.video_directory)
1057
+ if self.video_jsonl_file is not None:
1058
+ metadata["video_jsonl_file"] = os.path.basename(self.video_jsonl_file)
1059
+ metadata["frame_extraction"] = self.frame_extraction
1060
+ metadata["frame_stride"] = self.frame_stride
1061
+ metadata["frame_sample"] = self.frame_sample
1062
+ metadata["target_frames"] = self.target_frames
1063
+ return metadata
1064
+
1065
+ def retrieve_latent_cache_batches(self, num_workers: int):
1066
+ buckset_selector = BucketSelector(self.resolution)
1067
+ self.datasource.set_bucket_selector(buckset_selector)
1068
+
1069
+ executor = ThreadPoolExecutor(max_workers=num_workers)
1070
+
1071
+ # key: (width, height, frame_count), value: [ItemInfo]
1072
+ batches: dict[tuple[int, int, int], list[ItemInfo]] = {}
1073
+ futures = []
1074
+
1075
+ def aggregate_future(consume_all: bool = False):
1076
+ while len(futures) >= num_workers or (consume_all and len(futures) > 0):
1077
+ completed_futures = [future for future in futures if future.done()]
1078
+ if len(completed_futures) == 0:
1079
+ if len(futures) >= num_workers or consume_all: # to avoid adding too many futures
1080
+ time.sleep(0.1)
1081
+ continue
1082
+ else:
1083
+ break # submit batch if possible
1084
+
1085
+ for future in completed_futures:
1086
+ original_frame_size, video_key, video, caption = future.result()
1087
+
1088
+ frame_count = len(video)
1089
+ video = np.stack(video, axis=0)
1090
+ height, width = video.shape[1:3]
1091
+ bucket_reso = (width, height) # already resized
1092
+
1093
+ crop_pos_and_frames = []
1094
+ if self.frame_extraction == "head":
1095
+ for target_frame in self.target_frames:
1096
+ if frame_count >= target_frame:
1097
+ crop_pos_and_frames.append((0, target_frame))
1098
+ elif self.frame_extraction == "chunk":
1099
+ # split by target_frames
1100
+ for target_frame in self.target_frames:
1101
+ for i in range(0, frame_count, target_frame):
1102
+ if i + target_frame <= frame_count:
1103
+ crop_pos_and_frames.append((i, target_frame))
1104
+ elif self.frame_extraction == "slide":
1105
+ # slide window
1106
+ for target_frame in self.target_frames:
1107
+ if frame_count >= target_frame:
1108
+ for i in range(0, frame_count - target_frame + 1, self.frame_stride):
1109
+ crop_pos_and_frames.append((i, target_frame))
1110
+ elif self.frame_extraction == "uniform":
1111
+ # select N frames uniformly
1112
+ for target_frame in self.target_frames:
1113
+ if frame_count >= target_frame:
1114
+ frame_indices = np.linspace(0, frame_count - target_frame, self.frame_sample, dtype=int)
1115
+ for i in frame_indices:
1116
+ crop_pos_and_frames.append((i, target_frame))
1117
+ else:
1118
+ raise ValueError(f"frame_extraction {self.frame_extraction} is not supported")
1119
+
1120
+ for crop_pos, target_frame in crop_pos_and_frames:
1121
+ cropped_video = video[crop_pos : crop_pos + target_frame]
1122
+ body, ext = os.path.splitext(video_key)
1123
+ item_key = f"{body}_{crop_pos:05d}-{target_frame:03d}{ext}"
1124
+ batch_key = (*bucket_reso, target_frame) # bucket_reso with frame_count
1125
+
1126
+ item_info = ItemInfo(
1127
+ item_key, caption, original_frame_size, batch_key, frame_count=target_frame, content=cropped_video
1128
+ )
1129
+ item_info.latent_cache_path = self.get_latent_cache_path(item_info)
1130
+
1131
+ batch = batches.get(batch_key, [])
1132
+ batch.append(item_info)
1133
+ batches[batch_key] = batch
1134
+
1135
+ futures.remove(future)
1136
+
1137
+ def submit_batch(flush: bool = False):
1138
+ for key in batches:
1139
+ if len(batches[key]) >= self.batch_size or flush:
1140
+ batch = batches[key][0 : self.batch_size]
1141
+ if len(batches[key]) > self.batch_size:
1142
+ batches[key] = batches[key][self.batch_size :]
1143
+ else:
1144
+ del batches[key]
1145
+ return key, batch
1146
+ return None, None
1147
+
1148
+ for operator in self.datasource:
1149
+
1150
+ def fetch_and_resize(op: callable) -> tuple[tuple[int, int], str, list[np.ndarray], str]:
1151
+ video_key, video, caption = op()
1152
+ video: list[np.ndarray]
1153
+ frame_size = (video[0].shape[1], video[0].shape[0])
1154
+
1155
+ # resize if necessary
1156
+ bucket_reso = buckset_selector.get_bucket_resolution(frame_size)
1157
+ video = [resize_image_to_bucket(frame, bucket_reso) for frame in video]
1158
+
1159
+ return frame_size, video_key, video, caption
1160
+
1161
+ future = executor.submit(fetch_and_resize, operator)
1162
+ futures.append(future)
1163
+ aggregate_future()
1164
+ while True:
1165
+ key, batch = submit_batch()
1166
+ if key is None:
1167
+ break
1168
+ yield key, batch
1169
+
1170
+ aggregate_future(consume_all=True)
1171
+ while True:
1172
+ key, batch = submit_batch(flush=True)
1173
+ if key is None:
1174
+ break
1175
+ yield key, batch
1176
+
1177
+ executor.shutdown()
1178
+
1179
+ def retrieve_text_encoder_output_cache_batches(self, num_workers: int):
1180
+ return self._default_retrieve_text_encoder_output_cache_batches(self.datasource, self.batch_size, num_workers)
1181
+
1182
+ def prepare_for_training(self):
1183
+ bucket_selector = BucketSelector(self.resolution, self.enable_bucket, self.bucket_no_upscale)
1184
+
1185
+ # glob cache files
1186
+ latent_cache_files = glob.glob(os.path.join(self.cache_directory, f"*_{ARCHITECTURE_HUNYUAN_VIDEO}.safetensors"))
1187
+
1188
+ # assign cache files to item info
1189
+ bucketed_item_info: dict[tuple[int, int, int], list[ItemInfo]] = {} # (width, height, frame_count) -> [ItemInfo]
1190
+ for cache_file in latent_cache_files:
1191
+ tokens = os.path.basename(cache_file).split("_")
1192
+
1193
+ image_size = tokens[-2] # 0000x0000
1194
+ image_width, image_height = map(int, image_size.split("x"))
1195
+ image_size = (image_width, image_height)
1196
+
1197
+ frame_pos, frame_count = tokens[-3].split("-")
1198
+ frame_pos, frame_count = int(frame_pos), int(frame_count)
1199
+
1200
+ item_key = "_".join(tokens[:-3])
1201
+ text_encoder_output_cache_file = os.path.join(
1202
+ self.cache_directory, f"{item_key}_{ARCHITECTURE_HUNYUAN_VIDEO}_te.safetensors"
1203
+ )
1204
+ if not os.path.exists(text_encoder_output_cache_file):
1205
+ logger.warning(f"Text encoder output cache file not found: {text_encoder_output_cache_file}")
1206
+ continue
1207
+
1208
+ bucket_reso = bucket_selector.get_bucket_resolution(image_size)
1209
+ bucket_reso = (*bucket_reso, frame_count)
1210
+ item_info = ItemInfo(item_key, "", image_size, bucket_reso, frame_count=frame_count, latent_cache_path=cache_file)
1211
+ item_info.text_encoder_output_cache_path = text_encoder_output_cache_file
1212
+
1213
+ bucket = bucketed_item_info.get(bucket_reso, [])
1214
+ bucket.append(item_info)
1215
+ bucketed_item_info[bucket_reso] = bucket
1216
+
1217
+ # prepare batch manager
1218
+ self.batch_manager = BucketBatchManager(bucketed_item_info, self.batch_size)
1219
+ self.batch_manager.show_bucket_info()
1220
+
1221
+ self.num_train_items = sum([len(bucket) for bucket in bucketed_item_info.values()])
1222
+
1223
+ def shuffle_buckets(self):
1224
+ # set random seed for this epoch
1225
+ random.seed(self.seed + self.current_epoch)
1226
+ self.batch_manager.shuffle()
1227
+
1228
+ def __len__(self):
1229
+ if self.batch_manager is None:
1230
+ return 100 # dummy value
1231
+ return len(self.batch_manager)
1232
+
1233
+ def __getitem__(self, idx):
1234
+ return self.batch_manager[idx]
1235
+
1236
+
1237
+ class DatasetGroup(torch.utils.data.ConcatDataset):
1238
+ def __init__(self, datasets: Sequence[Union[ImageDataset, VideoDataset]]):
1239
+ super().__init__(datasets)
1240
+ self.datasets: list[Union[ImageDataset, VideoDataset]] = datasets
1241
+ self.num_train_items = 0
1242
+ for dataset in self.datasets:
1243
+ self.num_train_items += dataset.num_train_items
1244
+
1245
+ def set_current_epoch(self, epoch):
1246
+ for dataset in self.datasets:
1247
+ dataset.set_current_epoch(epoch)
1248
+
1249
+ def set_current_step(self, step):
1250
+ for dataset in self.datasets:
1251
+ dataset.set_current_step(step)
1252
+
1253
+ def set_max_train_steps(self, max_train_steps):
1254
+ for dataset in self.datasets:
1255
+ dataset.set_max_train_steps(max_train_steps)
hunyuan_model/__init__.py ADDED
File without changes
hunyuan_model/activation_layers.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+
4
+ def get_activation_layer(act_type):
5
+ """get activation layer
6
+
7
+ Args:
8
+ act_type (str): the activation type
9
+
10
+ Returns:
11
+ torch.nn.functional: the activation layer
12
+ """
13
+ if act_type == "gelu":
14
+ return lambda: nn.GELU()
15
+ elif act_type == "gelu_tanh":
16
+ # Approximate `tanh` requires torch >= 1.13
17
+ return lambda: nn.GELU(approximate="tanh")
18
+ elif act_type == "relu":
19
+ return nn.ReLU
20
+ elif act_type == "silu":
21
+ return nn.SiLU
22
+ else:
23
+ raise ValueError(f"Unknown activation type: {act_type}")
hunyuan_model/attention.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib.metadata
2
+ import math
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ try:
9
+ import flash_attn
10
+ from flash_attn.flash_attn_interface import _flash_attn_forward
11
+ from flash_attn.flash_attn_interface import flash_attn_varlen_func
12
+ except ImportError:
13
+ flash_attn = None
14
+ flash_attn_varlen_func = None
15
+ _flash_attn_forward = None
16
+
17
+ try:
18
+ print(f"Trying to import sageattention")
19
+ from sageattention import sageattn_varlen
20
+
21
+ print("Successfully imported sageattention")
22
+ except ImportError:
23
+ print(f"Failed to import flash_attn and sageattention")
24
+ sageattn_varlen = None
25
+
26
+ MEMORY_LAYOUT = {
27
+ "flash": (
28
+ lambda x: x.view(x.shape[0] * x.shape[1], *x.shape[2:]),
29
+ lambda x: x,
30
+ ),
31
+ "sageattn": (
32
+ lambda x: x.view(x.shape[0] * x.shape[1], *x.shape[2:]),
33
+ lambda x: x,
34
+ ),
35
+ "torch": (
36
+ lambda x: x.transpose(1, 2),
37
+ lambda x: x.transpose(1, 2),
38
+ ),
39
+ "vanilla": (
40
+ lambda x: x.transpose(1, 2),
41
+ lambda x: x.transpose(1, 2),
42
+ ),
43
+ }
44
+
45
+
46
+ def get_cu_seqlens(text_mask, img_len):
47
+ """Calculate cu_seqlens_q, cu_seqlens_kv using text_mask and img_len
48
+
49
+ Args:
50
+ text_mask (torch.Tensor): the mask of text
51
+ img_len (int): the length of image
52
+
53
+ Returns:
54
+ torch.Tensor: the calculated cu_seqlens for flash attention
55
+ """
56
+ batch_size = text_mask.shape[0]
57
+ text_len = text_mask.sum(dim=1)
58
+ max_len = text_mask.shape[1] + img_len
59
+
60
+ cu_seqlens = torch.zeros([2 * batch_size + 1], dtype=torch.int32, device="cuda")
61
+
62
+ for i in range(batch_size):
63
+ s = text_len[i] + img_len
64
+ s1 = i * max_len + s
65
+ s2 = (i + 1) * max_len
66
+ cu_seqlens[2 * i + 1] = s1
67
+ cu_seqlens[2 * i + 2] = s2
68
+
69
+ return cu_seqlens
70
+
71
+
72
+ def attention(
73
+ q_or_qkv_list,
74
+ k=None,
75
+ v=None,
76
+ mode="flash",
77
+ drop_rate=0,
78
+ attn_mask=None,
79
+ causal=False,
80
+ cu_seqlens_q=None,
81
+ cu_seqlens_kv=None,
82
+ max_seqlen_q=None,
83
+ max_seqlen_kv=None,
84
+ batch_size=1,
85
+ ):
86
+ """
87
+ Perform QKV self attention.
88
+
89
+ Args:
90
+ q (torch.Tensor): Query tensor with shape [b, s, a, d], where a is the number of heads.
91
+ k (torch.Tensor): Key tensor with shape [b, s1, a, d]
92
+ v (torch.Tensor): Value tensor with shape [b, s1, a, d]
93
+ mode (str): Attention mode. Choose from 'self_flash', 'cross_flash', 'torch', and 'vanilla'.
94
+ drop_rate (float): Dropout rate in attention map. (default: 0)
95
+ attn_mask (torch.Tensor): Attention mask with shape [b, s1] (cross_attn), or [b, a, s, s1] (torch or vanilla).
96
+ (default: None)
97
+ causal (bool): Whether to use causal attention. (default: False)
98
+ cu_seqlens_q (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
99
+ used to index into q.
100
+ cu_seqlens_kv (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
101
+ used to index into kv.
102
+ max_seqlen_q (int): The maximum sequence length in the batch of q.
103
+ max_seqlen_kv (int): The maximum sequence length in the batch of k and v.
104
+
105
+ Returns:
106
+ torch.Tensor: Output tensor after self attention with shape [b, s, ad]
107
+ """
108
+ q, k, v = q_or_qkv_list if type(q_or_qkv_list) == list else (q_or_qkv_list, k, v)
109
+ pre_attn_layout, post_attn_layout = MEMORY_LAYOUT[mode]
110
+ q = pre_attn_layout(q)
111
+ k = pre_attn_layout(k)
112
+ v = pre_attn_layout(v)
113
+
114
+ if mode == "torch":
115
+ if attn_mask is not None and attn_mask.dtype != torch.bool:
116
+ attn_mask = attn_mask.to(q.dtype)
117
+ x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal)
118
+ if type(q_or_qkv_list) == list:
119
+ q_or_qkv_list.clear()
120
+ del q, k, v
121
+ del attn_mask
122
+ elif mode == "flash":
123
+ x = flash_attn_varlen_func(
124
+ q,
125
+ k,
126
+ v,
127
+ cu_seqlens_q,
128
+ cu_seqlens_kv,
129
+ max_seqlen_q,
130
+ max_seqlen_kv,
131
+ )
132
+ if type(q_or_qkv_list) == list:
133
+ q_or_qkv_list.clear()
134
+ del q, k, v
135
+ # x with shape [(bxs), a, d]
136
+ x = x.view(batch_size, max_seqlen_q, x.shape[-2], x.shape[-1]) # reshape x to [b, s, a, d]
137
+ elif mode == "sageattn":
138
+ x = sageattn_varlen(
139
+ q,
140
+ k,
141
+ v,
142
+ cu_seqlens_q,
143
+ cu_seqlens_kv,
144
+ max_seqlen_q,
145
+ max_seqlen_kv,
146
+ )
147
+ if type(q_or_qkv_list) == list:
148
+ q_or_qkv_list.clear()
149
+ del q, k, v
150
+ # x with shape [(bxs), a, d]
151
+ x = x.view(batch_size, max_seqlen_q, x.shape[-2], x.shape[-1]) # reshape x to [b, s, a, d]
152
+ elif mode == "vanilla":
153
+ scale_factor = 1 / math.sqrt(q.size(-1))
154
+
155
+ b, a, s, _ = q.shape
156
+ s1 = k.size(2)
157
+ attn_bias = torch.zeros(b, a, s, s1, dtype=q.dtype, device=q.device)
158
+ if causal:
159
+ # Only applied to self attention
160
+ assert attn_mask is None, "Causal mask and attn_mask cannot be used together"
161
+ temp_mask = torch.ones(b, a, s, s, dtype=torch.bool, device=q.device).tril(diagonal=0)
162
+ attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
163
+ attn_bias.to(q.dtype)
164
+
165
+ if attn_mask is not None:
166
+ if attn_mask.dtype == torch.bool:
167
+ attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
168
+ else:
169
+ attn_bias += attn_mask
170
+
171
+ # TODO: Maybe force q and k to be float32 to avoid numerical overflow
172
+ attn = (q @ k.transpose(-2, -1)) * scale_factor
173
+ attn += attn_bias
174
+ attn = attn.softmax(dim=-1)
175
+ attn = torch.dropout(attn, p=drop_rate, train=True)
176
+ x = attn @ v
177
+ else:
178
+ raise NotImplementedError(f"Unsupported attention mode: {mode}")
179
+
180
+ x = post_attn_layout(x)
181
+ b, s, a, d = x.shape
182
+ out = x.reshape(b, s, -1)
183
+ return out
184
+
185
+
186
+ def parallel_attention(hybrid_seq_parallel_attn, q, k, v, img_q_len, img_kv_len, cu_seqlens_q, cu_seqlens_kv):
187
+ attn1 = hybrid_seq_parallel_attn(
188
+ None,
189
+ q[:, :img_q_len, :, :],
190
+ k[:, :img_kv_len, :, :],
191
+ v[:, :img_kv_len, :, :],
192
+ dropout_p=0.0,
193
+ causal=False,
194
+ joint_tensor_query=q[:, img_q_len : cu_seqlens_q[1]],
195
+ joint_tensor_key=k[:, img_kv_len : cu_seqlens_kv[1]],
196
+ joint_tensor_value=v[:, img_kv_len : cu_seqlens_kv[1]],
197
+ joint_strategy="rear",
198
+ )
199
+ if flash_attn.__version__ >= "2.7.0":
200
+ attn2, *_ = _flash_attn_forward(
201
+ q[:, cu_seqlens_q[1] :],
202
+ k[:, cu_seqlens_kv[1] :],
203
+ v[:, cu_seqlens_kv[1] :],
204
+ dropout_p=0.0,
205
+ softmax_scale=q.shape[-1] ** (-0.5),
206
+ causal=False,
207
+ window_size_left=-1,
208
+ window_size_right=-1,
209
+ softcap=0.0,
210
+ alibi_slopes=None,
211
+ return_softmax=False,
212
+ )
213
+ else:
214
+ attn2, *_ = _flash_attn_forward(
215
+ q[:, cu_seqlens_q[1] :],
216
+ k[:, cu_seqlens_kv[1] :],
217
+ v[:, cu_seqlens_kv[1] :],
218
+ dropout_p=0.0,
219
+ softmax_scale=q.shape[-1] ** (-0.5),
220
+ causal=False,
221
+ window_size=(-1, -1),
222
+ softcap=0.0,
223
+ alibi_slopes=None,
224
+ return_softmax=False,
225
+ )
226
+ attn = torch.cat([attn1, attn2], dim=1)
227
+ b, s, a, d = attn.shape
228
+ attn = attn.reshape(b, s, -1)
229
+
230
+ return attn
hunyuan_model/autoencoder_kl_causal_3d.py ADDED
@@ -0,0 +1,609 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ #
16
+ # Modified from diffusers==0.29.2
17
+ #
18
+ # ==============================================================================
19
+ from typing import Dict, Optional, Tuple, Union
20
+ from dataclasses import dataclass
21
+
22
+ import torch
23
+ import torch.nn as nn
24
+
25
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
26
+
27
+ try:
28
+ # This diffusers is modified and packed in the mirror.
29
+ from diffusers.loaders import FromOriginalVAEMixin
30
+ except ImportError:
31
+ # Use this to be compatible with the original diffusers.
32
+ from diffusers.loaders.single_file_model import FromOriginalModelMixin as FromOriginalVAEMixin
33
+ from diffusers.utils.accelerate_utils import apply_forward_hook
34
+ from diffusers.models.attention_processor import (
35
+ ADDED_KV_ATTENTION_PROCESSORS,
36
+ CROSS_ATTENTION_PROCESSORS,
37
+ Attention,
38
+ AttentionProcessor,
39
+ AttnAddedKVProcessor,
40
+ AttnProcessor,
41
+ )
42
+ from diffusers.models.modeling_outputs import AutoencoderKLOutput
43
+ from diffusers.models.modeling_utils import ModelMixin
44
+ from .vae import DecoderCausal3D, BaseOutput, DecoderOutput, DiagonalGaussianDistribution, EncoderCausal3D
45
+
46
+
47
+ @dataclass
48
+ class DecoderOutput2(BaseOutput):
49
+ sample: torch.FloatTensor
50
+ posterior: Optional[DiagonalGaussianDistribution] = None
51
+
52
+
53
+ class AutoencoderKLCausal3D(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
54
+ r"""
55
+ A VAE model with KL loss for encoding images/videos into latents and decoding latent representations into images/videos.
56
+
57
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
58
+ for all models (such as downloading or saving).
59
+ """
60
+
61
+ _supports_gradient_checkpointing = True
62
+
63
+ @register_to_config
64
+ def __init__(
65
+ self,
66
+ in_channels: int = 3,
67
+ out_channels: int = 3,
68
+ down_block_types: Tuple[str] = ("DownEncoderBlockCausal3D",),
69
+ up_block_types: Tuple[str] = ("UpDecoderBlockCausal3D",),
70
+ block_out_channels: Tuple[int] = (64,),
71
+ layers_per_block: int = 1,
72
+ act_fn: str = "silu",
73
+ latent_channels: int = 4,
74
+ norm_num_groups: int = 32,
75
+ sample_size: int = 32,
76
+ sample_tsize: int = 64,
77
+ scaling_factor: float = 0.18215,
78
+ force_upcast: float = True,
79
+ spatial_compression_ratio: int = 8,
80
+ time_compression_ratio: int = 4,
81
+ mid_block_add_attention: bool = True,
82
+ ):
83
+ super().__init__()
84
+
85
+ self.time_compression_ratio = time_compression_ratio
86
+
87
+ self.encoder = EncoderCausal3D(
88
+ in_channels=in_channels,
89
+ out_channels=latent_channels,
90
+ down_block_types=down_block_types,
91
+ block_out_channels=block_out_channels,
92
+ layers_per_block=layers_per_block,
93
+ act_fn=act_fn,
94
+ norm_num_groups=norm_num_groups,
95
+ double_z=True,
96
+ time_compression_ratio=time_compression_ratio,
97
+ spatial_compression_ratio=spatial_compression_ratio,
98
+ mid_block_add_attention=mid_block_add_attention,
99
+ )
100
+
101
+ self.decoder = DecoderCausal3D(
102
+ in_channels=latent_channels,
103
+ out_channels=out_channels,
104
+ up_block_types=up_block_types,
105
+ block_out_channels=block_out_channels,
106
+ layers_per_block=layers_per_block,
107
+ norm_num_groups=norm_num_groups,
108
+ act_fn=act_fn,
109
+ time_compression_ratio=time_compression_ratio,
110
+ spatial_compression_ratio=spatial_compression_ratio,
111
+ mid_block_add_attention=mid_block_add_attention,
112
+ )
113
+
114
+ self.quant_conv = nn.Conv3d(2 * latent_channels, 2 * latent_channels, kernel_size=1)
115
+ self.post_quant_conv = nn.Conv3d(latent_channels, latent_channels, kernel_size=1)
116
+
117
+ self.use_slicing = False
118
+ self.use_spatial_tiling = False
119
+ self.use_temporal_tiling = False
120
+
121
+ # only relevant if vae tiling is enabled
122
+ self.tile_sample_min_tsize = sample_tsize
123
+ self.tile_latent_min_tsize = sample_tsize // time_compression_ratio
124
+
125
+ self.tile_sample_min_size = self.config.sample_size
126
+ sample_size = self.config.sample_size[0] if isinstance(self.config.sample_size, (list, tuple)) else self.config.sample_size
127
+ self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
128
+ self.tile_overlap_factor = 0.25
129
+
130
+ def _set_gradient_checkpointing(self, module, value=False):
131
+ if isinstance(module, (EncoderCausal3D, DecoderCausal3D)):
132
+ module.gradient_checkpointing = value
133
+
134
+ def enable_temporal_tiling(self, use_tiling: bool = True):
135
+ self.use_temporal_tiling = use_tiling
136
+
137
+ def disable_temporal_tiling(self):
138
+ self.enable_temporal_tiling(False)
139
+
140
+ def enable_spatial_tiling(self, use_tiling: bool = True):
141
+ self.use_spatial_tiling = use_tiling
142
+
143
+ def disable_spatial_tiling(self):
144
+ self.enable_spatial_tiling(False)
145
+
146
+ def enable_tiling(self, use_tiling: bool = True):
147
+ r"""
148
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
149
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
150
+ processing larger videos.
151
+ """
152
+ self.enable_spatial_tiling(use_tiling)
153
+ self.enable_temporal_tiling(use_tiling)
154
+
155
+ def disable_tiling(self):
156
+ r"""
157
+ Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
158
+ decoding in one step.
159
+ """
160
+ self.disable_spatial_tiling()
161
+ self.disable_temporal_tiling()
162
+
163
+ def enable_slicing(self):
164
+ r"""
165
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
166
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
167
+ """
168
+ self.use_slicing = True
169
+
170
+ def disable_slicing(self):
171
+ r"""
172
+ Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
173
+ decoding in one step.
174
+ """
175
+ self.use_slicing = False
176
+
177
+ def set_chunk_size_for_causal_conv_3d(self, chunk_size: int):
178
+ # set chunk_size to CausalConv3d recursively
179
+ def set_chunk_size(module):
180
+ if hasattr(module, "chunk_size"):
181
+ module.chunk_size = chunk_size
182
+
183
+ self.apply(set_chunk_size)
184
+
185
+ @property
186
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
187
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
188
+ r"""
189
+ Returns:
190
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
191
+ indexed by its weight name.
192
+ """
193
+ # set recursively
194
+ processors = {}
195
+
196
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
197
+ if hasattr(module, "get_processor"):
198
+ processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
199
+
200
+ for sub_name, child in module.named_children():
201
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
202
+
203
+ return processors
204
+
205
+ for name, module in self.named_children():
206
+ fn_recursive_add_processors(name, module, processors)
207
+
208
+ return processors
209
+
210
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
211
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False):
212
+ r"""
213
+ Sets the attention processor to use to compute attention.
214
+
215
+ Parameters:
216
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
217
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
218
+ for **all** `Attention` layers.
219
+
220
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
221
+ processor. This is strongly recommended when setting trainable attention processors.
222
+
223
+ """
224
+ count = len(self.attn_processors.keys())
225
+
226
+ if isinstance(processor, dict) and len(processor) != count:
227
+ raise ValueError(
228
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
229
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
230
+ )
231
+
232
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
233
+ if hasattr(module, "set_processor"):
234
+ if not isinstance(processor, dict):
235
+ module.set_processor(processor, _remove_lora=_remove_lora)
236
+ else:
237
+ module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
238
+
239
+ for sub_name, child in module.named_children():
240
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
241
+
242
+ for name, module in self.named_children():
243
+ fn_recursive_attn_processor(name, module, processor)
244
+
245
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
246
+ def set_default_attn_processor(self):
247
+ """
248
+ Disables custom attention processors and sets the default attention implementation.
249
+ """
250
+ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
251
+ processor = AttnAddedKVProcessor()
252
+ elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
253
+ processor = AttnProcessor()
254
+ else:
255
+ raise ValueError(
256
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
257
+ )
258
+
259
+ self.set_attn_processor(processor, _remove_lora=True)
260
+
261
+ @apply_forward_hook
262
+ def encode(
263
+ self, x: torch.FloatTensor, return_dict: bool = True
264
+ ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
265
+ """
266
+ Encode a batch of images/videos into latents.
267
+
268
+ Args:
269
+ x (`torch.FloatTensor`): Input batch of images/videos.
270
+ return_dict (`bool`, *optional*, defaults to `True`):
271
+ Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
272
+
273
+ Returns:
274
+ The latent representations of the encoded images/videos. If `return_dict` is True, a
275
+ [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
276
+ """
277
+ assert len(x.shape) == 5, "The input tensor should have 5 dimensions."
278
+
279
+ if self.use_temporal_tiling and x.shape[2] > self.tile_sample_min_tsize:
280
+ return self.temporal_tiled_encode(x, return_dict=return_dict)
281
+
282
+ if self.use_spatial_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
283
+ return self.spatial_tiled_encode(x, return_dict=return_dict)
284
+
285
+ if self.use_slicing and x.shape[0] > 1:
286
+ encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
287
+ h = torch.cat(encoded_slices)
288
+ else:
289
+ h = self.encoder(x)
290
+
291
+ moments = self.quant_conv(h)
292
+ posterior = DiagonalGaussianDistribution(moments)
293
+
294
+ if not return_dict:
295
+ return (posterior,)
296
+
297
+ return AutoencoderKLOutput(latent_dist=posterior)
298
+
299
+ def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
300
+ assert len(z.shape) == 5, "The input tensor should have 5 dimensions."
301
+
302
+ if self.use_temporal_tiling and z.shape[2] > self.tile_latent_min_tsize:
303
+ return self.temporal_tiled_decode(z, return_dict=return_dict)
304
+
305
+ if self.use_spatial_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size):
306
+ return self.spatial_tiled_decode(z, return_dict=return_dict)
307
+
308
+ z = self.post_quant_conv(z)
309
+ dec = self.decoder(z)
310
+
311
+ if not return_dict:
312
+ return (dec,)
313
+
314
+ return DecoderOutput(sample=dec)
315
+
316
+ @apply_forward_hook
317
+ def decode(self, z: torch.FloatTensor, return_dict: bool = True, generator=None) -> Union[DecoderOutput, torch.FloatTensor]:
318
+ """
319
+ Decode a batch of images/videos.
320
+
321
+ Args:
322
+ z (`torch.FloatTensor`): Input batch of latent vectors.
323
+ return_dict (`bool`, *optional*, defaults to `True`):
324
+ Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
325
+
326
+ Returns:
327
+ [`~models.vae.DecoderOutput`] or `tuple`:
328
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
329
+ returned.
330
+
331
+ """
332
+ if self.use_slicing and z.shape[0] > 1:
333
+ decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
334
+ decoded = torch.cat(decoded_slices)
335
+ else:
336
+ decoded = self._decode(z).sample
337
+
338
+ if not return_dict:
339
+ return (decoded,)
340
+
341
+ return DecoderOutput(sample=decoded)
342
+
343
+ def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
344
+ blend_extent = min(a.shape[-2], b.shape[-2], blend_extent)
345
+ for y in range(blend_extent):
346
+ b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (y / blend_extent)
347
+ return b
348
+
349
+ def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
350
+ blend_extent = min(a.shape[-1], b.shape[-1], blend_extent)
351
+ for x in range(blend_extent):
352
+ b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (x / blend_extent)
353
+ return b
354
+
355
+ def blend_t(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
356
+ blend_extent = min(a.shape[-3], b.shape[-3], blend_extent)
357
+ for x in range(blend_extent):
358
+ b[:, :, x, :, :] = a[:, :, -blend_extent + x, :, :] * (1 - x / blend_extent) + b[:, :, x, :, :] * (x / blend_extent)
359
+ return b
360
+
361
+ def spatial_tiled_encode(
362
+ self, x: torch.FloatTensor, return_dict: bool = True, return_moments: bool = False
363
+ ) -> AutoencoderKLOutput:
364
+ r"""Encode a batch of images/videos using a tiled encoder.
365
+
366
+ When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
367
+ steps. This is useful to keep memory use constant regardless of image/videos size. The end result of tiled encoding is
368
+ different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
369
+ tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
370
+ output, but they should be much less noticeable.
371
+
372
+ Args:
373
+ x (`torch.FloatTensor`): Input batch of images/videos.
374
+ return_dict (`bool`, *optional*, defaults to `True`):
375
+ Whether or not to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
376
+
377
+ Returns:
378
+ [`~models.autoencoder_kl.AutoencoderKLOutput`] or `tuple`:
379
+ If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain
380
+ `tuple` is returned.
381
+ """
382
+ overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
383
+ blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
384
+ row_limit = self.tile_latent_min_size - blend_extent
385
+
386
+ # Split video into tiles and encode them separately.
387
+ rows = []
388
+ for i in range(0, x.shape[-2], overlap_size):
389
+ row = []
390
+ for j in range(0, x.shape[-1], overlap_size):
391
+ tile = x[:, :, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
392
+ tile = self.encoder(tile)
393
+ tile = self.quant_conv(tile)
394
+ row.append(tile)
395
+ rows.append(row)
396
+ result_rows = []
397
+ for i, row in enumerate(rows):
398
+ result_row = []
399
+ for j, tile in enumerate(row):
400
+ # blend the above tile and the left tile
401
+ # to the current tile and add the current tile to the result row
402
+ if i > 0:
403
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
404
+ if j > 0:
405
+ tile = self.blend_h(row[j - 1], tile, blend_extent)
406
+ result_row.append(tile[:, :, :, :row_limit, :row_limit])
407
+ result_rows.append(torch.cat(result_row, dim=-1))
408
+
409
+ moments = torch.cat(result_rows, dim=-2)
410
+ if return_moments:
411
+ return moments
412
+
413
+ posterior = DiagonalGaussianDistribution(moments)
414
+ if not return_dict:
415
+ return (posterior,)
416
+
417
+ return AutoencoderKLOutput(latent_dist=posterior)
418
+
419
+ def spatial_tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
420
+ r"""
421
+ Decode a batch of images/videos using a tiled decoder.
422
+
423
+ Args:
424
+ z (`torch.FloatTensor`): Input batch of latent vectors.
425
+ return_dict (`bool`, *optional*, defaults to `True`):
426
+ Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
427
+
428
+ Returns:
429
+ [`~models.vae.DecoderOutput`] or `tuple`:
430
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
431
+ returned.
432
+ """
433
+ overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
434
+ blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
435
+ row_limit = self.tile_sample_min_size - blend_extent
436
+
437
+ # Split z into overlapping tiles and decode them separately.
438
+ # The tiles have an overlap to avoid seams between tiles.
439
+ rows = []
440
+ for i in range(0, z.shape[-2], overlap_size):
441
+ row = []
442
+ for j in range(0, z.shape[-1], overlap_size):
443
+ tile = z[:, :, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size]
444
+ tile = self.post_quant_conv(tile)
445
+ decoded = self.decoder(tile)
446
+ row.append(decoded)
447
+ rows.append(row)
448
+ result_rows = []
449
+ for i, row in enumerate(rows):
450
+ result_row = []
451
+ for j, tile in enumerate(row):
452
+ # blend the above tile and the left tile
453
+ # to the current tile and add the current tile to the result row
454
+ if i > 0:
455
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
456
+ if j > 0:
457
+ tile = self.blend_h(row[j - 1], tile, blend_extent)
458
+ result_row.append(tile[:, :, :, :row_limit, :row_limit])
459
+ result_rows.append(torch.cat(result_row, dim=-1))
460
+
461
+ dec = torch.cat(result_rows, dim=-2)
462
+ if not return_dict:
463
+ return (dec,)
464
+
465
+ return DecoderOutput(sample=dec)
466
+
467
+ def temporal_tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
468
+
469
+ B, C, T, H, W = x.shape
470
+ overlap_size = int(self.tile_sample_min_tsize * (1 - self.tile_overlap_factor))
471
+ blend_extent = int(self.tile_latent_min_tsize * self.tile_overlap_factor)
472
+ t_limit = self.tile_latent_min_tsize - blend_extent
473
+
474
+ # Split the video into tiles and encode them separately.
475
+ row = []
476
+ for i in range(0, T, overlap_size):
477
+ tile = x[:, :, i : i + self.tile_sample_min_tsize + 1, :, :]
478
+ if self.use_spatial_tiling and (
479
+ tile.shape[-1] > self.tile_sample_min_size or tile.shape[-2] > self.tile_sample_min_size
480
+ ):
481
+ tile = self.spatial_tiled_encode(tile, return_moments=True)
482
+ else:
483
+ tile = self.encoder(tile)
484
+ tile = self.quant_conv(tile)
485
+ if i > 0:
486
+ tile = tile[:, :, 1:, :, :]
487
+ row.append(tile)
488
+ result_row = []
489
+ for i, tile in enumerate(row):
490
+ if i > 0:
491
+ tile = self.blend_t(row[i - 1], tile, blend_extent)
492
+ result_row.append(tile[:, :, :t_limit, :, :])
493
+ else:
494
+ result_row.append(tile[:, :, : t_limit + 1, :, :])
495
+
496
+ moments = torch.cat(result_row, dim=2)
497
+ posterior = DiagonalGaussianDistribution(moments)
498
+
499
+ if not return_dict:
500
+ return (posterior,)
501
+
502
+ return AutoencoderKLOutput(latent_dist=posterior)
503
+
504
+ def temporal_tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
505
+ # Split z into overlapping tiles and decode them separately.
506
+
507
+ B, C, T, H, W = z.shape
508
+ overlap_size = int(self.tile_latent_min_tsize * (1 - self.tile_overlap_factor))
509
+ blend_extent = int(self.tile_sample_min_tsize * self.tile_overlap_factor)
510
+ t_limit = self.tile_sample_min_tsize - blend_extent
511
+
512
+ row = []
513
+ for i in range(0, T, overlap_size):
514
+ tile = z[:, :, i : i + self.tile_latent_min_tsize + 1, :, :]
515
+ if self.use_spatial_tiling and (
516
+ tile.shape[-1] > self.tile_latent_min_size or tile.shape[-2] > self.tile_latent_min_size
517
+ ):
518
+ decoded = self.spatial_tiled_decode(tile, return_dict=True).sample
519
+ else:
520
+ tile = self.post_quant_conv(tile)
521
+ decoded = self.decoder(tile)
522
+ if i > 0:
523
+ decoded = decoded[:, :, 1:, :, :]
524
+ row.append(decoded)
525
+ result_row = []
526
+ for i, tile in enumerate(row):
527
+ if i > 0:
528
+ tile = self.blend_t(row[i - 1], tile, blend_extent)
529
+ result_row.append(tile[:, :, :t_limit, :, :])
530
+ else:
531
+ result_row.append(tile[:, :, : t_limit + 1, :, :])
532
+
533
+ dec = torch.cat(result_row, dim=2)
534
+ if not return_dict:
535
+ return (dec,)
536
+
537
+ return DecoderOutput(sample=dec)
538
+
539
+ def forward(
540
+ self,
541
+ sample: torch.FloatTensor,
542
+ sample_posterior: bool = False,
543
+ return_dict: bool = True,
544
+ return_posterior: bool = False,
545
+ generator: Optional[torch.Generator] = None,
546
+ ) -> Union[DecoderOutput2, torch.FloatTensor]:
547
+ r"""
548
+ Args:
549
+ sample (`torch.FloatTensor`): Input sample.
550
+ sample_posterior (`bool`, *optional*, defaults to `False`):
551
+ Whether to sample from the posterior.
552
+ return_dict (`bool`, *optional*, defaults to `True`):
553
+ Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
554
+ """
555
+ x = sample
556
+ posterior = self.encode(x).latent_dist
557
+ if sample_posterior:
558
+ z = posterior.sample(generator=generator)
559
+ else:
560
+ z = posterior.mode()
561
+ dec = self.decode(z).sample
562
+
563
+ if not return_dict:
564
+ if return_posterior:
565
+ return (dec, posterior)
566
+ else:
567
+ return (dec,)
568
+ if return_posterior:
569
+ return DecoderOutput2(sample=dec, posterior=posterior)
570
+ else:
571
+ return DecoderOutput2(sample=dec)
572
+
573
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
574
+ def fuse_qkv_projections(self):
575
+ """
576
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
577
+ key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
578
+
579
+ <Tip warning={true}>
580
+
581
+ This API is 🧪 experimental.
582
+
583
+ </Tip>
584
+ """
585
+ self.original_attn_processors = None
586
+
587
+ for _, attn_processor in self.attn_processors.items():
588
+ if "Added" in str(attn_processor.__class__.__name__):
589
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
590
+
591
+ self.original_attn_processors = self.attn_processors
592
+
593
+ for module in self.modules():
594
+ if isinstance(module, Attention):
595
+ module.fuse_projections(fuse=True)
596
+
597
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
598
+ def unfuse_qkv_projections(self):
599
+ """Disables the fused QKV projection if enabled.
600
+
601
+ <Tip warning={true}>
602
+
603
+ This API is 🧪 experimental.
604
+
605
+ </Tip>
606
+
607
+ """
608
+ if self.original_attn_processors is not None:
609
+ self.set_attn_processor(self.original_attn_processors)
hunyuan_model/embed_layers.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ import math
3
+ import torch
4
+ import torch.nn as nn
5
+ from einops import rearrange, repeat
6
+
7
+ from .helpers import to_2tuple
8
+
9
+ class PatchEmbed(nn.Module):
10
+ """2D Image to Patch Embedding
11
+
12
+ Image to Patch Embedding using Conv2d
13
+
14
+ A convolution based approach to patchifying a 2D image w/ embedding projection.
15
+
16
+ Based on the impl in https://github.com/google-research/vision_transformer
17
+
18
+ Hacked together by / Copyright 2020 Ross Wightman
19
+
20
+ Remove the _assert function in forward function to be compatible with multi-resolution images.
21
+ """
22
+
23
+ def __init__(
24
+ self,
25
+ patch_size=16,
26
+ in_chans=3,
27
+ embed_dim=768,
28
+ norm_layer=None,
29
+ flatten=True,
30
+ bias=True,
31
+ dtype=None,
32
+ device=None,
33
+ ):
34
+ factory_kwargs = {"dtype": dtype, "device": device}
35
+ super().__init__()
36
+ patch_size = to_2tuple(patch_size)
37
+ self.patch_size = patch_size
38
+ self.flatten = flatten
39
+
40
+ self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias, **factory_kwargs)
41
+ nn.init.xavier_uniform_(self.proj.weight.view(self.proj.weight.size(0), -1))
42
+ if bias:
43
+ nn.init.zeros_(self.proj.bias)
44
+
45
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
46
+
47
+ def forward(self, x):
48
+ x = self.proj(x)
49
+ if self.flatten:
50
+ x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
51
+ x = self.norm(x)
52
+ return x
53
+
54
+
55
+ class TextProjection(nn.Module):
56
+ """
57
+ Projects text embeddings. Also handles dropout for classifier-free guidance.
58
+
59
+ Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
60
+ """
61
+
62
+ def __init__(self, in_channels, hidden_size, act_layer, dtype=None, device=None):
63
+ factory_kwargs = {"dtype": dtype, "device": device}
64
+ super().__init__()
65
+ self.linear_1 = nn.Linear(in_features=in_channels, out_features=hidden_size, bias=True, **factory_kwargs)
66
+ self.act_1 = act_layer()
67
+ self.linear_2 = nn.Linear(in_features=hidden_size, out_features=hidden_size, bias=True, **factory_kwargs)
68
+
69
+ def forward(self, caption):
70
+ hidden_states = self.linear_1(caption)
71
+ hidden_states = self.act_1(hidden_states)
72
+ hidden_states = self.linear_2(hidden_states)
73
+ return hidden_states
74
+
75
+
76
+ def timestep_embedding(t, dim, max_period=10000):
77
+ """
78
+ Create sinusoidal timestep embeddings.
79
+
80
+ Args:
81
+ t (torch.Tensor): a 1-D Tensor of N indices, one per batch element. These may be fractional.
82
+ dim (int): the dimension of the output.
83
+ max_period (int): controls the minimum frequency of the embeddings.
84
+
85
+ Returns:
86
+ embedding (torch.Tensor): An (N, D) Tensor of positional embeddings.
87
+
88
+ .. ref_link: https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
89
+ """
90
+ half = dim // 2
91
+ freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(device=t.device)
92
+ args = t[:, None].float() * freqs[None]
93
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
94
+ if dim % 2:
95
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
96
+ return embedding
97
+
98
+
99
+ class TimestepEmbedder(nn.Module):
100
+ """
101
+ Embeds scalar timesteps into vector representations.
102
+ """
103
+
104
+ def __init__(
105
+ self,
106
+ hidden_size,
107
+ act_layer,
108
+ frequency_embedding_size=256,
109
+ max_period=10000,
110
+ out_size=None,
111
+ dtype=None,
112
+ device=None,
113
+ ):
114
+ factory_kwargs = {"dtype": dtype, "device": device}
115
+ super().__init__()
116
+ self.frequency_embedding_size = frequency_embedding_size
117
+ self.max_period = max_period
118
+ if out_size is None:
119
+ out_size = hidden_size
120
+
121
+ self.mlp = nn.Sequential(
122
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True, **factory_kwargs),
123
+ act_layer(),
124
+ nn.Linear(hidden_size, out_size, bias=True, **factory_kwargs),
125
+ )
126
+ nn.init.normal_(self.mlp[0].weight, std=0.02)
127
+ nn.init.normal_(self.mlp[2].weight, std=0.02)
128
+
129
+ def forward(self, t):
130
+ t_freq = timestep_embedding(t, self.frequency_embedding_size, self.max_period).type(self.mlp[0].weight.dtype)
131
+ t_emb = self.mlp(t_freq)
132
+ return t_emb
hunyuan_model/helpers.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections.abc
2
+
3
+ from itertools import repeat
4
+
5
+
6
+ def _ntuple(n):
7
+ def parse(x):
8
+ if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
9
+ x = tuple(x)
10
+ if len(x) == 1:
11
+ x = tuple(repeat(x[0], n))
12
+ return x
13
+ return tuple(repeat(x, n))
14
+ return parse
15
+
16
+
17
+ to_1tuple = _ntuple(1)
18
+ to_2tuple = _ntuple(2)
19
+ to_3tuple = _ntuple(3)
20
+ to_4tuple = _ntuple(4)
21
+
22
+
23
+ def as_tuple(x):
24
+ if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
25
+ return tuple(x)
26
+ if x is None or isinstance(x, (int, float, str)):
27
+ return (x,)
28
+ else:
29
+ raise ValueError(f"Unknown type {type(x)}")
30
+
31
+
32
+ def as_list_of_2tuple(x):
33
+ x = as_tuple(x)
34
+ if len(x) == 1:
35
+ x = (x[0], x[0])
36
+ assert len(x) % 2 == 0, f"Expect even length, got {len(x)}."
37
+ lst = []
38
+ for i in range(0, len(x), 2):
39
+ lst.append((x[i], x[i + 1]))
40
+ return lst
hunyuan_model/mlp_layers.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from timm library:
2
+ # https://github.com/huggingface/pytorch-image-models/blob/648aaa41233ba83eb38faf5ba9d415d574823241/timm/layers/mlp.py#L13
3
+
4
+ from functools import partial
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+ from .modulate_layers import modulate
10
+ from .helpers import to_2tuple
11
+
12
+
13
+ class MLP(nn.Module):
14
+ """MLP as used in Vision Transformer, MLP-Mixer and related networks"""
15
+
16
+ def __init__(
17
+ self,
18
+ in_channels,
19
+ hidden_channels=None,
20
+ out_features=None,
21
+ act_layer=nn.GELU,
22
+ norm_layer=None,
23
+ bias=True,
24
+ drop=0.0,
25
+ use_conv=False,
26
+ device=None,
27
+ dtype=None,
28
+ ):
29
+ factory_kwargs = {"device": device, "dtype": dtype}
30
+ super().__init__()
31
+ out_features = out_features or in_channels
32
+ hidden_channels = hidden_channels or in_channels
33
+ bias = to_2tuple(bias)
34
+ drop_probs = to_2tuple(drop)
35
+ linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
36
+
37
+ self.fc1 = linear_layer(
38
+ in_channels, hidden_channels, bias=bias[0], **factory_kwargs
39
+ )
40
+ self.act = act_layer()
41
+ self.drop1 = nn.Dropout(drop_probs[0])
42
+ self.norm = (
43
+ norm_layer(hidden_channels, **factory_kwargs)
44
+ if norm_layer is not None
45
+ else nn.Identity()
46
+ )
47
+ self.fc2 = linear_layer(
48
+ hidden_channels, out_features, bias=bias[1], **factory_kwargs
49
+ )
50
+ self.drop2 = nn.Dropout(drop_probs[1])
51
+
52
+ def forward(self, x):
53
+ x = self.fc1(x)
54
+ x = self.act(x)
55
+ x = self.drop1(x)
56
+ x = self.norm(x)
57
+ x = self.fc2(x)
58
+ x = self.drop2(x)
59
+ return x
60
+
61
+
62
+ #
63
+ class MLPEmbedder(nn.Module):
64
+ """copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py"""
65
+ def __init__(self, in_dim: int, hidden_dim: int, device=None, dtype=None):
66
+ factory_kwargs = {"device": device, "dtype": dtype}
67
+ super().__init__()
68
+ self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True, **factory_kwargs)
69
+ self.silu = nn.SiLU()
70
+ self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True, **factory_kwargs)
71
+
72
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
73
+ return self.out_layer(self.silu(self.in_layer(x)))
74
+
75
+
76
+ class FinalLayer(nn.Module):
77
+ """The final layer of DiT."""
78
+
79
+ def __init__(
80
+ self, hidden_size, patch_size, out_channels, act_layer, device=None, dtype=None
81
+ ):
82
+ factory_kwargs = {"device": device, "dtype": dtype}
83
+ super().__init__()
84
+
85
+ # Just use LayerNorm for the final layer
86
+ self.norm_final = nn.LayerNorm(
87
+ hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs
88
+ )
89
+ if isinstance(patch_size, int):
90
+ self.linear = nn.Linear(
91
+ hidden_size,
92
+ patch_size * patch_size * out_channels,
93
+ bias=True,
94
+ **factory_kwargs
95
+ )
96
+ else:
97
+ self.linear = nn.Linear(
98
+ hidden_size,
99
+ patch_size[0] * patch_size[1] * patch_size[2] * out_channels,
100
+ bias=True,
101
+ )
102
+ nn.init.zeros_(self.linear.weight)
103
+ nn.init.zeros_(self.linear.bias)
104
+
105
+ # Here we don't distinguish between the modulate types. Just use the simple one.
106
+ self.adaLN_modulation = nn.Sequential(
107
+ act_layer(),
108
+ nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs),
109
+ )
110
+ # Zero-initialize the modulation
111
+ nn.init.zeros_(self.adaLN_modulation[1].weight)
112
+ nn.init.zeros_(self.adaLN_modulation[1].bias)
113
+
114
+ def forward(self, x, c):
115
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
116
+ x = modulate(self.norm_final(x), shift=shift, scale=scale)
117
+ x = self.linear(x)
118
+ return x
hunyuan_model/models.py ADDED
@@ -0,0 +1,997 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Any, List, Tuple, Optional, Union, Dict
3
+ import accelerate
4
+ from einops import rearrange
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from torch.utils.checkpoint import checkpoint
9
+
10
+ from .activation_layers import get_activation_layer
11
+ from .norm_layers import get_norm_layer
12
+ from .embed_layers import TimestepEmbedder, PatchEmbed, TextProjection
13
+ from .attention import attention, parallel_attention, get_cu_seqlens
14
+ from .posemb_layers import apply_rotary_emb
15
+ from .mlp_layers import MLP, MLPEmbedder, FinalLayer
16
+ from .modulate_layers import ModulateDiT, modulate, apply_gate
17
+ from .token_refiner import SingleTokenRefiner
18
+ from modules.custom_offloading_utils import ModelOffloader, synchronize_device, clean_memory_on_device
19
+ from hunyuan_model.posemb_layers import get_nd_rotary_pos_embed
20
+
21
+ from utils.safetensors_utils import MemoryEfficientSafeOpen
22
+
23
+
24
+ class MMDoubleStreamBlock(nn.Module):
25
+ """
26
+ A multimodal dit block with seperate modulation for
27
+ text and image/video, see more details (SD3): https://arxiv.org/abs/2403.03206
28
+ (Flux.1): https://github.com/black-forest-labs/flux
29
+ """
30
+
31
+ def __init__(
32
+ self,
33
+ hidden_size: int,
34
+ heads_num: int,
35
+ mlp_width_ratio: float,
36
+ mlp_act_type: str = "gelu_tanh",
37
+ qk_norm: bool = True,
38
+ qk_norm_type: str = "rms",
39
+ qkv_bias: bool = False,
40
+ dtype: Optional[torch.dtype] = None,
41
+ device: Optional[torch.device] = None,
42
+ attn_mode: str = "flash",
43
+ ):
44
+ factory_kwargs = {"device": device, "dtype": dtype}
45
+ super().__init__()
46
+ self.attn_mode = attn_mode
47
+
48
+ self.deterministic = False
49
+ self.heads_num = heads_num
50
+ head_dim = hidden_size // heads_num
51
+ mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
52
+
53
+ self.img_mod = ModulateDiT(
54
+ hidden_size,
55
+ factor=6,
56
+ act_layer=get_activation_layer("silu"),
57
+ **factory_kwargs,
58
+ )
59
+ self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
60
+
61
+ self.img_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs)
62
+ qk_norm_layer = get_norm_layer(qk_norm_type)
63
+ self.img_attn_q_norm = (
64
+ qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
65
+ )
66
+ self.img_attn_k_norm = (
67
+ qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
68
+ )
69
+ self.img_attn_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs)
70
+
71
+ self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
72
+ self.img_mlp = MLP(
73
+ hidden_size,
74
+ mlp_hidden_dim,
75
+ act_layer=get_activation_layer(mlp_act_type),
76
+ bias=True,
77
+ **factory_kwargs,
78
+ )
79
+
80
+ self.txt_mod = ModulateDiT(
81
+ hidden_size,
82
+ factor=6,
83
+ act_layer=get_activation_layer("silu"),
84
+ **factory_kwargs,
85
+ )
86
+ self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
87
+
88
+ self.txt_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs)
89
+ self.txt_attn_q_norm = (
90
+ qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
91
+ )
92
+ self.txt_attn_k_norm = (
93
+ qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
94
+ )
95
+ self.txt_attn_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs)
96
+
97
+ self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
98
+ self.txt_mlp = MLP(
99
+ hidden_size,
100
+ mlp_hidden_dim,
101
+ act_layer=get_activation_layer(mlp_act_type),
102
+ bias=True,
103
+ **factory_kwargs,
104
+ )
105
+ self.hybrid_seq_parallel_attn = None
106
+
107
+ self.gradient_checkpointing = False
108
+
109
+ def enable_deterministic(self):
110
+ self.deterministic = True
111
+
112
+ def disable_deterministic(self):
113
+ self.deterministic = False
114
+
115
+ def enable_gradient_checkpointing(self):
116
+ self.gradient_checkpointing = True
117
+
118
+ def _forward(
119
+ self,
120
+ img: torch.Tensor,
121
+ txt: torch.Tensor,
122
+ vec: torch.Tensor,
123
+ attn_mask: Optional[torch.Tensor] = None,
124
+ cu_seqlens_q: Optional[torch.Tensor] = None,
125
+ cu_seqlens_kv: Optional[torch.Tensor] = None,
126
+ max_seqlen_q: Optional[int] = None,
127
+ max_seqlen_kv: Optional[int] = None,
128
+ freqs_cis: tuple = None,
129
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
130
+ (img_mod1_shift, img_mod1_scale, img_mod1_gate, img_mod2_shift, img_mod2_scale, img_mod2_gate) = self.img_mod(vec).chunk(
131
+ 6, dim=-1
132
+ )
133
+ (txt_mod1_shift, txt_mod1_scale, txt_mod1_gate, txt_mod2_shift, txt_mod2_scale, txt_mod2_gate) = self.txt_mod(vec).chunk(
134
+ 6, dim=-1
135
+ )
136
+
137
+ # Prepare image for attention.
138
+ img_modulated = self.img_norm1(img)
139
+ img_modulated = modulate(img_modulated, shift=img_mod1_shift, scale=img_mod1_scale)
140
+ img_qkv = self.img_attn_qkv(img_modulated)
141
+ img_modulated = None
142
+ img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
143
+ img_qkv = None
144
+ # Apply QK-Norm if needed
145
+ img_q = self.img_attn_q_norm(img_q).to(img_v)
146
+ img_k = self.img_attn_k_norm(img_k).to(img_v)
147
+
148
+ # Apply RoPE if needed.
149
+ if freqs_cis is not None:
150
+ img_q_shape = img_q.shape
151
+ img_k_shape = img_k.shape
152
+ img_q, img_k = apply_rotary_emb(img_q, img_k, freqs_cis, head_first=False)
153
+ assert (
154
+ img_q.shape == img_q_shape and img_k.shape == img_k_shape
155
+ ), f"img_kk: {img_q.shape}, img_q: {img_q_shape}, img_kk: {img_k.shape}, img_k: {img_k_shape}"
156
+ # img_q, img_k = img_qq, img_kk
157
+
158
+ # Prepare txt for attention.
159
+ txt_modulated = self.txt_norm1(txt)
160
+ txt_modulated = modulate(txt_modulated, shift=txt_mod1_shift, scale=txt_mod1_scale)
161
+ txt_qkv = self.txt_attn_qkv(txt_modulated)
162
+ txt_modulated = None
163
+ txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
164
+ txt_qkv = None
165
+ # Apply QK-Norm if needed.
166
+ txt_q = self.txt_attn_q_norm(txt_q).to(txt_v)
167
+ txt_k = self.txt_attn_k_norm(txt_k).to(txt_v)
168
+
169
+ # Run actual attention.
170
+ img_q_len = img_q.shape[1]
171
+ img_kv_len = img_k.shape[1]
172
+ batch_size = img_k.shape[0]
173
+ q = torch.cat((img_q, txt_q), dim=1)
174
+ img_q = txt_q = None
175
+ k = torch.cat((img_k, txt_k), dim=1)
176
+ img_k = txt_k = None
177
+ v = torch.cat((img_v, txt_v), dim=1)
178
+ img_v = txt_v = None
179
+
180
+ assert (
181
+ cu_seqlens_q.shape[0] == 2 * img.shape[0] + 1
182
+ ), f"cu_seqlens_q.shape:{cu_seqlens_q.shape}, img.shape[0]:{img.shape[0]}"
183
+
184
+ # attention computation start
185
+ if not self.hybrid_seq_parallel_attn:
186
+ l = [q, k, v]
187
+ q = k = v = None
188
+ attn = attention(
189
+ l,
190
+ mode=self.attn_mode,
191
+ attn_mask=attn_mask,
192
+ cu_seqlens_q=cu_seqlens_q,
193
+ cu_seqlens_kv=cu_seqlens_kv,
194
+ max_seqlen_q=max_seqlen_q,
195
+ max_seqlen_kv=max_seqlen_kv,
196
+ batch_size=batch_size,
197
+ )
198
+ else:
199
+ attn = parallel_attention(
200
+ self.hybrid_seq_parallel_attn,
201
+ q,
202
+ k,
203
+ v,
204
+ img_q_len=img_q_len,
205
+ img_kv_len=img_kv_len,
206
+ cu_seqlens_q=cu_seqlens_q,
207
+ cu_seqlens_kv=cu_seqlens_kv,
208
+ )
209
+
210
+ # attention computation end
211
+
212
+ img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1] :]
213
+ attn = None
214
+
215
+ # Calculate the img bloks.
216
+ img = img + apply_gate(self.img_attn_proj(img_attn), gate=img_mod1_gate)
217
+ img_attn = None
218
+ img = img + apply_gate(
219
+ self.img_mlp(modulate(self.img_norm2(img), shift=img_mod2_shift, scale=img_mod2_scale)),
220
+ gate=img_mod2_gate,
221
+ )
222
+
223
+ # Calculate the txt bloks.
224
+ txt = txt + apply_gate(self.txt_attn_proj(txt_attn), gate=txt_mod1_gate)
225
+ txt_attn = None
226
+ txt = txt + apply_gate(
227
+ self.txt_mlp(modulate(self.txt_norm2(txt), shift=txt_mod2_shift, scale=txt_mod2_scale)),
228
+ gate=txt_mod2_gate,
229
+ )
230
+
231
+ return img, txt
232
+
233
+ # def forward(
234
+ # self,
235
+ # img: torch.Tensor,
236
+ # txt: torch.Tensor,
237
+ # vec: torch.Tensor,
238
+ # attn_mask: Optional[torch.Tensor] = None,
239
+ # cu_seqlens_q: Optional[torch.Tensor] = None,
240
+ # cu_seqlens_kv: Optional[torch.Tensor] = None,
241
+ # max_seqlen_q: Optional[int] = None,
242
+ # max_seqlen_kv: Optional[int] = None,
243
+ # freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None,
244
+ # ) -> Tuple[torch.Tensor, torch.Tensor]:
245
+ def forward(self, *args, **kwargs):
246
+ if self.training and self.gradient_checkpointing:
247
+ return checkpoint(self._forward, *args, use_reentrant=False, **kwargs)
248
+ else:
249
+ return self._forward(*args, **kwargs)
250
+
251
+
252
+ class MMSingleStreamBlock(nn.Module):
253
+ """
254
+ A DiT block with parallel linear layers as described in
255
+ https://arxiv.org/abs/2302.05442 and adapted modulation interface.
256
+ Also refer to (SD3): https://arxiv.org/abs/2403.03206
257
+ (Flux.1): https://github.com/black-forest-labs/flux
258
+ """
259
+
260
+ def __init__(
261
+ self,
262
+ hidden_size: int,
263
+ heads_num: int,
264
+ mlp_width_ratio: float = 4.0,
265
+ mlp_act_type: str = "gelu_tanh",
266
+ qk_norm: bool = True,
267
+ qk_norm_type: str = "rms",
268
+ qk_scale: float = None,
269
+ dtype: Optional[torch.dtype] = None,
270
+ device: Optional[torch.device] = None,
271
+ attn_mode: str = "flash",
272
+ ):
273
+ factory_kwargs = {"device": device, "dtype": dtype}
274
+ super().__init__()
275
+ self.attn_mode = attn_mode
276
+
277
+ self.deterministic = False
278
+ self.hidden_size = hidden_size
279
+ self.heads_num = heads_num
280
+ head_dim = hidden_size // heads_num
281
+ mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
282
+ self.mlp_hidden_dim = mlp_hidden_dim
283
+ self.scale = qk_scale or head_dim**-0.5
284
+
285
+ # qkv and mlp_in
286
+ self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + mlp_hidden_dim, **factory_kwargs)
287
+ # proj and mlp_out
288
+ self.linear2 = nn.Linear(hidden_size + mlp_hidden_dim, hidden_size, **factory_kwargs)
289
+
290
+ qk_norm_layer = get_norm_layer(qk_norm_type)
291
+ self.q_norm = qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
292
+ self.k_norm = qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
293
+
294
+ self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
295
+
296
+ self.mlp_act = get_activation_layer(mlp_act_type)()
297
+ self.modulation = ModulateDiT(hidden_size, factor=3, act_layer=get_activation_layer("silu"), **factory_kwargs)
298
+ self.hybrid_seq_parallel_attn = None
299
+
300
+ self.gradient_checkpointing = False
301
+
302
+ def enable_deterministic(self):
303
+ self.deterministic = True
304
+
305
+ def disable_deterministic(self):
306
+ self.deterministic = False
307
+
308
+ def enable_gradient_checkpointing(self):
309
+ self.gradient_checkpointing = True
310
+
311
+ def _forward(
312
+ self,
313
+ x: torch.Tensor,
314
+ vec: torch.Tensor,
315
+ txt_len: int,
316
+ attn_mask: Optional[torch.Tensor] = None,
317
+ cu_seqlens_q: Optional[torch.Tensor] = None,
318
+ cu_seqlens_kv: Optional[torch.Tensor] = None,
319
+ max_seqlen_q: Optional[int] = None,
320
+ max_seqlen_kv: Optional[int] = None,
321
+ freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None,
322
+ ) -> torch.Tensor:
323
+ mod_shift, mod_scale, mod_gate = self.modulation(vec).chunk(3, dim=-1)
324
+ x_mod = modulate(self.pre_norm(x), shift=mod_shift, scale=mod_scale)
325
+ qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
326
+ x_mod = None
327
+ # mlp = mlp.to("cpu", non_blocking=True)
328
+ # clean_memory_on_device(x.device)
329
+
330
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
331
+ qkv = None
332
+
333
+ # Apply QK-Norm if needed.
334
+ q = self.q_norm(q).to(v)
335
+ k = self.k_norm(k).to(v)
336
+
337
+ # Apply RoPE if needed.
338
+ if freqs_cis is not None:
339
+ img_q, txt_q = q[:, :-txt_len, :, :], q[:, -txt_len:, :, :]
340
+ img_k, txt_k = k[:, :-txt_len, :, :], k[:, -txt_len:, :, :]
341
+ q = k = None
342
+ img_q_shape = img_q.shape
343
+ img_k_shape = img_k.shape
344
+ img_q, img_k = apply_rotary_emb(img_q, img_k, freqs_cis, head_first=False)
345
+ assert (
346
+ img_q.shape == img_q_shape and img_k_shape == img_k.shape
347
+ ), f"img_kk: {img_q.shape}, img_q: {img_q.shape}, img_kk: {img_k.shape}, img_k: {img_k.shape}"
348
+ # img_q, img_k = img_qq, img_kk
349
+ # del img_qq, img_kk
350
+ q = torch.cat((img_q, txt_q), dim=1)
351
+ k = torch.cat((img_k, txt_k), dim=1)
352
+ del img_q, txt_q, img_k, txt_k
353
+
354
+ # Compute attention.
355
+ assert cu_seqlens_q.shape[0] == 2 * x.shape[0] + 1, f"cu_seqlens_q.shape:{cu_seqlens_q.shape}, x.shape[0]:{x.shape[0]}"
356
+
357
+ # attention computation start
358
+ if not self.hybrid_seq_parallel_attn:
359
+ l = [q, k, v]
360
+ q = k = v = None
361
+ attn = attention(
362
+ l,
363
+ mode=self.attn_mode,
364
+ attn_mask=attn_mask,
365
+ cu_seqlens_q=cu_seqlens_q,
366
+ cu_seqlens_kv=cu_seqlens_kv,
367
+ max_seqlen_q=max_seqlen_q,
368
+ max_seqlen_kv=max_seqlen_kv,
369
+ batch_size=x.shape[0],
370
+ )
371
+ else:
372
+ attn = parallel_attention(
373
+ self.hybrid_seq_parallel_attn,
374
+ q,
375
+ k,
376
+ v,
377
+ img_q_len=img_q.shape[1],
378
+ img_kv_len=img_k.shape[1],
379
+ cu_seqlens_q=cu_seqlens_q,
380
+ cu_seqlens_kv=cu_seqlens_kv,
381
+ )
382
+ # attention computation end
383
+
384
+ # Compute activation in mlp stream, cat again and run second linear layer.
385
+ # mlp = mlp.to(x.device)
386
+ mlp = self.mlp_act(mlp)
387
+ attn_mlp = torch.cat((attn, mlp), 2)
388
+ attn = None
389
+ mlp = None
390
+ output = self.linear2(attn_mlp)
391
+ attn_mlp = None
392
+ return x + apply_gate(output, gate=mod_gate)
393
+
394
+ # def forward(
395
+ # self,
396
+ # x: torch.Tensor,
397
+ # vec: torch.Tensor,
398
+ # txt_len: int,
399
+ # attn_mask: Optional[torch.Tensor] = None,
400
+ # cu_seqlens_q: Optional[torch.Tensor] = None,
401
+ # cu_seqlens_kv: Optional[torch.Tensor] = None,
402
+ # max_seqlen_q: Optional[int] = None,
403
+ # max_seqlen_kv: Optional[int] = None,
404
+ # freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None,
405
+ # ) -> torch.Tensor:
406
+ def forward(self, *args, **kwargs):
407
+ if self.training and self.gradient_checkpointing:
408
+ return checkpoint(self._forward, *args, use_reentrant=False, **kwargs)
409
+ else:
410
+ return self._forward(*args, **kwargs)
411
+
412
+
413
+ class HYVideoDiffusionTransformer(nn.Module): # ModelMixin, ConfigMixin):
414
+ """
415
+ HunyuanVideo Transformer backbone
416
+
417
+ Inherited from ModelMixin and ConfigMixin for compatibility with diffusers' sampler StableDiffusionPipeline.
418
+
419
+ Reference:
420
+ [1] Flux.1: https://github.com/black-forest-labs/flux
421
+ [2] MMDiT: http://arxiv.org/abs/2403.03206
422
+
423
+ Parameters
424
+ ----------
425
+ args: argparse.Namespace
426
+ The arguments parsed by argparse.
427
+ patch_size: list
428
+ The size of the patch.
429
+ in_channels: int
430
+ The number of input channels.
431
+ out_channels: int
432
+ The number of output channels.
433
+ hidden_size: int
434
+ The hidden size of the transformer backbone.
435
+ heads_num: int
436
+ The number of attention heads.
437
+ mlp_width_ratio: float
438
+ The ratio of the hidden size of the MLP in the transformer block.
439
+ mlp_act_type: str
440
+ The activation function of the MLP in the transformer block.
441
+ depth_double_blocks: int
442
+ The number of transformer blocks in the double blocks.
443
+ depth_single_blocks: int
444
+ The number of transformer blocks in the single blocks.
445
+ rope_dim_list: list
446
+ The dimension of the rotary embedding for t, h, w.
447
+ qkv_bias: bool
448
+ Whether to use bias in the qkv linear layer.
449
+ qk_norm: bool
450
+ Whether to use qk norm.
451
+ qk_norm_type: str
452
+ The type of qk norm.
453
+ guidance_embed: bool
454
+ Whether to use guidance embedding for distillation.
455
+ text_projection: str
456
+ The type of the text projection, default is single_refiner.
457
+ use_attention_mask: bool
458
+ Whether to use attention mask for text encoder.
459
+ dtype: torch.dtype
460
+ The dtype of the model.
461
+ device: torch.device
462
+ The device of the model.
463
+ attn_mode: str
464
+ The mode of the attention, default is flash.
465
+ """
466
+
467
+ # @register_to_config
468
+ def __init__(
469
+ self,
470
+ text_states_dim: int,
471
+ text_states_dim_2: int,
472
+ patch_size: list = [1, 2, 2],
473
+ in_channels: int = 4, # Should be VAE.config.latent_channels.
474
+ out_channels: int = None,
475
+ hidden_size: int = 3072,
476
+ heads_num: int = 24,
477
+ mlp_width_ratio: float = 4.0,
478
+ mlp_act_type: str = "gelu_tanh",
479
+ mm_double_blocks_depth: int = 20,
480
+ mm_single_blocks_depth: int = 40,
481
+ rope_dim_list: List[int] = [16, 56, 56],
482
+ qkv_bias: bool = True,
483
+ qk_norm: bool = True,
484
+ qk_norm_type: str = "rms",
485
+ guidance_embed: bool = False, # For modulation.
486
+ text_projection: str = "single_refiner",
487
+ use_attention_mask: bool = True,
488
+ dtype: Optional[torch.dtype] = None,
489
+ device: Optional[torch.device] = None,
490
+ attn_mode: str = "flash",
491
+ ):
492
+ factory_kwargs = {"device": device, "dtype": dtype}
493
+ super().__init__()
494
+
495
+ self.patch_size = patch_size
496
+ self.in_channels = in_channels
497
+ self.out_channels = in_channels if out_channels is None else out_channels
498
+ self.unpatchify_channels = self.out_channels
499
+ self.guidance_embed = guidance_embed
500
+ self.rope_dim_list = rope_dim_list
501
+
502
+ # Text projection. Default to linear projection.
503
+ # Alternative: TokenRefiner. See more details (LI-DiT): http://arxiv.org/abs/2406.11831
504
+ self.use_attention_mask = use_attention_mask
505
+ self.text_projection = text_projection
506
+
507
+ self.text_states_dim = text_states_dim
508
+ self.text_states_dim_2 = text_states_dim_2
509
+
510
+ if hidden_size % heads_num != 0:
511
+ raise ValueError(f"Hidden size {hidden_size} must be divisible by heads_num {heads_num}")
512
+ pe_dim = hidden_size // heads_num
513
+ if sum(rope_dim_list) != pe_dim:
514
+ raise ValueError(f"Got {rope_dim_list} but expected positional dim {pe_dim}")
515
+ self.hidden_size = hidden_size
516
+ self.heads_num = heads_num
517
+
518
+ self.attn_mode = attn_mode
519
+
520
+ # image projection
521
+ self.img_in = PatchEmbed(self.patch_size, self.in_channels, self.hidden_size, **factory_kwargs)
522
+
523
+ # text projection
524
+ if self.text_projection == "linear":
525
+ self.txt_in = TextProjection(
526
+ self.text_states_dim,
527
+ self.hidden_size,
528
+ get_activation_layer("silu"),
529
+ **factory_kwargs,
530
+ )
531
+ elif self.text_projection == "single_refiner":
532
+ self.txt_in = SingleTokenRefiner(self.text_states_dim, hidden_size, heads_num, depth=2, **factory_kwargs)
533
+ else:
534
+ raise NotImplementedError(f"Unsupported text_projection: {self.text_projection}")
535
+
536
+ # time modulation
537
+ self.time_in = TimestepEmbedder(self.hidden_size, get_activation_layer("silu"), **factory_kwargs)
538
+
539
+ # text modulation
540
+ self.vector_in = MLPEmbedder(self.text_states_dim_2, self.hidden_size, **factory_kwargs)
541
+
542
+ # guidance modulation
543
+ self.guidance_in = (
544
+ TimestepEmbedder(self.hidden_size, get_activation_layer("silu"), **factory_kwargs) if guidance_embed else None
545
+ )
546
+
547
+ # double blocks
548
+ self.double_blocks = nn.ModuleList(
549
+ [
550
+ MMDoubleStreamBlock(
551
+ self.hidden_size,
552
+ self.heads_num,
553
+ mlp_width_ratio=mlp_width_ratio,
554
+ mlp_act_type=mlp_act_type,
555
+ qk_norm=qk_norm,
556
+ qk_norm_type=qk_norm_type,
557
+ qkv_bias=qkv_bias,
558
+ attn_mode=attn_mode,
559
+ **factory_kwargs,
560
+ )
561
+ for _ in range(mm_double_blocks_depth)
562
+ ]
563
+ )
564
+
565
+ # single blocks
566
+ self.single_blocks = nn.ModuleList(
567
+ [
568
+ MMSingleStreamBlock(
569
+ self.hidden_size,
570
+ self.heads_num,
571
+ mlp_width_ratio=mlp_width_ratio,
572
+ mlp_act_type=mlp_act_type,
573
+ qk_norm=qk_norm,
574
+ qk_norm_type=qk_norm_type,
575
+ attn_mode=attn_mode,
576
+ **factory_kwargs,
577
+ )
578
+ for _ in range(mm_single_blocks_depth)
579
+ ]
580
+ )
581
+
582
+ self.final_layer = FinalLayer(
583
+ self.hidden_size,
584
+ self.patch_size,
585
+ self.out_channels,
586
+ get_activation_layer("silu"),
587
+ **factory_kwargs,
588
+ )
589
+
590
+ self.gradient_checkpointing = False
591
+ self.blocks_to_swap = None
592
+ self.offloader_double = None
593
+ self.offloader_single = None
594
+ self._enable_img_in_txt_in_offloading = False
595
+
596
+ @property
597
+ def device(self):
598
+ return next(self.parameters()).device
599
+
600
+ @property
601
+ def dtype(self):
602
+ return next(self.parameters()).dtype
603
+
604
+ def enable_gradient_checkpointing(self):
605
+ self.gradient_checkpointing = True
606
+
607
+ self.txt_in.enable_gradient_checkpointing()
608
+
609
+ for block in self.double_blocks + self.single_blocks:
610
+ block.enable_gradient_checkpointing()
611
+
612
+ print(f"HYVideoDiffusionTransformer: Gradient checkpointing enabled.")
613
+
614
+ def enable_img_in_txt_in_offloading(self):
615
+ self._enable_img_in_txt_in_offloading = True
616
+
617
+ def enable_block_swap(self, num_blocks: int, device: torch.device, supports_backward: bool):
618
+ self.blocks_to_swap = num_blocks
619
+ self.num_double_blocks = len(self.double_blocks)
620
+ self.num_single_blocks = len(self.single_blocks)
621
+ double_blocks_to_swap = num_blocks // 2
622
+ single_blocks_to_swap = (num_blocks - double_blocks_to_swap) * 2 + 1
623
+
624
+ assert double_blocks_to_swap <= self.num_double_blocks - 1 and single_blocks_to_swap <= self.num_single_blocks - 1, (
625
+ f"Cannot swap more than {self.num_double_blocks - 1} double blocks and {self.num_single_blocks - 1} single blocks. "
626
+ f"Requested {double_blocks_to_swap} double blocks and {single_blocks_to_swap} single blocks."
627
+ )
628
+
629
+ self.offloader_double = ModelOffloader(
630
+ "double", self.double_blocks, self.num_double_blocks, double_blocks_to_swap, supports_backward, device # , debug=True
631
+ )
632
+ self.offloader_single = ModelOffloader(
633
+ "single", self.single_blocks, self.num_single_blocks, single_blocks_to_swap, supports_backward, device # , debug=True
634
+ )
635
+ print(
636
+ f"HYVideoDiffusionTransformer: Block swap enabled. Swapping {num_blocks} blocks, double blocks: {double_blocks_to_swap}, single blocks: {single_blocks_to_swap}."
637
+ )
638
+
639
+ def move_to_device_except_swap_blocks(self, device: torch.device):
640
+ # assume model is on cpu. do not move blocks to device to reduce temporary memory usage
641
+ if self.blocks_to_swap:
642
+ save_double_blocks = self.double_blocks
643
+ save_single_blocks = self.single_blocks
644
+ self.double_blocks = None
645
+ self.single_blocks = None
646
+
647
+ self.to(device)
648
+
649
+ if self.blocks_to_swap:
650
+ self.double_blocks = save_double_blocks
651
+ self.single_blocks = save_single_blocks
652
+
653
+ def prepare_block_swap_before_forward(self):
654
+ if self.blocks_to_swap is None or self.blocks_to_swap == 0:
655
+ return
656
+ self.offloader_double.prepare_block_devices_before_forward(self.double_blocks)
657
+ self.offloader_single.prepare_block_devices_before_forward(self.single_blocks)
658
+
659
+ def enable_deterministic(self):
660
+ for block in self.double_blocks:
661
+ block.enable_deterministic()
662
+ for block in self.single_blocks:
663
+ block.enable_deterministic()
664
+
665
+ def disable_deterministic(self):
666
+ for block in self.double_blocks:
667
+ block.disable_deterministic()
668
+ for block in self.single_blocks:
669
+ block.disable_deterministic()
670
+
671
+ def forward(
672
+ self,
673
+ x: torch.Tensor,
674
+ t: torch.Tensor, # Should be in range(0, 1000).
675
+ text_states: torch.Tensor = None,
676
+ text_mask: torch.Tensor = None, # Now we don't use it.
677
+ text_states_2: Optional[torch.Tensor] = None, # Text embedding for modulation.
678
+ freqs_cos: Optional[torch.Tensor] = None,
679
+ freqs_sin: Optional[torch.Tensor] = None,
680
+ guidance: torch.Tensor = None, # Guidance for modulation, should be cfg_scale x 1000.
681
+ return_dict: bool = True,
682
+ ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
683
+ out = {}
684
+ img = x
685
+ txt = text_states
686
+ _, _, ot, oh, ow = x.shape
687
+ tt, th, tw = (
688
+ ot // self.patch_size[0],
689
+ oh // self.patch_size[1],
690
+ ow // self.patch_size[2],
691
+ )
692
+
693
+ # Prepare modulation vectors.
694
+ vec = self.time_in(t)
695
+
696
+ # text modulation
697
+ vec = vec + self.vector_in(text_states_2)
698
+
699
+ # guidance modulation
700
+ if self.guidance_embed:
701
+ if guidance is None:
702
+ raise ValueError("Didn't get guidance strength for guidance distilled model.")
703
+
704
+ # our timestep_embedding is merged into guidance_in(TimestepEmbedder)
705
+ vec = vec + self.guidance_in(guidance)
706
+
707
+ # Embed image and text.
708
+ if self._enable_img_in_txt_in_offloading:
709
+ self.img_in.to(x.device, non_blocking=True)
710
+ self.txt_in.to(x.device, non_blocking=True)
711
+ synchronize_device(x.device)
712
+
713
+ img = self.img_in(img)
714
+ if self.text_projection == "linear":
715
+ txt = self.txt_in(txt)
716
+ elif self.text_projection == "single_refiner":
717
+ txt = self.txt_in(txt, t, text_mask if self.use_attention_mask else None)
718
+ else:
719
+ raise NotImplementedError(f"Unsupported text_projection: {self.text_projection}")
720
+
721
+ if self._enable_img_in_txt_in_offloading:
722
+ self.img_in.to(torch.device("cpu"), non_blocking=True)
723
+ self.txt_in.to(torch.device("cpu"), non_blocking=True)
724
+ synchronize_device(x.device)
725
+ clean_memory_on_device(x.device)
726
+
727
+ txt_seq_len = txt.shape[1]
728
+ img_seq_len = img.shape[1]
729
+
730
+ # Compute cu_squlens and max_seqlen for flash attention
731
+ cu_seqlens_q = get_cu_seqlens(text_mask, img_seq_len)
732
+ cu_seqlens_kv = cu_seqlens_q
733
+ max_seqlen_q = img_seq_len + txt_seq_len
734
+ max_seqlen_kv = max_seqlen_q
735
+
736
+ attn_mask = None
737
+ if self.attn_mode == "torch":
738
+ # initialize attention mask: bool tensor for sdpa, (b, 1, n, n)
739
+ bs = img.shape[0]
740
+ attn_mask = torch.zeros((bs, 1, max_seqlen_q, max_seqlen_q), dtype=torch.bool, device=text_mask.device)
741
+
742
+ # calculate text length and total length
743
+ text_len = text_mask.sum(dim=1) # (bs, )
744
+ total_len = img_seq_len + text_len # (bs, )
745
+
746
+ # set attention mask
747
+ for i in range(bs):
748
+ attn_mask[i, :, : total_len[i], : total_len[i]] = True
749
+
750
+ freqs_cis = (freqs_cos, freqs_sin) if freqs_cos is not None else None
751
+ # --------------------- Pass through DiT blocks ------------------------
752
+ for block_idx, block in enumerate(self.double_blocks):
753
+ double_block_args = [
754
+ img,
755
+ txt,
756
+ vec,
757
+ attn_mask,
758
+ cu_seqlens_q,
759
+ cu_seqlens_kv,
760
+ max_seqlen_q,
761
+ max_seqlen_kv,
762
+ freqs_cis,
763
+ ]
764
+
765
+ if self.blocks_to_swap:
766
+ self.offloader_double.wait_for_block(block_idx)
767
+
768
+ img, txt = block(*double_block_args)
769
+
770
+ if self.blocks_to_swap:
771
+ self.offloader_double.submit_move_blocks_forward(self.double_blocks, block_idx)
772
+
773
+ # Merge txt and img to pass through single stream blocks.
774
+ x = torch.cat((img, txt), 1)
775
+ if self.blocks_to_swap:
776
+ # delete img, txt to reduce memory usage
777
+ del img, txt
778
+ clean_memory_on_device(x.device)
779
+
780
+ if len(self.single_blocks) > 0:
781
+ for block_idx, block in enumerate(self.single_blocks):
782
+ single_block_args = [
783
+ x,
784
+ vec,
785
+ txt_seq_len,
786
+ attn_mask,
787
+ cu_seqlens_q,
788
+ cu_seqlens_kv,
789
+ max_seqlen_q,
790
+ max_seqlen_kv,
791
+ (freqs_cos, freqs_sin),
792
+ ]
793
+ if self.blocks_to_swap:
794
+ self.offloader_single.wait_for_block(block_idx)
795
+
796
+ x = block(*single_block_args)
797
+
798
+ if self.blocks_to_swap:
799
+ self.offloader_single.submit_move_blocks_forward(self.single_blocks, block_idx)
800
+
801
+ img = x[:, :img_seq_len, ...]
802
+ x = None
803
+
804
+ # ---------------------------- Final layer ------------------------------
805
+ img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
806
+
807
+ img = self.unpatchify(img, tt, th, tw)
808
+ if return_dict:
809
+ out["x"] = img
810
+ return out
811
+ return img
812
+
813
+ def unpatchify(self, x, t, h, w):
814
+ """
815
+ x: (N, T, patch_size**2 * C)
816
+ imgs: (N, H, W, C)
817
+ """
818
+ c = self.unpatchify_channels
819
+ pt, ph, pw = self.patch_size
820
+ assert t * h * w == x.shape[1]
821
+
822
+ x = x.reshape(shape=(x.shape[0], t, h, w, c, pt, ph, pw))
823
+ x = torch.einsum("nthwcopq->nctohpwq", x)
824
+ imgs = x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw))
825
+
826
+ return imgs
827
+
828
+ def params_count(self):
829
+ counts = {
830
+ "double": sum(
831
+ [
832
+ sum(p.numel() for p in block.img_attn_qkv.parameters())
833
+ + sum(p.numel() for p in block.img_attn_proj.parameters())
834
+ + sum(p.numel() for p in block.img_mlp.parameters())
835
+ + sum(p.numel() for p in block.txt_attn_qkv.parameters())
836
+ + sum(p.numel() for p in block.txt_attn_proj.parameters())
837
+ + sum(p.numel() for p in block.txt_mlp.parameters())
838
+ for block in self.double_blocks
839
+ ]
840
+ ),
841
+ "single": sum(
842
+ [
843
+ sum(p.numel() for p in block.linear1.parameters()) + sum(p.numel() for p in block.linear2.parameters())
844
+ for block in self.single_blocks
845
+ ]
846
+ ),
847
+ "total": sum(p.numel() for p in self.parameters()),
848
+ }
849
+ counts["attn+mlp"] = counts["double"] + counts["single"]
850
+ return counts
851
+
852
+
853
+ #################################################################################
854
+ # HunyuanVideo Configs #
855
+ #################################################################################
856
+
857
+ HUNYUAN_VIDEO_CONFIG = {
858
+ "HYVideo-T/2": {
859
+ "mm_double_blocks_depth": 20,
860
+ "mm_single_blocks_depth": 40,
861
+ "rope_dim_list": [16, 56, 56],
862
+ "hidden_size": 3072,
863
+ "heads_num": 24,
864
+ "mlp_width_ratio": 4,
865
+ },
866
+ "HYVideo-T/2-cfgdistill": {
867
+ "mm_double_blocks_depth": 20,
868
+ "mm_single_blocks_depth": 40,
869
+ "rope_dim_list": [16, 56, 56],
870
+ "hidden_size": 3072,
871
+ "heads_num": 24,
872
+ "mlp_width_ratio": 4,
873
+ "guidance_embed": True,
874
+ },
875
+ }
876
+
877
+
878
+ def load_dit_model(text_states_dim, text_states_dim_2, in_channels, out_channels, factor_kwargs):
879
+ """load hunyuan video model
880
+
881
+ NOTE: Only support HYVideo-T/2-cfgdistill now.
882
+
883
+ Args:
884
+ text_state_dim (int): text state dimension
885
+ text_state_dim_2 (int): text state dimension 2
886
+ in_channels (int): input channels number
887
+ out_channels (int): output channels number
888
+ factor_kwargs (dict): factor kwargs
889
+
890
+ Returns:
891
+ model (nn.Module): The hunyuan video model
892
+ """
893
+ # if args.model in HUNYUAN_VIDEO_CONFIG.keys():
894
+ model = HYVideoDiffusionTransformer(
895
+ text_states_dim=text_states_dim,
896
+ text_states_dim_2=text_states_dim_2,
897
+ in_channels=in_channels,
898
+ out_channels=out_channels,
899
+ **HUNYUAN_VIDEO_CONFIG["HYVideo-T/2-cfgdistill"],
900
+ **factor_kwargs,
901
+ )
902
+ return model
903
+ # else:
904
+ # raise NotImplementedError()
905
+
906
+
907
+ def load_state_dict(model, model_path):
908
+ state_dict = torch.load(model_path, map_location=lambda storage, loc: storage, weights_only=True)
909
+
910
+ load_key = "module"
911
+ if load_key in state_dict:
912
+ state_dict = state_dict[load_key]
913
+ else:
914
+ raise KeyError(
915
+ f"Missing key: `{load_key}` in the checkpoint: {model_path}. The keys in the checkpoint "
916
+ f"are: {list(state_dict.keys())}."
917
+ )
918
+ model.load_state_dict(state_dict, strict=True, assign=True)
919
+ return model
920
+
921
+
922
+ def load_transformer(dit_path, attn_mode, device, dtype) -> HYVideoDiffusionTransformer:
923
+ # =========================== Build main model ===========================
924
+ factor_kwargs = {"device": device, "dtype": dtype, "attn_mode": attn_mode}
925
+ latent_channels = 16
926
+ in_channels = latent_channels
927
+ out_channels = latent_channels
928
+
929
+ with accelerate.init_empty_weights():
930
+ transformer = load_dit_model(
931
+ text_states_dim=4096,
932
+ text_states_dim_2=768,
933
+ in_channels=in_channels,
934
+ out_channels=out_channels,
935
+ factor_kwargs=factor_kwargs,
936
+ )
937
+
938
+ if os.path.splitext(dit_path)[-1] == ".safetensors":
939
+ # loading safetensors: may be already fp8
940
+ with MemoryEfficientSafeOpen(dit_path) as f:
941
+ state_dict = {}
942
+ for k in f.keys():
943
+ tensor = f.get_tensor(k)
944
+ tensor = tensor.to(device=device, dtype=dtype)
945
+ # TODO support comfy model
946
+ # if k.startswith("model.model."):
947
+ # k = convert_comfy_model_key(k)
948
+ state_dict[k] = tensor
949
+ transformer.load_state_dict(state_dict, strict=True, assign=True)
950
+ else:
951
+ transformer = load_state_dict(transformer, dit_path)
952
+
953
+ return transformer
954
+
955
+
956
+ def get_rotary_pos_embed_by_shape(model, latents_size):
957
+ target_ndim = 3
958
+ ndim = 5 - 2
959
+
960
+ if isinstance(model.patch_size, int):
961
+ assert all(s % model.patch_size == 0 for s in latents_size), (
962
+ f"Latent size(last {ndim} dimensions) should be divisible by patch size({model.patch_size}), "
963
+ f"but got {latents_size}."
964
+ )
965
+ rope_sizes = [s // model.patch_size for s in latents_size]
966
+ elif isinstance(model.patch_size, list):
967
+ assert all(s % model.patch_size[idx] == 0 for idx, s in enumerate(latents_size)), (
968
+ f"Latent size(last {ndim} dimensions) should be divisible by patch size({model.patch_size}), "
969
+ f"but got {latents_size}."
970
+ )
971
+ rope_sizes = [s // model.patch_size[idx] for idx, s in enumerate(latents_size)]
972
+
973
+ if len(rope_sizes) != target_ndim:
974
+ rope_sizes = [1] * (target_ndim - len(rope_sizes)) + rope_sizes # time axis
975
+ head_dim = model.hidden_size // model.heads_num
976
+ rope_dim_list = model.rope_dim_list
977
+ if rope_dim_list is None:
978
+ rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
979
+ assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer"
980
+
981
+ rope_theta = 256
982
+ freqs_cos, freqs_sin = get_nd_rotary_pos_embed(
983
+ rope_dim_list, rope_sizes, theta=rope_theta, use_real=True, theta_rescale_factor=1
984
+ )
985
+ return freqs_cos, freqs_sin
986
+
987
+
988
+ def get_rotary_pos_embed(vae_name, model, video_length, height, width):
989
+ # 884
990
+ if "884" in vae_name:
991
+ latents_size = [(video_length - 1) // 4 + 1, height // 8, width // 8]
992
+ elif "888" in vae_name:
993
+ latents_size = [(video_length - 1) // 8 + 1, height // 8, width // 8]
994
+ else:
995
+ latents_size = [video_length, height // 8, width // 8]
996
+
997
+ return get_rotary_pos_embed_by_shape(model, latents_size)
hunyuan_model/modulate_layers.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+
7
+ class ModulateDiT(nn.Module):
8
+ """Modulation layer for DiT."""
9
+ def __init__(
10
+ self,
11
+ hidden_size: int,
12
+ factor: int,
13
+ act_layer: Callable,
14
+ dtype=None,
15
+ device=None,
16
+ ):
17
+ factory_kwargs = {"dtype": dtype, "device": device}
18
+ super().__init__()
19
+ self.act = act_layer()
20
+ self.linear = nn.Linear(
21
+ hidden_size, factor * hidden_size, bias=True, **factory_kwargs
22
+ )
23
+ # Zero-initialize the modulation
24
+ nn.init.zeros_(self.linear.weight)
25
+ nn.init.zeros_(self.linear.bias)
26
+
27
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
28
+ return self.linear(self.act(x))
29
+
30
+
31
+ def modulate(x, shift=None, scale=None):
32
+ """modulate by shift and scale
33
+
34
+ Args:
35
+ x (torch.Tensor): input tensor.
36
+ shift (torch.Tensor, optional): shift tensor. Defaults to None.
37
+ scale (torch.Tensor, optional): scale tensor. Defaults to None.
38
+
39
+ Returns:
40
+ torch.Tensor: the output tensor after modulate.
41
+ """
42
+ if scale is None and shift is None:
43
+ return x
44
+ elif shift is None:
45
+ return x * (1 + scale.unsqueeze(1))
46
+ elif scale is None:
47
+ return x + shift.unsqueeze(1)
48
+ else:
49
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
50
+
51
+
52
+ def apply_gate(x, gate=None, tanh=False):
53
+ """AI is creating summary for apply_gate
54
+
55
+ Args:
56
+ x (torch.Tensor): input tensor.
57
+ gate (torch.Tensor, optional): gate tensor. Defaults to None.
58
+ tanh (bool, optional): whether to use tanh function. Defaults to False.
59
+
60
+ Returns:
61
+ torch.Tensor: the output tensor after apply gate.
62
+ """
63
+ if gate is None:
64
+ return x
65
+ if tanh:
66
+ return x * gate.unsqueeze(1).tanh()
67
+ else:
68
+ return x * gate.unsqueeze(1)
69
+
70
+
71
+ def ckpt_wrapper(module):
72
+ def ckpt_forward(*inputs):
73
+ outputs = module(*inputs)
74
+ return outputs
75
+
76
+ return ckpt_forward
hunyuan_model/norm_layers.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class RMSNorm(nn.Module):
6
+ def __init__(
7
+ self,
8
+ dim: int,
9
+ elementwise_affine=True,
10
+ eps: float = 1e-6,
11
+ device=None,
12
+ dtype=None,
13
+ ):
14
+ """
15
+ Initialize the RMSNorm normalization layer.
16
+
17
+ Args:
18
+ dim (int): The dimension of the input tensor.
19
+ eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
20
+
21
+ Attributes:
22
+ eps (float): A small value added to the denominator for numerical stability.
23
+ weight (nn.Parameter): Learnable scaling parameter.
24
+
25
+ """
26
+ factory_kwargs = {"device": device, "dtype": dtype}
27
+ super().__init__()
28
+ self.eps = eps
29
+ if elementwise_affine:
30
+ self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs))
31
+
32
+ def _norm(self, x):
33
+ """
34
+ Apply the RMSNorm normalization to the input tensor.
35
+
36
+ Args:
37
+ x (torch.Tensor): The input tensor.
38
+
39
+ Returns:
40
+ torch.Tensor: The normalized tensor.
41
+
42
+ """
43
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
44
+
45
+ def forward(self, x):
46
+ """
47
+ Forward pass through the RMSNorm layer.
48
+
49
+ Args:
50
+ x (torch.Tensor): The input tensor.
51
+
52
+ Returns:
53
+ torch.Tensor: The output tensor after applying RMSNorm.
54
+
55
+ """
56
+ output = self._norm(x.float()).type_as(x)
57
+ if hasattr(self, "weight"):
58
+ # output = output * self.weight
59
+ # support fp8
60
+ output = output * self.weight.to(output.dtype)
61
+ return output
62
+
63
+
64
+ def get_norm_layer(norm_layer):
65
+ """
66
+ Get the normalization layer.
67
+
68
+ Args:
69
+ norm_layer (str): The type of normalization layer.
70
+
71
+ Returns:
72
+ norm_layer (nn.Module): The normalization layer.
73
+ """
74
+ if norm_layer == "layer":
75
+ return nn.LayerNorm
76
+ elif norm_layer == "rms":
77
+ return RMSNorm
78
+ else:
79
+ raise NotImplementedError(f"Norm layer {norm_layer} is not implemented")
hunyuan_model/pipeline_hunyuan_video.py ADDED
@@ -0,0 +1,1100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ #
16
+ # Modified from diffusers==0.29.2
17
+ #
18
+ # ==============================================================================
19
+ import inspect
20
+ from typing import Any, Callable, Dict, List, Optional, Union, Tuple
21
+ import torch
22
+ import torch.distributed as dist
23
+ import numpy as np
24
+ from dataclasses import dataclass
25
+ from packaging import version
26
+
27
+ from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
28
+ from diffusers.configuration_utils import FrozenDict
29
+ from diffusers.image_processor import VaeImageProcessor
30
+ from diffusers.loaders import LoraLoaderMixin, TextualInversionLoaderMixin
31
+ from diffusers.models import AutoencoderKL
32
+ from diffusers.models.lora import adjust_lora_scale_text_encoder
33
+ from diffusers.schedulers import KarrasDiffusionSchedulers
34
+ from diffusers.utils import (
35
+ USE_PEFT_BACKEND,
36
+ deprecate,
37
+ logging,
38
+ replace_example_docstring,
39
+ scale_lora_layers,
40
+ unscale_lora_layers,
41
+ )
42
+ from diffusers.utils.torch_utils import randn_tensor
43
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
44
+ from diffusers.utils import BaseOutput
45
+
46
+ from ...constants import PRECISION_TO_TYPE
47
+ from ...vae.autoencoder_kl_causal_3d import AutoencoderKLCausal3D
48
+ from ...text_encoder import TextEncoder
49
+ from ...modules import HYVideoDiffusionTransformer
50
+
51
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
52
+
53
+ EXAMPLE_DOC_STRING = """"""
54
+
55
+
56
+ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
57
+ """
58
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
59
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
60
+ """
61
+ std_text = noise_pred_text.std(
62
+ dim=list(range(1, noise_pred_text.ndim)), keepdim=True
63
+ )
64
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
65
+ # rescale the results from guidance (fixes overexposure)
66
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
67
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
68
+ noise_cfg = (
69
+ guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
70
+ )
71
+ return noise_cfg
72
+
73
+
74
+ def retrieve_timesteps(
75
+ scheduler,
76
+ num_inference_steps: Optional[int] = None,
77
+ device: Optional[Union[str, torch.device]] = None,
78
+ timesteps: Optional[List[int]] = None,
79
+ sigmas: Optional[List[float]] = None,
80
+ **kwargs,
81
+ ):
82
+ """
83
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
84
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
85
+
86
+ Args:
87
+ scheduler (`SchedulerMixin`):
88
+ The scheduler to get timesteps from.
89
+ num_inference_steps (`int`):
90
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
91
+ must be `None`.
92
+ device (`str` or `torch.device`, *optional*):
93
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
94
+ timesteps (`List[int]`, *optional*):
95
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
96
+ `num_inference_steps` and `sigmas` must be `None`.
97
+ sigmas (`List[float]`, *optional*):
98
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
99
+ `num_inference_steps` and `timesteps` must be `None`.
100
+
101
+ Returns:
102
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
103
+ second element is the number of inference steps.
104
+ """
105
+ if timesteps is not None and sigmas is not None:
106
+ raise ValueError(
107
+ "Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values"
108
+ )
109
+ if timesteps is not None:
110
+ accepts_timesteps = "timesteps" in set(
111
+ inspect.signature(scheduler.set_timesteps).parameters.keys()
112
+ )
113
+ if not accepts_timesteps:
114
+ raise ValueError(
115
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
116
+ f" timestep schedules. Please check whether you are using the correct scheduler."
117
+ )
118
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
119
+ timesteps = scheduler.timesteps
120
+ num_inference_steps = len(timesteps)
121
+ elif sigmas is not None:
122
+ accept_sigmas = "sigmas" in set(
123
+ inspect.signature(scheduler.set_timesteps).parameters.keys()
124
+ )
125
+ if not accept_sigmas:
126
+ raise ValueError(
127
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
128
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
129
+ )
130
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
131
+ timesteps = scheduler.timesteps
132
+ num_inference_steps = len(timesteps)
133
+ else:
134
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
135
+ timesteps = scheduler.timesteps
136
+ return timesteps, num_inference_steps
137
+
138
+
139
+ @dataclass
140
+ class HunyuanVideoPipelineOutput(BaseOutput):
141
+ videos: Union[torch.Tensor, np.ndarray]
142
+
143
+
144
+ class HunyuanVideoPipeline(DiffusionPipeline):
145
+ r"""
146
+ Pipeline for text-to-video generation using HunyuanVideo.
147
+
148
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
149
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
150
+
151
+ Args:
152
+ vae ([`AutoencoderKL`]):
153
+ Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
154
+ text_encoder ([`TextEncoder`]):
155
+ Frozen text-encoder.
156
+ text_encoder_2 ([`TextEncoder`]):
157
+ Frozen text-encoder_2.
158
+ transformer ([`HYVideoDiffusionTransformer`]):
159
+ A `HYVideoDiffusionTransformer` to denoise the encoded video latents.
160
+ scheduler ([`SchedulerMixin`]):
161
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents.
162
+ """
163
+
164
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
165
+ _optional_components = ["text_encoder_2"]
166
+ _exclude_from_cpu_offload = ["transformer"]
167
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
168
+
169
+ def __init__(
170
+ self,
171
+ vae: AutoencoderKL,
172
+ text_encoder: TextEncoder,
173
+ transformer: HYVideoDiffusionTransformer,
174
+ scheduler: KarrasDiffusionSchedulers,
175
+ text_encoder_2: Optional[TextEncoder] = None,
176
+ progress_bar_config: Dict[str, Any] = None,
177
+ args=None,
178
+ ):
179
+ super().__init__()
180
+
181
+ # ==========================================================================================
182
+ if progress_bar_config is None:
183
+ progress_bar_config = {}
184
+ if not hasattr(self, "_progress_bar_config"):
185
+ self._progress_bar_config = {}
186
+ self._progress_bar_config.update(progress_bar_config)
187
+
188
+ self.args = args
189
+ # ==========================================================================================
190
+
191
+ if (
192
+ hasattr(scheduler.config, "steps_offset")
193
+ and scheduler.config.steps_offset != 1
194
+ ):
195
+ deprecation_message = (
196
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
197
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
198
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
199
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
200
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
201
+ " file"
202
+ )
203
+ deprecate(
204
+ "steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False
205
+ )
206
+ new_config = dict(scheduler.config)
207
+ new_config["steps_offset"] = 1
208
+ scheduler._internal_dict = FrozenDict(new_config)
209
+
210
+ if (
211
+ hasattr(scheduler.config, "clip_sample")
212
+ and scheduler.config.clip_sample is True
213
+ ):
214
+ deprecation_message = (
215
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
216
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
217
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
218
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
219
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
220
+ )
221
+ deprecate(
222
+ "clip_sample not set", "1.0.0", deprecation_message, standard_warn=False
223
+ )
224
+ new_config = dict(scheduler.config)
225
+ new_config["clip_sample"] = False
226
+ scheduler._internal_dict = FrozenDict(new_config)
227
+
228
+ self.register_modules(
229
+ vae=vae,
230
+ text_encoder=text_encoder,
231
+ transformer=transformer,
232
+ scheduler=scheduler,
233
+ text_encoder_2=text_encoder_2,
234
+ )
235
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
236
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
237
+
238
+ def encode_prompt(
239
+ self,
240
+ prompt,
241
+ device,
242
+ num_videos_per_prompt,
243
+ do_classifier_free_guidance,
244
+ negative_prompt=None,
245
+ prompt_embeds: Optional[torch.Tensor] = None,
246
+ attention_mask: Optional[torch.Tensor] = None,
247
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
248
+ negative_attention_mask: Optional[torch.Tensor] = None,
249
+ lora_scale: Optional[float] = None,
250
+ clip_skip: Optional[int] = None,
251
+ text_encoder: Optional[TextEncoder] = None,
252
+ data_type: Optional[str] = "image",
253
+ ):
254
+ r"""
255
+ Encodes the prompt into text encoder hidden states.
256
+
257
+ Args:
258
+ prompt (`str` or `List[str]`, *optional*):
259
+ prompt to be encoded
260
+ device: (`torch.device`):
261
+ torch device
262
+ num_videos_per_prompt (`int`):
263
+ number of videos that should be generated per prompt
264
+ do_classifier_free_guidance (`bool`):
265
+ whether to use classifier free guidance or not
266
+ negative_prompt (`str` or `List[str]`, *optional*):
267
+ The prompt or prompts not to guide the video generation. If not defined, one has to pass
268
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
269
+ less than `1`).
270
+ prompt_embeds (`torch.Tensor`, *optional*):
271
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
272
+ provided, text embeddings will be generated from `prompt` input argument.
273
+ attention_mask (`torch.Tensor`, *optional*):
274
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
275
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
276
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
277
+ argument.
278
+ negative_attention_mask (`torch.Tensor`, *optional*):
279
+ lora_scale (`float`, *optional*):
280
+ A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
281
+ clip_skip (`int`, *optional*):
282
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
283
+ the output of the pre-final layer will be used for computing the prompt embeddings.
284
+ text_encoder (TextEncoder, *optional*):
285
+ data_type (`str`, *optional*):
286
+ """
287
+ if text_encoder is None:
288
+ text_encoder = self.text_encoder
289
+
290
+ # set lora scale so that monkey patched LoRA
291
+ # function of text encoder can correctly access it
292
+ if lora_scale is not None and isinstance(self, LoraLoaderMixin):
293
+ self._lora_scale = lora_scale
294
+
295
+ # dynamically adjust the LoRA scale
296
+ if not USE_PEFT_BACKEND:
297
+ adjust_lora_scale_text_encoder(text_encoder.model, lora_scale)
298
+ else:
299
+ scale_lora_layers(text_encoder.model, lora_scale)
300
+
301
+ if prompt is not None and isinstance(prompt, str):
302
+ batch_size = 1
303
+ elif prompt is not None and isinstance(prompt, list):
304
+ batch_size = len(prompt)
305
+ else:
306
+ batch_size = prompt_embeds.shape[0]
307
+
308
+ if prompt_embeds is None:
309
+ # textual inversion: process multi-vector tokens if necessary
310
+ if isinstance(self, TextualInversionLoaderMixin):
311
+ prompt = self.maybe_convert_prompt(prompt, text_encoder.tokenizer)
312
+
313
+ text_inputs = text_encoder.text2tokens(prompt, data_type=data_type)
314
+
315
+ if clip_skip is None:
316
+ prompt_outputs = text_encoder.encode(
317
+ text_inputs, data_type=data_type, device=device
318
+ )
319
+ prompt_embeds = prompt_outputs.hidden_state
320
+ else:
321
+ prompt_outputs = text_encoder.encode(
322
+ text_inputs,
323
+ output_hidden_states=True,
324
+ data_type=data_type,
325
+ device=device,
326
+ )
327
+ # Access the `hidden_states` first, that contains a tuple of
328
+ # all the hidden states from the encoder layers. Then index into
329
+ # the tuple to access the hidden states from the desired layer.
330
+ prompt_embeds = prompt_outputs.hidden_states_list[-(clip_skip + 1)]
331
+ # We also need to apply the final LayerNorm here to not mess with the
332
+ # representations. The `last_hidden_states` that we typically use for
333
+ # obtaining the final prompt representations passes through the LayerNorm
334
+ # layer.
335
+ prompt_embeds = text_encoder.model.text_model.final_layer_norm(
336
+ prompt_embeds
337
+ )
338
+
339
+ attention_mask = prompt_outputs.attention_mask
340
+ if attention_mask is not None:
341
+ attention_mask = attention_mask.to(device)
342
+ bs_embed, seq_len = attention_mask.shape
343
+ attention_mask = attention_mask.repeat(1, num_videos_per_prompt)
344
+ attention_mask = attention_mask.view(
345
+ bs_embed * num_videos_per_prompt, seq_len
346
+ )
347
+
348
+ if text_encoder is not None:
349
+ prompt_embeds_dtype = text_encoder.dtype
350
+ elif self.transformer is not None:
351
+ prompt_embeds_dtype = self.transformer.dtype
352
+ else:
353
+ prompt_embeds_dtype = prompt_embeds.dtype
354
+
355
+ prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
356
+
357
+ if prompt_embeds.ndim == 2:
358
+ bs_embed, _ = prompt_embeds.shape
359
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
360
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt)
361
+ prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, -1)
362
+ else:
363
+ bs_embed, seq_len, _ = prompt_embeds.shape
364
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
365
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
366
+ prompt_embeds = prompt_embeds.view(
367
+ bs_embed * num_videos_per_prompt, seq_len, -1
368
+ )
369
+
370
+ # get unconditional embeddings for classifier free guidance
371
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
372
+ uncond_tokens: List[str]
373
+ if negative_prompt is None:
374
+ uncond_tokens = [""] * batch_size
375
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
376
+ raise TypeError(
377
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
378
+ f" {type(prompt)}."
379
+ )
380
+ elif isinstance(negative_prompt, str):
381
+ uncond_tokens = [negative_prompt]
382
+ elif batch_size != len(negative_prompt):
383
+ raise ValueError(
384
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
385
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
386
+ " the batch size of `prompt`."
387
+ )
388
+ else:
389
+ uncond_tokens = negative_prompt
390
+
391
+ # textual inversion: process multi-vector tokens if necessary
392
+ if isinstance(self, TextualInversionLoaderMixin):
393
+ uncond_tokens = self.maybe_convert_prompt(
394
+ uncond_tokens, text_encoder.tokenizer
395
+ )
396
+
397
+ # max_length = prompt_embeds.shape[1]
398
+ uncond_input = text_encoder.text2tokens(uncond_tokens, data_type=data_type)
399
+
400
+ negative_prompt_outputs = text_encoder.encode(
401
+ uncond_input, data_type=data_type, device=device
402
+ )
403
+ negative_prompt_embeds = negative_prompt_outputs.hidden_state
404
+
405
+ negative_attention_mask = negative_prompt_outputs.attention_mask
406
+ if negative_attention_mask is not None:
407
+ negative_attention_mask = negative_attention_mask.to(device)
408
+ _, seq_len = negative_attention_mask.shape
409
+ negative_attention_mask = negative_attention_mask.repeat(
410
+ 1, num_videos_per_prompt
411
+ )
412
+ negative_attention_mask = negative_attention_mask.view(
413
+ batch_size * num_videos_per_prompt, seq_len
414
+ )
415
+
416
+ if do_classifier_free_guidance:
417
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
418
+ seq_len = negative_prompt_embeds.shape[1]
419
+
420
+ negative_prompt_embeds = negative_prompt_embeds.to(
421
+ dtype=prompt_embeds_dtype, device=device
422
+ )
423
+
424
+ if negative_prompt_embeds.ndim == 2:
425
+ negative_prompt_embeds = negative_prompt_embeds.repeat(
426
+ 1, num_videos_per_prompt
427
+ )
428
+ negative_prompt_embeds = negative_prompt_embeds.view(
429
+ batch_size * num_videos_per_prompt, -1
430
+ )
431
+ else:
432
+ negative_prompt_embeds = negative_prompt_embeds.repeat(
433
+ 1, num_videos_per_prompt, 1
434
+ )
435
+ negative_prompt_embeds = negative_prompt_embeds.view(
436
+ batch_size * num_videos_per_prompt, seq_len, -1
437
+ )
438
+
439
+ if text_encoder is not None:
440
+ if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
441
+ # Retrieve the original scale by scaling back the LoRA layers
442
+ unscale_lora_layers(text_encoder.model, lora_scale)
443
+
444
+ return (
445
+ prompt_embeds,
446
+ negative_prompt_embeds,
447
+ attention_mask,
448
+ negative_attention_mask,
449
+ )
450
+
451
+ def decode_latents(self, latents, enable_tiling=True):
452
+ deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
453
+ deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
454
+
455
+ latents = 1 / self.vae.config.scaling_factor * latents
456
+ if enable_tiling:
457
+ self.vae.enable_tiling()
458
+ image = self.vae.decode(latents, return_dict=False)[0]
459
+ else:
460
+ image = self.vae.decode(latents, return_dict=False)[0]
461
+ image = (image / 2 + 0.5).clamp(0, 1)
462
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
463
+ if image.ndim == 4:
464
+ image = image.cpu().permute(0, 2, 3, 1).float()
465
+ else:
466
+ image = image.cpu().float()
467
+ return image
468
+
469
+ def prepare_extra_func_kwargs(self, func, kwargs):
470
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
471
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
472
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
473
+ # and should be between [0, 1]
474
+ extra_step_kwargs = {}
475
+
476
+ for k, v in kwargs.items():
477
+ accepts = k in set(inspect.signature(func).parameters.keys())
478
+ if accepts:
479
+ extra_step_kwargs[k] = v
480
+ return extra_step_kwargs
481
+
482
+ def check_inputs(
483
+ self,
484
+ prompt,
485
+ height,
486
+ width,
487
+ video_length,
488
+ callback_steps,
489
+ negative_prompt=None,
490
+ prompt_embeds=None,
491
+ negative_prompt_embeds=None,
492
+ callback_on_step_end_tensor_inputs=None,
493
+ vae_ver="88-4c-sd",
494
+ ):
495
+ if height % 8 != 0 or width % 8 != 0:
496
+ raise ValueError(
497
+ f"`height` and `width` have to be divisible by 8 but are {height} and {width}."
498
+ )
499
+
500
+ if video_length is not None:
501
+ if "884" in vae_ver:
502
+ if video_length != 1 and (video_length - 1) % 4 != 0:
503
+ raise ValueError(
504
+ f"`video_length` has to be 1 or a multiple of 4 but is {video_length}."
505
+ )
506
+ elif "888" in vae_ver:
507
+ if video_length != 1 and (video_length - 1) % 8 != 0:
508
+ raise ValueError(
509
+ f"`video_length` has to be 1 or a multiple of 8 but is {video_length}."
510
+ )
511
+
512
+ if callback_steps is not None and (
513
+ not isinstance(callback_steps, int) or callback_steps <= 0
514
+ ):
515
+ raise ValueError(
516
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
517
+ f" {type(callback_steps)}."
518
+ )
519
+ if callback_on_step_end_tensor_inputs is not None and not all(
520
+ k in self._callback_tensor_inputs
521
+ for k in callback_on_step_end_tensor_inputs
522
+ ):
523
+ raise ValueError(
524
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
525
+ )
526
+
527
+ if prompt is not None and prompt_embeds is not None:
528
+ raise ValueError(
529
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
530
+ " only forward one of the two."
531
+ )
532
+ elif prompt is None and prompt_embeds is None:
533
+ raise ValueError(
534
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
535
+ )
536
+ elif prompt is not None and (
537
+ not isinstance(prompt, str) and not isinstance(prompt, list)
538
+ ):
539
+ raise ValueError(
540
+ f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
541
+ )
542
+
543
+ if negative_prompt is not None and negative_prompt_embeds is not None:
544
+ raise ValueError(
545
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
546
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
547
+ )
548
+
549
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
550
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
551
+ raise ValueError(
552
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
553
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
554
+ f" {negative_prompt_embeds.shape}."
555
+ )
556
+
557
+
558
+ def prepare_latents(
559
+ self,
560
+ batch_size,
561
+ num_channels_latents,
562
+ height,
563
+ width,
564
+ video_length,
565
+ dtype,
566
+ device,
567
+ generator,
568
+ latents=None,
569
+ ):
570
+ shape = (
571
+ batch_size,
572
+ num_channels_latents,
573
+ video_length,
574
+ int(height) // self.vae_scale_factor,
575
+ int(width) // self.vae_scale_factor,
576
+ )
577
+ if isinstance(generator, list) and len(generator) != batch_size:
578
+ raise ValueError(
579
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
580
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
581
+ )
582
+
583
+ if latents is None:
584
+ latents = randn_tensor(
585
+ shape, generator=generator, device=device, dtype=dtype
586
+ )
587
+ else:
588
+ latents = latents.to(device)
589
+
590
+ # Check existence to make it compatible with FlowMatchEulerDiscreteScheduler
591
+ if hasattr(self.scheduler, "init_noise_sigma"):
592
+ # scale the initial noise by the standard deviation required by the scheduler
593
+ latents = latents * self.scheduler.init_noise_sigma
594
+ return latents
595
+
596
+ # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
597
+ def get_guidance_scale_embedding(
598
+ self,
599
+ w: torch.Tensor,
600
+ embedding_dim: int = 512,
601
+ dtype: torch.dtype = torch.float32,
602
+ ) -> torch.Tensor:
603
+ """
604
+ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
605
+
606
+ Args:
607
+ w (`torch.Tensor`):
608
+ Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings.
609
+ embedding_dim (`int`, *optional*, defaults to 512):
610
+ Dimension of the embeddings to generate.
611
+ dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
612
+ Data type of the generated embeddings.
613
+
614
+ Returns:
615
+ `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`.
616
+ """
617
+ assert len(w.shape) == 1
618
+ w = w * 1000.0
619
+
620
+ half_dim = embedding_dim // 2
621
+ emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
622
+ emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
623
+ emb = w.to(dtype)[:, None] * emb[None, :]
624
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
625
+ if embedding_dim % 2 == 1: # zero pad
626
+ emb = torch.nn.functional.pad(emb, (0, 1))
627
+ assert emb.shape == (w.shape[0], embedding_dim)
628
+ return emb
629
+
630
+ @property
631
+ def guidance_scale(self):
632
+ return self._guidance_scale
633
+
634
+ @property
635
+ def guidance_rescale(self):
636
+ return self._guidance_rescale
637
+
638
+ @property
639
+ def clip_skip(self):
640
+ return self._clip_skip
641
+
642
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
643
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
644
+ # corresponds to doing no classifier free guidance.
645
+ @property
646
+ def do_classifier_free_guidance(self):
647
+ # return self._guidance_scale > 1 and self.transformer.config.time_cond_proj_dim is None
648
+ return self._guidance_scale > 1
649
+
650
+ @property
651
+ def cross_attention_kwargs(self):
652
+ return self._cross_attention_kwargs
653
+
654
+ @property
655
+ def num_timesteps(self):
656
+ return self._num_timesteps
657
+
658
+ @property
659
+ def interrupt(self):
660
+ return self._interrupt
661
+
662
+ @torch.no_grad()
663
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
664
+ def __call__(
665
+ self,
666
+ prompt: Union[str, List[str]],
667
+ height: int,
668
+ width: int,
669
+ video_length: int,
670
+ data_type: str = "video",
671
+ num_inference_steps: int = 50,
672
+ timesteps: List[int] = None,
673
+ sigmas: List[float] = None,
674
+ guidance_scale: float = 7.5,
675
+ negative_prompt: Optional[Union[str, List[str]]] = None,
676
+ num_videos_per_prompt: Optional[int] = 1,
677
+ eta: float = 0.0,
678
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
679
+ latents: Optional[torch.Tensor] = None,
680
+ prompt_embeds: Optional[torch.Tensor] = None,
681
+ attention_mask: Optional[torch.Tensor] = None,
682
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
683
+ negative_attention_mask: Optional[torch.Tensor] = None,
684
+ output_type: Optional[str] = "pil",
685
+ return_dict: bool = True,
686
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
687
+ guidance_rescale: float = 0.0,
688
+ clip_skip: Optional[int] = None,
689
+ callback_on_step_end: Optional[
690
+ Union[
691
+ Callable[[int, int, Dict], None],
692
+ PipelineCallback,
693
+ MultiPipelineCallbacks,
694
+ ]
695
+ ] = None,
696
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
697
+ freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None,
698
+ vae_ver: str = "88-4c-sd",
699
+ enable_tiling: bool = False,
700
+ n_tokens: Optional[int] = None,
701
+ embedded_guidance_scale: Optional[float] = None,
702
+ **kwargs,
703
+ ):
704
+ r"""
705
+ The call function to the pipeline for generation.
706
+
707
+ Args:
708
+ prompt (`str` or `List[str]`):
709
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
710
+ height (`int`):
711
+ The height in pixels of the generated image.
712
+ width (`int`):
713
+ The width in pixels of the generated image.
714
+ video_length (`int`):
715
+ The number of frames in the generated video.
716
+ num_inference_steps (`int`, *optional*, defaults to 50):
717
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
718
+ expense of slower inference.
719
+ timesteps (`List[int]`, *optional*):
720
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
721
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
722
+ passed will be used. Must be in descending order.
723
+ sigmas (`List[float]`, *optional*):
724
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
725
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
726
+ will be used.
727
+ guidance_scale (`float`, *optional*, defaults to 7.5):
728
+ A higher guidance scale value encourages the model to generate images closely linked to the text
729
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
730
+ negative_prompt (`str` or `List[str]`, *optional*):
731
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
732
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
733
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
734
+ The number of images to generate per prompt.
735
+ eta (`float`, *optional*, defaults to 0.0):
736
+ Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
737
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
738
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
739
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
740
+ generation deterministic.
741
+ latents (`torch.Tensor`, *optional*):
742
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
743
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
744
+ tensor is generated by sampling using the supplied random `generator`.
745
+ prompt_embeds (`torch.Tensor`, *optional*):
746
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
747
+ provided, text embeddings are generated from the `prompt` input argument.
748
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
749
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
750
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
751
+
752
+ output_type (`str`, *optional*, defaults to `"pil"`):
753
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
754
+ return_dict (`bool`, *optional*, defaults to `True`):
755
+ Whether or not to return a [`HunyuanVideoPipelineOutput`] instead of a
756
+ plain tuple.
757
+ cross_attention_kwargs (`dict`, *optional*):
758
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
759
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
760
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
761
+ Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are
762
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when
763
+ using zero terminal SNR.
764
+ clip_skip (`int`, *optional*):
765
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
766
+ the output of the pre-final layer will be used for computing the prompt embeddings.
767
+ callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
768
+ A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
769
+ each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
770
+ DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
771
+ list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
772
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
773
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
774
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
775
+ `._callback_tensor_inputs` attribute of your pipeline class.
776
+
777
+ Examples:
778
+
779
+ Returns:
780
+ [`~HunyuanVideoPipelineOutput`] or `tuple`:
781
+ If `return_dict` is `True`, [`HunyuanVideoPipelineOutput`] is returned,
782
+ otherwise a `tuple` is returned where the first element is a list with the generated images and the
783
+ second element is a list of `bool`s indicating whether the corresponding generated image contains
784
+ "not-safe-for-work" (nsfw) content.
785
+ """
786
+ callback = kwargs.pop("callback", None)
787
+ callback_steps = kwargs.pop("callback_steps", None)
788
+
789
+ if callback is not None:
790
+ deprecate(
791
+ "callback",
792
+ "1.0.0",
793
+ "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
794
+ )
795
+ if callback_steps is not None:
796
+ deprecate(
797
+ "callback_steps",
798
+ "1.0.0",
799
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
800
+ )
801
+
802
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
803
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
804
+
805
+ # 0. Default height and width to unet
806
+ # height = height or self.transformer.config.sample_size * self.vae_scale_factor
807
+ # width = width or self.transformer.config.sample_size * self.vae_scale_factor
808
+ # to deal with lora scaling and other possible forward hooks
809
+
810
+ # 1. Check inputs. Raise error if not correct
811
+ self.check_inputs(
812
+ prompt,
813
+ height,
814
+ width,
815
+ video_length,
816
+ callback_steps,
817
+ negative_prompt,
818
+ prompt_embeds,
819
+ negative_prompt_embeds,
820
+ callback_on_step_end_tensor_inputs,
821
+ vae_ver=vae_ver,
822
+ )
823
+
824
+ self._guidance_scale = guidance_scale
825
+ self._guidance_rescale = guidance_rescale
826
+ self._clip_skip = clip_skip
827
+ self._cross_attention_kwargs = cross_attention_kwargs
828
+ self._interrupt = False
829
+
830
+ # 2. Define call parameters
831
+ if prompt is not None and isinstance(prompt, str):
832
+ batch_size = 1
833
+ elif prompt is not None and isinstance(prompt, list):
834
+ batch_size = len(prompt)
835
+ else:
836
+ batch_size = prompt_embeds.shape[0]
837
+
838
+ device = torch.device(f"cuda:{dist.get_rank()}") if dist.is_initialized() else self._execution_device
839
+
840
+ # 3. Encode input prompt
841
+ lora_scale = (
842
+ self.cross_attention_kwargs.get("scale", None)
843
+ if self.cross_attention_kwargs is not None
844
+ else None
845
+ )
846
+
847
+ (
848
+ prompt_embeds,
849
+ negative_prompt_embeds,
850
+ prompt_mask,
851
+ negative_prompt_mask,
852
+ ) = self.encode_prompt(
853
+ prompt,
854
+ device,
855
+ num_videos_per_prompt,
856
+ self.do_classifier_free_guidance,
857
+ negative_prompt,
858
+ prompt_embeds=prompt_embeds,
859
+ attention_mask=attention_mask,
860
+ negative_prompt_embeds=negative_prompt_embeds,
861
+ negative_attention_mask=negative_attention_mask,
862
+ lora_scale=lora_scale,
863
+ clip_skip=self.clip_skip,
864
+ data_type=data_type,
865
+ )
866
+ if self.text_encoder_2 is not None:
867
+ (
868
+ prompt_embeds_2,
869
+ negative_prompt_embeds_2,
870
+ prompt_mask_2,
871
+ negative_prompt_mask_2,
872
+ ) = self.encode_prompt(
873
+ prompt,
874
+ device,
875
+ num_videos_per_prompt,
876
+ self.do_classifier_free_guidance,
877
+ negative_prompt,
878
+ prompt_embeds=None,
879
+ attention_mask=None,
880
+ negative_prompt_embeds=None,
881
+ negative_attention_mask=None,
882
+ lora_scale=lora_scale,
883
+ clip_skip=self.clip_skip,
884
+ text_encoder=self.text_encoder_2,
885
+ data_type=data_type,
886
+ )
887
+ else:
888
+ prompt_embeds_2 = None
889
+ negative_prompt_embeds_2 = None
890
+ prompt_mask_2 = None
891
+ negative_prompt_mask_2 = None
892
+
893
+ # For classifier free guidance, we need to do two forward passes.
894
+ # Here we concatenate the unconditional and text embeddings into a single batch
895
+ # to avoid doing two forward passes
896
+ if self.do_classifier_free_guidance:
897
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
898
+ if prompt_mask is not None:
899
+ prompt_mask = torch.cat([negative_prompt_mask, prompt_mask])
900
+ if prompt_embeds_2 is not None:
901
+ prompt_embeds_2 = torch.cat([negative_prompt_embeds_2, prompt_embeds_2])
902
+ if prompt_mask_2 is not None:
903
+ prompt_mask_2 = torch.cat([negative_prompt_mask_2, prompt_mask_2])
904
+
905
+
906
+ # 4. Prepare timesteps
907
+ extra_set_timesteps_kwargs = self.prepare_extra_func_kwargs(
908
+ self.scheduler.set_timesteps, {"n_tokens": n_tokens}
909
+ )
910
+ timesteps, num_inference_steps = retrieve_timesteps(
911
+ self.scheduler,
912
+ num_inference_steps,
913
+ device,
914
+ timesteps,
915
+ sigmas,
916
+ **extra_set_timesteps_kwargs,
917
+ )
918
+
919
+ if "884" in vae_ver:
920
+ video_length = (video_length - 1) // 4 + 1
921
+ elif "888" in vae_ver:
922
+ video_length = (video_length - 1) // 8 + 1
923
+ else:
924
+ video_length = video_length
925
+
926
+ # 5. Prepare latent variables
927
+ num_channels_latents = self.transformer.config.in_channels
928
+ latents = self.prepare_latents(
929
+ batch_size * num_videos_per_prompt,
930
+ num_channels_latents,
931
+ height,
932
+ width,
933
+ video_length,
934
+ prompt_embeds.dtype,
935
+ device,
936
+ generator,
937
+ latents,
938
+ )
939
+
940
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
941
+ extra_step_kwargs = self.prepare_extra_func_kwargs(
942
+ self.scheduler.step,
943
+ {"generator": generator, "eta": eta},
944
+ )
945
+
946
+ target_dtype = PRECISION_TO_TYPE[self.args.precision]
947
+ autocast_enabled = (
948
+ target_dtype != torch.float32
949
+ ) and not self.args.disable_autocast
950
+ vae_dtype = PRECISION_TO_TYPE[self.args.vae_precision]
951
+ vae_autocast_enabled = (
952
+ vae_dtype != torch.float32
953
+ ) and not self.args.disable_autocast
954
+
955
+ # 7. Denoising loop
956
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
957
+ self._num_timesteps = len(timesteps)
958
+
959
+ # if is_progress_bar:
960
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
961
+ for i, t in enumerate(timesteps):
962
+ if self.interrupt:
963
+ continue
964
+
965
+ # expand the latents if we are doing classifier free guidance
966
+ latent_model_input = (
967
+ torch.cat([latents] * 2)
968
+ if self.do_classifier_free_guidance
969
+ else latents
970
+ )
971
+ latent_model_input = self.scheduler.scale_model_input(
972
+ latent_model_input, t
973
+ )
974
+
975
+ t_expand = t.repeat(latent_model_input.shape[0])
976
+ guidance_expand = (
977
+ torch.tensor(
978
+ [embedded_guidance_scale] * latent_model_input.shape[0],
979
+ dtype=torch.float32,
980
+ device=device,
981
+ ).to(target_dtype)
982
+ * 1000.0
983
+ if embedded_guidance_scale is not None
984
+ else None
985
+ )
986
+
987
+ # predict the noise residual
988
+ with torch.autocast(
989
+ device_type="cuda", dtype=target_dtype, enabled=autocast_enabled
990
+ ):
991
+ noise_pred = self.transformer( # For an input image (129, 192, 336) (1, 256, 256)
992
+ latent_model_input, # [2, 16, 33, 24, 42]
993
+ t_expand, # [2]
994
+ text_states=prompt_embeds, # [2, 256, 4096]
995
+ text_mask=prompt_mask, # [2, 256]
996
+ text_states_2=prompt_embeds_2, # [2, 768]
997
+ freqs_cos=freqs_cis[0], # [seqlen, head_dim]
998
+ freqs_sin=freqs_cis[1], # [seqlen, head_dim]
999
+ guidance=guidance_expand,
1000
+ return_dict=True,
1001
+ )[
1002
+ "x"
1003
+ ]
1004
+
1005
+ # perform guidance
1006
+ if self.do_classifier_free_guidance:
1007
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1008
+ noise_pred = noise_pred_uncond + self.guidance_scale * (
1009
+ noise_pred_text - noise_pred_uncond
1010
+ )
1011
+
1012
+ if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
1013
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
1014
+ noise_pred = rescale_noise_cfg(
1015
+ noise_pred,
1016
+ noise_pred_text,
1017
+ guidance_rescale=self.guidance_rescale,
1018
+ )
1019
+
1020
+ # compute the previous noisy sample x_t -> x_t-1
1021
+ latents = self.scheduler.step(
1022
+ noise_pred, t, latents, **extra_step_kwargs, return_dict=False
1023
+ )[0]
1024
+
1025
+ if callback_on_step_end is not None:
1026
+ callback_kwargs = {}
1027
+ for k in callback_on_step_end_tensor_inputs:
1028
+ callback_kwargs[k] = locals()[k]
1029
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1030
+
1031
+ latents = callback_outputs.pop("latents", latents)
1032
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1033
+ negative_prompt_embeds = callback_outputs.pop(
1034
+ "negative_prompt_embeds", negative_prompt_embeds
1035
+ )
1036
+
1037
+ # call the callback, if provided
1038
+ if i == len(timesteps) - 1 or (
1039
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
1040
+ ):
1041
+ if progress_bar is not None:
1042
+ progress_bar.update()
1043
+ if callback is not None and i % callback_steps == 0:
1044
+ step_idx = i // getattr(self.scheduler, "order", 1)
1045
+ callback(step_idx, t, latents)
1046
+
1047
+ if not output_type == "latent":
1048
+ expand_temporal_dim = False
1049
+ if len(latents.shape) == 4:
1050
+ if isinstance(self.vae, AutoencoderKLCausal3D):
1051
+ latents = latents.unsqueeze(2)
1052
+ expand_temporal_dim = True
1053
+ elif len(latents.shape) == 5:
1054
+ pass
1055
+ else:
1056
+ raise ValueError(
1057
+ f"Only support latents with shape (b, c, h, w) or (b, c, f, h, w), but got {latents.shape}."
1058
+ )
1059
+
1060
+ if (
1061
+ hasattr(self.vae.config, "shift_factor")
1062
+ and self.vae.config.shift_factor
1063
+ ):
1064
+ latents = (
1065
+ latents / self.vae.config.scaling_factor
1066
+ + self.vae.config.shift_factor
1067
+ )
1068
+ else:
1069
+ latents = latents / self.vae.config.scaling_factor
1070
+
1071
+ with torch.autocast(
1072
+ device_type="cuda", dtype=vae_dtype, enabled=vae_autocast_enabled
1073
+ ):
1074
+ if enable_tiling:
1075
+ self.vae.enable_tiling()
1076
+ image = self.vae.decode(
1077
+ latents, return_dict=False, generator=generator
1078
+ )[0]
1079
+ else:
1080
+ image = self.vae.decode(
1081
+ latents, return_dict=False, generator=generator
1082
+ )[0]
1083
+
1084
+ if expand_temporal_dim or image.shape[2] == 1:
1085
+ image = image.squeeze(2)
1086
+
1087
+ else:
1088
+ image = latents
1089
+
1090
+ image = (image / 2 + 0.5).clamp(0, 1)
1091
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
1092
+ image = image.cpu().float()
1093
+
1094
+ # Offload all models
1095
+ self.maybe_free_model_hooks()
1096
+
1097
+ if not return_dict:
1098
+ return image
1099
+
1100
+ return HunyuanVideoPipelineOutput(videos=image)
hunyuan_model/posemb_layers.py ADDED
@@ -0,0 +1,310 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Union, Tuple, List
3
+
4
+
5
+ def _to_tuple(x, dim=2):
6
+ if isinstance(x, int):
7
+ return (x,) * dim
8
+ elif len(x) == dim:
9
+ return x
10
+ else:
11
+ raise ValueError(f"Expected length {dim} or int, but got {x}")
12
+
13
+
14
+ def get_meshgrid_nd(start, *args, dim=2):
15
+ """
16
+ Get n-D meshgrid with start, stop and num.
17
+
18
+ Args:
19
+ start (int or tuple): If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop,
20
+ step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num. For n-dim, start/stop/num
21
+ should be int or n-tuple. If n-tuple is provided, the meshgrid will be stacked following the dim order in
22
+ n-tuples.
23
+ *args: See above.
24
+ dim (int): Dimension of the meshgrid. Defaults to 2.
25
+
26
+ Returns:
27
+ grid (np.ndarray): [dim, ...]
28
+ """
29
+ if len(args) == 0:
30
+ # start is grid_size
31
+ num = _to_tuple(start, dim=dim)
32
+ start = (0,) * dim
33
+ stop = num
34
+ elif len(args) == 1:
35
+ # start is start, args[0] is stop, step is 1
36
+ start = _to_tuple(start, dim=dim)
37
+ stop = _to_tuple(args[0], dim=dim)
38
+ num = [stop[i] - start[i] for i in range(dim)]
39
+ elif len(args) == 2:
40
+ # start is start, args[0] is stop, args[1] is num
41
+ start = _to_tuple(start, dim=dim) # Left-Top eg: 12,0
42
+ stop = _to_tuple(args[0], dim=dim) # Right-Bottom eg: 20,32
43
+ num = _to_tuple(args[1], dim=dim) # Target Size eg: 32,124
44
+ else:
45
+ raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}")
46
+
47
+ # PyTorch implement of np.linspace(start[i], stop[i], num[i], endpoint=False)
48
+ axis_grid = []
49
+ for i in range(dim):
50
+ a, b, n = start[i], stop[i], num[i]
51
+ g = torch.linspace(a, b, n + 1, dtype=torch.float32)[:n]
52
+ axis_grid.append(g)
53
+ grid = torch.meshgrid(*axis_grid, indexing="ij") # dim x [W, H, D]
54
+ grid = torch.stack(grid, dim=0) # [dim, W, H, D]
55
+
56
+ return grid
57
+
58
+
59
+ #################################################################################
60
+ # Rotary Positional Embedding Functions #
61
+ #################################################################################
62
+ # https://github.com/meta-llama/llama/blob/be327c427cc5e89cc1d3ab3d3fec4484df771245/llama/model.py#L80
63
+
64
+
65
+ def reshape_for_broadcast(
66
+ freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
67
+ x: torch.Tensor,
68
+ head_first=False,
69
+ ):
70
+ """
71
+ Reshape frequency tensor for broadcasting it with another tensor.
72
+
73
+ This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
74
+ for the purpose of broadcasting the frequency tensor during element-wise operations.
75
+
76
+ Notes:
77
+ When using FlashMHAModified, head_first should be False.
78
+ When using Attention, head_first should be True.
79
+
80
+ Args:
81
+ freqs_cis (Union[torch.Tensor, Tuple[torch.Tensor]]): Frequency tensor to be reshaped.
82
+ x (torch.Tensor): Target tensor for broadcasting compatibility.
83
+ head_first (bool): head dimension first (except batch dim) or not.
84
+
85
+ Returns:
86
+ torch.Tensor: Reshaped frequency tensor.
87
+
88
+ Raises:
89
+ AssertionError: If the frequency tensor doesn't match the expected shape.
90
+ AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions.
91
+ """
92
+ ndim = x.ndim
93
+ assert 0 <= 1 < ndim
94
+
95
+ if isinstance(freqs_cis, tuple):
96
+ # freqs_cis: (cos, sin) in real space
97
+ if head_first:
98
+ assert freqs_cis[0].shape == (
99
+ x.shape[-2],
100
+ x.shape[-1],
101
+ ), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}"
102
+ shape = [
103
+ d if i == ndim - 2 or i == ndim - 1 else 1
104
+ for i, d in enumerate(x.shape)
105
+ ]
106
+ else:
107
+ assert freqs_cis[0].shape == (
108
+ x.shape[1],
109
+ x.shape[-1],
110
+ ), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}"
111
+ shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
112
+ return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape)
113
+ else:
114
+ # freqs_cis: values in complex space
115
+ if head_first:
116
+ assert freqs_cis.shape == (
117
+ x.shape[-2],
118
+ x.shape[-1],
119
+ ), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}"
120
+ shape = [
121
+ d if i == ndim - 2 or i == ndim - 1 else 1
122
+ for i, d in enumerate(x.shape)
123
+ ]
124
+ else:
125
+ assert freqs_cis.shape == (
126
+ x.shape[1],
127
+ x.shape[-1],
128
+ ), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}"
129
+ shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
130
+ return freqs_cis.view(*shape)
131
+
132
+
133
+ def rotate_half(x):
134
+ x_real, x_imag = (
135
+ x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1)
136
+ ) # [B, S, H, D//2]
137
+ return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
138
+
139
+
140
+ def apply_rotary_emb(
141
+ xq: torch.Tensor,
142
+ xk: torch.Tensor,
143
+ freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
144
+ head_first: bool = False,
145
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
146
+ """
147
+ Apply rotary embeddings to input tensors using the given frequency tensor.
148
+
149
+ This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided
150
+ frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor
151
+ is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are
152
+ returned as real tensors.
153
+
154
+ Args:
155
+ xq (torch.Tensor): Query tensor to apply rotary embeddings. [B, S, H, D]
156
+ xk (torch.Tensor): Key tensor to apply rotary embeddings. [B, S, H, D]
157
+ freqs_cis (torch.Tensor or tuple): Precomputed frequency tensor for complex exponential.
158
+ head_first (bool): head dimension first (except batch dim) or not.
159
+
160
+ Returns:
161
+ Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
162
+
163
+ """
164
+ xk_out = None
165
+ if isinstance(freqs_cis, tuple):
166
+ cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first) # [S, D]
167
+ cos, sin = cos.to(xq.device), sin.to(xq.device)
168
+ # real * cos - imag * sin
169
+ # imag * cos + real * sin
170
+ xq_out = (xq.float() * cos + rotate_half(xq.float()) * sin).type_as(xq)
171
+ xk_out = (xk.float() * cos + rotate_half(xk.float()) * sin).type_as(xk)
172
+ else:
173
+ # view_as_complex will pack [..., D/2, 2](real) to [..., D/2](complex)
174
+ xq_ = torch.view_as_complex(
175
+ xq.float().reshape(*xq.shape[:-1], -1, 2)
176
+ ) # [B, S, H, D//2]
177
+ freqs_cis = reshape_for_broadcast(freqs_cis, xq_, head_first).to(
178
+ xq.device
179
+ ) # [S, D//2] --> [1, S, 1, D//2]
180
+ # (real, imag) * (cos, sin) = (real * cos - imag * sin, imag * cos + real * sin)
181
+ # view_as_real will expand [..., D/2](complex) to [..., D/2, 2](real)
182
+ xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3).type_as(xq)
183
+ xk_ = torch.view_as_complex(
184
+ xk.float().reshape(*xk.shape[:-1], -1, 2)
185
+ ) # [B, S, H, D//2]
186
+ xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3).type_as(xk)
187
+
188
+ return xq_out, xk_out
189
+
190
+
191
+ def get_nd_rotary_pos_embed(
192
+ rope_dim_list,
193
+ start,
194
+ *args,
195
+ theta=10000.0,
196
+ use_real=False,
197
+ theta_rescale_factor: Union[float, List[float]] = 1.0,
198
+ interpolation_factor: Union[float, List[float]] = 1.0,
199
+ ):
200
+ """
201
+ This is a n-d version of precompute_freqs_cis, which is a RoPE for tokens with n-d structure.
202
+
203
+ Args:
204
+ rope_dim_list (list of int): Dimension of each rope. len(rope_dim_list) should equal to n.
205
+ sum(rope_dim_list) should equal to head_dim of attention layer.
206
+ start (int | tuple of int | list of int): If len(args) == 0, start is num; If len(args) == 1, start is start,
207
+ args[0] is stop, step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num.
208
+ *args: See above.
209
+ theta (float): Scaling factor for frequency computation. Defaults to 10000.0.
210
+ use_real (bool): If True, return real part and imaginary part separately. Otherwise, return complex numbers.
211
+ Some libraries such as TensorRT does not support complex64 data type. So it is useful to provide a real
212
+ part and an imaginary part separately.
213
+ theta_rescale_factor (float): Rescale factor for theta. Defaults to 1.0.
214
+
215
+ Returns:
216
+ pos_embed (torch.Tensor): [HW, D/2]
217
+ """
218
+
219
+ grid = get_meshgrid_nd(
220
+ start, *args, dim=len(rope_dim_list)
221
+ ) # [3, W, H, D] / [2, W, H]
222
+
223
+ if isinstance(theta_rescale_factor, int) or isinstance(theta_rescale_factor, float):
224
+ theta_rescale_factor = [theta_rescale_factor] * len(rope_dim_list)
225
+ elif isinstance(theta_rescale_factor, list) and len(theta_rescale_factor) == 1:
226
+ theta_rescale_factor = [theta_rescale_factor[0]] * len(rope_dim_list)
227
+ assert len(theta_rescale_factor) == len(
228
+ rope_dim_list
229
+ ), "len(theta_rescale_factor) should equal to len(rope_dim_list)"
230
+
231
+ if isinstance(interpolation_factor, int) or isinstance(interpolation_factor, float):
232
+ interpolation_factor = [interpolation_factor] * len(rope_dim_list)
233
+ elif isinstance(interpolation_factor, list) and len(interpolation_factor) == 1:
234
+ interpolation_factor = [interpolation_factor[0]] * len(rope_dim_list)
235
+ assert len(interpolation_factor) == len(
236
+ rope_dim_list
237
+ ), "len(interpolation_factor) should equal to len(rope_dim_list)"
238
+
239
+ # use 1/ndim of dimensions to encode grid_axis
240
+ embs = []
241
+ for i in range(len(rope_dim_list)):
242
+ emb = get_1d_rotary_pos_embed(
243
+ rope_dim_list[i],
244
+ grid[i].reshape(-1),
245
+ theta,
246
+ use_real=use_real,
247
+ theta_rescale_factor=theta_rescale_factor[i],
248
+ interpolation_factor=interpolation_factor[i],
249
+ ) # 2 x [WHD, rope_dim_list[i]]
250
+ embs.append(emb)
251
+
252
+ if use_real:
253
+ cos = torch.cat([emb[0] for emb in embs], dim=1) # (WHD, D/2)
254
+ sin = torch.cat([emb[1] for emb in embs], dim=1) # (WHD, D/2)
255
+ return cos, sin
256
+ else:
257
+ emb = torch.cat(embs, dim=1) # (WHD, D/2)
258
+ return emb
259
+
260
+
261
+ def get_1d_rotary_pos_embed(
262
+ dim: int,
263
+ pos: Union[torch.FloatTensor, int],
264
+ theta: float = 10000.0,
265
+ use_real: bool = False,
266
+ theta_rescale_factor: float = 1.0,
267
+ interpolation_factor: float = 1.0,
268
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
269
+ """
270
+ Precompute the frequency tensor for complex exponential (cis) with given dimensions.
271
+ (Note: `cis` means `cos + i * sin`, where i is the imaginary unit.)
272
+
273
+ This function calculates a frequency tensor with complex exponential using the given dimension 'dim'
274
+ and the end index 'end'. The 'theta' parameter scales the frequencies.
275
+ The returned tensor contains complex values in complex64 data type.
276
+
277
+ Args:
278
+ dim (int): Dimension of the frequency tensor.
279
+ pos (int or torch.FloatTensor): Position indices for the frequency tensor. [S] or scalar
280
+ theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
281
+ use_real (bool, optional): If True, return real part and imaginary part separately.
282
+ Otherwise, return complex numbers.
283
+ theta_rescale_factor (float, optional): Rescale factor for theta. Defaults to 1.0.
284
+
285
+ Returns:
286
+ freqs_cis: Precomputed frequency tensor with complex exponential. [S, D/2]
287
+ freqs_cos, freqs_sin: Precomputed frequency tensor with real and imaginary parts separately. [S, D]
288
+ """
289
+ if isinstance(pos, int):
290
+ pos = torch.arange(pos).float()
291
+
292
+ # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
293
+ # has some connection to NTK literature
294
+ if theta_rescale_factor != 1.0:
295
+ theta *= theta_rescale_factor ** (dim / (dim - 2))
296
+
297
+ freqs = 1.0 / (
298
+ theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
299
+ ) # [D/2]
300
+ # assert interpolation_factor == 1.0, f"interpolation_factor: {interpolation_factor}"
301
+ freqs = torch.outer(pos * interpolation_factor, freqs) # [S, D/2]
302
+ if use_real:
303
+ freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D]
304
+ freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D]
305
+ return freqs_cos, freqs_sin
306
+ else:
307
+ freqs_cis = torch.polar(
308
+ torch.ones_like(freqs), freqs
309
+ ) # complex64 # [S, D/2]
310
+ return freqs_cis
hunyuan_model/text_encoder.py ADDED
@@ -0,0 +1,438 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Optional, Tuple, Union
3
+ from copy import deepcopy
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from transformers import CLIPTextModel, CLIPTokenizer, AutoTokenizer, AutoModel
8
+ from transformers.utils import ModelOutput
9
+ from transformers.models.llama import LlamaModel
10
+
11
+ import logging
12
+
13
+ logger = logging.getLogger(__name__)
14
+ logging.basicConfig(level=logging.INFO)
15
+
16
+
17
+ # When using decoder-only models, we must provide a prompt template to instruct the text encoder
18
+ # on how to generate the text.
19
+ # --------------------------------------------------------------------
20
+ PROMPT_TEMPLATE_ENCODE = (
21
+ "<|start_header_id|>system<|end_header_id|>\n\nDescribe the image by detailing the color, shape, size, texture, "
22
+ "quantity, text, spatial relationships of the objects and background:<|eot_id|>"
23
+ "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
24
+ )
25
+ PROMPT_TEMPLATE_ENCODE_VIDEO = (
26
+ "<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: "
27
+ "1. The main content and theme of the video."
28
+ "2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects."
29
+ "3. Actions, events, behaviors temporal relationships, physical movement changes of the objects."
30
+ "4. background environment, light, style and atmosphere."
31
+ "5. camera angles, movements, and transitions used in the video:<|eot_id|>"
32
+ "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
33
+ )
34
+
35
+ NEGATIVE_PROMPT = "Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion"
36
+
37
+ PROMPT_TEMPLATE = {
38
+ "dit-llm-encode": {
39
+ "template": PROMPT_TEMPLATE_ENCODE,
40
+ "crop_start": 36,
41
+ },
42
+ "dit-llm-encode-video": {
43
+ "template": PROMPT_TEMPLATE_ENCODE_VIDEO,
44
+ "crop_start": 95,
45
+ },
46
+ }
47
+
48
+
49
+ def use_default(value, default):
50
+ return value if value is not None else default
51
+
52
+
53
+ def load_text_encoder(
54
+ text_encoder_type: str,
55
+ text_encoder_path: str,
56
+ text_encoder_dtype: Optional[Union[str, torch.dtype]] = None,
57
+ ):
58
+ logger.info(f"Loading text encoder model ({text_encoder_type}) from: {text_encoder_path}")
59
+
60
+ # reduce peak memory usage by specifying the dtype of the model
61
+ dtype = text_encoder_dtype
62
+ if text_encoder_type == "clipL":
63
+ text_encoder = CLIPTextModel.from_pretrained(text_encoder_path, torch_dtype=dtype)
64
+ text_encoder.final_layer_norm = text_encoder.text_model.final_layer_norm
65
+ elif text_encoder_type == "llm":
66
+ text_encoder = AutoModel.from_pretrained(text_encoder_path, low_cpu_mem_usage=True, torch_dtype=dtype)
67
+ text_encoder.final_layer_norm = text_encoder.norm
68
+ else:
69
+ raise ValueError(f"Unsupported text encoder type: {text_encoder_type}")
70
+ # from_pretrained will ensure that the model is in eval mode.
71
+
72
+ if dtype is not None:
73
+ text_encoder = text_encoder.to(dtype=dtype)
74
+
75
+ text_encoder.requires_grad_(False)
76
+
77
+ logger.info(f"Text encoder to dtype: {text_encoder.dtype}")
78
+ return text_encoder, text_encoder_path
79
+
80
+
81
+ def load_tokenizer(tokenizer_type, tokenizer_path=None, padding_side="right"):
82
+ logger.info(f"Loading tokenizer ({tokenizer_type}) from: {tokenizer_path}")
83
+
84
+ if tokenizer_type == "clipL":
85
+ tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path, max_length=77)
86
+ elif tokenizer_type == "llm":
87
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, padding_side=padding_side)
88
+ else:
89
+ raise ValueError(f"Unsupported tokenizer type: {tokenizer_type}")
90
+
91
+ return tokenizer, tokenizer_path
92
+
93
+
94
+ @dataclass
95
+ class TextEncoderModelOutput(ModelOutput):
96
+ """
97
+ Base class for model's outputs that also contains a pooling of the last hidden states.
98
+
99
+ Args:
100
+ hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
101
+ Sequence of hidden-states at the output of the last layer of the model.
102
+ attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
103
+ Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
104
+ hidden_states_list (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed):
105
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
106
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
107
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
108
+ text_outputs (`list`, *optional*, returned when `return_texts=True` is passed):
109
+ List of decoded texts.
110
+ """
111
+
112
+ hidden_state: torch.FloatTensor = None
113
+ attention_mask: Optional[torch.LongTensor] = None
114
+ hidden_states_list: Optional[Tuple[torch.FloatTensor, ...]] = None
115
+ text_outputs: Optional[list] = None
116
+
117
+
118
+ class TextEncoder(nn.Module):
119
+ def __init__(
120
+ self,
121
+ text_encoder_type: str,
122
+ max_length: int,
123
+ text_encoder_dtype: Optional[Union[str, torch.dtype]] = None,
124
+ text_encoder_path: Optional[str] = None,
125
+ tokenizer_type: Optional[str] = None,
126
+ tokenizer_path: Optional[str] = None,
127
+ output_key: Optional[str] = None,
128
+ use_attention_mask: bool = True,
129
+ input_max_length: Optional[int] = None,
130
+ prompt_template: Optional[dict] = None,
131
+ prompt_template_video: Optional[dict] = None,
132
+ hidden_state_skip_layer: Optional[int] = None,
133
+ apply_final_norm: bool = False,
134
+ reproduce: bool = False,
135
+ ):
136
+ super().__init__()
137
+ self.text_encoder_type = text_encoder_type
138
+ self.max_length = max_length
139
+ # self.precision = text_encoder_precision
140
+ self.model_path = text_encoder_path
141
+ self.tokenizer_type = tokenizer_type if tokenizer_type is not None else text_encoder_type
142
+ self.tokenizer_path = tokenizer_path if tokenizer_path is not None else text_encoder_path
143
+ self.use_attention_mask = use_attention_mask
144
+ if prompt_template_video is not None:
145
+ assert use_attention_mask is True, "Attention mask is True required when training videos."
146
+ self.input_max_length = input_max_length if input_max_length is not None else max_length
147
+ self.prompt_template = prompt_template
148
+ self.prompt_template_video = prompt_template_video
149
+ self.hidden_state_skip_layer = hidden_state_skip_layer
150
+ self.apply_final_norm = apply_final_norm
151
+ self.reproduce = reproduce
152
+
153
+ self.use_template = self.prompt_template is not None
154
+ if self.use_template:
155
+ assert (
156
+ isinstance(self.prompt_template, dict) and "template" in self.prompt_template
157
+ ), f"`prompt_template` must be a dictionary with a key 'template', got {self.prompt_template}"
158
+ assert "{}" in str(self.prompt_template["template"]), (
159
+ "`prompt_template['template']` must contain a placeholder `{}` for the input text, "
160
+ f"got {self.prompt_template['template']}"
161
+ )
162
+
163
+ self.use_video_template = self.prompt_template_video is not None
164
+ if self.use_video_template:
165
+ if self.prompt_template_video is not None:
166
+ assert (
167
+ isinstance(self.prompt_template_video, dict) and "template" in self.prompt_template_video
168
+ ), f"`prompt_template_video` must be a dictionary with a key 'template', got {self.prompt_template_video}"
169
+ assert "{}" in str(self.prompt_template_video["template"]), (
170
+ "`prompt_template_video['template']` must contain a placeholder `{}` for the input text, "
171
+ f"got {self.prompt_template_video['template']}"
172
+ )
173
+
174
+ if "t5" in text_encoder_type:
175
+ self.output_key = output_key or "last_hidden_state"
176
+ elif "clip" in text_encoder_type:
177
+ self.output_key = output_key or "pooler_output"
178
+ elif "llm" in text_encoder_type or "glm" in text_encoder_type:
179
+ self.output_key = output_key or "last_hidden_state"
180
+ else:
181
+ raise ValueError(f"Unsupported text encoder type: {text_encoder_type}")
182
+
183
+ self.model, self.model_path = load_text_encoder(
184
+ text_encoder_type=self.text_encoder_type, text_encoder_path=self.model_path, text_encoder_dtype=text_encoder_dtype
185
+ )
186
+ self.dtype = self.model.dtype
187
+
188
+ self.tokenizer, self.tokenizer_path = load_tokenizer(
189
+ tokenizer_type=self.tokenizer_type, tokenizer_path=self.tokenizer_path, padding_side="right"
190
+ )
191
+
192
+ def __repr__(self):
193
+ return f"{self.text_encoder_type} ({self.precision} - {self.model_path})"
194
+
195
+ @property
196
+ def device(self):
197
+ return self.model.device
198
+
199
+ @staticmethod
200
+ def apply_text_to_template(text, template, prevent_empty_text=True):
201
+ """
202
+ Apply text to template.
203
+
204
+ Args:
205
+ text (str): Input text.
206
+ template (str or list): Template string or list of chat conversation.
207
+ prevent_empty_text (bool): If Ture, we will prevent the user text from being empty
208
+ by adding a space. Defaults to True.
209
+ """
210
+ if isinstance(template, str):
211
+ # Will send string to tokenizer. Used for llm
212
+ return template.format(text)
213
+ else:
214
+ raise TypeError(f"Unsupported template type: {type(template)}")
215
+
216
+ def text2tokens(self, text, data_type="image"):
217
+ """
218
+ Tokenize the input text.
219
+
220
+ Args:
221
+ text (str or list): Input text.
222
+ """
223
+ tokenize_input_type = "str"
224
+ if self.use_template:
225
+ if data_type == "image":
226
+ prompt_template = self.prompt_template["template"]
227
+ elif data_type == "video":
228
+ prompt_template = self.prompt_template_video["template"]
229
+ else:
230
+ raise ValueError(f"Unsupported data type: {data_type}")
231
+ if isinstance(text, (list, tuple)):
232
+ text = [self.apply_text_to_template(one_text, prompt_template) for one_text in text]
233
+ if isinstance(text[0], list):
234
+ tokenize_input_type = "list"
235
+ elif isinstance(text, str):
236
+ text = self.apply_text_to_template(text, prompt_template)
237
+ if isinstance(text, list):
238
+ tokenize_input_type = "list"
239
+ else:
240
+ raise TypeError(f"Unsupported text type: {type(text)}")
241
+
242
+ kwargs = dict(
243
+ truncation=True,
244
+ max_length=self.max_length,
245
+ padding="max_length",
246
+ return_tensors="pt",
247
+ )
248
+ if tokenize_input_type == "str":
249
+ return self.tokenizer(
250
+ text,
251
+ return_length=False,
252
+ return_overflowing_tokens=False,
253
+ return_attention_mask=True,
254
+ **kwargs,
255
+ )
256
+ elif tokenize_input_type == "list":
257
+ return self.tokenizer.apply_chat_template(
258
+ text,
259
+ add_generation_prompt=True,
260
+ tokenize=True,
261
+ return_dict=True,
262
+ **kwargs,
263
+ )
264
+ else:
265
+ raise ValueError(f"Unsupported tokenize_input_type: {tokenize_input_type}")
266
+
267
+ def encode(
268
+ self,
269
+ batch_encoding,
270
+ use_attention_mask=None,
271
+ output_hidden_states=False,
272
+ do_sample=None,
273
+ hidden_state_skip_layer=None,
274
+ return_texts=False,
275
+ data_type="image",
276
+ device=None,
277
+ ):
278
+ """
279
+ Args:
280
+ batch_encoding (dict): Batch encoding from tokenizer.
281
+ use_attention_mask (bool): Whether to use attention mask. If None, use self.use_attention_mask.
282
+ Defaults to None.
283
+ output_hidden_states (bool): Whether to output hidden states. If False, return the value of
284
+ self.output_key. If True, return the entire output. If set self.hidden_state_skip_layer,
285
+ output_hidden_states will be set True. Defaults to False.
286
+ do_sample (bool): Whether to sample from the model. Used for Decoder-Only LLMs. Defaults to None.
287
+ When self.produce is False, do_sample is set to True by default.
288
+ hidden_state_skip_layer (int): Number of hidden states to hidden_state_skip_layer. 0 means the last layer.
289
+ If None, self.output_key will be used. Defaults to None.
290
+ return_texts (bool): Whether to return the decoded texts. Defaults to False.
291
+ """
292
+ device = self.model.device if device is None else device
293
+ use_attention_mask = use_default(use_attention_mask, self.use_attention_mask)
294
+ hidden_state_skip_layer = use_default(hidden_state_skip_layer, self.hidden_state_skip_layer)
295
+ do_sample = use_default(do_sample, not self.reproduce)
296
+ attention_mask = batch_encoding["attention_mask"].to(device) if use_attention_mask else None
297
+ outputs = self.model(
298
+ input_ids=batch_encoding["input_ids"].to(device),
299
+ attention_mask=attention_mask,
300
+ output_hidden_states=output_hidden_states or hidden_state_skip_layer is not None,
301
+ )
302
+ if hidden_state_skip_layer is not None:
303
+ last_hidden_state = outputs.hidden_states[-(hidden_state_skip_layer + 1)]
304
+ # Real last hidden state already has layer norm applied. So here we only apply it
305
+ # for intermediate layers.
306
+ if hidden_state_skip_layer > 0 and self.apply_final_norm:
307
+ last_hidden_state = self.model.final_layer_norm(last_hidden_state)
308
+ else:
309
+ last_hidden_state = outputs[self.output_key]
310
+
311
+ # Remove hidden states of instruction tokens, only keep prompt tokens.
312
+ if self.use_template:
313
+ if data_type == "image":
314
+ crop_start = self.prompt_template.get("crop_start", -1)
315
+ elif data_type == "video":
316
+ crop_start = self.prompt_template_video.get("crop_start", -1)
317
+ else:
318
+ raise ValueError(f"Unsupported data type: {data_type}")
319
+ if crop_start > 0:
320
+ last_hidden_state = last_hidden_state[:, crop_start:]
321
+ attention_mask = attention_mask[:, crop_start:] if use_attention_mask else None
322
+
323
+ if output_hidden_states:
324
+ return TextEncoderModelOutput(last_hidden_state, attention_mask, outputs.hidden_states)
325
+ return TextEncoderModelOutput(last_hidden_state, attention_mask)
326
+
327
+ def forward(
328
+ self,
329
+ text,
330
+ use_attention_mask=None,
331
+ output_hidden_states=False,
332
+ do_sample=False,
333
+ hidden_state_skip_layer=None,
334
+ return_texts=False,
335
+ ):
336
+ batch_encoding = self.text2tokens(text)
337
+ return self.encode(
338
+ batch_encoding,
339
+ use_attention_mask=use_attention_mask,
340
+ output_hidden_states=output_hidden_states,
341
+ do_sample=do_sample,
342
+ hidden_state_skip_layer=hidden_state_skip_layer,
343
+ return_texts=return_texts,
344
+ )
345
+
346
+
347
+ # region HunyanVideo architecture
348
+
349
+
350
+ def load_text_encoder_1(
351
+ text_encoder_dir: str, device: torch.device, fp8_llm: bool, dtype: Optional[Union[str, torch.dtype]] = None
352
+ ) -> TextEncoder:
353
+ text_encoder_dtype = dtype or torch.float16
354
+ text_encoder_type = "llm"
355
+ text_len = 256
356
+ hidden_state_skip_layer = 2
357
+ apply_final_norm = False
358
+ reproduce = False
359
+
360
+ prompt_template = "dit-llm-encode"
361
+ prompt_template = PROMPT_TEMPLATE[prompt_template]
362
+ prompt_template_video = "dit-llm-encode-video"
363
+ prompt_template_video = PROMPT_TEMPLATE[prompt_template_video]
364
+
365
+ crop_start = prompt_template_video["crop_start"] # .get("crop_start", 0)
366
+ max_length = text_len + crop_start
367
+
368
+ text_encoder_1 = TextEncoder(
369
+ text_encoder_type=text_encoder_type,
370
+ max_length=max_length,
371
+ text_encoder_dtype=text_encoder_dtype,
372
+ text_encoder_path=text_encoder_dir,
373
+ tokenizer_type=text_encoder_type,
374
+ prompt_template=prompt_template,
375
+ prompt_template_video=prompt_template_video,
376
+ hidden_state_skip_layer=hidden_state_skip_layer,
377
+ apply_final_norm=apply_final_norm,
378
+ reproduce=reproduce,
379
+ )
380
+ text_encoder_1.eval()
381
+
382
+ if fp8_llm:
383
+ org_dtype = text_encoder_1.dtype
384
+ logger.info(f"Moving and casting text encoder to {device} and torch.float8_e4m3fn")
385
+ text_encoder_1.to(device=device, dtype=torch.float8_e4m3fn)
386
+
387
+ # prepare LLM for fp8
388
+ def prepare_fp8(llama_model: LlamaModel, target_dtype):
389
+ def forward_hook(module):
390
+ def forward(hidden_states):
391
+ input_dtype = hidden_states.dtype
392
+ hidden_states = hidden_states.to(torch.float32)
393
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
394
+ hidden_states = hidden_states * torch.rsqrt(variance + module.variance_epsilon)
395
+ return module.weight.to(input_dtype) * hidden_states.to(input_dtype)
396
+
397
+ return forward
398
+
399
+ for module in llama_model.modules():
400
+ if module.__class__.__name__ in ["Embedding"]:
401
+ # print("set", module.__class__.__name__, "to", target_dtype)
402
+ module.to(target_dtype)
403
+ if module.__class__.__name__ in ["LlamaRMSNorm"]:
404
+ # print("set", module.__class__.__name__, "hooks")
405
+ module.forward = forward_hook(module)
406
+
407
+ prepare_fp8(text_encoder_1.model, org_dtype)
408
+ else:
409
+ text_encoder_1.to(device=device)
410
+
411
+ return text_encoder_1
412
+
413
+
414
+ def load_text_encoder_2(
415
+ text_encoder_dir: str, device: torch.device, dtype: Optional[Union[str, torch.dtype]] = None
416
+ ) -> TextEncoder:
417
+ text_encoder_dtype = dtype or torch.float16
418
+ reproduce = False
419
+
420
+ text_encoder_2_type = "clipL"
421
+ text_len_2 = 77
422
+
423
+ text_encoder_2 = TextEncoder(
424
+ text_encoder_type=text_encoder_2_type,
425
+ max_length=text_len_2,
426
+ text_encoder_dtype=text_encoder_dtype,
427
+ text_encoder_path=text_encoder_dir,
428
+ tokenizer_type=text_encoder_2_type,
429
+ reproduce=reproduce,
430
+ )
431
+ text_encoder_2.eval()
432
+
433
+ text_encoder_2.to(device=device)
434
+
435
+ return text_encoder_2
436
+
437
+
438
+ # endregion
hunyuan_model/token_refiner.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ from einops import rearrange
4
+ import torch
5
+ import torch.nn as nn
6
+ from torch.utils.checkpoint import checkpoint
7
+
8
+ from .activation_layers import get_activation_layer
9
+ from .attention import attention
10
+ from .norm_layers import get_norm_layer
11
+ from .embed_layers import TimestepEmbedder, TextProjection
12
+ from .mlp_layers import MLP
13
+ from .modulate_layers import modulate, apply_gate
14
+
15
+
16
+ class IndividualTokenRefinerBlock(nn.Module):
17
+ def __init__(
18
+ self,
19
+ hidden_size,
20
+ heads_num,
21
+ mlp_width_ratio: str = 4.0,
22
+ mlp_drop_rate: float = 0.0,
23
+ act_type: str = "silu",
24
+ qk_norm: bool = False,
25
+ qk_norm_type: str = "layer",
26
+ qkv_bias: bool = True,
27
+ dtype: Optional[torch.dtype] = None,
28
+ device: Optional[torch.device] = None,
29
+ ):
30
+ factory_kwargs = {"device": device, "dtype": dtype}
31
+ super().__init__()
32
+ self.heads_num = heads_num
33
+ head_dim = hidden_size // heads_num
34
+ mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
35
+
36
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs)
37
+ self.self_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs)
38
+ qk_norm_layer = get_norm_layer(qk_norm_type)
39
+ self.self_attn_q_norm = (
40
+ qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
41
+ )
42
+ self.self_attn_k_norm = (
43
+ qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
44
+ )
45
+ self.self_attn_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs)
46
+
47
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs)
48
+ act_layer = get_activation_layer(act_type)
49
+ self.mlp = MLP(
50
+ in_channels=hidden_size,
51
+ hidden_channels=mlp_hidden_dim,
52
+ act_layer=act_layer,
53
+ drop=mlp_drop_rate,
54
+ **factory_kwargs,
55
+ )
56
+
57
+ self.adaLN_modulation = nn.Sequential(
58
+ act_layer(),
59
+ nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs),
60
+ )
61
+ # Zero-initialize the modulation
62
+ nn.init.zeros_(self.adaLN_modulation[1].weight)
63
+ nn.init.zeros_(self.adaLN_modulation[1].bias)
64
+
65
+ self.gradient_checkpointing = False
66
+
67
+ def enable_gradient_checkpointing(self):
68
+ self.gradient_checkpointing = True
69
+
70
+ def _forward(
71
+ self,
72
+ x: torch.Tensor,
73
+ c: torch.Tensor, # timestep_aware_representations + context_aware_representations
74
+ attn_mask: torch.Tensor = None,
75
+ ):
76
+ gate_msa, gate_mlp = self.adaLN_modulation(c).chunk(2, dim=1)
77
+
78
+ norm_x = self.norm1(x)
79
+ qkv = self.self_attn_qkv(norm_x)
80
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
81
+ # Apply QK-Norm if needed
82
+ q = self.self_attn_q_norm(q).to(v)
83
+ k = self.self_attn_k_norm(k).to(v)
84
+
85
+ # Self-Attention
86
+ attn = attention(q, k, v, mode="torch", attn_mask=attn_mask)
87
+
88
+ x = x + apply_gate(self.self_attn_proj(attn), gate_msa)
89
+
90
+ # FFN Layer
91
+ x = x + apply_gate(self.mlp(self.norm2(x)), gate_mlp)
92
+
93
+ return x
94
+
95
+ def forward(self, *args, **kwargs):
96
+ if self.training and self.gradient_checkpointing:
97
+ return checkpoint(self._forward, *args, use_reentrant=False, **kwargs)
98
+ else:
99
+ return self._forward(*args, **kwargs)
100
+
101
+
102
+
103
+ class IndividualTokenRefiner(nn.Module):
104
+ def __init__(
105
+ self,
106
+ hidden_size,
107
+ heads_num,
108
+ depth,
109
+ mlp_width_ratio: float = 4.0,
110
+ mlp_drop_rate: float = 0.0,
111
+ act_type: str = "silu",
112
+ qk_norm: bool = False,
113
+ qk_norm_type: str = "layer",
114
+ qkv_bias: bool = True,
115
+ dtype: Optional[torch.dtype] = None,
116
+ device: Optional[torch.device] = None,
117
+ ):
118
+ factory_kwargs = {"device": device, "dtype": dtype}
119
+ super().__init__()
120
+ self.blocks = nn.ModuleList(
121
+ [
122
+ IndividualTokenRefinerBlock(
123
+ hidden_size=hidden_size,
124
+ heads_num=heads_num,
125
+ mlp_width_ratio=mlp_width_ratio,
126
+ mlp_drop_rate=mlp_drop_rate,
127
+ act_type=act_type,
128
+ qk_norm=qk_norm,
129
+ qk_norm_type=qk_norm_type,
130
+ qkv_bias=qkv_bias,
131
+ **factory_kwargs,
132
+ )
133
+ for _ in range(depth)
134
+ ]
135
+ )
136
+
137
+ def enable_gradient_checkpointing(self):
138
+ for block in self.blocks:
139
+ block.enable_gradient_checkpointing()
140
+
141
+ def forward(
142
+ self,
143
+ x: torch.Tensor,
144
+ c: torch.LongTensor,
145
+ mask: Optional[torch.Tensor] = None,
146
+ ):
147
+ self_attn_mask = None
148
+ if mask is not None:
149
+ batch_size = mask.shape[0]
150
+ seq_len = mask.shape[1]
151
+ mask = mask.to(x.device)
152
+ # batch_size x 1 x seq_len x seq_len
153
+ self_attn_mask_1 = mask.view(batch_size, 1, 1, seq_len).repeat(1, 1, seq_len, 1)
154
+ # batch_size x 1 x seq_len x seq_len
155
+ self_attn_mask_2 = self_attn_mask_1.transpose(2, 3)
156
+ # batch_size x 1 x seq_len x seq_len, 1 for broadcasting of heads_num
157
+ self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool()
158
+ # avoids self-attention weight being NaN for padding tokens
159
+ self_attn_mask[:, :, :, 0] = True
160
+
161
+ for block in self.blocks:
162
+ x = block(x, c, self_attn_mask)
163
+ return x
164
+
165
+
166
+ class SingleTokenRefiner(nn.Module):
167
+ """
168
+ A single token refiner block for llm text embedding refine.
169
+ """
170
+
171
+ def __init__(
172
+ self,
173
+ in_channels,
174
+ hidden_size,
175
+ heads_num,
176
+ depth,
177
+ mlp_width_ratio: float = 4.0,
178
+ mlp_drop_rate: float = 0.0,
179
+ act_type: str = "silu",
180
+ qk_norm: bool = False,
181
+ qk_norm_type: str = "layer",
182
+ qkv_bias: bool = True,
183
+ attn_mode: str = "torch",
184
+ dtype: Optional[torch.dtype] = None,
185
+ device: Optional[torch.device] = None,
186
+ ):
187
+ factory_kwargs = {"device": device, "dtype": dtype}
188
+ super().__init__()
189
+ self.attn_mode = attn_mode
190
+ assert self.attn_mode == "torch", "Only support 'torch' mode for token refiner."
191
+
192
+ self.input_embedder = nn.Linear(in_channels, hidden_size, bias=True, **factory_kwargs)
193
+
194
+ act_layer = get_activation_layer(act_type)
195
+ # Build timestep embedding layer
196
+ self.t_embedder = TimestepEmbedder(hidden_size, act_layer, **factory_kwargs)
197
+ # Build context embedding layer
198
+ self.c_embedder = TextProjection(in_channels, hidden_size, act_layer, **factory_kwargs)
199
+
200
+ self.individual_token_refiner = IndividualTokenRefiner(
201
+ hidden_size=hidden_size,
202
+ heads_num=heads_num,
203
+ depth=depth,
204
+ mlp_width_ratio=mlp_width_ratio,
205
+ mlp_drop_rate=mlp_drop_rate,
206
+ act_type=act_type,
207
+ qk_norm=qk_norm,
208
+ qk_norm_type=qk_norm_type,
209
+ qkv_bias=qkv_bias,
210
+ **factory_kwargs,
211
+ )
212
+
213
+ def enable_gradient_checkpointing(self):
214
+ self.individual_token_refiner.enable_gradient_checkpointing()
215
+
216
+ def forward(
217
+ self,
218
+ x: torch.Tensor,
219
+ t: torch.LongTensor,
220
+ mask: Optional[torch.LongTensor] = None,
221
+ ):
222
+ timestep_aware_representations = self.t_embedder(t)
223
+
224
+ if mask is None:
225
+ context_aware_representations = x.mean(dim=1)
226
+ else:
227
+ mask_float = mask.float().unsqueeze(-1) # [b, s1, 1]
228
+ context_aware_representations = (x * mask_float).sum(dim=1) / mask_float.sum(dim=1)
229
+ context_aware_representations = self.c_embedder(context_aware_representations)
230
+ c = timestep_aware_representations + context_aware_representations
231
+
232
+ x = self.input_embedder(x)
233
+
234
+ x = self.individual_token_refiner(x, c, mask)
235
+
236
+ return x
hunyuan_model/vae.py ADDED
@@ -0,0 +1,442 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ import json
3
+ from typing import Optional, Tuple, Union
4
+ from pathlib import Path
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+ from diffusers.utils import BaseOutput, is_torch_version
11
+ from diffusers.utils.torch_utils import randn_tensor
12
+ from diffusers.models.attention_processor import SpatialNorm
13
+ from modules.unet_causal_3d_blocks import CausalConv3d, UNetMidBlockCausal3D, get_down_block3d, get_up_block3d
14
+
15
+ import logging
16
+
17
+ logger = logging.getLogger(__name__)
18
+ logging.basicConfig(level=logging.INFO)
19
+
20
+
21
+ SCALING_FACTOR = 0.476986
22
+ VAE_VER = "884-16c-hy"
23
+
24
+
25
+ def load_vae(
26
+ vae_type: str = "884-16c-hy",
27
+ vae_dtype: Optional[Union[str, torch.dtype]] = None,
28
+ sample_size: tuple = None,
29
+ vae_path: str = None,
30
+ device=None,
31
+ ):
32
+ """the fucntion to load the 3D VAE model
33
+
34
+ Args:
35
+ vae_type (str): the type of the 3D VAE model. Defaults to "884-16c-hy".
36
+ vae_precision (str, optional): the precision to load vae. Defaults to None.
37
+ sample_size (tuple, optional): the tiling size. Defaults to None.
38
+ vae_path (str, optional): the path to vae. Defaults to None.
39
+ logger (_type_, optional): logger. Defaults to None.
40
+ device (_type_, optional): device to load vae. Defaults to None.
41
+ """
42
+ if vae_path is None:
43
+ vae_path = VAE_PATH[vae_type]
44
+
45
+ logger.info(f"Loading 3D VAE model ({vae_type}) from: {vae_path}")
46
+
47
+ # use fixed config for Hunyuan's VAE
48
+ CONFIG_JSON = """{
49
+ "_class_name": "AutoencoderKLCausal3D",
50
+ "_diffusers_version": "0.4.2",
51
+ "act_fn": "silu",
52
+ "block_out_channels": [
53
+ 128,
54
+ 256,
55
+ 512,
56
+ 512
57
+ ],
58
+ "down_block_types": [
59
+ "DownEncoderBlockCausal3D",
60
+ "DownEncoderBlockCausal3D",
61
+ "DownEncoderBlockCausal3D",
62
+ "DownEncoderBlockCausal3D"
63
+ ],
64
+ "in_channels": 3,
65
+ "latent_channels": 16,
66
+ "layers_per_block": 2,
67
+ "norm_num_groups": 32,
68
+ "out_channels": 3,
69
+ "sample_size": 256,
70
+ "sample_tsize": 64,
71
+ "up_block_types": [
72
+ "UpDecoderBlockCausal3D",
73
+ "UpDecoderBlockCausal3D",
74
+ "UpDecoderBlockCausal3D",
75
+ "UpDecoderBlockCausal3D"
76
+ ],
77
+ "scaling_factor": 0.476986,
78
+ "time_compression_ratio": 4,
79
+ "mid_block_add_attention": true
80
+ }"""
81
+
82
+ # config = AutoencoderKLCausal3D.load_config(vae_path)
83
+ config = json.loads(CONFIG_JSON)
84
+
85
+ # import here to avoid circular import
86
+ from .autoencoder_kl_causal_3d import AutoencoderKLCausal3D
87
+
88
+ if sample_size:
89
+ vae = AutoencoderKLCausal3D.from_config(config, sample_size=sample_size)
90
+ else:
91
+ vae = AutoencoderKLCausal3D.from_config(config)
92
+
93
+ # vae_ckpt = Path(vae_path) / "pytorch_model.pt"
94
+ # assert vae_ckpt.exists(), f"VAE checkpoint not found: {vae_ckpt}"
95
+
96
+ ckpt = torch.load(vae_path, map_location=vae.device, weights_only=True)
97
+ if "state_dict" in ckpt:
98
+ ckpt = ckpt["state_dict"]
99
+ if any(k.startswith("vae.") for k in ckpt.keys()):
100
+ ckpt = {k.replace("vae.", ""): v for k, v in ckpt.items() if k.startswith("vae.")}
101
+ vae.load_state_dict(ckpt)
102
+
103
+ spatial_compression_ratio = vae.config.spatial_compression_ratio
104
+ time_compression_ratio = vae.config.time_compression_ratio
105
+
106
+ if vae_dtype is not None:
107
+ vae = vae.to(vae_dtype)
108
+
109
+ vae.requires_grad_(False)
110
+
111
+ logger.info(f"VAE to dtype: {vae.dtype}")
112
+
113
+ if device is not None:
114
+ vae = vae.to(device)
115
+
116
+ vae.eval()
117
+
118
+ return vae, vae_path, spatial_compression_ratio, time_compression_ratio
119
+
120
+
121
+ @dataclass
122
+ class DecoderOutput(BaseOutput):
123
+ r"""
124
+ Output of decoding method.
125
+
126
+ Args:
127
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
128
+ The decoded output sample from the last layer of the model.
129
+ """
130
+
131
+ sample: torch.FloatTensor
132
+
133
+
134
+ class EncoderCausal3D(nn.Module):
135
+ r"""
136
+ The `EncoderCausal3D` layer of a variational autoencoder that encodes its input into a latent representation.
137
+ """
138
+
139
+ def __init__(
140
+ self,
141
+ in_channels: int = 3,
142
+ out_channels: int = 3,
143
+ down_block_types: Tuple[str, ...] = ("DownEncoderBlockCausal3D",),
144
+ block_out_channels: Tuple[int, ...] = (64,),
145
+ layers_per_block: int = 2,
146
+ norm_num_groups: int = 32,
147
+ act_fn: str = "silu",
148
+ double_z: bool = True,
149
+ mid_block_add_attention=True,
150
+ time_compression_ratio: int = 4,
151
+ spatial_compression_ratio: int = 8,
152
+ ):
153
+ super().__init__()
154
+ self.layers_per_block = layers_per_block
155
+
156
+ self.conv_in = CausalConv3d(in_channels, block_out_channels[0], kernel_size=3, stride=1)
157
+ self.mid_block = None
158
+ self.down_blocks = nn.ModuleList([])
159
+
160
+ # down
161
+ output_channel = block_out_channels[0]
162
+ for i, down_block_type in enumerate(down_block_types):
163
+ input_channel = output_channel
164
+ output_channel = block_out_channels[i]
165
+ is_final_block = i == len(block_out_channels) - 1
166
+ num_spatial_downsample_layers = int(np.log2(spatial_compression_ratio))
167
+ num_time_downsample_layers = int(np.log2(time_compression_ratio))
168
+
169
+ if time_compression_ratio == 4:
170
+ add_spatial_downsample = bool(i < num_spatial_downsample_layers)
171
+ add_time_downsample = bool(i >= (len(block_out_channels) - 1 - num_time_downsample_layers) and not is_final_block)
172
+ else:
173
+ raise ValueError(f"Unsupported time_compression_ratio: {time_compression_ratio}.")
174
+
175
+ downsample_stride_HW = (2, 2) if add_spatial_downsample else (1, 1)
176
+ downsample_stride_T = (2,) if add_time_downsample else (1,)
177
+ downsample_stride = tuple(downsample_stride_T + downsample_stride_HW)
178
+ down_block = get_down_block3d(
179
+ down_block_type,
180
+ num_layers=self.layers_per_block,
181
+ in_channels=input_channel,
182
+ out_channels=output_channel,
183
+ add_downsample=bool(add_spatial_downsample or add_time_downsample),
184
+ downsample_stride=downsample_stride,
185
+ resnet_eps=1e-6,
186
+ downsample_padding=0,
187
+ resnet_act_fn=act_fn,
188
+ resnet_groups=norm_num_groups,
189
+ attention_head_dim=output_channel,
190
+ temb_channels=None,
191
+ )
192
+ self.down_blocks.append(down_block)
193
+
194
+ # mid
195
+ self.mid_block = UNetMidBlockCausal3D(
196
+ in_channels=block_out_channels[-1],
197
+ resnet_eps=1e-6,
198
+ resnet_act_fn=act_fn,
199
+ output_scale_factor=1,
200
+ resnet_time_scale_shift="default",
201
+ attention_head_dim=block_out_channels[-1],
202
+ resnet_groups=norm_num_groups,
203
+ temb_channels=None,
204
+ add_attention=mid_block_add_attention,
205
+ )
206
+
207
+ # out
208
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
209
+ self.conv_act = nn.SiLU()
210
+
211
+ conv_out_channels = 2 * out_channels if double_z else out_channels
212
+ self.conv_out = CausalConv3d(block_out_channels[-1], conv_out_channels, kernel_size=3)
213
+
214
+ def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor:
215
+ r"""The forward method of the `EncoderCausal3D` class."""
216
+ assert len(sample.shape) == 5, "The input tensor should have 5 dimensions"
217
+
218
+ sample = self.conv_in(sample)
219
+
220
+ # down
221
+ for down_block in self.down_blocks:
222
+ sample = down_block(sample)
223
+
224
+ # middle
225
+ sample = self.mid_block(sample)
226
+
227
+ # post-process
228
+ sample = self.conv_norm_out(sample)
229
+ sample = self.conv_act(sample)
230
+ sample = self.conv_out(sample)
231
+
232
+ return sample
233
+
234
+
235
+ class DecoderCausal3D(nn.Module):
236
+ r"""
237
+ The `DecoderCausal3D` layer of a variational autoencoder that decodes its latent representation into an output sample.
238
+ """
239
+
240
+ def __init__(
241
+ self,
242
+ in_channels: int = 3,
243
+ out_channels: int = 3,
244
+ up_block_types: Tuple[str, ...] = ("UpDecoderBlockCausal3D",),
245
+ block_out_channels: Tuple[int, ...] = (64,),
246
+ layers_per_block: int = 2,
247
+ norm_num_groups: int = 32,
248
+ act_fn: str = "silu",
249
+ norm_type: str = "group", # group, spatial
250
+ mid_block_add_attention=True,
251
+ time_compression_ratio: int = 4,
252
+ spatial_compression_ratio: int = 8,
253
+ ):
254
+ super().__init__()
255
+ self.layers_per_block = layers_per_block
256
+
257
+ self.conv_in = CausalConv3d(in_channels, block_out_channels[-1], kernel_size=3, stride=1)
258
+ self.mid_block = None
259
+ self.up_blocks = nn.ModuleList([])
260
+
261
+ temb_channels = in_channels if norm_type == "spatial" else None
262
+
263
+ # mid
264
+ self.mid_block = UNetMidBlockCausal3D(
265
+ in_channels=block_out_channels[-1],
266
+ resnet_eps=1e-6,
267
+ resnet_act_fn=act_fn,
268
+ output_scale_factor=1,
269
+ resnet_time_scale_shift="default" if norm_type == "group" else norm_type,
270
+ attention_head_dim=block_out_channels[-1],
271
+ resnet_groups=norm_num_groups,
272
+ temb_channels=temb_channels,
273
+ add_attention=mid_block_add_attention,
274
+ )
275
+
276
+ # up
277
+ reversed_block_out_channels = list(reversed(block_out_channels))
278
+ output_channel = reversed_block_out_channels[0]
279
+ for i, up_block_type in enumerate(up_block_types):
280
+ prev_output_channel = output_channel
281
+ output_channel = reversed_block_out_channels[i]
282
+ is_final_block = i == len(block_out_channels) - 1
283
+ num_spatial_upsample_layers = int(np.log2(spatial_compression_ratio))
284
+ num_time_upsample_layers = int(np.log2(time_compression_ratio))
285
+
286
+ if time_compression_ratio == 4:
287
+ add_spatial_upsample = bool(i < num_spatial_upsample_layers)
288
+ add_time_upsample = bool(i >= len(block_out_channels) - 1 - num_time_upsample_layers and not is_final_block)
289
+ else:
290
+ raise ValueError(f"Unsupported time_compression_ratio: {time_compression_ratio}.")
291
+
292
+ upsample_scale_factor_HW = (2, 2) if add_spatial_upsample else (1, 1)
293
+ upsample_scale_factor_T = (2,) if add_time_upsample else (1,)
294
+ upsample_scale_factor = tuple(upsample_scale_factor_T + upsample_scale_factor_HW)
295
+ up_block = get_up_block3d(
296
+ up_block_type,
297
+ num_layers=self.layers_per_block + 1,
298
+ in_channels=prev_output_channel,
299
+ out_channels=output_channel,
300
+ prev_output_channel=None,
301
+ add_upsample=bool(add_spatial_upsample or add_time_upsample),
302
+ upsample_scale_factor=upsample_scale_factor,
303
+ resnet_eps=1e-6,
304
+ resnet_act_fn=act_fn,
305
+ resnet_groups=norm_num_groups,
306
+ attention_head_dim=output_channel,
307
+ temb_channels=temb_channels,
308
+ resnet_time_scale_shift=norm_type,
309
+ )
310
+ self.up_blocks.append(up_block)
311
+ prev_output_channel = output_channel
312
+
313
+ # out
314
+ if norm_type == "spatial":
315
+ self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels)
316
+ else:
317
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
318
+ self.conv_act = nn.SiLU()
319
+ self.conv_out = CausalConv3d(block_out_channels[0], out_channels, kernel_size=3)
320
+
321
+ self.gradient_checkpointing = False
322
+
323
+ def forward(
324
+ self,
325
+ sample: torch.FloatTensor,
326
+ latent_embeds: Optional[torch.FloatTensor] = None,
327
+ ) -> torch.FloatTensor:
328
+ r"""The forward method of the `DecoderCausal3D` class."""
329
+ assert len(sample.shape) == 5, "The input tensor should have 5 dimensions."
330
+
331
+ sample = self.conv_in(sample)
332
+
333
+ upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
334
+ if self.training and self.gradient_checkpointing:
335
+
336
+ def create_custom_forward(module):
337
+ def custom_forward(*inputs):
338
+ return module(*inputs)
339
+
340
+ return custom_forward
341
+
342
+ if is_torch_version(">=", "1.11.0"):
343
+ # middle
344
+ sample = torch.utils.checkpoint.checkpoint(
345
+ create_custom_forward(self.mid_block),
346
+ sample,
347
+ latent_embeds,
348
+ use_reentrant=False,
349
+ )
350
+ sample = sample.to(upscale_dtype)
351
+
352
+ # up
353
+ for up_block in self.up_blocks:
354
+ sample = torch.utils.checkpoint.checkpoint(
355
+ create_custom_forward(up_block),
356
+ sample,
357
+ latent_embeds,
358
+ use_reentrant=False,
359
+ )
360
+ else:
361
+ # middle
362
+ sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample, latent_embeds)
363
+ sample = sample.to(upscale_dtype)
364
+
365
+ # up
366
+ for up_block in self.up_blocks:
367
+ sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample, latent_embeds)
368
+ else:
369
+ # middle
370
+ sample = self.mid_block(sample, latent_embeds)
371
+ sample = sample.to(upscale_dtype)
372
+
373
+ # up
374
+ for up_block in self.up_blocks:
375
+ sample = up_block(sample, latent_embeds)
376
+
377
+ # post-process
378
+ if latent_embeds is None:
379
+ sample = self.conv_norm_out(sample)
380
+ else:
381
+ sample = self.conv_norm_out(sample, latent_embeds)
382
+ sample = self.conv_act(sample)
383
+ sample = self.conv_out(sample)
384
+
385
+ return sample
386
+
387
+
388
+ class DiagonalGaussianDistribution(object):
389
+ def __init__(self, parameters: torch.Tensor, deterministic: bool = False):
390
+ if parameters.ndim == 3:
391
+ dim = 2 # (B, L, C)
392
+ elif parameters.ndim == 5 or parameters.ndim == 4:
393
+ dim = 1 # (B, C, T, H ,W) / (B, C, H, W)
394
+ else:
395
+ raise NotImplementedError
396
+ self.parameters = parameters
397
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=dim)
398
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
399
+ self.deterministic = deterministic
400
+ self.std = torch.exp(0.5 * self.logvar)
401
+ self.var = torch.exp(self.logvar)
402
+ if self.deterministic:
403
+ self.var = self.std = torch.zeros_like(self.mean, device=self.parameters.device, dtype=self.parameters.dtype)
404
+
405
+ def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor:
406
+ # make sure sample is on the same device as the parameters and has same dtype
407
+ sample = randn_tensor(
408
+ self.mean.shape,
409
+ generator=generator,
410
+ device=self.parameters.device,
411
+ dtype=self.parameters.dtype,
412
+ )
413
+ x = self.mean + self.std * sample
414
+ return x
415
+
416
+ def kl(self, other: "DiagonalGaussianDistribution" = None) -> torch.Tensor:
417
+ if self.deterministic:
418
+ return torch.Tensor([0.0])
419
+ else:
420
+ reduce_dim = list(range(1, self.mean.ndim))
421
+ if other is None:
422
+ return 0.5 * torch.sum(
423
+ torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
424
+ dim=reduce_dim,
425
+ )
426
+ else:
427
+ return 0.5 * torch.sum(
428
+ torch.pow(self.mean - other.mean, 2) / other.var + self.var / other.var - 1.0 - self.logvar + other.logvar,
429
+ dim=reduce_dim,
430
+ )
431
+
432
+ def nll(self, sample: torch.Tensor, dims: Tuple[int, ...] = [1, 2, 3]) -> torch.Tensor:
433
+ if self.deterministic:
434
+ return torch.Tensor([0.0])
435
+ logtwopi = np.log(2.0 * np.pi)
436
+ return 0.5 * torch.sum(
437
+ logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
438
+ dim=dims,
439
+ )
440
+
441
+ def mode(self) -> torch.Tensor:
442
+ return self.mean
hv_generate_video.py ADDED
@@ -0,0 +1,563 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from datetime import datetime
3
+ from pathlib import Path
4
+ import random
5
+ import sys
6
+ import os
7
+ import time
8
+ from typing import Optional, Union
9
+
10
+ import numpy as np
11
+ import torch
12
+ import torchvision
13
+ import accelerate
14
+ from diffusers.utils.torch_utils import randn_tensor
15
+ from transformers.models.llama import LlamaModel
16
+ from tqdm import tqdm
17
+ import av
18
+ from einops import rearrange
19
+ from safetensors.torch import load_file
20
+
21
+ from hunyuan_model import vae
22
+ from hunyuan_model.text_encoder import TextEncoder
23
+ from hunyuan_model.text_encoder import PROMPT_TEMPLATE
24
+ from hunyuan_model.vae import load_vae
25
+ from hunyuan_model.models import load_transformer, get_rotary_pos_embed
26
+ from modules.scheduling_flow_match_discrete import FlowMatchDiscreteScheduler
27
+ from networks import lora
28
+
29
+ import logging
30
+
31
+ logger = logging.getLogger(__name__)
32
+ logging.basicConfig(level=logging.INFO)
33
+
34
+
35
+ def clean_memory_on_device(device):
36
+ if device.type == "cuda":
37
+ torch.cuda.empty_cache()
38
+ elif device.type == "cpu":
39
+ pass
40
+ elif device.type == "mps": # not tested
41
+ torch.mps.empty_cache()
42
+
43
+
44
+ def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=1, fps=24):
45
+ """save videos by video tensor
46
+ copy from https://github.com/guoyww/AnimateDiff/blob/e92bd5671ba62c0d774a32951453e328018b7c5b/animatediff/utils/util.py#L61
47
+
48
+ Args:
49
+ videos (torch.Tensor): video tensor predicted by the model
50
+ path (str): path to save video
51
+ rescale (bool, optional): rescale the video tensor from [-1, 1] to . Defaults to False.
52
+ n_rows (int, optional): Defaults to 1.
53
+ fps (int, optional): video save fps. Defaults to 8.
54
+ """
55
+ videos = rearrange(videos, "b c t h w -> t b c h w")
56
+ outputs = []
57
+ for x in videos:
58
+ x = torchvision.utils.make_grid(x, nrow=n_rows)
59
+ x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
60
+ if rescale:
61
+ x = (x + 1.0) / 2.0 # -1,1 -> 0,1
62
+ x = torch.clamp(x, 0, 1)
63
+ x = (x * 255).numpy().astype(np.uint8)
64
+ outputs.append(x)
65
+
66
+ os.makedirs(os.path.dirname(path), exist_ok=True)
67
+
68
+ # # save video with av
69
+ # container = av.open(path, "w")
70
+ # stream = container.add_stream("libx264", rate=fps)
71
+ # for x in outputs:
72
+ # frame = av.VideoFrame.from_ndarray(x, format="rgb24")
73
+ # packet = stream.encode(frame)
74
+ # container.mux(packet)
75
+ # packet = stream.encode(None)
76
+ # container.mux(packet)
77
+ # container.close()
78
+
79
+ height, width, _ = outputs[0].shape
80
+
81
+ # create output container
82
+ container = av.open(path, mode="w")
83
+
84
+ # create video stream
85
+ codec = "libx264"
86
+ pixel_format = "yuv420p"
87
+ stream = container.add_stream(codec, rate=fps)
88
+ stream.width = width
89
+ stream.height = height
90
+ stream.pix_fmt = pixel_format
91
+ stream.bit_rate = 4000000 # 4Mbit/s
92
+
93
+ for frame_array in outputs:
94
+ frame = av.VideoFrame.from_ndarray(frame_array, format="rgb24")
95
+ packets = stream.encode(frame)
96
+ for packet in packets:
97
+ container.mux(packet)
98
+
99
+ for packet in stream.encode():
100
+ container.mux(packet)
101
+
102
+ container.close()
103
+
104
+
105
+ # region Encoding prompt
106
+
107
+
108
+ def encode_prompt(prompt: Union[str, list[str]], device: torch.device, num_videos_per_prompt: int, text_encoder: TextEncoder):
109
+ r"""
110
+ Encodes the prompt into text encoder hidden states.
111
+
112
+ Args:
113
+ prompt (`str` or `List[str]`):
114
+ prompt to be encoded
115
+ device: (`torch.device`):
116
+ torch device
117
+ num_videos_per_prompt (`int`):
118
+ number of videos that should be generated per prompt
119
+ text_encoder (TextEncoder):
120
+ text encoder to be used for encoding the prompt
121
+ """
122
+ # LoRA and Textual Inversion are not supported in this script
123
+ # negative prompt and prompt embedding are not supported in this script
124
+ # clip_skip is not supported in this script because it is not used in the original script
125
+ data_type = "video" # video only, image is not supported
126
+
127
+ text_inputs = text_encoder.text2tokens(prompt, data_type=data_type)
128
+
129
+ with torch.no_grad():
130
+ prompt_outputs = text_encoder.encode(text_inputs, data_type=data_type, device=device)
131
+ prompt_embeds = prompt_outputs.hidden_state
132
+
133
+ attention_mask = prompt_outputs.attention_mask
134
+ if attention_mask is not None:
135
+ attention_mask = attention_mask.to(device)
136
+ bs_embed, seq_len = attention_mask.shape
137
+ attention_mask = attention_mask.repeat(1, num_videos_per_prompt)
138
+ attention_mask = attention_mask.view(bs_embed * num_videos_per_prompt, seq_len)
139
+
140
+ prompt_embeds_dtype = text_encoder.dtype
141
+ prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
142
+
143
+ if prompt_embeds.ndim == 2:
144
+ bs_embed, _ = prompt_embeds.shape
145
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
146
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt)
147
+ prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, -1)
148
+ else:
149
+ bs_embed, seq_len, _ = prompt_embeds.shape
150
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
151
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
152
+ prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, seq_len, -1)
153
+
154
+ return prompt_embeds, attention_mask
155
+
156
+
157
+ def encode_input_prompt(prompt, args, device, fp8_llm=False, accelerator=None):
158
+ # constants
159
+ prompt_template_video = "dit-llm-encode-video"
160
+ prompt_template = "dit-llm-encode"
161
+ text_encoder_dtype = torch.float16
162
+ text_encoder_type = "llm"
163
+ text_len = 256
164
+ hidden_state_skip_layer = 2
165
+ apply_final_norm = False
166
+ reproduce = False
167
+
168
+ text_encoder_2_type = "clipL"
169
+ text_len_2 = 77
170
+
171
+ num_videos = 1
172
+
173
+ # if args.prompt_template_video is not None:
174
+ # crop_start = PROMPT_TEMPLATE[args.prompt_template_video].get("crop_start", 0)
175
+ # elif args.prompt_template is not None:
176
+ # crop_start = PROMPT_TEMPLATE[args.prompt_template].get("crop_start", 0)
177
+ # else:
178
+ # crop_start = 0
179
+ crop_start = PROMPT_TEMPLATE[prompt_template_video].get("crop_start", 0)
180
+ max_length = text_len + crop_start
181
+
182
+ # prompt_template
183
+ prompt_template = PROMPT_TEMPLATE[prompt_template]
184
+
185
+ # prompt_template_video
186
+ prompt_template_video = PROMPT_TEMPLATE[prompt_template_video] # if args.prompt_template_video is not None else None
187
+
188
+ # load text encoders
189
+ logger.info(f"loading text encoder: {args.text_encoder1}")
190
+ text_encoder = TextEncoder(
191
+ text_encoder_type=text_encoder_type,
192
+ max_length=max_length,
193
+ text_encoder_dtype=text_encoder_dtype,
194
+ text_encoder_path=args.text_encoder1,
195
+ tokenizer_type=text_encoder_type,
196
+ prompt_template=prompt_template,
197
+ prompt_template_video=prompt_template_video,
198
+ hidden_state_skip_layer=hidden_state_skip_layer,
199
+ apply_final_norm=apply_final_norm,
200
+ reproduce=reproduce,
201
+ )
202
+ text_encoder.eval()
203
+ if fp8_llm:
204
+ org_dtype = text_encoder.dtype
205
+ logger.info(f"Moving and casting text encoder to {device} and torch.float8_e4m3fn")
206
+ text_encoder.to(device=device, dtype=torch.float8_e4m3fn)
207
+
208
+ # prepare LLM for fp8
209
+ def prepare_fp8(llama_model: LlamaModel, target_dtype):
210
+ def forward_hook(module):
211
+ def forward(hidden_states):
212
+ input_dtype = hidden_states.dtype
213
+ hidden_states = hidden_states.to(torch.float32)
214
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
215
+ hidden_states = hidden_states * torch.rsqrt(variance + module.variance_epsilon)
216
+ return module.weight.to(input_dtype) * hidden_states.to(input_dtype)
217
+
218
+ return forward
219
+
220
+ for module in llama_model.modules():
221
+ if module.__class__.__name__ in ["Embedding"]:
222
+ # print("set", module.__class__.__name__, "to", target_dtype)
223
+ module.to(target_dtype)
224
+ if module.__class__.__name__ in ["LlamaRMSNorm"]:
225
+ # print("set", module.__class__.__name__, "hooks")
226
+ module.forward = forward_hook(module)
227
+
228
+ prepare_fp8(text_encoder.model, org_dtype)
229
+
230
+ logger.info(f"loading text encoder 2: {args.text_encoder2}")
231
+ text_encoder_2 = TextEncoder(
232
+ text_encoder_type=text_encoder_2_type,
233
+ max_length=text_len_2,
234
+ text_encoder_dtype=text_encoder_dtype,
235
+ text_encoder_path=args.text_encoder2,
236
+ tokenizer_type=text_encoder_2_type,
237
+ reproduce=reproduce,
238
+ )
239
+ text_encoder_2.eval()
240
+
241
+ # encode prompt
242
+ logger.info(f"Encoding prompt with text encoder 1")
243
+ text_encoder.to(device=device)
244
+ if fp8_llm:
245
+ with accelerator.autocast():
246
+ prompt_embeds, prompt_mask = encode_prompt(prompt, device, num_videos, text_encoder)
247
+ else:
248
+ prompt_embeds, prompt_mask = encode_prompt(prompt, device, num_videos, text_encoder)
249
+ text_encoder = None
250
+ clean_memory_on_device(device)
251
+
252
+ logger.info(f"Encoding prompt with text encoder 2")
253
+ text_encoder_2.to(device=device)
254
+ prompt_embeds_2, prompt_mask_2 = encode_prompt(prompt, device, num_videos, text_encoder_2)
255
+
256
+ prompt_embeds = prompt_embeds.to("cpu")
257
+ prompt_mask = prompt_mask.to("cpu")
258
+ prompt_embeds_2 = prompt_embeds_2.to("cpu")
259
+ prompt_mask_2 = prompt_mask_2.to("cpu")
260
+
261
+ text_encoder_2 = None
262
+ clean_memory_on_device(device)
263
+
264
+ return prompt_embeds, prompt_mask, prompt_embeds_2, prompt_mask_2
265
+
266
+
267
+ # endregion
268
+
269
+
270
+ def decode_latents(args, latents, device):
271
+ vae_dtype = torch.float16
272
+ vae, _, s_ratio, t_ratio = load_vae(vae_dtype=vae_dtype, device=device, vae_path=args.vae)
273
+ vae.eval()
274
+ # vae_kwargs = {"s_ratio": s_ratio, "t_ratio": t_ratio}
275
+
276
+ # set chunk_size to CausalConv3d recursively
277
+ chunk_size = args.vae_chunk_size
278
+ if chunk_size is not None:
279
+ vae.set_chunk_size_for_causal_conv_3d(chunk_size)
280
+ logger.info(f"Set chunk_size to {chunk_size} for CausalConv3d")
281
+
282
+ expand_temporal_dim = False
283
+ if len(latents.shape) == 4:
284
+ latents = latents.unsqueeze(2)
285
+ expand_temporal_dim = True
286
+ elif len(latents.shape) == 5:
287
+ pass
288
+ else:
289
+ raise ValueError(f"Only support latents with shape (b, c, h, w) or (b, c, f, h, w), but got {latents.shape}.")
290
+
291
+ if hasattr(vae.config, "shift_factor") and vae.config.shift_factor:
292
+ latents = latents / vae.config.scaling_factor + vae.config.shift_factor
293
+ else:
294
+ latents = latents / vae.config.scaling_factor
295
+
296
+ latents = latents.to(device=device, dtype=vae.dtype)
297
+ if args.vae_spatial_tile_sample_min_size is not None:
298
+ vae.enable_spatial_tiling(True)
299
+ vae.tile_sample_min_size = args.vae_spatial_tile_sample_min_size
300
+ vae.tile_latent_min_size = args.vae_spatial_tile_sample_min_size // 8
301
+ # elif args.vae_tiling:
302
+ else:
303
+ vae.enable_spatial_tiling(True)
304
+ with torch.no_grad():
305
+ image = vae.decode(latents, return_dict=False)[0]
306
+
307
+ if expand_temporal_dim or image.shape[2] == 1:
308
+ image = image.squeeze(2)
309
+
310
+ image = (image / 2 + 0.5).clamp(0, 1)
311
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
312
+ image = image.cpu().float()
313
+
314
+ return image
315
+
316
+
317
+ def parse_args():
318
+ parser = argparse.ArgumentParser(description="HunyuanVideo inference script")
319
+
320
+ parser.add_argument("--dit", type=str, required=True, help="DiT checkpoint path or directory")
321
+ parser.add_argument("--vae", type=str, required=True, help="VAE checkpoint path or directory")
322
+ parser.add_argument("--text_encoder1", type=str, required=True, help="Text Encoder 1 directory")
323
+ parser.add_argument("--text_encoder2", type=str, required=True, help="Text Encoder 2 directory")
324
+
325
+ # LoRA
326
+ parser.add_argument("--lora_weight", type=str, required=False, default=None, help="LoRA weight path")
327
+ parser.add_argument("--lora_multiplier", type=float, default=1.0, help="LoRA multiplier")
328
+
329
+ parser.add_argument("--prompt", type=str, required=True, help="prompt for generation")
330
+ parser.add_argument("--video_size", type=int, nargs=2, default=[256, 256], help="video size")
331
+ parser.add_argument("--video_length", type=int, default=129, help="video length")
332
+ parser.add_argument("--infer_steps", type=int, default=50, help="number of inference steps")
333
+ parser.add_argument("--save_path", type=str, required=True, help="path to save generated video")
334
+ parser.add_argument("--seed", type=int, default=None, help="Seed for evaluation.")
335
+ parser.add_argument("--embedded_cfg_scale", type=float, default=6.0, help="Embeded classifier free guidance scale.")
336
+
337
+ # Flow Matching
338
+ parser.add_argument("--flow_shift", type=float, default=7.0, help="Shift factor for flow matching schedulers.")
339
+
340
+ parser.add_argument("--fp8", action="store_true", help="use fp8 for DiT model")
341
+ parser.add_argument("--fp8_llm", action="store_true", help="use fp8 for Text Encoder 1 (LLM)")
342
+ parser.add_argument(
343
+ "--device", type=str, default=None, help="device to use for inference. If None, use CUDA if available, otherwise use CPU"
344
+ )
345
+ parser.add_argument(
346
+ "--attn_mode", type=str, default="torch", choices=["flash", "torch", "sageattn", "sdpa"], help="attention mode"
347
+ )
348
+ parser.add_argument("--vae_chunk_size", type=int, default=None, help="chunk size for CausalConv3d in VAE")
349
+ parser.add_argument(
350
+ "--vae_spatial_tile_sample_min_size", type=int, default=None, help="spatial tile sample min size for VAE, default 256"
351
+ )
352
+ parser.add_argument("--blocks_to_swap", type=int, default=None, help="number of blocks to swap in the model")
353
+ parser.add_argument("--img_in_txt_in_offloading", action="store_true", help="offload img_in and txt_in to cpu")
354
+ parser.add_argument("--output_type", type=str, default="video", help="output type: video, latent or both")
355
+ parser.add_argument("--latent_path", type=str, default=None, help="path to latent for decode. no inference")
356
+
357
+ args = parser.parse_args()
358
+
359
+ assert args.latent_path is None or args.output_type == "video", "latent-path is only supported with output-type=video"
360
+
361
+ # update dit_weight based on model_base if not exists
362
+
363
+ return args
364
+
365
+
366
+ def check_inputs(args):
367
+ height = args.video_size[0]
368
+ width = args.video_size[1]
369
+ video_length = args.video_length
370
+
371
+ if height % 8 != 0 or width % 8 != 0:
372
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
373
+ return height, width, video_length
374
+
375
+
376
+ def main():
377
+ args = parse_args()
378
+
379
+ device = args.device if args.device is not None else "cuda" if torch.cuda.is_available() else "cpu"
380
+ device = torch.device(device)
381
+ dit_dtype = torch.bfloat16
382
+ dit_weight_dtype = torch.float8_e4m3fn if args.fp8 else dit_dtype
383
+ logger.info(f"Using device: {device}, DiT precision: {dit_dtype}, weight precision: {dit_weight_dtype}")
384
+
385
+ if args.latent_path is not None:
386
+ latents = torch.load(args.latent_path, map_location="cpu")
387
+ logger.info(f"Loaded latent from {args.latent_path}. Shape: {latents.shape}")
388
+ latents = latents.unsqueeze(0)
389
+ seeds = [0] # dummy seed
390
+ else:
391
+ # prepare accelerator
392
+ mixed_precision = "bf16" if dit_dtype == torch.bfloat16 else "fp16"
393
+ accelerator = accelerate.Accelerator(mixed_precision=mixed_precision)
394
+
395
+ # load prompt
396
+ prompt = args.prompt # TODO load prompts from file
397
+ assert prompt is not None, "prompt is required"
398
+
399
+ # check inputs: may be height, width, video_length etc will be changed for each generation in future
400
+ height, width, video_length = check_inputs(args)
401
+
402
+ # encode prompt with LLM and Text Encoder
403
+ logger.info(f"Encoding prompt: {prompt}")
404
+ prompt_embeds, prompt_mask, prompt_embeds_2, prompt_mask_2 = encode_input_prompt(
405
+ prompt, args, device, args.fp8_llm, accelerator
406
+ )
407
+
408
+ # load DiT model
409
+ blocks_to_swap = args.blocks_to_swap if args.blocks_to_swap else 0
410
+ loading_device = "cpu" if blocks_to_swap > 0 else device
411
+
412
+ logger.info(f"Loading DiT model from {args.dit}")
413
+ if args.attn_mode == "sdpa":
414
+ args.attn_mode = "torch"
415
+ transformer = load_transformer(args.dit, args.attn_mode, loading_device, dit_dtype)
416
+ transformer.eval()
417
+
418
+ # load LoRA weights
419
+ if args.lora_weight is not None:
420
+ logger.info(f"Loading LoRA weights from {args.lora_weight}")
421
+ weights_sd = load_file(args.lora_weight)
422
+ network = lora.create_network_from_weights_hunyuan_video(
423
+ args.lora_multiplier, weights_sd, unet=transformer, for_inference=True
424
+ )
425
+ logger.info("Merging LoRA weights to DiT model")
426
+ network.merge_to(None, transformer, weights_sd, device=device)
427
+ logger.info("LoRA weights loaded")
428
+
429
+ if blocks_to_swap > 0:
430
+ logger.info(f"Casting model to {dit_weight_dtype}")
431
+ transformer.to(dtype=dit_weight_dtype)
432
+ logger.info(f"Enable swap {blocks_to_swap} blocks to CPU from device: {device}")
433
+ transformer.enable_block_swap(blocks_to_swap, device, supports_backward=False)
434
+ transformer.move_to_device_except_swap_blocks(device)
435
+ transformer.prepare_block_swap_before_forward()
436
+ else:
437
+ logger.info(f"Moving and casting model to {device} and {dit_weight_dtype}")
438
+ transformer.to(device=device, dtype=dit_weight_dtype)
439
+ if args.img_in_txt_in_offloading:
440
+ logger.info("Enable offloading img_in and txt_in to CPU")
441
+ transformer.enable_img_in_txt_in_offloading()
442
+
443
+ # load scheduler
444
+ logger.info(f"Loading scheduler")
445
+ scheduler = FlowMatchDiscreteScheduler(shift=args.flow_shift, reverse=True, solver="euler")
446
+
447
+ # Prepare timesteps
448
+ num_inference_steps = args.infer_steps
449
+ scheduler.set_timesteps(num_inference_steps, device=device) # n_tokens is not used in FlowMatchDiscreteScheduler
450
+ timesteps = scheduler.timesteps
451
+
452
+ # Prepare generator
453
+ num_videos_per_prompt = 1 # args.num_videos
454
+ seed = args.seed
455
+ if seed is None:
456
+ seeds = [random.randint(0, 1_000_000) for _ in range(num_videos_per_prompt)]
457
+ elif isinstance(seed, int):
458
+ seeds = [seed + i for i in range(num_videos_per_prompt)]
459
+ else:
460
+ raise ValueError(f"Seed must be an integer or None, got {seed}.")
461
+ generator = [torch.Generator(device).manual_seed(seed) for seed in seeds]
462
+
463
+ # Prepare latents
464
+ num_channels_latents = 16 # transformer.config.in_channels
465
+ vae_scale_factor = 2 ** (4 - 1) # len(self.vae.config.block_out_channels) == 4
466
+
467
+ vae_ver = vae.VAE_VER
468
+ if "884" in vae_ver:
469
+ latent_video_length = (video_length - 1) // 4 + 1
470
+ elif "888" in vae_ver:
471
+ latent_video_length = (video_length - 1) // 8 + 1
472
+ else:
473
+ latent_video_length = video_length
474
+
475
+ shape = (
476
+ num_videos_per_prompt,
477
+ num_channels_latents,
478
+ latent_video_length,
479
+ height // vae_scale_factor,
480
+ width // vae_scale_factor,
481
+ )
482
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dit_dtype)
483
+ # FlowMatchDiscreteScheduler does not have init_noise_sigma
484
+
485
+ # Denoising loop
486
+ embedded_guidance_scale = args.embedded_cfg_scale
487
+ if embedded_guidance_scale is not None:
488
+ guidance_expand = torch.tensor([embedded_guidance_scale * 1000.0] * latents.shape[0], dtype=torch.float32, device="cpu")
489
+ guidance_expand = guidance_expand.to(device=device, dtype=dit_dtype)
490
+ else:
491
+ guidance_expand = None
492
+ freqs_cos, freqs_sin = get_rotary_pos_embed(vae.VAE_VER, transformer, video_length, height, width)
493
+ # n_tokens = freqs_cos.shape[0]
494
+
495
+ # move and cast all inputs to the correct device and dtype
496
+ prompt_embeds = prompt_embeds.to(device=device, dtype=dit_dtype)
497
+ prompt_mask = prompt_mask.to(device=device)
498
+ prompt_embeds_2 = prompt_embeds_2.to(device=device, dtype=dit_dtype)
499
+ prompt_mask_2 = prompt_mask_2.to(device=device)
500
+ freqs_cos = freqs_cos.to(device=device, dtype=dit_dtype)
501
+ freqs_sin = freqs_sin.to(device=device, dtype=dit_dtype)
502
+
503
+ num_warmup_steps = len(timesteps) - num_inference_steps * scheduler.order
504
+ # with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA]) as p:
505
+ with tqdm(total=num_inference_steps) as progress_bar:
506
+ for i, t in enumerate(timesteps):
507
+ latents = scheduler.scale_model_input(latents, t)
508
+
509
+ # predict the noise residual
510
+ with torch.no_grad(), accelerator.autocast():
511
+ noise_pred = transformer( # For an input image (129, 192, 336) (1, 256, 256)
512
+ latents, # [1, 16, 33, 24, 42]
513
+ t.repeat(latents.shape[0]).to(device=device, dtype=dit_dtype), # [1]
514
+ text_states=prompt_embeds, # [1, 256, 4096]
515
+ text_mask=prompt_mask, # [1, 256]
516
+ text_states_2=prompt_embeds_2, # [1, 768]
517
+ freqs_cos=freqs_cos, # [seqlen, head_dim]
518
+ freqs_sin=freqs_sin, # [seqlen, head_dim]
519
+ guidance=guidance_expand,
520
+ return_dict=True,
521
+ )["x"]
522
+
523
+ # compute the previous noisy sample x_t -> x_t-1
524
+ latents = scheduler.step(noise_pred, t, latents, return_dict=False)[0]
525
+
526
+ # update progress bar
527
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0):
528
+ if progress_bar is not None:
529
+ progress_bar.update()
530
+
531
+ # print(p.key_averages().table(sort_by="self_cpu_time_total", row_limit=-1))
532
+ # print(p.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1))
533
+
534
+ latents = latents.detach().cpu()
535
+ transformer = None
536
+ clean_memory_on_device(device)
537
+
538
+ # Save samples
539
+ output_type = args.output_type
540
+ save_path = args.save_path # if args.save_path_suffix == "" else f"{args.save_path}_{args.save_path_suffix}"
541
+ os.makedirs(save_path, exist_ok=True)
542
+ time_flag = datetime.fromtimestamp(time.time()).strftime("%Y%m%d-%H%M%S")
543
+
544
+ if output_type == "latent" or output_type == "both":
545
+ # save latent
546
+ for i, latent in enumerate(latents):
547
+ latent_path = f"{save_path}/{time_flag}_{i}_{seeds[i]}_latent.pt"
548
+ torch.save(latent, latent_path)
549
+ logger.info(f"Latent save to: {latent_path}")
550
+ if output_type == "video" or output_type == "both":
551
+ # save video
552
+ videos = decode_latents(args, latents, device)
553
+ for i, sample in enumerate(videos):
554
+ sample = sample.unsqueeze(0)
555
+ save_path = f"{save_path}/{time_flag}_{seeds[i]}.mp4"
556
+ save_videos_grid(sample, save_path, fps=24)
557
+ logger.info(f"Sample save to: {save_path}")
558
+
559
+ logger.info("Done!")
560
+
561
+
562
+ if __name__ == "__main__":
563
+ main()
hv_train_network.py ADDED
@@ -0,0 +1,2129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ast
2
+ import asyncio
3
+ from datetime import datetime
4
+ import gc
5
+ import importlib
6
+ import argparse
7
+ import math
8
+ import os
9
+ import pathlib
10
+ import re
11
+ import sys
12
+ import random
13
+ import time
14
+ import json
15
+ from multiprocessing import Value
16
+ from typing import Any, Dict, List, Optional
17
+ import accelerate
18
+ import numpy as np
19
+ from packaging.version import Version
20
+
21
+ import huggingface_hub
22
+ import toml
23
+
24
+ import torch
25
+ from tqdm import tqdm
26
+ from accelerate.utils import set_seed
27
+ from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs
28
+ from safetensors.torch import load_file
29
+ import transformers
30
+ from diffusers.optimization import (
31
+ SchedulerType as DiffusersSchedulerType,
32
+ TYPE_TO_SCHEDULER_FUNCTION as DIFFUSERS_TYPE_TO_SCHEDULER_FUNCTION,
33
+ )
34
+ from transformers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION
35
+
36
+ from dataset import config_utils
37
+ from hunyuan_model.models import load_transformer, get_rotary_pos_embed_by_shape
38
+ import hunyuan_model.text_encoder as text_encoder_module
39
+ from hunyuan_model.vae import load_vae
40
+ import hunyuan_model.vae as vae_module
41
+ from modules.scheduling_flow_match_discrete import FlowMatchDiscreteScheduler
42
+ import networks.lora as lora_module
43
+ from dataset.config_utils import BlueprintGenerator, ConfigSanitizer
44
+
45
+ import logging
46
+
47
+ from utils import huggingface_utils, model_utils, train_utils, sai_model_spec
48
+
49
+ logger = logging.getLogger(__name__)
50
+ logging.basicConfig(level=logging.INFO)
51
+
52
+
53
+ BASE_MODEL_VERSION_HUNYUAN_VIDEO = "hunyuan_video"
54
+
55
+ SS_METADATA_KEY_BASE_MODEL_VERSION = "ss_base_model_version"
56
+ SS_METADATA_KEY_NETWORK_MODULE = "ss_network_module"
57
+ SS_METADATA_KEY_NETWORK_DIM = "ss_network_dim"
58
+ SS_METADATA_KEY_NETWORK_ALPHA = "ss_network_alpha"
59
+ SS_METADATA_KEY_NETWORK_ARGS = "ss_network_args"
60
+
61
+ SS_METADATA_MINIMUM_KEYS = [
62
+ SS_METADATA_KEY_BASE_MODEL_VERSION,
63
+ SS_METADATA_KEY_NETWORK_MODULE,
64
+ SS_METADATA_KEY_NETWORK_DIM,
65
+ SS_METADATA_KEY_NETWORK_ALPHA,
66
+ SS_METADATA_KEY_NETWORK_ARGS,
67
+ ]
68
+
69
+
70
+ def clean_memory_on_device(device: torch.device):
71
+ r"""
72
+ Clean memory on the specified device, will be called from training scripts.
73
+ """
74
+ gc.collect()
75
+
76
+ # device may "cuda" or "cuda:0", so we need to check the type of device
77
+ if device.type == "cuda":
78
+ torch.cuda.empty_cache()
79
+ if device.type == "xpu":
80
+ torch.xpu.empty_cache()
81
+ if device.type == "mps":
82
+ torch.mps.empty_cache()
83
+
84
+
85
+ # for collate_fn: epoch and step is multiprocessing.Value
86
+ class collator_class:
87
+ def __init__(self, epoch, step, dataset):
88
+ self.current_epoch = epoch
89
+ self.current_step = step
90
+ self.dataset = dataset # not used if worker_info is not None, in case of multiprocessing
91
+
92
+ def __call__(self, examples):
93
+ worker_info = torch.utils.data.get_worker_info()
94
+ # worker_info is None in the main process
95
+ if worker_info is not None:
96
+ dataset = worker_info.dataset
97
+ else:
98
+ dataset = self.dataset
99
+
100
+ # set epoch and step
101
+ dataset.set_current_epoch(self.current_epoch.value)
102
+ dataset.set_current_step(self.current_step.value)
103
+ return examples[0]
104
+
105
+
106
+ def prepare_accelerator(args: argparse.Namespace) -> Accelerator:
107
+ """
108
+ DeepSpeed is not supported in this script currently.
109
+ """
110
+ if args.logging_dir is None:
111
+ logging_dir = None
112
+ else:
113
+ log_prefix = "" if args.log_prefix is None else args.log_prefix
114
+ logging_dir = args.logging_dir + "/" + log_prefix + time.strftime("%Y%m%d%H%M%S", time.localtime())
115
+
116
+ if args.log_with is None:
117
+ if logging_dir is not None:
118
+ log_with = "tensorboard"
119
+ else:
120
+ log_with = None
121
+ else:
122
+ log_with = args.log_with
123
+ if log_with in ["tensorboard", "all"]:
124
+ if logging_dir is None:
125
+ raise ValueError(
126
+ "logging_dir is required when log_with is tensorboard / Tensorboardを使う場合、logging_dirを指定してください"
127
+ )
128
+ if log_with in ["wandb", "all"]:
129
+ try:
130
+ import wandb
131
+ except ImportError:
132
+ raise ImportError("No wandb / wandb がインストールされていないようです")
133
+ if logging_dir is not None:
134
+ os.makedirs(logging_dir, exist_ok=True)
135
+ os.environ["WANDB_DIR"] = logging_dir
136
+ if args.wandb_api_key is not None:
137
+ wandb.login(key=args.wandb_api_key)
138
+
139
+ kwargs_handlers = [
140
+ (
141
+ InitProcessGroupKwargs(
142
+ backend="gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl",
143
+ init_method=(
144
+ "env://?use_libuv=False" if os.name == "nt" and Version(torch.__version__) >= Version("2.4.0") else None
145
+ ),
146
+ timeout=datetime.timedelta(minutes=args.ddp_timeout) if args.ddp_timeout else None,
147
+ )
148
+ if torch.cuda.device_count() > 1
149
+ else None
150
+ ),
151
+ (
152
+ DistributedDataParallelKwargs(
153
+ gradient_as_bucket_view=args.ddp_gradient_as_bucket_view, static_graph=args.ddp_static_graph
154
+ )
155
+ if args.ddp_gradient_as_bucket_view or args.ddp_static_graph
156
+ else None
157
+ ),
158
+ ]
159
+ kwargs_handlers = [i for i in kwargs_handlers if i is not None]
160
+
161
+ accelerator = Accelerator(
162
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
163
+ mixed_precision=args.mixed_precision,
164
+ log_with=log_with,
165
+ project_dir=logging_dir,
166
+ kwargs_handlers=kwargs_handlers,
167
+ )
168
+ print("accelerator device:", accelerator.device)
169
+ return accelerator
170
+
171
+
172
+ def line_to_prompt_dict(line: str) -> dict:
173
+ # subset of gen_img_diffusers
174
+ prompt_args = line.split(" --")
175
+ prompt_dict = {}
176
+ prompt_dict["prompt"] = prompt_args[0]
177
+
178
+ for parg in prompt_args:
179
+ try:
180
+ m = re.match(r"w (\d+)", parg, re.IGNORECASE)
181
+ if m:
182
+ prompt_dict["width"] = int(m.group(1))
183
+ continue
184
+
185
+ m = re.match(r"h (\d+)", parg, re.IGNORECASE)
186
+ if m:
187
+ prompt_dict["height"] = int(m.group(1))
188
+ continue
189
+
190
+ m = re.match(r"f (\d+)", parg, re.IGNORECASE)
191
+ if m:
192
+ prompt_dict["frame_count"] = int(m.group(1))
193
+ continue
194
+
195
+ m = re.match(r"d (\d+)", parg, re.IGNORECASE)
196
+ if m:
197
+ prompt_dict["seed"] = int(m.group(1))
198
+ continue
199
+
200
+ m = re.match(r"s (\d+)", parg, re.IGNORECASE)
201
+ if m: # steps
202
+ prompt_dict["sample_steps"] = max(1, min(1000, int(m.group(1))))
203
+ continue
204
+
205
+ # m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE)
206
+ # if m: # scale
207
+ # prompt_dict["scale"] = float(m.group(1))
208
+ # continue
209
+ # m = re.match(r"n (.+)", parg, re.IGNORECASE)
210
+ # if m: # negative prompt
211
+ # prompt_dict["negative_prompt"] = m.group(1)
212
+ # continue
213
+
214
+ except ValueError as ex:
215
+ logger.error(f"Exception in parsing / 解析エラー: {parg}")
216
+ logger.error(ex)
217
+
218
+ return prompt_dict
219
+
220
+
221
+ def load_prompts(prompt_file: str) -> list[Dict]:
222
+ # read prompts
223
+ if prompt_file.endswith(".txt"):
224
+ with open(prompt_file, "r", encoding="utf-8") as f:
225
+ lines = f.readlines()
226
+ prompts = [line.strip() for line in lines if len(line.strip()) > 0 and line[0] != "#"]
227
+ elif prompt_file.endswith(".toml"):
228
+ with open(prompt_file, "r", encoding="utf-8") as f:
229
+ data = toml.load(f)
230
+ prompts = [dict(**data["prompt"], **subset) for subset in data["prompt"]["subset"]]
231
+ elif prompt_file.endswith(".json"):
232
+ with open(prompt_file, "r", encoding="utf-8") as f:
233
+ prompts = json.load(f)
234
+
235
+ # preprocess prompts
236
+ for i in range(len(prompts)):
237
+ prompt_dict = prompts[i]
238
+ if isinstance(prompt_dict, str):
239
+ prompt_dict = line_to_prompt_dict(prompt_dict)
240
+ prompts[i] = prompt_dict
241
+ assert isinstance(prompt_dict, dict)
242
+
243
+ # Adds an enumerator to the dict based on prompt position. Used later to name image files. Also cleanup of extra data in original prompt dict.
244
+ prompt_dict["enum"] = i
245
+ prompt_dict.pop("subset", None)
246
+
247
+ return prompts
248
+
249
+
250
+ def compute_density_for_timestep_sampling(
251
+ weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None
252
+ ):
253
+ """Compute the density for sampling the timesteps when doing SD3 training.
254
+
255
+ Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
256
+
257
+ SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
258
+ """
259
+ if weighting_scheme == "logit_normal":
260
+ # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
261
+ u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu")
262
+ u = torch.nn.functional.sigmoid(u)
263
+ elif weighting_scheme == "mode":
264
+ u = torch.rand(size=(batch_size,), device="cpu")
265
+ u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
266
+ else:
267
+ u = torch.rand(size=(batch_size,), device="cpu")
268
+ return u
269
+
270
+
271
+ def get_sigmas(noise_scheduler, timesteps, device, n_dim=4, dtype=torch.float32):
272
+ sigmas = noise_scheduler.sigmas.to(device=device, dtype=dtype)
273
+ schedule_timesteps = noise_scheduler.timesteps.to(device)
274
+ timesteps = timesteps.to(device)
275
+
276
+ # if sum([(schedule_timesteps == t) for t in timesteps]) < len(timesteps):
277
+ if any([(schedule_timesteps == t).sum() == 0 for t in timesteps]):
278
+ # raise ValueError("Some timesteps are not in the schedule / 一部のtimestepsがスケジュールに含まれていません")
279
+ # round to nearest timestep
280
+ logger.warning("Some timesteps are not in the schedule / 一部のtimestepsがスケジュールに含まれていません")
281
+ step_indices = [torch.argmin(torch.abs(schedule_timesteps - t)).item() for t in timesteps]
282
+ else:
283
+ step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
284
+
285
+ sigma = sigmas[step_indices].flatten()
286
+ while len(sigma.shape) < n_dim:
287
+ sigma = sigma.unsqueeze(-1)
288
+ return sigma
289
+
290
+
291
+ def compute_loss_weighting_for_sd3(weighting_scheme: str, noise_scheduler, timesteps, device, dtype):
292
+ """Computes loss weighting scheme for SD3 training.
293
+
294
+ Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
295
+
296
+ SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
297
+ """
298
+ if weighting_scheme == "sigma_sqrt" or weighting_scheme == "cosmap":
299
+ sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=5, dtype=dtype)
300
+ if weighting_scheme == "sigma_sqrt":
301
+ weighting = (sigmas**-2.0).float()
302
+ else:
303
+ bot = 1 - 2 * sigmas + 2 * sigmas**2
304
+ weighting = 2 / (math.pi * bot)
305
+ else:
306
+ weighting = None # torch.ones_like(sigmas)
307
+ return weighting
308
+
309
+
310
+ class NetworkTrainer:
311
+ def __init__(self):
312
+ pass
313
+
314
+ # TODO 他のスクリプトと共通化する
315
+ def generate_step_logs(
316
+ self,
317
+ args: argparse.Namespace,
318
+ current_loss,
319
+ avr_loss,
320
+ lr_scheduler,
321
+ lr_descriptions,
322
+ optimizer=None,
323
+ keys_scaled=None,
324
+ mean_norm=None,
325
+ maximum_norm=None,
326
+ ):
327
+ network_train_unet_only = True
328
+ logs = {"loss/current": current_loss, "loss/average": avr_loss}
329
+
330
+ if keys_scaled is not None:
331
+ logs["max_norm/keys_scaled"] = keys_scaled
332
+ logs["max_norm/average_key_norm"] = mean_norm
333
+ logs["max_norm/max_key_norm"] = maximum_norm
334
+
335
+ lrs = lr_scheduler.get_last_lr()
336
+ for i, lr in enumerate(lrs):
337
+ if lr_descriptions is not None:
338
+ lr_desc = lr_descriptions[i]
339
+ else:
340
+ idx = i - (0 if network_train_unet_only else -1)
341
+ if idx == -1:
342
+ lr_desc = "textencoder"
343
+ else:
344
+ if len(lrs) > 2:
345
+ lr_desc = f"group{idx}"
346
+ else:
347
+ lr_desc = "unet"
348
+
349
+ logs[f"lr/{lr_desc}"] = lr
350
+
351
+ if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower():
352
+ # tracking d*lr value
353
+ logs[f"lr/d*lr/{lr_desc}"] = (
354
+ lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"]
355
+ )
356
+ if (
357
+ args.optimizer_type.lower().endswith("ProdigyPlusScheduleFree".lower()) and optimizer is not None
358
+ ): # tracking d*lr value of unet.
359
+ logs["lr/d*lr"] = optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"]
360
+ else:
361
+ idx = 0
362
+ if not network_train_unet_only:
363
+ logs["lr/textencoder"] = float(lrs[0])
364
+ idx = 1
365
+
366
+ for i in range(idx, len(lrs)):
367
+ logs[f"lr/group{i}"] = float(lrs[i])
368
+ if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower():
369
+ logs[f"lr/d*lr/group{i}"] = (
370
+ lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"]
371
+ )
372
+ if args.optimizer_type.lower().endswith("ProdigyPlusScheduleFree".lower()) and optimizer is not None:
373
+ logs[f"lr/d*lr/group{i}"] = optimizer.param_groups[i]["d"] * optimizer.param_groups[i]["lr"]
374
+
375
+ return logs
376
+
377
+ def process_sample_prompts(
378
+ self,
379
+ args: argparse.Namespace,
380
+ accelerator: Accelerator,
381
+ sample_prompts: str,
382
+ text_encoder1: str,
383
+ text_encoder2: str,
384
+ fp8_llm: bool,
385
+ ):
386
+ logger.info(f"cache Text Encoder outputs for sample prompt: {sample_prompts}")
387
+ prompts = load_prompts(sample_prompts)
388
+
389
+ def encode_for_text_encoder(text_encoder):
390
+ sample_prompts_te_outputs = {} # (prompt) -> (embeds, mask)
391
+ with accelerator.autocast(), torch.no_grad():
392
+ for prompt_dict in prompts:
393
+ for p in [prompt_dict.get("prompt", "")]:
394
+ if p not in sample_prompts_te_outputs:
395
+ logger.info(f"cache Text Encoder outputs for prompt: {p}")
396
+
397
+ data_type = "video"
398
+ text_inputs = text_encoder.text2tokens(p, data_type=data_type)
399
+
400
+ prompt_outputs = text_encoder.encode(text_inputs, data_type=data_type)
401
+ sample_prompts_te_outputs[p] = (prompt_outputs.hidden_state, prompt_outputs.attention_mask)
402
+
403
+ return sample_prompts_te_outputs
404
+
405
+ # Load Text Encoder 1 and encode
406
+ text_encoder_dtype = torch.float16 if args.text_encoder_dtype is None else model_utils.str_to_dtype(args.text_encoder_dtype)
407
+ logger.info(f"loading text encoder 1: {text_encoder1}")
408
+ text_encoder_1 = text_encoder_module.load_text_encoder_1(text_encoder1, accelerator.device, fp8_llm, text_encoder_dtype)
409
+
410
+ logger.info("encoding with Text Encoder 1")
411
+ te_outputs_1 = encode_for_text_encoder(text_encoder_1)
412
+ del text_encoder_1
413
+
414
+ # Load Text Encoder 2 and encode
415
+ logger.info(f"loading text encoder 2: {text_encoder2}")
416
+ text_encoder_2 = text_encoder_module.load_text_encoder_2(text_encoder2, accelerator.device, text_encoder_dtype)
417
+
418
+ logger.info("encoding with Text Encoder 2")
419
+ te_outputs_2 = encode_for_text_encoder(text_encoder_2, is_llm=False)
420
+ del text_encoder_2
421
+
422
+ # prepare sample parameters
423
+ sample_parameters = []
424
+ for prompt_dict in prompts:
425
+ prompt_dict_copy = prompt_dict.copy()
426
+ p = prompt_dict.get("prompt", "")
427
+ prompt_dict_copy["llm_embeds"] = te_outputs_1[p][0]
428
+ prompt_dict_copy["llm_mask"] = te_outputs_1[p][1]
429
+ prompt_dict_copy["clipL_embeds"] = te_outputs_2[p][0]
430
+ prompt_dict_copy["clipL_mask"] = te_outputs_2[p][1]
431
+ sample_parameters.append(prompt_dict_copy)
432
+
433
+ clean_memory_on_device(accelerator.device)
434
+
435
+ return sample_parameters
436
+
437
+ def get_optimizer(self, args, trainable_params: list[torch.nn.Parameter]) -> tuple[str, str, torch.optim.Optimizer]:
438
+ # adamw, adamw8bit, adafactor
439
+
440
+ optimizer_type = args.optimizer_type
441
+
442
+ # split optimizer_type and optimizer_args
443
+ optimizer_kwargs = {}
444
+ if args.optimizer_args is not None and len(args.optimizer_args) > 0:
445
+ for arg in args.optimizer_args:
446
+ key, value = arg.split("=")
447
+ value = ast.literal_eval(value)
448
+ optimizer_kwargs[key] = value
449
+
450
+ lr = args.learning_rate
451
+ optimizer = None
452
+ optimizer_class = None
453
+
454
+ if optimizer_type.endswith("8bit".lower()):
455
+ try:
456
+ import bitsandbytes as bnb
457
+ except ImportError:
458
+ raise ImportError("No bitsandbytes / bitsandbytesがインストールされていないようです")
459
+
460
+ if optimizer_type == "AdamW8bit".lower():
461
+ logger.info(f"use 8-bit AdamW optimizer | {optimizer_kwargs}")
462
+ optimizer_class = bnb.optim.AdamW8bit
463
+ optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
464
+
465
+ elif optimizer_type == "Adafactor".lower():
466
+ # Adafactor: check relative_step and warmup_init
467
+ if "relative_step" not in optimizer_kwargs:
468
+ optimizer_kwargs["relative_step"] = True # default
469
+ if not optimizer_kwargs["relative_step"] and optimizer_kwargs.get("warmup_init", False):
470
+ logger.info(
471
+ f"set relative_step to True because warmup_init is True / warmup_initがTrueのためrelative_stepをTrueにします"
472
+ )
473
+ optimizer_kwargs["relative_step"] = True
474
+ logger.info(f"use Adafactor optimizer | {optimizer_kwargs}")
475
+
476
+ if optimizer_kwargs["relative_step"]:
477
+ logger.info(f"relative_step is true / relative_stepがtrueです")
478
+ if lr != 0.0:
479
+ logger.warning(f"learning rate is used as initial_lr / 指定したlearning rateはinitial_lrとして使用されます")
480
+ args.learning_rate = None
481
+
482
+ if args.lr_scheduler != "adafactor":
483
+ logger.info(f"use adafactor_scheduler / スケジューラにadafactor_schedulerを使用します")
484
+ args.lr_scheduler = f"adafactor:{lr}" # ちょっと微妙だけど
485
+
486
+ lr = None
487
+ else:
488
+ if args.max_grad_norm != 0.0:
489
+ logger.warning(
490
+ f"because max_grad_norm is set, clip_grad_norm is enabled. consider set to 0 / max_grad_normが設定されているためclip_grad_normが有効になります。0に設定して無効にしたほうがいいかもしれません"
491
+ )
492
+ if args.lr_scheduler != "constant_with_warmup":
493
+ logger.warning(f"constant_with_warmup will be good / スケジューラはconstant_with_warmupが良いかもしれません")
494
+ if optimizer_kwargs.get("clip_threshold", 1.0) != 1.0:
495
+ logger.warning(f"clip_threshold=1.0 will be good / clip_thresholdは1.0が良いかもしれません")
496
+
497
+ optimizer_class = transformers.optimization.Adafactor
498
+ optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
499
+
500
+ elif optimizer_type == "AdamW".lower():
501
+ logger.info(f"use AdamW optimizer | {optimizer_kwargs}")
502
+ optimizer_class = torch.optim.AdamW
503
+ optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
504
+
505
+ if optimizer is None:
506
+ # 任意のoptimizerを使う
507
+ case_sensitive_optimizer_type = args.optimizer_type # not lower
508
+ logger.info(f"use {case_sensitive_optimizer_type} | {optimizer_kwargs}")
509
+
510
+ if "." not in case_sensitive_optimizer_type: # from torch.optim
511
+ optimizer_module = torch.optim
512
+ else: # from other library
513
+ values = case_sensitive_optimizer_type.split(".")
514
+ optimizer_module = importlib.import_module(".".join(values[:-1]))
515
+ case_sensitive_optimizer_type = values[-1]
516
+
517
+ optimizer_class = getattr(optimizer_module, case_sensitive_optimizer_type)
518
+ optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
519
+
520
+ # for logging
521
+ optimizer_name = optimizer_class.__module__ + "." + optimizer_class.__name__
522
+ optimizer_args = ",".join([f"{k}={v}" for k, v in optimizer_kwargs.items()])
523
+
524
+ # get train and eval functions
525
+ if hasattr(optimizer, "train") and callable(optimizer.train):
526
+ train_fn = optimizer.train
527
+ eval_fn = optimizer.eval
528
+ else:
529
+ train_fn = lambda: None
530
+ eval_fn = lambda: None
531
+
532
+ return optimizer_name, optimizer_args, optimizer, train_fn, eval_fn
533
+
534
+ def is_schedulefree_optimizer(self, optimizer: torch.optim.Optimizer, args: argparse.Namespace) -> bool:
535
+ return args.optimizer_type.lower().endswith("schedulefree".lower()) # or args.optimizer_schedulefree_wrapper
536
+
537
+ def get_dummy_scheduler(optimizer: torch.optim.Optimizer) -> Any:
538
+ # dummy scheduler for schedulefree optimizer. supports only empty step(), get_last_lr() and optimizers.
539
+ # this scheduler is used for logging only.
540
+ # this isn't be wrapped by accelerator because of this class is not a subclass of torch.optim.lr_scheduler._LRScheduler
541
+ class DummyScheduler:
542
+ def __init__(self, optimizer: torch.optim.Optimizer):
543
+ self.optimizer = optimizer
544
+
545
+ def step(self):
546
+ pass
547
+
548
+ def get_last_lr(self):
549
+ return [group["lr"] for group in self.optimizer.param_groups]
550
+
551
+ return DummyScheduler(optimizer)
552
+
553
+ def get_scheduler(self, args, optimizer: torch.optim.Optimizer, num_processes: int):
554
+ """
555
+ Unified API to get any scheduler from its name.
556
+ """
557
+ # if schedulefree optimizer, return dummy scheduler
558
+ if self.is_schedulefree_optimizer(optimizer, args):
559
+ return self.get_dummy_scheduler(optimizer)
560
+
561
+ name = args.lr_scheduler
562
+ num_training_steps = args.max_train_steps * num_processes # * args.gradient_accumulation_steps
563
+ num_warmup_steps: Optional[int] = (
564
+ int(args.lr_warmup_steps * num_training_steps) if isinstance(args.lr_warmup_steps, float) else args.lr_warmup_steps
565
+ )
566
+ num_decay_steps: Optional[int] = (
567
+ int(args.lr_decay_steps * num_training_steps) if isinstance(args.lr_decay_steps, float) else args.lr_decay_steps
568
+ )
569
+ num_stable_steps = num_training_steps - num_warmup_steps - num_decay_steps
570
+ num_cycles = args.lr_scheduler_num_cycles
571
+ power = args.lr_scheduler_power
572
+ timescale = args.lr_scheduler_timescale
573
+ min_lr_ratio = args.lr_scheduler_min_lr_ratio
574
+
575
+ lr_scheduler_kwargs = {} # get custom lr_scheduler kwargs
576
+ if args.lr_scheduler_args is not None and len(args.lr_scheduler_args) > 0:
577
+ for arg in args.lr_scheduler_args:
578
+ key, value = arg.split("=")
579
+ value = ast.literal_eval(value)
580
+ lr_scheduler_kwargs[key] = value
581
+
582
+ def wrap_check_needless_num_warmup_steps(return_vals):
583
+ if num_warmup_steps is not None and num_warmup_steps != 0:
584
+ raise ValueError(f"{name} does not require `num_warmup_steps`. Set None or 0.")
585
+ return return_vals
586
+
587
+ # using any lr_scheduler from other library
588
+ if args.lr_scheduler_type:
589
+ lr_scheduler_type = args.lr_scheduler_type
590
+ logger.info(f"use {lr_scheduler_type} | {lr_scheduler_kwargs} as lr_scheduler")
591
+ if "." not in lr_scheduler_type: # default to use torch.optim
592
+ lr_scheduler_module = torch.optim.lr_scheduler
593
+ else:
594
+ values = lr_scheduler_type.split(".")
595
+ lr_scheduler_module = importlib.import_module(".".join(values[:-1]))
596
+ lr_scheduler_type = values[-1]
597
+ lr_scheduler_class = getattr(lr_scheduler_module, lr_scheduler_type)
598
+ lr_scheduler = lr_scheduler_class(optimizer, **lr_scheduler_kwargs)
599
+ return lr_scheduler
600
+
601
+ if name.startswith("adafactor"):
602
+ assert (
603
+ type(optimizer) == transformers.optimization.Adafactor
604
+ ), f"adafactor scheduler must be used with Adafactor optimizer / adafactor schedulerはAdafactorオプティマイザと同時に使ってください"
605
+ initial_lr = float(name.split(":")[1])
606
+ # logger.info(f"adafactor scheduler init lr {initial_lr}")
607
+ return wrap_check_needless_num_warmup_steps(transformers.optimization.AdafactorSchedule(optimizer, initial_lr))
608
+
609
+ if name == DiffusersSchedulerType.PIECEWISE_CONSTANT.value:
610
+ name = DiffusersSchedulerType(name)
611
+ schedule_func = DIFFUSERS_TYPE_TO_SCHEDULER_FUNCTION[name]
612
+ return schedule_func(optimizer, **lr_scheduler_kwargs) # step_rules and last_epoch are given as kwargs
613
+
614
+ name = SchedulerType(name)
615
+ schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
616
+
617
+ if name == SchedulerType.CONSTANT:
618
+ return wrap_check_needless_num_warmup_steps(schedule_func(optimizer, **lr_scheduler_kwargs))
619
+
620
+ # All other schedulers require `num_warmup_steps`
621
+ if num_warmup_steps is None:
622
+ raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
623
+
624
+ if name == SchedulerType.CONSTANT_WITH_WARMUP:
625
+ return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, **lr_scheduler_kwargs)
626
+
627
+ if name == SchedulerType.INVERSE_SQRT:
628
+ return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, timescale=timescale, **lr_scheduler_kwargs)
629
+
630
+ # All other schedulers require `num_training_steps`
631
+ if num_training_steps is None:
632
+ raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
633
+
634
+ if name == SchedulerType.COSINE_WITH_RESTARTS:
635
+ return schedule_func(
636
+ optimizer,
637
+ num_warmup_steps=num_warmup_steps,
638
+ num_training_steps=num_training_steps,
639
+ num_cycles=num_cycles,
640
+ **lr_scheduler_kwargs,
641
+ )
642
+
643
+ if name == SchedulerType.POLYNOMIAL:
644
+ return schedule_func(
645
+ optimizer,
646
+ num_warmup_steps=num_warmup_steps,
647
+ num_training_steps=num_training_steps,
648
+ power=power,
649
+ **lr_scheduler_kwargs,
650
+ )
651
+
652
+ if name == SchedulerType.COSINE_WITH_MIN_LR:
653
+ return schedule_func(
654
+ optimizer,
655
+ num_warmup_steps=num_warmup_steps,
656
+ num_training_steps=num_training_steps,
657
+ num_cycles=num_cycles / 2,
658
+ min_lr_rate=min_lr_ratio,
659
+ **lr_scheduler_kwargs,
660
+ )
661
+
662
+ # these schedulers do not require `num_decay_steps`
663
+ if name == SchedulerType.LINEAR or name == SchedulerType.COSINE:
664
+ return schedule_func(
665
+ optimizer,
666
+ num_warmup_steps=num_warmup_steps,
667
+ num_training_steps=num_training_steps,
668
+ **lr_scheduler_kwargs,
669
+ )
670
+
671
+ # All other schedulers require `num_decay_steps`
672
+ if num_decay_steps is None:
673
+ raise ValueError(f"{name} requires `num_decay_steps`, please provide that argument.")
674
+ if name == SchedulerType.WARMUP_STABLE_DECAY:
675
+ return schedule_func(
676
+ optimizer,
677
+ num_warmup_steps=num_warmup_steps,
678
+ num_stable_steps=num_stable_steps,
679
+ num_decay_steps=num_decay_steps,
680
+ num_cycles=num_cycles / 2,
681
+ min_lr_ratio=min_lr_ratio if min_lr_ratio is not None else 0.0,
682
+ **lr_scheduler_kwargs,
683
+ )
684
+
685
+ return schedule_func(
686
+ optimizer,
687
+ num_warmup_steps=num_warmup_steps,
688
+ num_training_steps=num_training_steps,
689
+ num_decay_steps=num_decay_steps,
690
+ **lr_scheduler_kwargs,
691
+ )
692
+
693
+ def resume_from_local_or_hf_if_specified(self, accelerator: Accelerator, args: argparse.Namespace) -> bool:
694
+ if not args.resume:
695
+ return False
696
+
697
+ if not args.resume_from_huggingface:
698
+ logger.info(f"resume training from local state: {args.resume}")
699
+ accelerator.load_state(args.resume)
700
+ return True
701
+
702
+ logger.info(f"resume training from huggingface state: {args.resume}")
703
+ repo_id = args.resume.split("/")[0] + "/" + args.resume.split("/")[1]
704
+ path_in_repo = "/".join(args.resume.split("/")[2:])
705
+ revision = None
706
+ repo_type = None
707
+ if ":" in path_in_repo:
708
+ divided = path_in_repo.split(":")
709
+ if len(divided) == 2:
710
+ path_in_repo, revision = divided
711
+ repo_type = "model"
712
+ else:
713
+ path_in_repo, revision, repo_type = divided
714
+ logger.info(f"Downloading state from huggingface: {repo_id}/{path_in_repo}@{revision}")
715
+
716
+ list_files = huggingface_utils.list_dir(
717
+ repo_id=repo_id,
718
+ subfolder=path_in_repo,
719
+ revision=revision,
720
+ token=args.huggingface_token,
721
+ repo_type=repo_type,
722
+ )
723
+
724
+ async def download(filename) -> str:
725
+ def task():
726
+ return huggingface_hub.hf_hub_download(
727
+ repo_id=repo_id,
728
+ filename=filename,
729
+ revision=revision,
730
+ repo_type=repo_type,
731
+ token=args.huggingface_token,
732
+ )
733
+
734
+ return await asyncio.get_event_loop().run_in_executor(None, task)
735
+
736
+ loop = asyncio.get_event_loop()
737
+ results = loop.run_until_complete(asyncio.gather(*[download(filename=filename.rfilename) for filename in list_files]))
738
+ if len(results) == 0:
739
+ raise ValueError(
740
+ "No files found in the specified repo id/path/revision / 指定されたリポジトリID/パス/リビジョンにファイルが見つかりませんでした"
741
+ )
742
+ dirname = os.path.dirname(results[0])
743
+ accelerator.load_state(dirname)
744
+
745
+ return True
746
+
747
+ def sample_images(self, accelerator, args, epoch, global_step, device, vae, transformer, sample_parameters):
748
+ pass
749
+
750
+ def get_noisy_model_input_and_timesteps(
751
+ self,
752
+ args: argparse.Namespace,
753
+ noise: torch.Tensor,
754
+ latents: torch.Tensor,
755
+ noise_scheduler: FlowMatchDiscreteScheduler,
756
+ device: torch.device,
757
+ dtype: torch.dtype,
758
+ ):
759
+ batch_size = noise.shape[0]
760
+
761
+ if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid" or args.timestep_sampling == "shift":
762
+ if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid":
763
+ # Simple random t-based noise sampling
764
+ if args.timestep_sampling == "sigmoid":
765
+ t = torch.sigmoid(args.sigmoid_scale * torch.randn((batch_size,), device=device))
766
+ else:
767
+ t = torch.rand((batch_size,), device=device)
768
+
769
+ elif args.timestep_sampling == "shift":
770
+ shift = args.discrete_flow_shift
771
+ logits_norm = torch.randn(batch_size, device=device)
772
+ logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling
773
+ t = logits_norm.sigmoid()
774
+ t = (t * shift) / (1 + (shift - 1) * t)
775
+
776
+ t_min = args.min_timestep if args.min_timestep is not None else 0
777
+ t_max = args.max_timestep if args.max_timestep is not None else 1000.0
778
+ t_min /= 1000.0
779
+ t_max /= 1000.0
780
+ t = t * (t_max - t_min) + t_min # scale to [t_min, t_max], default [0, 1]
781
+
782
+ timesteps = t * 1000.0
783
+ t = t.view(-1, 1, 1, 1, 1)
784
+ noisy_model_input = (1 - t) * latents + t * noise
785
+
786
+ timesteps += 1 # 1 to 1000
787
+ else:
788
+ # Sample a random timestep for each image
789
+ # for weighting schemes where we sample timesteps non-uniformly
790
+ u = compute_density_for_timestep_sampling(
791
+ weighting_scheme=args.weighting_scheme,
792
+ batch_size=batch_size,
793
+ logit_mean=args.logit_mean,
794
+ logit_std=args.logit_std,
795
+ mode_scale=args.mode_scale,
796
+ )
797
+ # indices = (u * noise_scheduler.config.num_train_timesteps).long()
798
+ t_min = args.min_timestep if args.min_timestep is not None else 0
799
+ t_max = args.max_timestep if args.max_timestep is not None else 1000
800
+ indices = (u * (t_max - t_min) + t_min).long()
801
+
802
+ timesteps = noise_scheduler.timesteps[indices].to(device=device) # 1 to 1000
803
+
804
+ # Add noise according to flow matching.
805
+ sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype)
806
+ noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents
807
+
808
+ return noisy_model_input, timesteps
809
+
810
+ def show_timesteps(self, args: argparse.Namespace):
811
+ N_TRY = 100000
812
+ BATCH_SIZE = 1000
813
+ CONSOLE_WIDTH = 64
814
+ N_TIMESTEPS_PER_LINE = 25
815
+
816
+ noise_scheduler = FlowMatchDiscreteScheduler(shift=args.discrete_flow_shift, reverse=True, solver="euler")
817
+ # print(f"Noise scheduler timesteps: {noise_scheduler.timesteps}")
818
+
819
+ latents = torch.zeros(BATCH_SIZE, 1, 1, 1, 1, dtype=torch.float16)
820
+ noise = torch.ones_like(latents)
821
+
822
+ # sample timesteps
823
+ sampled_timesteps = [0] * noise_scheduler.config.num_train_timesteps
824
+ for i in tqdm(range(N_TRY // BATCH_SIZE)):
825
+ # we use noise=1, so retured noisy_model_input is same as timestep, because `noisy_model_input = (1 - t) * latents + t * noise`
826
+ actual_timesteps, _ = self.get_noisy_model_input_and_timesteps(
827
+ args, noise, latents, noise_scheduler, "cpu", torch.float16
828
+ )
829
+ actual_timesteps = actual_timesteps[:, 0, 0, 0, 0] * 1000
830
+ for t in actual_timesteps:
831
+ t = int(t.item())
832
+ sampled_timesteps[t] += 1
833
+
834
+ # sample weighting
835
+ sampled_weighting = [0] * noise_scheduler.config.num_train_timesteps
836
+ for i in tqdm(range(len(sampled_weighting))):
837
+ timesteps = torch.tensor([i + 1], device="cpu")
838
+ weighting = compute_loss_weighting_for_sd3(args.weighting_scheme, noise_scheduler, timesteps, "cpu", torch.float16)
839
+ if weighting is None:
840
+ weighting = torch.tensor(1.0, device="cpu")
841
+ elif torch.isinf(weighting).any():
842
+ weighting = torch.tensor(1.0, device="cpu")
843
+ sampled_weighting[i] = weighting.item()
844
+
845
+ # show results
846
+ if args.show_timesteps == "image":
847
+ # show timesteps with matplotlib
848
+ import matplotlib.pyplot as plt
849
+
850
+ plt.figure(figsize=(10, 5))
851
+ plt.subplot(1, 2, 1)
852
+ plt.bar(range(len(sampled_timesteps)), sampled_timesteps, width=1.0)
853
+ plt.title("Sampled timesteps")
854
+ plt.xlabel("Timestep")
855
+ plt.ylabel("Count")
856
+
857
+ plt.subplot(1, 2, 2)
858
+ plt.bar(range(len(sampled_weighting)), sampled_weighting, width=1.0)
859
+ plt.title("Sampled loss weighting")
860
+ plt.xlabel("Timestep")
861
+ plt.ylabel("Weighting")
862
+
863
+ plt.tight_layout()
864
+ plt.show()
865
+
866
+ else:
867
+ sampled_timesteps = np.array(sampled_timesteps)
868
+ sampled_weighting = np.array(sampled_weighting)
869
+
870
+ # average per line
871
+ sampled_timesteps = sampled_timesteps.reshape(-1, N_TIMESTEPS_PER_LINE).mean(axis=1)
872
+ sampled_weighting = sampled_weighting.reshape(-1, N_TIMESTEPS_PER_LINE).mean(axis=1)
873
+
874
+ max_count = max(sampled_timesteps)
875
+ print(f"Sampled timesteps: max count={max_count}")
876
+ for i, t in enumerate(sampled_timesteps):
877
+ line = f"{(i)*N_TIMESTEPS_PER_LINE:4d}-{(i+1)*N_TIMESTEPS_PER_LINE-1:4d}: "
878
+ line += "#" * int(t / max_count * CONSOLE_WIDTH)
879
+ print(line)
880
+
881
+ max_weighting = max(sampled_weighting)
882
+ print(f"Sampled loss weighting: max weighting={max_weighting}")
883
+ for i, w in enumerate(sampled_weighting):
884
+ line = f"{i*N_TIMESTEPS_PER_LINE:4d}-{(i+1)*N_TIMESTEPS_PER_LINE-1:4d}: {w:8.2f} "
885
+ line += "#" * int(w / max_weighting * CONSOLE_WIDTH)
886
+ print(line)
887
+
888
+ def train(self, args):
889
+ # show timesteps for debugging
890
+ if args.show_timesteps:
891
+ self.show_timesteps(args)
892
+ return
893
+
894
+ session_id = random.randint(0, 2**32)
895
+ training_started_at = time.time()
896
+ # setup_logging(args, reset=True)
897
+
898
+ if args.seed is None:
899
+ args.seed = random.randint(0, 2**32)
900
+ set_seed(args.seed)
901
+
902
+ # Load dataset config
903
+ blueprint_generator = BlueprintGenerator(ConfigSanitizer())
904
+ logger.info(f"Load dataset config from {args.dataset_config}")
905
+ user_config = config_utils.load_user_config(args.dataset_config)
906
+ blueprint = blueprint_generator.generate(user_config, args)
907
+ train_dataset_group = config_utils.generate_dataset_group_by_blueprint(blueprint.dataset_group, training=True)
908
+
909
+ current_epoch = Value("i", 0)
910
+ current_step = Value("i", 0)
911
+ ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
912
+ collator = collator_class(current_epoch, current_step, ds_for_collator)
913
+
914
+ # prepare accelerator
915
+ logger.info("preparing accelerator")
916
+ accelerator = prepare_accelerator(args)
917
+ is_main_process = accelerator.is_main_process
918
+
919
+ # prepare dtype
920
+ weight_dtype = torch.float32
921
+ if args.mixed_precision == "fp16":
922
+ weight_dtype = torch.float16
923
+ elif args.mixed_precision == "bf16":
924
+ weight_dtype = torch.bfloat16
925
+
926
+ # HunyuanVideo specific
927
+ dit_dtype = torch.bfloat16 if args.dit_dtype is None else model_utils.str_to_dtype(args.dit_dtype)
928
+ dit_weight_dtype = torch.float8_e4m3fn if args.fp8_base else dit_dtype
929
+ logger.info(f"DiT precision: {dit_dtype}, weight precision: {dit_weight_dtype}")
930
+ vae_dtype = torch.float16 if args.vae_dtype is None else model_utils.str_to_dtype(args.vae_dtype)
931
+
932
+ # get embedding for sampling images
933
+ sample_parameters = vae = None
934
+ if args.sample_prompts:
935
+ sample_parameters = self.process_sample_prompts(
936
+ args, accelerator, args.sample_prompts, args.text_encoder1, args.text_encoder2, args.fp8_llm
937
+ )
938
+
939
+ # Load VAE model for sampling images: VAE is loaded to cpu to save gpu memory
940
+ vae, _, s_ratio, t_ratio = load_vae(vae_dtype=vae_dtype, device="cpu", vae_path=args.vae)
941
+ vae.requires_grad_(False)
942
+ vae.eval()
943
+
944
+ if args.vae_chunk_size is not None:
945
+ vae.set_chunk_size_for_causal_conv_3d(args.vae_chunk_size)
946
+ logger.info(f"Set chunk_size to {args.vae_chunk_size} for CausalConv3d in VAE")
947
+ if args.vae_spatial_tile_sample_min_size is not None:
948
+ vae.enable_spatial_tiling(True)
949
+ vae.tile_sample_min_size = args.vae_spatial_tile_sample_min_size
950
+ vae.tile_latent_min_size = args.vae_spatial_tile_sample_min_size // 8
951
+ elif args.vae_tiling:
952
+ vae.enable_spatial_tiling(True)
953
+
954
+ # load DiT model
955
+ blocks_to_swap = args.blocks_to_swap if args.blocks_to_swap else 0
956
+ loading_device = "cpu" if blocks_to_swap > 0 else accelerator.device
957
+
958
+ logger.info(f"Loading DiT model from {args.dit}")
959
+ if args.sdpa:
960
+ attn_mode = "torch"
961
+ elif args.flash_attn:
962
+ attn_mode = "flash"
963
+ elif args.sage_attn:
964
+ attn_mode = "sageattn"
965
+ else:
966
+ raise ValueError(
967
+ f"either --sdpa or --flash-attn or --sage-attn must be specified / --sdpaか--flash-attnか--sage-attnのいずれかを指定してください"
968
+ )
969
+ transformer = load_transformer(args.dit, attn_mode, loading_device, dit_weight_dtype)
970
+ transformer.eval()
971
+ transformer.requires_grad_(False)
972
+
973
+ if blocks_to_swap > 0:
974
+ logger.info(f"enable swap {blocks_to_swap} blocks to CPU from device: {accelerator.device}")
975
+ transformer.enable_block_swap(blocks_to_swap, accelerator.device, supports_backward=True)
976
+ transformer.move_to_device_except_swap_blocks(accelerator.device)
977
+ if args.img_in_txt_in_offloading:
978
+ logger.info("Enable offloading img_in and txt_in to CPU")
979
+ transformer.enable_img_in_txt_in_offloading()
980
+
981
+ # load network model for differential training
982
+ sys.path.append(os.path.dirname(__file__))
983
+ accelerator.print("import network module:", args.network_module)
984
+ network_module: lora_module = importlib.import_module(args.network_module) # actual module may be different
985
+
986
+ if args.base_weights is not None:
987
+ # if base_weights is specified, merge the weights to DiT model
988
+ for i, weight_path in enumerate(args.base_weights):
989
+ if args.base_weights_multiplier is None or len(args.base_weights_multiplier) <= i:
990
+ multiplier = 1.0
991
+ else:
992
+ multiplier = args.base_weights_multiplier[i]
993
+
994
+ accelerator.print(f"merging module: {weight_path} with multiplier {multiplier}")
995
+
996
+ weights_sd = load_file(weight_path)
997
+ module = network_module.create_network_from_weights_hunyuan_video(multiplier, weights_sd, unet=transformer)
998
+ module.merge_to(None, transformer, weights_sd, weight_dtype, "cpu")
999
+
1000
+ accelerator.print(f"all weights merged: {', '.join(args.base_weights)}")
1001
+
1002
+ # prepare network
1003
+ net_kwargs = {}
1004
+ if args.network_args is not None:
1005
+ for net_arg in args.network_args:
1006
+ key, value = net_arg.split("=")
1007
+ net_kwargs[key] = value
1008
+
1009
+ if args.dim_from_weights:
1010
+ logger.info(f"Loading network from weights: {args.dim_from_weights}")
1011
+ weights_sd = load_file(args.dim_from_weights)
1012
+ network, _ = network_module.create_network_from_weights_hunyuan_video(1, weights_sd, unet=transformer)
1013
+ else:
1014
+ network = network_module.create_network_hunyuan_video(
1015
+ 1.0,
1016
+ args.network_dim,
1017
+ args.network_alpha,
1018
+ vae,
1019
+ None,
1020
+ transformer,
1021
+ neuron_dropout=args.network_dropout,
1022
+ **net_kwargs,
1023
+ )
1024
+ if network is None:
1025
+ return
1026
+
1027
+ network.prepare_network(args)
1028
+
1029
+ # apply network to DiT
1030
+ network.apply_to(None, transformer, apply_text_encoder=False, apply_unet=True)
1031
+
1032
+ if args.network_weights is not None:
1033
+ # FIXME consider alpha of weights: this assumes that the alpha is not changed
1034
+ info = network.load_weights(args.network_weights)
1035
+ accelerator.print(f"load network weights from {args.network_weights}: {info}")
1036
+
1037
+ if args.gradient_checkpointing:
1038
+ transformer.enable_gradient_checkpointing()
1039
+ network.enable_gradient_checkpointing() # may have no effect
1040
+
1041
+ # prepare optimizer, data loader etc.
1042
+ accelerator.print("prepare optimizer, data loader etc.")
1043
+
1044
+ trainable_params, lr_descriptions = network.prepare_optimizer_params(unet_lr=args.learning_rate)
1045
+ optimizer_name, optimizer_args, optimizer, optimizer_train_fn, optimizer_eval_fn = self.get_optimizer(
1046
+ args, trainable_params
1047
+ )
1048
+
1049
+ # prepare dataloader
1050
+
1051
+ # num workers for data loader: if 0, persistent_workers is not available
1052
+ n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
1053
+
1054
+ train_dataloader = torch.utils.data.DataLoader(
1055
+ train_dataset_group,
1056
+ batch_size=1,
1057
+ shuffle=True,
1058
+ collate_fn=collator,
1059
+ num_workers=n_workers,
1060
+ persistent_workers=args.persistent_data_loader_workers,
1061
+ )
1062
+
1063
+ # calculate max_train_steps
1064
+ if args.max_train_epochs is not None:
1065
+ args.max_train_steps = args.max_train_epochs * math.ceil(
1066
+ len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
1067
+ )
1068
+ accelerator.print(
1069
+ f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}"
1070
+ )
1071
+
1072
+ # send max_train_steps to train_dataset_group
1073
+ train_dataset_group.set_max_train_steps(args.max_train_steps)
1074
+
1075
+ # prepare lr_scheduler
1076
+ lr_scheduler = self.get_scheduler(args, optimizer, accelerator.num_processes)
1077
+
1078
+ # prepare training model. accelerator does some magic here
1079
+
1080
+ # experimental feature: train the model with gradients in fp16/bf16
1081
+ network_dtype = torch.float32
1082
+ args.full_fp16 = args.full_bf16 = False # temporary disabled because stochastic rounding is not supported yet
1083
+ if args.full_fp16:
1084
+ assert (
1085
+ args.mixed_precision == "fp16"
1086
+ ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
1087
+ accelerator.print("enable full fp16 training.")
1088
+ network_dtype = weight_dtype
1089
+ network.to(network_dtype)
1090
+ elif args.full_bf16:
1091
+ assert (
1092
+ args.mixed_precision == "bf16"
1093
+ ), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。"
1094
+ accelerator.print("enable full bf16 training.")
1095
+ network_dtype = weight_dtype
1096
+ network.to(network_dtype)
1097
+
1098
+ if dit_weight_dtype != dit_dtype:
1099
+ logger.info(f"casting model to {dit_weight_dtype}")
1100
+ transformer.to(dit_weight_dtype)
1101
+
1102
+ if blocks_to_swap > 0:
1103
+ transformer = accelerator.prepare(transformer, device_placement=[not blocks_to_swap > 0])
1104
+ accelerator.unwrap_model(transformer).move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage
1105
+ accelerator.unwrap_model(transformer).prepare_block_swap_before_forward()
1106
+ else:
1107
+ transformer = accelerator.prepare(transformer)
1108
+
1109
+ network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(network, optimizer, train_dataloader, lr_scheduler)
1110
+ training_model = network
1111
+
1112
+ if args.gradient_checkpointing:
1113
+ transformer.train()
1114
+ else:
1115
+ transformer.eval()
1116
+
1117
+ accelerator.unwrap_model(network).prepare_grad_etc(transformer)
1118
+
1119
+ if args.full_fp16:
1120
+ # patch accelerator for fp16 training
1121
+ # def patch_accelerator_for_fp16_training(accelerator):
1122
+ org_unscale_grads = accelerator.scaler._unscale_grads_
1123
+
1124
+ def _unscale_grads_replacer(optimizer, inv_scale, found_inf, allow_fp16):
1125
+ return org_unscale_grads(optimizer, inv_scale, found_inf, True)
1126
+
1127
+ accelerator.scaler._unscale_grads_ = _unscale_grads_replacer
1128
+
1129
+ # before resuming make hook for saving/loading to save/load the network weights only
1130
+ def save_model_hook(models, weights, output_dir):
1131
+ # pop weights of other models than network to save only network weights
1132
+ # only main process or deepspeed https://github.com/huggingface/diffusers/issues/2606
1133
+ if accelerator.is_main_process: # or args.deepspeed:
1134
+ remove_indices = []
1135
+ for i, model in enumerate(models):
1136
+ if not isinstance(model, type(accelerator.unwrap_model(network))):
1137
+ remove_indices.append(i)
1138
+ for i in reversed(remove_indices):
1139
+ if len(weights) > i:
1140
+ weights.pop(i)
1141
+ # print(f"save model hook: {len(weights)} weights will be saved")
1142
+
1143
+ def load_model_hook(models, input_dir):
1144
+ # remove models except network
1145
+ remove_indices = []
1146
+ for i, model in enumerate(models):
1147
+ if not isinstance(model, type(accelerator.unwrap_model(network))):
1148
+ remove_indices.append(i)
1149
+ for i in reversed(remove_indices):
1150
+ models.pop(i)
1151
+ # print(f"load model hook: {len(models)} models will be loaded")
1152
+
1153
+ accelerator.register_save_state_pre_hook(save_model_hook)
1154
+ accelerator.register_load_state_pre_hook(load_model_hook)
1155
+
1156
+ # resume from local or huggingface. accelerator.step is set
1157
+ self.resume_from_local_or_hf_if_specified(accelerator, args) # accelerator.load_state(args.resume)
1158
+
1159
+ # epoch数を計算する
1160
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
1161
+ num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
1162
+
1163
+ # 学習する
1164
+ # total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
1165
+
1166
+ accelerator.print("running training / 学習開始")
1167
+ accelerator.print(f" num train items / 学習画像、動画数: {train_dataset_group.num_train_items}")
1168
+ accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
1169
+ accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
1170
+ accelerator.print(
1171
+ f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}"
1172
+ )
1173
+ # accelerator.print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
1174
+ accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
1175
+ accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
1176
+
1177
+ # TODO refactor metadata creation and move to util
1178
+ metadata = {
1179
+ "ss_" "ss_session_id": session_id, # random integer indicating which group of epochs the model came from
1180
+ "ss_training_started_at": training_started_at, # unix timestamp
1181
+ "ss_output_name": args.output_name,
1182
+ "ss_learning_rate": args.learning_rate,
1183
+ "ss_num_train_items": train_dataset_group.num_train_items,
1184
+ "ss_num_batches_per_epoch": len(train_dataloader),
1185
+ "ss_num_epochs": num_train_epochs,
1186
+ "ss_gradient_checkpointing": args.gradient_checkpointing,
1187
+ "ss_gradient_accumulation_steps": args.gradient_accumulation_steps,
1188
+ "ss_max_train_steps": args.max_train_steps,
1189
+ "ss_lr_warmup_steps": args.lr_warmup_steps,
1190
+ "ss_lr_scheduler": args.lr_scheduler,
1191
+ SS_METADATA_KEY_BASE_MODEL_VERSION: BASE_MODEL_VERSION_HUNYUAN_VIDEO,
1192
+ # "ss_network_module": args.network_module,
1193
+ # "ss_network_dim": args.network_dim, # None means default because another network than LoRA may have another default dim
1194
+ # "ss_network_alpha": args.network_alpha, # some networks may not have alpha
1195
+ SS_METADATA_KEY_NETWORK_MODULE: args.network_module,
1196
+ SS_METADATA_KEY_NETWORK_DIM: args.network_dim,
1197
+ SS_METADATA_KEY_NETWORK_ALPHA: args.network_alpha,
1198
+ "ss_network_dropout": args.network_dropout, # some networks may not have dropout
1199
+ "ss_mixed_precision": args.mixed_precision,
1200
+ "ss_seed": args.seed,
1201
+ "ss_training_comment": args.training_comment, # will not be updated after training
1202
+ # "ss_sd_scripts_commit_hash": train_util.get_git_revision_hash(),
1203
+ "ss_optimizer": optimizer_name + (f"({optimizer_args})" if len(optimizer_args) > 0 else ""),
1204
+ "ss_max_grad_norm": args.max_grad_norm,
1205
+ "ss_fp8_base": bool(args.fp8_base),
1206
+ "ss_fp8_llm": bool(args.fp8_llm),
1207
+ "ss_full_fp16": bool(args.full_fp16),
1208
+ "ss_full_bf16": bool(args.full_bf16),
1209
+ "ss_weighting_scheme": args.weighting_scheme,
1210
+ "ss_logit_mean": args.logit_mean,
1211
+ "ss_logit_std": args.logit_std,
1212
+ "ss_mode_scale": args.mode_scale,
1213
+ "ss_guidance_scale": args.guidance_scale,
1214
+ "ss_timestep_sampling": args.timestep_sampling,
1215
+ "ss_sigmoid_scale": args.sigmoid_scale,
1216
+ "ss_discrete_flow_shift": args.discrete_flow_shift,
1217
+ }
1218
+
1219
+ datasets_metadata = []
1220
+ # tag_frequency = {} # merge tag frequency for metadata editor # TODO support tag frequency
1221
+ for dataset in train_dataset_group.datasets:
1222
+ dataset_metadata = dataset.get_metadata()
1223
+ datasets_metadata.append(dataset_metadata)
1224
+
1225
+ metadata["ss_datasets"] = json.dumps(datasets_metadata)
1226
+
1227
+ # add extra args
1228
+ if args.network_args:
1229
+ # metadata["ss_network_args"] = json.dumps(net_kwargs)
1230
+ metadata[SS_METADATA_KEY_NETWORK_ARGS] = json.dumps(net_kwargs)
1231
+
1232
+ # model name and hash
1233
+ if args.dit is not None:
1234
+ logger.info(f"calculate hash for DiT model: {args.dit}")
1235
+ sd_model_name = args.dit
1236
+ if os.path.exists(sd_model_name):
1237
+ metadata["ss_sd_model_hash"] = model_utils.model_hash(sd_model_name)
1238
+ metadata["ss_new_sd_model_hash"] = model_utils.calculate_sha256(sd_model_name)
1239
+ sd_model_name = os.path.basename(sd_model_name)
1240
+ metadata["ss_sd_model_name"] = sd_model_name
1241
+
1242
+ if args.vae is not None:
1243
+ logger.info(f"calculate hash for VAE model: {args.vae}")
1244
+ vae_name = args.vae
1245
+ if os.path.exists(vae_name):
1246
+ metadata["ss_vae_hash"] = model_utils.model_hash(vae_name)
1247
+ metadata["ss_new_vae_hash"] = model_utils.calculate_sha256(vae_name)
1248
+ vae_name = os.path.basename(vae_name)
1249
+ metadata["ss_vae_name"] = vae_name
1250
+
1251
+ metadata = {k: str(v) for k, v in metadata.items()}
1252
+
1253
+ # make minimum metadata for filtering
1254
+ minimum_metadata = {}
1255
+ for key in SS_METADATA_MINIMUM_KEYS:
1256
+ if key in metadata:
1257
+ minimum_metadata[key] = metadata[key]
1258
+
1259
+ if accelerator.is_main_process:
1260
+ init_kwargs = {}
1261
+ if args.wandb_run_name:
1262
+ init_kwargs["wandb"] = {"name": args.wandb_run_name}
1263
+ if args.log_tracker_config is not None:
1264
+ init_kwargs = toml.load(args.log_tracker_config)
1265
+ accelerator.init_trackers(
1266
+ "network_train" if args.log_tracker_name is None else args.log_tracker_name,
1267
+ config=train_utils.get_sanitized_config_or_none(args),
1268
+ init_kwargs=init_kwargs,
1269
+ )
1270
+
1271
+ # TODO skip until initial step
1272
+ progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
1273
+
1274
+ epoch_to_start = 0
1275
+ global_step = 0
1276
+ noise_scheduler = FlowMatchDiscreteScheduler(shift=args.discrete_flow_shift, reverse=True, solver="euler")
1277
+
1278
+ loss_recorder = train_utils.LossRecorder()
1279
+ del train_dataset_group
1280
+
1281
+ # function for saving/removing
1282
+ save_dtype = dit_dtype
1283
+
1284
+ def save_model(ckpt_name: str, unwrapped_nw, steps, epoch_no, force_sync_upload=False):
1285
+ os.makedirs(args.output_dir, exist_ok=True)
1286
+ ckpt_file = os.path.join(args.output_dir, ckpt_name)
1287
+
1288
+ accelerator.print(f"\nsaving checkpoint: {ckpt_file}")
1289
+ metadata["ss_training_finished_at"] = str(time.time())
1290
+ metadata["ss_steps"] = str(steps)
1291
+ metadata["ss_epoch"] = str(epoch_no)
1292
+
1293
+ metadata_to_save = minimum_metadata if args.no_metadata else metadata
1294
+
1295
+ title = args.metadata_title if args.metadata_title is not None else args.output_name
1296
+ if args.min_timestep is not None or args.max_timestep is not None:
1297
+ min_time_step = args.min_timestep if args.min_timestep is not None else 0
1298
+ max_time_step = args.max_timestep if args.max_timestep is not None else 1000
1299
+ md_timesteps = (min_time_step, max_time_step)
1300
+ else:
1301
+ md_timesteps = None
1302
+
1303
+ sai_metadata = sai_model_spec.build_metadata(
1304
+ None,
1305
+ time.time(),
1306
+ title,
1307
+ None,
1308
+ args.metadata_author,
1309
+ args.metadata_description,
1310
+ args.metadata_license,
1311
+ args.metadata_tags,
1312
+ timesteps=md_timesteps,
1313
+ )
1314
+
1315
+ metadata_to_save.update(sai_metadata)
1316
+
1317
+ unwrapped_nw.save_weights(ckpt_file, save_dtype, metadata_to_save)
1318
+ if args.huggingface_repo_id is not None:
1319
+ huggingface_utils.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload)
1320
+
1321
+ def remove_model(old_ckpt_name):
1322
+ old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name)
1323
+ if os.path.exists(old_ckpt_file):
1324
+ accelerator.print(f"removing old checkpoint: {old_ckpt_file}")
1325
+ os.remove(old_ckpt_file)
1326
+
1327
+ # For --sample_at_first
1328
+ optimizer_eval_fn()
1329
+ self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, transformer, sample_parameters)
1330
+ optimizer_train_fn()
1331
+ if len(accelerator.trackers) > 0:
1332
+ # log empty object to commit the sample images to wandb
1333
+ accelerator.log({}, step=0)
1334
+
1335
+ # training loop
1336
+
1337
+ # log device and dtype for each model
1338
+ logger.info(f"DiT dtype: {transformer.dtype}, device: {transformer.device}")
1339
+
1340
+ clean_memory_on_device(accelerator.device)
1341
+
1342
+ pos_embed_cache = {}
1343
+
1344
+ for epoch in range(epoch_to_start, num_train_epochs):
1345
+ accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
1346
+ current_epoch.value = epoch + 1
1347
+
1348
+ metadata["ss_epoch"] = str(epoch + 1)
1349
+
1350
+ accelerator.unwrap_model(network).on_epoch_start(transformer)
1351
+
1352
+ for step, batch in enumerate(train_dataloader):
1353
+ latents, llm_embeds, llm_mask, clip_embeds = batch
1354
+ bsz = latents.shape[0]
1355
+ current_step.value = global_step
1356
+
1357
+ with accelerator.accumulate(training_model):
1358
+ accelerator.unwrap_model(network).on_step_start()
1359
+
1360
+ latents = latents * vae_module.SCALING_FACTOR
1361
+
1362
+ # Sample noise that we'll add to the latents
1363
+ noise = torch.randn_like(latents)
1364
+
1365
+ # calculate model input and timesteps
1366
+ noisy_model_input, timesteps = self.get_noisy_model_input_and_timesteps(
1367
+ args, noise, latents, noise_scheduler, accelerator.device, dit_dtype
1368
+ )
1369
+
1370
+ weighting = compute_loss_weighting_for_sd3(
1371
+ args.weighting_scheme, noise_scheduler, timesteps, accelerator.device, dit_dtype
1372
+ )
1373
+
1374
+ # ensure guidance_scale in args is float
1375
+ guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device) # , dtype=dit_dtype)
1376
+
1377
+ # ensure the hidden state will require grad
1378
+ if args.gradient_checkpointing:
1379
+ noisy_model_input.requires_grad_(True)
1380
+ guidance_vec.requires_grad_(True)
1381
+
1382
+ pos_emb_shape = latents.shape[1:]
1383
+ if pos_emb_shape not in pos_embed_cache:
1384
+ freqs_cos, freqs_sin = get_rotary_pos_embed_by_shape(transformer, latents.shape[2:])
1385
+ # freqs_cos = freqs_cos.to(device=accelerator.device, dtype=dit_dtype)
1386
+ # freqs_sin = freqs_sin.to(device=accelerator.device, dtype=dit_dtype)
1387
+ pos_embed_cache[pos_emb_shape] = (freqs_cos, freqs_sin)
1388
+ else:
1389
+ freqs_cos, freqs_sin = pos_embed_cache[pos_emb_shape]
1390
+
1391
+ # call DiT
1392
+ latents = latents.to(device=accelerator.device, dtype=network_dtype)
1393
+ noisy_model_input = noisy_model_input.to(device=accelerator.device, dtype=network_dtype)
1394
+ # timesteps = timesteps.to(device=accelerator.device, dtype=dit_dtype)
1395
+ # llm_embeds = llm_embeds.to(device=accelerator.device, dtype=dit_dtype)
1396
+ # llm_mask = llm_mask.to(device=accelerator.device)
1397
+ # clip_embeds = clip_embeds.to(device=accelerator.device, dtype=dit_dtype)
1398
+ with accelerator.autocast():
1399
+ model_pred = transformer(
1400
+ noisy_model_input,
1401
+ timesteps,
1402
+ text_states=llm_embeds,
1403
+ text_mask=llm_mask,
1404
+ text_states_2=clip_embeds,
1405
+ freqs_cos=freqs_cos,
1406
+ freqs_sin=freqs_sin,
1407
+ guidance=guidance_vec,
1408
+ return_dict=False,
1409
+ )
1410
+
1411
+ # flow matching loss
1412
+ target = noise - latents
1413
+
1414
+ loss = torch.nn.functional.mse_loss(model_pred.to(network_dtype), target, reduction="none")
1415
+
1416
+ if weighting is not None:
1417
+ loss = loss * weighting
1418
+ # loss = loss.mean([1, 2, 3])
1419
+ # # min snr gamma, scale v pred loss like noise pred, v pred like loss, debiased estimation etc.
1420
+ # loss = self.post_process_loss(loss, args, timesteps, noise_scheduler)
1421
+
1422
+ loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
1423
+
1424
+ accelerator.backward(loss)
1425
+ if accelerator.sync_gradients:
1426
+ # self.all_reduce_network(accelerator, network) # sync DDP grad manually
1427
+ state = accelerate.PartialState()
1428
+ if state.distributed_type != accelerate.DistributedType.NO:
1429
+ for param in network.parameters():
1430
+ if param.grad is not None:
1431
+ param.grad = accelerator.reduce(param.grad, reduction="mean")
1432
+
1433
+ if args.max_grad_norm != 0.0:
1434
+ params_to_clip = accelerator.unwrap_model(network).get_trainable_params()
1435
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
1436
+
1437
+ optimizer.step()
1438
+ lr_scheduler.step()
1439
+ optimizer.zero_grad(set_to_none=True)
1440
+
1441
+ if args.scale_weight_norms:
1442
+ keys_scaled, mean_norm, maximum_norm = accelerator.unwrap_model(network).apply_max_norm_regularization(
1443
+ args.scale_weight_norms, accelerator.device
1444
+ )
1445
+ max_mean_logs = {"Keys Scaled": keys_scaled, "Average key norm": mean_norm}
1446
+ else:
1447
+ keys_scaled, mean_norm, maximum_norm = None, None, None
1448
+
1449
+ # Checks if the accelerator has performed an optimization step behind the scenes
1450
+ if accelerator.sync_gradients:
1451
+ progress_bar.update(1)
1452
+ global_step += 1
1453
+
1454
+ optimizer_eval_fn()
1455
+ self.sample_images(
1456
+ accelerator, args, None, global_step, accelerator.device, vae, transformer, sample_parameters
1457
+ )
1458
+
1459
+ # 指定ステップごとにモデルを保存
1460
+ if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
1461
+ accelerator.wait_for_everyone()
1462
+ if accelerator.is_main_process:
1463
+ ckpt_name = train_utils.get_step_ckpt_name(args.output_name, global_step)
1464
+ save_model(ckpt_name, accelerator.unwrap_model(network), global_step, epoch)
1465
+
1466
+ if args.save_state:
1467
+ train_utils.save_and_remove_state_stepwise(args, accelerator, global_step)
1468
+
1469
+ remove_step_no = train_utils.get_remove_step_no(args, global_step)
1470
+ if remove_step_no is not None:
1471
+ remove_ckpt_name = train_utils.get_step_ckpt_name(args.output_name, remove_step_no)
1472
+ remove_model(remove_ckpt_name)
1473
+ optimizer_train_fn()
1474
+
1475
+ current_loss = loss.detach().item()
1476
+ loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
1477
+ avr_loss: float = loss_recorder.moving_average
1478
+ logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
1479
+ progress_bar.set_postfix(**logs)
1480
+
1481
+ if args.scale_weight_norms:
1482
+ progress_bar.set_postfix(**{**max_mean_logs, **logs})
1483
+
1484
+ if len(accelerator.trackers) > 0:
1485
+ logs = self.generate_step_logs(
1486
+ args, current_loss, avr_loss, lr_scheduler, lr_descriptions, optimizer, keys_scaled, mean_norm, maximum_norm
1487
+ )
1488
+ accelerator.log(logs, step=global_step)
1489
+
1490
+ if global_step >= args.max_train_steps:
1491
+ break
1492
+
1493
+ if len(accelerator.trackers) > 0:
1494
+ logs = {"loss/epoch": loss_recorder.moving_average}
1495
+ accelerator.log(logs, step=epoch + 1)
1496
+
1497
+ accelerator.wait_for_everyone()
1498
+
1499
+ # 指定エポックごとにモデルを保存
1500
+ optimizer_eval_fn()
1501
+ if args.save_every_n_epochs is not None:
1502
+ saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs
1503
+ if is_main_process and saving:
1504
+ ckpt_name = train_utils.get_epoch_ckpt_name(args.output_name, epoch + 1)
1505
+ save_model(ckpt_name, accelerator.unwrap_model(network), global_step, epoch + 1)
1506
+
1507
+ remove_epoch_no = train_utils.get_remove_epoch_no(args, epoch + 1)
1508
+ if remove_epoch_no is not None:
1509
+ remove_ckpt_name = train_utils.get_epoch_ckpt_name(args.output_name, remove_epoch_no)
1510
+ remove_model(remove_ckpt_name)
1511
+
1512
+ if args.save_state:
1513
+ train_utils.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1)
1514
+
1515
+ self.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, transformer, sample_parameters)
1516
+ optimizer_train_fn()
1517
+
1518
+ # end of epoch
1519
+
1520
+ # metadata["ss_epoch"] = str(num_train_epochs)
1521
+ metadata["ss_training_finished_at"] = str(time.time())
1522
+
1523
+ if is_main_process:
1524
+ network = accelerator.unwrap_model(network)
1525
+
1526
+ accelerator.end_training()
1527
+ optimizer_eval_fn()
1528
+
1529
+ if is_main_process and (args.save_state or args.save_state_on_train_end):
1530
+ train_utils.save_state_on_train_end(args, accelerator)
1531
+
1532
+ if is_main_process:
1533
+ ckpt_name = train_utils.get_last_ckpt_name(args.output_name)
1534
+ save_model(ckpt_name, network, global_step, num_train_epochs, force_sync_upload=True)
1535
+
1536
+ logger.info("model saved.")
1537
+
1538
+
1539
+ def setup_parser() -> argparse.ArgumentParser:
1540
+ def int_or_float(value):
1541
+ if value.endswith("%"):
1542
+ try:
1543
+ return float(value[:-1]) / 100.0
1544
+ except ValueError:
1545
+ raise argparse.ArgumentTypeError(f"Value '{value}' is not a valid percentage")
1546
+ try:
1547
+ float_value = float(value)
1548
+ if float_value >= 1 and float_value.is_integer():
1549
+ return int(value)
1550
+ return float(value)
1551
+ except ValueError:
1552
+ raise argparse.ArgumentTypeError(f"'{value}' is not an int or float")
1553
+
1554
+ parser = argparse.ArgumentParser()
1555
+
1556
+ # general settings
1557
+ parser.add_argument(
1558
+ "--config_file",
1559
+ type=str,
1560
+ default=None,
1561
+ help="using .toml instead of args to pass hyperparameter / ハイパーパラメータを引数ではなく.tomlファイルで渡す",
1562
+ )
1563
+ parser.add_argument(
1564
+ "--dataset_config",
1565
+ type=pathlib.Path,
1566
+ default=None,
1567
+ required=True,
1568
+ help="config file for dataset / データセットの設定ファイル",
1569
+ )
1570
+
1571
+ # training settings
1572
+ parser.add_argument(
1573
+ "--sdpa",
1574
+ action="store_true",
1575
+ help="use sdpa for CrossAttention (requires PyTorch 2.0) / CrossAttentionにsdpaを使う(PyTorch 2.0が必要)",
1576
+ )
1577
+ parser.add_argument(
1578
+ "--flash_attn",
1579
+ action="store_true",
1580
+ help="use FlashAttention for CrossAttention, requires FlashAttention / CrossAttentionにFlashAttentionを使う、FlashAttentionが必要",
1581
+ )
1582
+ parser.add_argument(
1583
+ "--sage_attn",
1584
+ action="store_true",
1585
+ help="use SageAttention. requires SageAttention / SageAttentionを使う。SageAttentionが必要",
1586
+ )
1587
+ parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数")
1588
+ parser.add_argument(
1589
+ "--max_train_epochs",
1590
+ type=int,
1591
+ default=None,
1592
+ help="training epochs (overrides max_train_steps) / 学習エポック数(max_train_stepsを上書きします)",
1593
+ )
1594
+ parser.add_argument(
1595
+ "--max_data_loader_n_workers",
1596
+ type=int,
1597
+ default=8,
1598
+ help="max num workers for DataLoader (lower is less main RAM usage, faster epoch start and slower data loading) / DataLoaderの最大プロセス数(小さい値ではメインメモリの使用量が減りエポック間の待ち時間が減りますが、データ読み込みは遅くなります)",
1599
+ )
1600
+ parser.add_argument(
1601
+ "--persistent_data_loader_workers",
1602
+ action="store_true",
1603
+ help="persistent DataLoader workers (useful for reduce time gap between epoch, but may use more memory) / DataLoader のワーカーを持続させる (エポック間の時間差を少なくするのに有効だが、より多くのメモリを消費する可能性がある)",
1604
+ )
1605
+ parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed")
1606
+ parser.add_argument(
1607
+ "--gradient_checkpointing", action="store_true", help="enable gradient checkpointing / gradient checkpointingを有効にする"
1608
+ )
1609
+ parser.add_argument(
1610
+ "--gradient_accumulation_steps",
1611
+ type=int,
1612
+ default=1,
1613
+ help="Number of updates steps to accumulate before performing a backward/update pass / 学習時に逆伝播をする前に勾配を合計するステップ数",
1614
+ )
1615
+ parser.add_argument(
1616
+ "--mixed_precision",
1617
+ type=str,
1618
+ default="no",
1619
+ choices=["no", "fp16", "bf16"],
1620
+ help="use mixed precision / 混合精度を使う場合、その精度",
1621
+ )
1622
+
1623
+ parser.add_argument(
1624
+ "--logging_dir",
1625
+ type=str,
1626
+ default=None,
1627
+ help="enable logging and output TensorBoard log to this directory / ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する",
1628
+ )
1629
+ parser.add_argument(
1630
+ "--log_with",
1631
+ type=str,
1632
+ default=None,
1633
+ choices=["tensorboard", "wandb", "all"],
1634
+ help="what logging tool(s) to use (if 'all', TensorBoard and WandB are both used) / ログ出力に使用するツール (allを指定するとTensorBoardとWandBの両方が使用される)",
1635
+ )
1636
+ parser.add_argument(
1637
+ "--log_prefix", type=str, default=None, help="add prefix for each log directory / ログディレクトリ名の先頭に追加する文字列"
1638
+ )
1639
+ parser.add_argument(
1640
+ "--log_tracker_name",
1641
+ type=str,
1642
+ default=None,
1643
+ help="name of tracker to use for logging, default is script-specific default name / ログ出力に使用するtrackerの名前、省略時はスクリプトごとのデフォルト名",
1644
+ )
1645
+ parser.add_argument(
1646
+ "--wandb_run_name",
1647
+ type=str,
1648
+ default=None,
1649
+ help="The name of the specific wandb session / wandb ログに表示される特定の実行の名前",
1650
+ )
1651
+ parser.add_argument(
1652
+ "--log_tracker_config",
1653
+ type=str,
1654
+ default=None,
1655
+ help="path to tracker config file to use for logging / ログ出力に使用するtrackerの設定ファイルのパス",
1656
+ )
1657
+ parser.add_argument(
1658
+ "--wandb_api_key",
1659
+ type=str,
1660
+ default=None,
1661
+ help="specify WandB API key to log in before starting training (optional). / WandB APIキーを指定して学習開始前にログインする(オプション)",
1662
+ )
1663
+ parser.add_argument("--log_config", action="store_true", help="log training configuration / 学習設定をログに出力する")
1664
+
1665
+ parser.add_argument(
1666
+ "--ddp_timeout",
1667
+ type=int,
1668
+ default=None,
1669
+ help="DDP timeout (min, None for default of accelerate) / DDPのタイムアウト(分、Noneでaccelerateのデフォルト)",
1670
+ )
1671
+ parser.add_argument(
1672
+ "--ddp_gradient_as_bucket_view",
1673
+ action="store_true",
1674
+ help="enable gradient_as_bucket_view for DDP / DDPでgradient_as_bucket_viewを有効にする",
1675
+ )
1676
+ parser.add_argument(
1677
+ "--ddp_static_graph",
1678
+ action="store_true",
1679
+ help="enable static_graph for DDP / DDPでstatic_graphを有効にする",
1680
+ )
1681
+
1682
+ parser.add_argument(
1683
+ "--sample_every_n_steps",
1684
+ type=int,
1685
+ default=None,
1686
+ help="generate sample images every N steps / 学習中のモデルで指定ステップごとにサンプル出力する",
1687
+ )
1688
+ parser.add_argument(
1689
+ "--sample_at_first", action="store_true", help="generate sample images before training / 学習前にサンプル出力する"
1690
+ )
1691
+ parser.add_argument(
1692
+ "--sample_every_n_epochs",
1693
+ type=int,
1694
+ default=None,
1695
+ help="generate sample images every N epochs (overwrites n_steps) / 学習中のモデルで指定エポックごとにサンプル出力する(ステップ数指定を上書きします)",
1696
+ )
1697
+ parser.add_argument(
1698
+ "--sample_prompts",
1699
+ type=str,
1700
+ default=None,
1701
+ help="file for prompts to generate sample images / 学習中モデルのサンプル出力用プロンプトのファイル",
1702
+ )
1703
+
1704
+ # optimizer and lr scheduler settings
1705
+ parser.add_argument(
1706
+ "--optimizer_type",
1707
+ type=str,
1708
+ default="",
1709
+ help="Optimizer to use / オプティマイザの種類: AdamW (default), AdamW8bit, AdaFactor. "
1710
+ "Also, you can use any optimizer by specifying the full path to the class, like 'torch.optim.AdamW', 'bitsandbytes.optim.AdEMAMix8bit' or 'bitsandbytes.optim.PagedAdEMAMix8bit' etc. / ",
1711
+ )
1712
+ parser.add_argument(
1713
+ "--optimizer_args",
1714
+ type=str,
1715
+ default=None,
1716
+ nargs="*",
1717
+ help='additional arguments for optimizer (like "weight_decay=0.01 betas=0.9,0.999 ...") / オプティマイザの追加引数(例: "weight_decay=0.01 betas=0.9,0.999 ...")',
1718
+ )
1719
+ parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率")
1720
+ parser.add_argument(
1721
+ "--max_grad_norm",
1722
+ default=1.0,
1723
+ type=float,
1724
+ help="Max gradient norm, 0 for no clipping / 勾配正規化の最大norm、0でclippingを行わない",
1725
+ )
1726
+
1727
+ parser.add_argument(
1728
+ "--lr_scheduler",
1729
+ type=str,
1730
+ default="constant",
1731
+ help="scheduler to use for learning rate / 学習率のスケジューラ: linear, cosine, cosine_with_restarts, polynomial, constant (default), constant_with_warmup, adafactor",
1732
+ )
1733
+ parser.add_argument(
1734
+ "--lr_warmup_steps",
1735
+ type=int_or_float,
1736
+ default=0,
1737
+ help="Int number of steps for the warmup in the lr scheduler (default is 0) or float with ratio of train steps"
1738
+ " / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)、または学習ステップの比率(1未満のfloat値の場合)",
1739
+ )
1740
+ parser.add_argument(
1741
+ "--lr_decay_steps",
1742
+ type=int_or_float,
1743
+ default=0,
1744
+ help="Int number of steps for the decay in the lr scheduler (default is 0) or float (<1) with ratio of train steps"
1745
+ " / 学習率のスケジューラを減衰させるステップ数(デフォルト0)、または学習ステップの比率(1未満のfloat値の場合)",
1746
+ )
1747
+ parser.add_argument(
1748
+ "--lr_scheduler_num_cycles",
1749
+ type=int,
1750
+ default=1,
1751
+ help="Number of restarts for cosine scheduler with restarts / cosine with restartsスケジューラでのリスタート回数",
1752
+ )
1753
+ parser.add_argument(
1754
+ "--lr_scheduler_power",
1755
+ type=float,
1756
+ default=1,
1757
+ help="Polynomial power for polynomial scheduler / polynomialスケジューラでのpolynomial power",
1758
+ )
1759
+ parser.add_argument(
1760
+ "--lr_scheduler_timescale",
1761
+ type=int,
1762
+ default=None,
1763
+ help="Inverse sqrt timescale for inverse sqrt scheduler,defaults to `num_warmup_steps`"
1764
+ + " / 逆平方根スケジューラのタイムスケール、デフォルトは`num_warmup_steps`",
1765
+ )
1766
+ parser.add_argument(
1767
+ "--lr_scheduler_min_lr_ratio",
1768
+ type=float,
1769
+ default=None,
1770
+ help="The minimum learning rate as a ratio of the initial learning rate for cosine with min lr scheduler and warmup decay scheduler"
1771
+ + " / 初期学習率の比率としての最小学習率を指定する、cosine with min lr と warmup decay スケジューラ で有効",
1772
+ )
1773
+ parser.add_argument("--lr_scheduler_type", type=str, default="", help="custom scheduler module / 使用するスケジューラ")
1774
+ parser.add_argument(
1775
+ "--lr_scheduler_args",
1776
+ type=str,
1777
+ default=None,
1778
+ nargs="*",
1779
+ help='additional arguments for scheduler (like "T_max=100") / スケジューラの追加引数(例: "T_max100")',
1780
+ )
1781
+
1782
+ # model settings
1783
+ parser.add_argument("--dit", type=str, required=True, help="DiT checkpoint path / DiTのチェックポイントのパス")
1784
+ parser.add_argument("--dit_dtype", type=str, default=None, help="data type for DiT, default is bfloat16")
1785
+ parser.add_argument("--vae", type=str, help="VAE checkpoint path / VAEのチェックポイントのパス")
1786
+ parser.add_argument("--vae_dtype", type=str, default=None, help="data type for VAE, default is float16")
1787
+ parser.add_argument(
1788
+ "--vae_tiling",
1789
+ action="store_true",
1790
+ help="enable spatial tiling for VAE, default is False. If vae_spatial_tile_sample_min_size is set, this is automatically enabled."
1791
+ " / VAEの空間タイリングを有効にする、デフォルトはFalse。vae_spatial_tile_sample_min_sizeが設定されている場合、自動的に有効になります。",
1792
+ )
1793
+ parser.add_argument("--vae_chunk_size", type=int, default=None, help="chunk size for CausalConv3d in VAE")
1794
+ parser.add_argument(
1795
+ "--vae_spatial_tile_sample_min_size", type=int, default=None, help="spatial tile sample min size for VAE, default 256"
1796
+ )
1797
+ parser.add_argument("--text_encoder1", type=str, help="Text Encoder 1 directory / テキストエンコーダ1のディレクトリ")
1798
+ parser.add_argument("--text_encoder2", type=str, help="Text Encoder 2 directory / テキストエンコーダ2のディレクトリ")
1799
+ parser.add_argument("--text_encoder_dtype", type=str, default=None, help="data type for Text Encoder, default is float16")
1800
+ parser.add_argument("--fp8_llm", action="store_true", help="use fp8 for LLM / LLMにfp8を使う")
1801
+ parser.add_argument("--fp8_base", action="store_true", help="use fp8 for base model / base modelにfp8を使う")
1802
+ # parser.add_argument("--full_fp16", action="store_true", help="fp16 training including gradients / 勾配も含めてfp16で学習する")
1803
+ # parser.add_argument("--full_bf16", action="store_true", help="bf16 training including gradients / 勾配も含めてbf16で学習する")
1804
+
1805
+ parser.add_argument(
1806
+ "--blocks_to_swap",
1807
+ type=int,
1808
+ default=None,
1809
+ help="number of blocks to swap in the model, max XXX / モデル内のブロックの数、最大XXX",
1810
+ )
1811
+ parser.add_argument(
1812
+ "--img_in_txt_in_offloading",
1813
+ action="store_true",
1814
+ help="offload img_in and txt_in to cpu / img_inとtxt_inをCPUにオフロードする",
1815
+ )
1816
+
1817
+ # parser.add_argument("--flow_shift", type=float, default=7.0, help="Shift factor for flow matching schedulers")
1818
+ parser.add_argument("--guidance_scale", type=float, default=1.0, help="Embeded classifier free guidance scale.")
1819
+ parser.add_argument(
1820
+ "--timestep_sampling",
1821
+ choices=["sigma", "uniform", "sigmoid", "shift"],
1822
+ default="sigma",
1823
+ help="Method to sample timesteps: sigma-based, uniform random, sigmoid of random normal and shift of sigmoid."
1824
+ " / タイムステップをサンプリングする方法:sigma、random uniform、random normalのsigmoid、sigmoidのシフト。",
1825
+ )
1826
+ parser.add_argument(
1827
+ "--discrete_flow_shift",
1828
+ type=float,
1829
+ default=1.0,
1830
+ help="Discrete flow shift for the Euler Discrete Scheduler, default is 1.0. / Euler Discrete Schedulerの離散フローシフト、デフォルトは1.0。",
1831
+ )
1832
+ parser.add_argument(
1833
+ "--sigmoid_scale",
1834
+ type=float,
1835
+ default=1.0,
1836
+ help='Scale factor for sigmoid timestep sampling (only used when timestep-sampling is "sigmoid" or "shift"). / sigmoidタイムステップサンプリングの倍率(timestep-samplingが"sigmoid"または"shift"の場合のみ有効)。',
1837
+ )
1838
+ parser.add_argument(
1839
+ "--weighting_scheme",
1840
+ type=str,
1841
+ default="none",
1842
+ choices=["logit_normal", "mode", "cosmap", "sigma_sqrt", "none"],
1843
+ help="weighting scheme for timestep distribution. Default is none"
1844
+ " / タイムステップ分布の重み付けスキーム、デフォルトはnone",
1845
+ )
1846
+ parser.add_argument(
1847
+ "--logit_mean",
1848
+ type=float,
1849
+ default=0.0,
1850
+ help="mean to use when using the `'logit_normal'` weighting scheme / `'logit_normal'`重み付けスキームを使用する場合の平均",
1851
+ )
1852
+ parser.add_argument(
1853
+ "--logit_std",
1854
+ type=float,
1855
+ default=1.0,
1856
+ help="std to use when using the `'logit_normal'` weighting scheme / `'logit_normal'`重み付けスキームを使用する場合のstd",
1857
+ )
1858
+ parser.add_argument(
1859
+ "--mode_scale",
1860
+ type=float,
1861
+ default=1.29,
1862
+ help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme` / モード重み付けスキームのスケール",
1863
+ )
1864
+ parser.add_argument(
1865
+ "--min_timestep",
1866
+ type=int,
1867
+ default=None,
1868
+ help="set minimum time step for training (0~999, default is 0) / 学習時のtime stepの最小値を設定する(0~999で指定、省略時はデフォルト値(0)) ",
1869
+ )
1870
+ parser.add_argument(
1871
+ "--max_timestep",
1872
+ type=int,
1873
+ default=None,
1874
+ help="set maximum time step for training (1~1000, default is 1000) / 学習時のtime stepの最大値を設定する(1~1000で指定、省略時はデフォルト値(1000))",
1875
+ )
1876
+
1877
+ parser.add_argument(
1878
+ "--show_timesteps",
1879
+ type=str,
1880
+ default=None,
1881
+ choices=["image", "console"],
1882
+ help="show timesteps in image or console, and return to console / タイムステップを画像またはコンソールに表示し、コンソールに戻る",
1883
+ )
1884
+
1885
+ # network settings
1886
+ parser.add_argument(
1887
+ "--no_metadata", action="store_true", help="do not save metadata in output model / メタデータを出力先モデルに保存しない"
1888
+ )
1889
+ parser.add_argument(
1890
+ "--network_weights", type=str, default=None, help="pretrained weights for network / 学習するネットワークの初期重み"
1891
+ )
1892
+ parser.add_argument(
1893
+ "--network_module", type=str, default=None, help="network module to train / 学習対象のネットワークのモジュール"
1894
+ )
1895
+ parser.add_argument(
1896
+ "--network_dim",
1897
+ type=int,
1898
+ default=None,
1899
+ help="network dimensions (depends on each network) / モジュールの次元数(ネットワークにより定義は異なります)",
1900
+ )
1901
+ parser.add_argument(
1902
+ "--network_alpha",
1903
+ type=float,
1904
+ default=1,
1905
+ help="alpha for LoRA weight scaling, default 1 (same as network_dim for same behavior as old version) / LoRaの重み調整のalpha値、デフォルト1(旧バージョンと同じ動作をするにはnetwork_dimと同じ値を指定)",
1906
+ )
1907
+ parser.add_argument(
1908
+ "--network_dropout",
1909
+ type=float,
1910
+ default=None,
1911
+ help="Drops neurons out of training every step (0 or None is default behavior (no dropout), 1 would drop all neurons) / 訓練時に毎ステップでニューロンをdropする(0またはNoneはdropoutなし、1は全ニューロンをdropout)",
1912
+ )
1913
+ parser.add_argument(
1914
+ "--network_args",
1915
+ type=str,
1916
+ default=None,
1917
+ nargs="*",
1918
+ help="additional arguments for network (key=value) / ネットワークへの追加の引数",
1919
+ )
1920
+ parser.add_argument(
1921
+ "--training_comment",
1922
+ type=str,
1923
+ default=None,
1924
+ help="arbitrary comment string stored in metadata / メタデータに記録する任意のコメント文字列",
1925
+ )
1926
+ parser.add_argument(
1927
+ "--dim_from_weights",
1928
+ action="store_true",
1929
+ help="automatically determine dim (rank) from network_weights / dim (rank)をnetwork_weightsで指定した重みから自動で決定する",
1930
+ )
1931
+ parser.add_argument(
1932
+ "--scale_weight_norms",
1933
+ type=float,
1934
+ default=None,
1935
+ help="Scale the weight of each key pair to help prevent overtraing via exploding gradients. (1 is a good starting point) / 重みの値をスケーリングして勾配爆発を防ぐ(1が初期値としては適当)",
1936
+ )
1937
+ parser.add_argument(
1938
+ "--base_weights",
1939
+ type=str,
1940
+ default=None,
1941
+ nargs="*",
1942
+ help="network weights to merge into the model before training / 学習前にあらかじめモデルにマージするnetworkの重みファイル",
1943
+ )
1944
+ parser.add_argument(
1945
+ "--base_weights_multiplier",
1946
+ type=float,
1947
+ default=None,
1948
+ nargs="*",
1949
+ help="multiplier for network weights to merge into the model before training / 学習前にあらかじめモデルにマージするnetworkの重みの倍率",
1950
+ )
1951
+
1952
+ # save and load settings
1953
+ parser.add_argument(
1954
+ "--output_dir", type=str, default=None, help="directory to output trained model / 学習後のモデル出力先ディレクトリ"
1955
+ )
1956
+ parser.add_argument(
1957
+ "--output_name",
1958
+ type=str,
1959
+ default=None,
1960
+ required=True,
1961
+ help="base name of trained model file / 学習後のモデルの拡張子を除くファイル名",
1962
+ )
1963
+ parser.add_argument("--resume", type=str, default=None, help="saved state to resume training / 学習再開するモデルのstate")
1964
+
1965
+ parser.add_argument(
1966
+ "--save_every_n_epochs",
1967
+ type=int,
1968
+ default=None,
1969
+ help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する",
1970
+ )
1971
+ parser.add_argument(
1972
+ "--save_every_n_steps",
1973
+ type=int,
1974
+ default=None,
1975
+ help="save checkpoint every N steps / 学習中のモデルを指定ステップごとに保存する",
1976
+ )
1977
+ parser.add_argument(
1978
+ "--save_last_n_epochs",
1979
+ type=int,
1980
+ default=None,
1981
+ help="save last N checkpoints when saving every N epochs (remove older checkpoints) / 指定エポックごとにモデルを保存するとき最大Nエポック保存する(古いチェックポイントは削除する)",
1982
+ )
1983
+ parser.add_argument(
1984
+ "--save_last_n_epochs_state",
1985
+ type=int,
1986
+ default=None,
1987
+ help="save last N checkpoints of state (overrides the value of --save_last_n_epochs)/ 最大Nエポックstateを保存する(--save_last_n_epochsの指定を上書きする)",
1988
+ )
1989
+ parser.add_argument(
1990
+ "--save_last_n_steps",
1991
+ type=int,
1992
+ default=None,
1993
+ help="save checkpoints until N steps elapsed (remove older checkpoints if N steps elapsed) / 指定ステップごとにモデルを保存するとき、このステップ数経過するまで保存する(このステップ数経過したら削除する)",
1994
+ )
1995
+ parser.add_argument(
1996
+ "--save_last_n_steps_state",
1997
+ type=int,
1998
+ default=None,
1999
+ help="save states until N steps elapsed (remove older states if N steps elapsed, overrides --save_last_n_steps) / 指定ステップごとにstateを保存するとき、このステップ数経過するまで保存する(このステップ数経過したら削除する。--save_last_n_stepsを上書きする)",
2000
+ )
2001
+ parser.add_argument(
2002
+ "--save_state",
2003
+ action="store_true",
2004
+ help="save training state additionally (including optimizer states etc.) when saving model / optimizerなど学習状態も含めたstateをモデル保存時に追加で保存する",
2005
+ )
2006
+ parser.add_argument(
2007
+ "--save_state_on_train_end",
2008
+ action="store_true",
2009
+ help="save training state (including optimizer states etc.) on train end even if --save_state is not specified"
2010
+ " / --save_stateが未指定時にもoptimizerなど学習状態も含めたstateを学習終了時に保存する",
2011
+ )
2012
+
2013
+ # SAI Model spec
2014
+ parser.add_argument(
2015
+ "--metadata_title",
2016
+ type=str,
2017
+ default=None,
2018
+ help="title for model metadata (default is output_name) / メタデータに書き込まれるモデルタイトル、省略時はoutput_name",
2019
+ )
2020
+ parser.add_argument(
2021
+ "--metadata_author",
2022
+ type=str,
2023
+ default=None,
2024
+ help="author name for model metadata / メタデータに書き込まれるモデル作者名",
2025
+ )
2026
+ parser.add_argument(
2027
+ "--metadata_description",
2028
+ type=str,
2029
+ default=None,
2030
+ help="description for model metadata / メタデータに書き込まれるモデル説明",
2031
+ )
2032
+ parser.add_argument(
2033
+ "--metadata_license",
2034
+ type=str,
2035
+ default=None,
2036
+ help="license for model metadata / メタデータに書き込まれるモデルライセンス",
2037
+ )
2038
+ parser.add_argument(
2039
+ "--metadata_tags",
2040
+ type=str,
2041
+ default=None,
2042
+ help="tags for model metadata, separated by comma / メタデータに書き込まれるモデルタグ、カンマ区切り",
2043
+ )
2044
+
2045
+ # huggingface settings
2046
+ parser.add_argument(
2047
+ "--huggingface_repo_id",
2048
+ type=str,
2049
+ default=None,
2050
+ help="huggingface repo name to upload / huggingfaceにアップロードするリポジトリ名",
2051
+ )
2052
+ parser.add_argument(
2053
+ "--huggingface_repo_type",
2054
+ type=str,
2055
+ default=None,
2056
+ help="huggingface repo type to upload / huggingfaceにアップロードするリポジトリの種類",
2057
+ )
2058
+ parser.add_argument(
2059
+ "--huggingface_path_in_repo",
2060
+ type=str,
2061
+ default=None,
2062
+ help="huggingface model path to upload files / huggingfaceにアップロードするファイルのパス",
2063
+ )
2064
+ parser.add_argument("--huggingface_token", type=str, default=None, help="huggingface token / huggingfaceのトークン")
2065
+ parser.add_argument(
2066
+ "--huggingface_repo_visibility",
2067
+ type=str,
2068
+ default=None,
2069
+ help="huggingface repository visibility ('public' for public, 'private' or None for private) / huggingfaceにアップロードするリポジトリの公開設定('public'で公開、'private'またはNoneで非公開)",
2070
+ )
2071
+ parser.add_argument(
2072
+ "--save_state_to_huggingface", action="store_true", help="save state to huggingface / huggingfaceにstateを保存する"
2073
+ )
2074
+ parser.add_argument(
2075
+ "--resume_from_huggingface",
2076
+ action="store_true",
2077
+ help="resume from huggingface (ex: --resume {repo_id}/{path_in_repo}:{revision}:{repo_type}) / huggingfaceから学習を再開する(例: --resume {repo_id}/{path_in_repo}:{revision}:{repo_type})",
2078
+ )
2079
+ parser.add_argument(
2080
+ "--async_upload",
2081
+ action="store_true",
2082
+ help="upload to huggingface asynchronously / huggingfaceに非同期でアップロードする",
2083
+ )
2084
+
2085
+ return parser
2086
+
2087
+
2088
+ def read_config_from_file(args: argparse.Namespace, parser: argparse.ArgumentParser):
2089
+ if not args.config_file:
2090
+ return args
2091
+
2092
+ config_path = args.config_file + ".toml" if not args.config_file.endswith(".toml") else args.config_file
2093
+
2094
+ if not os.path.exists(config_path):
2095
+ logger.info(f"{config_path} not found.")
2096
+ exit(1)
2097
+
2098
+ logger.info(f"Loading settings from {config_path}...")
2099
+ with open(config_path, "r", encoding="utf-8") as f:
2100
+ config_dict = toml.load(f)
2101
+
2102
+ # combine all sections into one
2103
+ ignore_nesting_dict = {}
2104
+ for section_name, section_dict in config_dict.items():
2105
+ # if value is not dict, save key and value as is
2106
+ if not isinstance(section_dict, dict):
2107
+ ignore_nesting_dict[section_name] = section_dict
2108
+ continue
2109
+
2110
+ # if value is dict, save all key and value into one dict
2111
+ for key, value in section_dict.items():
2112
+ ignore_nesting_dict[key] = value
2113
+
2114
+ config_args = argparse.Namespace(**ignore_nesting_dict)
2115
+ args = parser.parse_args(namespace=config_args)
2116
+ args.config_file = os.path.splitext(args.config_file)[0]
2117
+ logger.info(args.config_file)
2118
+
2119
+ return args
2120
+
2121
+
2122
+ if __name__ == "__main__":
2123
+ parser = setup_parser()
2124
+
2125
+ args = parser.parse_args()
2126
+ args = read_config_from_file(args, parser)
2127
+
2128
+ trainer = NetworkTrainer()
2129
+ trainer.train(args)
modules/__init__.py ADDED
File without changes
modules/custom_offloading_utils.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from concurrent.futures import ThreadPoolExecutor
2
+ import gc
3
+ import time
4
+ from typing import Optional
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+
9
+ def clean_memory_on_device(device: torch.device):
10
+ r"""
11
+ Clean memory on the specified device, will be called from training scripts.
12
+ """
13
+ gc.collect()
14
+
15
+ # device may "cuda" or "cuda:0", so we need to check the type of device
16
+ if device.type == "cuda":
17
+ torch.cuda.empty_cache()
18
+ if device.type == "xpu":
19
+ torch.xpu.empty_cache()
20
+ if device.type == "mps":
21
+ torch.mps.empty_cache()
22
+
23
+
24
+ def synchronize_device(device: torch.device):
25
+ if device.type == "cuda":
26
+ torch.cuda.synchronize()
27
+ elif device.type == "xpu":
28
+ torch.xpu.synchronize()
29
+ elif device.type == "mps":
30
+ torch.mps.synchronize()
31
+
32
+
33
+ def swap_weight_devices_cuda(device: torch.device, layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
34
+ assert layer_to_cpu.__class__ == layer_to_cuda.__class__
35
+
36
+ weight_swap_jobs = []
37
+
38
+ # This is not working for all cases (e.g. SD3), so we need to find the corresponding modules
39
+ # for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
40
+ # print(module_to_cpu.__class__, module_to_cuda.__class__)
41
+ # if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
42
+ # weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data))
43
+
44
+ modules_to_cpu = {k: v for k, v in layer_to_cpu.named_modules()}
45
+ for module_to_cuda_name, module_to_cuda in layer_to_cuda.named_modules():
46
+ if hasattr(module_to_cuda, "weight") and module_to_cuda.weight is not None:
47
+ module_to_cpu = modules_to_cpu.get(module_to_cuda_name, None)
48
+ if module_to_cpu is not None and module_to_cpu.weight.shape == module_to_cuda.weight.shape:
49
+ weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data))
50
+ else:
51
+ if module_to_cuda.weight.data.device.type != device.type:
52
+ # print(
53
+ # f"Module {module_to_cuda_name} not found in CPU model or shape mismatch, so not swapping and moving to device"
54
+ # )
55
+ module_to_cuda.weight.data = module_to_cuda.weight.data.to(device)
56
+
57
+ torch.cuda.current_stream().synchronize() # this prevents the illegal loss value
58
+
59
+ stream = torch.cuda.Stream()
60
+ with torch.cuda.stream(stream):
61
+ # cuda to cpu
62
+ for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
63
+ cuda_data_view.record_stream(stream)
64
+ module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True)
65
+
66
+ stream.synchronize()
67
+
68
+ # cpu to cuda
69
+ for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
70
+ cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True)
71
+ module_to_cuda.weight.data = cuda_data_view
72
+
73
+ stream.synchronize()
74
+ torch.cuda.current_stream().synchronize() # this prevents the illegal loss value
75
+
76
+
77
+ def swap_weight_devices_no_cuda(device: torch.device, layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
78
+ """
79
+ not tested
80
+ """
81
+ assert layer_to_cpu.__class__ == layer_to_cuda.__class__
82
+
83
+ weight_swap_jobs = []
84
+ for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
85
+ if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
86
+ weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data))
87
+
88
+ # device to cpu
89
+ for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
90
+ module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True)
91
+
92
+ synchronize_device()
93
+
94
+ # cpu to device
95
+ for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
96
+ cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True)
97
+ module_to_cuda.weight.data = cuda_data_view
98
+
99
+ synchronize_device()
100
+
101
+
102
+ def weighs_to_device(layer: nn.Module, device: torch.device):
103
+ for module in layer.modules():
104
+ if hasattr(module, "weight") and module.weight is not None:
105
+ module.weight.data = module.weight.data.to(device, non_blocking=True)
106
+
107
+
108
+ class Offloader:
109
+ """
110
+ common offloading class
111
+ """
112
+
113
+ def __init__(self, block_type: str, num_blocks: int, blocks_to_swap: int, device: torch.device, debug: bool = False):
114
+ self.block_type = block_type
115
+ self.num_blocks = num_blocks
116
+ self.blocks_to_swap = blocks_to_swap
117
+ self.device = device
118
+ self.debug = debug
119
+
120
+ self.thread_pool = ThreadPoolExecutor(max_workers=1)
121
+ self.futures = {}
122
+ self.cuda_available = device.type == "cuda"
123
+
124
+ def swap_weight_devices(self, block_to_cpu: nn.Module, block_to_cuda: nn.Module):
125
+ if self.cuda_available:
126
+ swap_weight_devices_cuda(self.device, block_to_cpu, block_to_cuda)
127
+ else:
128
+ swap_weight_devices_no_cuda(self.device, block_to_cpu, block_to_cuda)
129
+
130
+ def _submit_move_blocks(self, blocks, block_idx_to_cpu, block_idx_to_cuda):
131
+ def move_blocks(bidx_to_cpu, block_to_cpu, bidx_to_cuda, block_to_cuda):
132
+ if self.debug:
133
+ start_time = time.perf_counter()
134
+ print(
135
+ f"[{self.block_type}] Move block {bidx_to_cpu} to CPU and block {bidx_to_cuda} to {'CUDA' if self.cuda_available else 'device'}"
136
+ )
137
+
138
+ self.swap_weight_devices(block_to_cpu, block_to_cuda)
139
+
140
+ if self.debug:
141
+ print(f"[{self.block_type}] Moved blocks {bidx_to_cpu} and {bidx_to_cuda} in {time.perf_counter()-start_time:.2f}s")
142
+ return bidx_to_cpu, bidx_to_cuda # , event
143
+
144
+ block_to_cpu = blocks[block_idx_to_cpu]
145
+ block_to_cuda = blocks[block_idx_to_cuda]
146
+
147
+ self.futures[block_idx_to_cuda] = self.thread_pool.submit(
148
+ move_blocks, block_idx_to_cpu, block_to_cpu, block_idx_to_cuda, block_to_cuda
149
+ )
150
+
151
+ def _wait_blocks_move(self, block_idx):
152
+ if block_idx not in self.futures:
153
+ return
154
+
155
+ if self.debug:
156
+ print(f"[{self.block_type}] Wait for block {block_idx}")
157
+ start_time = time.perf_counter()
158
+
159
+ future = self.futures.pop(block_idx)
160
+ _, bidx_to_cuda = future.result()
161
+
162
+ assert block_idx == bidx_to_cuda, f"Block index mismatch: {block_idx} != {bidx_to_cuda}"
163
+
164
+ if self.debug:
165
+ print(f"[{self.block_type}] Waited for block {block_idx}: {time.perf_counter()-start_time:.2f}s")
166
+
167
+
168
+ class ModelOffloader(Offloader):
169
+ """
170
+ supports forward offloading
171
+ """
172
+
173
+ def __init__(
174
+ self,
175
+ block_type: str,
176
+ blocks: list[nn.Module],
177
+ num_blocks: int,
178
+ blocks_to_swap: int,
179
+ supports_backward: bool,
180
+ device: torch.device,
181
+ debug: bool = False,
182
+ ):
183
+ super().__init__(block_type, num_blocks, blocks_to_swap, device, debug)
184
+
185
+ self.supports_backward = supports_backward
186
+
187
+ if self.supports_backward:
188
+ # register backward hooks
189
+ self.remove_handles = []
190
+ for i, block in enumerate(blocks):
191
+ hook = self.create_backward_hook(blocks, i)
192
+ if hook is not None:
193
+ handle = block.register_full_backward_hook(hook)
194
+ self.remove_handles.append(handle)
195
+
196
+ def __del__(self):
197
+ if self.supports_backward:
198
+ for handle in self.remove_handles:
199
+ handle.remove()
200
+
201
+ def create_backward_hook(self, blocks: list[nn.Module], block_index: int) -> Optional[callable]:
202
+ # -1 for 0-based index
203
+ num_blocks_propagated = self.num_blocks - block_index - 1
204
+ swapping = num_blocks_propagated > 0 and num_blocks_propagated <= self.blocks_to_swap
205
+ waiting = block_index > 0 and block_index <= self.blocks_to_swap
206
+
207
+ if not swapping and not waiting:
208
+ return None
209
+
210
+ # create hook
211
+ block_idx_to_cpu = self.num_blocks - num_blocks_propagated
212
+ block_idx_to_cuda = self.blocks_to_swap - num_blocks_propagated
213
+ block_idx_to_wait = block_index - 1
214
+
215
+ def backward_hook(module, grad_input, grad_output):
216
+ if self.debug:
217
+ print(f"Backward hook for block {block_index}")
218
+
219
+ if swapping:
220
+ self._submit_move_blocks(blocks, block_idx_to_cpu, block_idx_to_cuda)
221
+ if waiting:
222
+ self._wait_blocks_move(block_idx_to_wait)
223
+ return None
224
+
225
+ return backward_hook
226
+
227
+ def prepare_block_devices_before_forward(self, blocks: list[nn.Module]):
228
+ if self.blocks_to_swap is None or self.blocks_to_swap == 0:
229
+ return
230
+
231
+ if self.debug:
232
+ print(f"[{self.block_type}] Prepare block devices before forward")
233
+
234
+ for b in blocks[0 : self.num_blocks - self.blocks_to_swap]:
235
+ b.to(self.device)
236
+ weighs_to_device(b, self.device) # make sure weights are on device
237
+
238
+ for b in blocks[self.num_blocks - self.blocks_to_swap :]:
239
+ b.to(self.device) # move block to device first
240
+ weighs_to_device(b, "cpu") # make sure weights are on cpu
241
+
242
+ synchronize_device(self.device)
243
+ clean_memory_on_device(self.device)
244
+
245
+ def wait_for_block(self, block_idx: int):
246
+ if self.blocks_to_swap is None or self.blocks_to_swap == 0:
247
+ return
248
+ self._wait_blocks_move(block_idx)
249
+
250
+ def submit_move_blocks_forward(self, blocks: list[nn.Module], block_idx: int):
251
+ # check if blocks_to_swap is enabled
252
+ if self.blocks_to_swap is None or self.blocks_to_swap == 0:
253
+ return
254
+
255
+ # if supports_backward, we swap blocks more than blocks_to_swap in backward pass
256
+ if self.supports_backward and block_idx >= self.blocks_to_swap:
257
+ return
258
+
259
+ block_idx_to_cpu = block_idx
260
+ block_idx_to_cuda = self.num_blocks - self.blocks_to_swap + block_idx
261
+ block_idx_to_cuda = block_idx_to_cuda % self.num_blocks # this works for forward-only offloading
262
+ self._submit_move_blocks(blocks, block_idx_to_cpu, block_idx_to_cuda)
modules/scheduling_flow_match_discrete.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ #
16
+ # Modified from diffusers==0.29.2
17
+ #
18
+ # ==============================================================================
19
+
20
+ from dataclasses import dataclass
21
+ from typing import Optional, Tuple, Union
22
+
23
+ import numpy as np
24
+ import torch
25
+
26
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
27
+ from diffusers.utils import BaseOutput, logging
28
+ from diffusers.schedulers.scheduling_utils import SchedulerMixin
29
+
30
+
31
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
32
+
33
+
34
+ @dataclass
35
+ class FlowMatchDiscreteSchedulerOutput(BaseOutput):
36
+ """
37
+ Output class for the scheduler's `step` function output.
38
+
39
+ Args:
40
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
41
+ Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
42
+ denoising loop.
43
+ """
44
+
45
+ prev_sample: torch.FloatTensor
46
+
47
+
48
+ class FlowMatchDiscreteScheduler(SchedulerMixin, ConfigMixin):
49
+ """
50
+ Euler scheduler.
51
+
52
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
53
+ methods the library implements for all schedulers such as loading and saving.
54
+
55
+ Args:
56
+ num_train_timesteps (`int`, defaults to 1000):
57
+ The number of diffusion steps to train the model.
58
+ timestep_spacing (`str`, defaults to `"linspace"`):
59
+ The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
60
+ Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
61
+ shift (`float`, defaults to 1.0):
62
+ The shift value for the timestep schedule.
63
+ reverse (`bool`, defaults to `True`):
64
+ Whether to reverse the timestep schedule.
65
+ """
66
+
67
+ _compatibles = []
68
+ order = 1
69
+
70
+ @register_to_config
71
+ def __init__(
72
+ self,
73
+ num_train_timesteps: int = 1000,
74
+ shift: float = 1.0,
75
+ reverse: bool = True,
76
+ solver: str = "euler",
77
+ n_tokens: Optional[int] = None,
78
+ ):
79
+ sigmas = torch.linspace(1, 0, num_train_timesteps + 1)
80
+
81
+ if not reverse:
82
+ sigmas = sigmas.flip(0)
83
+
84
+ self.sigmas = sigmas
85
+ # the value fed to model
86
+ self.timesteps = (sigmas[:-1] * num_train_timesteps).to(dtype=torch.float32)
87
+
88
+ self._step_index = None
89
+ self._begin_index = None
90
+
91
+ self.supported_solver = ["euler"]
92
+ if solver not in self.supported_solver:
93
+ raise ValueError(
94
+ f"Solver {solver} not supported. Supported solvers: {self.supported_solver}"
95
+ )
96
+
97
+ @property
98
+ def step_index(self):
99
+ """
100
+ The index counter for current timestep. It will increase 1 after each scheduler step.
101
+ """
102
+ return self._step_index
103
+
104
+ @property
105
+ def begin_index(self):
106
+ """
107
+ The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
108
+ """
109
+ return self._begin_index
110
+
111
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
112
+ def set_begin_index(self, begin_index: int = 0):
113
+ """
114
+ Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
115
+
116
+ Args:
117
+ begin_index (`int`):
118
+ The begin index for the scheduler.
119
+ """
120
+ self._begin_index = begin_index
121
+
122
+ def _sigma_to_t(self, sigma):
123
+ return sigma * self.config.num_train_timesteps
124
+
125
+ def set_timesteps(
126
+ self,
127
+ num_inference_steps: int,
128
+ device: Union[str, torch.device] = None,
129
+ n_tokens: int = None,
130
+ ):
131
+ """
132
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
133
+
134
+ Args:
135
+ num_inference_steps (`int`):
136
+ The number of diffusion steps used when generating samples with a pre-trained model.
137
+ device (`str` or `torch.device`, *optional*):
138
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
139
+ n_tokens (`int`, *optional*):
140
+ Number of tokens in the input sequence.
141
+ """
142
+ self.num_inference_steps = num_inference_steps
143
+
144
+ sigmas = torch.linspace(1, 0, num_inference_steps + 1)
145
+ sigmas = self.sd3_time_shift(sigmas)
146
+
147
+ if not self.config.reverse:
148
+ sigmas = 1 - sigmas
149
+
150
+ self.sigmas = sigmas
151
+ self.timesteps = (sigmas[:-1] * self.config.num_train_timesteps).to(
152
+ dtype=torch.float32, device=device
153
+ )
154
+
155
+ # Reset step index
156
+ self._step_index = None
157
+
158
+ def index_for_timestep(self, timestep, schedule_timesteps=None):
159
+ if schedule_timesteps is None:
160
+ schedule_timesteps = self.timesteps
161
+
162
+ indices = (schedule_timesteps == timestep).nonzero()
163
+
164
+ # The sigma index that is taken for the **very** first `step`
165
+ # is always the second index (or the last index if there is only 1)
166
+ # This way we can ensure we don't accidentally skip a sigma in
167
+ # case we start in the middle of the denoising schedule (e.g. for image-to-image)
168
+ pos = 1 if len(indices) > 1 else 0
169
+
170
+ return indices[pos].item()
171
+
172
+ def _init_step_index(self, timestep):
173
+ if self.begin_index is None:
174
+ if isinstance(timestep, torch.Tensor):
175
+ timestep = timestep.to(self.timesteps.device)
176
+ self._step_index = self.index_for_timestep(timestep)
177
+ else:
178
+ self._step_index = self._begin_index
179
+
180
+ def scale_model_input(
181
+ self, sample: torch.Tensor, timestep: Optional[int] = None
182
+ ) -> torch.Tensor:
183
+ return sample
184
+
185
+ def sd3_time_shift(self, t: torch.Tensor):
186
+ return (self.config.shift * t) / (1 + (self.config.shift - 1) * t)
187
+
188
+ def step(
189
+ self,
190
+ model_output: torch.FloatTensor,
191
+ timestep: Union[float, torch.FloatTensor],
192
+ sample: torch.FloatTensor,
193
+ return_dict: bool = True,
194
+ ) -> Union[FlowMatchDiscreteSchedulerOutput, Tuple]:
195
+ """
196
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
197
+ process from the learned model outputs (most often the predicted noise).
198
+
199
+ Args:
200
+ model_output (`torch.FloatTensor`):
201
+ The direct output from learned diffusion model.
202
+ timestep (`float`):
203
+ The current discrete timestep in the diffusion chain.
204
+ sample (`torch.FloatTensor`):
205
+ A current instance of a sample created by the diffusion process.
206
+ generator (`torch.Generator`, *optional*):
207
+ A random number generator.
208
+ n_tokens (`int`, *optional*):
209
+ Number of tokens in the input sequence.
210
+ return_dict (`bool`):
211
+ Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
212
+ tuple.
213
+
214
+ Returns:
215
+ [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
216
+ If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
217
+ returned, otherwise a tuple is returned where the first element is the sample tensor.
218
+ """
219
+
220
+ if (
221
+ isinstance(timestep, int)
222
+ or isinstance(timestep, torch.IntTensor)
223
+ or isinstance(timestep, torch.LongTensor)
224
+ ):
225
+ raise ValueError(
226
+ (
227
+ "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
228
+ " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
229
+ " one of the `scheduler.timesteps` as a timestep."
230
+ ),
231
+ )
232
+
233
+ if self.step_index is None:
234
+ self._init_step_index(timestep)
235
+
236
+ # Upcast to avoid precision issues when computing prev_sample
237
+ sample = sample.to(torch.float32)
238
+
239
+ dt = self.sigmas[self.step_index + 1] - self.sigmas[self.step_index]
240
+
241
+ if self.config.solver == "euler":
242
+ prev_sample = sample + model_output.to(torch.float32) * dt
243
+ else:
244
+ raise ValueError(
245
+ f"Solver {self.config.solver} not supported. Supported solvers: {self.supported_solver}"
246
+ )
247
+
248
+ # upon completion increase step index by one
249
+ self._step_index += 1
250
+
251
+ if not return_dict:
252
+ return (prev_sample,)
253
+
254
+ return FlowMatchDiscreteSchedulerOutput(prev_sample=prev_sample)
255
+
256
+ def __len__(self):
257
+ return self.config.num_train_timesteps
modules/unet_causal_3d_blocks.py ADDED
@@ -0,0 +1,818 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ #
16
+ # Modified from diffusers==0.29.2
17
+ #
18
+ # ==============================================================================
19
+
20
+ from typing import Optional, Tuple, Union
21
+
22
+ import torch
23
+ import torch.nn.functional as F
24
+ from torch import nn
25
+ from einops import rearrange
26
+
27
+ from diffusers.utils import logging
28
+ from diffusers.models.activations import get_activation
29
+ from diffusers.models.attention_processor import SpatialNorm
30
+ from diffusers.models.attention_processor import Attention
31
+ from diffusers.models.normalization import AdaGroupNorm
32
+ from diffusers.models.normalization import RMSNorm
33
+
34
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
35
+
36
+
37
+ def prepare_causal_attention_mask(n_frame: int, n_hw: int, dtype, device, batch_size: int = None):
38
+ seq_len = n_frame * n_hw
39
+ mask = torch.full((seq_len, seq_len), float("-inf"), dtype=dtype, device=device)
40
+ for i in range(seq_len):
41
+ i_frame = i // n_hw
42
+ mask[i, : (i_frame + 1) * n_hw] = 0
43
+ if batch_size is not None:
44
+ mask = mask.unsqueeze(0).expand(batch_size, -1, -1)
45
+ return mask
46
+
47
+
48
+ class CausalConv3d(nn.Module):
49
+ """
50
+ Implements a causal 3D convolution layer where each position only depends on previous timesteps and current spatial locations.
51
+ This maintains temporal causality in video generation tasks.
52
+ """
53
+
54
+ def __init__(
55
+ self,
56
+ chan_in,
57
+ chan_out,
58
+ kernel_size: Union[int, Tuple[int, int, int]],
59
+ stride: Union[int, Tuple[int, int, int]] = 1,
60
+ dilation: Union[int, Tuple[int, int, int]] = 1,
61
+ pad_mode="replicate",
62
+ chunk_size=0,
63
+ **kwargs,
64
+ ):
65
+ super().__init__()
66
+
67
+ self.pad_mode = pad_mode
68
+ padding = (kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size - 1, 0) # W, H, T
69
+ self.time_causal_padding = padding
70
+ self.chunk_size = chunk_size
71
+
72
+ self.conv = nn.Conv3d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs)
73
+
74
+ def original_forward(self, x):
75
+ x = F.pad(x, self.time_causal_padding, mode=self.pad_mode)
76
+ return self.conv(x)
77
+
78
+ def forward(self, x):
79
+ if self.chunk_size == 0:
80
+ return self.original_forward(x)
81
+
82
+ # if not large, call original forward
83
+ if x.shape[4] < self.chunk_size * 1.5:
84
+ return self.original_forward(x)
85
+
86
+ # # debug: verify the original forward is the same as chunked forward
87
+ # orig_forwarded_value = None
88
+ # if x.shape[4] < self.chunk_size * 4:
89
+ # orig_forwarded_value = self.original_forward(x)
90
+
91
+ # get the kernel size
92
+ kernel_size = self.conv.kernel_size[0] # assume cubic kernel
93
+ assert kernel_size == self.conv.kernel_size[1] == self.conv.kernel_size[2], "Only cubic kernels are supported"
94
+ padding_size = kernel_size // 2 # 1 for kernel_size=3, 0 for kernel_size=1
95
+
96
+ x = F.pad(x, self.time_causal_padding, mode=self.pad_mode)
97
+
98
+ B, C, D, H, W = orig_shape = x.shape
99
+ chunk_size = self.chunk_size
100
+ chunk_size -= chunk_size % self.conv.stride[2] # make sure the chunk size is divisible by stride
101
+ # print(f"chunked forward: {x.shape}, chunk_size: {chunk_size}")
102
+
103
+ # calculate the indices for chunking with overlap and padding by kernel size and stride
104
+ indices = []
105
+ i = 0
106
+ while i < W - padding_size:
107
+ start_idx = i - padding_size
108
+ end_idx = min(i + chunk_size + padding_size, W)
109
+ if i == 0:
110
+ start_idx = 0
111
+ end_idx += padding_size # to make sure the first chunk is divisible by stride
112
+ if W - end_idx < chunk_size // 2: # small chunk at the end
113
+ end_idx = W
114
+ indices.append((start_idx, end_idx))
115
+ i = end_idx - padding_size
116
+ # print(f"chunked forward: {x.shape}, chunked indices: {indices}")
117
+
118
+ chunks = []
119
+ for start_idx, end_idx in indices:
120
+ chunk = x[:, :, :, :, start_idx:end_idx]
121
+ chunk_output = self.conv(chunk)
122
+ # print(chunk.shape, chunk_output.shape)
123
+ chunks.append(chunk_output)
124
+
125
+ # concatenate the chunks
126
+ x = torch.cat(chunks, dim=4)
127
+
128
+ assert (
129
+ x.shape[2] == ((D - padding_size * 2) + self.conv.stride[0] - 1) // self.conv.stride[0]
130
+ ), f"Invalid shape: {x.shape}, {orig_shape}, {padding_size}, {self.conv.stride}"
131
+ assert (
132
+ x.shape[3] == ((H - padding_size * 2) + self.conv.stride[1] - 1) // self.conv.stride[1]
133
+ ), f"Invalid shape: {x.shape}, {orig_shape}, {padding_size}, {self.conv.stride}"
134
+ assert (
135
+ x.shape[4] == ((W - padding_size * 2) + self.conv.stride[2] - 1) // self.conv.stride[2]
136
+ ), f"Invalid shape: {x.shape}, {orig_shape}, {padding_size}, {self.conv.stride}"
137
+
138
+ # # debug: verify the original forward is the same as chunked forward
139
+ # if orig_forwarded_value is not None:
140
+ # assert torch.allclose(
141
+ # orig_forwarded_value, x, rtol=1e-4, atol=1e-2
142
+ # ), f"Chunked forward is different from original forward. {x.shape}, {orig_shape}, {padding_size}, {self.conv.stride}, {self.conv.kernel_size}"
143
+
144
+ return x
145
+
146
+
147
+ class UpsampleCausal3D(nn.Module):
148
+ """
149
+ A 3D upsampling layer with an optional convolution.
150
+ """
151
+
152
+ def __init__(
153
+ self,
154
+ channels: int,
155
+ use_conv: bool = False,
156
+ use_conv_transpose: bool = False,
157
+ out_channels: Optional[int] = None,
158
+ name: str = "conv",
159
+ kernel_size: Optional[int] = None,
160
+ padding=1,
161
+ norm_type=None,
162
+ eps=None,
163
+ elementwise_affine=None,
164
+ bias=True,
165
+ interpolate=True,
166
+ upsample_factor=(2, 2, 2),
167
+ ):
168
+ super().__init__()
169
+ self.channels = channels
170
+ self.out_channels = out_channels or channels
171
+ self.use_conv = use_conv
172
+ self.use_conv_transpose = use_conv_transpose
173
+ self.name = name
174
+ self.interpolate = interpolate
175
+ self.upsample_factor = upsample_factor
176
+
177
+ if norm_type == "ln_norm":
178
+ self.norm = nn.LayerNorm(channels, eps, elementwise_affine)
179
+ elif norm_type == "rms_norm":
180
+ self.norm = RMSNorm(channels, eps, elementwise_affine)
181
+ elif norm_type is None:
182
+ self.norm = None
183
+ else:
184
+ raise ValueError(f"unknown norm_type: {norm_type}")
185
+
186
+ conv = None
187
+ if use_conv_transpose:
188
+ raise NotImplementedError
189
+ elif use_conv:
190
+ if kernel_size is None:
191
+ kernel_size = 3
192
+ conv = CausalConv3d(self.channels, self.out_channels, kernel_size=kernel_size, bias=bias)
193
+
194
+ if name == "conv":
195
+ self.conv = conv
196
+ else:
197
+ self.Conv2d_0 = conv
198
+
199
+ def forward(
200
+ self,
201
+ hidden_states: torch.FloatTensor,
202
+ output_size: Optional[int] = None,
203
+ scale: float = 1.0,
204
+ ) -> torch.FloatTensor:
205
+ assert hidden_states.shape[1] == self.channels
206
+
207
+ if self.norm is not None:
208
+ raise NotImplementedError
209
+
210
+ if self.use_conv_transpose:
211
+ return self.conv(hidden_states)
212
+
213
+ # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
214
+ dtype = hidden_states.dtype
215
+ if dtype == torch.bfloat16:
216
+ hidden_states = hidden_states.to(torch.float32)
217
+
218
+ # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
219
+ if hidden_states.shape[0] >= 64:
220
+ hidden_states = hidden_states.contiguous()
221
+
222
+ # if `output_size` is passed we force the interpolation output
223
+ # size and do not make use of `scale_factor=2`
224
+ if self.interpolate:
225
+ B, C, T, H, W = hidden_states.shape
226
+ first_h, other_h = hidden_states.split((1, T - 1), dim=2)
227
+ if output_size is None:
228
+ if T > 1:
229
+ other_h = F.interpolate(other_h, scale_factor=self.upsample_factor, mode="nearest")
230
+
231
+ first_h = first_h.squeeze(2)
232
+ first_h = F.interpolate(first_h, scale_factor=self.upsample_factor[1:], mode="nearest")
233
+ first_h = first_h.unsqueeze(2)
234
+ else:
235
+ raise NotImplementedError
236
+
237
+ if T > 1:
238
+ hidden_states = torch.cat((first_h, other_h), dim=2)
239
+ else:
240
+ hidden_states = first_h
241
+
242
+ # If the input is bfloat16, we cast back to bfloat16
243
+ if dtype == torch.bfloat16:
244
+ hidden_states = hidden_states.to(dtype)
245
+
246
+ if self.use_conv:
247
+ if self.name == "conv":
248
+ hidden_states = self.conv(hidden_states)
249
+ else:
250
+ hidden_states = self.Conv2d_0(hidden_states)
251
+
252
+ return hidden_states
253
+
254
+
255
+ class DownsampleCausal3D(nn.Module):
256
+ """
257
+ A 3D downsampling layer with an optional convolution.
258
+ """
259
+
260
+ def __init__(
261
+ self,
262
+ channels: int,
263
+ use_conv: bool = False,
264
+ out_channels: Optional[int] = None,
265
+ padding: int = 1,
266
+ name: str = "conv",
267
+ kernel_size=3,
268
+ norm_type=None,
269
+ eps=None,
270
+ elementwise_affine=None,
271
+ bias=True,
272
+ stride=2,
273
+ ):
274
+ super().__init__()
275
+ self.channels = channels
276
+ self.out_channels = out_channels or channels
277
+ self.use_conv = use_conv
278
+ self.padding = padding
279
+ stride = stride
280
+ self.name = name
281
+
282
+ if norm_type == "ln_norm":
283
+ self.norm = nn.LayerNorm(channels, eps, elementwise_affine)
284
+ elif norm_type == "rms_norm":
285
+ self.norm = RMSNorm(channels, eps, elementwise_affine)
286
+ elif norm_type is None:
287
+ self.norm = None
288
+ else:
289
+ raise ValueError(f"unknown norm_type: {norm_type}")
290
+
291
+ if use_conv:
292
+ conv = CausalConv3d(self.channels, self.out_channels, kernel_size=kernel_size, stride=stride, bias=bias)
293
+ else:
294
+ raise NotImplementedError
295
+
296
+ if name == "conv":
297
+ self.Conv2d_0 = conv
298
+ self.conv = conv
299
+ elif name == "Conv2d_0":
300
+ self.conv = conv
301
+ else:
302
+ self.conv = conv
303
+
304
+ def forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0) -> torch.FloatTensor:
305
+ assert hidden_states.shape[1] == self.channels
306
+
307
+ if self.norm is not None:
308
+ hidden_states = self.norm(hidden_states.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
309
+
310
+ assert hidden_states.shape[1] == self.channels
311
+
312
+ hidden_states = self.conv(hidden_states)
313
+
314
+ return hidden_states
315
+
316
+
317
+ class ResnetBlockCausal3D(nn.Module):
318
+ r"""
319
+ A Resnet block.
320
+ """
321
+
322
+ def __init__(
323
+ self,
324
+ *,
325
+ in_channels: int,
326
+ out_channels: Optional[int] = None,
327
+ conv_shortcut: bool = False,
328
+ dropout: float = 0.0,
329
+ temb_channels: int = 512,
330
+ groups: int = 32,
331
+ groups_out: Optional[int] = None,
332
+ pre_norm: bool = True,
333
+ eps: float = 1e-6,
334
+ non_linearity: str = "swish",
335
+ skip_time_act: bool = False,
336
+ # default, scale_shift, ada_group, spatial
337
+ time_embedding_norm: str = "default",
338
+ kernel: Optional[torch.FloatTensor] = None,
339
+ output_scale_factor: float = 1.0,
340
+ use_in_shortcut: Optional[bool] = None,
341
+ up: bool = False,
342
+ down: bool = False,
343
+ conv_shortcut_bias: bool = True,
344
+ conv_3d_out_channels: Optional[int] = None,
345
+ ):
346
+ super().__init__()
347
+ self.pre_norm = pre_norm
348
+ self.pre_norm = True
349
+ self.in_channels = in_channels
350
+ out_channels = in_channels if out_channels is None else out_channels
351
+ self.out_channels = out_channels
352
+ self.use_conv_shortcut = conv_shortcut
353
+ self.up = up
354
+ self.down = down
355
+ self.output_scale_factor = output_scale_factor
356
+ self.time_embedding_norm = time_embedding_norm
357
+ self.skip_time_act = skip_time_act
358
+
359
+ linear_cls = nn.Linear
360
+
361
+ if groups_out is None:
362
+ groups_out = groups
363
+
364
+ if self.time_embedding_norm == "ada_group":
365
+ self.norm1 = AdaGroupNorm(temb_channels, in_channels, groups, eps=eps)
366
+ elif self.time_embedding_norm == "spatial":
367
+ self.norm1 = SpatialNorm(in_channels, temb_channels)
368
+ else:
369
+ self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
370
+
371
+ self.conv1 = CausalConv3d(in_channels, out_channels, kernel_size=3, stride=1)
372
+
373
+ if temb_channels is not None:
374
+ if self.time_embedding_norm == "default":
375
+ self.time_emb_proj = linear_cls(temb_channels, out_channels)
376
+ elif self.time_embedding_norm == "scale_shift":
377
+ self.time_emb_proj = linear_cls(temb_channels, 2 * out_channels)
378
+ elif self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
379
+ self.time_emb_proj = None
380
+ else:
381
+ raise ValueError(f"Unknown time_embedding_norm : {self.time_embedding_norm} ")
382
+ else:
383
+ self.time_emb_proj = None
384
+
385
+ if self.time_embedding_norm == "ada_group":
386
+ self.norm2 = AdaGroupNorm(temb_channels, out_channels, groups_out, eps=eps)
387
+ elif self.time_embedding_norm == "spatial":
388
+ self.norm2 = SpatialNorm(out_channels, temb_channels)
389
+ else:
390
+ self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
391
+
392
+ self.dropout = torch.nn.Dropout(dropout)
393
+ conv_3d_out_channels = conv_3d_out_channels or out_channels
394
+ self.conv2 = CausalConv3d(out_channels, conv_3d_out_channels, kernel_size=3, stride=1)
395
+
396
+ self.nonlinearity = get_activation(non_linearity)
397
+
398
+ self.upsample = self.downsample = None
399
+ if self.up:
400
+ self.upsample = UpsampleCausal3D(in_channels, use_conv=False)
401
+ elif self.down:
402
+ self.downsample = DownsampleCausal3D(in_channels, use_conv=False, name="op")
403
+
404
+ self.use_in_shortcut = self.in_channels != conv_3d_out_channels if use_in_shortcut is None else use_in_shortcut
405
+
406
+ self.conv_shortcut = None
407
+ if self.use_in_shortcut:
408
+ self.conv_shortcut = CausalConv3d(
409
+ in_channels,
410
+ conv_3d_out_channels,
411
+ kernel_size=1,
412
+ stride=1,
413
+ bias=conv_shortcut_bias,
414
+ )
415
+
416
+ def forward(
417
+ self,
418
+ input_tensor: torch.FloatTensor,
419
+ temb: torch.FloatTensor,
420
+ scale: float = 1.0,
421
+ ) -> torch.FloatTensor:
422
+ hidden_states = input_tensor
423
+
424
+ if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
425
+ hidden_states = self.norm1(hidden_states, temb)
426
+ else:
427
+ hidden_states = self.norm1(hidden_states)
428
+
429
+ hidden_states = self.nonlinearity(hidden_states)
430
+
431
+ if self.upsample is not None:
432
+ # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
433
+ if hidden_states.shape[0] >= 64:
434
+ input_tensor = input_tensor.contiguous()
435
+ hidden_states = hidden_states.contiguous()
436
+ input_tensor = self.upsample(input_tensor, scale=scale)
437
+ hidden_states = self.upsample(hidden_states, scale=scale)
438
+ elif self.downsample is not None:
439
+ input_tensor = self.downsample(input_tensor, scale=scale)
440
+ hidden_states = self.downsample(hidden_states, scale=scale)
441
+
442
+ hidden_states = self.conv1(hidden_states)
443
+
444
+ if self.time_emb_proj is not None:
445
+ if not self.skip_time_act:
446
+ temb = self.nonlinearity(temb)
447
+ temb = self.time_emb_proj(temb, scale)[:, :, None, None]
448
+
449
+ if temb is not None and self.time_embedding_norm == "default":
450
+ hidden_states = hidden_states + temb
451
+
452
+ if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
453
+ hidden_states = self.norm2(hidden_states, temb)
454
+ else:
455
+ hidden_states = self.norm2(hidden_states)
456
+
457
+ if temb is not None and self.time_embedding_norm == "scale_shift":
458
+ scale, shift = torch.chunk(temb, 2, dim=1)
459
+ hidden_states = hidden_states * (1 + scale) + shift
460
+
461
+ hidden_states = self.nonlinearity(hidden_states)
462
+
463
+ hidden_states = self.dropout(hidden_states)
464
+ hidden_states = self.conv2(hidden_states)
465
+
466
+ if self.conv_shortcut is not None:
467
+ input_tensor = self.conv_shortcut(input_tensor)
468
+
469
+ output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
470
+
471
+ return output_tensor
472
+
473
+
474
+ def get_down_block3d(
475
+ down_block_type: str,
476
+ num_layers: int,
477
+ in_channels: int,
478
+ out_channels: int,
479
+ temb_channels: int,
480
+ add_downsample: bool,
481
+ downsample_stride: int,
482
+ resnet_eps: float,
483
+ resnet_act_fn: str,
484
+ transformer_layers_per_block: int = 1,
485
+ num_attention_heads: Optional[int] = None,
486
+ resnet_groups: Optional[int] = None,
487
+ cross_attention_dim: Optional[int] = None,
488
+ downsample_padding: Optional[int] = None,
489
+ dual_cross_attention: bool = False,
490
+ use_linear_projection: bool = False,
491
+ only_cross_attention: bool = False,
492
+ upcast_attention: bool = False,
493
+ resnet_time_scale_shift: str = "default",
494
+ attention_type: str = "default",
495
+ resnet_skip_time_act: bool = False,
496
+ resnet_out_scale_factor: float = 1.0,
497
+ cross_attention_norm: Optional[str] = None,
498
+ attention_head_dim: Optional[int] = None,
499
+ downsample_type: Optional[str] = None,
500
+ dropout: float = 0.0,
501
+ ):
502
+ # If attn head dim is not defined, we default it to the number of heads
503
+ if attention_head_dim is None:
504
+ logger.warn(
505
+ f"It is recommended to provide `attention_head_dim` when calling `get_down_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
506
+ )
507
+ attention_head_dim = num_attention_heads
508
+
509
+ down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
510
+ if down_block_type == "DownEncoderBlockCausal3D":
511
+ return DownEncoderBlockCausal3D(
512
+ num_layers=num_layers,
513
+ in_channels=in_channels,
514
+ out_channels=out_channels,
515
+ dropout=dropout,
516
+ add_downsample=add_downsample,
517
+ downsample_stride=downsample_stride,
518
+ resnet_eps=resnet_eps,
519
+ resnet_act_fn=resnet_act_fn,
520
+ resnet_groups=resnet_groups,
521
+ downsample_padding=downsample_padding,
522
+ resnet_time_scale_shift=resnet_time_scale_shift,
523
+ )
524
+ raise ValueError(f"{down_block_type} does not exist.")
525
+
526
+
527
+ def get_up_block3d(
528
+ up_block_type: str,
529
+ num_layers: int,
530
+ in_channels: int,
531
+ out_channels: int,
532
+ prev_output_channel: int,
533
+ temb_channels: int,
534
+ add_upsample: bool,
535
+ upsample_scale_factor: Tuple,
536
+ resnet_eps: float,
537
+ resnet_act_fn: str,
538
+ resolution_idx: Optional[int] = None,
539
+ transformer_layers_per_block: int = 1,
540
+ num_attention_heads: Optional[int] = None,
541
+ resnet_groups: Optional[int] = None,
542
+ cross_attention_dim: Optional[int] = None,
543
+ dual_cross_attention: bool = False,
544
+ use_linear_projection: bool = False,
545
+ only_cross_attention: bool = False,
546
+ upcast_attention: bool = False,
547
+ resnet_time_scale_shift: str = "default",
548
+ attention_type: str = "default",
549
+ resnet_skip_time_act: bool = False,
550
+ resnet_out_scale_factor: float = 1.0,
551
+ cross_attention_norm: Optional[str] = None,
552
+ attention_head_dim: Optional[int] = None,
553
+ upsample_type: Optional[str] = None,
554
+ dropout: float = 0.0,
555
+ ) -> nn.Module:
556
+ # If attn head dim is not defined, we default it to the number of heads
557
+ if attention_head_dim is None:
558
+ logger.warn(
559
+ f"It is recommended to provide `attention_head_dim` when calling `get_up_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
560
+ )
561
+ attention_head_dim = num_attention_heads
562
+
563
+ up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
564
+ if up_block_type == "UpDecoderBlockCausal3D":
565
+ return UpDecoderBlockCausal3D(
566
+ num_layers=num_layers,
567
+ in_channels=in_channels,
568
+ out_channels=out_channels,
569
+ resolution_idx=resolution_idx,
570
+ dropout=dropout,
571
+ add_upsample=add_upsample,
572
+ upsample_scale_factor=upsample_scale_factor,
573
+ resnet_eps=resnet_eps,
574
+ resnet_act_fn=resnet_act_fn,
575
+ resnet_groups=resnet_groups,
576
+ resnet_time_scale_shift=resnet_time_scale_shift,
577
+ temb_channels=temb_channels,
578
+ )
579
+ raise ValueError(f"{up_block_type} does not exist.")
580
+
581
+
582
+ class UNetMidBlockCausal3D(nn.Module):
583
+ """
584
+ A 3D UNet mid-block [`UNetMidBlockCausal3D`] with multiple residual blocks and optional attention blocks.
585
+ """
586
+
587
+ def __init__(
588
+ self,
589
+ in_channels: int,
590
+ temb_channels: int,
591
+ dropout: float = 0.0,
592
+ num_layers: int = 1,
593
+ resnet_eps: float = 1e-6,
594
+ resnet_time_scale_shift: str = "default", # default, spatial
595
+ resnet_act_fn: str = "swish",
596
+ resnet_groups: int = 32,
597
+ attn_groups: Optional[int] = None,
598
+ resnet_pre_norm: bool = True,
599
+ add_attention: bool = True,
600
+ attention_head_dim: int = 1,
601
+ output_scale_factor: float = 1.0,
602
+ ):
603
+ super().__init__()
604
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
605
+ self.add_attention = add_attention
606
+
607
+ if attn_groups is None:
608
+ attn_groups = resnet_groups if resnet_time_scale_shift == "default" else None
609
+
610
+ # there is always at least one resnet
611
+ resnets = [
612
+ ResnetBlockCausal3D(
613
+ in_channels=in_channels,
614
+ out_channels=in_channels,
615
+ temb_channels=temb_channels,
616
+ eps=resnet_eps,
617
+ groups=resnet_groups,
618
+ dropout=dropout,
619
+ time_embedding_norm=resnet_time_scale_shift,
620
+ non_linearity=resnet_act_fn,
621
+ output_scale_factor=output_scale_factor,
622
+ pre_norm=resnet_pre_norm,
623
+ )
624
+ ]
625
+ attentions = []
626
+
627
+ if attention_head_dim is None:
628
+ logger.warn(
629
+ f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {in_channels}."
630
+ )
631
+ attention_head_dim = in_channels
632
+
633
+ for _ in range(num_layers):
634
+ if self.add_attention:
635
+ attentions.append(
636
+ Attention(
637
+ in_channels,
638
+ heads=in_channels // attention_head_dim,
639
+ dim_head=attention_head_dim,
640
+ rescale_output_factor=output_scale_factor,
641
+ eps=resnet_eps,
642
+ norm_num_groups=attn_groups,
643
+ spatial_norm_dim=temb_channels if resnet_time_scale_shift == "spatial" else None,
644
+ residual_connection=True,
645
+ bias=True,
646
+ upcast_softmax=True,
647
+ _from_deprecated_attn_block=True,
648
+ )
649
+ )
650
+ else:
651
+ attentions.append(None)
652
+
653
+ resnets.append(
654
+ ResnetBlockCausal3D(
655
+ in_channels=in_channels,
656
+ out_channels=in_channels,
657
+ temb_channels=temb_channels,
658
+ eps=resnet_eps,
659
+ groups=resnet_groups,
660
+ dropout=dropout,
661
+ time_embedding_norm=resnet_time_scale_shift,
662
+ non_linearity=resnet_act_fn,
663
+ output_scale_factor=output_scale_factor,
664
+ pre_norm=resnet_pre_norm,
665
+ )
666
+ )
667
+
668
+ self.attentions = nn.ModuleList(attentions)
669
+ self.resnets = nn.ModuleList(resnets)
670
+
671
+ def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
672
+ hidden_states = self.resnets[0](hidden_states, temb)
673
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
674
+ if attn is not None:
675
+ B, C, T, H, W = hidden_states.shape
676
+ hidden_states = rearrange(hidden_states, "b c f h w -> b (f h w) c")
677
+ attention_mask = prepare_causal_attention_mask(T, H * W, hidden_states.dtype, hidden_states.device, batch_size=B)
678
+ hidden_states = attn(hidden_states, temb=temb, attention_mask=attention_mask)
679
+ hidden_states = rearrange(hidden_states, "b (f h w) c -> b c f h w", f=T, h=H, w=W)
680
+ hidden_states = resnet(hidden_states, temb)
681
+
682
+ return hidden_states
683
+
684
+
685
+ class DownEncoderBlockCausal3D(nn.Module):
686
+ def __init__(
687
+ self,
688
+ in_channels: int,
689
+ out_channels: int,
690
+ dropout: float = 0.0,
691
+ num_layers: int = 1,
692
+ resnet_eps: float = 1e-6,
693
+ resnet_time_scale_shift: str = "default",
694
+ resnet_act_fn: str = "swish",
695
+ resnet_groups: int = 32,
696
+ resnet_pre_norm: bool = True,
697
+ output_scale_factor: float = 1.0,
698
+ add_downsample: bool = True,
699
+ downsample_stride: int = 2,
700
+ downsample_padding: int = 1,
701
+ ):
702
+ super().__init__()
703
+ resnets = []
704
+
705
+ for i in range(num_layers):
706
+ in_channels = in_channels if i == 0 else out_channels
707
+ resnets.append(
708
+ ResnetBlockCausal3D(
709
+ in_channels=in_channels,
710
+ out_channels=out_channels,
711
+ temb_channels=None,
712
+ eps=resnet_eps,
713
+ groups=resnet_groups,
714
+ dropout=dropout,
715
+ time_embedding_norm=resnet_time_scale_shift,
716
+ non_linearity=resnet_act_fn,
717
+ output_scale_factor=output_scale_factor,
718
+ pre_norm=resnet_pre_norm,
719
+ )
720
+ )
721
+
722
+ self.resnets = nn.ModuleList(resnets)
723
+
724
+ if add_downsample:
725
+ self.downsamplers = nn.ModuleList(
726
+ [
727
+ DownsampleCausal3D(
728
+ out_channels,
729
+ use_conv=True,
730
+ out_channels=out_channels,
731
+ padding=downsample_padding,
732
+ name="op",
733
+ stride=downsample_stride,
734
+ )
735
+ ]
736
+ )
737
+ else:
738
+ self.downsamplers = None
739
+
740
+ def forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0) -> torch.FloatTensor:
741
+ for resnet in self.resnets:
742
+ hidden_states = resnet(hidden_states, temb=None, scale=scale)
743
+
744
+ if self.downsamplers is not None:
745
+ for downsampler in self.downsamplers:
746
+ hidden_states = downsampler(hidden_states, scale)
747
+
748
+ return hidden_states
749
+
750
+
751
+ class UpDecoderBlockCausal3D(nn.Module):
752
+ def __init__(
753
+ self,
754
+ in_channels: int,
755
+ out_channels: int,
756
+ resolution_idx: Optional[int] = None,
757
+ dropout: float = 0.0,
758
+ num_layers: int = 1,
759
+ resnet_eps: float = 1e-6,
760
+ resnet_time_scale_shift: str = "default", # default, spatial
761
+ resnet_act_fn: str = "swish",
762
+ resnet_groups: int = 32,
763
+ resnet_pre_norm: bool = True,
764
+ output_scale_factor: float = 1.0,
765
+ add_upsample: bool = True,
766
+ upsample_scale_factor=(2, 2, 2),
767
+ temb_channels: Optional[int] = None,
768
+ ):
769
+ super().__init__()
770
+ resnets = []
771
+
772
+ for i in range(num_layers):
773
+ input_channels = in_channels if i == 0 else out_channels
774
+
775
+ resnets.append(
776
+ ResnetBlockCausal3D(
777
+ in_channels=input_channels,
778
+ out_channels=out_channels,
779
+ temb_channels=temb_channels,
780
+ eps=resnet_eps,
781
+ groups=resnet_groups,
782
+ dropout=dropout,
783
+ time_embedding_norm=resnet_time_scale_shift,
784
+ non_linearity=resnet_act_fn,
785
+ output_scale_factor=output_scale_factor,
786
+ pre_norm=resnet_pre_norm,
787
+ )
788
+ )
789
+
790
+ self.resnets = nn.ModuleList(resnets)
791
+
792
+ if add_upsample:
793
+ self.upsamplers = nn.ModuleList(
794
+ [
795
+ UpsampleCausal3D(
796
+ out_channels,
797
+ use_conv=True,
798
+ out_channels=out_channels,
799
+ upsample_factor=upsample_scale_factor,
800
+ )
801
+ ]
802
+ )
803
+ else:
804
+ self.upsamplers = None
805
+
806
+ self.resolution_idx = resolution_idx
807
+
808
+ def forward(
809
+ self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0
810
+ ) -> torch.FloatTensor:
811
+ for resnet in self.resnets:
812
+ hidden_states = resnet(hidden_states, temb=temb, scale=scale)
813
+
814
+ if self.upsamplers is not None:
815
+ for upsampler in self.upsamplers:
816
+ hidden_states = upsampler(hidden_states)
817
+
818
+ return hidden_states
networks/__init__.py ADDED
File without changes
networks/lora.py ADDED
@@ -0,0 +1,828 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # LoRA network module: currently conv2d is not fully supported
2
+ # reference:
3
+ # https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
4
+ # https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py
5
+
6
+ import math
7
+ import os
8
+ from typing import Dict, List, Optional, Type, Union
9
+ from diffusers import AutoencoderKL
10
+ from transformers import CLIPTextModel
11
+ import numpy as np
12
+ import torch
13
+ import torch.nn as nn
14
+
15
+ import logging
16
+
17
+ logger = logging.getLogger(__name__)
18
+ logging.basicConfig(level=logging.INFO)
19
+
20
+ HUNYUAN_TARGET_REPLACE_MODULES = ["MMDoubleStreamBlock", "MMSingleStreamBlock"]
21
+
22
+
23
+ class LoRAModule(torch.nn.Module):
24
+ """
25
+ replaces forward method of the original Linear, instead of replacing the original Linear module.
26
+ """
27
+
28
+ def __init__(
29
+ self,
30
+ lora_name,
31
+ org_module: torch.nn.Module,
32
+ multiplier=1.0,
33
+ lora_dim=4,
34
+ alpha=1,
35
+ dropout=None,
36
+ rank_dropout=None,
37
+ module_dropout=None,
38
+ split_dims: Optional[List[int]] = None,
39
+ ):
40
+ """
41
+ if alpha == 0 or None, alpha is rank (no scaling).
42
+
43
+ split_dims is used to mimic the split qkv of multi-head attention.
44
+ """
45
+ super().__init__()
46
+ self.lora_name = lora_name
47
+
48
+ if org_module.__class__.__name__ == "Conv2d":
49
+ in_dim = org_module.in_channels
50
+ out_dim = org_module.out_channels
51
+ else:
52
+ in_dim = org_module.in_features
53
+ out_dim = org_module.out_features
54
+
55
+ self.lora_dim = lora_dim
56
+ self.split_dims = split_dims
57
+
58
+ if split_dims is None:
59
+ if org_module.__class__.__name__ == "Conv2d":
60
+ kernel_size = org_module.kernel_size
61
+ stride = org_module.stride
62
+ padding = org_module.padding
63
+ self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False)
64
+ self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False)
65
+ else:
66
+ self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False)
67
+ self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False)
68
+
69
+ torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
70
+ torch.nn.init.zeros_(self.lora_up.weight)
71
+ else:
72
+ # conv2d not supported
73
+ assert sum(split_dims) == out_dim, "sum of split_dims must be equal to out_dim"
74
+ assert org_module.__class__.__name__ == "Linear", "split_dims is only supported for Linear"
75
+ # print(f"split_dims: {split_dims}")
76
+ self.lora_down = torch.nn.ModuleList(
77
+ [torch.nn.Linear(in_dim, self.lora_dim, bias=False) for _ in range(len(split_dims))]
78
+ )
79
+ self.lora_up = torch.nn.ModuleList([torch.nn.Linear(self.lora_dim, split_dim, bias=False) for split_dim in split_dims])
80
+ for lora_down in self.lora_down:
81
+ torch.nn.init.kaiming_uniform_(lora_down.weight, a=math.sqrt(5))
82
+ for lora_up in self.lora_up:
83
+ torch.nn.init.zeros_(lora_up.weight)
84
+
85
+ if type(alpha) == torch.Tensor:
86
+ alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
87
+ alpha = self.lora_dim if alpha is None or alpha == 0 else alpha
88
+ self.scale = alpha / self.lora_dim
89
+ self.register_buffer("alpha", torch.tensor(alpha)) # for save/load
90
+
91
+ # same as microsoft's
92
+ self.multiplier = multiplier
93
+ self.org_module = org_module # remove in applying
94
+ self.dropout = dropout
95
+ self.rank_dropout = rank_dropout
96
+ self.module_dropout = module_dropout
97
+
98
+ def apply_to(self):
99
+ self.org_forward = self.org_module.forward
100
+ self.org_module.forward = self.forward
101
+ del self.org_module
102
+
103
+ def forward(self, x):
104
+ org_forwarded = self.org_forward(x)
105
+
106
+ # module dropout
107
+ if self.module_dropout is not None and self.training:
108
+ if torch.rand(1) < self.module_dropout:
109
+ return org_forwarded
110
+
111
+ if self.split_dims is None:
112
+ lx = self.lora_down(x)
113
+
114
+ # normal dropout
115
+ if self.dropout is not None and self.training:
116
+ lx = torch.nn.functional.dropout(lx, p=self.dropout)
117
+
118
+ # rank dropout
119
+ if self.rank_dropout is not None and self.training:
120
+ mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout
121
+ if len(lx.size()) == 3:
122
+ mask = mask.unsqueeze(1) # for Text Encoder
123
+ elif len(lx.size()) == 4:
124
+ mask = mask.unsqueeze(-1).unsqueeze(-1) # for Conv2d
125
+ lx = lx * mask
126
+
127
+ # scaling for rank dropout: treat as if the rank is changed
128
+ scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability
129
+ else:
130
+ scale = self.scale
131
+
132
+ lx = self.lora_up(lx)
133
+
134
+ return org_forwarded + lx * self.multiplier * scale
135
+ else:
136
+ lxs = [lora_down(x) for lora_down in self.lora_down]
137
+
138
+ # normal dropout
139
+ if self.dropout is not None and self.training:
140
+ lxs = [torch.nn.functional.dropout(lx, p=self.dropout) for lx in lxs]
141
+
142
+ # rank dropout
143
+ if self.rank_dropout is not None and self.training:
144
+ masks = [torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout for lx in lxs]
145
+ for i in range(len(lxs)):
146
+ if len(lx.size()) == 3:
147
+ masks[i] = masks[i].unsqueeze(1)
148
+ elif len(lx.size()) == 4:
149
+ masks[i] = masks[i].unsqueeze(-1).unsqueeze(-1)
150
+ lxs[i] = lxs[i] * masks[i]
151
+
152
+ # scaling for rank dropout: treat as if the rank is changed
153
+ scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability
154
+ else:
155
+ scale = self.scale
156
+
157
+ lxs = [lora_up(lx) for lora_up, lx in zip(self.lora_up, lxs)]
158
+
159
+ return org_forwarded + torch.cat(lxs, dim=-1) * self.multiplier * scale
160
+
161
+
162
+ class LoRAInfModule(LoRAModule):
163
+ def __init__(
164
+ self,
165
+ lora_name,
166
+ org_module: torch.nn.Module,
167
+ multiplier=1.0,
168
+ lora_dim=4,
169
+ alpha=1,
170
+ **kwargs,
171
+ ):
172
+ # no dropout for inference
173
+ super().__init__(lora_name, org_module, multiplier, lora_dim, alpha)
174
+
175
+ self.org_module_ref = [org_module] # for reference
176
+ self.enabled = True
177
+ self.network: LoRANetwork = None
178
+
179
+ def set_network(self, network):
180
+ self.network = network
181
+
182
+ # merge weight to org_module
183
+ def merge_to(self, sd, dtype, device):
184
+ # extract weight from org_module
185
+ org_sd = self.org_module.state_dict()
186
+ weight = org_sd["weight"]
187
+ org_dtype = weight.dtype
188
+ org_device = weight.device
189
+ weight = weight.to(device, dtype=torch.float) # for calculation
190
+
191
+ if dtype is None:
192
+ dtype = org_dtype
193
+ if device is None:
194
+ device = org_device
195
+
196
+ if self.split_dims is None:
197
+ # get up/down weight
198
+ down_weight = sd["lora_down.weight"].to(device, dtype=torch.float)
199
+ up_weight = sd["lora_up.weight"].to(device, dtype=torch.float)
200
+
201
+ # merge weight
202
+ if len(weight.size()) == 2:
203
+ # linear
204
+ weight = weight + self.multiplier * (up_weight @ down_weight) * self.scale
205
+ elif down_weight.size()[2:4] == (1, 1):
206
+ # conv2d 1x1
207
+ weight = (
208
+ weight
209
+ + self.multiplier
210
+ * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
211
+ * self.scale
212
+ )
213
+ else:
214
+ # conv2d 3x3
215
+ conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
216
+ # logger.info(conved.size(), weight.size(), module.stride, module.padding)
217
+ weight = weight + self.multiplier * conved * self.scale
218
+
219
+ # set weight to org_module
220
+ org_sd["weight"] = weight.to(org_device, dtype=dtype)
221
+ self.org_module.load_state_dict(org_sd)
222
+ else:
223
+ # split_dims
224
+ total_dims = sum(self.split_dims)
225
+ for i in range(len(self.split_dims)):
226
+ # get up/down weight
227
+ down_weight = sd[f"lora_down.{i}.weight"].to(torch.float).to(device) # (rank, in_dim)
228
+ up_weight = sd[f"lora_up.{i}.weight"].to(torch.float).to(device) # (split dim, rank)
229
+
230
+ # pad up_weight -> (total_dims, rank)
231
+ padded_up_weight = torch.zeros((total_dims, up_weight.size(0)), device=device, dtype=torch.float)
232
+ padded_up_weight[sum(self.split_dims[:i]) : sum(self.split_dims[: i + 1])] = up_weight
233
+
234
+ # merge weight
235
+ weight = weight + self.multiplier * (up_weight @ down_weight) * self.scale
236
+
237
+ # set weight to org_module
238
+ org_sd["weight"] = weight.to(dtype)
239
+ self.org_module.load_state_dict(org_sd)
240
+
241
+ # return weight for merge
242
+ def get_weight(self, multiplier=None):
243
+ if multiplier is None:
244
+ multiplier = self.multiplier
245
+
246
+ # get up/down weight from module
247
+ up_weight = self.lora_up.weight.to(torch.float)
248
+ down_weight = self.lora_down.weight.to(torch.float)
249
+
250
+ # pre-calculated weight
251
+ if len(down_weight.size()) == 2:
252
+ # linear
253
+ weight = self.multiplier * (up_weight @ down_weight) * self.scale
254
+ elif down_weight.size()[2:4] == (1, 1):
255
+ # conv2d 1x1
256
+ weight = (
257
+ self.multiplier
258
+ * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
259
+ * self.scale
260
+ )
261
+ else:
262
+ # conv2d 3x3
263
+ conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
264
+ weight = self.multiplier * conved * self.scale
265
+
266
+ return weight
267
+
268
+ def default_forward(self, x):
269
+ # logger.info(f"default_forward {self.lora_name} {x.size()}")
270
+ if self.split_dims is None:
271
+ lx = self.lora_down(x)
272
+ lx = self.lora_up(lx)
273
+ return self.org_forward(x) + lx * self.multiplier * self.scale
274
+ else:
275
+ lxs = [lora_down(x) for lora_down in self.lora_down]
276
+ lxs = [lora_up(lx) for lora_up, lx in zip(self.lora_up, lxs)]
277
+ return self.org_forward(x) + torch.cat(lxs, dim=-1) * self.multiplier * self.scale
278
+
279
+ def forward(self, x):
280
+ if not self.enabled:
281
+ return self.org_forward(x)
282
+ return self.default_forward(x)
283
+
284
+
285
+ def create_network_hunyuan_video(
286
+ multiplier: float,
287
+ network_dim: Optional[int],
288
+ network_alpha: Optional[float],
289
+ vae: nn.Module,
290
+ text_encoders: List[nn.Module],
291
+ unet: nn.Module,
292
+ neuron_dropout: Optional[float] = None,
293
+ **kwargs,
294
+ ):
295
+ return create_network(
296
+ HUNYUAN_TARGET_REPLACE_MODULES,
297
+ "lora_unet",
298
+ multiplier,
299
+ network_dim,
300
+ network_alpha,
301
+ vae,
302
+ text_encoders,
303
+ unet,
304
+ neuron_dropout=neuron_dropout,
305
+ **kwargs,
306
+ )
307
+
308
+
309
+ def create_network(
310
+ target_replace_modules: List[str],
311
+ prefix: str,
312
+ multiplier: float,
313
+ network_dim: Optional[int],
314
+ network_alpha: Optional[float],
315
+ vae: nn.Module,
316
+ text_encoders: List[nn.Module],
317
+ unet: nn.Module,
318
+ neuron_dropout: Optional[float] = None,
319
+ **kwargs,
320
+ ):
321
+ if network_dim is None:
322
+ network_dim = 4 # default
323
+ if network_alpha is None:
324
+ network_alpha = 1.0
325
+
326
+ # extract dim/alpha for conv2d, and block dim
327
+ conv_dim = kwargs.get("conv_dim", None)
328
+ conv_alpha = kwargs.get("conv_alpha", None)
329
+ if conv_dim is not None:
330
+ conv_dim = int(conv_dim)
331
+ if conv_alpha is None:
332
+ conv_alpha = 1.0
333
+ else:
334
+ conv_alpha = float(conv_alpha)
335
+
336
+ # TODO generic rank/dim setting with regular expression
337
+
338
+ # rank/module dropout
339
+ rank_dropout = kwargs.get("rank_dropout", None)
340
+ if rank_dropout is not None:
341
+ rank_dropout = float(rank_dropout)
342
+ module_dropout = kwargs.get("module_dropout", None)
343
+ if module_dropout is not None:
344
+ module_dropout = float(module_dropout)
345
+
346
+ # verbose
347
+ verbose = kwargs.get("verbose", False)
348
+ if verbose is not None:
349
+ verbose = True if verbose == "True" else False
350
+
351
+ # too many arguments ( ^ω^)・・・
352
+ network = LoRANetwork(
353
+ target_replace_modules,
354
+ prefix,
355
+ text_encoders,
356
+ unet,
357
+ multiplier=multiplier,
358
+ lora_dim=network_dim,
359
+ alpha=network_alpha,
360
+ dropout=neuron_dropout,
361
+ rank_dropout=rank_dropout,
362
+ module_dropout=module_dropout,
363
+ conv_lora_dim=conv_dim,
364
+ conv_alpha=conv_alpha,
365
+ verbose=verbose,
366
+ )
367
+
368
+ loraplus_lr_ratio = kwargs.get("loraplus_lr_ratio", None)
369
+ # loraplus_unet_lr_ratio = kwargs.get("loraplus_unet_lr_ratio", None)
370
+ # loraplus_text_encoder_lr_ratio = kwargs.get("loraplus_text_encoder_lr_ratio", None)
371
+ loraplus_lr_ratio = float(loraplus_lr_ratio) if loraplus_lr_ratio is not None else None
372
+ # loraplus_unet_lr_ratio = float(loraplus_unet_lr_ratio) if loraplus_unet_lr_ratio is not None else None
373
+ # loraplus_text_encoder_lr_ratio = float(loraplus_text_encoder_lr_ratio) if loraplus_text_encoder_lr_ratio is not None else None
374
+ if loraplus_lr_ratio is not None: # or loraplus_unet_lr_ratio is not None or loraplus_text_encoder_lr_ratio is not None:
375
+ network.set_loraplus_lr_ratio(loraplus_lr_ratio) # , loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio)
376
+
377
+ return network
378
+
379
+
380
+ class LoRANetwork(torch.nn.Module):
381
+ # only supports U-Net (DiT), Text Encoders are not supported
382
+
383
+ def __init__(
384
+ self,
385
+ target_replace_modules: List[str],
386
+ prefix: str,
387
+ text_encoders: Union[List[CLIPTextModel], CLIPTextModel],
388
+ unet: nn.Module,
389
+ multiplier: float = 1.0,
390
+ lora_dim: int = 4,
391
+ alpha: float = 1,
392
+ dropout: Optional[float] = None,
393
+ rank_dropout: Optional[float] = None,
394
+ module_dropout: Optional[float] = None,
395
+ conv_lora_dim: Optional[int] = None,
396
+ conv_alpha: Optional[float] = None,
397
+ module_class: Type[object] = LoRAModule,
398
+ modules_dim: Optional[Dict[str, int]] = None,
399
+ modules_alpha: Optional[Dict[str, int]] = None,
400
+ verbose: Optional[bool] = False,
401
+ ) -> None:
402
+ super().__init__()
403
+ self.multiplier = multiplier
404
+
405
+ self.lora_dim = lora_dim
406
+ self.alpha = alpha
407
+ self.conv_lora_dim = conv_lora_dim
408
+ self.conv_alpha = conv_alpha
409
+ self.dropout = dropout
410
+ self.rank_dropout = rank_dropout
411
+ self.module_dropout = module_dropout
412
+ self.target_replace_modules = target_replace_modules
413
+ self.prefix = prefix
414
+
415
+ self.loraplus_lr_ratio = None
416
+ # self.loraplus_unet_lr_ratio = None
417
+ # self.loraplus_text_encoder_lr_ratio = None
418
+
419
+ if modules_dim is not None:
420
+ logger.info(f"create LoRA network from weights")
421
+ else:
422
+ logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
423
+ logger.info(
424
+ f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}"
425
+ )
426
+ # if self.conv_lora_dim is not None:
427
+ # logger.info(
428
+ # f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}"
429
+ # )
430
+ # if train_t5xxl:
431
+ # logger.info(f"train T5XXL as well")
432
+
433
+ # create module instances
434
+ def create_modules(
435
+ is_unet: bool,
436
+ pfx: str,
437
+ root_module: torch.nn.Module,
438
+ target_replace_mods: List[str],
439
+ filter: Optional[str] = None,
440
+ default_dim: Optional[int] = None,
441
+ ) -> List[LoRAModule]:
442
+ loras = []
443
+ skipped = []
444
+ for name, module in root_module.named_modules():
445
+ if target_replace_mods is None or module.__class__.__name__ in target_replace_mods:
446
+ if target_replace_mods is None: # dirty hack for all modules
447
+ module = root_module # search all modules
448
+
449
+ for child_name, child_module in module.named_modules():
450
+ is_linear = child_module.__class__.__name__ == "Linear"
451
+ is_conv2d = child_module.__class__.__name__ == "Conv2d"
452
+ is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
453
+
454
+ if is_linear or is_conv2d:
455
+ original_name = (name + "." if name else "") + child_name
456
+ lora_name = f"{pfx}.{original_name}".replace(".", "_")
457
+
458
+ if filter is not None and not filter in lora_name:
459
+ continue
460
+
461
+ dim = None
462
+ alpha = None
463
+
464
+ if modules_dim is not None:
465
+ # モジュール指定あり
466
+ if lora_name in modules_dim:
467
+ dim = modules_dim[lora_name]
468
+ alpha = modules_alpha[lora_name]
469
+ else:
470
+ # 通常、すべて対象とする
471
+ if is_linear or is_conv2d_1x1:
472
+ dim = default_dim if default_dim is not None else self.lora_dim
473
+ alpha = self.alpha
474
+ elif self.conv_lora_dim is not None:
475
+ dim = self.conv_lora_dim
476
+ alpha = self.conv_alpha
477
+
478
+ if dim is None or dim == 0:
479
+ # skipした情報を出力
480
+ if is_linear or is_conv2d_1x1 or (self.conv_lora_dim is not None):
481
+ skipped.append(lora_name)
482
+ continue
483
+
484
+ lora = module_class(
485
+ lora_name,
486
+ child_module,
487
+ self.multiplier,
488
+ dim,
489
+ alpha,
490
+ dropout=dropout,
491
+ rank_dropout=rank_dropout,
492
+ module_dropout=module_dropout,
493
+ )
494
+ loras.append(lora)
495
+
496
+ if target_replace_mods is None:
497
+ break # all modules are searched
498
+ return loras, skipped
499
+
500
+ # # create LoRA for text encoder
501
+ # # it is redundant to create LoRA modules even if they are not used
502
+
503
+ self.text_encoder_loras: List[Union[LoRAModule, LoRAInfModule]] = []
504
+ # skipped_te = []
505
+ # for i, text_encoder in enumerate(text_encoders):
506
+ # index = i
507
+ # if not train_t5xxl and index > 0: # 0: CLIP, 1: T5XXL, so we skip T5XXL if train_t5xxl is False
508
+ # break
509
+ # logger.info(f"create LoRA for Text Encoder {index+1}:")
510
+ # text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
511
+ # logger.info(f"create LoRA for Text Encoder {index+1}: {len(text_encoder_loras)} modules.")
512
+ # self.text_encoder_loras.extend(text_encoder_loras)
513
+ # skipped_te += skipped
514
+
515
+ # create LoRA for U-Net
516
+ self.unet_loras: List[Union[LoRAModule, LoRAInfModule]]
517
+ self.unet_loras, skipped_un = create_modules(True, prefix, unet, target_replace_modules)
518
+
519
+ logger.info(f"create LoRA for U-Net/DiT: {len(self.unet_loras)} modules.")
520
+ if verbose:
521
+ for lora in self.unet_loras:
522
+ logger.info(f"\t{lora.lora_name:50} {lora.lora_dim}, {lora.alpha}")
523
+
524
+ skipped = skipped_un
525
+ if verbose and len(skipped) > 0:
526
+ logger.warning(
527
+ f"because dim (rank) is 0, {len(skipped)} LoRA modules are skipped / dim (rank)が0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:"
528
+ )
529
+ for name in skipped:
530
+ logger.info(f"\t{name}")
531
+
532
+ # assertion
533
+ names = set()
534
+ for lora in self.text_encoder_loras + self.unet_loras:
535
+ assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
536
+ names.add(lora.lora_name)
537
+
538
+ def prepare_network(self, args):
539
+ """
540
+ called after the network is created
541
+ """
542
+ pass
543
+
544
+ def set_multiplier(self, multiplier):
545
+ self.multiplier = multiplier
546
+ for lora in self.text_encoder_loras + self.unet_loras:
547
+ lora.multiplier = self.multiplier
548
+
549
+ def set_enabled(self, is_enabled):
550
+ for lora in self.text_encoder_loras + self.unet_loras:
551
+ lora.enabled = is_enabled
552
+
553
+ def load_weights(self, file):
554
+ if os.path.splitext(file)[1] == ".safetensors":
555
+ from safetensors.torch import load_file
556
+
557
+ weights_sd = load_file(file)
558
+ else:
559
+ weights_sd = torch.load(file, map_location="cpu")
560
+
561
+ info = self.load_state_dict(weights_sd, False)
562
+ return info
563
+
564
+ def apply_to(
565
+ self,
566
+ text_encoders: Optional[nn.Module],
567
+ unet: Optional[nn.Module],
568
+ apply_text_encoder: bool = True,
569
+ apply_unet: bool = True,
570
+ ):
571
+ if apply_text_encoder:
572
+ logger.info(f"enable LoRA for text encoder: {len(self.text_encoder_loras)} modules")
573
+ else:
574
+ self.text_encoder_loras = []
575
+
576
+ if apply_unet:
577
+ logger.info(f"enable LoRA for U-Net: {len(self.unet_loras)} modules")
578
+ else:
579
+ self.unet_loras = []
580
+
581
+ for lora in self.text_encoder_loras + self.unet_loras:
582
+ lora.apply_to()
583
+ self.add_module(lora.lora_name, lora)
584
+
585
+ # マージできるかどうかを返す
586
+ def is_mergeable(self):
587
+ return True
588
+
589
+ # TODO refactor to common function with apply_to
590
+ def merge_to(self, text_encoders, unet, weights_sd, dtype=None, device=None):
591
+ for lora in self.text_encoder_loras + self.unet_loras:
592
+ sd_for_lora = {}
593
+ for key in weights_sd.keys():
594
+ if key.startswith(lora.lora_name):
595
+ sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key]
596
+ if len(sd_for_lora) == 0:
597
+ logger.info(f"no weight for {lora.lora_name}")
598
+ continue
599
+ lora.merge_to(sd_for_lora, dtype, device)
600
+
601
+ logger.info(f"weights are merged")
602
+
603
+ def set_loraplus_lr_ratio(self, loraplus_lr_ratio): # , loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio):
604
+ self.loraplus_lr_ratio = loraplus_lr_ratio
605
+
606
+ logger.info(f"LoRA+ UNet LR Ratio: {self.loraplus_lr_ratio}")
607
+ # logger.info(f"LoRA+ Text Encoder LR Ratio: {self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio}")
608
+
609
+ def prepare_optimizer_params(self, unet_lr: float = 1e-4, **kwargs):
610
+ self.requires_grad_(True)
611
+
612
+ all_params = []
613
+ lr_descriptions = []
614
+
615
+ def assemble_params(loras, lr, loraplus_ratio):
616
+ param_groups = {"lora": {}, "plus": {}}
617
+ for lora in loras:
618
+ for name, param in lora.named_parameters():
619
+ if loraplus_ratio is not None and "lora_up" in name:
620
+ param_groups["plus"][f"{lora.lora_name}.{name}"] = param
621
+ else:
622
+ param_groups["lora"][f"{lora.lora_name}.{name}"] = param
623
+
624
+ params = []
625
+ descriptions = []
626
+ for key in param_groups.keys():
627
+ param_data = {"params": param_groups[key].values()}
628
+
629
+ if len(param_data["params"]) == 0:
630
+ continue
631
+
632
+ if lr is not None:
633
+ if key == "plus":
634
+ param_data["lr"] = lr * loraplus_ratio
635
+ else:
636
+ param_data["lr"] = lr
637
+
638
+ if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None:
639
+ logger.info("NO LR skipping!")
640
+ continue
641
+
642
+ params.append(param_data)
643
+ descriptions.append("plus" if key == "plus" else "")
644
+
645
+ return params, descriptions
646
+
647
+ if self.unet_loras:
648
+ params, descriptions = assemble_params(self.unet_loras, unet_lr, self.loraplus_lr_ratio)
649
+ all_params.extend(params)
650
+ lr_descriptions.extend(["unet" + (" " + d if d else "") for d in descriptions])
651
+
652
+ return all_params, lr_descriptions
653
+
654
+ def enable_gradient_checkpointing(self):
655
+ # not supported
656
+ pass
657
+
658
+ def prepare_grad_etc(self, unet):
659
+ self.requires_grad_(True)
660
+
661
+ def on_epoch_start(self, unet):
662
+ self.train()
663
+
664
+ def on_step_start(self):
665
+ pass
666
+
667
+ def get_trainable_params(self):
668
+ return self.parameters()
669
+
670
+ def save_weights(self, file, dtype, metadata):
671
+ if metadata is not None and len(metadata) == 0:
672
+ metadata = None
673
+
674
+ state_dict = self.state_dict()
675
+
676
+ if dtype is not None:
677
+ for key in list(state_dict.keys()):
678
+ v = state_dict[key]
679
+ v = v.detach().clone().to("cpu").to(dtype)
680
+ state_dict[key] = v
681
+
682
+ if os.path.splitext(file)[1] == ".safetensors":
683
+ from safetensors.torch import save_file
684
+ from utils import model_utils
685
+
686
+ # Precalculate model hashes to save time on indexing
687
+ if metadata is None:
688
+ metadata = {}
689
+ model_hash, legacy_hash = model_utils.precalculate_safetensors_hashes(state_dict, metadata)
690
+ metadata["sshs_model_hash"] = model_hash
691
+ metadata["sshs_legacy_hash"] = legacy_hash
692
+
693
+ save_file(state_dict, file, metadata)
694
+ else:
695
+ torch.save(state_dict, file)
696
+
697
+ def backup_weights(self):
698
+ # 重みのバックアップを行う
699
+ loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras
700
+ for lora in loras:
701
+ org_module = lora.org_module_ref[0]
702
+ if not hasattr(org_module, "_lora_org_weight"):
703
+ sd = org_module.state_dict()
704
+ org_module._lora_org_weight = sd["weight"].detach().clone()
705
+ org_module._lora_restored = True
706
+
707
+ def restore_weights(self):
708
+ # 重みのリストアを行う
709
+ loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras
710
+ for lora in loras:
711
+ org_module = lora.org_module_ref[0]
712
+ if not org_module._lora_restored:
713
+ sd = org_module.state_dict()
714
+ sd["weight"] = org_module._lora_org_weight
715
+ org_module.load_state_dict(sd)
716
+ org_module._lora_restored = True
717
+
718
+ def pre_calculation(self):
719
+ # 事前計算を行う
720
+ loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras
721
+ for lora in loras:
722
+ org_module = lora.org_module_ref[0]
723
+ sd = org_module.state_dict()
724
+
725
+ org_weight = sd["weight"]
726
+ lora_weight = lora.get_weight().to(org_weight.device, dtype=org_weight.dtype)
727
+ sd["weight"] = org_weight + lora_weight
728
+ assert sd["weight"].shape == org_weight.shape
729
+ org_module.load_state_dict(sd)
730
+
731
+ org_module._lora_restored = False
732
+ lora.enabled = False
733
+
734
+ def apply_max_norm_regularization(self, max_norm_value, device):
735
+ downkeys = []
736
+ upkeys = []
737
+ alphakeys = []
738
+ norms = []
739
+ keys_scaled = 0
740
+
741
+ state_dict = self.state_dict()
742
+ for key in state_dict.keys():
743
+ if "lora_down" in key and "weight" in key:
744
+ downkeys.append(key)
745
+ upkeys.append(key.replace("lora_down", "lora_up"))
746
+ alphakeys.append(key.replace("lora_down.weight", "alpha"))
747
+
748
+ for i in range(len(downkeys)):
749
+ down = state_dict[downkeys[i]].to(device)
750
+ up = state_dict[upkeys[i]].to(device)
751
+ alpha = state_dict[alphakeys[i]].to(device)
752
+ dim = down.shape[0]
753
+ scale = alpha / dim
754
+
755
+ if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1):
756
+ updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3)
757
+ elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3):
758
+ updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute(1, 0, 2, 3)
759
+ else:
760
+ updown = up @ down
761
+
762
+ updown *= scale
763
+
764
+ norm = updown.norm().clamp(min=max_norm_value / 2)
765
+ desired = torch.clamp(norm, max=max_norm_value)
766
+ ratio = desired.cpu() / norm.cpu()
767
+ sqrt_ratio = ratio**0.5
768
+ if ratio != 1:
769
+ keys_scaled += 1
770
+ state_dict[upkeys[i]] *= sqrt_ratio
771
+ state_dict[downkeys[i]] *= sqrt_ratio
772
+ scalednorm = updown.norm() * ratio
773
+ norms.append(scalednorm.item())
774
+
775
+ return keys_scaled, sum(norms) / len(norms), max(norms)
776
+
777
+
778
+ def create_network_from_weights_hunyuan_video(
779
+ multiplier: float,
780
+ weights_sd: Dict[str, torch.Tensor],
781
+ text_encoders: Optional[List[nn.Module]] = None,
782
+ unet: Optional[nn.Module] = None,
783
+ for_inference: bool = False,
784
+ **kwargs,
785
+ ) -> LoRANetwork:
786
+ return create_network_from_weights(
787
+ HUNYUAN_TARGET_REPLACE_MODULES, multiplier, weights_sd, text_encoders, unet, for_inference, **kwargs
788
+ )
789
+
790
+
791
+ # Create network from weights for inference, weights are not loaded here (because can be merged)
792
+ def create_network_from_weights(
793
+ target_replace_modules: List[str],
794
+ multiplier: float,
795
+ weights_sd: Dict[str, torch.Tensor],
796
+ text_encoders: Optional[List[nn.Module]] = None,
797
+ unet: Optional[nn.Module] = None,
798
+ for_inference: bool = False,
799
+ **kwargs,
800
+ ) -> LoRANetwork:
801
+ # get dim/alpha mapping
802
+ modules_dim = {}
803
+ modules_alpha = {}
804
+ for key, value in weights_sd.items():
805
+ if "." not in key:
806
+ continue
807
+
808
+ lora_name = key.split(".")[0]
809
+ if "alpha" in key:
810
+ modules_alpha[lora_name] = value
811
+ elif "lora_down" in key:
812
+ dim = value.shape[0]
813
+ modules_dim[lora_name] = dim
814
+ # logger.info(lora_name, value.size(), dim)
815
+
816
+ module_class = LoRAInfModule if for_inference else LoRAModule
817
+
818
+ network = LoRANetwork(
819
+ target_replace_modules,
820
+ "lora_unet",
821
+ text_encoders,
822
+ unet,
823
+ multiplier=multiplier,
824
+ modules_dim=modules_dim,
825
+ modules_alpha=modules_alpha,
826
+ module_class=module_class,
827
+ )
828
+ return network
requirements.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==1.2.1
2
+ av==14.0.1
3
+ bitsandbytes==0.45.0
4
+ diffusers==0.32.1
5
+ einops==0.7.0
6
+ huggingface-hub==0.26.5
7
+ opencv-python==4.10.0.84
8
+ pillow==10.2.0
9
+ safetensors==0.4.5
10
+ toml==0.10.2
11
+ tqdm==4.67.1
12
+ transformers==4.46.3
13
+ voluptuous==0.15.2
14
+
15
+ # optional dependencies
16
+ # ascii-magic==2.3.0
17
+ # matplotlib==3.10.0
18
+ # tensorboard
utils/__init__.py ADDED
File without changes
utils/huggingface_utils.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import threading
2
+ from typing import Union, BinaryIO
3
+ from huggingface_hub import HfApi
4
+ from pathlib import Path
5
+ import argparse
6
+ import os
7
+ import logging
8
+
9
+ logger = logging.getLogger(__name__)
10
+ logging.basicConfig(level=logging.INFO)
11
+
12
+
13
+ def fire_in_thread(f, *args, **kwargs):
14
+ threading.Thread(target=f, args=args, kwargs=kwargs).start()
15
+
16
+
17
+ def exists_repo(repo_id: str, repo_type: str, revision: str = "main", token: str = None):
18
+ api = HfApi(
19
+ token=token,
20
+ )
21
+ try:
22
+ api.repo_info(repo_id=repo_id, revision=revision, repo_type=repo_type)
23
+ return True
24
+ except:
25
+ return False
26
+
27
+
28
+ def upload(
29
+ args: argparse.Namespace,
30
+ src: Union[str, Path, bytes, BinaryIO],
31
+ dest_suffix: str = "",
32
+ force_sync_upload: bool = False,
33
+ ):
34
+ repo_id = args.huggingface_repo_id
35
+ repo_type = args.huggingface_repo_type
36
+ token = args.huggingface_token
37
+ path_in_repo = args.huggingface_path_in_repo + dest_suffix if args.huggingface_path_in_repo is not None else None
38
+ private = args.huggingface_repo_visibility is None or args.huggingface_repo_visibility != "public"
39
+ api = HfApi(token=token)
40
+ if not exists_repo(repo_id=repo_id, repo_type=repo_type, token=token):
41
+ try:
42
+ api.create_repo(repo_id=repo_id, repo_type=repo_type, private=private)
43
+ except Exception as e: # RepositoryNotFoundError or something else
44
+ logger.error("===========================================")
45
+ logger.error(f"failed to create HuggingFace repo / HuggingFaceのリポジトリの作成に失敗しました : {e}")
46
+ logger.error("===========================================")
47
+
48
+ is_folder = (type(src) == str and os.path.isdir(src)) or (isinstance(src, Path) and src.is_dir())
49
+
50
+ def uploader():
51
+ try:
52
+ if is_folder:
53
+ api.upload_folder(
54
+ repo_id=repo_id,
55
+ repo_type=repo_type,
56
+ folder_path=src,
57
+ path_in_repo=path_in_repo,
58
+ )
59
+ else:
60
+ api.upload_file(
61
+ repo_id=repo_id,
62
+ repo_type=repo_type,
63
+ path_or_fileobj=src,
64
+ path_in_repo=path_in_repo,
65
+ )
66
+ except Exception as e: # RuntimeError or something else
67
+ logger.error("===========================================")
68
+ logger.error(f"failed to upload to HuggingFace / HuggingFaceへのアップロードに失敗しました : {e}")
69
+ logger.error("===========================================")
70
+
71
+ if args.async_upload and not force_sync_upload:
72
+ fire_in_thread(uploader)
73
+ else:
74
+ uploader()
75
+
76
+
77
+ def list_dir(
78
+ repo_id: str,
79
+ subfolder: str,
80
+ repo_type: str,
81
+ revision: str = "main",
82
+ token: str = None,
83
+ ):
84
+ api = HfApi(
85
+ token=token,
86
+ )
87
+ repo_info = api.repo_info(repo_id=repo_id, revision=revision, repo_type=repo_type)
88
+ file_list = [file for file in repo_info.siblings if file.rfilename.startswith(subfolder)]
89
+ return file_list
utils/model_utils.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ from io import BytesIO
3
+ from typing import Optional
4
+
5
+ import safetensors.torch
6
+ import torch
7
+
8
+
9
+ def model_hash(filename):
10
+ """Old model hash used by stable-diffusion-webui"""
11
+ try:
12
+ with open(filename, "rb") as file:
13
+ m = hashlib.sha256()
14
+
15
+ file.seek(0x100000)
16
+ m.update(file.read(0x10000))
17
+ return m.hexdigest()[0:8]
18
+ except FileNotFoundError:
19
+ return "NOFILE"
20
+ except IsADirectoryError: # Linux?
21
+ return "IsADirectory"
22
+ except PermissionError: # Windows
23
+ return "IsADirectory"
24
+
25
+
26
+ def calculate_sha256(filename):
27
+ """New model hash used by stable-diffusion-webui"""
28
+ try:
29
+ hash_sha256 = hashlib.sha256()
30
+ blksize = 1024 * 1024
31
+
32
+ with open(filename, "rb") as f:
33
+ for chunk in iter(lambda: f.read(blksize), b""):
34
+ hash_sha256.update(chunk)
35
+
36
+ return hash_sha256.hexdigest()
37
+ except FileNotFoundError:
38
+ return "NOFILE"
39
+ except IsADirectoryError: # Linux?
40
+ return "IsADirectory"
41
+ except PermissionError: # Windows
42
+ return "IsADirectory"
43
+
44
+
45
+ def addnet_hash_legacy(b):
46
+ """Old model hash used by sd-webui-additional-networks for .safetensors format files"""
47
+ m = hashlib.sha256()
48
+
49
+ b.seek(0x100000)
50
+ m.update(b.read(0x10000))
51
+ return m.hexdigest()[0:8]
52
+
53
+
54
+ def addnet_hash_safetensors(b):
55
+ """New model hash used by sd-webui-additional-networks for .safetensors format files"""
56
+ hash_sha256 = hashlib.sha256()
57
+ blksize = 1024 * 1024
58
+
59
+ b.seek(0)
60
+ header = b.read(8)
61
+ n = int.from_bytes(header, "little")
62
+
63
+ offset = n + 8
64
+ b.seek(offset)
65
+ for chunk in iter(lambda: b.read(blksize), b""):
66
+ hash_sha256.update(chunk)
67
+
68
+ return hash_sha256.hexdigest()
69
+
70
+
71
+ def precalculate_safetensors_hashes(tensors, metadata):
72
+ """Precalculate the model hashes needed by sd-webui-additional-networks to
73
+ save time on indexing the model later."""
74
+
75
+ # Because writing user metadata to the file can change the result of
76
+ # sd_models.model_hash(), only retain the training metadata for purposes of
77
+ # calculating the hash, as they are meant to be immutable
78
+ metadata = {k: v for k, v in metadata.items() if k.startswith("ss_")}
79
+
80
+ bytes = safetensors.torch.save(tensors, metadata)
81
+ b = BytesIO(bytes)
82
+
83
+ model_hash = addnet_hash_safetensors(b)
84
+ legacy_hash = addnet_hash_legacy(b)
85
+ return model_hash, legacy_hash
86
+
87
+
88
+ def dtype_to_str(dtype: torch.dtype) -> str:
89
+ # get name of the dtype
90
+ dtype_name = str(dtype).split(".")[-1]
91
+ return dtype_name
92
+
93
+
94
+ def str_to_dtype(s: Optional[str], default_dtype: Optional[torch.dtype] = None) -> torch.dtype:
95
+ """
96
+ Convert a string to a torch.dtype
97
+
98
+ Args:
99
+ s: string representation of the dtype
100
+ default_dtype: default dtype to return if s is None
101
+
102
+ Returns:
103
+ torch.dtype: the corresponding torch.dtype
104
+
105
+ Raises:
106
+ ValueError: if the dtype is not supported
107
+
108
+ Examples:
109
+ >>> str_to_dtype("float32")
110
+ torch.float32
111
+ >>> str_to_dtype("fp32")
112
+ torch.float32
113
+ >>> str_to_dtype("float16")
114
+ torch.float16
115
+ >>> str_to_dtype("fp16")
116
+ torch.float16
117
+ >>> str_to_dtype("bfloat16")
118
+ torch.bfloat16
119
+ >>> str_to_dtype("bf16")
120
+ torch.bfloat16
121
+ >>> str_to_dtype("fp8")
122
+ torch.float8_e4m3fn
123
+ >>> str_to_dtype("fp8_e4m3fn")
124
+ torch.float8_e4m3fn
125
+ >>> str_to_dtype("fp8_e4m3fnuz")
126
+ torch.float8_e4m3fnuz
127
+ >>> str_to_dtype("fp8_e5m2")
128
+ torch.float8_e5m2
129
+ >>> str_to_dtype("fp8_e5m2fnuz")
130
+ torch.float8_e5m2fnuz
131
+ """
132
+ if s is None:
133
+ return default_dtype
134
+ if s in ["bf16", "bfloat16"]:
135
+ return torch.bfloat16
136
+ elif s in ["fp16", "float16"]:
137
+ return torch.float16
138
+ elif s in ["fp32", "float32", "float"]:
139
+ return torch.float32
140
+ elif s in ["fp8_e4m3fn", "e4m3fn", "float8_e4m3fn"]:
141
+ return torch.float8_e4m3fn
142
+ elif s in ["fp8_e4m3fnuz", "e4m3fnuz", "float8_e4m3fnuz"]:
143
+ return torch.float8_e4m3fnuz
144
+ elif s in ["fp8_e5m2", "e5m2", "float8_e5m2"]:
145
+ return torch.float8_e5m2
146
+ elif s in ["fp8_e5m2fnuz", "e5m2fnuz", "float8_e5m2fnuz"]:
147
+ return torch.float8_e5m2fnuz
148
+ elif s in ["fp8", "float8"]:
149
+ return torch.float8_e4m3fn # default fp8
150
+ else:
151
+ raise ValueError(f"Unsupported dtype: {s}")
utils/safetensors_utils.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import json
3
+ import struct
4
+ from typing import Dict, Any, Union, Optional
5
+
6
+ from safetensors.torch import load_file
7
+
8
+
9
+ def mem_eff_save_file(tensors: Dict[str, torch.Tensor], filename: str, metadata: Dict[str, Any] = None):
10
+ """
11
+ memory efficient save file
12
+ """
13
+
14
+ _TYPES = {
15
+ torch.float64: "F64",
16
+ torch.float32: "F32",
17
+ torch.float16: "F16",
18
+ torch.bfloat16: "BF16",
19
+ torch.int64: "I64",
20
+ torch.int32: "I32",
21
+ torch.int16: "I16",
22
+ torch.int8: "I8",
23
+ torch.uint8: "U8",
24
+ torch.bool: "BOOL",
25
+ getattr(torch, "float8_e5m2", None): "F8_E5M2",
26
+ getattr(torch, "float8_e4m3fn", None): "F8_E4M3",
27
+ }
28
+ _ALIGN = 256
29
+
30
+ def validate_metadata(metadata: Dict[str, Any]) -> Dict[str, str]:
31
+ validated = {}
32
+ for key, value in metadata.items():
33
+ if not isinstance(key, str):
34
+ raise ValueError(f"Metadata key must be a string, got {type(key)}")
35
+ if not isinstance(value, str):
36
+ print(f"Warning: Metadata value for key '{key}' is not a string. Converting to string.")
37
+ validated[key] = str(value)
38
+ else:
39
+ validated[key] = value
40
+ return validated
41
+
42
+ # print(f"Using memory efficient save file: {filename}")
43
+
44
+ header = {}
45
+ offset = 0
46
+ if metadata:
47
+ header["__metadata__"] = validate_metadata(metadata)
48
+ for k, v in tensors.items():
49
+ if v.numel() == 0: # empty tensor
50
+ header[k] = {"dtype": _TYPES[v.dtype], "shape": list(v.shape), "data_offsets": [offset, offset]}
51
+ else:
52
+ size = v.numel() * v.element_size()
53
+ header[k] = {"dtype": _TYPES[v.dtype], "shape": list(v.shape), "data_offsets": [offset, offset + size]}
54
+ offset += size
55
+
56
+ hjson = json.dumps(header).encode("utf-8")
57
+ hjson += b" " * (-(len(hjson) + 8) % _ALIGN)
58
+
59
+ with open(filename, "wb") as f:
60
+ f.write(struct.pack("<Q", len(hjson)))
61
+ f.write(hjson)
62
+
63
+ for k, v in tensors.items():
64
+ if v.numel() == 0:
65
+ continue
66
+ if v.is_cuda:
67
+ # Direct GPU to disk save
68
+ with torch.cuda.device(v.device):
69
+ if v.dim() == 0: # if scalar, need to add a dimension to work with view
70
+ v = v.unsqueeze(0)
71
+ tensor_bytes = v.contiguous().view(torch.uint8)
72
+ tensor_bytes.cpu().numpy().tofile(f)
73
+ else:
74
+ # CPU tensor save
75
+ if v.dim() == 0: # if scalar, need to add a dimension to work with view
76
+ v = v.unsqueeze(0)
77
+ v.contiguous().view(torch.uint8).numpy().tofile(f)
78
+
79
+
80
+ class MemoryEfficientSafeOpen:
81
+ # does not support metadata loading
82
+ def __init__(self, filename):
83
+ self.filename = filename
84
+ self.file = open(filename, "rb")
85
+ self.header, self.header_size = self._read_header()
86
+
87
+ def __enter__(self):
88
+ return self
89
+
90
+ def __exit__(self, exc_type, exc_val, exc_tb):
91
+ self.file.close()
92
+
93
+ def keys(self):
94
+ return [k for k in self.header.keys() if k != "__metadata__"]
95
+
96
+ def metadata(self) -> Dict[str, str]:
97
+ return self.header.get("__metadata__", {})
98
+
99
+ def get_tensor(self, key):
100
+ if key not in self.header:
101
+ raise KeyError(f"Tensor '{key}' not found in the file")
102
+
103
+ metadata = self.header[key]
104
+ offset_start, offset_end = metadata["data_offsets"]
105
+
106
+ if offset_start == offset_end:
107
+ tensor_bytes = None
108
+ else:
109
+ # adjust offset by header size
110
+ self.file.seek(self.header_size + 8 + offset_start)
111
+ tensor_bytes = self.file.read(offset_end - offset_start)
112
+
113
+ return self._deserialize_tensor(tensor_bytes, metadata)
114
+
115
+ def _read_header(self):
116
+ header_size = struct.unpack("<Q", self.file.read(8))[0]
117
+ header_json = self.file.read(header_size).decode("utf-8")
118
+ return json.loads(header_json), header_size
119
+
120
+ def _deserialize_tensor(self, tensor_bytes, metadata):
121
+ dtype = self._get_torch_dtype(metadata["dtype"])
122
+ shape = metadata["shape"]
123
+
124
+ if tensor_bytes is None:
125
+ byte_tensor = torch.empty(0, dtype=torch.uint8)
126
+ else:
127
+ tensor_bytes = bytearray(tensor_bytes) # make it writable
128
+ byte_tensor = torch.frombuffer(tensor_bytes, dtype=torch.uint8)
129
+
130
+ # process float8 types
131
+ if metadata["dtype"] in ["F8_E5M2", "F8_E4M3"]:
132
+ return self._convert_float8(byte_tensor, metadata["dtype"], shape)
133
+
134
+ # convert to the target dtype and reshape
135
+ return byte_tensor.view(dtype).reshape(shape)
136
+
137
+ @staticmethod
138
+ def _get_torch_dtype(dtype_str):
139
+ dtype_map = {
140
+ "F64": torch.float64,
141
+ "F32": torch.float32,
142
+ "F16": torch.float16,
143
+ "BF16": torch.bfloat16,
144
+ "I64": torch.int64,
145
+ "I32": torch.int32,
146
+ "I16": torch.int16,
147
+ "I8": torch.int8,
148
+ "U8": torch.uint8,
149
+ "BOOL": torch.bool,
150
+ }
151
+ # add float8 types if available
152
+ if hasattr(torch, "float8_e5m2"):
153
+ dtype_map["F8_E5M2"] = torch.float8_e5m2
154
+ if hasattr(torch, "float8_e4m3fn"):
155
+ dtype_map["F8_E4M3"] = torch.float8_e4m3fn
156
+ return dtype_map.get(dtype_str)
157
+
158
+ @staticmethod
159
+ def _convert_float8(byte_tensor, dtype_str, shape):
160
+ if dtype_str == "F8_E5M2" and hasattr(torch, "float8_e5m2"):
161
+ return byte_tensor.view(torch.float8_e5m2).reshape(shape)
162
+ elif dtype_str == "F8_E4M3" and hasattr(torch, "float8_e4m3fn"):
163
+ return byte_tensor.view(torch.float8_e4m3fn).reshape(shape)
164
+ else:
165
+ # # convert to float16 if float8 is not supported
166
+ # print(f"Warning: {dtype_str} is not supported in this PyTorch version. Converting to float16.")
167
+ # return byte_tensor.view(torch.uint8).to(torch.float16).reshape(shape)
168
+ raise ValueError(f"Unsupported float8 type: {dtype_str} (upgrade PyTorch to support float8 types)")
169
+
170
+
171
+ def load_safetensors(
172
+ path: str, device: Union[str, torch.device], disable_mmap: bool = False, dtype: Optional[torch.dtype] = torch.float32
173
+ ) -> dict[str, torch.Tensor]:
174
+ if disable_mmap:
175
+ # return safetensors.torch.load(open(path, "rb").read())
176
+ # use experimental loader
177
+ # logger.info(f"Loading without mmap (experimental)")
178
+ state_dict = {}
179
+ with MemoryEfficientSafeOpen(path) as f:
180
+ for key in f.keys():
181
+ state_dict[key] = f.get_tensor(key).to(device, dtype=dtype)
182
+ return state_dict
183
+ else:
184
+ try:
185
+ state_dict = load_file(path, device=device)
186
+ except:
187
+ state_dict = load_file(path) # prevent device invalid Error
188
+ if dtype is not None:
189
+ for key in state_dict.keys():
190
+ state_dict[key] = state_dict[key].to(dtype=dtype)
191
+ return state_dict
utils/sai_model_spec.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # based on https://github.com/Stability-AI/ModelSpec
2
+ import datetime
3
+ import hashlib
4
+ from io import BytesIO
5
+ import os
6
+ from typing import List, Optional, Tuple, Union
7
+ import safetensors
8
+ import logging
9
+
10
+ logger = logging.getLogger(__name__)
11
+ logger.setLevel(logging.INFO)
12
+
13
+
14
+ r"""
15
+ # Metadata Example
16
+ metadata = {
17
+ # === Must ===
18
+ "modelspec.sai_model_spec": "1.0.0", # Required version ID for the spec
19
+ "modelspec.architecture": "stable-diffusion-xl-v1-base", # Architecture, reference the ID of the original model of the arch to match the ID
20
+ "modelspec.implementation": "sgm",
21
+ "modelspec.title": "Example Model Version 1.0", # Clean, human-readable title. May use your own phrasing/language/etc
22
+ # === Should ===
23
+ "modelspec.author": "Example Corp", # Your name or company name
24
+ "modelspec.description": "This is my example model to show you how to do it!", # Describe the model in your own words/language/etc. Focus on what users need to know
25
+ "modelspec.date": "2023-07-20", # ISO-8601 compliant date of when the model was created
26
+ # === Can ===
27
+ "modelspec.license": "ExampleLicense-1.0", # eg CreativeML Open RAIL, etc.
28
+ "modelspec.usage_hint": "Use keyword 'example'" # In your own language, very short hints about how the user should use the model
29
+ }
30
+ """
31
+
32
+ BASE_METADATA = {
33
+ # === Must ===
34
+ "modelspec.sai_model_spec": "1.0.0", # Required version ID for the spec
35
+ "modelspec.architecture": None,
36
+ "modelspec.implementation": None,
37
+ "modelspec.title": None,
38
+ "modelspec.resolution": None,
39
+ # === Should ===
40
+ "modelspec.description": None,
41
+ "modelspec.author": None,
42
+ "modelspec.date": None,
43
+ # === Can ===
44
+ "modelspec.license": None,
45
+ "modelspec.tags": None,
46
+ "modelspec.merged_from": None,
47
+ "modelspec.prediction_type": None,
48
+ "modelspec.timestep_range": None,
49
+ "modelspec.encoder_layer": None,
50
+ }
51
+
52
+ # 別に使うやつだけ定義
53
+ MODELSPEC_TITLE = "modelspec.title"
54
+
55
+ ARCH_HUNYUAN_VIDEO = "hunyuan-video"
56
+
57
+ ADAPTER_LORA = "lora"
58
+
59
+ IMPL_HUNYUAN_VIDEO = "https://github.com/Tencent/HunyuanVideo"
60
+
61
+ PRED_TYPE_EPSILON = "epsilon"
62
+ # PRED_TYPE_V = "v"
63
+
64
+
65
+ def load_bytes_in_safetensors(tensors):
66
+ bytes = safetensors.torch.save(tensors)
67
+ b = BytesIO(bytes)
68
+
69
+ b.seek(0)
70
+ header = b.read(8)
71
+ n = int.from_bytes(header, "little")
72
+
73
+ offset = n + 8
74
+ b.seek(offset)
75
+
76
+ return b.read()
77
+
78
+
79
+ def precalculate_safetensors_hashes(state_dict):
80
+ # calculate each tensor one by one to reduce memory usage
81
+ hash_sha256 = hashlib.sha256()
82
+ for tensor in state_dict.values():
83
+ single_tensor_sd = {"tensor": tensor}
84
+ bytes_for_tensor = load_bytes_in_safetensors(single_tensor_sd)
85
+ hash_sha256.update(bytes_for_tensor)
86
+
87
+ return f"0x{hash_sha256.hexdigest()}"
88
+
89
+
90
+ def update_hash_sha256(metadata: dict, state_dict: dict):
91
+ raise NotImplementedError
92
+
93
+
94
+ def build_metadata(
95
+ state_dict: Optional[dict],
96
+ timestamp: float,
97
+ title: Optional[str] = None,
98
+ reso: Optional[Union[int, Tuple[int, int]]] = None,
99
+ author: Optional[str] = None,
100
+ description: Optional[str] = None,
101
+ license: Optional[str] = None,
102
+ tags: Optional[str] = None,
103
+ merged_from: Optional[str] = None,
104
+ timesteps: Optional[Tuple[int, int]] = None,
105
+ ):
106
+ metadata = {}
107
+ metadata.update(BASE_METADATA)
108
+
109
+ # TODO implement if we can calculate hash without loading all tensors
110
+ # if state_dict is not None:
111
+ # hash = precalculate_safetensors_hashes(state_dict)
112
+ # metadata["modelspec.hash_sha256"] = hash
113
+
114
+ arch = ARCH_HUNYUAN_VIDEO
115
+ arch += f"/{ADAPTER_LORA}"
116
+ metadata["modelspec.architecture"] = arch
117
+
118
+ impl = IMPL_HUNYUAN_VIDEO
119
+ metadata["modelspec.implementation"] = impl
120
+
121
+ if title is None:
122
+ title = "LoRA"
123
+ title += f"@{timestamp}"
124
+ metadata[MODELSPEC_TITLE] = title
125
+
126
+ if author is not None:
127
+ metadata["modelspec.author"] = author
128
+ else:
129
+ del metadata["modelspec.author"]
130
+
131
+ if description is not None:
132
+ metadata["modelspec.description"] = description
133
+ else:
134
+ del metadata["modelspec.description"]
135
+
136
+ if merged_from is not None:
137
+ metadata["modelspec.merged_from"] = merged_from
138
+ else:
139
+ del metadata["modelspec.merged_from"]
140
+
141
+ if license is not None:
142
+ metadata["modelspec.license"] = license
143
+ else:
144
+ del metadata["modelspec.license"]
145
+
146
+ if tags is not None:
147
+ metadata["modelspec.tags"] = tags
148
+ else:
149
+ del metadata["modelspec.tags"]
150
+
151
+ # remove microsecond from time
152
+ int_ts = int(timestamp)
153
+
154
+ # time to iso-8601 compliant date
155
+ date = datetime.datetime.fromtimestamp(int_ts).isoformat()
156
+ metadata["modelspec.date"] = date
157
+
158
+ if reso is not None:
159
+ # comma separated to tuple
160
+ if isinstance(reso, str):
161
+ reso = tuple(map(int, reso.split(",")))
162
+ if len(reso) == 1:
163
+ reso = (reso[0], reso[0])
164
+ else:
165
+ # resolution is defined in dataset, so use default
166
+ reso = (1280, 720)
167
+ if isinstance(reso, int):
168
+ reso = (reso, reso)
169
+
170
+ metadata["modelspec.resolution"] = f"{reso[0]}x{reso[1]}"
171
+
172
+ # metadata["modelspec.prediction_type"] = PRED_TYPE_EPSILON
173
+ del metadata["modelspec.prediction_type"]
174
+
175
+ if timesteps is not None:
176
+ if isinstance(timesteps, str) or isinstance(timesteps, int):
177
+ timesteps = (timesteps, timesteps)
178
+ if len(timesteps) == 1:
179
+ timesteps = (timesteps[0], timesteps[0])
180
+ metadata["modelspec.timestep_range"] = f"{timesteps[0]},{timesteps[1]}"
181
+ else:
182
+ del metadata["modelspec.timestep_range"]
183
+
184
+ # if clip_skip is not None:
185
+ # metadata["modelspec.encoder_layer"] = f"{clip_skip}"
186
+ # else:
187
+ del metadata["modelspec.encoder_layer"]
188
+
189
+ # # assert all values are filled
190
+ # assert all([v is not None for v in metadata.values()]), metadata
191
+ if not all([v is not None for v in metadata.values()]):
192
+ logger.error(f"Internal error: some metadata values are None: {metadata}")
193
+
194
+ return metadata
195
+
196
+
197
+ # region utils
198
+
199
+
200
+ def get_title(metadata: dict) -> Optional[str]:
201
+ return metadata.get(MODELSPEC_TITLE, None)
202
+
203
+
204
+ def load_metadata_from_safetensors(model: str) -> dict:
205
+ if not model.endswith(".safetensors"):
206
+ return {}
207
+
208
+ with safetensors.safe_open(model, framework="pt") as f:
209
+ metadata = f.metadata()
210
+ if metadata is None:
211
+ metadata = {}
212
+ return metadata
213
+
214
+
215
+ def build_merged_from(models: List[str]) -> str:
216
+ def get_title(model: str):
217
+ metadata = load_metadata_from_safetensors(model)
218
+ title = metadata.get(MODELSPEC_TITLE, None)
219
+ if title is None:
220
+ title = os.path.splitext(os.path.basename(model))[0] # use filename
221
+ return title
222
+
223
+ titles = [get_title(model) for model in models]
224
+ return ", ".join(titles)
225
+
226
+
227
+ # endregion
228
+
229
+
230
+ r"""
231
+ if __name__ == "__main__":
232
+ import argparse
233
+ import torch
234
+ from safetensors.torch import load_file
235
+ from library import train_util
236
+
237
+ parser = argparse.ArgumentParser()
238
+ parser.add_argument("--ckpt", type=str, required=True)
239
+ args = parser.parse_args()
240
+
241
+ print(f"Loading {args.ckpt}")
242
+ state_dict = load_file(args.ckpt)
243
+
244
+ print(f"Calculating metadata")
245
+ metadata = get(state_dict, False, False, False, False, "sgm", False, False, "title", "date", 256, 1000, 0)
246
+ print(metadata)
247
+ del state_dict
248
+
249
+ # by reference implementation
250
+ with open(args.ckpt, mode="rb") as file_data:
251
+ file_hash = hashlib.sha256()
252
+ head_len = struct.unpack("Q", file_data.read(8)) # int64 header length prefix
253
+ header = json.loads(file_data.read(head_len[0])) # header itself, json string
254
+ content = (
255
+ file_data.read()
256
+ ) # All other content is tightly packed tensors. Copy to RAM for simplicity, but you can avoid this read with a more careful FS-dependent impl.
257
+ file_hash.update(content)
258
+ # ===== Update the hash for modelspec =====
259
+ by_ref = f"0x{file_hash.hexdigest()}"
260
+ print(by_ref)
261
+ print("is same?", by_ref == metadata["modelspec.hash_sha256"])
262
+
263
+ """
utils/train_utils.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ import os
4
+ import shutil
5
+
6
+ import accelerate
7
+ import torch
8
+
9
+ from utils import huggingface_utils
10
+
11
+ logger = logging.getLogger(__name__)
12
+ logging.basicConfig(level=logging.INFO)
13
+
14
+
15
+ # checkpointファイル名
16
+ EPOCH_STATE_NAME = "{}-{:06d}-state"
17
+ EPOCH_FILE_NAME = "{}-{:06d}"
18
+ EPOCH_DIFFUSERS_DIR_NAME = "{}-{:06d}"
19
+ LAST_STATE_NAME = "{}-state"
20
+ STEP_STATE_NAME = "{}-step{:08d}-state"
21
+ STEP_FILE_NAME = "{}-step{:08d}"
22
+ STEP_DIFFUSERS_DIR_NAME = "{}-step{:08d}"
23
+
24
+
25
+ def get_sanitized_config_or_none(args: argparse.Namespace):
26
+ # if `--log_config` is enabled, return args for logging. if not, return None.
27
+ # when `--log_config is enabled, filter out sensitive values from args
28
+ # if wandb is not enabled, the log is not exposed to the public, but it is fine to filter out sensitive values to be safe
29
+
30
+ if not args.log_config:
31
+ return None
32
+
33
+ sensitive_args = ["wandb_api_key", "huggingface_token"]
34
+ sensitive_path_args = [
35
+ "dit",
36
+ "vae",
37
+ "text_encoder1",
38
+ "text_encoder2",
39
+ "base_weights",
40
+ "network_weights",
41
+ "output_dir",
42
+ "logging_dir",
43
+ ]
44
+ filtered_args = {}
45
+ for k, v in vars(args).items():
46
+ # filter out sensitive values and convert to string if necessary
47
+ if k not in sensitive_args + sensitive_path_args:
48
+ # Accelerate values need to have type `bool`,`str`, `float`, `int`, or `None`.
49
+ if v is None or isinstance(v, bool) or isinstance(v, str) or isinstance(v, float) or isinstance(v, int):
50
+ filtered_args[k] = v
51
+ # accelerate does not support lists
52
+ elif isinstance(v, list):
53
+ filtered_args[k] = f"{v}"
54
+ # accelerate does not support objects
55
+ elif isinstance(v, object):
56
+ filtered_args[k] = f"{v}"
57
+
58
+ return filtered_args
59
+
60
+
61
+ class LossRecorder:
62
+ def __init__(self):
63
+ self.loss_list: list[float] = []
64
+ self.loss_total: float = 0.0
65
+
66
+ def add(self, *, epoch: int, step: int, loss: float) -> None:
67
+ if epoch == 0:
68
+ self.loss_list.append(loss)
69
+ else:
70
+ while len(self.loss_list) <= step:
71
+ self.loss_list.append(0.0)
72
+ self.loss_total -= self.loss_list[step]
73
+ self.loss_list[step] = loss
74
+ self.loss_total += loss
75
+
76
+ @property
77
+ def moving_average(self) -> float:
78
+ return self.loss_total / len(self.loss_list)
79
+
80
+
81
+ def get_epoch_ckpt_name(model_name, epoch_no: int):
82
+ return EPOCH_FILE_NAME.format(model_name, epoch_no) + ".safetensors"
83
+
84
+
85
+ def get_step_ckpt_name(model_name, step_no: int):
86
+ return STEP_FILE_NAME.format(model_name, step_no) + ".safetensors"
87
+
88
+
89
+ def get_last_ckpt_name(model_name):
90
+ return model_name + ".safetensors"
91
+
92
+
93
+ def get_remove_epoch_no(args: argparse.Namespace, epoch_no: int):
94
+ if args.save_last_n_epochs is None:
95
+ return None
96
+
97
+ remove_epoch_no = epoch_no - args.save_every_n_epochs * args.save_last_n_epochs
98
+ if remove_epoch_no < 0:
99
+ return None
100
+ return remove_epoch_no
101
+
102
+
103
+ def get_remove_step_no(args: argparse.Namespace, step_no: int):
104
+ if args.save_last_n_steps is None:
105
+ return None
106
+
107
+ # calculate the step number to remove from the last_n_steps and save_every_n_steps
108
+ # e.g. if save_every_n_steps=10, save_last_n_steps=30, at step 50, keep 30 steps and remove step 10
109
+ remove_step_no = step_no - args.save_last_n_steps - 1
110
+ remove_step_no = remove_step_no - (remove_step_no % args.save_every_n_steps)
111
+ if remove_step_no < 0:
112
+ return None
113
+ return remove_step_no
114
+
115
+
116
+ def save_and_remove_state_on_epoch_end(args: argparse.Namespace, accelerator: accelerate.Accelerator, epoch_no: int):
117
+ model_name = args.output_name
118
+
119
+ logger.info("")
120
+ logger.info(f"saving state at epoch {epoch_no}")
121
+ os.makedirs(args.output_dir, exist_ok=True)
122
+
123
+ state_dir = os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, epoch_no))
124
+ accelerator.save_state(state_dir)
125
+ if args.save_state_to_huggingface:
126
+ logger.info("uploading state to huggingface.")
127
+ huggingface_utils.upload(args, state_dir, "/" + EPOCH_STATE_NAME.format(model_name, epoch_no))
128
+
129
+ last_n_epochs = args.save_last_n_epochs_state if args.save_last_n_epochs_state else args.save_last_n_epochs
130
+ if last_n_epochs is not None:
131
+ remove_epoch_no = epoch_no - args.save_every_n_epochs * last_n_epochs
132
+ state_dir_old = os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, remove_epoch_no))
133
+ if os.path.exists(state_dir_old):
134
+ logger.info(f"removing old state: {state_dir_old}")
135
+ shutil.rmtree(state_dir_old)
136
+
137
+
138
+ def save_and_remove_state_stepwise(args: argparse.Namespace, accelerator: accelerate.Accelerator, step_no: int):
139
+ model_name = args.output_name
140
+
141
+ logger.info("")
142
+ logger.info(f"saving state at step {step_no}")
143
+ os.makedirs(args.output_dir, exist_ok=True)
144
+
145
+ state_dir = os.path.join(args.output_dir, STEP_STATE_NAME.format(model_name, step_no))
146
+ accelerator.save_state(state_dir)
147
+ if args.save_state_to_huggingface:
148
+ logger.info("uploading state to huggingface.")
149
+ huggingface_utils.upload(args, state_dir, "/" + STEP_STATE_NAME.format(model_name, step_no))
150
+
151
+ last_n_steps = args.save_last_n_steps_state if args.save_last_n_steps_state else args.save_last_n_steps
152
+ if last_n_steps is not None:
153
+ # last_n_steps前のstep_noから、save_every_n_stepsの倍数のstep_noを計算して削除する
154
+ remove_step_no = step_no - last_n_steps - 1
155
+ remove_step_no = remove_step_no - (remove_step_no % args.save_every_n_steps)
156
+
157
+ if remove_step_no > 0:
158
+ state_dir_old = os.path.join(args.output_dir, STEP_STATE_NAME.format(model_name, remove_step_no))
159
+ if os.path.exists(state_dir_old):
160
+ logger.info(f"removing old state: {state_dir_old}")
161
+ shutil.rmtree(state_dir_old)
162
+
163
+
164
+ def save_state_on_train_end(args: argparse.Namespace, accelerator: accelerate.Accelerator):
165
+ model_name = args.output_name
166
+
167
+ logger.info("")
168
+ logger.info("saving last state.")
169
+ os.makedirs(args.output_dir, exist_ok=True)
170
+
171
+ state_dir = os.path.join(args.output_dir, LAST_STATE_NAME.format(model_name))
172
+ accelerator.save_state(state_dir)
173
+
174
+ if args.save_state_to_huggingface:
175
+ logger.info("uploading last state to huggingface.")
176
+ huggingface_utils.upload(args, state_dir, "/" + LAST_STATE_NAME.format(model_name))
177
+
zhongli_ningguang_couple_im_lora_dir/zhongli_ningguang_couple_im_lora-000001.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c83468be7bc357b777fa900e3ed4fd4452a142ba46dde32074a7d7e15ba9695c
3
+ size 322557560
zhongli_ningguang_couple_im_lora_dir/zhongli_ningguang_couple_im_lora-000002.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8a56b9f52c800b05f93a1437475942deae891d04dab83ae2ad34e179606fbdce
3
+ size 322557560
zhongli_ningguang_couple_im_lora_dir/zhongli_ningguang_couple_im_lora-000003.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:24b72675e75de1ef2e06d9a7abfb4bfdd17c70ce26901beccd5085dfffe664d8
3
+ size 322557560
zhongli_ningguang_couple_im_lora_dir/zhongli_ningguang_couple_im_lora-000004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9f0aa3e1a2610f9d1e80cffe36a022a96d03c6890a47de2cd48e7fef5917546e
3
+ size 322557560
zhongli_ningguang_couple_im_lora_dir/zhongli_ningguang_couple_im_lora-000005.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:536d31441ec18c50c3db6ab089f53bf14b62ce9ad9256e1fbd64b9e0dd58a8e6
3
+ size 322557560
zhongli_ningguang_couple_im_lora_dir/zhongli_ningguang_couple_im_lora-000006.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d6f51d038ea3d3bbe7e30b780c7e69ebf932029c57daf45f5f2b7802b7617464
3
+ size 322557560
zhongli_ningguang_couple_im_lora_dir/zhongli_ningguang_couple_im_lora-000007.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9fcb428b85762e86071a631e4e2d6b438ac8d157e616db50c05a7ec2efbc67dc
3
+ size 322557560
zhongli_ningguang_couple_im_lora_dir/zhongli_ningguang_couple_im_lora-000008.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:57a00a9e3a7f5ebd7de3fe1ae947f73e55c5a0c06f5adb25226d8e606d20992b
3
+ size 322557560
zhongli_ningguang_couple_im_lora_dir/zhongli_ningguang_couple_im_lora-000009.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:db9255ec51a3015974ea04de5d309682e38e2130683d9d4f2283aa6cda120021
3
+ size 322557560
zhongli_ningguang_couple_im_lora_dir/zhongli_ningguang_couple_im_lora-000010.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:07ff9643355fddf8f583dc94d73e412223630d24a338c59bfc7d3dc9c8107eca
3
+ size 322557560
zhongli_ningguang_couple_im_lora_dir/zhongli_ningguang_couple_im_lora-000011.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:272ef2da97d9e142084b5dbef41d419a65c8080c316a46ff9870caf793236fdd
3
+ size 322557560
zhongli_ningguang_couple_im_lora_dir/zhongli_ningguang_couple_im_lora-000012.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c0308166fce29bed867ad7c31ac796d5b21264c8c0ba0638a468dea6cec2dfcd
3
+ size 322557560