bubbliiiing commited on
Commit
f62c8b9
1 Parent(s): ab9a89a
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. app.py +31 -8
  2. config/easyanimate_image_magvit_v2.yaml +0 -8
  3. config/easyanimate_image_normal_v1.yaml +0 -8
  4. config/easyanimate_image_slicevae_v3.yaml +0 -9
  5. config/easyanimate_video_casual_motion_module_v1.yaml +0 -27
  6. config/easyanimate_video_long_sequence_v1.yaml +0 -14
  7. config/{easyanimate_video_motion_module_v1.yaml → easyanimate_video_v1_motion_module.yaml} +5 -7
  8. config/{easyanimate_video_slicevae_motion_module_v3.yaml → easyanimate_video_v2_magvit_motion_module.yaml} +11 -9
  9. config/{easyanimate_video_magvit_motion_module_v2.yaml → easyanimate_video_v3_slicevae_motion_module.yaml} +24 -11
  10. config/easyanimate_video_v4_slicevae_multi_text_encoder.yaml +20 -0
  11. config/easyanimate_video_v5_magvit_multi_text_encoder.yaml +19 -0
  12. config/zero_stage2_config.json +16 -0
  13. easyanimate/api/api.py +55 -9
  14. easyanimate/api/post_infer.py +0 -1
  15. easyanimate/data/dataset_image_video.py +311 -22
  16. easyanimate/models/__init__.py +16 -0
  17. easyanimate/models/attention.py +437 -659
  18. easyanimate/models/autoencoder_magvit.py +520 -4
  19. easyanimate/models/embeddings.py +107 -0
  20. easyanimate/models/norm.py +55 -2
  21. easyanimate/models/patch.py +0 -9
  22. easyanimate/models/processor.py +312 -0
  23. easyanimate/models/resampler.py +146 -0
  24. easyanimate/models/transformer2d.py +23 -58
  25. easyanimate/models/transformer3d.py +762 -70
  26. easyanimate/pipeline/pipeline_easyanimate.py +29 -39
  27. easyanimate/pipeline/pipeline_easyanimate_inpaint.py +90 -138
  28. easyanimate/pipeline/pipeline_easyanimate_multi_text_encoder.py +925 -0
  29. easyanimate/pipeline/pipeline_easyanimate_multi_text_encoder_control.py +996 -0
  30. easyanimate/pipeline/pipeline_easyanimate_multi_text_encoder_inpaint.py +1334 -0
  31. easyanimate/ui/ui.py +0 -0
  32. easyanimate/utils/discrete_sampler.py +46 -0
  33. easyanimate/utils/fp8_optimization.py +28 -0
  34. easyanimate/utils/lora_utils.py +26 -20
  35. easyanimate/utils/utils.py +64 -20
  36. easyanimate/vae/configs/autoencoder/autoencoder_kl_32x32x4_cogvideox.yaml +64 -0
  37. easyanimate/vae/configs/autoencoder/autoencoder_kl_32x32x4_mag_v2.yaml +65 -0
  38. easyanimate/vae/ldm/data/dataset_callback.py +1 -0
  39. easyanimate/vae/ldm/data/dataset_image_video.py +7 -4
  40. easyanimate/vae/ldm/models/casual3dcnn.py +337 -0
  41. easyanimate/vae/ldm/models/cogvideox_casual3dcnn.py +326 -0
  42. easyanimate/vae/ldm/models/cogvideox_enc_dec.py +312 -0
  43. easyanimate/vae/ldm/models/{enc_dec_pytorch.py → enc_dec.py} +0 -0
  44. easyanimate/vae/ldm/models/omnigen_casual3dcnn.py +48 -28
  45. easyanimate/vae/ldm/models/omnigen_enc_dec.py +296 -27
  46. easyanimate/vae/ldm/modules/ema.py +2 -1
  47. easyanimate/vae/ldm/modules/losses/contperceptual.py +2 -9
  48. easyanimate/vae/ldm/modules/vaemodules/common.py +106 -27
  49. easyanimate/vae/ldm/modules/vaemodules/upsamplers.py +4 -23
  50. easyanimate/video_caption/README.md +0 -90
app.py CHANGED
@@ -1,27 +1,50 @@
1
- import time
2
 
3
- from easyanimate.api.api import infer_forward_api, update_diffusion_transformer_api, update_edition_api
4
- from easyanimate.ui.ui import ui_modelscope, ui_eas, ui
 
 
 
 
5
 
6
  if __name__ == "__main__":
7
  # Choose the ui mode
8
  ui_mode = "eas"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  # Server ip
10
  server_name = "0.0.0.0"
11
  server_port = 7860
12
 
13
  # Params below is used when ui_mode = "modelscope"
14
- edition = "v3"
15
- config_path = "config/easyanimate_video_slicevae_motion_module_v3.yaml"
16
- model_name = "models/Diffusion_Transformer/EasyAnimateV3-XL-2-InP-512x512"
 
 
 
 
 
17
  savedir_sample = "samples"
18
 
19
  if ui_mode == "modelscope":
20
- demo, controller = ui_modelscope(edition, config_path, model_name, savedir_sample)
21
  elif ui_mode == "eas":
22
  demo, controller = ui_eas(edition, config_path, model_name, savedir_sample)
23
  else:
24
- demo, controller = ui()
25
 
26
  # launch gradio
27
  app, _, _ = demo.queue(status_update_rate=1).launch(
 
1
+ import time
2
 
3
+ import torch
4
+
5
+ from easyanimate.api.api import (infer_forward_api,
6
+ update_diffusion_transformer_api,
7
+ update_edition_api)
8
+ from easyanimate.ui.ui import ui, ui_eas, ui_modelscope
9
 
10
  if __name__ == "__main__":
11
  # Choose the ui mode
12
  ui_mode = "eas"
13
+
14
+ # GPU memory mode, which can be choosen in ["model_cpu_offload", "model_cpu_offload_and_qfloat8", "sequential_cpu_offload"].
15
+ # "model_cpu_offload" means that the entire model will be moved to the CPU after use, which can save some GPU memory.
16
+ #
17
+ # "model_cpu_offload_and_qfloat8" indicates that the entire model will be moved to the CPU after use,
18
+ # and the transformer model has been quantized to float8, which can save more GPU memory.
19
+ #
20
+ # "sequential_cpu_offload" means that each layer of the model will be moved to the CPU after use,
21
+ # resulting in slower speeds but saving a large amount of GPU memory.
22
+ GPU_memory_mode = "model_cpu_offload_and_qfloat8"
23
+ # Use torch.float16 if GPU does not support torch.bfloat16
24
+ # ome graphics cards, such as v100, 2080ti, do not support torch.bfloat16
25
+ weight_dtype = torch.bfloat16
26
+
27
  # Server ip
28
  server_name = "0.0.0.0"
29
  server_port = 7860
30
 
31
  # Params below is used when ui_mode = "modelscope"
32
+ edition = "v5"
33
+ # Config
34
+ config_path = "config/easyanimate_video_v5_magvit_multi_text_encoder.yaml"
35
+ # Model path of the pretrained model
36
+ model_name = "models/Diffusion_Transformer/EasyAnimateV5-12b-zh-InP"
37
+ # "Inpaint" or "Control"
38
+ model_type = "Inpaint"
39
+ # Save dir
40
  savedir_sample = "samples"
41
 
42
  if ui_mode == "modelscope":
43
+ demo, controller = ui_modelscope(model_type, edition, config_path, model_name, savedir_sample, GPU_memory_mode, weight_dtype)
44
  elif ui_mode == "eas":
45
  demo, controller = ui_eas(edition, config_path, model_name, savedir_sample)
46
  else:
47
+ demo, controller = ui(GPU_memory_mode, weight_dtype)
48
 
49
  # launch gradio
50
  app, _, _ = demo.queue(status_update_rate=1).launch(
config/easyanimate_image_magvit_v2.yaml DELETED
@@ -1,8 +0,0 @@
1
- noise_scheduler_kwargs:
2
- beta_start: 0.0001
3
- beta_end: 0.02
4
- beta_schedule: "linear"
5
- steps_offset: 1
6
-
7
- vae_kwargs:
8
- enable_magvit: true
 
 
 
 
 
 
 
 
 
config/easyanimate_image_normal_v1.yaml DELETED
@@ -1,8 +0,0 @@
1
- noise_scheduler_kwargs:
2
- beta_start: 0.0001
3
- beta_end: 0.02
4
- beta_schedule: "linear"
5
- steps_offset: 1
6
-
7
- vae_kwargs:
8
- enable_magvit: false
 
 
 
 
 
 
 
 
 
config/easyanimate_image_slicevae_v3.yaml DELETED
@@ -1,9 +0,0 @@
1
- noise_scheduler_kwargs:
2
- beta_start: 0.0001
3
- beta_end: 0.02
4
- beta_schedule: "linear"
5
- steps_offset: 1
6
-
7
- vae_kwargs:
8
- enable_magvit: true
9
- slice_compression_vae: true
 
 
 
 
 
 
 
 
 
 
config/easyanimate_video_casual_motion_module_v1.yaml DELETED
@@ -1,27 +0,0 @@
1
- transformer_additional_kwargs:
2
- patch_3d: false
3
- fake_3d: false
4
- casual_3d: true
5
- casual_3d_upsampler_index: [16, 20]
6
- time_patch_size: 4
7
- basic_block_type: "motionmodule"
8
- time_position_encoding_before_transformer: false
9
- motion_module_type: "VanillaGrid"
10
-
11
- motion_module_kwargs:
12
- num_attention_heads: 8
13
- num_transformer_block: 1
14
- attention_block_types: [ "Temporal_Self", "Temporal_Self" ]
15
- temporal_position_encoding: true
16
- temporal_position_encoding_max_len: 4096
17
- temporal_attention_dim_div: 1
18
- block_size: 2
19
-
20
- noise_scheduler_kwargs:
21
- beta_start: 0.0001
22
- beta_end: 0.02
23
- beta_schedule: "linear"
24
- steps_offset: 1
25
-
26
- vae_kwargs:
27
- enable_magvit: false
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
config/easyanimate_video_long_sequence_v1.yaml DELETED
@@ -1,14 +0,0 @@
1
- transformer_additional_kwargs:
2
- patch_3d: false
3
- fake_3d: false
4
- basic_block_type: "selfattentiontemporal"
5
- time_position_encoding_before_transformer: true
6
-
7
- noise_scheduler_kwargs:
8
- beta_start: 0.0001
9
- beta_end: 0.02
10
- beta_schedule: "linear"
11
- steps_offset: 1
12
-
13
- vae_kwargs:
14
- enable_magvit: false
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
config/{easyanimate_video_motion_module_v1.yaml → easyanimate_video_v1_motion_module.yaml} RENAMED
@@ -1,4 +1,5 @@
1
  transformer_additional_kwargs:
 
2
  patch_3d: false
3
  fake_3d: false
4
  basic_block_type: "motionmodule"
@@ -14,11 +15,8 @@ transformer_additional_kwargs:
14
  temporal_attention_dim_div: 1
15
  block_size: 2
16
 
17
- noise_scheduler_kwargs:
18
- beta_start: 0.0001
19
- beta_end: 0.02
20
- beta_schedule: "linear"
21
- steps_offset: 1
22
-
23
  vae_kwargs:
24
- enable_magvit: false
 
 
 
 
1
  transformer_additional_kwargs:
2
+ transformer_type: "Transformer3DModel"
3
  patch_3d: false
4
  fake_3d: false
5
  basic_block_type: "motionmodule"
 
15
  temporal_attention_dim_div: 1
16
  block_size: 2
17
 
 
 
 
 
 
 
18
  vae_kwargs:
19
+ vae_type: "AutoencoderKL"
20
+
21
+ text_encoder_kwargs:
22
+ enable_multi_text_encoder: false
config/{easyanimate_video_slicevae_motion_module_v3.yaml → easyanimate_video_v2_magvit_motion_module.yaml} RENAMED
@@ -1,4 +1,5 @@
1
  transformer_additional_kwargs:
 
2
  patch_3d: false
3
  fake_3d: false
4
  basic_block_type: "motionmodule"
@@ -15,13 +16,14 @@ transformer_additional_kwargs:
15
  temporal_attention_dim_div: 1
16
  block_size: 1
17
 
18
- noise_scheduler_kwargs:
19
- beta_start: 0.0001
20
- beta_end: 0.02
21
- beta_schedule: "linear"
22
- steps_offset: 1
23
-
24
  vae_kwargs:
25
- enable_magvit: true
26
- slice_compression_vae: true
27
- mini_batch_encoder: 8
 
 
 
 
 
 
 
 
1
  transformer_additional_kwargs:
2
+ transformer_type: "Transformer3DModel"
3
  patch_3d: false
4
  fake_3d: false
5
  basic_block_type: "motionmodule"
 
16
  temporal_attention_dim_div: 1
17
  block_size: 1
18
 
 
 
 
 
 
 
19
  vae_kwargs:
20
+ vae_type: "AutoencoderKLMagvit"
21
+ mini_batch_encoder: 9
22
+ mini_batch_decoder: 3
23
+ slice_mag_vae: true
24
+ slice_compression_vae: false
25
+ cache_compression_vae: false
26
+ cache_mag_vae: false
27
+
28
+ text_encoder_kwargs:
29
+ enable_multi_text_encoder: false
config/{easyanimate_video_magvit_motion_module_v2.yaml → easyanimate_video_v3_slicevae_motion_module.yaml} RENAMED
@@ -1,26 +1,39 @@
1
  transformer_additional_kwargs:
 
2
  patch_3d: false
3
  fake_3d: false
4
- basic_block_type: "motionmodule"
5
  time_position_encoding_before_transformer: false
6
  motion_module_type: "Vanilla"
7
  enable_uvit: true
8
 
9
- motion_module_kwargs:
10
- num_attention_heads: 8
11
  num_transformer_block: 1
12
  attention_block_types: [ "Temporal_Self", "Temporal_Self" ]
13
  temporal_position_encoding: true
14
  temporal_position_encoding_max_len: 4096
15
  temporal_attention_dim_div: 1
16
  block_size: 1
17
-
18
- noise_scheduler_kwargs:
19
- beta_start: 0.0001
20
- beta_end: 0.02
21
- beta_schedule: "linear"
22
- steps_offset: 1
 
 
 
 
23
 
24
  vae_kwargs:
25
- enable_magvit: true
26
- mini_batch_encoder: 9
 
 
 
 
 
 
 
 
 
1
  transformer_additional_kwargs:
2
+ transformer_type: "Transformer3DModel"
3
  patch_3d: false
4
  fake_3d: false
5
+ basic_block_type: "global_motionmodule"
6
  time_position_encoding_before_transformer: false
7
  motion_module_type: "Vanilla"
8
  enable_uvit: true
9
 
10
+ motion_module_kwargs_even:
11
+ num_attention_heads: 16
12
  num_transformer_block: 1
13
  attention_block_types: [ "Temporal_Self", "Temporal_Self" ]
14
  temporal_position_encoding: true
15
  temporal_position_encoding_max_len: 4096
16
  temporal_attention_dim_div: 1
17
  block_size: 1
18
+ remove_time_embedding_in_photo: false
19
+ motion_module_kwargs_odd:
20
+ num_attention_heads: 16
21
+ num_transformer_block: 1
22
+ attention_block_types: [ "Temporal_Self", "Global_Self" ]
23
+ temporal_position_encoding: true
24
+ temporal_position_encoding_max_len: 4096
25
+ temporal_attention_dim_div: 1
26
+ block_size: 1
27
+ remove_time_embedding_in_photo: false
28
 
29
  vae_kwargs:
30
+ vae_type: "AutoencoderKLMagvit"
31
+ mini_batch_encoder: 8
32
+ mini_batch_decoder: 2
33
+ slice_mag_vae: false
34
+ slice_compression_vae: true
35
+ cache_compression_vae: false
36
+ cache_mag_vae: false
37
+
38
+ text_encoder_kwargs:
39
+ enable_multi_text_encoder: false
config/easyanimate_video_v4_slicevae_multi_text_encoder.yaml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ transformer_additional_kwargs:
2
+ transformer_type: "HunyuanTransformer3DModel"
3
+ basic_block_type: "basic"
4
+ after_norm: false
5
+ time_position_encoding_type: "2d_rope"
6
+ time_position_encoding: true
7
+ resize_inpaint_mask_directly: false
8
+ enable_clip_in_inpaint: true
9
+
10
+ vae_kwargs:
11
+ vae_type: "AutoencoderKLMagvit"
12
+ mini_batch_encoder: 8
13
+ mini_batch_decoder: 2
14
+ slice_mag_vae: false
15
+ slice_compression_vae: false
16
+ cache_compression_vae: true
17
+ cache_mag_vae: false
18
+
19
+ text_encoder_kwargs:
20
+ enable_multi_text_encoder: true
config/easyanimate_video_v5_magvit_multi_text_encoder.yaml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ transformer_additional_kwargs:
2
+ transformer_type: "EasyAnimateTransformer3DModel"
3
+ after_norm: false
4
+ time_position_encoding_type: "3d_rope"
5
+ resize_inpaint_mask_directly: true
6
+ enable_text_attention_mask: false
7
+ enable_clip_in_inpaint: false
8
+
9
+ vae_kwargs:
10
+ vae_type: "AutoencoderKLMagvit"
11
+ mini_batch_encoder: 4
12
+ mini_batch_decoder: 1
13
+ slice_mag_vae: false
14
+ slice_compression_vae: false
15
+ cache_compression_vae: false
16
+ cache_mag_vae: true
17
+
18
+ text_encoder_kwargs:
19
+ enable_multi_text_encoder: true
config/zero_stage2_config.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bf16": {
3
+ "enabled": true
4
+ },
5
+ "train_micro_batch_size_per_gpu": 1,
6
+ "train_batch_size": "auto",
7
+ "gradient_accumulation_steps": "auto",
8
+ "dump_state": true,
9
+ "zero_optimization": {
10
+ "stage": 2,
11
+ "overlap_comm": true,
12
+ "contiguous_gradients": true,
13
+ "sub_group_size": 1e9,
14
+ "reduce_bucket_size": 5e8
15
+ }
16
+ }
easyanimate/api/api.py CHANGED
@@ -1,15 +1,17 @@
1
- import io
2
- import gc
3
  import base64
4
- import torch
5
- import gradio as gr
6
- import tempfile
7
  import hashlib
 
 
 
 
8
 
 
 
9
  from fastapi import FastAPI
10
- from io import BytesIO
11
  from PIL import Image
12
 
 
13
  # Function to encode a file to Base64
14
  def encode_file_to_base64(file_path):
15
  with open(file_path, "rb") as file:
@@ -53,6 +55,34 @@ def update_diffusion_transformer_api(_: gr.Blocks, app: FastAPI, controller):
53
 
54
  return {"message": comment}
55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  def infer_forward_api(_: gr.Blocks, app: FastAPI, controller):
57
  @app.post("/easyanimate/infer_forward")
58
  def _infer_forward_api(
@@ -63,7 +93,7 @@ def infer_forward_api(_: gr.Blocks, app: FastAPI, controller):
63
  lora_model_path = datas.get('lora_model_path', 'none')
64
  lora_alpha_slider = datas.get('lora_alpha_slider', 0.55)
65
  prompt_textbox = datas.get('prompt_textbox', None)
66
- negative_prompt_textbox = datas.get('negative_prompt_textbox', 'The video is not of a high quality, it has a low resolution, and the audio quality is not clear. Strange motion trajectory, a poor composition and deformed video, low resolution, duplicate and ugly, strange body structure, long and strange neck, bad teeth, bad eyes, bad limbs, bad hands, rotating camera, blurry camera, shaking camera. Deformation, low-resolution, blurry, ugly, distortion.')
67
  sampler_dropdown = datas.get('sampler_dropdown', 'Euler')
68
  sample_step_slider = datas.get('sample_step_slider', 30)
69
  resize_method = datas.get('resize_method', "Generate by")
@@ -72,17 +102,20 @@ def infer_forward_api(_: gr.Blocks, app: FastAPI, controller):
72
  base_resolution = datas.get('base_resolution', 512)
73
  is_image = datas.get('is_image', False)
74
  generation_method = datas.get('generation_method', False)
75
- length_slider = datas.get('length_slider', 144)
76
  overlap_video_length = datas.get('overlap_video_length', 4)
77
  partial_video_length = datas.get('partial_video_length', 72)
78
  cfg_scale_slider = datas.get('cfg_scale_slider', 6)
79
  start_image = datas.get('start_image', None)
80
  end_image = datas.get('end_image', None)
 
 
 
 
81
  seed_textbox = datas.get("seed_textbox", 43)
82
 
83
  generation_method = "Image Generation" if is_image else generation_method
84
 
85
- temp_directory = tempfile.gettempdir()
86
  if start_image is not None:
87
  start_image = base64.b64decode(start_image)
88
  start_image = [Image.open(BytesIO(start_image))]
@@ -91,6 +124,15 @@ def infer_forward_api(_: gr.Blocks, app: FastAPI, controller):
91
  end_image = base64.b64decode(end_image)
92
  end_image = [Image.open(BytesIO(end_image))]
93
 
 
 
 
 
 
 
 
 
 
94
  try:
95
  save_sample_path, comment = controller.generate(
96
  "",
@@ -113,6 +155,10 @@ def infer_forward_api(_: gr.Blocks, app: FastAPI, controller):
113
  cfg_scale_slider,
114
  start_image,
115
  end_image,
 
 
 
 
116
  seed_textbox,
117
  is_api = True,
118
  )
 
 
 
1
  import base64
2
+ import gc
 
 
3
  import hashlib
4
+ import io
5
+ import os
6
+ import tempfile
7
+ from io import BytesIO
8
 
9
+ import gradio as gr
10
+ import torch
11
  from fastapi import FastAPI
 
12
  from PIL import Image
13
 
14
+
15
  # Function to encode a file to Base64
16
  def encode_file_to_base64(file_path):
17
  with open(file_path, "rb") as file:
 
55
 
56
  return {"message": comment}
57
 
58
+ def save_base64_video(base64_string):
59
+ video_data = base64.b64decode(base64_string)
60
+
61
+ md5_hash = hashlib.md5(video_data).hexdigest()
62
+ filename = f"{md5_hash}.mp4"
63
+
64
+ temp_dir = tempfile.gettempdir()
65
+ file_path = os.path.join(temp_dir, filename)
66
+
67
+ with open(file_path, 'wb') as video_file:
68
+ video_file.write(video_data)
69
+
70
+ return file_path
71
+
72
+ def save_base64_image(base64_string):
73
+ video_data = base64.b64decode(base64_string)
74
+
75
+ md5_hash = hashlib.md5(video_data).hexdigest()
76
+ filename = f"{md5_hash}.jpg"
77
+
78
+ temp_dir = tempfile.gettempdir()
79
+ file_path = os.path.join(temp_dir, filename)
80
+
81
+ with open(file_path, 'wb') as video_file:
82
+ video_file.write(video_data)
83
+
84
+ return file_path
85
+
86
  def infer_forward_api(_: gr.Blocks, app: FastAPI, controller):
87
  @app.post("/easyanimate/infer_forward")
88
  def _infer_forward_api(
 
93
  lora_model_path = datas.get('lora_model_path', 'none')
94
  lora_alpha_slider = datas.get('lora_alpha_slider', 0.55)
95
  prompt_textbox = datas.get('prompt_textbox', None)
96
+ negative_prompt_textbox = datas.get('negative_prompt_textbox', 'Unclear, mutated, deformed, distorted, dark frames, fixed frames, comic book, comic book, small and indistinguishable subject.')
97
  sampler_dropdown = datas.get('sampler_dropdown', 'Euler')
98
  sample_step_slider = datas.get('sample_step_slider', 30)
99
  resize_method = datas.get('resize_method', "Generate by")
 
102
  base_resolution = datas.get('base_resolution', 512)
103
  is_image = datas.get('is_image', False)
104
  generation_method = datas.get('generation_method', False)
105
+ length_slider = datas.get('length_slider', 49)
106
  overlap_video_length = datas.get('overlap_video_length', 4)
107
  partial_video_length = datas.get('partial_video_length', 72)
108
  cfg_scale_slider = datas.get('cfg_scale_slider', 6)
109
  start_image = datas.get('start_image', None)
110
  end_image = datas.get('end_image', None)
111
+ validation_video = datas.get('validation_video', None)
112
+ validation_video_mask = datas.get('validation_video_mask', None)
113
+ control_video = datas.get('control_video', None)
114
+ denoise_strength = datas.get('denoise_strength', 0.70)
115
  seed_textbox = datas.get("seed_textbox", 43)
116
 
117
  generation_method = "Image Generation" if is_image else generation_method
118
 
 
119
  if start_image is not None:
120
  start_image = base64.b64decode(start_image)
121
  start_image = [Image.open(BytesIO(start_image))]
 
124
  end_image = base64.b64decode(end_image)
125
  end_image = [Image.open(BytesIO(end_image))]
126
 
127
+ if validation_video is not None:
128
+ validation_video = save_base64_video(validation_video)
129
+
130
+ if validation_video_mask is not None:
131
+ validation_video_mask = save_base64_image(validation_video_mask)
132
+
133
+ if control_video is not None:
134
+ control_video = save_base64_video(control_video)
135
+
136
  try:
137
  save_sample_path, comment = controller.generate(
138
  "",
 
155
  cfg_scale_slider,
156
  start_image,
157
  end_image,
158
+ validation_video,
159
+ validation_video_mask,
160
+ control_video,
161
+ denoise_strength,
162
  seed_textbox,
163
  is_api = True,
164
  )
easyanimate/api/post_infer.py CHANGED
@@ -7,7 +7,6 @@ from io import BytesIO
7
 
8
  import cv2
9
  import requests
10
- import base64
11
 
12
 
13
  def post_diffusion_transformer(diffusion_transformer_path, url='http://127.0.0.1:7860'):
 
7
 
8
  import cv2
9
  import requests
 
10
 
11
 
12
  def post_diffusion_transformer(diffusion_transformer_path, url='http://127.0.0.1:7860'):
easyanimate/data/dataset_image_video.py CHANGED
@@ -1,24 +1,23 @@
1
  import csv
 
2
  import io
3
  import json
4
  import math
5
  import os
6
  import random
 
7
  from threading import Thread
8
 
9
  import albumentations
10
  import cv2
11
- import gc
12
  import numpy as np
13
  import torch
14
  import torchvision.transforms as transforms
15
-
16
- from func_timeout import func_timeout, FunctionTimedOut
17
  from decord import VideoReader
 
18
  from PIL import Image
19
  from torch.utils.data import BatchSampler, Sampler
20
  from torch.utils.data.dataset import Dataset
21
- from contextlib import contextmanager
22
 
23
  VIDEO_READER_TIMEOUT = 20
24
 
@@ -26,9 +25,9 @@ def get_random_mask(shape):
26
  f, c, h, w = shape
27
 
28
  if f != 1:
29
- mask_index = np.random.randint(1, 4)
30
  else:
31
- mask_index = np.random.randint(1, 2)
32
  mask = torch.zeros((f, 1, h, w), dtype=torch.uint8)
33
 
34
  if mask_index == 0:
@@ -64,6 +63,40 @@ def get_random_mask(shape):
64
  mask_frame_before = np.random.randint(0, f // 2)
65
  mask_frame_after = np.random.randint(f // 2, f)
66
  mask[mask_frame_before:mask_frame_after, :, start_y:end_y, start_x:end_x] = 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  else:
68
  raise ValueError(f"The mask_index {mask_index} is not define")
69
  return mask
@@ -128,19 +161,35 @@ def get_video_reader_batch(video_reader, batch_index):
128
  frames = video_reader.get_batch(batch_index).asnumpy()
129
  return frames
130
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
  class ImageVideoDataset(Dataset):
132
  def __init__(
133
- self,
134
- ann_path, data_root=None,
135
- video_sample_size=512, video_sample_stride=4, video_sample_n_frames=16,
136
- image_sample_size=512,
137
- video_repeat=0,
138
- text_drop_ratio=-1,
139
- enable_bucket=False,
140
- video_length_drop_start=0.1,
141
- video_length_drop_end=0.9,
142
- enable_inpaint=False,
143
- ):
144
  # Loading annotations from files
145
  print(f"loading annotations from {ann_path} ...")
146
  if ann_path.endswith('.csv'):
@@ -176,11 +225,11 @@ class ImageVideoDataset(Dataset):
176
  # Video params
177
  self.video_sample_stride = video_sample_stride
178
  self.video_sample_n_frames = video_sample_n_frames
179
- video_sample_size = tuple(video_sample_size) if not isinstance(video_sample_size, int) else (video_sample_size, video_sample_size)
180
  self.video_transforms = transforms.Compose(
181
  [
182
- transforms.Resize(video_sample_size[0]),
183
- transforms.CenterCrop(video_sample_size),
184
  transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
185
  ]
186
  )
@@ -193,7 +242,9 @@ class ImageVideoDataset(Dataset):
193
  transforms.ToTensor(),
194
  transforms.Normalize([0.5, 0.5, 0.5],[0.5, 0.5, 0.5])
195
  ])
196
-
 
 
197
  def get_batch(self, idx):
198
  data_info = self.dataset[idx % len(self.dataset)]
199
 
@@ -208,7 +259,7 @@ class ImageVideoDataset(Dataset):
208
  with VideoReader_contextmanager(video_dir, num_threads=2) as video_reader:
209
  min_sample_n_frames = min(
210
  self.video_sample_n_frames,
211
- int(len(video_reader) * (self.video_length_drop_end - self.video_length_drop_start))
212
  )
213
  if min_sample_n_frames == 0:
214
  raise ValueError(f"No Frames in video.")
@@ -223,6 +274,12 @@ class ImageVideoDataset(Dataset):
223
  pixel_values = func_timeout(
224
  VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
225
  )
 
 
 
 
 
 
226
  except FunctionTimedOut:
227
  raise ValueError(f"Read {idx} timeout.")
228
  except Exception as e:
@@ -291,6 +348,238 @@ class ImageVideoDataset(Dataset):
291
  clip_pixel_values = (clip_pixel_values * 0.5 + 0.5) * 255
292
  sample["clip_pixel_values"] = clip_pixel_values
293
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
294
  return sample
295
 
296
  if __name__ == "__main__":
 
1
  import csv
2
+ import gc
3
  import io
4
  import json
5
  import math
6
  import os
7
  import random
8
+ from contextlib import contextmanager
9
  from threading import Thread
10
 
11
  import albumentations
12
  import cv2
 
13
  import numpy as np
14
  import torch
15
  import torchvision.transforms as transforms
 
 
16
  from decord import VideoReader
17
+ from func_timeout import FunctionTimedOut, func_timeout
18
  from PIL import Image
19
  from torch.utils.data import BatchSampler, Sampler
20
  from torch.utils.data.dataset import Dataset
 
21
 
22
  VIDEO_READER_TIMEOUT = 20
23
 
 
25
  f, c, h, w = shape
26
 
27
  if f != 1:
28
+ mask_index = np.random.choice([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], p=[0.05, 0.2, 0.2, 0.2, 0.05, 0.05, 0.05, 0.1, 0.05, 0.05])
29
  else:
30
+ mask_index = np.random.choice([0, 1], p = [0.2, 0.8])
31
  mask = torch.zeros((f, 1, h, w), dtype=torch.uint8)
32
 
33
  if mask_index == 0:
 
63
  mask_frame_before = np.random.randint(0, f // 2)
64
  mask_frame_after = np.random.randint(f // 2, f)
65
  mask[mask_frame_before:mask_frame_after, :, start_y:end_y, start_x:end_x] = 1
66
+ elif mask_index == 5:
67
+ mask = torch.randint(0, 2, (f, 1, h, w), dtype=torch.uint8)
68
+ elif mask_index == 6:
69
+ num_frames_to_mask = random.randint(1, max(f // 2, 1))
70
+ frames_to_mask = random.sample(range(f), num_frames_to_mask)
71
+
72
+ for i in frames_to_mask:
73
+ block_height = random.randint(1, h // 4)
74
+ block_width = random.randint(1, w // 4)
75
+ top_left_y = random.randint(0, h - block_height)
76
+ top_left_x = random.randint(0, w - block_width)
77
+ mask[i, 0, top_left_y:top_left_y + block_height, top_left_x:top_left_x + block_width] = 1
78
+ elif mask_index == 7:
79
+ center_x = torch.randint(0, w, (1,)).item()
80
+ center_y = torch.randint(0, h, (1,)).item()
81
+ a = torch.randint(min(w, h) // 8, min(w, h) // 4, (1,)).item() # 长半轴
82
+ b = torch.randint(min(h, w) // 8, min(h, w) // 4, (1,)).item() # 短半轴
83
+
84
+ for i in range(h):
85
+ for j in range(w):
86
+ if ((i - center_y) ** 2) / (b ** 2) + ((j - center_x) ** 2) / (a ** 2) < 1:
87
+ mask[:, :, i, j] = 1
88
+ elif mask_index == 8:
89
+ center_x = torch.randint(0, w, (1,)).item()
90
+ center_y = torch.randint(0, h, (1,)).item()
91
+ radius = torch.randint(min(h, w) // 8, min(h, w) // 4, (1,)).item()
92
+ for i in range(h):
93
+ for j in range(w):
94
+ if (i - center_y) ** 2 + (j - center_x) ** 2 < radius ** 2:
95
+ mask[:, :, i, j] = 1
96
+ elif mask_index == 9:
97
+ for idx in range(f):
98
+ if np.random.rand() > 0.5:
99
+ mask[idx, :, :, :] = 1
100
  else:
101
  raise ValueError(f"The mask_index {mask_index} is not define")
102
  return mask
 
161
  frames = video_reader.get_batch(batch_index).asnumpy()
162
  return frames
163
 
164
+ def resize_frame(frame, target_short_side):
165
+ h, w, _ = frame.shape
166
+ if h < w:
167
+ if target_short_side > h:
168
+ return frame
169
+ new_h = target_short_side
170
+ new_w = int(target_short_side * w / h)
171
+ else:
172
+ if target_short_side > w:
173
+ return frame
174
+ new_w = target_short_side
175
+ new_h = int(target_short_side * h / w)
176
+
177
+ resized_frame = cv2.resize(frame, (new_w, new_h))
178
+ return resized_frame
179
+
180
  class ImageVideoDataset(Dataset):
181
  def __init__(
182
+ self,
183
+ ann_path, data_root=None,
184
+ video_sample_size=512, video_sample_stride=4, video_sample_n_frames=16,
185
+ image_sample_size=512,
186
+ video_repeat=0,
187
+ text_drop_ratio=-1,
188
+ enable_bucket=False,
189
+ video_length_drop_start=0.1,
190
+ video_length_drop_end=0.9,
191
+ enable_inpaint=False,
192
+ ):
193
  # Loading annotations from files
194
  print(f"loading annotations from {ann_path} ...")
195
  if ann_path.endswith('.csv'):
 
225
  # Video params
226
  self.video_sample_stride = video_sample_stride
227
  self.video_sample_n_frames = video_sample_n_frames
228
+ self.video_sample_size = tuple(video_sample_size) if not isinstance(video_sample_size, int) else (video_sample_size, video_sample_size)
229
  self.video_transforms = transforms.Compose(
230
  [
231
+ transforms.Resize(min(self.video_sample_size)),
232
+ transforms.CenterCrop(self.video_sample_size),
233
  transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
234
  ]
235
  )
 
242
  transforms.ToTensor(),
243
  transforms.Normalize([0.5, 0.5, 0.5],[0.5, 0.5, 0.5])
244
  ])
245
+
246
+ self.larger_side_of_image_and_video = max(min(self.image_sample_size), min(self.video_sample_size))
247
+
248
  def get_batch(self, idx):
249
  data_info = self.dataset[idx % len(self.dataset)]
250
 
 
259
  with VideoReader_contextmanager(video_dir, num_threads=2) as video_reader:
260
  min_sample_n_frames = min(
261
  self.video_sample_n_frames,
262
+ int(len(video_reader) * (self.video_length_drop_end - self.video_length_drop_start) // self.video_sample_stride)
263
  )
264
  if min_sample_n_frames == 0:
265
  raise ValueError(f"No Frames in video.")
 
274
  pixel_values = func_timeout(
275
  VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
276
  )
277
+ resized_frames = []
278
+ for i in range(len(pixel_values)):
279
+ frame = pixel_values[i]
280
+ resized_frame = resize_frame(frame, self.larger_side_of_image_and_video)
281
+ resized_frames.append(resized_frame)
282
+ pixel_values = np.array(resized_frames)
283
  except FunctionTimedOut:
284
  raise ValueError(f"Read {idx} timeout.")
285
  except Exception as e:
 
348
  clip_pixel_values = (clip_pixel_values * 0.5 + 0.5) * 255
349
  sample["clip_pixel_values"] = clip_pixel_values
350
 
351
+ ref_pixel_values = sample["pixel_values"][0].unsqueeze(0)
352
+ if (mask == 1).all():
353
+ ref_pixel_values = torch.ones_like(ref_pixel_values) * -1
354
+ sample["ref_pixel_values"] = ref_pixel_values
355
+
356
+ return sample
357
+
358
+
359
+ class ImageVideoControlDataset(Dataset):
360
+ def __init__(
361
+ self,
362
+ ann_path, data_root=None,
363
+ video_sample_size=512, video_sample_stride=4, video_sample_n_frames=16,
364
+ image_sample_size=512,
365
+ video_repeat=0,
366
+ text_drop_ratio=-1,
367
+ enable_bucket=False,
368
+ video_length_drop_start=0.1,
369
+ video_length_drop_end=0.9,
370
+ enable_inpaint=False,
371
+ ):
372
+ # Loading annotations from files
373
+ print(f"loading annotations from {ann_path} ...")
374
+ if ann_path.endswith('.csv'):
375
+ with open(ann_path, 'r') as csvfile:
376
+ dataset = list(csv.DictReader(csvfile))
377
+ elif ann_path.endswith('.json'):
378
+ dataset = json.load(open(ann_path))
379
+
380
+ self.data_root = data_root
381
+
382
+ # It's used to balance num of images and videos.
383
+ self.dataset = []
384
+ for data in dataset:
385
+ if data.get('type', 'image') != 'video':
386
+ self.dataset.append(data)
387
+ if video_repeat > 0:
388
+ for _ in range(video_repeat):
389
+ for data in dataset:
390
+ if data.get('type', 'image') == 'video':
391
+ self.dataset.append(data)
392
+ del dataset
393
+
394
+ self.length = len(self.dataset)
395
+ print(f"data scale: {self.length}")
396
+ # TODO: enable bucket training
397
+ self.enable_bucket = enable_bucket
398
+ self.text_drop_ratio = text_drop_ratio
399
+ self.enable_inpaint = enable_inpaint
400
+
401
+ self.video_length_drop_start = video_length_drop_start
402
+ self.video_length_drop_end = video_length_drop_end
403
+
404
+ # Video params
405
+ self.video_sample_stride = video_sample_stride
406
+ self.video_sample_n_frames = video_sample_n_frames
407
+ self.video_sample_size = tuple(video_sample_size) if not isinstance(video_sample_size, int) else (video_sample_size, video_sample_size)
408
+ self.video_transforms = transforms.Compose(
409
+ [
410
+ transforms.Resize(min(self.video_sample_size)),
411
+ transforms.CenterCrop(self.video_sample_size),
412
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
413
+ ]
414
+ )
415
+
416
+ # Image params
417
+ self.image_sample_size = tuple(image_sample_size) if not isinstance(image_sample_size, int) else (image_sample_size, image_sample_size)
418
+ self.image_transforms = transforms.Compose([
419
+ transforms.Resize(min(self.image_sample_size)),
420
+ transforms.CenterCrop(self.image_sample_size),
421
+ transforms.ToTensor(),
422
+ transforms.Normalize([0.5, 0.5, 0.5],[0.5, 0.5, 0.5])
423
+ ])
424
+
425
+ self.larger_side_of_image_and_video = max(min(self.image_sample_size), min(self.video_sample_size))
426
+
427
+ def get_batch(self, idx):
428
+ data_info = self.dataset[idx % len(self.dataset)]
429
+ video_id, text = data_info['file_path'], data_info['text']
430
+
431
+ if data_info.get('type', 'image')=='video':
432
+ if self.data_root is None:
433
+ video_dir = video_id
434
+ else:
435
+ video_dir = os.path.join(self.data_root, video_id)
436
+
437
+ with VideoReader_contextmanager(video_dir, num_threads=2) as video_reader:
438
+ min_sample_n_frames = min(
439
+ self.video_sample_n_frames,
440
+ int(len(video_reader) * (self.video_length_drop_end - self.video_length_drop_start) // self.video_sample_stride)
441
+ )
442
+ if min_sample_n_frames == 0:
443
+ raise ValueError(f"No Frames in video.")
444
+
445
+ video_length = int(self.video_length_drop_end * len(video_reader))
446
+ clip_length = min(video_length, (min_sample_n_frames - 1) * self.video_sample_stride + 1)
447
+ start_idx = random.randint(int(self.video_length_drop_start * video_length), video_length - clip_length) if video_length != clip_length else 0
448
+ batch_index = np.linspace(start_idx, start_idx + clip_length - 1, min_sample_n_frames, dtype=int)
449
+
450
+ try:
451
+ sample_args = (video_reader, batch_index)
452
+ pixel_values = func_timeout(
453
+ VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
454
+ )
455
+ resized_frames = []
456
+ for i in range(len(pixel_values)):
457
+ frame = pixel_values[i]
458
+ resized_frame = resize_frame(frame, self.larger_side_of_image_and_video)
459
+ resized_frames.append(resized_frame)
460
+ pixel_values = np.array(resized_frames)
461
+ except FunctionTimedOut:
462
+ raise ValueError(f"Read {idx} timeout.")
463
+ except Exception as e:
464
+ raise ValueError(f"Failed to extract frames from video. Error is {e}.")
465
+
466
+ if not self.enable_bucket:
467
+ pixel_values = torch.from_numpy(pixel_values).permute(0, 3, 1, 2).contiguous()
468
+ pixel_values = pixel_values / 255.
469
+ del video_reader
470
+ else:
471
+ pixel_values = pixel_values
472
+
473
+ if not self.enable_bucket:
474
+ pixel_values = self.video_transforms(pixel_values)
475
+
476
+ # Random use no text generation
477
+ if random.random() < self.text_drop_ratio:
478
+ text = ''
479
+
480
+ control_video_id = data_info['control_file_path']
481
+
482
+ if self.data_root is None:
483
+ control_video_id = control_video_id
484
+ else:
485
+ control_video_id = os.path.join(self.data_root, control_video_id)
486
+
487
+ with VideoReader_contextmanager(control_video_id, num_threads=2) as control_video_reader:
488
+ try:
489
+ sample_args = (control_video_reader, batch_index)
490
+ control_pixel_values = func_timeout(
491
+ VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
492
+ )
493
+ resized_frames = []
494
+ for i in range(len(control_pixel_values)):
495
+ frame = control_pixel_values[i]
496
+ resized_frame = resize_frame(frame, self.larger_side_of_image_and_video)
497
+ resized_frames.append(resized_frame)
498
+ control_pixel_values = np.array(resized_frames)
499
+ except FunctionTimedOut:
500
+ raise ValueError(f"Read {idx} timeout.")
501
+ except Exception as e:
502
+ raise ValueError(f"Failed to extract frames from video. Error is {e}.")
503
+
504
+ if not self.enable_bucket:
505
+ control_pixel_values = torch.from_numpy(control_pixel_values).permute(0, 3, 1, 2).contiguous()
506
+ control_pixel_values = control_pixel_values / 255.
507
+ del control_video_reader
508
+ else:
509
+ control_pixel_values = control_pixel_values
510
+
511
+ if not self.enable_bucket:
512
+ control_pixel_values = self.video_transforms(control_pixel_values)
513
+ return pixel_values, control_pixel_values, text, "video"
514
+ else:
515
+ image_path, text = data_info['file_path'], data_info['text']
516
+ if self.data_root is not None:
517
+ image_path = os.path.join(self.data_root, image_path)
518
+ image = Image.open(image_path).convert('RGB')
519
+ if not self.enable_bucket:
520
+ image = self.image_transforms(image).unsqueeze(0)
521
+ else:
522
+ image = np.expand_dims(np.array(image), 0)
523
+
524
+ if random.random() < self.text_drop_ratio:
525
+ text = ''
526
+
527
+ control_image_id = data_info['control_file_path']
528
+
529
+ if self.data_root is None:
530
+ control_image_id = control_image_id
531
+ else:
532
+ control_image_id = os.path.join(self.data_root, control_image_id)
533
+
534
+ control_image = Image.open(control_image_id).convert('RGB')
535
+ if not self.enable_bucket:
536
+ control_image = self.image_transforms(control_image).unsqueeze(0)
537
+ else:
538
+ control_image = np.expand_dims(np.array(control_image), 0)
539
+ return image, control_image, text, 'image'
540
+
541
+ def __len__(self):
542
+ return self.length
543
+
544
+ def __getitem__(self, idx):
545
+ data_info = self.dataset[idx % len(self.dataset)]
546
+ data_type = data_info.get('type', 'image')
547
+ while True:
548
+ sample = {}
549
+ try:
550
+ data_info_local = self.dataset[idx % len(self.dataset)]
551
+ data_type_local = data_info_local.get('type', 'image')
552
+ if data_type_local != data_type:
553
+ raise ValueError("data_type_local != data_type")
554
+
555
+ pixel_values, control_pixel_values, name, data_type = self.get_batch(idx)
556
+ sample["pixel_values"] = pixel_values
557
+ sample["control_pixel_values"] = control_pixel_values
558
+ sample["text"] = name
559
+ sample["data_type"] = data_type
560
+ sample["idx"] = idx
561
+
562
+ if len(sample) > 0:
563
+ break
564
+ except Exception as e:
565
+ print(e, self.dataset[idx % len(self.dataset)])
566
+ idx = random.randint(0, self.length-1)
567
+
568
+ if self.enable_inpaint and not self.enable_bucket:
569
+ mask = get_random_mask(pixel_values.size())
570
+ mask_pixel_values = pixel_values * (1 - mask) + torch.ones_like(pixel_values) * -1 * mask
571
+ sample["mask_pixel_values"] = mask_pixel_values
572
+ sample["mask"] = mask
573
+
574
+ clip_pixel_values = sample["pixel_values"][0].permute(1, 2, 0).contiguous()
575
+ clip_pixel_values = (clip_pixel_values * 0.5 + 0.5) * 255
576
+ sample["clip_pixel_values"] = clip_pixel_values
577
+
578
+ ref_pixel_values = sample["pixel_values"][0].unsqueeze(0)
579
+ if (mask == 1).all():
580
+ ref_pixel_values = torch.ones_like(ref_pixel_values) * -1
581
+ sample["ref_pixel_values"] = ref_pixel_values
582
+
583
  return sample
584
 
585
  if __name__ == "__main__":
easyanimate/models/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .autoencoder_magvit import (AutoencoderKLCogVideoX, AutoencoderKLMagvit, AutoencoderKL)
2
+ from .transformer3d import (EasyAnimateTransformer3DModel,
3
+ HunyuanTransformer3DModel,
4
+ Transformer3DModel)
5
+
6
+
7
+ name_to_transformer3d = {
8
+ "Transformer3DModel": Transformer3DModel,
9
+ "HunyuanTransformer3DModel": HunyuanTransformer3DModel,
10
+ "EasyAnimateTransformer3DModel": EasyAnimateTransformer3DModel,
11
+ }
12
+ name_to_autoencoder_magvit = {
13
+ "AutoencoderKL": AutoencoderKL,
14
+ "AutoencoderKLMagvit": AutoencoderKLMagvit,
15
+ "AutoencoderKLCogVideoX": AutoencoderKLCogVideoX,
16
+ }
easyanimate/models/attention.py CHANGED
@@ -11,34 +11,38 @@
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
- from typing import Any, Dict, Optional
15
 
16
  import diffusers
17
  import pkg_resources
18
  import torch
19
  import torch.nn.functional as F
20
  import torch.nn.init as init
21
-
22
- installed_version = diffusers.__version__
23
-
24
- if pkg_resources.parse_version(installed_version) >= pkg_resources.parse_version("0.28.2"):
25
- from diffusers.models.attention_processor import (Attention,
26
- AttnProcessor2_0,
27
- HunyuanAttnProcessor2_0)
28
- else:
29
- from diffusers.models.attention_processor import Attention, AttnProcessor2_0
30
-
31
- from diffusers.models.attention import AdaLayerNorm, FeedForward
32
- from diffusers.models.embeddings import SinusoidalPositionalEmbedding
33
- from diffusers.models.normalization import AdaLayerNorm, AdaLayerNormZero
34
- from diffusers.utils import USE_PEFT_BACKEND
35
  from diffusers.utils.import_utils import is_xformers_available
36
  from diffusers.utils.torch_utils import maybe_allow_in_graph
37
  from einops import rearrange, repeat
38
  from torch import nn
39
 
40
  from .motion_module import PositionalEncoding, get_motion_module
41
- from .norm import FP32LayerNorm
 
 
 
 
42
 
43
  if is_xformers_available():
44
  import xformers
@@ -53,7 +57,6 @@ def zero_module(module):
53
  p.detach().zero_()
54
  return module
55
 
56
-
57
  @maybe_allow_in_graph
58
  class GatedSelfAttentionDense(nn.Module):
59
  r"""
@@ -95,267 +98,33 @@ class GatedSelfAttentionDense(nn.Module):
95
 
96
  return x
97
 
98
-
99
- class KVCompressionCrossAttention(nn.Module):
100
- r"""
101
- A cross attention layer.
102
-
103
- Parameters:
104
- query_dim (`int`): The number of channels in the query.
105
- cross_attention_dim (`int`, *optional*):
106
- The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
107
- heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
108
- dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
109
- dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
110
- bias (`bool`, *optional*, defaults to False):
111
- Set to `True` for the query, key, and value linear layers to contain a bias parameter.
112
- """
113
-
114
  def __init__(
115
- self,
116
- query_dim: int,
117
- cross_attention_dim: Optional[int] = None,
118
- heads: int = 8,
119
- dim_head: int = 64,
120
- dropout: float = 0.0,
121
- bias=False,
122
- upcast_attention: bool = False,
123
- upcast_softmax: bool = False,
124
- added_kv_proj_dim: Optional[int] = None,
125
- norm_num_groups: Optional[int] = None,
126
- ):
127
- super().__init__()
128
- inner_dim = dim_head * heads
129
- cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
130
- self.upcast_attention = upcast_attention
131
- self.upcast_softmax = upcast_softmax
132
-
133
- self.scale = dim_head**-0.5
134
-
135
- self.heads = heads
136
- # for slice_size > 0 the attention score computation
137
- # is split across the batch axis to save memory
138
- # You can set slice_size with `set_attention_slice`
139
- self.sliceable_head_dim = heads
140
- self._slice_size = None
141
- self._use_memory_efficient_attention_xformers = True
142
- self.added_kv_proj_dim = added_kv_proj_dim
143
-
144
- if norm_num_groups is not None:
145
- self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, eps=1e-5, affine=True)
146
- else:
147
- self.group_norm = None
148
-
149
- self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
150
- self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
151
- self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
152
-
153
- if self.added_kv_proj_dim is not None:
154
- self.add_k_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
155
- self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
156
-
157
- self.kv_compression = nn.Conv2d(
158
- query_dim,
159
- query_dim,
160
- groups=query_dim,
161
- kernel_size=2,
162
- stride=2,
163
  bias=True
164
  )
165
- self.kv_compression_norm = FP32LayerNorm(query_dim)
166
- init.constant_(self.kv_compression.weight, 1 / 4)
167
- if self.kv_compression.bias is not None:
168
- init.constant_(self.kv_compression.bias, 0)
169
-
170
- self.to_out = nn.ModuleList([])
171
- self.to_out.append(nn.Linear(inner_dim, query_dim))
172
- self.to_out.append(nn.Dropout(dropout))
173
-
174
- def reshape_heads_to_batch_dim(self, tensor):
175
- batch_size, seq_len, dim = tensor.shape
176
- head_size = self.heads
177
- tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
178
- tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
179
- return tensor
180
-
181
- def reshape_batch_dim_to_heads(self, tensor):
182
- batch_size, seq_len, dim = tensor.shape
183
- head_size = self.heads
184
- tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
185
- tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
186
- return tensor
187
-
188
- def set_attention_slice(self, slice_size):
189
- if slice_size is not None and slice_size > self.sliceable_head_dim:
190
- raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
191
-
192
- self._slice_size = slice_size
193
-
194
- def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, num_frames: int = 16, height: int = 32, width: int = 32):
195
- batch_size, sequence_length, _ = hidden_states.shape
196
-
197
- encoder_hidden_states = encoder_hidden_states
198
-
199
- if self.group_norm is not None:
200
- hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
201
-
202
- query = self.to_q(hidden_states)
203
- dim = query.shape[-1]
204
- query = self.reshape_heads_to_batch_dim(query)
205
-
206
- if self.added_kv_proj_dim is not None:
207
- key = self.to_k(hidden_states)
208
- value = self.to_v(hidden_states)
209
-
210
- encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states)
211
- encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states)
212
-
213
- key = rearrange(key, "b (f h w) c -> (b f) c h w", f=num_frames, h=height, w=width)
214
- key = self.kv_compression(key)
215
- key = rearrange(key, "(b f) c h w -> b (f h w) c", f=num_frames)
216
- key = self.kv_compression_norm(key)
217
- key = key.to(query.dtype)
218
-
219
- value = rearrange(value, "b (f h w) c -> (b f) c h w", f=num_frames, h=height, w=width)
220
- value = self.kv_compression(value)
221
- value = rearrange(value, "(b f) c h w -> b (f h w) c", f=num_frames)
222
- value = self.kv_compression_norm(value)
223
- value = value.to(query.dtype)
224
-
225
- key = self.reshape_heads_to_batch_dim(key)
226
- value = self.reshape_heads_to_batch_dim(value)
227
- encoder_hidden_states_key_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_key_proj)
228
- encoder_hidden_states_value_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_value_proj)
229
-
230
- key = torch.concat([encoder_hidden_states_key_proj, key], dim=1)
231
- value = torch.concat([encoder_hidden_states_value_proj, value], dim=1)
232
- else:
233
- encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
234
- key = self.to_k(encoder_hidden_states)
235
- value = self.to_v(encoder_hidden_states)
236
-
237
- key = rearrange(key, "b (f h w) c -> (b f) c h w", f=num_frames, h=height, w=width)
238
- key = self.kv_compression(key)
239
- key = rearrange(key, "(b f) c h w -> b (f h w) c", f=num_frames)
240
- key = self.kv_compression_norm(key)
241
- key = key.to(query.dtype)
242
-
243
- value = rearrange(value, "b (f h w) c -> (b f) c h w", f=num_frames, h=height, w=width)
244
- value = self.kv_compression(value)
245
- value = rearrange(value, "(b f) c h w -> b (f h w) c", f=num_frames)
246
- value = self.kv_compression_norm(value)
247
- value = value.to(query.dtype)
248
-
249
- key = self.reshape_heads_to_batch_dim(key)
250
- value = self.reshape_heads_to_batch_dim(value)
251
-
252
- if attention_mask is not None:
253
- if attention_mask.shape[-1] != query.shape[1]:
254
- target_length = query.shape[1]
255
- attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
256
- attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
257
-
258
- # attention, what we cannot get enough of
259
- if self._use_memory_efficient_attention_xformers:
260
- hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
261
- # Some versions of xformers return output in fp32, cast it back to the dtype of the input
262
- hidden_states = hidden_states.to(query.dtype)
263
- else:
264
- if self._slice_size is None or query.shape[0] // self._slice_size == 1:
265
- hidden_states = self._attention(query, key, value, attention_mask)
266
- else:
267
- hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
268
-
269
- # linear proj
270
- hidden_states = self.to_out[0](hidden_states)
271
-
272
- # dropout
273
- hidden_states = self.to_out[1](hidden_states)
274
- return hidden_states
275
-
276
- def _attention(self, query, key, value, attention_mask=None):
277
- if self.upcast_attention:
278
- query = query.float()
279
- key = key.float()
280
-
281
- attention_scores = torch.baddbmm(
282
- torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
283
- query,
284
- key.transpose(-1, -2),
285
- beta=0,
286
- alpha=self.scale,
287
- )
288
-
289
- if attention_mask is not None:
290
- attention_scores = attention_scores + attention_mask
291
-
292
- if self.upcast_softmax:
293
- attention_scores = attention_scores.float()
294
-
295
- attention_probs = attention_scores.softmax(dim=-1)
296
-
297
- # cast back to the original dtype
298
- attention_probs = attention_probs.to(value.dtype)
299
-
300
- # compute attention output
301
- hidden_states = torch.bmm(attention_probs, value)
302
-
303
- # reshape hidden_states
304
- hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
305
- return hidden_states
306
-
307
- def _sliced_attention(self, query, key, value, sequence_length, dim, attention_mask):
308
- batch_size_attention = query.shape[0]
309
- hidden_states = torch.zeros(
310
- (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype
311
  )
312
- slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0]
313
- for i in range(hidden_states.shape[0] // slice_size):
314
- start_idx = i * slice_size
315
- end_idx = (i + 1) * slice_size
316
-
317
- query_slice = query[start_idx:end_idx]
318
- key_slice = key[start_idx:end_idx]
319
-
320
- if self.upcast_attention:
321
- query_slice = query_slice.float()
322
- key_slice = key_slice.float()
323
-
324
- attn_slice = torch.baddbmm(
325
- torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query_slice.dtype, device=query.device),
326
- query_slice,
327
- key_slice.transpose(-1, -2),
328
- beta=0,
329
- alpha=self.scale,
330
- )
331
-
332
- if attention_mask is not None:
333
- attn_slice = attn_slice + attention_mask[start_idx:end_idx]
334
-
335
- if self.upcast_softmax:
336
- attn_slice = attn_slice.float()
337
-
338
- attn_slice = attn_slice.softmax(dim=-1)
339
-
340
- # cast back to the original dtype
341
- attn_slice = attn_slice.to(value.dtype)
342
- attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
343
-
344
- hidden_states[start_idx:end_idx] = attn_slice
345
-
346
- # reshape hidden_states
347
- hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
348
- return hidden_states
349
-
350
- def _memory_efficient_attention_xformers(self, query, key, value, attention_mask):
351
- # TODO attention_mask
352
- query = query.contiguous()
353
- key = key.contiguous()
354
- value = value.contiguous()
355
- hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
356
- hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
357
- return hidden_states
358
-
359
 
360
  @maybe_allow_in_graph
361
  class TemporalTransformerBlock(nn.Module):
@@ -413,8 +182,6 @@ class TemporalTransformerBlock(nn.Module):
413
  attention_type: str = "default",
414
  positional_embeddings: Optional[str] = None,
415
  num_positional_embeddings: Optional[int] = None,
416
- # kv compression
417
- kvcompression: Optional[bool] = False,
418
  # motion module kwargs
419
  motion_module_type = "VanillaGrid",
420
  motion_module_kwargs = None,
@@ -454,40 +221,17 @@ class TemporalTransformerBlock(nn.Module):
454
  else:
455
  self.norm1 = FP32LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
456
 
457
- self.kvcompression = kvcompression
458
- if kvcompression:
459
- self.attn1 = KVCompressionCrossAttention(
460
- query_dim=dim,
461
- heads=num_attention_heads,
462
- dim_head=attention_head_dim,
463
- dropout=dropout,
464
- bias=attention_bias,
465
- cross_attention_dim=cross_attention_dim if only_cross_attention else None,
466
- upcast_attention=upcast_attention,
467
- )
468
- else:
469
- if pkg_resources.parse_version(installed_version) >= pkg_resources.parse_version("0.28.2"):
470
- self.attn1 = Attention(
471
- query_dim=dim,
472
- heads=num_attention_heads,
473
- dim_head=attention_head_dim,
474
- dropout=dropout,
475
- bias=attention_bias,
476
- cross_attention_dim=cross_attention_dim if only_cross_attention else None,
477
- upcast_attention=upcast_attention,
478
- qk_norm="layer_norm" if qk_norm else None,
479
- processor=HunyuanAttnProcessor2_0() if qk_norm else AttnProcessor2_0(),
480
- )
481
- else:
482
- self.attn1 = Attention(
483
- query_dim=dim,
484
- heads=num_attention_heads,
485
- dim_head=attention_head_dim,
486
- dropout=dropout,
487
- bias=attention_bias,
488
- cross_attention_dim=cross_attention_dim if only_cross_attention else None,
489
- upcast_attention=upcast_attention,
490
- )
491
 
492
  self.attn_temporal = get_motion_module(
493
  in_channels = dim,
@@ -505,28 +249,17 @@ class TemporalTransformerBlock(nn.Module):
505
  if self.use_ada_layer_norm
506
  else FP32LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
507
  )
508
- if pkg_resources.parse_version(installed_version) >= pkg_resources.parse_version("0.28.2"):
509
- self.attn2 = Attention(
510
- query_dim=dim,
511
- cross_attention_dim=cross_attention_dim if not double_self_attention else None,
512
- heads=num_attention_heads,
513
- dim_head=attention_head_dim,
514
- dropout=dropout,
515
- bias=attention_bias,
516
- upcast_attention=upcast_attention,
517
- qk_norm="layer_norm" if qk_norm else None,
518
- processor=HunyuanAttnProcessor2_0() if qk_norm else AttnProcessor2_0(),
519
- ) # is self-attn if encoder_hidden_states is none
520
- else:
521
- self.attn2 = Attention(
522
- query_dim=dim,
523
- cross_attention_dim=cross_attention_dim if not double_self_attention else None,
524
- heads=num_attention_heads,
525
- dim_head=attention_head_dim,
526
- dropout=dropout,
527
- bias=attention_bias,
528
- upcast_attention=upcast_attention,
529
- ) # is self-attn if encoder_hidden_states is none
530
  else:
531
  self.norm2 = None
532
  self.attn2 = None
@@ -605,23 +338,12 @@ class TemporalTransformerBlock(nn.Module):
605
  gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
606
 
607
  norm_hidden_states = rearrange(norm_hidden_states, "b (f d) c -> (b f) d c", f=num_frames)
608
- if self.kvcompression:
609
- attn_output = self.attn1(
610
- norm_hidden_states,
611
- encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
612
- attention_mask=attention_mask,
613
- num_frames=1,
614
- height=height,
615
- width=width,
616
- **cross_attention_kwargs,
617
- )
618
- else:
619
- attn_output = self.attn1(
620
- norm_hidden_states,
621
- encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
622
- attention_mask=attention_mask,
623
- **cross_attention_kwargs,
624
- )
625
  attn_output = rearrange(attn_output, "(b f) d c -> b (f d) c", f=num_frames)
626
  if self.use_ada_layer_norm_zero:
627
  attn_output = gate_msa.unsqueeze(1) * attn_output
@@ -658,6 +380,9 @@ class TemporalTransformerBlock(nn.Module):
658
  if self.pos_embed is not None and self.use_ada_layer_norm_single is None:
659
  norm_hidden_states = self.pos_embed(norm_hidden_states)
660
 
 
 
 
661
  attn_output = self.attn2(
662
  norm_hidden_states,
663
  encoder_hidden_states=encoder_hidden_states,
@@ -760,7 +485,7 @@ class SelfAttentionTemporalTransformerBlock(nn.Module):
760
  double_self_attention: bool = False,
761
  upcast_attention: bool = False,
762
  norm_elementwise_affine: bool = True,
763
- norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single'
764
  norm_eps: float = 1e-5,
765
  final_dropout: bool = False,
766
  attention_type: str = "default",
@@ -802,28 +527,17 @@ class SelfAttentionTemporalTransformerBlock(nn.Module):
802
  else:
803
  self.norm1 = FP32LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
804
 
805
- if pkg_resources.parse_version(installed_version) >= pkg_resources.parse_version("0.28.2"):
806
- self.attn1 = Attention(
807
- query_dim=dim,
808
- heads=num_attention_heads,
809
- dim_head=attention_head_dim,
810
- dropout=dropout,
811
- bias=attention_bias,
812
- cross_attention_dim=cross_attention_dim if only_cross_attention else None,
813
- upcast_attention=upcast_attention,
814
- qk_norm="layer_norm" if qk_norm else None,
815
- processor=HunyuanAttnProcessor2_0() if qk_norm else AttnProcessor2_0(),
816
- )
817
- else:
818
- self.attn1 = Attention(
819
- query_dim=dim,
820
- heads=num_attention_heads,
821
- dim_head=attention_head_dim,
822
- dropout=dropout,
823
- bias=attention_bias,
824
- cross_attention_dim=cross_attention_dim if only_cross_attention else None,
825
- upcast_attention=upcast_attention,
826
- )
827
 
828
  # 2. Cross-Attn
829
  if cross_attention_dim is not None or double_self_attention:
@@ -835,28 +549,17 @@ class SelfAttentionTemporalTransformerBlock(nn.Module):
835
  if self.use_ada_layer_norm
836
  else FP32LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
837
  )
838
- if pkg_resources.parse_version(installed_version) >= pkg_resources.parse_version("0.28.2"):
839
- self.attn2 = Attention(
840
- query_dim=dim,
841
- cross_attention_dim=cross_attention_dim if not double_self_attention else None,
842
- heads=num_attention_heads,
843
- dim_head=attention_head_dim,
844
- dropout=dropout,
845
- bias=attention_bias,
846
- upcast_attention=upcast_attention,
847
- qk_norm="layer_norm" if qk_norm else None,
848
- processor=HunyuanAttnProcessor2_0() if qk_norm else AttnProcessor2_0(),
849
- ) # is self-attn if encoder_hidden_states is none
850
- else:
851
- self.attn2 = Attention(
852
- query_dim=dim,
853
- cross_attention_dim=cross_attention_dim if not double_self_attention else None,
854
- heads=num_attention_heads,
855
- dim_head=attention_head_dim,
856
- dropout=dropout,
857
- bias=attention_bias,
858
- upcast_attention=upcast_attention,
859
- ) # is self-attn if encoder_hidden_states is none
860
  else:
861
  self.norm2 = None
862
  self.attn2 = None
@@ -1017,340 +720,415 @@ class SelfAttentionTemporalTransformerBlock(nn.Module):
1017
  hidden_states = hidden_states.squeeze(1)
1018
 
1019
  return hidden_states
1020
-
 
 
 
 
 
 
 
 
 
1021
 
1022
  @maybe_allow_in_graph
1023
- class KVCompressionTransformerBlock(nn.Module):
1024
  r"""
1025
- A Temporal Transformer block.
 
1026
 
1027
  Parameters:
1028
- dim (`int`): The number of channels in the input and output.
1029
- num_attention_heads (`int`): The number of heads to use for multi-head attention.
1030
- attention_head_dim (`int`): The number of channels in each head.
1031
- dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
1032
- cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
1033
- activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
1034
- num_embeds_ada_norm (:
1035
- obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
1036
- attention_bias (:
1037
- obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
1038
- only_cross_attention (`bool`, *optional*):
1039
- Whether to use only cross-attention layers. In this case two cross attention layers are used.
1040
- double_self_attention (`bool`, *optional*):
1041
- Whether to use two self-attention layers. In this case no cross attention layers are used.
1042
- upcast_attention (`bool`, *optional*):
1043
- Whether to upcast the attention computation to float32. This is useful for mixed precision training.
1044
  norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
1045
  Whether to use learnable elementwise affine parameters for normalization.
1046
- norm_type (`str`, *optional*, defaults to `"layer_norm"`):
1047
- The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
1048
  final_dropout (`bool` *optional*, defaults to False):
1049
  Whether to apply a final dropout after the last feed-forward layer.
1050
- attention_type (`str`, *optional*, defaults to `"default"`):
1051
- The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
1052
- positional_embeddings (`str`, *optional*, defaults to `None`):
1053
- The type of positional embeddings to apply to.
1054
- num_positional_embeddings (`int`, *optional*, defaults to `None`):
1055
- The maximum number of positional embeddings to apply.
 
 
1056
  """
1057
 
1058
  def __init__(
1059
  self,
1060
  dim: int,
1061
  num_attention_heads: int,
1062
- attention_head_dim: int,
1063
  dropout=0.0,
1064
- cross_attention_dim: Optional[int] = None,
1065
  activation_fn: str = "geglu",
1066
- num_embeds_ada_norm: Optional[int] = None,
1067
- attention_bias: bool = False,
1068
- only_cross_attention: bool = False,
1069
- double_self_attention: bool = False,
1070
- upcast_attention: bool = False,
1071
  norm_elementwise_affine: bool = True,
1072
- norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single'
1073
- norm_eps: float = 1e-5,
1074
  final_dropout: bool = False,
1075
- attention_type: str = "default",
1076
- positional_embeddings: Optional[str] = None,
1077
- num_positional_embeddings: Optional[int] = None,
1078
- kvcompression: Optional[bool] = False,
1079
- qk_norm = False,
1080
- after_norm = False,
 
 
 
 
1081
  ):
1082
  super().__init__()
1083
- self.only_cross_attention = only_cross_attention
1084
-
1085
- self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
1086
- self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
1087
- self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
1088
- self.use_layer_norm = norm_type == "layer_norm"
1089
-
1090
- if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
1091
- raise ValueError(
1092
- f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
1093
- f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
1094
- )
1095
-
1096
- if positional_embeddings and (num_positional_embeddings is None):
1097
- raise ValueError(
1098
- "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
1099
- )
1100
-
1101
- if positional_embeddings == "sinusoidal":
1102
- self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
1103
- else:
1104
- self.pos_embed = None
1105
 
1106
  # Define 3 blocks. Each block has its own normalization layer.
 
1107
  # 1. Self-Attn
1108
- if self.use_ada_layer_norm:
1109
- self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
1110
- elif self.use_ada_layer_norm_zero:
1111
- self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
1112
- else:
1113
- self.norm1 = FP32LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
1114
 
 
 
1115
  self.kvcompression = kvcompression
1116
  if kvcompression:
1117
- self.attn1 = KVCompressionCrossAttention(
1118
  query_dim=dim,
 
 
1119
  heads=num_attention_heads,
1120
- dim_head=attention_head_dim,
1121
- dropout=dropout,
1122
- bias=attention_bias,
1123
- cross_attention_dim=cross_attention_dim if only_cross_attention else None,
1124
- upcast_attention=upcast_attention,
1125
  )
1126
  else:
1127
- if pkg_resources.parse_version(installed_version) >= pkg_resources.parse_version("0.28.2"):
1128
- self.attn1 = Attention(
1129
- query_dim=dim,
1130
- heads=num_attention_heads,
1131
- dim_head=attention_head_dim,
1132
- dropout=dropout,
1133
- bias=attention_bias,
1134
- cross_attention_dim=cross_attention_dim if only_cross_attention else None,
1135
- upcast_attention=upcast_attention,
1136
- qk_norm="layer_norm" if qk_norm else None,
1137
- processor=HunyuanAttnProcessor2_0() if qk_norm else AttnProcessor2_0(),
1138
- )
1139
- else:
1140
- self.attn1 = Attention(
1141
- query_dim=dim,
1142
- heads=num_attention_heads,
1143
- dim_head=attention_head_dim,
1144
- dropout=dropout,
1145
- bias=attention_bias,
1146
- cross_attention_dim=cross_attention_dim if only_cross_attention else None,
1147
- upcast_attention=upcast_attention,
1148
- )
1149
 
1150
  # 2. Cross-Attn
1151
- if cross_attention_dim is not None or double_self_attention:
1152
- # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
1153
- # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
1154
- # the second cross attention block.
1155
- self.norm2 = (
1156
- AdaLayerNorm(dim, num_embeds_ada_norm)
1157
- if self.use_ada_layer_norm
1158
- else FP32LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
 
 
 
 
 
 
1159
  )
1160
- if pkg_resources.parse_version(installed_version) >= pkg_resources.parse_version("0.28.2"):
1161
- self.attn2 = Attention(
1162
- query_dim=dim,
1163
- cross_attention_dim=cross_attention_dim if not double_self_attention else None,
1164
- heads=num_attention_heads,
1165
- dim_head=attention_head_dim,
1166
- dropout=dropout,
1167
- bias=attention_bias,
1168
- upcast_attention=upcast_attention,
1169
- qk_norm="layer_norm" if qk_norm else None,
1170
- processor=HunyuanAttnProcessor2_0() if qk_norm else AttnProcessor2_0(),
1171
- ) # is self-attn if encoder_hidden_states is none
1172
- else:
1173
- self.attn2 = Attention(
1174
- query_dim=dim,
1175
- cross_attention_dim=cross_attention_dim if not double_self_attention else None,
1176
- heads=num_attention_heads,
1177
- dim_head=attention_head_dim,
1178
- dropout=dropout,
1179
- bias=attention_bias,
1180
- upcast_attention=upcast_attention,
1181
- ) # is self-attn if encoder_hidden_states is none
1182
- else:
1183
- self.norm2 = None
1184
- self.attn2 = None
1185
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1186
  # 3. Feed-forward
1187
- if not self.use_ada_layer_norm_single:
1188
- self.norm3 = FP32LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
 
 
 
 
 
 
 
 
1189
 
1190
- self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
 
 
 
 
 
1191
 
1192
  if after_norm:
1193
  self.norm4 = FP32LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
1194
  else:
1195
  self.norm4 = None
1196
 
1197
- # 4. Fuser
1198
- if attention_type == "gated" or attention_type == "gated-text-image":
1199
- self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)
1200
-
1201
- # 5. Scale-shift for PixArt-Alpha.
1202
- if self.use_ada_layer_norm_single:
1203
- self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
1204
-
1205
  # let chunk size default to None
1206
  self._chunk_size = None
1207
  self._chunk_dim = 0
1208
 
1209
- def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
1210
  # Sets chunk feed-forward
1211
  self._chunk_size = chunk_size
1212
  self._chunk_dim = dim
1213
 
1214
  def forward(
1215
  self,
1216
- hidden_states: torch.FloatTensor,
1217
- attention_mask: Optional[torch.FloatTensor] = None,
1218
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
1219
- encoder_attention_mask: Optional[torch.FloatTensor] = None,
1220
- timestep: Optional[torch.LongTensor] = None,
1221
- cross_attention_kwargs: Dict[str, Any] = None,
1222
- class_labels: Optional[torch.LongTensor] = None,
1223
- num_frames: int = 16,
1224
  height: int = 32,
1225
  width: int = 32,
1226
- use_reentrant: bool = False,
1227
- ) -> torch.FloatTensor:
 
1228
  # Notice that normalization is always applied before the real computation in the following blocks.
1229
- # 0. Self-Attention
1230
- batch_size = hidden_states.shape[0]
1231
-
1232
- if self.use_ada_layer_norm:
1233
- norm_hidden_states = self.norm1(hidden_states, timestep)
1234
- elif self.use_ada_layer_norm_zero:
1235
- norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
1236
- hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
1237
- )
1238
- elif self.use_layer_norm:
1239
- norm_hidden_states = self.norm1(hidden_states)
1240
- elif self.use_ada_layer_norm_single:
1241
- shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
1242
- self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
1243
- ).chunk(6, dim=1)
1244
- norm_hidden_states = self.norm1(hidden_states)
1245
- norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
1246
- norm_hidden_states = norm_hidden_states.squeeze(1)
1247
- else:
1248
- raise ValueError("Incorrect norm used")
1249
-
1250
- if self.pos_embed is not None:
1251
- norm_hidden_states = self.pos_embed(norm_hidden_states)
1252
-
1253
- # 1. Retrieve lora scale.
1254
- lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
1255
-
1256
- # 2. Prepare GLIGEN inputs
1257
- cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
1258
- gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
1259
-
1260
- if self.kvcompression:
1261
  attn_output = self.attn1(
1262
- norm_hidden_states,
1263
- encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
1264
- attention_mask=attention_mask,
1265
- num_frames=num_frames,
1266
- height=height,
1267
- width=width,
1268
- **cross_attention_kwargs,
1269
  )
1270
- else:
1271
- attn_output = self.attn1(
1272
- norm_hidden_states,
1273
- encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
1274
- attention_mask=attention_mask,
1275
- **cross_attention_kwargs,
 
 
1276
  )
 
 
1277
 
1278
- if self.use_ada_layer_norm_zero:
1279
- attn_output = gate_msa.unsqueeze(1) * attn_output
1280
- elif self.use_ada_layer_norm_single:
1281
- attn_output = gate_msa * attn_output
1282
-
1283
- hidden_states = attn_output + hidden_states
1284
- if hidden_states.ndim == 4:
1285
- hidden_states = hidden_states.squeeze(1)
1286
-
1287
- # 2.5 GLIGEN Control
1288
- if gligen_kwargs is not None:
1289
- hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
1290
-
1291
- # 3. Cross-Attention
1292
- if self.attn2 is not None:
1293
- if self.use_ada_layer_norm:
1294
- norm_hidden_states = self.norm2(hidden_states, timestep)
1295
- elif self.use_ada_layer_norm_zero or self.use_layer_norm:
1296
- norm_hidden_states = self.norm2(hidden_states)
1297
- elif self.use_ada_layer_norm_single:
1298
- # For PixArt norm2 isn't applied here:
1299
- # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
1300
- norm_hidden_states = hidden_states
1301
  else:
1302
- raise ValueError("Incorrect norm")
1303
-
1304
- if self.pos_embed is not None and self.use_ada_layer_norm_single is None:
1305
- norm_hidden_states = self.pos_embed(norm_hidden_states)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1306
 
1307
- attn_output = self.attn2(
1308
- norm_hidden_states,
1309
- encoder_hidden_states=encoder_hidden_states,
1310
- attention_mask=encoder_attention_mask,
1311
- **cross_attention_kwargs,
 
 
 
 
1312
  )
1313
- hidden_states = attn_output + hidden_states
1314
 
1315
- # 4. Feed-forward
1316
- if not self.use_ada_layer_norm_single:
1317
- norm_hidden_states = self.norm3(hidden_states)
 
 
 
1318
 
1319
- if self.use_ada_layer_norm_zero:
1320
- norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
1321
 
1322
- if self.use_ada_layer_norm_single:
1323
- norm_hidden_states = self.norm2(hidden_states)
1324
- norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1325
 
1326
- if self._chunk_size is not None:
1327
- # "feed_forward_chunk_size" can be used to save memory
1328
- if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
1329
- raise ValueError(
1330
- f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
1331
- )
1332
 
1333
- num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
1334
- ff_output = torch.cat(
1335
- [
1336
- self.ff(hid_slice, scale=lora_scale)
1337
- for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)
1338
- ],
1339
- dim=self._chunk_dim,
1340
- )
1341
- else:
1342
- ff_output = self.ff(norm_hidden_states, scale=lora_scale)
 
 
 
 
 
 
 
 
1343
 
1344
- if self.norm4 is not None:
1345
- ff_output = self.norm4(ff_output)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1346
 
1347
- if self.use_ada_layer_norm_zero:
1348
- ff_output = gate_mlp.unsqueeze(1) * ff_output
1349
- elif self.use_ada_layer_norm_single:
1350
- ff_output = gate_mlp * ff_output
 
 
 
 
 
 
 
1351
 
1352
- hidden_states = ff_output + hidden_states
1353
- if hidden_states.ndim == 4:
1354
- hidden_states = hidden_states.squeeze(1)
 
 
 
 
 
 
1355
 
1356
- return hidden_states
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
+ from typing import Any, Dict, Optional, Tuple, Union
15
 
16
  import diffusers
17
  import pkg_resources
18
  import torch
19
  import torch.nn.functional as F
20
  import torch.nn.init as init
21
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
22
+ from diffusers.models.attention import Attention, FeedForward
23
+ from diffusers.models.attention_processor import (Attention,
24
+ AttentionProcessor,
25
+ AttnProcessor2_0,
26
+ HunyuanAttnProcessor2_0)
27
+ from diffusers.models.embeddings import (SinusoidalPositionalEmbedding,
28
+ TimestepEmbedding, Timesteps,
29
+ get_3d_sincos_pos_embed)
30
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
31
+ from diffusers.models.modeling_utils import ModelMixin
32
+ from diffusers.models.normalization import (AdaLayerNorm, AdaLayerNormZero,
33
+ CogVideoXLayerNormZero)
34
+ from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging
35
  from diffusers.utils.import_utils import is_xformers_available
36
  from diffusers.utils.torch_utils import maybe_allow_in_graph
37
  from einops import rearrange, repeat
38
  from torch import nn
39
 
40
  from .motion_module import PositionalEncoding, get_motion_module
41
+ from .norm import AdaLayerNormShift, FP32LayerNorm, EasyAnimateLayerNormZero
42
+ from .processor import (EasyAnimateAttnProcessor2_0,
43
+ LazyKVCompressionProcessor2_0)
44
+
45
+
46
 
47
  if is_xformers_available():
48
  import xformers
 
57
  p.detach().zero_()
58
  return module
59
 
 
60
  @maybe_allow_in_graph
61
  class GatedSelfAttentionDense(nn.Module):
62
  r"""
 
98
 
99
  return x
100
 
101
+ class LazyKVCompressionAttention(Attention):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  def __init__(
103
+ self,
104
+ sr_ratio=2, *args, **kwargs
105
+ ):
106
+ super().__init__(*args, **kwargs)
107
+ self.sr_ratio = sr_ratio
108
+ self.k_compression = nn.Conv2d(
109
+ kwargs["query_dim"],
110
+ kwargs["query_dim"],
111
+ groups=kwargs["query_dim"],
112
+ kernel_size=sr_ratio,
113
+ stride=sr_ratio,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
  bias=True
115
  )
116
+ self.v_compression = nn.Conv2d(
117
+ kwargs["query_dim"],
118
+ kwargs["query_dim"],
119
+ groups=kwargs["query_dim"],
120
+ kernel_size=sr_ratio,
121
+ stride=sr_ratio,
122
+ bias=True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
  )
124
+ init.constant_(self.k_compression.weight, 1 / (sr_ratio * sr_ratio))
125
+ init.constant_(self.v_compression.weight, 1 / (sr_ratio * sr_ratio))
126
+ init.constant_(self.k_compression.bias, 0)
127
+ init.constant_(self.v_compression.bias, 0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
 
129
  @maybe_allow_in_graph
130
  class TemporalTransformerBlock(nn.Module):
 
182
  attention_type: str = "default",
183
  positional_embeddings: Optional[str] = None,
184
  num_positional_embeddings: Optional[int] = None,
 
 
185
  # motion module kwargs
186
  motion_module_type = "VanillaGrid",
187
  motion_module_kwargs = None,
 
221
  else:
222
  self.norm1 = FP32LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
223
 
224
+ self.attn1 = Attention(
225
+ query_dim=dim,
226
+ heads=num_attention_heads,
227
+ dim_head=attention_head_dim,
228
+ dropout=dropout,
229
+ bias=attention_bias,
230
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
231
+ upcast_attention=upcast_attention,
232
+ qk_norm="layer_norm" if qk_norm else None,
233
+ processor=HunyuanAttnProcessor2_0() if qk_norm else AttnProcessor2_0(),
234
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
235
 
236
  self.attn_temporal = get_motion_module(
237
  in_channels = dim,
 
249
  if self.use_ada_layer_norm
250
  else FP32LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
251
  )
252
+ self.attn2 = Attention(
253
+ query_dim=dim,
254
+ cross_attention_dim=cross_attention_dim if not double_self_attention else None,
255
+ heads=num_attention_heads,
256
+ dim_head=attention_head_dim,
257
+ dropout=dropout,
258
+ bias=attention_bias,
259
+ upcast_attention=upcast_attention,
260
+ qk_norm="layer_norm" if qk_norm else None,
261
+ processor=HunyuanAttnProcessor2_0() if qk_norm else AttnProcessor2_0(),
262
+ ) # is self-attn if encoder_hidden_states is none
 
 
 
 
 
 
 
 
 
 
 
263
  else:
264
  self.norm2 = None
265
  self.attn2 = None
 
338
  gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
339
 
340
  norm_hidden_states = rearrange(norm_hidden_states, "b (f d) c -> (b f) d c", f=num_frames)
341
+ attn_output = self.attn1(
342
+ norm_hidden_states,
343
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
344
+ attention_mask=attention_mask,
345
+ **cross_attention_kwargs,
346
+ )
 
 
 
 
 
 
 
 
 
 
 
347
  attn_output = rearrange(attn_output, "(b f) d c -> b (f d) c", f=num_frames)
348
  if self.use_ada_layer_norm_zero:
349
  attn_output = gate_msa.unsqueeze(1) * attn_output
 
380
  if self.pos_embed is not None and self.use_ada_layer_norm_single is None:
381
  norm_hidden_states = self.pos_embed(norm_hidden_states)
382
 
383
+ if norm_hidden_states.dtype != encoder_hidden_states.dtype or norm_hidden_states.dtype != encoder_attention_mask.dtype:
384
+ norm_hidden_states = norm_hidden_states.to(encoder_hidden_states.dtype)
385
+
386
  attn_output = self.attn2(
387
  norm_hidden_states,
388
  encoder_hidden_states=encoder_hidden_states,
 
485
  double_self_attention: bool = False,
486
  upcast_attention: bool = False,
487
  norm_elementwise_affine: bool = True,
488
+ norm_type: str = "layer_norm",
489
  norm_eps: float = 1e-5,
490
  final_dropout: bool = False,
491
  attention_type: str = "default",
 
527
  else:
528
  self.norm1 = FP32LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
529
 
530
+ self.attn1 = Attention(
531
+ query_dim=dim,
532
+ heads=num_attention_heads,
533
+ dim_head=attention_head_dim,
534
+ dropout=dropout,
535
+ bias=attention_bias,
536
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
537
+ upcast_attention=upcast_attention,
538
+ qk_norm="layer_norm" if qk_norm else None,
539
+ processor=HunyuanAttnProcessor2_0() if qk_norm else AttnProcessor2_0(),
540
+ )
 
 
 
 
 
 
 
 
 
 
 
541
 
542
  # 2. Cross-Attn
543
  if cross_attention_dim is not None or double_self_attention:
 
549
  if self.use_ada_layer_norm
550
  else FP32LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
551
  )
552
+ self.attn2 = Attention(
553
+ query_dim=dim,
554
+ cross_attention_dim=cross_attention_dim if not double_self_attention else None,
555
+ heads=num_attention_heads,
556
+ dim_head=attention_head_dim,
557
+ dropout=dropout,
558
+ bias=attention_bias,
559
+ upcast_attention=upcast_attention,
560
+ qk_norm="layer_norm" if qk_norm else None,
561
+ processor=HunyuanAttnProcessor2_0() if qk_norm else AttnProcessor2_0(),
562
+ ) # is self-attn if encoder_hidden_states is none
 
 
 
 
 
 
 
 
 
 
 
563
  else:
564
  self.norm2 = None
565
  self.attn2 = None
 
720
  hidden_states = hidden_states.squeeze(1)
721
 
722
  return hidden_states
723
+
724
+ class GEGLU(nn.Module):
725
+ def __init__(self, dim_in, dim_out, norm_elementwise_affine):
726
+ super().__init__()
727
+ self.norm = FP32LayerNorm(dim_in, dim_in, norm_elementwise_affine)
728
+ self.proj = nn.Linear(dim_in, dim_out * 2)
729
+
730
+ def forward(self, x):
731
+ x, gate = self.proj(self.norm(x)).chunk(2, dim=-1)
732
+ return x * F.gelu(gate)
733
 
734
  @maybe_allow_in_graph
735
+ class HunyuanDiTBlock(nn.Module):
736
  r"""
737
+ Transformer block used in Hunyuan-DiT model (https://github.com/Tencent/HunyuanDiT). Allow skip connection and
738
+ QKNorm
739
 
740
  Parameters:
741
+ dim (`int`):
742
+ The number of channels in the input and output.
743
+ num_attention_heads (`int`):
744
+ The number of headsto use for multi-head attention.
745
+ cross_attention_dim (`int`,*optional*):
746
+ The size of the encoder_hidden_states vector for cross attention.
747
+ dropout(`float`, *optional*, defaults to 0.0):
748
+ The dropout probability to use.
749
+ activation_fn (`str`,*optional*, defaults to `"geglu"`):
750
+ Activation function to be used in feed-forward. .
 
 
 
 
 
 
751
  norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
752
  Whether to use learnable elementwise affine parameters for normalization.
753
+ norm_eps (`float`, *optional*, defaults to 1e-6):
754
+ A small constant added to the denominator in normalization layers to prevent division by zero.
755
  final_dropout (`bool` *optional*, defaults to False):
756
  Whether to apply a final dropout after the last feed-forward layer.
757
+ ff_inner_dim (`int`, *optional*):
758
+ The size of the hidden layer in the feed-forward block. Defaults to `None`.
759
+ ff_bias (`bool`, *optional*, defaults to `True`):
760
+ Whether to use bias in the feed-forward block.
761
+ skip (`bool`, *optional*, defaults to `False`):
762
+ Whether to use skip connection. Defaults to `False` for down-blocks and mid-blocks.
763
+ qk_norm (`bool`, *optional*, defaults to `True`):
764
+ Whether to use normalization in QK calculation. Defaults to `True`.
765
  """
766
 
767
  def __init__(
768
  self,
769
  dim: int,
770
  num_attention_heads: int,
771
+ cross_attention_dim: int = 1024,
772
  dropout=0.0,
 
773
  activation_fn: str = "geglu",
 
 
 
 
 
774
  norm_elementwise_affine: bool = True,
775
+ norm_eps: float = 1e-6,
 
776
  final_dropout: bool = False,
777
+ ff_inner_dim: Optional[int] = None,
778
+ ff_bias: bool = True,
779
+ skip: bool = False,
780
+ qk_norm: bool = True,
781
+ time_position_encoding: bool = False,
782
+ after_norm: bool = False,
783
+ is_local_attention: bool = False,
784
+ local_attention_frames: int = 2,
785
+ enable_inpaint: bool = False,
786
+ kvcompression = False,
787
  ):
788
  super().__init__()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
789
 
790
  # Define 3 blocks. Each block has its own normalization layer.
791
+ # NOTE: when new version comes, check norm2 and norm 3
792
  # 1. Self-Attn
793
+ self.norm1 = AdaLayerNormShift(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
794
+ self.t_embed = PositionalEncoding(dim, dropout=0., max_len=512) \
795
+ if time_position_encoding else nn.Identity()
 
 
 
796
 
797
+ self.is_local_attention = is_local_attention
798
+ self.local_attention_frames = local_attention_frames
799
  self.kvcompression = kvcompression
800
  if kvcompression:
801
+ self.attn1 = LazyKVCompressionAttention(
802
  query_dim=dim,
803
+ cross_attention_dim=None,
804
+ dim_head=dim // num_attention_heads,
805
  heads=num_attention_heads,
806
+ qk_norm="layer_norm" if qk_norm else None,
807
+ eps=1e-6,
808
+ bias=True,
809
+ processor=LazyKVCompressionProcessor2_0(),
 
810
  )
811
  else:
812
+ self.attn1 = Attention(
813
+ query_dim=dim,
814
+ cross_attention_dim=None,
815
+ dim_head=dim // num_attention_heads,
816
+ heads=num_attention_heads,
817
+ qk_norm="layer_norm" if qk_norm else None,
818
+ eps=1e-6,
819
+ bias=True,
820
+ processor=HunyuanAttnProcessor2_0(),
821
+ )
 
 
 
 
 
 
 
 
 
 
 
 
822
 
823
  # 2. Cross-Attn
824
+ self.norm2 = FP32LayerNorm(dim, norm_eps, norm_elementwise_affine)
825
+
826
+ if self.is_local_attention:
827
+ from mamba_ssm import Mamba2
828
+ self.mamba_norm_in = FP32LayerNorm(dim, norm_eps, norm_elementwise_affine)
829
+ self.in_linear = nn.Linear(dim, 1536)
830
+ self.mamba_norm_1 = FP32LayerNorm(1536, norm_eps, norm_elementwise_affine)
831
+ self.mamba_norm_2 = FP32LayerNorm(1536, norm_eps, norm_elementwise_affine)
832
+
833
+ self.mamba_block_1 = Mamba2(
834
+ d_model=1536,
835
+ d_state=64,
836
+ d_conv=4,
837
+ expand=2,
838
  )
839
+ self.mamba_block_2 = Mamba2(
840
+ d_model=1536,
841
+ d_state=64,
842
+ d_conv=4,
843
+ expand=2,
844
+ )
845
+ self.mamba_norm_after_mamba_block = FP32LayerNorm(1536, norm_eps, norm_elementwise_affine)
846
+
847
+ self.out_linear = nn.Linear(1536, dim)
848
+ self.out_linear = zero_module(self.out_linear)
849
+ self.mamba_norm_out = FP32LayerNorm(dim, norm_eps, norm_elementwise_affine)
850
+
851
+ self.attn2 = Attention(
852
+ query_dim=dim,
853
+ cross_attention_dim=cross_attention_dim,
854
+ dim_head=dim // num_attention_heads,
855
+ heads=num_attention_heads,
856
+ qk_norm="layer_norm" if qk_norm else None,
857
+ eps=1e-6,
858
+ bias=True,
859
+ processor=HunyuanAttnProcessor2_0(),
860
+ )
 
 
 
861
 
862
+ if enable_inpaint:
863
+ self.norm_clip = FP32LayerNorm(dim, norm_eps, norm_elementwise_affine)
864
+ self.attn_clip = Attention(
865
+ query_dim=dim,
866
+ cross_attention_dim=cross_attention_dim,
867
+ dim_head=dim // num_attention_heads,
868
+ heads=num_attention_heads,
869
+ qk_norm="layer_norm" if qk_norm else None,
870
+ eps=1e-6,
871
+ bias=True,
872
+ processor=HunyuanAttnProcessor2_0(),
873
+ )
874
+ self.gate_clip = GEGLU(dim, dim, norm_elementwise_affine)
875
+ self.norm_clip_out = FP32LayerNorm(dim, norm_eps, norm_elementwise_affine)
876
+ else:
877
+ self.attn_clip = None
878
+ self.norm_clip = None
879
+ self.gate_clip = None
880
+ self.norm_clip_out = None
881
+
882
  # 3. Feed-forward
883
+ self.norm3 = FP32LayerNorm(dim, norm_eps, norm_elementwise_affine)
884
+
885
+ self.ff = FeedForward(
886
+ dim,
887
+ dropout=dropout, ### 0.0
888
+ activation_fn=activation_fn, ### approx GeLU
889
+ final_dropout=final_dropout, ### 0.0
890
+ inner_dim=ff_inner_dim, ### int(dim * mlp_ratio)
891
+ bias=ff_bias,
892
+ )
893
 
894
+ # 4. Skip Connection
895
+ if skip:
896
+ self.skip_norm = FP32LayerNorm(2 * dim, norm_eps, elementwise_affine=True)
897
+ self.skip_linear = nn.Linear(2 * dim, dim)
898
+ else:
899
+ self.skip_linear = None
900
 
901
  if after_norm:
902
  self.norm4 = FP32LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
903
  else:
904
  self.norm4 = None
905
 
 
 
 
 
 
 
 
 
906
  # let chunk size default to None
907
  self._chunk_size = None
908
  self._chunk_dim = 0
909
 
910
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
911
  # Sets chunk feed-forward
912
  self._chunk_size = chunk_size
913
  self._chunk_dim = dim
914
 
915
  def forward(
916
  self,
917
+ hidden_states: torch.Tensor,
918
+ encoder_hidden_states: Optional[torch.Tensor] = None,
919
+ temb: Optional[torch.Tensor] = None,
920
+ image_rotary_emb=None,
921
+ skip=None,
922
+ num_frames: int = 1,
 
 
923
  height: int = 32,
924
  width: int = 32,
925
+ clip_encoder_hidden_states: Optional[torch.Tensor] = None,
926
+ disable_image_rotary_emb_in_attn1=False,
927
+ ) -> torch.Tensor:
928
  # Notice that normalization is always applied before the real computation in the following blocks.
929
+ # 0. Long Skip Connection
930
+ if self.skip_linear is not None:
931
+ cat = torch.cat([hidden_states, skip], dim=-1)
932
+ cat = self.skip_norm(cat)
933
+ hidden_states = self.skip_linear(cat)
934
+
935
+ if image_rotary_emb is not None:
936
+ image_rotary_emb = (torch.cat([image_rotary_emb[0] for i in range(num_frames)], dim=0), torch.cat([image_rotary_emb[1] for i in range(num_frames)], dim=0))
937
+
938
+ if num_frames != 1:
939
+ # add time embedding
940
+ hidden_states = rearrange(hidden_states, "b (f d) c -> (b d) f c", f=num_frames)
941
+ if self.t_embed is not None:
942
+ hidden_states = self.t_embed(hidden_states)
943
+ hidden_states = rearrange(hidden_states, "(b d) f c -> b (f d) c", d=height * width)
944
+
945
+ # 1. Self-Attention
946
+ norm_hidden_states = self.norm1(hidden_states, temb) ### checked: self.norm1 is correct
947
+ if num_frames > 2 and self.is_local_attention:
948
+ if image_rotary_emb is not None:
949
+ attn1_image_rotary_emb = (image_rotary_emb[0][:int(height * width * 2)], image_rotary_emb[1][:int(height * width * 2)])
950
+ else:
951
+ attn1_image_rotary_emb = image_rotary_emb
952
+ norm_hidden_states_1 = rearrange(norm_hidden_states, "b (f d) c -> b f d c", d=height * width)
953
+ norm_hidden_states_1 = rearrange(norm_hidden_states_1, "b (f p) d c -> (b f) (p d) c", p = 2)
954
+
 
 
 
 
 
 
955
  attn_output = self.attn1(
956
+ norm_hidden_states_1,
957
+ image_rotary_emb=attn1_image_rotary_emb if not disable_image_rotary_emb_in_attn1 else None,
 
 
 
 
 
958
  )
959
+ attn_output = rearrange(attn_output, "(b f) (p d) c -> b (f p) d c", p = 2, f = num_frames // 2)
960
+
961
+ norm_hidden_states_2 = rearrange(norm_hidden_states, "b (f d) c -> b f d c", d = height * width)[:, 1:-1]
962
+ local_attention_frames_num = norm_hidden_states_2.size()[1] // 2
963
+ norm_hidden_states_2 = rearrange(norm_hidden_states_2, "b (f p) d c -> (b f) (p d) c", p = 2)
964
+ attn_output_2 = self.attn1(
965
+ norm_hidden_states_2,
966
+ image_rotary_emb=attn1_image_rotary_emb if not disable_image_rotary_emb_in_attn1 else None,
967
  )
968
+ attn_output_2 = rearrange(attn_output_2, "(b f) (p d) c -> b (f p) d c", p = 2, f = local_attention_frames_num)
969
+ attn_output[:, 1:-1] = (attn_output[:, 1:-1] + attn_output_2) / 2
970
 
971
+ attn_output = rearrange(attn_output, "b f d c -> b (f d) c")
972
+ else:
973
+ if self.kvcompression:
974
+ norm_hidden_states = rearrange(norm_hidden_states, "b (f h w) c -> b c f h w", f = num_frames, h = height, w = width)
975
+ attn_output = self.attn1(
976
+ norm_hidden_states,
977
+ image_rotary_emb=image_rotary_emb if not disable_image_rotary_emb_in_attn1 else None,
978
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
979
  else:
980
+ attn_output = self.attn1(
981
+ norm_hidden_states,
982
+ image_rotary_emb=image_rotary_emb if not disable_image_rotary_emb_in_attn1 else None,
983
+ )
984
+ hidden_states = hidden_states + attn_output
985
+
986
+ if num_frames > 2 and self.is_local_attention:
987
+ hidden_states_in = self.in_linear(self.mamba_norm_in(hidden_states))
988
+ hidden_states = hidden_states + self.mamba_norm_out(
989
+ self.out_linear(
990
+ self.mamba_norm_after_mamba_block(
991
+ self.mamba_block_1(
992
+ self.mamba_norm_1(hidden_states_in)
993
+ ) +
994
+ self.mamba_block_2(
995
+ self.mamba_norm_2(hidden_states_in.flip(1))
996
+ ).flip(1)
997
+ )
998
+ )
999
+ )
1000
+
1001
+ # 2. Cross-Attention
1002
+ hidden_states = hidden_states + self.attn2(
1003
+ self.norm2(hidden_states),
1004
+ encoder_hidden_states=encoder_hidden_states,
1005
+ image_rotary_emb=image_rotary_emb,
1006
+ )
1007
 
1008
+ if self.attn_clip is not None:
1009
+ hidden_states = hidden_states + self.norm_clip_out(
1010
+ self.gate_clip(
1011
+ self.attn_clip(
1012
+ self.norm_clip(hidden_states),
1013
+ encoder_hidden_states=clip_encoder_hidden_states,
1014
+ image_rotary_emb=image_rotary_emb,
1015
+ )
1016
+ )
1017
  )
 
1018
 
1019
+ # FFN Layer ### TODO: switch norm2 and norm3 in the state dict
1020
+ mlp_inputs = self.norm3(hidden_states)
1021
+ if self.norm4 is not None:
1022
+ hidden_states = hidden_states + self.norm4(self.ff(mlp_inputs))
1023
+ else:
1024
+ hidden_states = hidden_states + self.ff(mlp_inputs)
1025
 
1026
+ return hidden_states
 
1027
 
1028
+ @maybe_allow_in_graph
1029
+ class EasyAnimateDiTBlock(nn.Module):
1030
+ def __init__(
1031
+ self,
1032
+ dim: int,
1033
+ num_attention_heads: int,
1034
+ attention_head_dim: int,
1035
+ time_embed_dim: int,
1036
+ dropout: float = 0.0,
1037
+ activation_fn: str = "gelu-approximate",
1038
+ norm_elementwise_affine: bool = True,
1039
+ norm_eps: float = 1e-6,
1040
+ final_dropout: bool = True,
1041
+ ff_inner_dim: Optional[int] = None,
1042
+ ff_bias: bool = True,
1043
+ qk_norm: bool = True,
1044
+ after_norm: bool = False,
1045
+ norm_type: str="fp32_layer_norm"
1046
+ ):
1047
+ super().__init__()
1048
 
1049
+ # Attention Part
1050
+ self.norm1 = EasyAnimateLayerNormZero(
1051
+ time_embed_dim, dim, norm_elementwise_affine, norm_eps, norm_type=norm_type, bias=True
1052
+ )
 
 
1053
 
1054
+ self.attn1 = Attention(
1055
+ query_dim=dim,
1056
+ dim_head=attention_head_dim,
1057
+ heads=num_attention_heads,
1058
+ qk_norm="layer_norm" if qk_norm else None,
1059
+ eps=1e-6,
1060
+ bias=True,
1061
+ processor=EasyAnimateAttnProcessor2_0(),
1062
+ )
1063
+ self.attn2 = Attention(
1064
+ query_dim=dim,
1065
+ dim_head=attention_head_dim,
1066
+ heads=num_attention_heads,
1067
+ qk_norm="layer_norm" if qk_norm else None,
1068
+ eps=1e-6,
1069
+ bias=True,
1070
+ processor=EasyAnimateAttnProcessor2_0(),
1071
+ )
1072
 
1073
+ # FFN Part
1074
+ self.norm2 = EasyAnimateLayerNormZero(
1075
+ time_embed_dim, dim, norm_elementwise_affine, norm_eps, norm_type=norm_type, bias=True
1076
+ )
1077
+ self.ff = FeedForward(
1078
+ dim,
1079
+ dropout=dropout,
1080
+ activation_fn=activation_fn,
1081
+ final_dropout=final_dropout,
1082
+ inner_dim=ff_inner_dim,
1083
+ bias=ff_bias,
1084
+ )
1085
+ self.txt_ff = FeedForward(
1086
+ dim,
1087
+ dropout=dropout,
1088
+ activation_fn=activation_fn,
1089
+ final_dropout=final_dropout,
1090
+ inner_dim=ff_inner_dim,
1091
+ bias=ff_bias,
1092
+ )
1093
+ if after_norm:
1094
+ self.norm3 = FP32LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
1095
+ else:
1096
+ self.norm3 = None
1097
 
1098
+ def forward(
1099
+ self,
1100
+ hidden_states: torch.Tensor,
1101
+ encoder_hidden_states: torch.Tensor,
1102
+ temb: torch.Tensor,
1103
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
1104
+ ) -> torch.Tensor:
1105
+ # Norm
1106
+ norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
1107
+ hidden_states, encoder_hidden_states, temb
1108
+ )
1109
 
1110
+ # Attn
1111
+ attn_hidden_states, attn_encoder_hidden_states = self.attn1(
1112
+ hidden_states=norm_hidden_states,
1113
+ encoder_hidden_states=norm_encoder_hidden_states,
1114
+ image_rotary_emb=image_rotary_emb,
1115
+ attn2=self.attn2,
1116
+ )
1117
+ hidden_states = hidden_states + gate_msa * attn_hidden_states
1118
+ encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states
1119
 
1120
+ # Norm
1121
+ norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2(
1122
+ hidden_states, encoder_hidden_states, temb
1123
+ )
1124
+
1125
+ # FFN
1126
+ if self.norm3 is not None:
1127
+ norm_hidden_states = self.norm3(self.ff(norm_hidden_states))
1128
+ norm_encoder_hidden_states = self.norm3(self.txt_ff(norm_encoder_hidden_states))
1129
+ else:
1130
+ norm_hidden_states = self.ff(norm_hidden_states)
1131
+ norm_encoder_hidden_states = self.txt_ff(norm_encoder_hidden_states)
1132
+ hidden_states = hidden_states + gate_ff * norm_hidden_states
1133
+ encoder_hidden_states = encoder_hidden_states + enc_gate_ff * norm_encoder_hidden_states
1134
+ return hidden_states, encoder_hidden_states
easyanimate/models/autoencoder_magvit.py CHANGED
@@ -15,8 +15,14 @@ from typing import Dict, Optional, Tuple, Union
15
 
16
  import torch
17
  import torch.nn as nn
18
- import torch.nn.functional as F
19
  from diffusers.configuration_utils import ConfigMixin, register_to_config
 
 
 
 
 
 
 
20
 
21
  try:
22
  from diffusers.loaders import FromOriginalVAEMixin
@@ -32,10 +38,16 @@ from diffusers.models.modeling_outputs import AutoencoderKLOutput
32
  from diffusers.models.modeling_utils import ModelMixin
33
  from diffusers.utils.accelerate_utils import apply_forward_hook
34
  from torch import nn
 
35
 
 
 
 
 
36
  from ..vae.ldm.models.omnigen_enc_dec import Decoder as omnigen_Mag_Decoder
37
  from ..vae.ldm.models.omnigen_enc_dec import Encoder as omnigen_Mag_Encoder
38
 
 
39
 
40
  def str_eval(item):
41
  if type(item) == str:
@@ -97,10 +109,19 @@ class AutoencoderKLMagvit(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
97
  latent_channels: int = 4,
98
  norm_num_groups: int = 32,
99
  scaling_factor: float = 0.1825,
 
100
  slice_compression_vae=False,
 
 
101
  use_tiling=False,
 
 
102
  mini_batch_encoder=9,
103
  mini_batch_decoder=3,
 
 
 
 
104
  ):
105
  super().__init__()
106
  down_block_types = str_eval(down_block_types)
@@ -121,8 +142,12 @@ class AutoencoderKLMagvit(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
121
  act_fn=act_fn,
122
  num_attention_heads=num_attention_heads,
123
  double_z=True,
 
124
  slice_compression_vae=slice_compression_vae,
 
 
125
  mini_batch_encoder=mini_batch_encoder,
 
126
  )
127
 
128
  self.decoder = omnigen_Mag_Decoder(
@@ -140,20 +165,30 @@ class AutoencoderKLMagvit(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
140
  norm_num_groups=norm_num_groups,
141
  act_fn=act_fn,
142
  num_attention_heads=num_attention_heads,
 
143
  slice_compression_vae=slice_compression_vae,
 
 
144
  mini_batch_decoder=mini_batch_decoder,
 
145
  )
146
 
147
  self.quant_conv = nn.Conv3d(2 * latent_channels, 2 * latent_channels, kernel_size=1)
148
  self.post_quant_conv = nn.Conv3d(latent_channels, latent_channels, kernel_size=1)
149
 
 
150
  self.slice_compression_vae = slice_compression_vae
 
 
151
  self.mini_batch_encoder = mini_batch_encoder
152
  self.mini_batch_decoder = mini_batch_decoder
153
  self.use_slicing = False
154
  self.use_tiling = use_tiling
155
- self.tile_sample_min_size = 384
156
- self.tile_overlap_factor = 0.25
 
 
 
157
  self.tile_latent_min_size = int(self.tile_sample_min_size / (2 ** (len(ch_mult) - 1)))
158
  self.scaling_factor = scaling_factor
159
 
@@ -253,8 +288,16 @@ class AutoencoderKLMagvit(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
253
  The latent representations of the encoded images. If `return_dict` is True, a
254
  [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
255
  """
 
 
 
 
256
  if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
257
- return self.tiled_encode(x, return_dict=return_dict)
 
 
 
 
258
 
259
  if self.use_slicing and x.shape[0] > 1:
260
  encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
@@ -271,8 +314,15 @@ class AutoencoderKLMagvit(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
271
  return AutoencoderKLOutput(latent_dist=posterior)
272
 
273
  def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
 
 
 
 
274
  if self.use_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size):
275
  return self.tiled_decode(z, return_dict=return_dict)
 
 
 
276
  z = self.post_quant_conv(z)
277
  dec = self.decoder(z)
278
 
@@ -408,6 +458,34 @@ class AutoencoderKLMagvit(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
408
  result_rows.append(torch.cat(result_row, dim=4))
409
 
410
  dec = torch.cat(result_rows, dim=3)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
411
  if not return_dict:
412
  return (dec,)
413
 
@@ -507,3 +585,441 @@ class AutoencoderKLMagvit(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
507
  print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
508
  print(m, u)
509
  return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  import torch
17
  import torch.nn as nn
 
18
  from diffusers.configuration_utils import ConfigMixin, register_to_config
19
+ from diffusers.loaders.single_file_model import FromOriginalModelMixin
20
+ from diffusers.models.autoencoders.vae import (DecoderOutput,
21
+ DiagonalGaussianDistribution)
22
+ from diffusers.models.modeling_outputs import AutoencoderKLOutput
23
+ from diffusers.models.modeling_utils import ModelMixin
24
+ from diffusers.utils import logging
25
+ from diffusers.utils.accelerate_utils import apply_forward_hook
26
 
27
  try:
28
  from diffusers.loaders import FromOriginalVAEMixin
 
38
  from diffusers.models.modeling_utils import ModelMixin
39
  from diffusers.utils.accelerate_utils import apply_forward_hook
40
  from torch import nn
41
+ from diffusers import AutoencoderKL
42
 
43
+ from ..vae.ldm.models.cogvideox_enc_dec import (CogVideoXCausalConv3d,
44
+ CogVideoXDecoder3D,
45
+ CogVideoXEncoder3D,
46
+ CogVideoXSafeConv3d)
47
  from ..vae.ldm.models.omnigen_enc_dec import Decoder as omnigen_Mag_Decoder
48
  from ..vae.ldm.models.omnigen_enc_dec import Encoder as omnigen_Mag_Encoder
49
 
50
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
51
 
52
  def str_eval(item):
53
  if type(item) == str:
 
109
  latent_channels: int = 4,
110
  norm_num_groups: int = 32,
111
  scaling_factor: float = 0.1825,
112
+ slice_mag_vae=True,
113
  slice_compression_vae=False,
114
+ cache_compression_vae=False,
115
+ cache_mag_vae=False,
116
  use_tiling=False,
117
+ use_tiling_encoder=False,
118
+ use_tiling_decoder=False,
119
  mini_batch_encoder=9,
120
  mini_batch_decoder=3,
121
+ upcast_vae=False,
122
+ spatial_group_norm=False,
123
+ tile_sample_min_size=384,
124
+ tile_overlap_factor=0.25,
125
  ):
126
  super().__init__()
127
  down_block_types = str_eval(down_block_types)
 
142
  act_fn=act_fn,
143
  num_attention_heads=num_attention_heads,
144
  double_z=True,
145
+ slice_mag_vae=slice_mag_vae,
146
  slice_compression_vae=slice_compression_vae,
147
+ cache_compression_vae=cache_compression_vae,
148
+ cache_mag_vae=cache_mag_vae,
149
  mini_batch_encoder=mini_batch_encoder,
150
+ spatial_group_norm=spatial_group_norm,
151
  )
152
 
153
  self.decoder = omnigen_Mag_Decoder(
 
165
  norm_num_groups=norm_num_groups,
166
  act_fn=act_fn,
167
  num_attention_heads=num_attention_heads,
168
+ slice_mag_vae=slice_mag_vae,
169
  slice_compression_vae=slice_compression_vae,
170
+ cache_compression_vae=cache_compression_vae,
171
+ cache_mag_vae=cache_mag_vae,
172
  mini_batch_decoder=mini_batch_decoder,
173
+ spatial_group_norm=spatial_group_norm,
174
  )
175
 
176
  self.quant_conv = nn.Conv3d(2 * latent_channels, 2 * latent_channels, kernel_size=1)
177
  self.post_quant_conv = nn.Conv3d(latent_channels, latent_channels, kernel_size=1)
178
 
179
+ self.slice_mag_vae = slice_mag_vae
180
  self.slice_compression_vae = slice_compression_vae
181
+ self.cache_compression_vae = cache_compression_vae
182
+ self.cache_mag_vae = cache_mag_vae
183
  self.mini_batch_encoder = mini_batch_encoder
184
  self.mini_batch_decoder = mini_batch_decoder
185
  self.use_slicing = False
186
  self.use_tiling = use_tiling
187
+ self.use_tiling_encoder = use_tiling_encoder
188
+ self.use_tiling_decoder = use_tiling_decoder
189
+ self.upcast_vae = upcast_vae
190
+ self.tile_sample_min_size = tile_sample_min_size
191
+ self.tile_overlap_factor = tile_overlap_factor
192
  self.tile_latent_min_size = int(self.tile_sample_min_size / (2 ** (len(ch_mult) - 1)))
193
  self.scaling_factor = scaling_factor
194
 
 
288
  The latent representations of the encoded images. If `return_dict` is True, a
289
  [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
290
  """
291
+ if self.upcast_vae:
292
+ x = x.float()
293
+ self.encoder = self.encoder.float()
294
+ self.quant_conv = self.quant_conv.float()
295
  if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
296
+ x = self.tiled_encode(x, return_dict=return_dict)
297
+ return x
298
+ if self.use_tiling_encoder and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
299
+ x = self.tiled_encode(x, return_dict=return_dict)
300
+ return x
301
 
302
  if self.use_slicing and x.shape[0] > 1:
303
  encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
 
314
  return AutoencoderKLOutput(latent_dist=posterior)
315
 
316
  def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
317
+ if self.upcast_vae:
318
+ z = z.float()
319
+ self.decoder = self.decoder.float()
320
+ self.post_quant_conv = self.post_quant_conv.float()
321
  if self.use_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size):
322
  return self.tiled_decode(z, return_dict=return_dict)
323
+ if self.use_tiling_decoder and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size):
324
+ return self.tiled_decode(z, return_dict=return_dict)
325
+
326
  z = self.post_quant_conv(z)
327
  dec = self.decoder(z)
328
 
 
458
  result_rows.append(torch.cat(result_row, dim=4))
459
 
460
  dec = torch.cat(result_rows, dim=3)
461
+
462
+ # Handle the lower right corner tile separately
463
+ lower_right_original = z[
464
+ :,
465
+ :,
466
+ :,
467
+ -self.tile_latent_min_size:,
468
+ -self.tile_latent_min_size:
469
+ ]
470
+ quantized_lower_right = self.decoder(self.post_quant_conv(lower_right_original))
471
+
472
+ # Combine
473
+ H, W = quantized_lower_right.size(-2), quantized_lower_right.size(-1)
474
+ x_weights = torch.linspace(0, 1, W).unsqueeze(0).repeat(H, 1)
475
+ y_weights = torch.linspace(0, 1, H).unsqueeze(1).repeat(1, W)
476
+ weights = torch.min(x_weights, y_weights)
477
+
478
+ if len(dec.size()) == 4:
479
+ weights = weights.unsqueeze(0).unsqueeze(0)
480
+ elif len(dec.size()) == 5:
481
+ weights = weights.unsqueeze(0).unsqueeze(0).unsqueeze(0)
482
+
483
+ weights = weights.to(dec.device)
484
+ quantized_area = dec[:, :, :, -H:, -W:]
485
+ combined = weights * quantized_lower_right + (1 - weights) * quantized_area
486
+
487
+ dec[:, :, :, -H:, -W:] = combined
488
+
489
  if not return_dict:
490
  return (dec,)
491
 
 
585
  print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
586
  print(m, u)
587
  return model
588
+
589
+
590
+ # Modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py
591
+ # Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
592
+ # All rights reserved.
593
+ #
594
+ # Licensed under the Apache License, Version 2.0 (the "License");
595
+ # you may not use this file except in compliance with the License.
596
+ # You may obtain a copy of the License at
597
+ #
598
+ # http://www.apache.org/licenses/LICENSE-2.0
599
+ #
600
+ # Unless required by applicable law or agreed to in writing, software
601
+ # distributed under the License is distributed on an "AS IS" BASIS,
602
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
603
+ # See the License for the specific language governing permissions and
604
+ # limitations under the License.
605
+
606
+
607
+ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
608
+ r"""
609
+ A VAE model with KL loss for encoding images into latents and decoding latent representations into images. Used in
610
+ [CogVideoX](https://github.com/THUDM/CogVideo).
611
+
612
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
613
+ for all models (such as downloading or saving).
614
+
615
+ Parameters:
616
+ in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
617
+ out_channels (int, *optional*, defaults to 3): Number of channels in the output.
618
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
619
+ Tuple of downsample block types.
620
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
621
+ Tuple of upsample block types.
622
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
623
+ Tuple of block output channels.
624
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
625
+ sample_size (`int`, *optional*, defaults to `32`): Sample input size.
626
+ scaling_factor (`float`, *optional*, defaults to `1.15258426`):
627
+ The component-wise standard deviation of the trained latent space computed using the first batch of the
628
+ training set. This is used to scale the latent space to have unit variance when training the diffusion
629
+ model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
630
+ diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
631
+ / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
632
+ Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
633
+ force_upcast (`bool`, *optional*, default to `True`):
634
+ If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE
635
+ can be fine-tuned / trained to a lower range without loosing too much precision in which case
636
+ `force_upcast` can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix
637
+ """
638
+
639
+ _supports_gradient_checkpointing = True
640
+ _no_split_modules = ["CogVideoXResnetBlock3D"]
641
+
642
+ @register_to_config
643
+ def __init__(
644
+ self,
645
+ in_channels: int = 3,
646
+ out_channels: int = 3,
647
+ down_block_types: Tuple[str] = (
648
+ "CogVideoXDownBlock3D",
649
+ "CogVideoXDownBlock3D",
650
+ "CogVideoXDownBlock3D",
651
+ "CogVideoXDownBlock3D",
652
+ ),
653
+ up_block_types: Tuple[str] = (
654
+ "CogVideoXUpBlock3D",
655
+ "CogVideoXUpBlock3D",
656
+ "CogVideoXUpBlock3D",
657
+ "CogVideoXUpBlock3D",
658
+ ),
659
+ block_out_channels: Tuple[int] = (128, 256, 256, 512),
660
+ latent_channels: int = 16,
661
+ layers_per_block: int = 3,
662
+ act_fn: str = "silu",
663
+ norm_eps: float = 1e-6,
664
+ norm_num_groups: int = 32,
665
+ temporal_compression_ratio: float = 4,
666
+ sample_height: int = 480,
667
+ sample_width: int = 720,
668
+ scaling_factor: float = 1.15258426,
669
+ shift_factor: Optional[float] = None,
670
+ latents_mean: Optional[Tuple[float]] = None,
671
+ latents_std: Optional[Tuple[float]] = None,
672
+ force_upcast: float = True,
673
+ use_quant_conv: bool = False,
674
+ use_post_quant_conv: bool = False,
675
+ slice_mag_vae=False,
676
+ slice_compression_vae=False,
677
+ cache_compression_vae=False,
678
+ cache_mag_vae=True,
679
+ use_tiling=False,
680
+ mini_batch_encoder=4,
681
+ mini_batch_decoder=1,
682
+ ):
683
+ super().__init__()
684
+
685
+ self.encoder = CogVideoXEncoder3D(
686
+ in_channels=in_channels,
687
+ out_channels=latent_channels,
688
+ down_block_types=down_block_types,
689
+ block_out_channels=block_out_channels,
690
+ layers_per_block=layers_per_block,
691
+ act_fn=act_fn,
692
+ norm_eps=norm_eps,
693
+ norm_num_groups=norm_num_groups,
694
+ temporal_compression_ratio=temporal_compression_ratio,
695
+ )
696
+ self.decoder = CogVideoXDecoder3D(
697
+ in_channels=latent_channels,
698
+ out_channels=out_channels,
699
+ up_block_types=up_block_types,
700
+ block_out_channels=block_out_channels,
701
+ layers_per_block=layers_per_block,
702
+ act_fn=act_fn,
703
+ norm_eps=norm_eps,
704
+ norm_num_groups=norm_num_groups,
705
+ temporal_compression_ratio=temporal_compression_ratio,
706
+ )
707
+ self.quant_conv = CogVideoXSafeConv3d(2 * out_channels, 2 * out_channels, 1) if use_quant_conv else None
708
+ self.post_quant_conv = CogVideoXSafeConv3d(out_channels, out_channels, 1) if use_post_quant_conv else None
709
+
710
+ self.use_slicing = False
711
+ self.use_tiling = use_tiling
712
+
713
+ # Can be increased to decode more latent frames at once, but comes at a reasonable memory cost and it is not
714
+ # recommended because the temporal parts of the VAE, here, are tricky to understand.
715
+ # If you decode X latent frames together, the number of output frames is:
716
+ # (X + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale)) => X + 6 frames
717
+ #
718
+ # Example with num_latent_frames_batch_size = 2:
719
+ # - 12 latent frames: (0, 1), (2, 3), (4, 5), (6, 7), (8, 9), (10, 11) are processed together
720
+ # => (12 // 2 frame slices) * ((2 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale))
721
+ # => 6 * 8 = 48 frames
722
+ # - 13 latent frames: (0, 1, 2) (special case), (3, 4), (5, 6), (7, 8), (9, 10), (11, 12) are processed together
723
+ # => (1 frame slice) * ((3 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale)) +
724
+ # ((13 - 3) // 2) * ((2 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale))
725
+ # => 1 * 9 + 5 * 8 = 49 frames
726
+ # It has been implemented this way so as to not have "magic values" in the code base that would be hard to explain. Note that
727
+ # setting it to anything other than 2 would give poor results because the VAE hasn't been trained to be adaptive with different
728
+ # number of temporal frames.
729
+ self.num_latent_frames_batch_size = 2
730
+
731
+ # We make the minimum height and width of sample for tiling half that of the generally supported
732
+ self.tile_sample_min_height = sample_height // 2
733
+ self.tile_sample_min_width = sample_width // 2
734
+ self.tile_latent_min_height = int(
735
+ self.tile_sample_min_height / (2 ** (len(self.config.block_out_channels) - 1))
736
+ )
737
+ self.tile_latent_min_width = int(self.tile_sample_min_width / (2 ** (len(self.config.block_out_channels) - 1)))
738
+
739
+ # These are experimental overlap factors that were chosen based on experimentation and seem to work best for
740
+ # 720x480 (WxH) resolution. The above resolution is the strongly recommended generation resolution in CogVideoX
741
+ # and so the tiling implementation has only been tested on those specific resolutions.
742
+ self.tile_overlap_factor_height = 1 / 6
743
+ self.tile_overlap_factor_width = 1 / 5
744
+
745
+ def _set_gradient_checkpointing(self, module, value=False):
746
+ if isinstance(module, (CogVideoXEncoder3D, CogVideoXDecoder3D)):
747
+ module.gradient_checkpointing = value
748
+
749
+ def _clear_fake_context_parallel_cache(self):
750
+ for name, module in self.named_modules():
751
+ if isinstance(module, CogVideoXCausalConv3d):
752
+ logger.debug(f"Clearing fake Context Parallel cache for layer: {name}")
753
+ module._clear_fake_context_parallel_cache()
754
+
755
+ def enable_tiling(
756
+ self,
757
+ tile_sample_min_height: Optional[int] = None,
758
+ tile_sample_min_width: Optional[int] = None,
759
+ tile_overlap_factor_height: Optional[float] = None,
760
+ tile_overlap_factor_width: Optional[float] = None,
761
+ ) -> None:
762
+ r"""
763
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
764
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
765
+ processing larger images.
766
+
767
+ Args:
768
+ tile_sample_min_height (`int`, *optional*):
769
+ The minimum height required for a sample to be separated into tiles across the height dimension.
770
+ tile_sample_min_width (`int`, *optional*):
771
+ The minimum width required for a sample to be separated into tiles across the width dimension.
772
+ tile_overlap_factor_height (`int`, *optional*):
773
+ The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
774
+ no tiling artifacts produced across the height dimension. Must be between 0 and 1. Setting a higher
775
+ value might cause more tiles to be processed leading to slow down of the decoding process.
776
+ tile_overlap_factor_width (`int`, *optional*):
777
+ The minimum amount of overlap between two consecutive horizontal tiles. This is to ensure that there
778
+ are no tiling artifacts produced across the width dimension. Must be between 0 and 1. Setting a higher
779
+ value might cause more tiles to be processed leading to slow down of the decoding process.
780
+ """
781
+ self.use_tiling = True
782
+ self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
783
+ self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
784
+ self.tile_latent_min_height = int(
785
+ self.tile_sample_min_height / (2 ** (len(self.config.block_out_channels) - 1))
786
+ )
787
+ self.tile_latent_min_width = int(self.tile_sample_min_width / (2 ** (len(self.config.block_out_channels) - 1)))
788
+ self.tile_overlap_factor_height = tile_overlap_factor_height or self.tile_overlap_factor_height
789
+ self.tile_overlap_factor_width = tile_overlap_factor_width or self.tile_overlap_factor_width
790
+
791
+ def disable_tiling(self) -> None:
792
+ r"""
793
+ Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
794
+ decoding in one step.
795
+ """
796
+ self.use_tiling = False
797
+
798
+ def enable_slicing(self) -> None:
799
+ r"""
800
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
801
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
802
+ """
803
+ self.use_slicing = True
804
+
805
+ def disable_slicing(self) -> None:
806
+ r"""
807
+ Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
808
+ decoding in one step.
809
+ """
810
+ self.use_slicing = False
811
+
812
+ @apply_forward_hook
813
+ def encode(
814
+ self, x: torch.Tensor, return_dict: bool = True
815
+ ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
816
+ """
817
+ Encode a batch of images into latents.
818
+
819
+ Args:
820
+ x (`torch.Tensor`): Input batch of images.
821
+ return_dict (`bool`, *optional*, defaults to `True`):
822
+ Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
823
+
824
+ Returns:
825
+ The latent representations of the encoded images. If `return_dict` is True, a
826
+ [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
827
+ """
828
+ batch_size, num_channels, num_frames, height, width = x.shape
829
+ if num_frames == 1:
830
+ h = self.encoder(x)
831
+ if self.quant_conv is not None:
832
+ h = self.quant_conv(h)
833
+ posterior = DiagonalGaussianDistribution(h)
834
+ else:
835
+ frame_batch_size = 4
836
+ h = []
837
+ for i in range(num_frames // frame_batch_size):
838
+ remaining_frames = num_frames % frame_batch_size
839
+ start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames)
840
+ end_frame = frame_batch_size * (i + 1) + remaining_frames
841
+ z_intermediate = x[:, :, start_frame:end_frame]
842
+ z_intermediate = self.encoder(z_intermediate)
843
+ if self.quant_conv is not None:
844
+ z_intermediate = self.quant_conv(z_intermediate)
845
+ h.append(z_intermediate)
846
+ self._clear_fake_context_parallel_cache()
847
+ h = torch.cat(h, dim=2)
848
+ posterior = DiagonalGaussianDistribution(h)
849
+ self._clear_fake_context_parallel_cache()
850
+ if not return_dict:
851
+ return (posterior,)
852
+ return AutoencoderKLOutput(latent_dist=posterior)
853
+
854
+ def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
855
+ batch_size, num_channels, num_frames, height, width = z.shape
856
+
857
+ if self.use_tiling and (width > self.tile_latent_min_width or height > self.tile_latent_min_height):
858
+ return self.tiled_decode(z, return_dict=return_dict)
859
+
860
+ if num_frames == 1:
861
+ dec = []
862
+ z_intermediate = z
863
+ if self.post_quant_conv is not None:
864
+ z_intermediate = self.post_quant_conv(z_intermediate)
865
+ z_intermediate = self.decoder(z_intermediate)
866
+ dec.append(z_intermediate)
867
+ else:
868
+ frame_batch_size = self.num_latent_frames_batch_size
869
+ dec = []
870
+ for i in range(num_frames // frame_batch_size):
871
+ remaining_frames = num_frames % frame_batch_size
872
+ start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames)
873
+ end_frame = frame_batch_size * (i + 1) + remaining_frames
874
+ z_intermediate = z[:, :, start_frame:end_frame]
875
+ if self.post_quant_conv is not None:
876
+ z_intermediate = self.post_quant_conv(z_intermediate)
877
+ z_intermediate = self.decoder(z_intermediate)
878
+ dec.append(z_intermediate)
879
+
880
+ self._clear_fake_context_parallel_cache()
881
+ dec = torch.cat(dec, dim=2)
882
+
883
+ if not return_dict:
884
+ return (dec,)
885
+
886
+ return DecoderOutput(sample=dec)
887
+
888
+ @apply_forward_hook
889
+ def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
890
+ """
891
+ Decode a batch of images.
892
+
893
+ Args:
894
+ z (`torch.Tensor`): Input batch of latent vectors.
895
+ return_dict (`bool`, *optional*, defaults to `True`):
896
+ Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
897
+
898
+ Returns:
899
+ [`~models.vae.DecoderOutput`] or `tuple`:
900
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
901
+ returned.
902
+ """
903
+ if self.use_slicing and z.shape[0] > 1:
904
+ decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
905
+ decoded = torch.cat(decoded_slices)
906
+ else:
907
+ decoded = self._decode(z).sample
908
+
909
+ if not return_dict:
910
+ return (decoded,)
911
+ return DecoderOutput(sample=decoded)
912
+
913
+ def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
914
+ blend_extent = min(a.shape[3], b.shape[3], blend_extent)
915
+ for y in range(blend_extent):
916
+ b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (
917
+ y / blend_extent
918
+ )
919
+ return b
920
+
921
+ def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
922
+ blend_extent = min(a.shape[4], b.shape[4], blend_extent)
923
+ for x in range(blend_extent):
924
+ b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (
925
+ x / blend_extent
926
+ )
927
+ return b
928
+
929
+ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
930
+ r"""
931
+ Decode a batch of images using a tiled decoder.
932
+
933
+ Args:
934
+ z (`torch.Tensor`): Input batch of latent vectors.
935
+ return_dict (`bool`, *optional*, defaults to `True`):
936
+ Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
937
+
938
+ Returns:
939
+ [`~models.vae.DecoderOutput`] or `tuple`:
940
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
941
+ returned.
942
+ """
943
+ # Rough memory assessment:
944
+ # - In CogVideoX-2B, there are a total of 24 CausalConv3d layers.
945
+ # - The biggest intermediate dimensions are: [1, 128, 9, 480, 720].
946
+ # - Assume fp16 (2 bytes per value).
947
+ # Memory required: 1 * 128 * 9 * 480 * 720 * 24 * 2 / 1024**3 = 17.8 GB
948
+ #
949
+ # Memory assessment when using tiling:
950
+ # - Assume everything as above but now HxW is 240x360 by tiling in half
951
+ # Memory required: 1 * 128 * 9 * 240 * 360 * 24 * 2 / 1024**3 = 4.5 GB
952
+
953
+ batch_size, num_channels, num_frames, height, width = z.shape
954
+
955
+ overlap_height = int(self.tile_latent_min_height * (1 - self.tile_overlap_factor_height))
956
+ overlap_width = int(self.tile_latent_min_width * (1 - self.tile_overlap_factor_width))
957
+ blend_extent_height = int(self.tile_sample_min_height * self.tile_overlap_factor_height)
958
+ blend_extent_width = int(self.tile_sample_min_width * self.tile_overlap_factor_width)
959
+ row_limit_height = self.tile_sample_min_height - blend_extent_height
960
+ row_limit_width = self.tile_sample_min_width - blend_extent_width
961
+ frame_batch_size = self.num_latent_frames_batch_size
962
+
963
+ # Split z into overlapping tiles and decode them separately.
964
+ # The tiles have an overlap to avoid seams between tiles.
965
+ rows = []
966
+ for i in range(0, height, overlap_height):
967
+ row = []
968
+ for j in range(0, width, overlap_width):
969
+ time = []
970
+ for k in range(num_frames // frame_batch_size):
971
+ remaining_frames = num_frames % frame_batch_size
972
+ start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames)
973
+ end_frame = frame_batch_size * (k + 1) + remaining_frames
974
+ tile = z[
975
+ :,
976
+ :,
977
+ start_frame:end_frame,
978
+ i : i + self.tile_latent_min_height,
979
+ j : j + self.tile_latent_min_width,
980
+ ]
981
+ if self.post_quant_conv is not None:
982
+ tile = self.post_quant_conv(tile)
983
+ tile = self.decoder(tile)
984
+ time.append(tile)
985
+ self._clear_fake_context_parallel_cache()
986
+ row.append(torch.cat(time, dim=2))
987
+ rows.append(row)
988
+
989
+ result_rows = []
990
+ for i, row in enumerate(rows):
991
+ result_row = []
992
+ for j, tile in enumerate(row):
993
+ # blend the above tile and the left tile
994
+ # to the current tile and add the current tile to the result row
995
+ if i > 0:
996
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent_height)
997
+ if j > 0:
998
+ tile = self.blend_h(row[j - 1], tile, blend_extent_width)
999
+ result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width])
1000
+ result_rows.append(torch.cat(result_row, dim=4))
1001
+
1002
+ dec = torch.cat(result_rows, dim=3)
1003
+
1004
+ if not return_dict:
1005
+ return (dec,)
1006
+
1007
+ return DecoderOutput(sample=dec)
1008
+
1009
+ def forward(
1010
+ self,
1011
+ sample: torch.Tensor,
1012
+ sample_posterior: bool = False,
1013
+ return_dict: bool = True,
1014
+ generator: Optional[torch.Generator] = None,
1015
+ ) -> Union[torch.Tensor, torch.Tensor]:
1016
+ x = sample
1017
+ posterior = self.encode(x).latent_dist
1018
+ if sample_posterior:
1019
+ z = posterior.sample(generator=generator)
1020
+ else:
1021
+ z = posterior.mode()
1022
+ dec = self.decode(z)
1023
+ if not return_dict:
1024
+ return (dec,)
1025
+ return dec
easyanimate/models/embeddings.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Optional
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from diffusers.models.embeddings import (PixArtAlphaTextProjection, get_timestep_embedding,
8
+ TimestepEmbedding, Timesteps)
9
+ from einops import rearrange
10
+ from torch import nn
11
+
12
+
13
+ class HunyuanDiTAttentionPool(nn.Module):
14
+ def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
15
+ super().__init__()
16
+ self.positional_embedding = nn.Parameter(torch.randn(spacial_dim + 1, embed_dim) / embed_dim**0.5)
17
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
18
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
19
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
20
+ self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
21
+ self.num_heads = num_heads
22
+
23
+ def forward(self, x):
24
+ x = torch.cat([x.mean(dim=1, keepdim=True), x], dim=1)
25
+ x = x + self.positional_embedding[None, :, :].to(x.dtype)
26
+
27
+ query = self.q_proj(x[:, :1])
28
+ key = self.k_proj(x)
29
+ value = self.v_proj(x)
30
+ batch_size, _, _ = query.size()
31
+
32
+ query = query.reshape(batch_size, -1, self.num_heads, query.size(-1) // self.num_heads).transpose(1, 2) # (1, H, N, E/H)
33
+ key = key.reshape(batch_size, -1, self.num_heads, key.size(-1) // self.num_heads).transpose(1, 2) # (L+1, H, N, E/H)
34
+ value = value.reshape(batch_size, -1, self.num_heads, value.size(-1) // self.num_heads).transpose(1, 2) # (L+1, H, N, E/H)
35
+
36
+ x = F.scaled_dot_product_attention(query=query, key=key, value=value, attn_mask=None, dropout_p=0.0, is_causal=False)
37
+ x = x.transpose(1, 2).reshape(batch_size, 1, -1)
38
+ x = x.to(query.dtype)
39
+ x = self.c_proj(x)
40
+
41
+ return x.squeeze(1)
42
+
43
+
44
+ class HunyuanCombinedTimestepTextSizeStyleEmbedding(nn.Module):
45
+ def __init__(self, embedding_dim, pooled_projection_dim=1024, seq_len=256, cross_attention_dim=2048):
46
+ super().__init__()
47
+
48
+ self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
49
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
50
+
51
+ self.pooler = HunyuanDiTAttentionPool(
52
+ seq_len, cross_attention_dim, num_heads=8, output_dim=pooled_projection_dim
53
+ )
54
+ # Here we use a default learned embedder layer for future extension.
55
+ self.style_embedder = nn.Embedding(1, embedding_dim)
56
+ extra_in_dim = 256 * 6 + embedding_dim + pooled_projection_dim
57
+ self.extra_embedder = PixArtAlphaTextProjection(
58
+ in_features=extra_in_dim,
59
+ hidden_size=embedding_dim * 4,
60
+ out_features=embedding_dim,
61
+ act_fn="silu_fp32",
62
+ )
63
+
64
+ def forward(self, timestep, encoder_hidden_states, image_meta_size, style, hidden_dtype=None):
65
+ timesteps_proj = self.time_proj(timestep)
66
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, 256)
67
+
68
+ # extra condition1: text
69
+ pooled_projections = self.pooler(encoder_hidden_states) # (N, 1024)
70
+
71
+ # extra condition2: image meta size embdding
72
+ image_meta_size = get_timestep_embedding(image_meta_size.view(-1), 256, True, 0)
73
+ image_meta_size = image_meta_size.to(dtype=hidden_dtype)
74
+ image_meta_size = image_meta_size.view(-1, 6 * 256) # (N, 1536)
75
+
76
+ # extra condition3: style embedding
77
+ style_embedding = self.style_embedder(style) # (N, embedding_dim)
78
+
79
+ # Concatenate all extra vectors
80
+ extra_cond = torch.cat([pooled_projections, image_meta_size, style_embedding], dim=1)
81
+ conditioning = timesteps_emb + self.extra_embedder(extra_cond) # [B, D]
82
+
83
+ return conditioning
84
+
85
+
86
+ class TimePositionalEncoding(nn.Module):
87
+ def __init__(
88
+ self,
89
+ d_model,
90
+ dropout = 0.,
91
+ max_len = 24
92
+ ):
93
+ super().__init__()
94
+ self.dropout = nn.Dropout(p=dropout)
95
+ position = torch.arange(max_len).unsqueeze(1)
96
+ div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
97
+ pe = torch.zeros(1, max_len, d_model)
98
+ pe[0, :, 0::2] = torch.sin(position * div_term)
99
+ pe[0, :, 1::2] = torch.cos(position * div_term)
100
+ self.register_buffer('pe', pe)
101
+
102
+ def forward(self, x):
103
+ b, c, f, h, w = x.size()
104
+ x = rearrange(x, "b c f h w -> (b h w) f c")
105
+ x = x + self.pe[:, :x.size(1)]
106
+ x = rearrange(x, "(b h w) f c -> b c f h w", b=b, h=h, w=w)
107
+ return self.dropout(x)
easyanimate/models/norm.py CHANGED
@@ -2,7 +2,8 @@ from typing import Any, Dict, Optional, Tuple
2
 
3
  import torch
4
  import torch.nn.functional as F
5
- from diffusers.models.embeddings import TimestepEmbedding, Timesteps
 
6
  from torch import nn
7
 
8
 
@@ -12,7 +13,6 @@ def zero_module(module):
12
  p.detach().zero_()
13
  return module
14
 
15
-
16
  class FP32LayerNorm(nn.LayerNorm):
17
  def forward(self, inputs: torch.Tensor) -> torch.Tensor:
18
  origin_dtype = inputs.dtype
@@ -95,3 +95,56 @@ class AdaLayerNormSingle(nn.Module):
95
  # No modulation happening here.
96
  embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype)
97
  return self.linear(self.silu(embedded_timestep)), embedded_timestep
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  import torch
4
  import torch.nn.functional as F
5
+ from diffusers.models.embeddings import (CombinedTimestepLabelEmbeddings,
6
+ TimestepEmbedding, Timesteps)
7
  from torch import nn
8
 
9
 
 
13
  p.detach().zero_()
14
  return module
15
 
 
16
  class FP32LayerNorm(nn.LayerNorm):
17
  def forward(self, inputs: torch.Tensor) -> torch.Tensor:
18
  origin_dtype = inputs.dtype
 
95
  # No modulation happening here.
96
  embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype)
97
  return self.linear(self.silu(embedded_timestep)), embedded_timestep
98
+
99
+ class AdaLayerNormShift(nn.Module):
100
+ r"""
101
+ Norm layer modified to incorporate timestep embeddings.
102
+
103
+ Parameters:
104
+ embedding_dim (`int`): The size of each embedding vector.
105
+ num_embeddings (`int`): The size of the embeddings dictionary.
106
+ """
107
+
108
+ def __init__(self, embedding_dim: int, elementwise_affine=True, eps=1e-6):
109
+ super().__init__()
110
+ self.silu = nn.SiLU()
111
+ self.linear = nn.Linear(embedding_dim, embedding_dim)
112
+ self.norm = FP32LayerNorm(embedding_dim, elementwise_affine=elementwise_affine, eps=eps)
113
+
114
+ def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
115
+ shift = self.linear(self.silu(emb.to(torch.float32)).to(emb.dtype))
116
+ x = self.norm(x) + shift.unsqueeze(dim=1)
117
+ return x
118
+
119
+ class EasyAnimateLayerNormZero(nn.Module):
120
+ # Modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/normalization.py
121
+ # Add fp32 layer norm
122
+ def __init__(
123
+ self,
124
+ conditioning_dim: int,
125
+ embedding_dim: int,
126
+ elementwise_affine: bool = True,
127
+ eps: float = 1e-5,
128
+ bias: bool = True,
129
+ norm_type: str = "fp32_layer_norm",
130
+ ) -> None:
131
+ super().__init__()
132
+
133
+ self.silu = nn.SiLU()
134
+ self.linear = nn.Linear(conditioning_dim, 6 * embedding_dim, bias=bias)
135
+ if norm_type == "layer_norm":
136
+ self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=elementwise_affine, eps=eps)
137
+ elif norm_type == "fp32_layer_norm":
138
+ self.norm = FP32LayerNorm(embedding_dim, elementwise_affine=elementwise_affine, eps=eps)
139
+ else:
140
+ raise ValueError(
141
+ f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'."
142
+ )
143
+
144
+ def forward(
145
+ self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor
146
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
147
+ shift, scale, gate, enc_shift, enc_scale, enc_gate = self.linear(self.silu(temb)).chunk(6, dim=1)
148
+ hidden_states = self.norm(hidden_states) * (1 + scale)[:, None, :] + shift[:, None, :]
149
+ encoder_hidden_states = self.norm(encoder_hidden_states) * (1 + enc_scale)[:, None, :] + enc_shift[:, None, :]
150
+ return hidden_states, encoder_hidden_states, gate[:, None, :], enc_gate[:, None, :]
easyanimate/models/patch.py CHANGED
@@ -153,15 +153,6 @@ class TemporalUpsampler3D(Upsampler):
153
  x = torch.cat([first_frame, x], dim=2)
154
  return x
155
 
156
- def cast_tuple(t, length = 1):
157
- return t if isinstance(t, tuple) else ((t,) * length)
158
-
159
- def divisible_by(num, den):
160
- return (num % den) == 0
161
-
162
- def is_odd(n):
163
- return not divisible_by(n, 2)
164
-
165
  class CausalConv3d(nn.Conv3d):
166
  def __init__(
167
  self,
 
153
  x = torch.cat([first_frame, x], dim=2)
154
  return x
155
 
 
 
 
 
 
 
 
 
 
156
  class CausalConv3d(nn.Conv3d):
157
  def __init__(
158
  self,
easyanimate/models/processor.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from diffusers.models.attention import Attention
6
+ from diffusers.models.embeddings import apply_rotary_emb
7
+ from einops import rearrange, repeat
8
+
9
+
10
+ class HunyuanAttnProcessor2_0:
11
+ r"""
12
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
13
+ used in the HunyuanDiT model. It applies a s normalization layer and rotary embedding on query and key vector.
14
+ """
15
+
16
+ def __init__(self):
17
+ if not hasattr(F, "scaled_dot_product_attention"):
18
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
19
+
20
+ def __call__(
21
+ self,
22
+ attn: Attention,
23
+ hidden_states: torch.Tensor,
24
+ encoder_hidden_states: Optional[torch.Tensor] = None,
25
+ attention_mask: Optional[torch.Tensor] = None,
26
+ temb: Optional[torch.Tensor] = None,
27
+ image_rotary_emb: Optional[torch.Tensor] = None,
28
+ ) -> torch.Tensor:
29
+ residual = hidden_states
30
+ if attn.spatial_norm is not None:
31
+ hidden_states = attn.spatial_norm(hidden_states, temb)
32
+
33
+ input_ndim = hidden_states.ndim
34
+
35
+ if input_ndim == 4:
36
+ batch_size, channel, height, width = hidden_states.shape
37
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
38
+
39
+ batch_size, sequence_length, _ = (
40
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
41
+ )
42
+
43
+ if attention_mask is not None:
44
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
45
+ # scaled_dot_product_attention expects attention_mask shape to be
46
+ # (batch, heads, source_length, target_length)
47
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
48
+
49
+ if attn.group_norm is not None:
50
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
51
+
52
+ query = attn.to_q(hidden_states)
53
+
54
+ if encoder_hidden_states is None:
55
+ encoder_hidden_states = hidden_states
56
+ elif attn.norm_cross:
57
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
58
+
59
+ key = attn.to_k(encoder_hidden_states)
60
+ value = attn.to_v(encoder_hidden_states)
61
+
62
+ inner_dim = key.shape[-1]
63
+ head_dim = inner_dim // attn.heads
64
+
65
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
66
+
67
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
68
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
69
+
70
+ if attn.norm_q is not None:
71
+ query = attn.norm_q(query)
72
+ if attn.norm_k is not None:
73
+ key = attn.norm_k(key)
74
+
75
+ # Apply RoPE if needed
76
+ if image_rotary_emb is not None:
77
+ query = apply_rotary_emb(query, image_rotary_emb)
78
+ if not attn.is_cross_attention:
79
+ key = apply_rotary_emb(key, image_rotary_emb)
80
+
81
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
82
+ # TODO: add support for attn.scale when we move to Torch 2.1
83
+ hidden_states = F.scaled_dot_product_attention(
84
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
85
+ )
86
+
87
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
88
+ hidden_states = hidden_states.to(query.dtype)
89
+
90
+ # linear proj
91
+ hidden_states = attn.to_out[0](hidden_states)
92
+ # dropout
93
+ hidden_states = attn.to_out[1](hidden_states)
94
+
95
+ if input_ndim == 4:
96
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
97
+
98
+ if attn.residual_connection:
99
+ hidden_states = hidden_states + residual
100
+
101
+ hidden_states = hidden_states / attn.rescale_output_factor
102
+
103
+ return hidden_states
104
+
105
+ class LazyKVCompressionProcessor2_0:
106
+ r"""
107
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
108
+ used in the KVCompression model. It applies a s normalization layer and rotary embedding on query and key vector.
109
+ """
110
+
111
+ def __init__(self):
112
+ if not hasattr(F, "scaled_dot_product_attention"):
113
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
114
+
115
+ def __call__(
116
+ self,
117
+ attn: Attention,
118
+ hidden_states: torch.Tensor,
119
+ encoder_hidden_states: Optional[torch.Tensor] = None,
120
+ attention_mask: Optional[torch.Tensor] = None,
121
+ temb: Optional[torch.Tensor] = None,
122
+ image_rotary_emb: Optional[torch.Tensor] = None,
123
+ ) -> torch.Tensor:
124
+ residual = hidden_states
125
+ if attn.spatial_norm is not None:
126
+ hidden_states = attn.spatial_norm(hidden_states, temb)
127
+
128
+ input_ndim = hidden_states.ndim
129
+
130
+ batch_size, channel, num_frames, height, width = hidden_states.shape
131
+ hidden_states = rearrange(hidden_states, "b c f h w -> b (f h w) c", f=num_frames, h=height, w=width)
132
+
133
+ batch_size, sequence_length, _ = (
134
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
135
+ )
136
+
137
+ if attention_mask is not None:
138
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
139
+ # scaled_dot_product_attention expects attention_mask shape to be
140
+ # (batch, heads, source_length, target_length)
141
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
142
+
143
+ if attn.group_norm is not None:
144
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
145
+
146
+ query = attn.to_q(hidden_states)
147
+
148
+ if encoder_hidden_states is None:
149
+ encoder_hidden_states = hidden_states
150
+ elif attn.norm_cross:
151
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
152
+
153
+ key = attn.to_k(encoder_hidden_states)
154
+ value = attn.to_v(encoder_hidden_states)
155
+
156
+ key = rearrange(key, "b (f h w) c -> (b f) c h w", f=num_frames, h=height, w=width)
157
+ key = attn.k_compression(key)
158
+ key_shape = key.size()
159
+ key = rearrange(key, "(b f) c h w -> b (f h w) c", f=num_frames)
160
+
161
+ value = rearrange(value, "b (f h w) c -> (b f) c h w", f=num_frames, h=height, w=width)
162
+ value = attn.v_compression(value)
163
+ value = rearrange(value, "(b f) c h w -> b (f h w) c", f=num_frames)
164
+
165
+ inner_dim = key.shape[-1]
166
+ head_dim = inner_dim // attn.heads
167
+
168
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
169
+
170
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
171
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
172
+
173
+ if attn.norm_q is not None:
174
+ query = attn.norm_q(query)
175
+ if attn.norm_k is not None:
176
+ key = attn.norm_k(key)
177
+
178
+ # Apply RoPE if needed
179
+ if image_rotary_emb is not None:
180
+ compression_image_rotary_emb = (
181
+ rearrange(image_rotary_emb[0], "(f h w) c -> f c h w", f=num_frames, h=height, w=width),
182
+ rearrange(image_rotary_emb[1], "(f h w) c -> f c h w", f=num_frames, h=height, w=width),
183
+ )
184
+ compression_image_rotary_emb = (
185
+ F.interpolate(compression_image_rotary_emb[0], size=key_shape[-2:], mode='bilinear'),
186
+ F.interpolate(compression_image_rotary_emb[1], size=key_shape[-2:], mode='bilinear')
187
+ )
188
+ compression_image_rotary_emb = (
189
+ rearrange(compression_image_rotary_emb[0], "f c h w -> (f h w) c"),
190
+ rearrange(compression_image_rotary_emb[1], "f c h w -> (f h w) c"),
191
+ )
192
+
193
+ query = apply_rotary_emb(query, image_rotary_emb)
194
+ if not attn.is_cross_attention:
195
+ key = apply_rotary_emb(key, compression_image_rotary_emb)
196
+
197
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
198
+ # TODO: add support for attn.scale when we move to Torch 2.1
199
+ hidden_states = F.scaled_dot_product_attention(
200
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
201
+ )
202
+
203
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
204
+ hidden_states = hidden_states.to(query.dtype)
205
+
206
+ # linear proj
207
+ hidden_states = attn.to_out[0](hidden_states)
208
+ # dropout
209
+ hidden_states = attn.to_out[1](hidden_states)
210
+
211
+ if attn.residual_connection:
212
+ hidden_states = hidden_states + residual
213
+
214
+ hidden_states = hidden_states / attn.rescale_output_factor
215
+
216
+ return hidden_states
217
+
218
+ class EasyAnimateAttnProcessor2_0:
219
+ def __init__(self):
220
+ pass
221
+
222
+ def __call__(
223
+ self,
224
+ attn: Attention,
225
+ hidden_states: torch.Tensor,
226
+ encoder_hidden_states: torch.Tensor,
227
+ attention_mask: Optional[torch.Tensor] = None,
228
+ image_rotary_emb: Optional[torch.Tensor] = None,
229
+ attn2: Attention = None,
230
+ ) -> torch.Tensor:
231
+ text_seq_length = encoder_hidden_states.size(1)
232
+
233
+ batch_size, sequence_length, _ = (
234
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
235
+ )
236
+
237
+ if attention_mask is not None:
238
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
239
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
240
+
241
+ if attn2 is None:
242
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
243
+
244
+ query = attn.to_q(hidden_states)
245
+ key = attn.to_k(hidden_states)
246
+ value = attn.to_v(hidden_states)
247
+
248
+ inner_dim = key.shape[-1]
249
+ head_dim = inner_dim // attn.heads
250
+
251
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
252
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
253
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
254
+
255
+ if attn.norm_q is not None:
256
+ query = attn.norm_q(query)
257
+ if attn.norm_k is not None:
258
+ key = attn.norm_k(key)
259
+
260
+ if attn2 is not None:
261
+ query_txt = attn2.to_q(encoder_hidden_states)
262
+ key_txt = attn2.to_k(encoder_hidden_states)
263
+ value_txt = attn2.to_v(encoder_hidden_states)
264
+
265
+ inner_dim = key_txt.shape[-1]
266
+ head_dim = inner_dim // attn.heads
267
+
268
+ query_txt = query_txt.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
269
+ key_txt = key_txt.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
270
+ value_txt = value_txt.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
271
+
272
+ if attn2.norm_q is not None:
273
+ query_txt = attn2.norm_q(query_txt)
274
+ if attn2.norm_k is not None:
275
+ key_txt = attn2.norm_k(key_txt)
276
+
277
+ query = torch.cat([query_txt, query], dim=2)
278
+ key = torch.cat([key_txt, key], dim=2)
279
+ value = torch.cat([value_txt, value], dim=2)
280
+
281
+ # Apply RoPE if needed
282
+ if image_rotary_emb is not None:
283
+ query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
284
+ if not attn.is_cross_attention:
285
+ key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
286
+
287
+ hidden_states = F.scaled_dot_product_attention(
288
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
289
+ )
290
+
291
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
292
+
293
+ if attn2 is None:
294
+ # linear proj
295
+ hidden_states = attn.to_out[0](hidden_states)
296
+ # dropout
297
+ hidden_states = attn.to_out[1](hidden_states)
298
+
299
+ encoder_hidden_states, hidden_states = hidden_states.split(
300
+ [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
301
+ )
302
+ else:
303
+ encoder_hidden_states, hidden_states = hidden_states.split(
304
+ [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
305
+ )
306
+ # linear proj
307
+ hidden_states = attn.to_out[0](hidden_states)
308
+ encoder_hidden_states = attn2.to_out[0](encoder_hidden_states)
309
+ # dropout
310
+ hidden_states = attn.to_out[1](hidden_states)
311
+ encoder_hidden_states = attn2.to_out[1](encoder_hidden_states)
312
+ return hidden_states, encoder_hidden_states
easyanimate/models/resampler.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba Cloud.
2
+ #
3
+ # This source code is licensed under the license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import math
7
+
8
+ import numpy as np
9
+ import torch
10
+ from torch import nn
11
+ from torch.nn import functional as F
12
+ from torch.nn.init import normal_
13
+
14
+
15
+ def get_abs_pos(abs_pos, tgt_size):
16
+ # abs_pos: L, C
17
+ # tgt_size: M
18
+ # return: M, C
19
+ src_size = int(math.sqrt(abs_pos.size(0)))
20
+ tgt_size = int(math.sqrt(tgt_size))
21
+ dtype = abs_pos.dtype
22
+
23
+ if src_size != tgt_size:
24
+ return F.interpolate(
25
+ abs_pos.float().reshape(1, src_size, src_size, -1).permute(0, 3, 1, 2),
26
+ size=(tgt_size, tgt_size),
27
+ mode="bicubic",
28
+ align_corners=False,
29
+ ).permute(0, 2, 3, 1).flatten(0, 2).to(dtype=dtype)
30
+ else:
31
+ return abs_pos
32
+
33
+ # https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20
34
+ def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
35
+ """
36
+ grid_size: int of the grid height and width
37
+ return:
38
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
39
+ """
40
+ grid_h = np.arange(grid_size, dtype=np.float32)
41
+ grid_w = np.arange(grid_size, dtype=np.float32)
42
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
43
+ grid = np.stack(grid, axis=0)
44
+
45
+ grid = grid.reshape([2, 1, grid_size, grid_size])
46
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
47
+ if cls_token:
48
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
49
+ return pos_embed
50
+
51
+
52
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
53
+ assert embed_dim % 2 == 0
54
+
55
+ # use half of dimensions to encode grid_h
56
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
57
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
58
+
59
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
60
+ return emb
61
+
62
+
63
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
64
+ """
65
+ embed_dim: output dimension for each position
66
+ pos: a list of positions to be encoded: size (M,)
67
+ out: (M, D)
68
+ """
69
+ assert embed_dim % 2 == 0
70
+ omega = np.arange(embed_dim // 2, dtype=np.float32)
71
+ omega /= embed_dim / 2.
72
+ omega = 1. / 10000**omega # (D/2,)
73
+
74
+ pos = pos.reshape(-1) # (M,)
75
+ out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
76
+
77
+ emb_sin = np.sin(out) # (M, D/2)
78
+ emb_cos = np.cos(out) # (M, D/2)
79
+
80
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
81
+ return emb
82
+
83
+ class Resampler(nn.Module):
84
+ """
85
+ A 2D perceiver-resampler network with one cross attention layers by
86
+ (grid_size**2) learnable queries and 2d sincos pos_emb
87
+ Outputs:
88
+ A tensor with the shape of (grid_size**2, embed_dim)
89
+ """
90
+ def __init__(
91
+ self,
92
+ grid_size,
93
+ embed_dim,
94
+ num_heads,
95
+ kv_dim=None,
96
+ norm_layer=nn.LayerNorm
97
+ ):
98
+ super().__init__()
99
+ self.num_queries = grid_size ** 2
100
+ self.embed_dim = embed_dim
101
+ self.num_heads = num_heads
102
+
103
+ self.pos_embed = nn.Parameter(
104
+ torch.from_numpy(get_2d_sincos_pos_embed(embed_dim, grid_size)).float()
105
+ ).requires_grad_(False)
106
+
107
+ self.query = nn.Parameter(torch.zeros(self.num_queries, embed_dim))
108
+ normal_(self.query, std=.02)
109
+
110
+ if kv_dim is not None and kv_dim != embed_dim:
111
+ self.kv_proj = nn.Linear(kv_dim, embed_dim, bias=False)
112
+ else:
113
+ self.kv_proj = nn.Identity()
114
+
115
+ self.attn = nn.MultiheadAttention(embed_dim, num_heads)
116
+ self.ln_q = norm_layer(embed_dim)
117
+ self.ln_kv = norm_layer(embed_dim)
118
+
119
+ self.apply(self._init_weights)
120
+
121
+ def _init_weights(self, m):
122
+ if isinstance(m, nn.Linear):
123
+ normal_(m.weight, std=.02)
124
+ if isinstance(m, nn.Linear) and m.bias is not None:
125
+ nn.init.constant_(m.bias, 0)
126
+ elif isinstance(m, nn.LayerNorm):
127
+ nn.init.constant_(m.bias, 0)
128
+ nn.init.constant_(m.weight, 1.0)
129
+
130
+ def forward(self, x, key_padding_mask=None):
131
+ pos_embed = get_abs_pos(self.pos_embed, x.size(1))
132
+
133
+ x = self.kv_proj(x)
134
+ x = self.ln_kv(x).permute(1, 0, 2)
135
+
136
+ N = x.shape[1]
137
+ q = self.ln_q(self.query)
138
+ out = self.attn(
139
+ self._repeat(q, N) + self.pos_embed.unsqueeze(1),
140
+ x + pos_embed.unsqueeze(1),
141
+ x,
142
+ key_padding_mask=key_padding_mask)[0]
143
+ return out.permute(1, 0, 2)
144
+
145
+ def _repeat(self, query, N: int):
146
+ return query.unsqueeze(1).repeat(1, N, 1)
easyanimate/models/transformer2d.py CHANGED
@@ -37,10 +37,6 @@ except:
37
  from diffusers.models.embeddings import \
38
  CaptionProjection as PixArtAlphaTextProjection
39
 
40
- from .attention import (KVCompressionTransformerBlock,
41
- SelfAttentionTemporalTransformerBlock,
42
- TemporalTransformerBlock)
43
-
44
 
45
  @dataclass
46
  class Transformer2DModelOutput(BaseOutput):
@@ -196,58 +192,29 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
196
  interpolation_scale=interpolation_scale,
197
  )
198
 
199
- basic_block = {
200
- "basic": BasicTransformerBlock,
201
- "kvcompression": KVCompressionTransformerBlock,
202
- }[self.basic_block_type]
203
- if self.basic_block_type == "kvcompression":
204
- self.transformer_blocks = nn.ModuleList(
205
- [
206
- basic_block(
207
- inner_dim,
208
- num_attention_heads,
209
- attention_head_dim,
210
- dropout=dropout,
211
- cross_attention_dim=cross_attention_dim,
212
- activation_fn=activation_fn,
213
- num_embeds_ada_norm=num_embeds_ada_norm,
214
- attention_bias=attention_bias,
215
- only_cross_attention=only_cross_attention,
216
- double_self_attention=double_self_attention,
217
- upcast_attention=upcast_attention,
218
- norm_type=norm_type,
219
- norm_elementwise_affine=norm_elementwise_affine,
220
- norm_eps=norm_eps,
221
- attention_type=attention_type,
222
- kvcompression=False if d < 14 else True,
223
- )
224
- for d in range(num_layers)
225
- ]
226
- )
227
- else:
228
- # 3. Define transformers blocks
229
- self.transformer_blocks = nn.ModuleList(
230
- [
231
- BasicTransformerBlock(
232
- inner_dim,
233
- num_attention_heads,
234
- attention_head_dim,
235
- dropout=dropout,
236
- cross_attention_dim=cross_attention_dim,
237
- activation_fn=activation_fn,
238
- num_embeds_ada_norm=num_embeds_ada_norm,
239
- attention_bias=attention_bias,
240
- only_cross_attention=only_cross_attention,
241
- double_self_attention=double_self_attention,
242
- upcast_attention=upcast_attention,
243
- norm_type=norm_type,
244
- norm_elementwise_affine=norm_elementwise_affine,
245
- norm_eps=norm_eps,
246
- attention_type=attention_type,
247
- )
248
- for d in range(num_layers)
249
- ]
250
- )
251
 
252
  # 4. Define output layers
253
  self.out_channels = in_channels if out_channels is None else out_channels
@@ -413,7 +380,6 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
413
  if self.training and self.gradient_checkpointing:
414
  args = {
415
  "basic": [],
416
- "kvcompression": [1, height, width],
417
  }[self.basic_block_type]
418
  hidden_states = torch.utils.checkpoint.checkpoint(
419
  block,
@@ -430,7 +396,6 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
430
  else:
431
  kwargs = {
432
  "basic": {},
433
- "kvcompression": {"num_frames":1, "height":height, "width":width},
434
  }[self.basic_block_type]
435
  hidden_states = block(
436
  hidden_states,
 
37
  from diffusers.models.embeddings import \
38
  CaptionProjection as PixArtAlphaTextProjection
39
 
 
 
 
 
40
 
41
  @dataclass
42
  class Transformer2DModelOutput(BaseOutput):
 
192
  interpolation_scale=interpolation_scale,
193
  )
194
 
195
+ # 3. Define transformers blocks
196
+ self.transformer_blocks = nn.ModuleList(
197
+ [
198
+ BasicTransformerBlock(
199
+ inner_dim,
200
+ num_attention_heads,
201
+ attention_head_dim,
202
+ dropout=dropout,
203
+ cross_attention_dim=cross_attention_dim,
204
+ activation_fn=activation_fn,
205
+ num_embeds_ada_norm=num_embeds_ada_norm,
206
+ attention_bias=attention_bias,
207
+ only_cross_attention=only_cross_attention,
208
+ double_self_attention=double_self_attention,
209
+ upcast_attention=upcast_attention,
210
+ norm_type=norm_type,
211
+ norm_elementwise_affine=norm_elementwise_affine,
212
+ norm_eps=norm_eps,
213
+ attention_type=attention_type,
214
+ )
215
+ for d in range(num_layers)
216
+ ]
217
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
218
 
219
  # 4. Define output layers
220
  self.out_channels = in_channels if out_channels is None else out_channels
 
380
  if self.training and self.gradient_checkpointing:
381
  args = {
382
  "basic": [],
 
383
  }[self.basic_block_type]
384
  hidden_states = torch.utils.checkpoint.checkpoint(
385
  block,
 
396
  else:
397
  kwargs = {
398
  "basic": {},
 
399
  }[self.basic_block_type]
400
  hidden_states = block(
401
  hidden_states,
easyanimate/models/transformer3d.py CHANGED
@@ -11,34 +11,39 @@
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
 
 
14
  import json
15
  import math
16
  import os
17
  from dataclasses import dataclass
18
- from typing import Any, Dict, Optional, Tuple
19
 
20
  import numpy as np
21
  import torch
22
  import torch.nn.functional as F
23
- import torch.nn.init as init
24
  from diffusers.configuration_utils import ConfigMixin, register_to_config
25
- from diffusers.models.attention import BasicTransformerBlock, FeedForward
26
  from diffusers.models.embeddings import (PatchEmbed, PixArtAlphaTextProjection,
27
- TimestepEmbedding, Timesteps)
28
- from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
 
29
  from diffusers.models.modeling_utils import ModelMixin
30
- from diffusers.models.normalization import AdaLayerNormContinuous
31
  from diffusers.utils import (USE_PEFT_BACKEND, BaseOutput, is_torch_version,
32
  logging)
33
  from diffusers.utils.torch_utils import maybe_allow_in_graph
34
  from einops import rearrange
35
  from torch import nn
36
 
37
- from .attention import (SelfAttentionTemporalTransformerBlock,
38
- TemporalTransformerBlock)
 
 
39
  from .norm import AdaLayerNormSingle
40
- from .patch import (CasualPatchEmbed3D, Patch1D, PatchEmbed3D, PatchEmbedF3D,
41
  TemporalUpsampler3D, UnPatch1D)
 
42
 
43
  try:
44
  from diffusers.models.embeddings import PixArtAlphaTextProjection
@@ -46,12 +51,6 @@ except:
46
  from diffusers.models.embeddings import \
47
  CaptionProjection as PixArtAlphaTextProjection
48
 
49
- def zero_module(module):
50
- # Zero out the parameters of a module and return it.
51
- for p in module.parameters():
52
- p.detach().zero_()
53
- return module
54
-
55
 
56
  class CLIPProjection(nn.Module):
57
  """
@@ -72,28 +71,6 @@ class CLIPProjection(nn.Module):
72
  hidden_states = self.linear_2(hidden_states)
73
  return hidden_states
74
 
75
- class TimePositionalEncoding(nn.Module):
76
- def __init__(
77
- self,
78
- d_model,
79
- dropout = 0.,
80
- max_len = 24
81
- ):
82
- super().__init__()
83
- self.dropout = nn.Dropout(p=dropout)
84
- position = torch.arange(max_len).unsqueeze(1)
85
- div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
86
- pe = torch.zeros(1, max_len, d_model)
87
- pe[0, :, 0::2] = torch.sin(position * div_term)
88
- pe[0, :, 1::2] = torch.cos(position * div_term)
89
- self.register_buffer('pe', pe)
90
-
91
- def forward(self, x):
92
- b, c, f, h, w = x.size()
93
- x = rearrange(x, "b c f h w -> (b h w) f c")
94
- x = x + self.pe[:, :x.size(1)]
95
- x = rearrange(x, "(b h w) f c -> b c f h w", b=b, h=h, w=w)
96
- return self.dropout(x)
97
 
98
  @dataclass
99
  class Transformer3DModelOutput(BaseOutput):
@@ -189,6 +166,10 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
189
 
190
  qk_norm = False,
191
  after_norm = False,
 
 
 
 
192
  ):
193
  super().__init__()
194
  self.use_linear_projection = use_linear_projection
@@ -202,9 +183,6 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
202
  self.casual_3d = casual_3d
203
  self.casual_3d_upsampler_index = casual_3d_upsampler_index
204
 
205
- conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
206
- linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
207
-
208
  assert sample_size is not None, "Transformer3DModel over patched input must provide sample_size"
209
 
210
  self.height = sample_size
@@ -310,34 +288,6 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
310
  for d in range(num_layers)
311
  ]
312
  )
313
- elif self.basic_block_type == "kvcompression_motionmodule":
314
- self.transformer_blocks = nn.ModuleList(
315
- [
316
- TemporalTransformerBlock(
317
- inner_dim,
318
- num_attention_heads,
319
- attention_head_dim,
320
- dropout=dropout,
321
- cross_attention_dim=cross_attention_dim,
322
- activation_fn=activation_fn,
323
- num_embeds_ada_norm=num_embeds_ada_norm,
324
- attention_bias=attention_bias,
325
- only_cross_attention=only_cross_attention,
326
- double_self_attention=double_self_attention,
327
- upcast_attention=upcast_attention,
328
- norm_type=norm_type,
329
- norm_elementwise_affine=norm_elementwise_affine,
330
- norm_eps=norm_eps,
331
- attention_type=attention_type,
332
- kvcompression=False if d < 14 else True,
333
- motion_module_type=motion_module_type,
334
- motion_module_kwargs=motion_module_kwargs,
335
- qk_norm=qk_norm,
336
- after_norm=after_norm,
337
- )
338
- for d in range(num_layers)
339
- ]
340
- )
341
  elif self.basic_block_type == "selfattentiontemporal":
342
  self.transformer_blocks = nn.ModuleList(
343
  [
@@ -448,6 +398,7 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
448
  self,
449
  hidden_states: torch.Tensor,
450
  inpaint_latents: torch.Tensor = None,
 
451
  encoder_hidden_states: Optional[torch.Tensor] = None,
452
  clip_encoder_hidden_states: Optional[torch.Tensor] = None,
453
  timestep: Optional[torch.LongTensor] = None,
@@ -524,6 +475,8 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
524
 
525
  if inpaint_latents is not None:
526
  hidden_states = torch.concat([hidden_states, inpaint_latents], 1)
 
 
527
  # 1. Input
528
  if self.casual_3d:
529
  video_length, height, width = (hidden_states.shape[-3] - 1) // self.time_patch_size + 1, hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size
@@ -596,7 +549,6 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
596
  "motionmodule": [video_length, height, width],
597
  "global_motionmodule": [video_length, height, width],
598
  "selfattentiontemporal": [],
599
- "kvcompression_motionmodule": [video_length, height, width],
600
  }[self.basic_block_type]
601
  hidden_states = torch.utils.checkpoint.checkpoint(
602
  create_custom_forward(block),
@@ -616,7 +568,6 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
616
  "motionmodule": {"num_frames":video_length, "height":height, "width":width},
617
  "global_motionmodule": {"num_frames":video_length, "height":height, "width":width},
618
  "selfattentiontemporal": {},
619
- "kvcompression_motionmodule": {"num_frames":video_length, "height":height, "width":width},
620
  }[self.basic_block_type]
621
  hidden_states = block(
622
  hidden_states,
@@ -741,4 +692,745 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
741
  params = [p.numel() if "attn_temporal." in n else 0 for n, p in model.named_parameters()]
742
  print(f"### Attn temporal Parameters: {sum(params) / 1e6} M")
743
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
744
  return model
 
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
+
15
+ import glob
16
  import json
17
  import math
18
  import os
19
  from dataclasses import dataclass
20
+ from typing import Any, Dict, Optional
21
 
22
  import numpy as np
23
  import torch
24
  import torch.nn.functional as F
 
25
  from diffusers.configuration_utils import ConfigMixin, register_to_config
26
+ from diffusers.models.attention import BasicTransformerBlock
27
  from diffusers.models.embeddings import (PatchEmbed, PixArtAlphaTextProjection,
28
+ TimestepEmbedding, Timesteps,
29
+ get_2d_sincos_pos_embed)
30
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
31
  from diffusers.models.modeling_utils import ModelMixin
32
+ from diffusers.models.normalization import AdaLayerNorm, AdaLayerNormContinuous
33
  from diffusers.utils import (USE_PEFT_BACKEND, BaseOutput, is_torch_version,
34
  logging)
35
  from diffusers.utils.torch_utils import maybe_allow_in_graph
36
  from einops import rearrange
37
  from torch import nn
38
 
39
+ from .attention import (EasyAnimateDiTBlock, HunyuanDiTBlock,
40
+ SelfAttentionTemporalTransformerBlock,
41
+ TemporalTransformerBlock, zero_module)
42
+ from .embeddings import HunyuanCombinedTimestepTextSizeStyleEmbedding, TimePositionalEncoding
43
  from .norm import AdaLayerNormSingle
44
+ from .patch import (CasualPatchEmbed3D, PatchEmbed3D, PatchEmbedF3D,
45
  TemporalUpsampler3D, UnPatch1D)
46
+ from .resampler import Resampler
47
 
48
  try:
49
  from diffusers.models.embeddings import PixArtAlphaTextProjection
 
51
  from diffusers.models.embeddings import \
52
  CaptionProjection as PixArtAlphaTextProjection
53
 
 
 
 
 
 
 
54
 
55
  class CLIPProjection(nn.Module):
56
  """
 
71
  hidden_states = self.linear_2(hidden_states)
72
  return hidden_states
73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
  @dataclass
76
  class Transformer3DModelOutput(BaseOutput):
 
166
 
167
  qk_norm = False,
168
  after_norm = False,
169
+ resize_inpaint_mask_directly: bool = False,
170
+ enable_clip_in_inpaint: bool = True,
171
+ enable_text_attention_mask: bool = True,
172
+ add_noise_in_inpaint_model: bool = False,
173
  ):
174
  super().__init__()
175
  self.use_linear_projection = use_linear_projection
 
183
  self.casual_3d = casual_3d
184
  self.casual_3d_upsampler_index = casual_3d_upsampler_index
185
 
 
 
 
186
  assert sample_size is not None, "Transformer3DModel over patched input must provide sample_size"
187
 
188
  self.height = sample_size
 
288
  for d in range(num_layers)
289
  ]
290
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
291
  elif self.basic_block_type == "selfattentiontemporal":
292
  self.transformer_blocks = nn.ModuleList(
293
  [
 
398
  self,
399
  hidden_states: torch.Tensor,
400
  inpaint_latents: torch.Tensor = None,
401
+ control_latents: torch.Tensor = None,
402
  encoder_hidden_states: Optional[torch.Tensor] = None,
403
  clip_encoder_hidden_states: Optional[torch.Tensor] = None,
404
  timestep: Optional[torch.LongTensor] = None,
 
475
 
476
  if inpaint_latents is not None:
477
  hidden_states = torch.concat([hidden_states, inpaint_latents], 1)
478
+ if control_latents is not None:
479
+ hidden_states = torch.concat([hidden_states, control_latents], 1)
480
  # 1. Input
481
  if self.casual_3d:
482
  video_length, height, width = (hidden_states.shape[-3] - 1) // self.time_patch_size + 1, hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size
 
549
  "motionmodule": [video_length, height, width],
550
  "global_motionmodule": [video_length, height, width],
551
  "selfattentiontemporal": [],
 
552
  }[self.basic_block_type]
553
  hidden_states = torch.utils.checkpoint.checkpoint(
554
  create_custom_forward(block),
 
568
  "motionmodule": {"num_frames":video_length, "height":height, "width":width},
569
  "global_motionmodule": {"num_frames":video_length, "height":height, "width":width},
570
  "selfattentiontemporal": {},
 
571
  }[self.basic_block_type]
572
  hidden_states = block(
573
  hidden_states,
 
692
  params = [p.numel() if "attn_temporal." in n else 0 for n, p in model.named_parameters()]
693
  print(f"### Attn temporal Parameters: {sum(params) / 1e6} M")
694
 
695
+ return model
696
+
697
+ class HunyuanTransformer3DModel(ModelMixin, ConfigMixin):
698
+ """
699
+ HunYuanDiT: Diffusion model with a Transformer backbone.
700
+
701
+ Inherit ModelMixin and ConfigMixin to be compatible with the sampler StableDiffusionPipeline of diffusers.
702
+
703
+ Parameters:
704
+ num_attention_heads (`int`, *optional*, defaults to 16):
705
+ The number of heads to use for multi-head attention.
706
+ attention_head_dim (`int`, *optional*, defaults to 88):
707
+ The number of channels in each head.
708
+ in_channels (`int`, *optional*):
709
+ The number of channels in the input and output (specify if the input is **continuous**).
710
+ patch_size (`int`, *optional*):
711
+ The size of the patch to use for the input.
712
+ activation_fn (`str`, *optional*, defaults to `"geglu"`):
713
+ Activation function to use in feed-forward.
714
+ sample_size (`int`, *optional*):
715
+ The width of the latent images. This is fixed during training since it is used to learn a number of
716
+ position embeddings.
717
+ dropout (`float`, *optional*, defaults to 0.0):
718
+ The dropout probability to use.
719
+ cross_attention_dim (`int`, *optional*):
720
+ The number of dimension in the clip text embedding.
721
+ hidden_size (`int`, *optional*):
722
+ The size of hidden layer in the conditioning embedding layers.
723
+ num_layers (`int`, *optional*, defaults to 1):
724
+ The number of layers of Transformer blocks to use.
725
+ mlp_ratio (`float`, *optional*, defaults to 4.0):
726
+ The ratio of the hidden layer size to the input size.
727
+ learn_sigma (`bool`, *optional*, defaults to `True`):
728
+ Whether to predict variance.
729
+ cross_attention_dim_t5 (`int`, *optional*):
730
+ The number dimensions in t5 text embedding.
731
+ pooled_projection_dim (`int`, *optional*):
732
+ The size of the pooled projection.
733
+ text_len (`int`, *optional*):
734
+ The length of the clip text embedding.
735
+ text_len_t5 (`int`, *optional*):
736
+ The length of the T5 text embedding.
737
+ """
738
+ _supports_gradient_checkpointing = True
739
+
740
+ @register_to_config
741
+ def __init__(
742
+ self,
743
+ num_attention_heads: int = 16,
744
+ attention_head_dim: int = 88,
745
+ in_channels: Optional[int] = None,
746
+ out_channels: Optional[int] = None,
747
+ patch_size: Optional[int] = None,
748
+
749
+ n_query=16,
750
+ projection_dim=768,
751
+ activation_fn: str = "gelu-approximate",
752
+ sample_size=32,
753
+ hidden_size=1152,
754
+ num_layers: int = 28,
755
+ mlp_ratio: float = 4.0,
756
+ learn_sigma: bool = True,
757
+ cross_attention_dim: int = 1024,
758
+ norm_type: str = "layer_norm",
759
+ cross_attention_dim_t5: int = 2048,
760
+ pooled_projection_dim: int = 1024,
761
+ text_len: int = 77,
762
+ text_len_t5: int = 256,
763
+
764
+ # block type
765
+ basic_block_type: str = "basic",
766
+
767
+ time_position_encoding = False,
768
+ time_position_encoding_type: str = "2d_rope",
769
+ after_norm = False,
770
+ resize_inpaint_mask_directly: bool = False,
771
+ enable_clip_in_inpaint: bool = True,
772
+ enable_text_attention_mask: bool = True,
773
+ add_noise_in_inpaint_model: bool = False,
774
+ ):
775
+ super().__init__()
776
+ # 4. Define output layers
777
+ if learn_sigma:
778
+ self.out_channels = in_channels * 2 if out_channels is None else out_channels
779
+ else:
780
+ self.out_channels = in_channels if out_channels is None else out_channels
781
+ self.enable_inpaint = in_channels * 2 != self.out_channels if learn_sigma else in_channels != self.out_channels
782
+ self.num_heads = num_attention_heads
783
+ self.inner_dim = num_attention_heads * attention_head_dim
784
+ self.basic_block_type = basic_block_type
785
+ self.resize_inpaint_mask_directly = resize_inpaint_mask_directly
786
+ self.text_embedder = PixArtAlphaTextProjection(
787
+ in_features=cross_attention_dim_t5,
788
+ hidden_size=cross_attention_dim_t5 * 4,
789
+ out_features=cross_attention_dim,
790
+ act_fn="silu_fp32",
791
+ )
792
+
793
+ self.text_embedding_padding = nn.Parameter(
794
+ torch.randn(text_len + text_len_t5, cross_attention_dim, dtype=torch.float32)
795
+ )
796
+
797
+ self.pos_embed = PatchEmbed(
798
+ height=sample_size,
799
+ width=sample_size,
800
+ in_channels=in_channels,
801
+ embed_dim=hidden_size,
802
+ patch_size=patch_size,
803
+ pos_embed_type=None,
804
+ )
805
+
806
+ self.time_extra_emb = HunyuanCombinedTimestepTextSizeStyleEmbedding(
807
+ hidden_size,
808
+ pooled_projection_dim=pooled_projection_dim,
809
+ seq_len=text_len_t5,
810
+ cross_attention_dim=cross_attention_dim_t5,
811
+ )
812
+
813
+ # 3. Define transformers blocks
814
+ if self.basic_block_type == "hybrid_attention":
815
+ self.blocks = nn.ModuleList(
816
+ [
817
+ HunyuanDiTBlock(
818
+ dim=self.inner_dim,
819
+ num_attention_heads=self.config.num_attention_heads,
820
+ activation_fn=activation_fn,
821
+ ff_inner_dim=int(self.inner_dim * mlp_ratio),
822
+ cross_attention_dim=cross_attention_dim,
823
+ qk_norm=True, # See http://arxiv.org/abs/2302.05442 for details.
824
+ skip=layer > num_layers // 2,
825
+ after_norm=after_norm,
826
+ time_position_encoding=time_position_encoding,
827
+ is_local_attention=False if layer % 2 == 0 else True,
828
+ local_attention_frames=2,
829
+ enable_inpaint=self.enable_inpaint and enable_clip_in_inpaint,
830
+ )
831
+ for layer in range(num_layers)
832
+ ]
833
+ )
834
+ elif self.basic_block_type == "kvcompression_basic":
835
+ self.blocks = nn.ModuleList(
836
+ [
837
+ HunyuanDiTBlock(
838
+ dim=self.inner_dim,
839
+ num_attention_heads=self.config.num_attention_heads,
840
+ activation_fn=activation_fn,
841
+ ff_inner_dim=int(self.inner_dim * mlp_ratio),
842
+ cross_attention_dim=cross_attention_dim,
843
+ qk_norm=True, # See http://arxiv.org/abs/2302.05442 for details.
844
+ skip=layer > num_layers // 2,
845
+ after_norm=after_norm,
846
+ time_position_encoding=time_position_encoding,
847
+ kvcompression=False if layer < num_layers // 2 else True,
848
+ enable_inpaint=self.enable_inpaint and enable_clip_in_inpaint,
849
+ )
850
+ for layer in range(num_layers)
851
+ ]
852
+ )
853
+ else:
854
+ self.blocks = nn.ModuleList(
855
+ [
856
+ HunyuanDiTBlock(
857
+ dim=self.inner_dim,
858
+ num_attention_heads=self.config.num_attention_heads,
859
+ activation_fn=activation_fn,
860
+ ff_inner_dim=int(self.inner_dim * mlp_ratio),
861
+ cross_attention_dim=cross_attention_dim,
862
+ qk_norm=True, # See http://arxiv.org/abs/2302.05442 for details.
863
+ skip=layer > num_layers // 2,
864
+ after_norm=after_norm,
865
+ time_position_encoding=time_position_encoding,
866
+ enable_inpaint=self.enable_inpaint and enable_clip_in_inpaint,
867
+ )
868
+ for layer in range(num_layers)
869
+ ]
870
+ )
871
+
872
+ self.n_query = n_query
873
+ if self.enable_inpaint and enable_clip_in_inpaint:
874
+ self.clip_padding = nn.Parameter(
875
+ torch.randn((self.n_query, cross_attention_dim)) * 0.02
876
+ )
877
+ self.clip_projection = Resampler(
878
+ int(math.sqrt(n_query)),
879
+ embed_dim=cross_attention_dim,
880
+ num_heads=self.config.num_attention_heads,
881
+ kv_dim=projection_dim,
882
+ norm_layer=nn.LayerNorm,
883
+ )
884
+ else:
885
+ self.clip_padding = None
886
+ self.clip_projection = None
887
+
888
+ self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
889
+ self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
890
+
891
+ self.gradient_checkpointing = False
892
+
893
+ def _set_gradient_checkpointing(self, module, value=False):
894
+ if hasattr(module, "gradient_checkpointing"):
895
+ module.gradient_checkpointing = value
896
+
897
+ def forward(
898
+ self,
899
+ hidden_states,
900
+ timestep,
901
+ encoder_hidden_states=None,
902
+ text_embedding_mask=None,
903
+ encoder_hidden_states_t5=None,
904
+ text_embedding_mask_t5=None,
905
+ image_meta_size=None,
906
+ style=None,
907
+ image_rotary_emb=None,
908
+ inpaint_latents=None,
909
+ control_latents: torch.Tensor = None,
910
+ clip_encoder_hidden_states: Optional[torch.Tensor]=None,
911
+ clip_attention_mask: Optional[torch.Tensor]=None,
912
+ return_dict=True,
913
+ ):
914
+ """
915
+ The [`HunyuanDiT2DModel`] forward method.
916
+
917
+ Args:
918
+ hidden_states (`torch.Tensor` of shape `(batch size, dim, height, width)`):
919
+ The input tensor.
920
+ timestep ( `torch.LongTensor`, *optional*):
921
+ Used to indicate denoising step.
922
+ encoder_hidden_states ( `torch.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
923
+ Conditional embeddings for cross attention layer. This is the output of `BertModel`.
924
+ text_embedding_mask: torch.Tensor
925
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. This is the output
926
+ of `BertModel`.
927
+ encoder_hidden_states_t5 ( `torch.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
928
+ Conditional embeddings for cross attention layer. This is the output of T5 Text Encoder.
929
+ text_embedding_mask_t5: torch.Tensor
930
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. This is the output
931
+ of T5 Text Encoder.
932
+ image_meta_size (torch.Tensor):
933
+ Conditional embedding indicate the image sizes
934
+ style: torch.Tensor:
935
+ Conditional embedding indicate the style
936
+ image_rotary_emb (`torch.Tensor`):
937
+ The image rotary embeddings to apply on query and key tensors during attention calculation.
938
+ return_dict: bool
939
+ Whether to return a dictionary.
940
+ """
941
+ if inpaint_latents is not None:
942
+ hidden_states = torch.concat([hidden_states, inpaint_latents], 1)
943
+ if control_latents is not None:
944
+ hidden_states = torch.concat([hidden_states, control_latents], 1)
945
+
946
+ # unpatchify: (N, out_channels, H, W)
947
+ patch_size = self.pos_embed.patch_size
948
+ video_length, height, width = hidden_states.shape[-3], hidden_states.shape[-2] // patch_size, hidden_states.shape[-1] // patch_size
949
+ hidden_states = rearrange(hidden_states, "b c f h w ->(b f) c h w")
950
+ hidden_states = self.pos_embed(hidden_states)
951
+ hidden_states = rearrange(hidden_states, "(b f) (h w) c -> b c f h w", f=video_length, h=height, w=width)
952
+ hidden_states = hidden_states.flatten(2).transpose(1, 2)
953
+
954
+ temb = self.time_extra_emb(
955
+ timestep, encoder_hidden_states_t5, image_meta_size, style, hidden_dtype=timestep.dtype
956
+ ) # [B, D]
957
+
958
+ # text projection
959
+ batch_size, sequence_length, _ = encoder_hidden_states_t5.shape
960
+ encoder_hidden_states_t5 = self.text_embedder(
961
+ encoder_hidden_states_t5.view(-1, encoder_hidden_states_t5.shape[-1])
962
+ )
963
+ encoder_hidden_states_t5 = encoder_hidden_states_t5.view(batch_size, sequence_length, -1)
964
+
965
+ encoder_hidden_states = torch.cat([encoder_hidden_states, encoder_hidden_states_t5], dim=1)
966
+ text_embedding_mask = torch.cat([text_embedding_mask, text_embedding_mask_t5], dim=-1)
967
+ text_embedding_mask = text_embedding_mask.unsqueeze(2).bool()
968
+
969
+ encoder_hidden_states = torch.where(text_embedding_mask, encoder_hidden_states, self.text_embedding_padding)
970
+
971
+ if clip_encoder_hidden_states is not None:
972
+ batch_size = encoder_hidden_states.shape[0]
973
+
974
+ clip_encoder_hidden_states = self.clip_projection(clip_encoder_hidden_states)
975
+ clip_encoder_hidden_states = clip_encoder_hidden_states.view(batch_size, -1, encoder_hidden_states.shape[-1])
976
+
977
+ clip_attention_mask = clip_attention_mask.unsqueeze(2).bool()
978
+ clip_encoder_hidden_states = torch.where(clip_attention_mask, clip_encoder_hidden_states, self.clip_padding)
979
+
980
+ skips = []
981
+ for layer, block in enumerate(self.blocks):
982
+ if layer > self.config.num_layers // 2:
983
+ skip = skips.pop()
984
+ if self.training and self.gradient_checkpointing:
985
+
986
+ def create_custom_forward(module, return_dict=None):
987
+ def custom_forward(*inputs):
988
+ if return_dict is not None:
989
+ return module(*inputs, return_dict=return_dict)
990
+ else:
991
+ return module(*inputs)
992
+
993
+ return custom_forward
994
+
995
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
996
+ args = {
997
+ "kvcompression_basic": [video_length, height, width, clip_encoder_hidden_states],
998
+ "basic": [video_length, height, width, clip_encoder_hidden_states],
999
+ "hybrid_attention": [video_length, height, width, clip_encoder_hidden_states],
1000
+ }[self.basic_block_type]
1001
+ hidden_states = torch.utils.checkpoint.checkpoint(
1002
+ create_custom_forward(block),
1003
+ hidden_states,
1004
+ encoder_hidden_states,
1005
+ temb,
1006
+ image_rotary_emb,
1007
+ skip,
1008
+ *args,
1009
+ **ckpt_kwargs,
1010
+ )
1011
+ else:
1012
+ kwargs = {
1013
+ "kvcompression_basic": {"num_frames":video_length, "height":height, "width":width, "clip_encoder_hidden_states":clip_encoder_hidden_states},
1014
+ "basic": {"num_frames":video_length, "height":height, "width":width, "clip_encoder_hidden_states":clip_encoder_hidden_states},
1015
+ "hybrid_attention": {"num_frames":video_length, "height":height, "width":width, "clip_encoder_hidden_states":clip_encoder_hidden_states},
1016
+ }[self.basic_block_type]
1017
+ hidden_states = block(
1018
+ hidden_states,
1019
+ temb=temb,
1020
+ encoder_hidden_states=encoder_hidden_states,
1021
+ image_rotary_emb=image_rotary_emb,
1022
+ skip=skip,
1023
+ **kwargs
1024
+ ) # (N, L, D)
1025
+ else:
1026
+ if self.training and self.gradient_checkpointing:
1027
+
1028
+ def create_custom_forward(module, return_dict=None):
1029
+ def custom_forward(*inputs):
1030
+ if return_dict is not None:
1031
+ return module(*inputs, return_dict=return_dict)
1032
+ else:
1033
+ return module(*inputs)
1034
+
1035
+ return custom_forward
1036
+
1037
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
1038
+ args = {
1039
+ "kvcompression_basic": [None, video_length, height, width, clip_encoder_hidden_states, True if layer==0 else False],
1040
+ "basic": [None, video_length, height, width, clip_encoder_hidden_states, True if layer==0 else False],
1041
+ "hybrid_attention": [None, video_length, height, width, clip_encoder_hidden_states, True if layer==0 else False],
1042
+ }[self.basic_block_type]
1043
+ hidden_states = torch.utils.checkpoint.checkpoint(
1044
+ create_custom_forward(block),
1045
+ hidden_states,
1046
+ encoder_hidden_states,
1047
+ temb,
1048
+ image_rotary_emb,
1049
+ *args,
1050
+ **ckpt_kwargs,
1051
+ )
1052
+ else:
1053
+ kwargs = {
1054
+ "kvcompression_basic": {"num_frames":video_length, "height":height, "width":width, "clip_encoder_hidden_states":clip_encoder_hidden_states},
1055
+ "basic": {"num_frames":video_length, "height":height, "width":width, "clip_encoder_hidden_states":clip_encoder_hidden_states},
1056
+ "hybrid_attention": {"num_frames":video_length, "height":height, "width":width, "clip_encoder_hidden_states":clip_encoder_hidden_states},
1057
+ }[self.basic_block_type]
1058
+ hidden_states = block(
1059
+ hidden_states,
1060
+ temb=temb,
1061
+ encoder_hidden_states=encoder_hidden_states,
1062
+ image_rotary_emb=image_rotary_emb,
1063
+ disable_image_rotary_emb_in_attn1=True if layer==0 else False,
1064
+ **kwargs
1065
+ ) # (N, L, D)
1066
+
1067
+ if layer < (self.config.num_layers // 2 - 1):
1068
+ skips.append(hidden_states)
1069
+
1070
+ # final layer
1071
+ hidden_states = self.norm_out(hidden_states, temb.to(torch.float32))
1072
+ hidden_states = self.proj_out(hidden_states)
1073
+ # (N, L, patch_size ** 2 * out_channels)
1074
+
1075
+ hidden_states = hidden_states.reshape(
1076
+ shape=(hidden_states.shape[0], video_length, height, width, patch_size, patch_size, self.out_channels)
1077
+ )
1078
+ hidden_states = torch.einsum("nfhwpqc->ncfhpwq", hidden_states)
1079
+ output = hidden_states.reshape(
1080
+ shape=(hidden_states.shape[0], self.out_channels, video_length, height * patch_size, width * patch_size)
1081
+ )
1082
+
1083
+ if not return_dict:
1084
+ return (output,)
1085
+ return Transformer2DModelOutput(sample=output)
1086
+
1087
+ @classmethod
1088
+ def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, patch_size=2, transformer_additional_kwargs={}):
1089
+ if subfolder is not None:
1090
+ pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
1091
+ print(f"loaded 3D transformer's pretrained weights from {pretrained_model_path} ...")
1092
+
1093
+ config_file = os.path.join(pretrained_model_path, 'config.json')
1094
+ if not os.path.isfile(config_file):
1095
+ raise RuntimeError(f"{config_file} does not exist")
1096
+ with open(config_file, "r") as f:
1097
+ config = json.load(f)
1098
+
1099
+ from diffusers.utils import WEIGHTS_NAME
1100
+ model = cls.from_config(config, **transformer_additional_kwargs)
1101
+ model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
1102
+ model_file_safetensors = model_file.replace(".bin", ".safetensors")
1103
+ if os.path.exists(model_file_safetensors):
1104
+ from safetensors.torch import load_file, safe_open
1105
+ state_dict = load_file(model_file_safetensors)
1106
+ else:
1107
+ if not os.path.isfile(model_file):
1108
+ raise RuntimeError(f"{model_file} does not exist")
1109
+ state_dict = torch.load(model_file, map_location="cpu")
1110
+
1111
+ if model.state_dict()['pos_embed.proj.weight'].size() != state_dict['pos_embed.proj.weight'].size():
1112
+ new_shape = model.state_dict()['pos_embed.proj.weight'].size()
1113
+ if len(new_shape) == 5:
1114
+ state_dict['pos_embed.proj.weight'] = state_dict['pos_embed.proj.weight'].unsqueeze(2).expand(new_shape).clone()
1115
+ state_dict['pos_embed.proj.weight'][:, :, :-1] = 0
1116
+ else:
1117
+ if model.state_dict()['pos_embed.proj.weight'].size()[1] > state_dict['pos_embed.proj.weight'].size()[1]:
1118
+ model.state_dict()['pos_embed.proj.weight'][:, :state_dict['pos_embed.proj.weight'].size()[1], :, :] = state_dict['pos_embed.proj.weight']
1119
+ model.state_dict()['pos_embed.proj.weight'][:, state_dict['pos_embed.proj.weight'].size()[1]:, :, :] = 0
1120
+ state_dict['pos_embed.proj.weight'] = model.state_dict()['pos_embed.proj.weight']
1121
+ else:
1122
+ model.state_dict()['pos_embed.proj.weight'][:, :, :, :] = state_dict['pos_embed.proj.weight'][:, :model.state_dict()['pos_embed.proj.weight'].size()[1], :, :]
1123
+ state_dict['pos_embed.proj.weight'] = model.state_dict()['pos_embed.proj.weight']
1124
+
1125
+ if model.state_dict()['proj_out.weight'].size() != state_dict['proj_out.weight'].size():
1126
+ if model.state_dict()['proj_out.weight'].size()[0] > state_dict['proj_out.weight'].size()[0]:
1127
+ model.state_dict()['proj_out.weight'][:state_dict['proj_out.weight'].size()[0], :] = state_dict['proj_out.weight']
1128
+ state_dict['proj_out.weight'] = model.state_dict()['proj_out.weight']
1129
+ else:
1130
+ model.state_dict()['proj_out.weight'][:, :] = state_dict['proj_out.weight'][:model.state_dict()['proj_out.weight'].size()[0], :]
1131
+ state_dict['proj_out.weight'] = model.state_dict()['proj_out.weight']
1132
+
1133
+ if model.state_dict()['proj_out.bias'].size() != state_dict['proj_out.bias'].size():
1134
+ if model.state_dict()['proj_out.bias'].size()[0] > state_dict['proj_out.bias'].size()[0]:
1135
+ model.state_dict()['proj_out.bias'][:state_dict['proj_out.bias'].size()[0]] = state_dict['proj_out.bias']
1136
+ state_dict['proj_out.bias'] = model.state_dict()['proj_out.bias']
1137
+ else:
1138
+ model.state_dict()['proj_out.bias'][:, :] = state_dict['proj_out.bias'][:model.state_dict()['proj_out.bias'].size()[0], :]
1139
+ state_dict['proj_out.bias'] = model.state_dict()['proj_out.bias']
1140
+
1141
+ tmp_state_dict = {}
1142
+ for key in state_dict:
1143
+ if key in model.state_dict().keys() and model.state_dict()[key].size() == state_dict[key].size():
1144
+ tmp_state_dict[key] = state_dict[key]
1145
+ else:
1146
+ print(key, "Size don't match, skip")
1147
+ state_dict = tmp_state_dict
1148
+
1149
+ m, u = model.load_state_dict(state_dict, strict=False)
1150
+ print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
1151
+ print(m)
1152
+
1153
+ params = [p.numel() if "mamba" in n else 0 for n, p in model.named_parameters()]
1154
+ print(f"### Mamba Parameters: {sum(params) / 1e6} M")
1155
+
1156
+ params = [p.numel() if "attn1." in n else 0 for n, p in model.named_parameters()]
1157
+ print(f"### attn1 Parameters: {sum(params) / 1e6} M")
1158
+
1159
+ return model
1160
+
1161
+ class EasyAnimateTransformer3DModel(ModelMixin, ConfigMixin):
1162
+ _supports_gradient_checkpointing = True
1163
+
1164
+ @register_to_config
1165
+ def __init__(
1166
+ self,
1167
+ num_attention_heads: int = 30,
1168
+ attention_head_dim: int = 64,
1169
+ in_channels: Optional[int] = None,
1170
+ out_channels: Optional[int] = None,
1171
+ patch_size: Optional[int] = None,
1172
+ sample_width: int = 90,
1173
+ sample_height: int = 60,
1174
+ ref_channels: int = None,
1175
+ clip_channels: int = None,
1176
+
1177
+ activation_fn: str = "gelu-approximate",
1178
+ timestep_activation_fn: str = "silu",
1179
+ freq_shift: int = 0,
1180
+ num_layers: int = 30,
1181
+ dropout: float = 0.0,
1182
+ time_embed_dim: int = 512,
1183
+ text_embed_dim: int = 4096,
1184
+ text_embed_dim_t5: int = 4096,
1185
+ norm_eps: float = 1e-5,
1186
+
1187
+ norm_elementwise_affine: bool = True,
1188
+ flip_sin_to_cos: bool = True,
1189
+
1190
+ time_position_encoding_type: str = "3d_rope",
1191
+ after_norm = False,
1192
+ resize_inpaint_mask_directly: bool = False,
1193
+ enable_clip_in_inpaint: bool = True,
1194
+ enable_text_attention_mask: bool = True,
1195
+ add_noise_in_inpaint_model: bool = False,
1196
+ ):
1197
+ super().__init__()
1198
+ self.num_heads = num_attention_heads
1199
+ self.inner_dim = num_attention_heads * attention_head_dim
1200
+ self.resize_inpaint_mask_directly = resize_inpaint_mask_directly
1201
+ self.patch_size = patch_size
1202
+
1203
+ post_patch_height = sample_height // patch_size
1204
+ post_patch_width = sample_width // patch_size
1205
+ self.post_patch_height = post_patch_height
1206
+ self.post_patch_width = post_patch_width
1207
+
1208
+ self.time_proj = Timesteps(self.inner_dim, flip_sin_to_cos, freq_shift)
1209
+ self.time_embedding = TimestepEmbedding(self.inner_dim, time_embed_dim, timestep_activation_fn)
1210
+
1211
+ self.proj = nn.Conv2d(
1212
+ in_channels, self.inner_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=True
1213
+ )
1214
+ self.text_proj = nn.Linear(text_embed_dim, self.inner_dim)
1215
+ self.text_proj_t5 = nn.Linear(text_embed_dim_t5, self.inner_dim)
1216
+
1217
+ if ref_channels is not None:
1218
+ self.ref_proj = nn.Conv2d(
1219
+ ref_channels, self.inner_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=True
1220
+ )
1221
+ ref_pos_embedding = get_2d_sincos_pos_embed(self.inner_dim, (post_patch_height, post_patch_width))
1222
+ ref_pos_embedding = torch.from_numpy(ref_pos_embedding)
1223
+ self.register_buffer("ref_pos_embedding", ref_pos_embedding, persistent=False)
1224
+
1225
+ if clip_channels is not None:
1226
+ self.clip_proj = nn.Linear(clip_channels, self.inner_dim)
1227
+
1228
+ self.transformer_blocks = nn.ModuleList(
1229
+ [
1230
+ EasyAnimateDiTBlock(
1231
+ dim=self.inner_dim,
1232
+ num_attention_heads=num_attention_heads,
1233
+ attention_head_dim=attention_head_dim,
1234
+ time_embed_dim=time_embed_dim,
1235
+ dropout=dropout,
1236
+ activation_fn=activation_fn,
1237
+ norm_elementwise_affine=norm_elementwise_affine,
1238
+ norm_eps=norm_eps,
1239
+ after_norm=after_norm
1240
+ )
1241
+ for _ in range(num_layers)
1242
+ ]
1243
+ )
1244
+ self.norm_final = nn.LayerNorm(self.inner_dim, norm_eps, norm_elementwise_affine)
1245
+
1246
+ # 5. Output blocks
1247
+ self.norm_out = AdaLayerNorm(
1248
+ embedding_dim=time_embed_dim,
1249
+ output_dim=2 * self.inner_dim,
1250
+ norm_elementwise_affine=norm_elementwise_affine,
1251
+ norm_eps=norm_eps,
1252
+ chunk_dim=1,
1253
+ )
1254
+ self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * out_channels)
1255
+
1256
+ self.gradient_checkpointing = False
1257
+
1258
+ def _set_gradient_checkpointing(self, module, value=False):
1259
+ self.gradient_checkpointing = value
1260
+
1261
+ def forward(
1262
+ self,
1263
+ hidden_states,
1264
+ timestep,
1265
+ timestep_cond = None,
1266
+ encoder_hidden_states: Optional[torch.Tensor] = None,
1267
+ text_embedding_mask: Optional[torch.Tensor] = None,
1268
+ encoder_hidden_states_t5: Optional[torch.Tensor] = None,
1269
+ text_embedding_mask_t5: Optional[torch.Tensor] = None,
1270
+ image_meta_size = None,
1271
+ style = None,
1272
+ image_rotary_emb: Optional[torch.Tensor] = None,
1273
+ inpaint_latents: Optional[torch.Tensor] = None,
1274
+ control_latents: Optional[torch.Tensor] = None,
1275
+ ref_latents: Optional[torch.Tensor] = None,
1276
+ clip_encoder_hidden_states: Optional[torch.Tensor] = None,
1277
+ clip_attention_mask: Optional[torch.Tensor] = None,
1278
+ return_dict=True,
1279
+ ):
1280
+ batch_size, channels, video_length, height, width = hidden_states.size()
1281
+
1282
+ # 1. Time embedding
1283
+ temb = self.time_proj(timestep).to(dtype=hidden_states.dtype)
1284
+ temb = self.time_embedding(temb, timestep_cond)
1285
+
1286
+ # 2. Patch embedding
1287
+ if inpaint_latents is not None:
1288
+ hidden_states = torch.concat([hidden_states, inpaint_latents], 1)
1289
+ if control_latents is not None:
1290
+ hidden_states = torch.concat([hidden_states, control_latents], 1)
1291
+
1292
+ hidden_states = rearrange(hidden_states, "b c f h w ->(b f) c h w")
1293
+ hidden_states = self.proj(hidden_states)
1294
+ hidden_states = rearrange(hidden_states, "(b f) c h w -> b c f h w", f=video_length, h=height // self.patch_size, w=width // self.patch_size)
1295
+ hidden_states = hidden_states.flatten(2).transpose(1, 2)
1296
+
1297
+ encoder_hidden_states = self.text_proj(encoder_hidden_states)
1298
+ if encoder_hidden_states_t5 is not None:
1299
+ encoder_hidden_states_t5 = self.text_proj_t5(encoder_hidden_states_t5)
1300
+ encoder_hidden_states = torch.cat([encoder_hidden_states, encoder_hidden_states_t5], dim=1).contiguous()
1301
+
1302
+ if ref_latents is not None:
1303
+ ref_batch, ref_channels, ref_video_length, ref_height, ref_width = ref_latents.shape
1304
+ ref_latents = rearrange(ref_latents, "b c f h w ->(b f) c h w")
1305
+ ref_latents = self.ref_proj(ref_latents)
1306
+ ref_latents = rearrange(ref_latents, "(b f) c h w -> b c f h w", f=ref_video_length, h=ref_height // self.patch_size, w=ref_width // self.patch_size)
1307
+ ref_latents = ref_latents.flatten(2).transpose(1, 2)
1308
+
1309
+ emb_size = hidden_states.size()[-1]
1310
+ ref_pos_embedding = self.ref_pos_embedding
1311
+ ref_pos_embedding_interpolate = ref_pos_embedding.view(1, 1, self.post_patch_height, self.post_patch_width, emb_size).permute([0, 4, 1, 2, 3])
1312
+ ref_pos_embedding_interpolate = F.interpolate(
1313
+ ref_pos_embedding_interpolate,
1314
+ size=[1, height // self.config.patch_size, width // self.config.patch_size],
1315
+ mode='trilinear', align_corners=False
1316
+ )
1317
+ ref_pos_embedding_interpolate = ref_pos_embedding_interpolate.permute([0, 2, 3, 4, 1]).view(1, -1, emb_size)
1318
+ ref_latents = ref_latents + ref_pos_embedding_interpolate
1319
+
1320
+ encoder_hidden_states = ref_latents
1321
+
1322
+ if clip_encoder_hidden_states is not None:
1323
+ clip_encoder_hidden_states = self.clip_proj(clip_encoder_hidden_states)
1324
+
1325
+ encoder_hidden_states = torch.concat([clip_encoder_hidden_states, ref_latents], dim=1)
1326
+
1327
+ # 4. Transformer blocks
1328
+ for i, block in enumerate(self.transformer_blocks):
1329
+ if self.training and self.gradient_checkpointing:
1330
+ def create_custom_forward(module, return_dict=None):
1331
+ def custom_forward(*inputs):
1332
+ if return_dict is not None:
1333
+ return module(*inputs, return_dict=return_dict)
1334
+ else:
1335
+ return module(*inputs)
1336
+
1337
+ return custom_forward
1338
+
1339
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
1340
+ hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
1341
+ create_custom_forward(block),
1342
+ hidden_states,
1343
+ encoder_hidden_states,
1344
+ temb,
1345
+ image_rotary_emb,
1346
+ **ckpt_kwargs,
1347
+ )
1348
+ else:
1349
+ hidden_states, encoder_hidden_states = block(
1350
+ hidden_states=hidden_states,
1351
+ encoder_hidden_states=encoder_hidden_states,
1352
+ temb=temb,
1353
+ image_rotary_emb=image_rotary_emb,
1354
+ )
1355
+
1356
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
1357
+ hidden_states = self.norm_final(hidden_states)
1358
+ hidden_states = hidden_states[:, encoder_hidden_states.size()[1]:]
1359
+
1360
+ # 5. Final block
1361
+ hidden_states = self.norm_out(hidden_states, temb=temb)
1362
+ hidden_states = self.proj_out(hidden_states)
1363
+
1364
+ # 6. Unpatchify
1365
+ p = self.config.patch_size
1366
+ output = hidden_states.reshape(batch_size, video_length, height // p, width // p, channels, p, p)
1367
+ output = output.permute(0, 4, 1, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
1368
+
1369
+ if not return_dict:
1370
+ return (output,)
1371
+ return Transformer2DModelOutput(sample=output)
1372
+
1373
+ @classmethod
1374
+ def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, transformer_additional_kwargs={}):
1375
+ if subfolder is not None:
1376
+ pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
1377
+ print(f"loaded 3D transformer's pretrained weights from {pretrained_model_path} ...")
1378
+
1379
+ config_file = os.path.join(pretrained_model_path, 'config.json')
1380
+ if not os.path.isfile(config_file):
1381
+ raise RuntimeError(f"{config_file} does not exist")
1382
+ with open(config_file, "r") as f:
1383
+ config = json.load(f)
1384
+
1385
+ from diffusers.utils import WEIGHTS_NAME
1386
+ model = cls.from_config(config, **transformer_additional_kwargs)
1387
+ model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
1388
+ model_file_safetensors = model_file.replace(".bin", ".safetensors")
1389
+ if os.path.exists(model_file):
1390
+ state_dict = torch.load(model_file, map_location="cpu")
1391
+ elif os.path.exists(model_file_safetensors):
1392
+ from safetensors.torch import load_file, safe_open
1393
+ state_dict = load_file(model_file_safetensors)
1394
+ else:
1395
+ from safetensors.torch import load_file, safe_open
1396
+ model_files_safetensors = glob.glob(os.path.join(pretrained_model_path, "*.safetensors"))
1397
+ state_dict = {}
1398
+ for model_file_safetensors in model_files_safetensors:
1399
+ _state_dict = load_file(model_file_safetensors)
1400
+ for key in _state_dict:
1401
+ state_dict[key] = _state_dict[key]
1402
+
1403
+ if model.state_dict()['proj.weight'].size() != state_dict['proj.weight'].size():
1404
+ new_shape = model.state_dict()['proj.weight'].size()
1405
+ if len(new_shape) == 5:
1406
+ state_dict['proj.weight'] = state_dict['proj.weight'].unsqueeze(2).expand(new_shape).clone()
1407
+ state_dict['proj.weight'][:, :, :-1] = 0
1408
+ else:
1409
+ if model.state_dict()['proj.weight'].size()[1] > state_dict['proj.weight'].size()[1]:
1410
+ model.state_dict()['proj.weight'][:, :state_dict['proj.weight'].size()[1], :, :] = state_dict['proj.weight']
1411
+ model.state_dict()['proj.weight'][:, state_dict['proj.weight'].size()[1]:, :, :] = 0
1412
+ state_dict['proj.weight'] = model.state_dict()['proj.weight']
1413
+ else:
1414
+ model.state_dict()['proj.weight'][:, :, :, :] = state_dict['proj.weight'][:, :model.state_dict()['proj.weight'].size()[1], :, :]
1415
+ state_dict['proj.weight'] = model.state_dict()['proj.weight']
1416
+
1417
+ tmp_state_dict = {}
1418
+ for key in state_dict:
1419
+ if key in model.state_dict().keys() and model.state_dict()[key].size() == state_dict[key].size():
1420
+ tmp_state_dict[key] = state_dict[key]
1421
+ else:
1422
+ print(key, "Size don't match, skip")
1423
+
1424
+ state_dict = tmp_state_dict
1425
+
1426
+ m, u = model.load_state_dict(state_dict, strict=False)
1427
+ print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
1428
+ print(m)
1429
+
1430
+ params = [p.numel() if "." in n else 0 for n, p in model.named_parameters()]
1431
+ print(f"### All Parameters: {sum(params) / 1e6} M")
1432
+
1433
+ params = [p.numel() if "attn1." in n else 0 for n, p in model.named_parameters()]
1434
+ print(f"### attn1 Parameters: {sum(params) / 1e6} M")
1435
+
1436
  return model
easyanimate/pipeline/pipeline_easyanimate.py CHANGED
@@ -12,9 +12,9 @@
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
 
 
15
  import html
16
  import inspect
17
- import copy
18
  import re
19
  import urllib.parse as ul
20
  from dataclasses import dataclass
@@ -154,7 +154,8 @@ class EasyAnimatePipeline(DiffusionPipeline):
154
 
155
  self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
156
  self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
157
-
 
158
  # Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/utils.py
159
  def mask_text_embeddings(self, emb, mask):
160
  if emb.shape[0] == 1:
@@ -548,31 +549,13 @@ class EasyAnimatePipeline(DiffusionPipeline):
548
  prefix_index_before = mini_batch_encoder // 2
549
  prefix_index_after = mini_batch_encoder - prefix_index_before
550
  pixel_values = video[:, :, prefix_index_before:-prefix_index_after]
551
-
552
- if self.vae.slice_compression_vae:
553
- latents = self.vae.encode(pixel_values)[0]
554
- latents = latents.sample()
555
- else:
556
- new_pixel_values = []
557
- for i in range(0, pixel_values.shape[2], mini_batch_encoder):
558
- with torch.no_grad():
559
- pixel_values_bs = pixel_values[:, :, i: i + mini_batch_encoder, :, :]
560
- pixel_values_bs = self.vae.encode(pixel_values_bs)[0]
561
- pixel_values_bs = pixel_values_bs.sample()
562
- new_pixel_values.append(pixel_values_bs)
563
- latents = torch.cat(new_pixel_values, dim = 2)
564
-
565
- if self.vae.slice_compression_vae:
566
- middle_video = self.vae.decode(latents)[0]
567
- else:
568
- middle_video = []
569
- for i in range(0, latents.shape[2], mini_batch_decoder):
570
- with torch.no_grad():
571
- start_index = i
572
- end_index = i + mini_batch_decoder
573
- latents_bs = self.vae.decode(latents[:, :, start_index:end_index, :, :])[0]
574
- middle_video.append(latents_bs)
575
- middle_video = torch.cat(middle_video, 2)
576
  video[:, :, prefix_index_before:-prefix_index_after] = (video[:, :, prefix_index_before:-prefix_index_after] + middle_video) / 2
577
  return video
578
 
@@ -582,17 +565,7 @@ class EasyAnimatePipeline(DiffusionPipeline):
582
  if self.vae.quant_conv.weight.ndim==5:
583
  mini_batch_encoder = self.vae.mini_batch_encoder
584
  mini_batch_decoder = self.vae.mini_batch_decoder
585
- if self.vae.slice_compression_vae:
586
- video = self.vae.decode(latents)[0]
587
- else:
588
- video = []
589
- for i in range(0, latents.shape[2], mini_batch_decoder):
590
- with torch.no_grad():
591
- start_index = i
592
- end_index = i + mini_batch_decoder
593
- latents_bs = self.vae.decode(latents[:, :, start_index:end_index, :, :])[0]
594
- video.append(latents_bs)
595
- video = torch.cat(video, 2)
596
  video = video.clamp(-1, 1)
597
  video = self.smooth_output(video, mini_batch_encoder, mini_batch_decoder).cpu().clamp(-1, 1)
598
  else:
@@ -607,6 +580,9 @@ class EasyAnimatePipeline(DiffusionPipeline):
607
  video = video.cpu().float().numpy()
608
  return video
609
 
 
 
 
610
  @torch.no_grad()
611
  @replace_example_docstring(EXAMPLE_DOC_STRING)
612
  def __call__(
@@ -633,6 +609,7 @@ class EasyAnimatePipeline(DiffusionPipeline):
633
  callback_steps: int = 1,
634
  clean_caption: bool = True,
635
  max_sequence_length: int = 120,
 
636
  **kwargs,
637
  ) -> Union[EasyAnimatePipelineOutput, Tuple]:
638
  """
@@ -780,9 +757,16 @@ class EasyAnimatePipeline(DiffusionPipeline):
780
 
781
  added_cond_kwargs = {"resolution": resolution, "aspect_ratio": aspect_ratio}
782
 
 
 
 
 
 
783
  # 7. Denoising loop
784
  num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
785
-
 
 
786
  with self.progress_bar(total=num_inference_steps) as progress_bar:
787
  for i, t in enumerate(timesteps):
788
  latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
@@ -834,6 +818,12 @@ class EasyAnimatePipeline(DiffusionPipeline):
834
  step_idx = i // getattr(self.scheduler, "order", 1)
835
  callback(step_idx, t, latents)
836
 
 
 
 
 
 
 
837
  # Post-processing
838
  video = self.decode_latents(latents)
839
 
 
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
 
15
+ import copy
16
  import html
17
  import inspect
 
18
  import re
19
  import urllib.parse as ul
20
  from dataclasses import dataclass
 
154
 
155
  self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
156
  self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
157
+ self.enable_autocast_float8_transformer_flag = False
158
+
159
  # Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/utils.py
160
  def mask_text_embeddings(self, emb, mask):
161
  if emb.shape[0] == 1:
 
549
  prefix_index_before = mini_batch_encoder // 2
550
  prefix_index_after = mini_batch_encoder - prefix_index_before
551
  pixel_values = video[:, :, prefix_index_before:-prefix_index_after]
552
+
553
+ # Encode middle videos
554
+ latents = self.vae.encode(pixel_values)[0]
555
+ latents = latents.mode()
556
+ # Decode middle videos
557
+ middle_video = self.vae.decode(latents)[0]
558
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
559
  video[:, :, prefix_index_before:-prefix_index_after] = (video[:, :, prefix_index_before:-prefix_index_after] + middle_video) / 2
560
  return video
561
 
 
565
  if self.vae.quant_conv.weight.ndim==5:
566
  mini_batch_encoder = self.vae.mini_batch_encoder
567
  mini_batch_decoder = self.vae.mini_batch_decoder
568
+ video = self.vae.decode(latents)[0]
 
 
 
 
 
 
 
 
 
 
569
  video = video.clamp(-1, 1)
570
  video = self.smooth_output(video, mini_batch_encoder, mini_batch_decoder).cpu().clamp(-1, 1)
571
  else:
 
580
  video = video.cpu().float().numpy()
581
  return video
582
 
583
+ def enable_autocast_float8_transformer(self):
584
+ self.enable_autocast_float8_transformer_flag = True
585
+
586
  @torch.no_grad()
587
  @replace_example_docstring(EXAMPLE_DOC_STRING)
588
  def __call__(
 
609
  callback_steps: int = 1,
610
  clean_caption: bool = True,
611
  max_sequence_length: int = 120,
612
+ comfyui_progressbar: bool = False,
613
  **kwargs,
614
  ) -> Union[EasyAnimatePipelineOutput, Tuple]:
615
  """
 
757
 
758
  added_cond_kwargs = {"resolution": resolution, "aspect_ratio": aspect_ratio}
759
 
760
+ torch.cuda.empty_cache()
761
+ if self.enable_autocast_float8_transformer_flag:
762
+ origin_weight_dtype = self.transformer.dtype
763
+ self.transformer = self.transformer.to(torch.float8_e4m3fn)
764
+
765
  # 7. Denoising loop
766
  num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
767
+ if comfyui_progressbar:
768
+ from comfy.utils import ProgressBar
769
+ pbar = ProgressBar(num_inference_steps)
770
  with self.progress_bar(total=num_inference_steps) as progress_bar:
771
  for i, t in enumerate(timesteps):
772
  latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
 
818
  step_idx = i // getattr(self.scheduler, "order", 1)
819
  callback(step_idx, t, latents)
820
 
821
+ if comfyui_progressbar:
822
+ pbar.update(1)
823
+
824
+ if self.enable_autocast_float8_transformer_flag:
825
+ self.transformer = self.transformer.to("cpu", origin_weight_dtype)
826
+
827
  # Post-processing
828
  video = self.decode_latents(latents)
829
 
easyanimate/pipeline/pipeline_easyanimate_inpaint.py CHANGED
@@ -12,14 +12,13 @@
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
 
 
 
15
  import html
16
  import inspect
17
  import re
18
- import gc
19
- import copy
20
  import urllib.parse as ul
21
  from dataclasses import dataclass
22
- from PIL import Image
23
  from typing import Callable, List, Optional, Tuple, Union
24
 
25
  import numpy as np
@@ -34,9 +33,10 @@ from diffusers.utils import (BACKENDS_MAPPING, BaseOutput, deprecate,
34
  replace_example_docstring)
35
  from diffusers.utils.torch_utils import randn_tensor
36
  from einops import rearrange
 
37
  from tqdm import tqdm
38
- from transformers import T5EncoderModel, T5Tokenizer
39
- from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor
40
 
41
  from ..models.transformer3d import Transformer3DModel
42
 
@@ -129,6 +129,7 @@ class EasyAnimateInpaintPipeline(DiffusionPipeline):
129
  self.mask_processor = VaeImageProcessor(
130
  vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
131
  )
 
132
 
133
  # Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/utils.py
134
  def mask_text_embeddings(self, emb, mask):
@@ -493,6 +494,60 @@ class EasyAnimateInpaintPipeline(DiffusionPipeline):
493
 
494
  return caption.strip()
495
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
496
  def prepare_latents(
497
  self,
498
  batch_size,
@@ -529,22 +584,11 @@ class EasyAnimateInpaintPipeline(DiffusionPipeline):
529
  bs = 1
530
  mini_batch_encoder = self.vae.mini_batch_encoder
531
  new_video = []
532
- if self.vae.slice_compression_vae:
533
- for i in range(0, video.shape[0], bs):
534
- video_bs = video[i : i + bs]
535
- video_bs = self.vae.encode(video_bs)[0]
536
- video_bs = video_bs.sample()
537
- new_video.append(video_bs)
538
- else:
539
- for i in range(0, video.shape[0], bs):
540
- new_video_mini_batch = []
541
- for j in range(0, video.shape[2], mini_batch_encoder):
542
- video_bs = video[i : i + bs, :, j: j + mini_batch_encoder, :, :]
543
- video_bs = self.vae.encode(video_bs)[0]
544
- video_bs = video_bs.sample()
545
- new_video_mini_batch.append(video_bs)
546
- new_video_mini_batch = torch.cat(new_video_mini_batch, dim = 2)
547
- new_video.append(new_video_mini_batch)
548
  video = torch.cat(new_video, dim = 0)
549
  video = video * self.vae.config.scaling_factor
550
 
@@ -585,31 +629,13 @@ class EasyAnimateInpaintPipeline(DiffusionPipeline):
585
  prefix_index_before = mini_batch_encoder // 2
586
  prefix_index_after = mini_batch_encoder - prefix_index_before
587
  pixel_values = video[:, :, prefix_index_before:-prefix_index_after]
588
-
589
- if self.vae.slice_compression_vae:
590
- latents = self.vae.encode(pixel_values)[0]
591
- latents = latents.sample()
592
- else:
593
- new_pixel_values = []
594
- for i in range(0, pixel_values.shape[2], mini_batch_encoder):
595
- with torch.no_grad():
596
- pixel_values_bs = pixel_values[:, :, i: i + mini_batch_encoder, :, :]
597
- pixel_values_bs = self.vae.encode(pixel_values_bs)[0]
598
- pixel_values_bs = pixel_values_bs.sample()
599
- new_pixel_values.append(pixel_values_bs)
600
- latents = torch.cat(new_pixel_values, dim = 2)
601
-
602
- if self.vae.slice_compression_vae:
603
- middle_video = self.vae.decode(latents)[0]
604
- else:
605
- middle_video = []
606
- for i in range(0, latents.shape[2], mini_batch_decoder):
607
- with torch.no_grad():
608
- start_index = i
609
- end_index = i + mini_batch_decoder
610
- latents_bs = self.vae.decode(latents[:, :, start_index:end_index, :, :])[0]
611
- middle_video.append(latents_bs)
612
- middle_video = torch.cat(middle_video, 2)
613
  video[:, :, prefix_index_before:-prefix_index_after] = (video[:, :, prefix_index_before:-prefix_index_after] + middle_video) / 2
614
  return video
615
 
@@ -619,17 +645,7 @@ class EasyAnimateInpaintPipeline(DiffusionPipeline):
619
  if self.vae.quant_conv.weight.ndim==5:
620
  mini_batch_encoder = self.vae.mini_batch_encoder
621
  mini_batch_decoder = self.vae.mini_batch_decoder
622
- if self.vae.slice_compression_vae:
623
- video = self.vae.decode(latents)[0]
624
- else:
625
- video = []
626
- for i in range(0, latents.shape[2], mini_batch_decoder):
627
- with torch.no_grad():
628
- start_index = i
629
- end_index = i + mini_batch_decoder
630
- latents_bs = self.vae.decode(latents[:, :, start_index:end_index, :, :])[0]
631
- video.append(latents_bs)
632
- video = torch.cat(video, 2)
633
  video = video.clamp(-1, 1)
634
  video = self.smooth_output(video, mini_batch_encoder, mini_batch_decoder).cpu().clamp(-1, 1)
635
  else:
@@ -668,84 +684,9 @@ class EasyAnimateInpaintPipeline(DiffusionPipeline):
668
 
669
  return timesteps, num_inference_steps - t_start
670
 
671
- def prepare_mask_latents(
672
- self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
673
- ):
674
- # resize the mask to latents shape as we concatenate the mask to the latents
675
- # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
676
- # and half precision
677
- video_length = mask.shape[2]
678
-
679
- mask = mask.to(device=device, dtype=self.vae.dtype)
680
- if self.vae.quant_conv.weight.ndim==5:
681
- bs = 1
682
- mini_batch_encoder = self.vae.mini_batch_encoder
683
- new_mask = []
684
- if self.vae.slice_compression_vae:
685
- for i in range(0, mask.shape[0], bs):
686
- mask_bs = mask[i : i + bs]
687
- mask_bs = self.vae.encode(mask_bs)[0]
688
- mask_bs = mask_bs.sample()
689
- new_mask.append(mask_bs)
690
- else:
691
- for i in range(0, mask.shape[0], bs):
692
- new_mask_mini_batch = []
693
- for j in range(0, mask.shape[2], mini_batch_encoder):
694
- mask_bs = mask[i : i + bs, :, j: j + mini_batch_encoder, :, :]
695
- mask_bs = self.vae.encode(mask_bs)[0]
696
- mask_bs = mask_bs.sample()
697
- new_mask_mini_batch.append(mask_bs)
698
- new_mask_mini_batch = torch.cat(new_mask_mini_batch, dim = 2)
699
- new_mask.append(new_mask_mini_batch)
700
- mask = torch.cat(new_mask, dim = 0)
701
- mask = mask * self.vae.config.scaling_factor
702
-
703
- else:
704
- if mask.shape[1] == 4:
705
- mask = mask
706
- else:
707
- video_length = mask.shape[2]
708
- mask = rearrange(mask, "b c f h w -> (b f) c h w")
709
- mask = self._encode_vae_image(mask, generator=generator)
710
- mask = rearrange(mask, "(b f) c h w -> b c f h w", f=video_length)
711
-
712
- masked_image = masked_image.to(device=device, dtype=self.vae.dtype)
713
- if self.vae.quant_conv.weight.ndim==5:
714
- bs = 1
715
- mini_batch_encoder = self.vae.mini_batch_encoder
716
- new_mask_pixel_values = []
717
- if self.vae.slice_compression_vae:
718
- for i in range(0, masked_image.shape[0], bs):
719
- mask_pixel_values_bs = masked_image[i : i + bs]
720
- mask_pixel_values_bs = self.vae.encode(mask_pixel_values_bs)[0]
721
- mask_pixel_values_bs = mask_pixel_values_bs.sample()
722
- new_mask_pixel_values.append(mask_pixel_values_bs)
723
- else:
724
- for i in range(0, masked_image.shape[0], bs):
725
- new_mask_pixel_values_mini_batch = []
726
- for j in range(0, masked_image.shape[2], mini_batch_encoder):
727
- mask_pixel_values_bs = masked_image[i : i + bs, :, j: j + mini_batch_encoder, :, :]
728
- mask_pixel_values_bs = self.vae.encode(mask_pixel_values_bs)[0]
729
- mask_pixel_values_bs = mask_pixel_values_bs.sample()
730
- new_mask_pixel_values_mini_batch.append(mask_pixel_values_bs)
731
- new_mask_pixel_values_mini_batch = torch.cat(new_mask_pixel_values_mini_batch, dim = 2)
732
- new_mask_pixel_values.append(new_mask_pixel_values_mini_batch)
733
- masked_image_latents = torch.cat(new_mask_pixel_values, dim = 0)
734
- masked_image_latents = masked_image_latents * self.vae.config.scaling_factor
735
-
736
- else:
737
- if masked_image.shape[1] == 4:
738
- masked_image_latents = masked_image
739
- else:
740
- video_length = mask.shape[2]
741
- masked_image = rearrange(masked_image, "b c f h w -> (b f) c h w")
742
- masked_image_latents = self._encode_vae_image(masked_image, generator=generator)
743
- masked_image_latents = rearrange(masked_image_latents, "(b f) c h w -> b c f h w", f=video_length)
744
 
745
- # aligning device to prevent device errors when concating it with the latent model input
746
- masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
747
- return mask, masked_image_latents
748
-
749
  @torch.no_grad()
750
  @replace_example_docstring(EXAMPLE_DOC_STRING)
751
  def __call__(
@@ -779,6 +720,8 @@ class EasyAnimateInpaintPipeline(DiffusionPipeline):
779
  max_sequence_length: int = 120,
780
  clip_image: Image = None,
781
  clip_apply_ratio: float = 0.50,
 
 
782
  ) -> Union[EasyAnimatePipelineOutput, Tuple]:
783
  """
784
  Function invoked when calling the pipeline for generation.
@@ -1057,10 +1000,16 @@ class EasyAnimateInpaintPipeline(DiffusionPipeline):
1057
  gc.collect()
1058
  torch.cuda.empty_cache()
1059
  torch.cuda.ipc_collect()
 
 
 
1060
 
1061
  # 10. Denoising loop
1062
  num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
1063
  self._num_timesteps = len(timesteps)
 
 
 
1064
  with self.progress_bar(total=num_inference_steps) as progress_bar:
1065
  for i, t in enumerate(timesteps):
1066
  latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
@@ -1130,16 +1079,19 @@ class EasyAnimateInpaintPipeline(DiffusionPipeline):
1130
  step_idx = i // getattr(self.scheduler, "order", 1)
1131
  callback(step_idx, t, latents)
1132
 
 
 
 
 
 
 
1133
  gc.collect()
1134
  torch.cuda.empty_cache()
1135
  torch.cuda.ipc_collect()
1136
 
1137
  # Post-processing
1138
  video = self.decode_latents(latents)
1139
-
1140
- gc.collect()
1141
- torch.cuda.empty_cache()
1142
- torch.cuda.ipc_collect()
1143
  # Convert to tensor
1144
  if output_type == "latent":
1145
  video = torch.from_numpy(video)
 
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
 
15
+ import copy
16
+ import gc
17
  import html
18
  import inspect
19
  import re
 
 
20
  import urllib.parse as ul
21
  from dataclasses import dataclass
 
22
  from typing import Callable, List, Optional, Tuple, Union
23
 
24
  import numpy as np
 
33
  replace_example_docstring)
34
  from diffusers.utils.torch_utils import randn_tensor
35
  from einops import rearrange
36
+ from PIL import Image
37
  from tqdm import tqdm
38
+ from transformers import (CLIPImageProcessor, CLIPVisionModelWithProjection,
39
+ T5EncoderModel, T5Tokenizer)
40
 
41
  from ..models.transformer3d import Transformer3DModel
42
 
 
129
  self.mask_processor = VaeImageProcessor(
130
  vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
131
  )
132
+ self.enable_autocast_float8_transformer_flag = False
133
 
134
  # Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/utils.py
135
  def mask_text_embeddings(self, emb, mask):
 
494
 
495
  return caption.strip()
496
 
497
+ def prepare_mask_latents(
498
+ self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
499
+ ):
500
+ # resize the mask to latents shape as we concatenate the mask to the latents
501
+ # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
502
+ # and half precision
503
+ video_length = mask.shape[2]
504
+
505
+ mask = mask.to(device=device, dtype=self.vae.dtype)
506
+ if self.vae.quant_conv.weight.ndim==5:
507
+ bs = 1
508
+ new_mask = []
509
+ for i in range(0, mask.shape[0], bs):
510
+ mask_bs = mask[i : i + bs]
511
+ mask_bs = self.vae.encode(mask_bs)[0]
512
+ mask_bs = mask_bs.sample()
513
+ new_mask.append(mask_bs)
514
+ mask = torch.cat(new_mask, dim = 0)
515
+ mask = mask * self.vae.config.scaling_factor
516
+
517
+ else:
518
+ if mask.shape[1] == 4:
519
+ mask = mask
520
+ else:
521
+ video_length = mask.shape[2]
522
+ mask = rearrange(mask, "b c f h w -> (b f) c h w")
523
+ mask = self._encode_vae_image(mask, generator=generator)
524
+ mask = rearrange(mask, "(b f) c h w -> b c f h w", f=video_length)
525
+
526
+ masked_image = masked_image.to(device=device, dtype=self.vae.dtype)
527
+ if self.vae.quant_conv.weight.ndim==5:
528
+ bs = 1
529
+ new_mask_pixel_values = []
530
+ for i in range(0, masked_image.shape[0], bs):
531
+ mask_pixel_values_bs = masked_image[i : i + bs]
532
+ mask_pixel_values_bs = self.vae.encode(mask_pixel_values_bs)[0]
533
+ mask_pixel_values_bs = mask_pixel_values_bs.sample()
534
+ new_mask_pixel_values.append(mask_pixel_values_bs)
535
+ masked_image_latents = torch.cat(new_mask_pixel_values, dim = 0)
536
+ masked_image_latents = masked_image_latents * self.vae.config.scaling_factor
537
+
538
+ else:
539
+ if masked_image.shape[1] == 4:
540
+ masked_image_latents = masked_image
541
+ else:
542
+ video_length = mask.shape[2]
543
+ masked_image = rearrange(masked_image, "b c f h w -> (b f) c h w")
544
+ masked_image_latents = self._encode_vae_image(masked_image, generator=generator)
545
+ masked_image_latents = rearrange(masked_image_latents, "(b f) c h w -> b c f h w", f=video_length)
546
+
547
+ # aligning device to prevent device errors when concating it with the latent model input
548
+ masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
549
+ return mask, masked_image_latents
550
+
551
  def prepare_latents(
552
  self,
553
  batch_size,
 
584
  bs = 1
585
  mini_batch_encoder = self.vae.mini_batch_encoder
586
  new_video = []
587
+ for i in range(0, video.shape[0], bs):
588
+ video_bs = video[i : i + bs]
589
+ video_bs = self.vae.encode(video_bs)[0]
590
+ video_bs = video_bs.sample()
591
+ new_video.append(video_bs)
 
 
 
 
 
 
 
 
 
 
 
592
  video = torch.cat(new_video, dim = 0)
593
  video = video * self.vae.config.scaling_factor
594
 
 
629
  prefix_index_before = mini_batch_encoder // 2
630
  prefix_index_after = mini_batch_encoder - prefix_index_before
631
  pixel_values = video[:, :, prefix_index_before:-prefix_index_after]
632
+
633
+ # Encode middle videos
634
+ latents = self.vae.encode(pixel_values)[0]
635
+ latents = latents.sample()
636
+ # Decode middle videos
637
+ middle_video = self.vae.decode(latents)[0]
638
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
639
  video[:, :, prefix_index_before:-prefix_index_after] = (video[:, :, prefix_index_before:-prefix_index_after] + middle_video) / 2
640
  return video
641
 
 
645
  if self.vae.quant_conv.weight.ndim==5:
646
  mini_batch_encoder = self.vae.mini_batch_encoder
647
  mini_batch_decoder = self.vae.mini_batch_decoder
648
+ video = self.vae.decode(latents)[0]
 
 
 
 
 
 
 
 
 
 
649
  video = video.clamp(-1, 1)
650
  video = self.smooth_output(video, mini_batch_encoder, mini_batch_decoder).cpu().clamp(-1, 1)
651
  else:
 
684
 
685
  return timesteps, num_inference_steps - t_start
686
 
687
+ def enable_autocast_float8_transformer(self):
688
+ self.enable_autocast_float8_transformer_flag = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
689
 
 
 
 
 
690
  @torch.no_grad()
691
  @replace_example_docstring(EXAMPLE_DOC_STRING)
692
  def __call__(
 
720
  max_sequence_length: int = 120,
721
  clip_image: Image = None,
722
  clip_apply_ratio: float = 0.50,
723
+ comfyui_progressbar: bool = False,
724
+ **kwargs,
725
  ) -> Union[EasyAnimatePipelineOutput, Tuple]:
726
  """
727
  Function invoked when calling the pipeline for generation.
 
1000
  gc.collect()
1001
  torch.cuda.empty_cache()
1002
  torch.cuda.ipc_collect()
1003
+ if self.enable_autocast_float8_transformer_flag:
1004
+ origin_weight_dtype = self.transformer.dtype
1005
+ self.transformer = self.transformer.to(torch.float8_e4m3fn)
1006
 
1007
  # 10. Denoising loop
1008
  num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
1009
  self._num_timesteps = len(timesteps)
1010
+ if comfyui_progressbar:
1011
+ from comfy.utils import ProgressBar
1012
+ pbar = ProgressBar(num_inference_steps)
1013
  with self.progress_bar(total=num_inference_steps) as progress_bar:
1014
  for i, t in enumerate(timesteps):
1015
  latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
 
1079
  step_idx = i // getattr(self.scheduler, "order", 1)
1080
  callback(step_idx, t, latents)
1081
 
1082
+ if comfyui_progressbar:
1083
+ pbar.update(1)
1084
+
1085
+ if self.enable_autocast_float8_transformer_flag:
1086
+ self.transformer = self.transformer.to("cpu", origin_weight_dtype)
1087
+
1088
  gc.collect()
1089
  torch.cuda.empty_cache()
1090
  torch.cuda.ipc_collect()
1091
 
1092
  # Post-processing
1093
  video = self.decode_latents(latents)
1094
+
 
 
 
1095
  # Convert to tensor
1096
  if output_type == "latent":
1097
  video = torch.from_numpy(video)
easyanimate/pipeline/pipeline_easyanimate_multi_text_encoder.py ADDED
@@ -0,0 +1,925 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 EasyAnimate Authors and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ from typing import Callable, Dict, List, Optional, Tuple, Union
17
+
18
+ import numpy as np
19
+ import torch
20
+ from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
21
+ from diffusers.image_processor import VaeImageProcessor
22
+ from diffusers.models.embeddings import (get_2d_rotary_pos_embed,
23
+ get_3d_rotary_pos_embed)
24
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
25
+ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
26
+ from diffusers.pipelines.stable_diffusion.safety_checker import \
27
+ StableDiffusionSafetyChecker
28
+ from diffusers.schedulers import DDIMScheduler
29
+ from diffusers.utils import (is_torch_xla_available, logging,
30
+ replace_example_docstring)
31
+ from diffusers.utils.torch_utils import randn_tensor
32
+ from einops import rearrange
33
+ from tqdm import tqdm
34
+ from transformers import (BertModel, BertTokenizer, CLIPImageProcessor,
35
+ T5Tokenizer, T5EncoderModel)
36
+
37
+ from .pipeline_easyanimate import EasyAnimatePipelineOutput
38
+ from ..models import AutoencoderKLMagvit, EasyAnimateTransformer3DModel
39
+
40
+ if is_torch_xla_available():
41
+ import torch_xla.core.xla_model as xm
42
+
43
+ XLA_AVAILABLE = True
44
+ else:
45
+ XLA_AVAILABLE = False
46
+
47
+
48
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
49
+
50
+ EXAMPLE_DOC_STRING = """
51
+ Examples:
52
+ ```py
53
+ >>> pass
54
+ ```
55
+ """
56
+
57
+
58
+ def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
59
+ tw = tgt_width
60
+ th = tgt_height
61
+ h, w = src
62
+ r = h / w
63
+ if r > (th / tw):
64
+ resize_height = th
65
+ resize_width = int(round(th / h * w))
66
+ else:
67
+ resize_width = tw
68
+ resize_height = int(round(tw / w * h))
69
+
70
+ crop_top = int(round((th - resize_height) / 2.0))
71
+ crop_left = int(round((tw - resize_width) / 2.0))
72
+
73
+ return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
74
+
75
+
76
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
77
+ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
78
+ """
79
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
80
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
81
+ """
82
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
83
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
84
+ # rescale the results from guidance (fixes overexposure)
85
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
86
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
87
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
88
+ return noise_cfg
89
+
90
+
91
+ class EasyAnimatePipeline_Multi_Text_Encoder(DiffusionPipeline):
92
+ r"""
93
+ Pipeline for text-to-video generation using EasyAnimate.
94
+
95
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
96
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
97
+
98
+ EasyAnimate uses two text encoders: [mT5](https://huggingface.co/google/mt5-base) and [bilingual CLIP](fine-tuned by
99
+ HunyuanDiT team)
100
+
101
+ Args:
102
+ vae ([`AutoencoderKLMagvit`]):
103
+ Variational Auto-Encoder (VAE) Model to encode and decode video to and from latent representations.
104
+ text_encoder (Optional[`~transformers.BertModel`, `~transformers.CLIPTextModel`]):
105
+ Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
106
+ EasyAnimate uses a fine-tuned [bilingual CLIP].
107
+ tokenizer (Optional[`~transformers.BertTokenizer`, `~transformers.CLIPTokenizer`]):
108
+ A `BertTokenizer` or `CLIPTokenizer` to tokenize text.
109
+ transformer ([`EasyAnimateTransformer3DModel`]):
110
+ The EasyAnimate model designed by Tencent Hunyuan.
111
+ text_encoder_2 (`T5EncoderModel`):
112
+ The mT5 embedder.
113
+ tokenizer_2 (`T5Tokenizer`):
114
+ The tokenizer for the mT5 embedder.
115
+ scheduler ([`DDIMScheduler`]):
116
+ A scheduler to be used in combination with EasyAnimate to denoise the encoded image latents.
117
+ """
118
+
119
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
120
+ _optional_components = [
121
+ "safety_checker",
122
+ "feature_extractor",
123
+ "text_encoder_2",
124
+ "tokenizer_2",
125
+ "text_encoder",
126
+ "tokenizer",
127
+ ]
128
+ _exclude_from_cpu_offload = ["safety_checker"]
129
+ _callback_tensor_inputs = [
130
+ "latents",
131
+ "prompt_embeds",
132
+ "negative_prompt_embeds",
133
+ "prompt_embeds_2",
134
+ "negative_prompt_embeds_2",
135
+ ]
136
+
137
+ def __init__(
138
+ self,
139
+ vae: AutoencoderKLMagvit,
140
+ text_encoder: BertModel,
141
+ tokenizer: BertTokenizer,
142
+ text_encoder_2: T5EncoderModel,
143
+ tokenizer_2: T5Tokenizer,
144
+ transformer: EasyAnimateTransformer3DModel,
145
+ scheduler: DDIMScheduler,
146
+ safety_checker: StableDiffusionSafetyChecker,
147
+ feature_extractor: CLIPImageProcessor,
148
+ requires_safety_checker: bool = True,
149
+ ):
150
+ super().__init__()
151
+
152
+ self.register_modules(
153
+ vae=vae,
154
+ text_encoder=text_encoder,
155
+ tokenizer=tokenizer,
156
+ tokenizer_2=tokenizer_2,
157
+ transformer=transformer,
158
+ scheduler=scheduler,
159
+ safety_checker=safety_checker,
160
+ feature_extractor=feature_extractor,
161
+ text_encoder_2=text_encoder_2,
162
+ )
163
+
164
+ if safety_checker is None and requires_safety_checker:
165
+ logger.warning(
166
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
167
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
168
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
169
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
170
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
171
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
172
+ )
173
+
174
+ if safety_checker is not None and feature_extractor is None:
175
+ raise ValueError(
176
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
177
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
178
+ )
179
+
180
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
181
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
182
+ self.enable_autocast_float8_transformer_flag = False
183
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
184
+
185
+ def enable_sequential_cpu_offload(self, *args, **kwargs):
186
+ super().enable_sequential_cpu_offload(*args, **kwargs)
187
+ if hasattr(self.transformer, "clip_projection") and self.transformer.clip_projection is not None:
188
+ import accelerate
189
+ accelerate.hooks.remove_hook_from_module(self.transformer.clip_projection, recurse=True)
190
+ self.transformer.clip_projection = self.transformer.clip_projection.to("cuda")
191
+
192
+ def encode_prompt(
193
+ self,
194
+ prompt: str,
195
+ device: torch.device,
196
+ dtype: torch.dtype,
197
+ num_images_per_prompt: int = 1,
198
+ do_classifier_free_guidance: bool = True,
199
+ negative_prompt: Optional[str] = None,
200
+ prompt_embeds: Optional[torch.Tensor] = None,
201
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
202
+ prompt_attention_mask: Optional[torch.Tensor] = None,
203
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
204
+ max_sequence_length: Optional[int] = None,
205
+ text_encoder_index: int = 0,
206
+ actual_max_sequence_length: int = 256
207
+ ):
208
+ r"""
209
+ Encodes the prompt into text encoder hidden states.
210
+
211
+ Args:
212
+ prompt (`str` or `List[str]`, *optional*):
213
+ prompt to be encoded
214
+ device: (`torch.device`):
215
+ torch device
216
+ dtype (`torch.dtype`):
217
+ torch dtype
218
+ num_images_per_prompt (`int`):
219
+ number of images that should be generated per prompt
220
+ do_classifier_free_guidance (`bool`):
221
+ whether to use classifier free guidance or not
222
+ negative_prompt (`str` or `List[str]`, *optional*):
223
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
224
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
225
+ less than `1`).
226
+ prompt_embeds (`torch.Tensor`, *optional*):
227
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
228
+ provided, text embeddings will be generated from `prompt` input argument.
229
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
230
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
231
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
232
+ argument.
233
+ prompt_attention_mask (`torch.Tensor`, *optional*):
234
+ Attention mask for the prompt. Required when `prompt_embeds` is passed directly.
235
+ negative_prompt_attention_mask (`torch.Tensor`, *optional*):
236
+ Attention mask for the negative prompt. Required when `negative_prompt_embeds` is passed directly.
237
+ max_sequence_length (`int`, *optional*): maximum sequence length to use for the prompt.
238
+ text_encoder_index (`int`, *optional*):
239
+ Index of the text encoder to use. `0` for clip and `1` for T5.
240
+ """
241
+ tokenizers = [self.tokenizer, self.tokenizer_2]
242
+ text_encoders = [self.text_encoder, self.text_encoder_2]
243
+
244
+ tokenizer = tokenizers[text_encoder_index]
245
+ text_encoder = text_encoders[text_encoder_index]
246
+
247
+ if max_sequence_length is None:
248
+ if text_encoder_index == 0:
249
+ max_length = min(self.tokenizer.model_max_length, actual_max_sequence_length)
250
+ if text_encoder_index == 1:
251
+ max_length = min(self.tokenizer_2.model_max_length, actual_max_sequence_length)
252
+ else:
253
+ max_length = max_sequence_length
254
+
255
+ if prompt is not None and isinstance(prompt, str):
256
+ batch_size = 1
257
+ elif prompt is not None and isinstance(prompt, list):
258
+ batch_size = len(prompt)
259
+ else:
260
+ batch_size = prompt_embeds.shape[0]
261
+
262
+ if prompt_embeds is None:
263
+ text_inputs = tokenizer(
264
+ prompt,
265
+ padding="max_length",
266
+ max_length=max_length,
267
+ truncation=True,
268
+ return_attention_mask=True,
269
+ return_tensors="pt",
270
+ )
271
+ text_input_ids = text_inputs.input_ids
272
+ if text_input_ids.shape[-1] > actual_max_sequence_length:
273
+ reprompt = tokenizer.batch_decode(text_input_ids[:, :actual_max_sequence_length], skip_special_tokens=True)
274
+ text_inputs = tokenizer(
275
+ reprompt,
276
+ padding="max_length",
277
+ max_length=max_length,
278
+ truncation=True,
279
+ return_attention_mask=True,
280
+ return_tensors="pt",
281
+ )
282
+ text_input_ids = text_inputs.input_ids
283
+ untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
284
+
285
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
286
+ text_input_ids, untruncated_ids
287
+ ):
288
+ _actual_max_sequence_length = min(tokenizer.model_max_length, actual_max_sequence_length)
289
+ removed_text = tokenizer.batch_decode(untruncated_ids[:, _actual_max_sequence_length - 1 : -1])
290
+ logger.warning(
291
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
292
+ f" {_actual_max_sequence_length} tokens: {removed_text}"
293
+ )
294
+ prompt_attention_mask = text_inputs.attention_mask.to(device)
295
+
296
+ if self.transformer.config.enable_text_attention_mask:
297
+ prompt_embeds = text_encoder(
298
+ text_input_ids.to(device),
299
+ attention_mask=prompt_attention_mask,
300
+ )
301
+ else:
302
+ prompt_embeds = text_encoder(
303
+ text_input_ids.to(device)
304
+ )
305
+ prompt_embeds = prompt_embeds[0]
306
+ prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
307
+
308
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
309
+
310
+ bs_embed, seq_len, _ = prompt_embeds.shape
311
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
312
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
313
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
314
+
315
+ # get unconditional embeddings for classifier free guidance
316
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
317
+ uncond_tokens: List[str]
318
+ if negative_prompt is None:
319
+ uncond_tokens = [""] * batch_size
320
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
321
+ raise TypeError(
322
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
323
+ f" {type(prompt)}."
324
+ )
325
+ elif isinstance(negative_prompt, str):
326
+ uncond_tokens = [negative_prompt]
327
+ elif batch_size != len(negative_prompt):
328
+ raise ValueError(
329
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
330
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
331
+ " the batch size of `prompt`."
332
+ )
333
+ else:
334
+ uncond_tokens = negative_prompt
335
+
336
+ max_length = prompt_embeds.shape[1]
337
+ uncond_input = tokenizer(
338
+ uncond_tokens,
339
+ padding="max_length",
340
+ max_length=max_length,
341
+ truncation=True,
342
+ return_tensors="pt",
343
+ )
344
+ uncond_input_ids = uncond_input.input_ids
345
+ if uncond_input_ids.shape[-1] > actual_max_sequence_length:
346
+ reuncond_tokens = tokenizer.batch_decode(uncond_input_ids[:, :actual_max_sequence_length], skip_special_tokens=True)
347
+ uncond_input = tokenizer(
348
+ reuncond_tokens,
349
+ padding="max_length",
350
+ max_length=max_length,
351
+ truncation=True,
352
+ return_attention_mask=True,
353
+ return_tensors="pt",
354
+ )
355
+ uncond_input_ids = uncond_input.input_ids
356
+
357
+ negative_prompt_attention_mask = uncond_input.attention_mask.to(device)
358
+ if self.transformer.config.enable_text_attention_mask:
359
+ negative_prompt_embeds = text_encoder(
360
+ uncond_input.input_ids.to(device),
361
+ attention_mask=negative_prompt_attention_mask,
362
+ )
363
+ else:
364
+ negative_prompt_embeds = text_encoder(
365
+ uncond_input.input_ids.to(device)
366
+ )
367
+ negative_prompt_embeds = negative_prompt_embeds[0]
368
+ negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
369
+
370
+ if do_classifier_free_guidance:
371
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
372
+ seq_len = negative_prompt_embeds.shape[1]
373
+
374
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
375
+
376
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
377
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
378
+
379
+ return prompt_embeds, negative_prompt_embeds, prompt_attention_mask, negative_prompt_attention_mask
380
+
381
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
382
+ def run_safety_checker(self, image, device, dtype):
383
+ if self.safety_checker is None:
384
+ has_nsfw_concept = None
385
+ else:
386
+ if torch.is_tensor(image):
387
+ feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
388
+ else:
389
+ feature_extractor_input = self.image_processor.numpy_to_pil(image)
390
+ safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
391
+ image, has_nsfw_concept = self.safety_checker(
392
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
393
+ )
394
+ return image, has_nsfw_concept
395
+
396
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
397
+ def prepare_extra_step_kwargs(self, generator, eta):
398
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
399
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
400
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
401
+ # and should be between [0, 1]
402
+
403
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
404
+ extra_step_kwargs = {}
405
+ if accepts_eta:
406
+ extra_step_kwargs["eta"] = eta
407
+
408
+ # check if the scheduler accepts generator
409
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
410
+ if accepts_generator:
411
+ extra_step_kwargs["generator"] = generator
412
+ return extra_step_kwargs
413
+
414
+ def check_inputs(
415
+ self,
416
+ prompt,
417
+ height,
418
+ width,
419
+ negative_prompt=None,
420
+ prompt_embeds=None,
421
+ negative_prompt_embeds=None,
422
+ prompt_attention_mask=None,
423
+ negative_prompt_attention_mask=None,
424
+ prompt_embeds_2=None,
425
+ negative_prompt_embeds_2=None,
426
+ prompt_attention_mask_2=None,
427
+ negative_prompt_attention_mask_2=None,
428
+ callback_on_step_end_tensor_inputs=None,
429
+ ):
430
+ if height % 8 != 0 or width % 8 != 0:
431
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
432
+
433
+ if callback_on_step_end_tensor_inputs is not None and not all(
434
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
435
+ ):
436
+ raise ValueError(
437
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
438
+ )
439
+
440
+ if prompt is not None and prompt_embeds is not None:
441
+ raise ValueError(
442
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
443
+ " only forward one of the two."
444
+ )
445
+ elif prompt is None and prompt_embeds is None:
446
+ raise ValueError(
447
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
448
+ )
449
+ elif prompt is None and prompt_embeds_2 is None:
450
+ raise ValueError(
451
+ "Provide either `prompt` or `prompt_embeds_2`. Cannot leave both `prompt` and `prompt_embeds_2` undefined."
452
+ )
453
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
454
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
455
+
456
+ if prompt_embeds is not None and prompt_attention_mask is None:
457
+ raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
458
+
459
+ if prompt_embeds_2 is not None and prompt_attention_mask_2 is None:
460
+ raise ValueError("Must provide `prompt_attention_mask_2` when specifying `prompt_embeds_2`.")
461
+
462
+ if negative_prompt is not None and negative_prompt_embeds is not None:
463
+ raise ValueError(
464
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
465
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
466
+ )
467
+
468
+ if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
469
+ raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
470
+
471
+ if negative_prompt_embeds_2 is not None and negative_prompt_attention_mask_2 is None:
472
+ raise ValueError(
473
+ "Must provide `negative_prompt_attention_mask_2` when specifying `negative_prompt_embeds_2`."
474
+ )
475
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
476
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
477
+ raise ValueError(
478
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
479
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
480
+ f" {negative_prompt_embeds.shape}."
481
+ )
482
+ if prompt_embeds_2 is not None and negative_prompt_embeds_2 is not None:
483
+ if prompt_embeds_2.shape != negative_prompt_embeds_2.shape:
484
+ raise ValueError(
485
+ "`prompt_embeds_2` and `negative_prompt_embeds_2` must have the same shape when passed directly, but"
486
+ f" got: `prompt_embeds_2` {prompt_embeds_2.shape} != `negative_prompt_embeds_2`"
487
+ f" {negative_prompt_embeds_2.shape}."
488
+ )
489
+
490
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
491
+ def prepare_latents(self, batch_size, num_channels_latents, video_length, height, width, dtype, device, generator, latents=None):
492
+ if self.vae.quant_conv is None or self.vae.quant_conv.weight.ndim==5:
493
+ if self.vae.cache_mag_vae:
494
+ mini_batch_encoder = self.vae.mini_batch_encoder
495
+ mini_batch_decoder = self.vae.mini_batch_decoder
496
+ shape = (batch_size, num_channels_latents, int((video_length - 1) // mini_batch_encoder * mini_batch_decoder + 1) if video_length != 1 else 1, height // self.vae_scale_factor, width // self.vae_scale_factor)
497
+ else:
498
+ mini_batch_encoder = self.vae.mini_batch_encoder
499
+ mini_batch_decoder = self.vae.mini_batch_decoder
500
+ shape = (batch_size, num_channels_latents, int(video_length // mini_batch_encoder * mini_batch_decoder) if video_length != 1 else 1, height // self.vae_scale_factor, width // self.vae_scale_factor)
501
+ else:
502
+ shape = (batch_size, num_channels_latents, video_length, height // self.vae_scale_factor, width // self.vae_scale_factor)
503
+
504
+ if isinstance(generator, list) and len(generator) != batch_size:
505
+ raise ValueError(
506
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
507
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
508
+ )
509
+
510
+ if latents is None:
511
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
512
+ else:
513
+ latents = latents.to(device)
514
+
515
+ # scale the initial noise by the standard deviation required by the scheduler
516
+ latents = latents * self.scheduler.init_noise_sigma
517
+ return latents
518
+
519
+ def smooth_output(self, video, mini_batch_encoder, mini_batch_decoder):
520
+ if video.size()[2] <= mini_batch_encoder:
521
+ return video
522
+ prefix_index_before = mini_batch_encoder // 2
523
+ prefix_index_after = mini_batch_encoder - prefix_index_before
524
+ pixel_values = video[:, :, prefix_index_before:-prefix_index_after]
525
+
526
+ # Encode middle videos
527
+ latents = self.vae.encode(pixel_values)[0]
528
+ latents = latents.mode()
529
+ # Decode middle videos
530
+ middle_video = self.vae.decode(latents)[0]
531
+
532
+ video[:, :, prefix_index_before:-prefix_index_after] = (video[:, :, prefix_index_before:-prefix_index_after] + middle_video) / 2
533
+ return video
534
+
535
+ def decode_latents(self, latents):
536
+ video_length = latents.shape[2]
537
+ latents = 1 / self.vae.config.scaling_factor * latents
538
+ if self.vae.quant_conv is None or self.vae.quant_conv.weight.ndim==5:
539
+ mini_batch_encoder = self.vae.mini_batch_encoder
540
+ mini_batch_decoder = self.vae.mini_batch_decoder
541
+ video = self.vae.decode(latents)[0]
542
+ video = video.clamp(-1, 1)
543
+ if not self.vae.cache_compression_vae and not self.vae.cache_mag_vae:
544
+ video = self.smooth_output(video, mini_batch_encoder, mini_batch_decoder).cpu().clamp(-1, 1)
545
+ else:
546
+ latents = rearrange(latents, "b c f h w -> (b f) c h w")
547
+ video = []
548
+ for frame_idx in tqdm(range(latents.shape[0])):
549
+ video.append(self.vae.decode(latents[frame_idx:frame_idx+1]).sample)
550
+ video = torch.cat(video)
551
+ video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
552
+ video = (video / 2 + 0.5).clamp(0, 1)
553
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
554
+ video = video.cpu().float().numpy()
555
+ return video
556
+
557
+ @property
558
+ def guidance_scale(self):
559
+ return self._guidance_scale
560
+
561
+ @property
562
+ def guidance_rescale(self):
563
+ return self._guidance_rescale
564
+
565
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
566
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
567
+ # corresponds to doing no classifier free guidance.
568
+ @property
569
+ def do_classifier_free_guidance(self):
570
+ return self._guidance_scale > 1
571
+
572
+ @property
573
+ def num_timesteps(self):
574
+ return self._num_timesteps
575
+
576
+ @property
577
+ def interrupt(self):
578
+ return self._interrupt
579
+
580
+ def enable_autocast_float8_transformer(self):
581
+ self.enable_autocast_float8_transformer_flag = True
582
+
583
+ @torch.no_grad()
584
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
585
+ def __call__(
586
+ self,
587
+ prompt: Union[str, List[str]] = None,
588
+ video_length: Optional[int] = None,
589
+ height: Optional[int] = None,
590
+ width: Optional[int] = None,
591
+ num_inference_steps: Optional[int] = 50,
592
+ guidance_scale: Optional[float] = 5.0,
593
+ negative_prompt: Optional[Union[str, List[str]]] = None,
594
+ num_images_per_prompt: Optional[int] = 1,
595
+ eta: Optional[float] = 0.0,
596
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
597
+ latents: Optional[torch.Tensor] = None,
598
+ prompt_embeds: Optional[torch.Tensor] = None,
599
+ prompt_embeds_2: Optional[torch.Tensor] = None,
600
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
601
+ negative_prompt_embeds_2: Optional[torch.Tensor] = None,
602
+ prompt_attention_mask: Optional[torch.Tensor] = None,
603
+ prompt_attention_mask_2: Optional[torch.Tensor] = None,
604
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
605
+ negative_prompt_attention_mask_2: Optional[torch.Tensor] = None,
606
+ output_type: Optional[str] = "latent",
607
+ return_dict: bool = True,
608
+ callback_on_step_end: Optional[
609
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
610
+ ] = None,
611
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
612
+ guidance_rescale: float = 0.0,
613
+ original_size: Optional[Tuple[int, int]] = (1024, 1024),
614
+ target_size: Optional[Tuple[int, int]] = None,
615
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
616
+ comfyui_progressbar: bool = False,
617
+ ):
618
+ r"""
619
+ Generates images or video using the EasyAnimate pipeline based on the provided prompts.
620
+
621
+ Examples:
622
+ prompt (`str` or `List[str]`, *optional*):
623
+ Text prompts to guide the image or video generation. If not provided, use `prompt_embeds` instead.
624
+ video_length (`int`, *optional*):
625
+ Length of the generated video (in frames).
626
+ height (`int`, *optional*):
627
+ Height of the generated image in pixels.
628
+ width (`int`, *optional*):
629
+ Width of the generated image in pixels.
630
+ num_inference_steps (`int`, *optional*, defaults to 50):
631
+ Number of denoising steps during generation. More steps generally yield higher quality images but slow down inference.
632
+ guidance_scale (`float`, *optional*, defaults to 5.0):
633
+ Encourages the model to align outputs with prompts. A higher value may decrease image quality.
634
+ negative_prompt (`str` or `List[str]`, *optional*):
635
+ Prompts indicating what to exclude in generation. If not specified, use `negative_prompt_embeds`.
636
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
637
+ Number of images to generate for each prompt.
638
+ eta (`float`, *optional*, defaults to 0.0):
639
+ Applies to DDIM scheduling. Controlled by the eta parameter from the related literature.
640
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
641
+ A generator to ensure reproducibility in image generation.
642
+ latents (`torch.Tensor`, *optional*):
643
+ Predefined latent tensors to condition generation.
644
+ prompt_embeds (`torch.Tensor`, *optional*):
645
+ Text embeddings for the prompts. Overrides prompt string inputs for more flexibility.
646
+ prompt_embeds_2 (`torch.Tensor`, *optional*):
647
+ Secondary text embeddings to supplement or replace the initial prompt embeddings.
648
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
649
+ Embeddings for negative prompts. Overrides string inputs if defined.
650
+ negative_prompt_embeds_2 (`torch.Tensor`, *optional*):
651
+ Secondary embeddings for negative prompts, similar to `negative_prompt_embeds`.
652
+ prompt_attention_mask (`torch.Tensor`, *optional*):
653
+ Attention mask for the primary prompt embeddings.
654
+ prompt_attention_mask_2 (`torch.Tensor`, *optional*):
655
+ Attention mask for the secondary prompt embeddings.
656
+ negative_prompt_attention_mask (`torch.Tensor`, *optional*):
657
+ Attention mask for negative prompt embeddings.
658
+ negative_prompt_attention_mask_2 (`torch.Tensor`, *optional*):
659
+ Attention mask for secondary negative prompt embeddings.
660
+ output_type (`str`, *optional*, defaults to "latent"):
661
+ Format of the generated output, either as a PIL image or as a NumPy array.
662
+ return_dict (`bool`, *optional*, defaults to `True`):
663
+ If `True`, returns a structured output. Otherwise returns a simple tuple.
664
+ callback_on_step_end (`Callable`, *optional*):
665
+ Functions called at the end of each denoising step.
666
+ callback_on_step_end_tensor_inputs (`List[str]`, *optional*):
667
+ Tensor names to be included in callback function calls.
668
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
669
+ Adjusts noise levels based on guidance scale.
670
+ original_size (`Tuple[int, int]`, *optional*, defaults to `(1024, 1024)`):
671
+ Original dimensions of the output.
672
+ target_size (`Tuple[int, int]`, *optional*):
673
+ Desired output dimensions for calculations.
674
+ crops_coords_top_left (`Tuple[int, int]`, *optional*, defaults to `(0, 0)`):
675
+ Coordinates for cropping.
676
+
677
+ Returns:
678
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
679
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
680
+ otherwise a `tuple` is returned where the first element is a list with the generated images and the
681
+ second element is a list of `bool`s indicating whether the corresponding generated image contains
682
+ "not-safe-for-work" (nsfw) content.
683
+ """
684
+
685
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
686
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
687
+
688
+ # 0. default height and width
689
+ height = int((height // 16) * 16)
690
+ width = int((width // 16) * 16)
691
+
692
+ # 1. Check inputs. Raise error if not correct
693
+ self.check_inputs(
694
+ prompt,
695
+ height,
696
+ width,
697
+ negative_prompt,
698
+ prompt_embeds,
699
+ negative_prompt_embeds,
700
+ prompt_attention_mask,
701
+ negative_prompt_attention_mask,
702
+ prompt_embeds_2,
703
+ negative_prompt_embeds_2,
704
+ prompt_attention_mask_2,
705
+ negative_prompt_attention_mask_2,
706
+ callback_on_step_end_tensor_inputs,
707
+ )
708
+ self._guidance_scale = guidance_scale
709
+ self._guidance_rescale = guidance_rescale
710
+ self._interrupt = False
711
+
712
+ # 2. Define call parameters
713
+ if prompt is not None and isinstance(prompt, str):
714
+ batch_size = 1
715
+ elif prompt is not None and isinstance(prompt, list):
716
+ batch_size = len(prompt)
717
+ else:
718
+ batch_size = prompt_embeds.shape[0]
719
+
720
+ device = self._execution_device
721
+
722
+ # 3. Encode input prompt
723
+ (
724
+ prompt_embeds,
725
+ negative_prompt_embeds,
726
+ prompt_attention_mask,
727
+ negative_prompt_attention_mask,
728
+ ) = self.encode_prompt(
729
+ prompt=prompt,
730
+ device=device,
731
+ dtype=self.transformer.dtype,
732
+ num_images_per_prompt=num_images_per_prompt,
733
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
734
+ negative_prompt=negative_prompt,
735
+ prompt_embeds=prompt_embeds,
736
+ negative_prompt_embeds=negative_prompt_embeds,
737
+ prompt_attention_mask=prompt_attention_mask,
738
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
739
+ text_encoder_index=0,
740
+ )
741
+ (
742
+ prompt_embeds_2,
743
+ negative_prompt_embeds_2,
744
+ prompt_attention_mask_2,
745
+ negative_prompt_attention_mask_2,
746
+ ) = self.encode_prompt(
747
+ prompt=prompt,
748
+ device=device,
749
+ dtype=self.transformer.dtype,
750
+ num_images_per_prompt=num_images_per_prompt,
751
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
752
+ negative_prompt=negative_prompt,
753
+ prompt_embeds=prompt_embeds_2,
754
+ negative_prompt_embeds=negative_prompt_embeds_2,
755
+ prompt_attention_mask=prompt_attention_mask_2,
756
+ negative_prompt_attention_mask=negative_prompt_attention_mask_2,
757
+ text_encoder_index=1,
758
+ )
759
+ torch.cuda.empty_cache()
760
+
761
+ # 4. Prepare timesteps
762
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
763
+ timesteps = self.scheduler.timesteps
764
+ if comfyui_progressbar:
765
+ from comfy.utils import ProgressBar
766
+ pbar = ProgressBar(num_inference_steps + 1)
767
+
768
+ # 5. Prepare latent variables
769
+ num_channels_latents = self.transformer.config.in_channels
770
+ latents = self.prepare_latents(
771
+ batch_size * num_images_per_prompt,
772
+ num_channels_latents,
773
+ video_length,
774
+ height,
775
+ width,
776
+ prompt_embeds.dtype,
777
+ device,
778
+ generator,
779
+ latents,
780
+ )
781
+ if comfyui_progressbar:
782
+ pbar.update(1)
783
+
784
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
785
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
786
+
787
+ # 7 create image_rotary_emb, style embedding & time ids
788
+ grid_height = height // 8 // self.transformer.config.patch_size
789
+ grid_width = width // 8 // self.transformer.config.patch_size
790
+ if self.transformer.config.get("time_position_encoding_type", "2d_rope") == "3d_rope":
791
+ base_size_width = 720 // 8 // self.transformer.config.patch_size
792
+ base_size_height = 480 // 8 // self.transformer.config.patch_size
793
+
794
+ grid_crops_coords = get_resize_crop_region_for_grid(
795
+ (grid_height, grid_width), base_size_width, base_size_height
796
+ )
797
+ image_rotary_emb = get_3d_rotary_pos_embed(
798
+ self.transformer.config.attention_head_dim, grid_crops_coords, grid_size=(grid_height, grid_width),
799
+ temporal_size=latents.size(2), use_real=True,
800
+ )
801
+ else:
802
+ base_size = 512 // 8 // self.transformer.config.patch_size
803
+ grid_crops_coords = get_resize_crop_region_for_grid(
804
+ (grid_height, grid_width), base_size, base_size
805
+ )
806
+ image_rotary_emb = get_2d_rotary_pos_embed(
807
+ self.transformer.config.attention_head_dim, grid_crops_coords, (grid_height, grid_width)
808
+ )
809
+
810
+ # Get other hunyuan params
811
+ style = torch.tensor([0], device=device)
812
+
813
+ target_size = target_size or (height, width)
814
+ add_time_ids = list(original_size + target_size + crops_coords_top_left)
815
+ add_time_ids = torch.tensor([add_time_ids], dtype=prompt_embeds.dtype)
816
+
817
+ if self.do_classifier_free_guidance:
818
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
819
+ prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask])
820
+ prompt_embeds_2 = torch.cat([negative_prompt_embeds_2, prompt_embeds_2])
821
+ prompt_attention_mask_2 = torch.cat([negative_prompt_attention_mask_2, prompt_attention_mask_2])
822
+ add_time_ids = torch.cat([add_time_ids] * 2, dim=0)
823
+ style = torch.cat([style] * 2, dim=0)
824
+
825
+ # To latents.device
826
+ prompt_embeds = prompt_embeds.to(device=device)
827
+ prompt_attention_mask = prompt_attention_mask.to(device=device)
828
+ prompt_embeds_2 = prompt_embeds_2.to(device=device)
829
+ prompt_attention_mask_2 = prompt_attention_mask_2.to(device=device)
830
+ add_time_ids = add_time_ids.to(dtype=prompt_embeds.dtype, device=device).repeat(
831
+ batch_size * num_images_per_prompt, 1
832
+ )
833
+ style = style.to(device=device).repeat(batch_size * num_images_per_prompt)
834
+
835
+ torch.cuda.empty_cache()
836
+ if self.enable_autocast_float8_transformer_flag:
837
+ origin_weight_dtype = self.transformer.dtype
838
+ self.transformer = self.transformer.to(torch.float8_e4m3fn)
839
+ # 8. Denoising loop
840
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
841
+ self._num_timesteps = len(timesteps)
842
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
843
+ for i, t in enumerate(timesteps):
844
+ if self.interrupt:
845
+ continue
846
+
847
+ # expand the latents if we are doing classifier free guidance
848
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
849
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
850
+
851
+ # expand scalar t to 1-D tensor to match the 1st dim of latent_model_input
852
+ t_expand = torch.tensor([t] * latent_model_input.shape[0], device=device).to(
853
+ dtype=latent_model_input.dtype
854
+ )
855
+
856
+ # predict the noise residual
857
+ noise_pred = self.transformer(
858
+ latent_model_input,
859
+ t_expand,
860
+ encoder_hidden_states=prompt_embeds,
861
+ text_embedding_mask=prompt_attention_mask,
862
+ encoder_hidden_states_t5=prompt_embeds_2,
863
+ text_embedding_mask_t5=prompt_attention_mask_2,
864
+ image_meta_size=add_time_ids,
865
+ style=style,
866
+ image_rotary_emb=image_rotary_emb,
867
+ return_dict=False,
868
+ )[0]
869
+
870
+ if noise_pred.size()[1] != self.vae.config.latent_channels:
871
+ noise_pred, _ = noise_pred.chunk(2, dim=1)
872
+
873
+ # perform guidance
874
+ if self.do_classifier_free_guidance:
875
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
876
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
877
+
878
+ if self.do_classifier_free_guidance and guidance_rescale > 0.0:
879
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
880
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
881
+
882
+ # compute the previous noisy sample x_t -> x_t-1
883
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
884
+
885
+ if callback_on_step_end is not None:
886
+ callback_kwargs = {}
887
+ for k in callback_on_step_end_tensor_inputs:
888
+ callback_kwargs[k] = locals()[k]
889
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
890
+
891
+ latents = callback_outputs.pop("latents", latents)
892
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
893
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
894
+ prompt_embeds_2 = callback_outputs.pop("prompt_embeds_2", prompt_embeds_2)
895
+ negative_prompt_embeds_2 = callback_outputs.pop(
896
+ "negative_prompt_embeds_2", negative_prompt_embeds_2
897
+ )
898
+
899
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
900
+ progress_bar.update()
901
+
902
+ if XLA_AVAILABLE:
903
+ xm.mark_step()
904
+
905
+ if comfyui_progressbar:
906
+ pbar.update(1)
907
+
908
+ if self.enable_autocast_float8_transformer_flag:
909
+ self.transformer = self.transformer.to("cpu", origin_weight_dtype)
910
+
911
+ torch.cuda.empty_cache()
912
+ # Post-processing
913
+ video = self.decode_latents(latents)
914
+
915
+ # Convert to tensor
916
+ if output_type == "latent":
917
+ video = torch.from_numpy(video)
918
+
919
+ # Offload all models
920
+ self.maybe_free_model_hooks()
921
+
922
+ if not return_dict:
923
+ return video
924
+
925
+ return EasyAnimatePipelineOutput(videos=video)
easyanimate/pipeline/pipeline_easyanimate_multi_text_encoder_control.py ADDED
@@ -0,0 +1,996 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 EasyAnimate Authors and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ import re
17
+ import urllib.parse as ul
18
+ from dataclasses import dataclass
19
+ from typing import Callable, Dict, List, Optional, Tuple, Union
20
+
21
+ import numpy as np
22
+ import torch
23
+ import torch.nn.functional as F
24
+ from diffusers import DiffusionPipeline, ImagePipelineOutput
25
+ from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
26
+ from diffusers.image_processor import VaeImageProcessor
27
+ from diffusers.models import AutoencoderKL, HunyuanDiT2DModel
28
+ from diffusers.models.embeddings import (get_2d_rotary_pos_embed,
29
+ get_3d_rotary_pos_embed)
30
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
31
+ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
32
+ from diffusers.pipelines.stable_diffusion.safety_checker import \
33
+ StableDiffusionSafetyChecker
34
+ from diffusers.schedulers import DDIMScheduler, DPMSolverMultistepScheduler
35
+ from diffusers.utils import (BACKENDS_MAPPING, BaseOutput, deprecate,
36
+ is_bs4_available, is_ftfy_available,
37
+ is_torch_xla_available, logging,
38
+ replace_example_docstring)
39
+ from diffusers.utils.torch_utils import randn_tensor
40
+ from einops import rearrange
41
+ from PIL import Image
42
+ from tqdm import tqdm
43
+ from transformers import (BertModel, BertTokenizer, CLIPImageProcessor,
44
+ CLIPVisionModelWithProjection,
45
+ T5EncoderModel, T5Tokenizer)
46
+
47
+ from ..models import AutoencoderKLMagvit, EasyAnimateTransformer3DModel
48
+ from .pipeline_easyanimate import EasyAnimatePipelineOutput
49
+
50
+ if is_torch_xla_available():
51
+ import torch_xla.core.xla_model as xm
52
+
53
+ XLA_AVAILABLE = True
54
+ else:
55
+ XLA_AVAILABLE = False
56
+
57
+
58
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
59
+
60
+ EXAMPLE_DOC_STRING = """
61
+ Examples:
62
+ ```py
63
+ >>> pass
64
+ ```
65
+ """
66
+
67
+ def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
68
+ tw = tgt_width
69
+ th = tgt_height
70
+ h, w = src
71
+ r = h / w
72
+ if r > (th / tw):
73
+ resize_height = th
74
+ resize_width = int(round(th / h * w))
75
+ else:
76
+ resize_width = tw
77
+ resize_height = int(round(tw / w * h))
78
+
79
+ crop_top = int(round((th - resize_height) / 2.0))
80
+ crop_left = int(round((tw - resize_width) / 2.0))
81
+
82
+ return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
83
+
84
+
85
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
86
+ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
87
+ """
88
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
89
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
90
+ """
91
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
92
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
93
+ # rescale the results from guidance (fixes overexposure)
94
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
95
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
96
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
97
+ return noise_cfg
98
+
99
+
100
+ class EasyAnimatePipeline_Multi_Text_Encoder_Control(DiffusionPipeline):
101
+ r"""
102
+ Pipeline for text-to-video generation using EasyAnimate.
103
+
104
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
105
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
106
+
107
+ EasyAnimate uses two text encoders: [mT5](https://huggingface.co/google/mt5-base) and [bilingual CLIP](fine-tuned by
108
+ HunyuanDiT team)
109
+
110
+ Args:
111
+ vae ([`AutoencoderKLMagvit`]):
112
+ Variational Auto-Encoder (VAE) Model to encode and decode video to and from latent representations.
113
+ text_encoder (Optional[`~transformers.BertModel`, `~transformers.CLIPTextModel`]):
114
+ Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
115
+ EasyAnimate uses a fine-tuned [bilingual CLIP].
116
+ tokenizer (Optional[`~transformers.BertTokenizer`, `~transformers.CLIPTokenizer`]):
117
+ A `BertTokenizer` or `CLIPTokenizer` to tokenize text.
118
+ transformer ([`EasyAnimateTransformer3DModel`]):
119
+ The EasyAnimate model designed by Tencent Hunyuan.
120
+ text_encoder_2 (`T5EncoderModel`):
121
+ The mT5 embedder.
122
+ tokenizer_2 (`T5Tokenizer`):
123
+ The tokenizer for the mT5 embedder.
124
+ scheduler ([`DDIMScheduler`]):
125
+ A scheduler to be used in combination with EasyAnimate to denoise the encoded image latents.
126
+ """
127
+
128
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
129
+ _optional_components = [
130
+ "safety_checker",
131
+ "feature_extractor",
132
+ "text_encoder_2",
133
+ "tokenizer_2",
134
+ "text_encoder",
135
+ "tokenizer",
136
+ ]
137
+ _exclude_from_cpu_offload = ["safety_checker"]
138
+ _callback_tensor_inputs = [
139
+ "latents",
140
+ "prompt_embeds",
141
+ "negative_prompt_embeds",
142
+ "prompt_embeds_2",
143
+ "negative_prompt_embeds_2",
144
+ ]
145
+
146
+ def __init__(
147
+ self,
148
+ vae: AutoencoderKLMagvit,
149
+ text_encoder: BertModel,
150
+ tokenizer: BertTokenizer,
151
+ text_encoder_2: T5EncoderModel,
152
+ tokenizer_2: T5Tokenizer,
153
+ transformer: EasyAnimateTransformer3DModel,
154
+ scheduler: DDIMScheduler,
155
+ safety_checker: StableDiffusionSafetyChecker,
156
+ feature_extractor: CLIPImageProcessor,
157
+ requires_safety_checker: bool = True
158
+ ):
159
+ super().__init__()
160
+
161
+ self.register_modules(
162
+ vae=vae,
163
+ text_encoder=text_encoder,
164
+ tokenizer=tokenizer,
165
+ tokenizer_2=tokenizer_2,
166
+ transformer=transformer,
167
+ scheduler=scheduler,
168
+ safety_checker=safety_checker,
169
+ feature_extractor=feature_extractor,
170
+ text_encoder_2=text_encoder_2
171
+ )
172
+
173
+ if safety_checker is None and requires_safety_checker:
174
+ logger.warning(
175
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
176
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
177
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
178
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
179
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
180
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
181
+ )
182
+
183
+ if safety_checker is not None and feature_extractor is None:
184
+ raise ValueError(
185
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
186
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
187
+ )
188
+
189
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
190
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
191
+ self.mask_processor = VaeImageProcessor(
192
+ vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
193
+ )
194
+ self.enable_autocast_float8_transformer_flag = False
195
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
196
+
197
+ def enable_sequential_cpu_offload(self, *args, **kwargs):
198
+ super().enable_sequential_cpu_offload(*args, **kwargs)
199
+ if hasattr(self.transformer, "clip_projection") and self.transformer.clip_projection is not None:
200
+ import accelerate
201
+ accelerate.hooks.remove_hook_from_module(self.transformer.clip_projection, recurse=True)
202
+ self.transformer.clip_projection = self.transformer.clip_projection.to("cuda")
203
+
204
+ def encode_prompt(
205
+ self,
206
+ prompt: str,
207
+ device: torch.device,
208
+ dtype: torch.dtype,
209
+ num_images_per_prompt: int = 1,
210
+ do_classifier_free_guidance: bool = True,
211
+ negative_prompt: Optional[str] = None,
212
+ prompt_embeds: Optional[torch.Tensor] = None,
213
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
214
+ prompt_attention_mask: Optional[torch.Tensor] = None,
215
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
216
+ max_sequence_length: Optional[int] = None,
217
+ text_encoder_index: int = 0,
218
+ actual_max_sequence_length: int = 256
219
+ ):
220
+ r"""
221
+ Encodes the prompt into text encoder hidden states.
222
+
223
+ Args:
224
+ prompt (`str` or `List[str]`, *optional*):
225
+ prompt to be encoded
226
+ device: (`torch.device`):
227
+ torch device
228
+ dtype (`torch.dtype`):
229
+ torch dtype
230
+ num_images_per_prompt (`int`):
231
+ number of images that should be generated per prompt
232
+ do_classifier_free_guidance (`bool`):
233
+ whether to use classifier free guidance or not
234
+ negative_prompt (`str` or `List[str]`, *optional*):
235
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
236
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
237
+ less than `1`).
238
+ prompt_embeds (`torch.Tensor`, *optional*):
239
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
240
+ provided, text embeddings will be generated from `prompt` input argument.
241
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
242
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
243
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
244
+ argument.
245
+ prompt_attention_mask (`torch.Tensor`, *optional*):
246
+ Attention mask for the prompt. Required when `prompt_embeds` is passed directly.
247
+ negative_prompt_attention_mask (`torch.Tensor`, *optional*):
248
+ Attention mask for the negative prompt. Required when `negative_prompt_embeds` is passed directly.
249
+ max_sequence_length (`int`, *optional*): maximum sequence length to use for the prompt.
250
+ text_encoder_index (`int`, *optional*):
251
+ Index of the text encoder to use. `0` for clip and `1` for T5.
252
+ """
253
+ tokenizers = [self.tokenizer, self.tokenizer_2]
254
+ text_encoders = [self.text_encoder, self.text_encoder_2]
255
+
256
+ tokenizer = tokenizers[text_encoder_index]
257
+ text_encoder = text_encoders[text_encoder_index]
258
+
259
+ if max_sequence_length is None:
260
+ if text_encoder_index == 0:
261
+ max_length = min(self.tokenizer.model_max_length, actual_max_sequence_length)
262
+ if text_encoder_index == 1:
263
+ max_length = min(self.tokenizer_2.model_max_length, actual_max_sequence_length)
264
+ else:
265
+ max_length = max_sequence_length
266
+
267
+ if prompt is not None and isinstance(prompt, str):
268
+ batch_size = 1
269
+ elif prompt is not None and isinstance(prompt, list):
270
+ batch_size = len(prompt)
271
+ else:
272
+ batch_size = prompt_embeds.shape[0]
273
+
274
+ if prompt_embeds is None:
275
+ text_inputs = tokenizer(
276
+ prompt,
277
+ padding="max_length",
278
+ max_length=max_length,
279
+ truncation=True,
280
+ return_attention_mask=True,
281
+ return_tensors="pt",
282
+ )
283
+ text_input_ids = text_inputs.input_ids
284
+ if text_input_ids.shape[-1] > actual_max_sequence_length:
285
+ reprompt = tokenizer.batch_decode(text_input_ids[:, :actual_max_sequence_length], skip_special_tokens=True)
286
+ text_inputs = tokenizer(
287
+ reprompt,
288
+ padding="max_length",
289
+ max_length=max_length,
290
+ truncation=True,
291
+ return_attention_mask=True,
292
+ return_tensors="pt",
293
+ )
294
+ text_input_ids = text_inputs.input_ids
295
+ untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
296
+
297
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
298
+ text_input_ids, untruncated_ids
299
+ ):
300
+ _actual_max_sequence_length = min(tokenizer.model_max_length, actual_max_sequence_length)
301
+ removed_text = tokenizer.batch_decode(untruncated_ids[:, _actual_max_sequence_length - 1 : -1])
302
+ logger.warning(
303
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
304
+ f" {_actual_max_sequence_length} tokens: {removed_text}"
305
+ )
306
+ prompt_attention_mask = text_inputs.attention_mask.to(device)
307
+ if self.transformer.config.enable_text_attention_mask:
308
+ prompt_embeds = text_encoder(
309
+ text_input_ids.to(device),
310
+ attention_mask=prompt_attention_mask,
311
+ )
312
+ else:
313
+ prompt_embeds = text_encoder(
314
+ text_input_ids.to(device)
315
+ )
316
+ prompt_embeds = prompt_embeds[0]
317
+ prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
318
+
319
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
320
+
321
+ bs_embed, seq_len, _ = prompt_embeds.shape
322
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
323
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
324
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
325
+
326
+ # get unconditional embeddings for classifier free guidance
327
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
328
+ uncond_tokens: List[str]
329
+ if negative_prompt is None:
330
+ uncond_tokens = [""] * batch_size
331
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
332
+ raise TypeError(
333
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
334
+ f" {type(prompt)}."
335
+ )
336
+ elif isinstance(negative_prompt, str):
337
+ uncond_tokens = [negative_prompt]
338
+ elif batch_size != len(negative_prompt):
339
+ raise ValueError(
340
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
341
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
342
+ " the batch size of `prompt`."
343
+ )
344
+ else:
345
+ uncond_tokens = negative_prompt
346
+
347
+ max_length = prompt_embeds.shape[1]
348
+ uncond_input = tokenizer(
349
+ uncond_tokens,
350
+ padding="max_length",
351
+ max_length=max_length,
352
+ truncation=True,
353
+ return_tensors="pt",
354
+ )
355
+ uncond_input_ids = uncond_input.input_ids
356
+ if uncond_input_ids.shape[-1] > actual_max_sequence_length:
357
+ reuncond_tokens = tokenizer.batch_decode(uncond_input_ids[:, :actual_max_sequence_length], skip_special_tokens=True)
358
+ uncond_input = tokenizer(
359
+ reuncond_tokens,
360
+ padding="max_length",
361
+ max_length=max_length,
362
+ truncation=True,
363
+ return_attention_mask=True,
364
+ return_tensors="pt",
365
+ )
366
+ uncond_input_ids = uncond_input.input_ids
367
+
368
+ negative_prompt_attention_mask = uncond_input.attention_mask.to(device)
369
+ if self.transformer.config.enable_text_attention_mask:
370
+ negative_prompt_embeds = text_encoder(
371
+ uncond_input.input_ids.to(device),
372
+ attention_mask=negative_prompt_attention_mask,
373
+ )
374
+ else:
375
+ negative_prompt_embeds = text_encoder(
376
+ uncond_input.input_ids.to(device)
377
+ )
378
+ negative_prompt_embeds = negative_prompt_embeds[0]
379
+ negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
380
+
381
+ if do_classifier_free_guidance:
382
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
383
+ seq_len = negative_prompt_embeds.shape[1]
384
+
385
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
386
+
387
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
388
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
389
+
390
+ return prompt_embeds, negative_prompt_embeds, prompt_attention_mask, negative_prompt_attention_mask
391
+
392
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
393
+ def run_safety_checker(self, image, device, dtype):
394
+ if self.safety_checker is None:
395
+ has_nsfw_concept = None
396
+ else:
397
+ if torch.is_tensor(image):
398
+ feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
399
+ else:
400
+ feature_extractor_input = self.image_processor.numpy_to_pil(image)
401
+ safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
402
+ image, has_nsfw_concept = self.safety_checker(
403
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
404
+ )
405
+ return image, has_nsfw_concept
406
+
407
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
408
+ def prepare_extra_step_kwargs(self, generator, eta):
409
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
410
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
411
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
412
+ # and should be between [0, 1]
413
+
414
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
415
+ extra_step_kwargs = {}
416
+ if accepts_eta:
417
+ extra_step_kwargs["eta"] = eta
418
+
419
+ # check if the scheduler accepts generator
420
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
421
+ if accepts_generator:
422
+ extra_step_kwargs["generator"] = generator
423
+ return extra_step_kwargs
424
+
425
+ def check_inputs(
426
+ self,
427
+ prompt,
428
+ height,
429
+ width,
430
+ negative_prompt=None,
431
+ prompt_embeds=None,
432
+ negative_prompt_embeds=None,
433
+ prompt_attention_mask=None,
434
+ negative_prompt_attention_mask=None,
435
+ prompt_embeds_2=None,
436
+ negative_prompt_embeds_2=None,
437
+ prompt_attention_mask_2=None,
438
+ negative_prompt_attention_mask_2=None,
439
+ callback_on_step_end_tensor_inputs=None,
440
+ ):
441
+ if height % 8 != 0 or width % 8 != 0:
442
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
443
+
444
+ if callback_on_step_end_tensor_inputs is not None and not all(
445
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
446
+ ):
447
+ raise ValueError(
448
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
449
+ )
450
+
451
+ if prompt is not None and prompt_embeds is not None:
452
+ raise ValueError(
453
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
454
+ " only forward one of the two."
455
+ )
456
+ elif prompt is None and prompt_embeds is None:
457
+ raise ValueError(
458
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
459
+ )
460
+ elif prompt is None and prompt_embeds_2 is None:
461
+ raise ValueError(
462
+ "Provide either `prompt` or `prompt_embeds_2`. Cannot leave both `prompt` and `prompt_embeds_2` undefined."
463
+ )
464
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
465
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
466
+
467
+ if prompt_embeds is not None and prompt_attention_mask is None:
468
+ raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
469
+
470
+ if prompt_embeds_2 is not None and prompt_attention_mask_2 is None:
471
+ raise ValueError("Must provide `prompt_attention_mask_2` when specifying `prompt_embeds_2`.")
472
+
473
+ if negative_prompt is not None and negative_prompt_embeds is not None:
474
+ raise ValueError(
475
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
476
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
477
+ )
478
+
479
+ if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
480
+ raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
481
+
482
+ if negative_prompt_embeds_2 is not None and negative_prompt_attention_mask_2 is None:
483
+ raise ValueError(
484
+ "Must provide `negative_prompt_attention_mask_2` when specifying `negative_prompt_embeds_2`."
485
+ )
486
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
487
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
488
+ raise ValueError(
489
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
490
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
491
+ f" {negative_prompt_embeds.shape}."
492
+ )
493
+ if prompt_embeds_2 is not None and negative_prompt_embeds_2 is not None:
494
+ if prompt_embeds_2.shape != negative_prompt_embeds_2.shape:
495
+ raise ValueError(
496
+ "`prompt_embeds_2` and `negative_prompt_embeds_2` must have the same shape when passed directly, but"
497
+ f" got: `prompt_embeds_2` {prompt_embeds_2.shape} != `negative_prompt_embeds_2`"
498
+ f" {negative_prompt_embeds_2.shape}."
499
+ )
500
+
501
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
502
+ def prepare_latents(self, batch_size, num_channels_latents, video_length, height, width, dtype, device, generator, latents=None):
503
+ if self.vae.quant_conv is None or self.vae.quant_conv.weight.ndim==5:
504
+ if self.vae.cache_mag_vae:
505
+ mini_batch_encoder = self.vae.mini_batch_encoder
506
+ mini_batch_decoder = self.vae.mini_batch_decoder
507
+ shape = (batch_size, num_channels_latents, int((video_length - 1) // mini_batch_encoder * mini_batch_decoder + 1) if video_length != 1 else 1, height // self.vae_scale_factor, width // self.vae_scale_factor)
508
+ else:
509
+ mini_batch_encoder = self.vae.mini_batch_encoder
510
+ mini_batch_decoder = self.vae.mini_batch_decoder
511
+ shape = (batch_size, num_channels_latents, int(video_length // mini_batch_encoder * mini_batch_decoder) if video_length != 1 else 1, height // self.vae_scale_factor, width // self.vae_scale_factor)
512
+ else:
513
+ shape = (batch_size, num_channels_latents, video_length, height // self.vae_scale_factor, width // self.vae_scale_factor)
514
+
515
+ if isinstance(generator, list) and len(generator) != batch_size:
516
+ raise ValueError(
517
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
518
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
519
+ )
520
+
521
+ if latents is None:
522
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
523
+ else:
524
+ latents = latents.to(device)
525
+
526
+ # scale the initial noise by the standard deviation required by the scheduler
527
+ latents = latents * self.scheduler.init_noise_sigma
528
+ return latents
529
+
530
+ def prepare_control_latents(
531
+ self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
532
+ ):
533
+ # resize the mask to latents shape as we concatenate the mask to the latents
534
+ # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
535
+ # and half precision
536
+
537
+ if mask is not None:
538
+ mask = mask.to(device=device, dtype=self.vae.dtype)
539
+ bs = 1
540
+ new_mask = []
541
+ for i in range(0, mask.shape[0], bs):
542
+ mask_bs = mask[i : i + bs]
543
+ mask_bs = self.vae.encode(mask_bs)[0]
544
+ mask_bs = mask_bs.mode()
545
+ new_mask.append(mask_bs)
546
+ mask = torch.cat(new_mask, dim = 0)
547
+ mask = mask * self.vae.config.scaling_factor
548
+
549
+ if masked_image is not None:
550
+ masked_image = masked_image.to(device=device, dtype=self.vae.dtype)
551
+ bs = 1
552
+ new_mask_pixel_values = []
553
+ for i in range(0, masked_image.shape[0], bs):
554
+ mask_pixel_values_bs = masked_image[i : i + bs]
555
+ mask_pixel_values_bs = self.vae.encode(mask_pixel_values_bs)[0]
556
+ mask_pixel_values_bs = mask_pixel_values_bs.mode()
557
+ new_mask_pixel_values.append(mask_pixel_values_bs)
558
+ masked_image_latents = torch.cat(new_mask_pixel_values, dim = 0)
559
+ masked_image_latents = masked_image_latents * self.vae.config.scaling_factor
560
+ else:
561
+ masked_image_latents = None
562
+
563
+ return mask, masked_image_latents
564
+
565
+ def smooth_output(self, video, mini_batch_encoder, mini_batch_decoder):
566
+ if video.size()[2] <= mini_batch_encoder:
567
+ return video
568
+ prefix_index_before = mini_batch_encoder // 2
569
+ prefix_index_after = mini_batch_encoder - prefix_index_before
570
+ pixel_values = video[:, :, prefix_index_before:-prefix_index_after]
571
+
572
+ # Encode middle videos
573
+ latents = self.vae.encode(pixel_values)[0]
574
+ latents = latents.mode()
575
+ # Decode middle videos
576
+ middle_video = self.vae.decode(latents)[0]
577
+
578
+ video[:, :, prefix_index_before:-prefix_index_after] = (video[:, :, prefix_index_before:-prefix_index_after] + middle_video) / 2
579
+ return video
580
+
581
+ def decode_latents(self, latents):
582
+ video_length = latents.shape[2]
583
+ latents = 1 / self.vae.config.scaling_factor * latents
584
+ if self.vae.quant_conv is None or self.vae.quant_conv.weight.ndim==5:
585
+ mini_batch_encoder = self.vae.mini_batch_encoder
586
+ mini_batch_decoder = self.vae.mini_batch_decoder
587
+ video = self.vae.decode(latents)[0]
588
+ video = video.clamp(-1, 1)
589
+ if not self.vae.cache_compression_vae and not self.vae.cache_mag_vae:
590
+ video = self.smooth_output(video, mini_batch_encoder, mini_batch_decoder).cpu().clamp(-1, 1)
591
+ else:
592
+ latents = rearrange(latents, "b c f h w -> (b f) c h w")
593
+ video = []
594
+ for frame_idx in tqdm(range(latents.shape[0])):
595
+ video.append(self.vae.decode(latents[frame_idx:frame_idx+1]).sample)
596
+ video = torch.cat(video)
597
+ video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
598
+ video = (video / 2 + 0.5).clamp(0, 1)
599
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
600
+ video = video.cpu().float().numpy()
601
+ return video
602
+
603
+ @property
604
+ def guidance_scale(self):
605
+ return self._guidance_scale
606
+
607
+ @property
608
+ def guidance_rescale(self):
609
+ return self._guidance_rescale
610
+
611
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
612
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
613
+ # corresponds to doing no classifier free guidance.
614
+ @property
615
+ def do_classifier_free_guidance(self):
616
+ return self._guidance_scale > 1
617
+
618
+ @property
619
+ def num_timesteps(self):
620
+ return self._num_timesteps
621
+
622
+ @property
623
+ def interrupt(self):
624
+ return self._interrupt
625
+
626
+ def enable_autocast_float8_transformer(self):
627
+ self.enable_autocast_float8_transformer_flag = True
628
+
629
+ @torch.no_grad()
630
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
631
+ def __call__(
632
+ self,
633
+ prompt: Union[str, List[str]] = None,
634
+ video_length: Optional[int] = None,
635
+ height: Optional[int] = None,
636
+ width: Optional[int] = None,
637
+ control_video: Union[torch.FloatTensor] = None,
638
+ num_inference_steps: Optional[int] = 50,
639
+ guidance_scale: Optional[float] = 5.0,
640
+ negative_prompt: Optional[Union[str, List[str]]] = None,
641
+ num_images_per_prompt: Optional[int] = 1,
642
+ eta: Optional[float] = 0.0,
643
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
644
+ latents: Optional[torch.Tensor] = None,
645
+ prompt_embeds: Optional[torch.Tensor] = None,
646
+ prompt_embeds_2: Optional[torch.Tensor] = None,
647
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
648
+ negative_prompt_embeds_2: Optional[torch.Tensor] = None,
649
+ prompt_attention_mask: Optional[torch.Tensor] = None,
650
+ prompt_attention_mask_2: Optional[torch.Tensor] = None,
651
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
652
+ negative_prompt_attention_mask_2: Optional[torch.Tensor] = None,
653
+ output_type: Optional[str] = "latent",
654
+ return_dict: bool = True,
655
+ callback_on_step_end: Optional[
656
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
657
+ ] = None,
658
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
659
+ guidance_rescale: float = 0.0,
660
+ original_size: Optional[Tuple[int, int]] = (1024, 1024),
661
+ target_size: Optional[Tuple[int, int]] = None,
662
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
663
+ comfyui_progressbar: bool = False,
664
+ ):
665
+ r"""
666
+ Generates images or video using the EasyAnimate pipeline based on the provided prompts.
667
+
668
+ Examples:
669
+ prompt (`str` or `List[str]`, *optional*):
670
+ Text prompts to guide the image or video generation. If not provided, use `prompt_embeds` instead.
671
+ video_length (`int`, *optional*):
672
+ Length of the generated video (in frames).
673
+ height (`int`, *optional*):
674
+ Height of the generated image in pixels.
675
+ width (`int`, *optional*):
676
+ Width of the generated image in pixels.
677
+ num_inference_steps (`int`, *optional*, defaults to 50):
678
+ Number of denoising steps during generation. More steps generally yield higher quality images but slow down inference.
679
+ guidance_scale (`float`, *optional*, defaults to 5.0):
680
+ Encourages the model to align outputs with prompts. A higher value may decrease image quality.
681
+ negative_prompt (`str` or `List[str]`, *optional*):
682
+ Prompts indicating what to exclude in generation. If not specified, use `negative_prompt_embeds`.
683
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
684
+ Number of images to generate for each prompt.
685
+ eta (`float`, *optional*, defaults to 0.0):
686
+ Applies to DDIM scheduling. Controlled by the eta parameter from the related literature.
687
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
688
+ A generator to ensure reproducibility in image generation.
689
+ latents (`torch.Tensor`, *optional*):
690
+ Predefined latent tensors to condition generation.
691
+ prompt_embeds (`torch.Tensor`, *optional*):
692
+ Text embeddings for the prompts. Overrides prompt string inputs for more flexibility.
693
+ prompt_embeds_2 (`torch.Tensor`, *optional*):
694
+ Secondary text embeddings to supplement or replace the initial prompt embeddings.
695
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
696
+ Embeddings for negative prompts. Overrides string inputs if defined.
697
+ negative_prompt_embeds_2 (`torch.Tensor`, *optional*):
698
+ Secondary embeddings for negative prompts, similar to `negative_prompt_embeds`.
699
+ prompt_attention_mask (`torch.Tensor`, *optional*):
700
+ Attention mask for the primary prompt embeddings.
701
+ prompt_attention_mask_2 (`torch.Tensor`, *optional*):
702
+ Attention mask for the secondary prompt embeddings.
703
+ negative_prompt_attention_mask (`torch.Tensor`, *optional*):
704
+ Attention mask for negative prompt embeddings.
705
+ negative_prompt_attention_mask_2 (`torch.Tensor`, *optional*):
706
+ Attention mask for secondary negative prompt embeddings.
707
+ output_type (`str`, *optional*, defaults to "latent"):
708
+ Format of the generated output, either as a PIL image or as a NumPy array.
709
+ return_dict (`bool`, *optional*, defaults to `True`):
710
+ If `True`, returns a structured output. Otherwise returns a simple tuple.
711
+ callback_on_step_end (`Callable`, *optional*):
712
+ Functions called at the end of each denoising step.
713
+ callback_on_step_end_tensor_inputs (`List[str]`, *optional*):
714
+ Tensor names to be included in callback function calls.
715
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
716
+ Adjusts noise levels based on guidance scale.
717
+ original_size (`Tuple[int, int]`, *optional*, defaults to `(1024, 1024)`):
718
+ Original dimensions of the output.
719
+ target_size (`Tuple[int, int]`, *optional*):
720
+ Desired output dimensions for calculations.
721
+ crops_coords_top_left (`Tuple[int, int]`, *optional*, defaults to `(0, 0)`):
722
+ Coordinates for cropping.
723
+
724
+ Returns:
725
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
726
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
727
+ otherwise a `tuple` is returned where the first element is a list with the generated images and the
728
+ second element is a list of `bool`s indicating whether the corresponding generated image contains
729
+ "not-safe-for-work" (nsfw) content.
730
+ """
731
+
732
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
733
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
734
+
735
+ # 0. default height and width
736
+ height = int((height // 16) * 16)
737
+ width = int((width // 16) * 16)
738
+
739
+ # 1. Check inputs. Raise error if not correct
740
+ self.check_inputs(
741
+ prompt,
742
+ height,
743
+ width,
744
+ negative_prompt,
745
+ prompt_embeds,
746
+ negative_prompt_embeds,
747
+ prompt_attention_mask,
748
+ negative_prompt_attention_mask,
749
+ prompt_embeds_2,
750
+ negative_prompt_embeds_2,
751
+ prompt_attention_mask_2,
752
+ negative_prompt_attention_mask_2,
753
+ callback_on_step_end_tensor_inputs,
754
+ )
755
+ self._guidance_scale = guidance_scale
756
+ self._guidance_rescale = guidance_rescale
757
+ self._interrupt = False
758
+
759
+ # 2. Define call parameters
760
+ if prompt is not None and isinstance(prompt, str):
761
+ batch_size = 1
762
+ elif prompt is not None and isinstance(prompt, list):
763
+ batch_size = len(prompt)
764
+ else:
765
+ batch_size = prompt_embeds.shape[0]
766
+
767
+ device = self._execution_device
768
+
769
+ # 3. Encode input prompt
770
+ (
771
+ prompt_embeds,
772
+ negative_prompt_embeds,
773
+ prompt_attention_mask,
774
+ negative_prompt_attention_mask,
775
+ ) = self.encode_prompt(
776
+ prompt=prompt,
777
+ device=device,
778
+ dtype=self.transformer.dtype,
779
+ num_images_per_prompt=num_images_per_prompt,
780
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
781
+ negative_prompt=negative_prompt,
782
+ prompt_embeds=prompt_embeds,
783
+ negative_prompt_embeds=negative_prompt_embeds,
784
+ prompt_attention_mask=prompt_attention_mask,
785
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
786
+ text_encoder_index=0,
787
+ )
788
+ (
789
+ prompt_embeds_2,
790
+ negative_prompt_embeds_2,
791
+ prompt_attention_mask_2,
792
+ negative_prompt_attention_mask_2,
793
+ ) = self.encode_prompt(
794
+ prompt=prompt,
795
+ device=device,
796
+ dtype=self.transformer.dtype,
797
+ num_images_per_prompt=num_images_per_prompt,
798
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
799
+ negative_prompt=negative_prompt,
800
+ prompt_embeds=prompt_embeds_2,
801
+ negative_prompt_embeds=negative_prompt_embeds_2,
802
+ prompt_attention_mask=prompt_attention_mask_2,
803
+ negative_prompt_attention_mask=negative_prompt_attention_mask_2,
804
+ text_encoder_index=1,
805
+ )
806
+ torch.cuda.empty_cache()
807
+
808
+ # 4. Prepare timesteps
809
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
810
+ timesteps = self.scheduler.timesteps
811
+ if comfyui_progressbar:
812
+ from comfy.utils import ProgressBar
813
+ pbar = ProgressBar(num_inference_steps + 2)
814
+
815
+ # 5. Prepare latent variables
816
+ num_channels_latents = self.vae.config.latent_channels
817
+ latents = self.prepare_latents(
818
+ batch_size * num_images_per_prompt,
819
+ num_channels_latents,
820
+ video_length,
821
+ height,
822
+ width,
823
+ prompt_embeds.dtype,
824
+ device,
825
+ generator,
826
+ latents,
827
+ )
828
+ if comfyui_progressbar:
829
+ pbar.update(1)
830
+
831
+ if control_video is not None:
832
+ video_length = control_video.shape[2]
833
+ control_video = self.image_processor.preprocess(rearrange(control_video, "b c f h w -> (b f) c h w"), height=height, width=width)
834
+ control_video = control_video.to(dtype=torch.float32)
835
+ control_video = rearrange(control_video, "(b f) c h w -> b c f h w", f=video_length)
836
+ else:
837
+ control_video = None
838
+ control_video_latents = self.prepare_control_latents(
839
+ None,
840
+ control_video,
841
+ batch_size,
842
+ height,
843
+ width,
844
+ prompt_embeds.dtype,
845
+ device,
846
+ generator,
847
+ self.do_classifier_free_guidance
848
+ )[1]
849
+ control_latents = (
850
+ torch.cat([control_video_latents] * 2) if self.do_classifier_free_guidance else control_video_latents
851
+ )
852
+
853
+ if comfyui_progressbar:
854
+ pbar.update(1)
855
+
856
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
857
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
858
+
859
+ # 7 create image_rotary_emb, style embedding & time ids
860
+ grid_height = height // 8 // self.transformer.config.patch_size
861
+ grid_width = width // 8 // self.transformer.config.patch_size
862
+ if self.transformer.config.get("time_position_encoding_type", "2d_rope") == "3d_rope":
863
+ base_size_width = 720 // 8 // self.transformer.config.patch_size
864
+ base_size_height = 480 // 8 // self.transformer.config.patch_size
865
+
866
+ grid_crops_coords = get_resize_crop_region_for_grid(
867
+ (grid_height, grid_width), base_size_width, base_size_height
868
+ )
869
+ image_rotary_emb = get_3d_rotary_pos_embed(
870
+ self.transformer.config.attention_head_dim, grid_crops_coords, grid_size=(grid_height, grid_width),
871
+ temporal_size=latents.size(2), use_real=True,
872
+ )
873
+ else:
874
+ base_size = 512 // 8 // self.transformer.config.patch_size
875
+ grid_crops_coords = get_resize_crop_region_for_grid(
876
+ (grid_height, grid_width), base_size, base_size
877
+ )
878
+ image_rotary_emb = get_2d_rotary_pos_embed(
879
+ self.transformer.config.attention_head_dim, grid_crops_coords, (grid_height, grid_width)
880
+ )
881
+
882
+ # Get other hunyuan params
883
+ style = torch.tensor([0], device=device)
884
+
885
+ target_size = target_size or (height, width)
886
+ add_time_ids = list(original_size + target_size + crops_coords_top_left)
887
+ add_time_ids = torch.tensor([add_time_ids], dtype=prompt_embeds.dtype)
888
+
889
+ if self.do_classifier_free_guidance:
890
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
891
+ prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask])
892
+ prompt_embeds_2 = torch.cat([negative_prompt_embeds_2, prompt_embeds_2])
893
+ prompt_attention_mask_2 = torch.cat([negative_prompt_attention_mask_2, prompt_attention_mask_2])
894
+ add_time_ids = torch.cat([add_time_ids] * 2, dim=0)
895
+ style = torch.cat([style] * 2, dim=0)
896
+
897
+ # To latents.device
898
+ prompt_embeds = prompt_embeds.to(device=device)
899
+ prompt_attention_mask = prompt_attention_mask.to(device=device)
900
+ prompt_embeds_2 = prompt_embeds_2.to(device=device)
901
+ prompt_attention_mask_2 = prompt_attention_mask_2.to(device=device)
902
+ add_time_ids = add_time_ids.to(dtype=prompt_embeds.dtype, device=device).repeat(
903
+ batch_size * num_images_per_prompt, 1
904
+ )
905
+ style = style.to(device=device).repeat(batch_size * num_images_per_prompt)
906
+
907
+ torch.cuda.empty_cache()
908
+ if self.enable_autocast_float8_transformer_flag:
909
+ origin_weight_dtype = self.transformer.dtype
910
+ self.transformer = self.transformer.to(torch.float8_e4m3fn)
911
+ # 8. Denoising loop
912
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
913
+ self._num_timesteps = len(timesteps)
914
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
915
+ for i, t in enumerate(timesteps):
916
+ if self.interrupt:
917
+ continue
918
+
919
+ # expand the latents if we are doing classifier free guidance
920
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
921
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
922
+
923
+ # expand scalar t to 1-D tensor to match the 1st dim of latent_model_input
924
+ t_expand = torch.tensor([t] * latent_model_input.shape[0], device=device).to(
925
+ dtype=latent_model_input.dtype
926
+ )
927
+ # predict the noise residual
928
+ noise_pred = self.transformer(
929
+ latent_model_input,
930
+ t_expand,
931
+ encoder_hidden_states=prompt_embeds,
932
+ text_embedding_mask=prompt_attention_mask,
933
+ encoder_hidden_states_t5=prompt_embeds_2,
934
+ text_embedding_mask_t5=prompt_attention_mask_2,
935
+ image_meta_size=add_time_ids,
936
+ style=style,
937
+ image_rotary_emb=image_rotary_emb,
938
+ return_dict=False,
939
+ control_latents=control_latents,
940
+ )[0]
941
+ if noise_pred.size()[1] != self.vae.config.latent_channels:
942
+ noise_pred, _ = noise_pred.chunk(2, dim=1)
943
+
944
+ # perform guidance
945
+ if self.do_classifier_free_guidance:
946
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
947
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
948
+
949
+ if self.do_classifier_free_guidance and guidance_rescale > 0.0:
950
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
951
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
952
+
953
+ # compute the previous noisy sample x_t -> x_t-1
954
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
955
+
956
+ if callback_on_step_end is not None:
957
+ callback_kwargs = {}
958
+ for k in callback_on_step_end_tensor_inputs:
959
+ callback_kwargs[k] = locals()[k]
960
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
961
+
962
+ latents = callback_outputs.pop("latents", latents)
963
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
964
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
965
+ prompt_embeds_2 = callback_outputs.pop("prompt_embeds_2", prompt_embeds_2)
966
+ negative_prompt_embeds_2 = callback_outputs.pop(
967
+ "negative_prompt_embeds_2", negative_prompt_embeds_2
968
+ )
969
+
970
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
971
+ progress_bar.update()
972
+
973
+ if XLA_AVAILABLE:
974
+ xm.mark_step()
975
+
976
+ if comfyui_progressbar:
977
+ pbar.update(1)
978
+
979
+ if self.enable_autocast_float8_transformer_flag:
980
+ self.transformer = self.transformer.to("cpu", origin_weight_dtype)
981
+
982
+ torch.cuda.empty_cache()
983
+ # Post-processing
984
+ video = self.decode_latents(latents)
985
+
986
+ # Convert to tensor
987
+ if output_type == "latent":
988
+ video = torch.from_numpy(video)
989
+
990
+ # Offload all models
991
+ self.maybe_free_model_hooks()
992
+
993
+ if not return_dict:
994
+ return video
995
+
996
+ return EasyAnimatePipelineOutput(videos=video)
easyanimate/pipeline/pipeline_easyanimate_multi_text_encoder_inpaint.py ADDED
@@ -0,0 +1,1334 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 EasyAnimate Authors and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ from typing import Callable, Dict, List, Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn.functional as F
20
+ from diffusers import DiffusionPipeline
21
+ from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
22
+ from diffusers.image_processor import VaeImageProcessor
23
+ from diffusers.models import AutoencoderKL, HunyuanDiT2DModel
24
+ from diffusers.models.embeddings import (get_2d_rotary_pos_embed,
25
+ get_3d_rotary_pos_embed)
26
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
27
+ from diffusers.pipelines.stable_diffusion.safety_checker import \
28
+ StableDiffusionSafetyChecker
29
+ from diffusers.schedulers import DDIMScheduler
30
+ from diffusers.utils import (is_torch_xla_available, logging,
31
+ replace_example_docstring)
32
+ from diffusers.utils.torch_utils import randn_tensor
33
+ from einops import rearrange
34
+ from PIL import Image
35
+ from tqdm import tqdm
36
+ from transformers import (BertModel, BertTokenizer, CLIPImageProcessor,
37
+ CLIPVisionModelWithProjection, T5Tokenizer,
38
+ T5EncoderModel)
39
+
40
+ from .pipeline_easyanimate import EasyAnimatePipelineOutput
41
+ from ..models import AutoencoderKLMagvit, EasyAnimateTransformer3DModel
42
+
43
+ if is_torch_xla_available():
44
+ import torch_xla.core.xla_model as xm
45
+
46
+ XLA_AVAILABLE = True
47
+ else:
48
+ XLA_AVAILABLE = False
49
+
50
+
51
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
52
+
53
+ EXAMPLE_DOC_STRING = """
54
+ Examples:
55
+ ```py
56
+ >>> pass
57
+ ```
58
+ """
59
+
60
+
61
+ def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
62
+ tw = tgt_width
63
+ th = tgt_height
64
+ h, w = src
65
+ r = h / w
66
+ if r > (th / tw):
67
+ resize_height = th
68
+ resize_width = int(round(th / h * w))
69
+ else:
70
+ resize_width = tw
71
+ resize_height = int(round(tw / w * h))
72
+
73
+ crop_top = int(round((th - resize_height) / 2.0))
74
+ crop_left = int(round((tw - resize_width) / 2.0))
75
+
76
+ return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
77
+
78
+
79
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
80
+ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
81
+ """
82
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
83
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
84
+ """
85
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
86
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
87
+ # rescale the results from guidance (fixes overexposure)
88
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
89
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
90
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
91
+ return noise_cfg
92
+
93
+
94
+ def resize_mask(mask, latent, process_first_frame_only=True):
95
+ latent_size = latent.size()
96
+
97
+ if process_first_frame_only:
98
+ target_size = list(latent_size[2:])
99
+ target_size[0] = 1
100
+ first_frame_resized = F.interpolate(
101
+ mask[:, :, 0:1, :, :],
102
+ size=target_size,
103
+ mode='trilinear',
104
+ align_corners=False
105
+ )
106
+
107
+ target_size = list(latent_size[2:])
108
+ target_size[0] = target_size[0] - 1
109
+ if target_size[0] != 0:
110
+ remaining_frames_resized = F.interpolate(
111
+ mask[:, :, 1:, :, :],
112
+ size=target_size,
113
+ mode='trilinear',
114
+ align_corners=False
115
+ )
116
+ resized_mask = torch.cat([first_frame_resized, remaining_frames_resized], dim=2)
117
+ else:
118
+ resized_mask = first_frame_resized
119
+ else:
120
+ target_size = list(latent_size[2:])
121
+ resized_mask = F.interpolate(
122
+ mask,
123
+ size=target_size,
124
+ mode='trilinear',
125
+ align_corners=False
126
+ )
127
+ return resized_mask
128
+
129
+
130
+ def add_noise_to_reference_video(image, ratio=None):
131
+ if ratio is None:
132
+ sigma = torch.normal(mean=-3.0, std=0.5, size=(image.shape[0],)).to(image.device)
133
+ sigma = torch.exp(sigma).to(image.dtype)
134
+ else:
135
+ sigma = torch.ones((image.shape[0],)).to(image.device, image.dtype) * ratio
136
+
137
+ image_noise = torch.randn_like(image) * sigma[:, None, None, None, None]
138
+ image_noise = torch.where(image==-1, torch.zeros_like(image), image_noise)
139
+ image = image + image_noise
140
+ return image
141
+
142
+
143
+ class EasyAnimatePipeline_Multi_Text_Encoder_Inpaint(DiffusionPipeline):
144
+ r"""
145
+ Pipeline for text-to-video generation using EasyAnimate.
146
+
147
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
148
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
149
+
150
+ EasyAnimate uses two text encoders: [mT5](https://huggingface.co/google/mt5-base) and [bilingual CLIP](fine-tuned by
151
+ HunyuanDiT team)
152
+
153
+ Args:
154
+ vae ([`AutoencoderKLMagvit`]):
155
+ Variational Auto-Encoder (VAE) Model to encode and decode video to and from latent representations.
156
+ text_encoder (Optional[`~transformers.BertModel`, `~transformers.CLIPTextModel`]):
157
+ Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
158
+ EasyAnimate uses a fine-tuned [bilingual CLIP].
159
+ tokenizer (Optional[`~transformers.BertTokenizer`, `~transformers.CLIPTokenizer`]):
160
+ A `BertTokenizer` or `CLIPTokenizer` to tokenize text.
161
+ transformer ([`EasyAnimateTransformer3DModel`]):
162
+ The EasyAnimate model designed by Tencent Hunyuan.
163
+ text_encoder_2 (`T5EncoderModel`):
164
+ The mT5 embedder.
165
+ tokenizer_2 (`T5Tokenizer`):
166
+ The tokenizer for the mT5 embedder.
167
+ scheduler ([`DDIMScheduler`]):
168
+ A scheduler to be used in combination with EasyAnimate to denoise the encoded image latents.
169
+ clip_image_processor (`CLIPImageProcessor`):
170
+ The CLIP image embedder.
171
+ clip_image_encoder (`CLIPVisionModelWithProjection`):
172
+ The image processor for the CLIP image embedder.
173
+ """
174
+
175
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->clip_image_encoder->transformer->vae"
176
+ _optional_components = [
177
+ "safety_checker",
178
+ "feature_extractor",
179
+ "text_encoder_2",
180
+ "tokenizer_2",
181
+ "text_encoder",
182
+ "tokenizer",
183
+ "clip_image_encoder",
184
+ ]
185
+ _exclude_from_cpu_offload = ["safety_checker"]
186
+ _callback_tensor_inputs = [
187
+ "latents",
188
+ "prompt_embeds",
189
+ "negative_prompt_embeds",
190
+ "prompt_embeds_2",
191
+ "negative_prompt_embeds_2",
192
+ ]
193
+
194
+ def __init__(
195
+ self,
196
+ vae: AutoencoderKLMagvit,
197
+ text_encoder: BertModel,
198
+ tokenizer: BertTokenizer,
199
+ text_encoder_2: T5EncoderModel,
200
+ tokenizer_2: T5Tokenizer,
201
+ transformer: EasyAnimateTransformer3DModel,
202
+ scheduler: DDIMScheduler,
203
+ safety_checker: StableDiffusionSafetyChecker,
204
+ feature_extractor: CLIPImageProcessor,
205
+ requires_safety_checker: bool = True,
206
+ clip_image_processor: CLIPImageProcessor = None,
207
+ clip_image_encoder: CLIPVisionModelWithProjection = None,
208
+ ):
209
+ super().__init__()
210
+
211
+ self.register_modules(
212
+ vae=vae,
213
+ text_encoder=text_encoder,
214
+ tokenizer=tokenizer,
215
+ tokenizer_2=tokenizer_2,
216
+ transformer=transformer,
217
+ scheduler=scheduler,
218
+ safety_checker=safety_checker,
219
+ feature_extractor=feature_extractor,
220
+ text_encoder_2=text_encoder_2,
221
+ clip_image_processor=clip_image_processor,
222
+ clip_image_encoder=clip_image_encoder,
223
+ )
224
+
225
+ if safety_checker is None and requires_safety_checker:
226
+ logger.warning(
227
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
228
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
229
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
230
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
231
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
232
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
233
+ )
234
+
235
+ if safety_checker is not None and feature_extractor is None:
236
+ raise ValueError(
237
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
238
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
239
+ )
240
+
241
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
242
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
243
+ self.mask_processor = VaeImageProcessor(
244
+ vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
245
+ )
246
+ self.enable_autocast_float8_transformer_flag = False
247
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
248
+
249
+ def enable_sequential_cpu_offload(self, *args, **kwargs):
250
+ super().enable_sequential_cpu_offload(*args, **kwargs)
251
+ if hasattr(self.transformer, "clip_projection") and self.transformer.clip_projection is not None:
252
+ import accelerate
253
+ accelerate.hooks.remove_hook_from_module(self.transformer.clip_projection, recurse=True)
254
+ self.transformer.clip_projection = self.transformer.clip_projection.to("cuda")
255
+
256
+ def encode_prompt(
257
+ self,
258
+ prompt: str,
259
+ device: torch.device,
260
+ dtype: torch.dtype,
261
+ num_images_per_prompt: int = 1,
262
+ do_classifier_free_guidance: bool = True,
263
+ negative_prompt: Optional[str] = None,
264
+ prompt_embeds: Optional[torch.Tensor] = None,
265
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
266
+ prompt_attention_mask: Optional[torch.Tensor] = None,
267
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
268
+ max_sequence_length: Optional[int] = None,
269
+ text_encoder_index: int = 0,
270
+ actual_max_sequence_length: int = 256
271
+ ):
272
+ r"""
273
+ Encodes the prompt into text encoder hidden states.
274
+
275
+ Args:
276
+ prompt (`str` or `List[str]`, *optional*):
277
+ prompt to be encoded
278
+ device: (`torch.device`):
279
+ torch device
280
+ dtype (`torch.dtype`):
281
+ torch dtype
282
+ num_images_per_prompt (`int`):
283
+ number of images that should be generated per prompt
284
+ do_classifier_free_guidance (`bool`):
285
+ whether to use classifier free guidance or not
286
+ negative_prompt (`str` or `List[str]`, *optional*):
287
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
288
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
289
+ less than `1`).
290
+ prompt_embeds (`torch.Tensor`, *optional*):
291
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
292
+ provided, text embeddings will be generated from `prompt` input argument.
293
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
294
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
295
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
296
+ argument.
297
+ prompt_attention_mask (`torch.Tensor`, *optional*):
298
+ Attention mask for the prompt. Required when `prompt_embeds` is passed directly.
299
+ negative_prompt_attention_mask (`torch.Tensor`, *optional*):
300
+ Attention mask for the negative prompt. Required when `negative_prompt_embeds` is passed directly.
301
+ max_sequence_length (`int`, *optional*): maximum sequence length to use for the prompt.
302
+ text_encoder_index (`int`, *optional*):
303
+ Index of the text encoder to use. `0` for clip and `1` for T5.
304
+ """
305
+ tokenizers = [self.tokenizer, self.tokenizer_2]
306
+ text_encoders = [self.text_encoder, self.text_encoder_2]
307
+
308
+ tokenizer = tokenizers[text_encoder_index]
309
+ text_encoder = text_encoders[text_encoder_index]
310
+
311
+ if max_sequence_length is None:
312
+ if text_encoder_index == 0:
313
+ max_length = min(self.tokenizer.model_max_length, actual_max_sequence_length)
314
+ if text_encoder_index == 1:
315
+ max_length = min(self.tokenizer_2.model_max_length, actual_max_sequence_length)
316
+ else:
317
+ max_length = max_sequence_length
318
+
319
+ if prompt is not None and isinstance(prompt, str):
320
+ batch_size = 1
321
+ elif prompt is not None and isinstance(prompt, list):
322
+ batch_size = len(prompt)
323
+ else:
324
+ batch_size = prompt_embeds.shape[0]
325
+
326
+ if prompt_embeds is None:
327
+ text_inputs = tokenizer(
328
+ prompt,
329
+ padding="max_length",
330
+ max_length=max_length,
331
+ truncation=True,
332
+ return_attention_mask=True,
333
+ return_tensors="pt",
334
+ )
335
+ text_input_ids = text_inputs.input_ids
336
+ if text_input_ids.shape[-1] > actual_max_sequence_length:
337
+ reprompt = tokenizer.batch_decode(text_input_ids[:, :actual_max_sequence_length], skip_special_tokens=True)
338
+ text_inputs = tokenizer(
339
+ reprompt,
340
+ padding="max_length",
341
+ max_length=max_length,
342
+ truncation=True,
343
+ return_attention_mask=True,
344
+ return_tensors="pt",
345
+ )
346
+ text_input_ids = text_inputs.input_ids
347
+ untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
348
+
349
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
350
+ text_input_ids, untruncated_ids
351
+ ):
352
+ _actual_max_sequence_length = min(tokenizer.model_max_length, actual_max_sequence_length)
353
+ removed_text = tokenizer.batch_decode(untruncated_ids[:, _actual_max_sequence_length - 1 : -1])
354
+ logger.warning(
355
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
356
+ f" {_actual_max_sequence_length} tokens: {removed_text}"
357
+ )
358
+ prompt_attention_mask = text_inputs.attention_mask.to(device)
359
+ if self.transformer.config.enable_text_attention_mask:
360
+ prompt_embeds = text_encoder(
361
+ text_input_ids.to(device),
362
+ attention_mask=prompt_attention_mask,
363
+ )
364
+ else:
365
+ prompt_embeds = text_encoder(
366
+ text_input_ids.to(device)
367
+ )
368
+ prompt_embeds = prompt_embeds[0]
369
+ prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
370
+
371
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
372
+
373
+ bs_embed, seq_len, _ = prompt_embeds.shape
374
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
375
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
376
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
377
+
378
+ # get unconditional embeddings for classifier free guidance
379
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
380
+ uncond_tokens: List[str]
381
+ if negative_prompt is None:
382
+ uncond_tokens = [""] * batch_size
383
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
384
+ raise TypeError(
385
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
386
+ f" {type(prompt)}."
387
+ )
388
+ elif isinstance(negative_prompt, str):
389
+ uncond_tokens = [negative_prompt]
390
+ elif batch_size != len(negative_prompt):
391
+ raise ValueError(
392
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
393
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
394
+ " the batch size of `prompt`."
395
+ )
396
+ else:
397
+ uncond_tokens = negative_prompt
398
+
399
+ max_length = prompt_embeds.shape[1]
400
+ uncond_input = tokenizer(
401
+ uncond_tokens,
402
+ padding="max_length",
403
+ max_length=max_length,
404
+ truncation=True,
405
+ return_tensors="pt",
406
+ )
407
+ uncond_input_ids = uncond_input.input_ids
408
+ if uncond_input_ids.shape[-1] > actual_max_sequence_length:
409
+ reuncond_tokens = tokenizer.batch_decode(uncond_input_ids[:, :actual_max_sequence_length], skip_special_tokens=True)
410
+ uncond_input = tokenizer(
411
+ reuncond_tokens,
412
+ padding="max_length",
413
+ max_length=max_length,
414
+ truncation=True,
415
+ return_attention_mask=True,
416
+ return_tensors="pt",
417
+ )
418
+ uncond_input_ids = uncond_input.input_ids
419
+
420
+ negative_prompt_attention_mask = uncond_input.attention_mask.to(device)
421
+ if self.transformer.config.enable_text_attention_mask:
422
+ negative_prompt_embeds = text_encoder(
423
+ uncond_input.input_ids.to(device),
424
+ attention_mask=negative_prompt_attention_mask,
425
+ )
426
+ else:
427
+ negative_prompt_embeds = text_encoder(
428
+ uncond_input.input_ids.to(device)
429
+ )
430
+ negative_prompt_embeds = negative_prompt_embeds[0]
431
+ negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
432
+
433
+ if do_classifier_free_guidance:
434
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
435
+ seq_len = negative_prompt_embeds.shape[1]
436
+
437
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
438
+
439
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
440
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
441
+
442
+ return prompt_embeds, negative_prompt_embeds, prompt_attention_mask, negative_prompt_attention_mask
443
+
444
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
445
+ def run_safety_checker(self, image, device, dtype):
446
+ if self.safety_checker is None:
447
+ has_nsfw_concept = None
448
+ else:
449
+ if torch.is_tensor(image):
450
+ feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
451
+ else:
452
+ feature_extractor_input = self.image_processor.numpy_to_pil(image)
453
+ safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
454
+ image, has_nsfw_concept = self.safety_checker(
455
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
456
+ )
457
+ return image, has_nsfw_concept
458
+
459
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
460
+ def prepare_extra_step_kwargs(self, generator, eta):
461
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
462
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
463
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
464
+ # and should be between [0, 1]
465
+
466
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
467
+ extra_step_kwargs = {}
468
+ if accepts_eta:
469
+ extra_step_kwargs["eta"] = eta
470
+
471
+ # check if the scheduler accepts generator
472
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
473
+ if accepts_generator:
474
+ extra_step_kwargs["generator"] = generator
475
+ return extra_step_kwargs
476
+
477
+ def check_inputs(
478
+ self,
479
+ prompt,
480
+ height,
481
+ width,
482
+ negative_prompt=None,
483
+ prompt_embeds=None,
484
+ negative_prompt_embeds=None,
485
+ prompt_attention_mask=None,
486
+ negative_prompt_attention_mask=None,
487
+ prompt_embeds_2=None,
488
+ negative_prompt_embeds_2=None,
489
+ prompt_attention_mask_2=None,
490
+ negative_prompt_attention_mask_2=None,
491
+ callback_on_step_end_tensor_inputs=None,
492
+ ):
493
+ if height % 8 != 0 or width % 8 != 0:
494
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
495
+
496
+ if callback_on_step_end_tensor_inputs is not None and not all(
497
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
498
+ ):
499
+ raise ValueError(
500
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
501
+ )
502
+
503
+ if prompt is not None and prompt_embeds is not None:
504
+ raise ValueError(
505
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
506
+ " only forward one of the two."
507
+ )
508
+ elif prompt is None and prompt_embeds is None:
509
+ raise ValueError(
510
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
511
+ )
512
+ elif prompt is None and prompt_embeds_2 is None:
513
+ raise ValueError(
514
+ "Provide either `prompt` or `prompt_embeds_2`. Cannot leave both `prompt` and `prompt_embeds_2` undefined."
515
+ )
516
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
517
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
518
+
519
+ if prompt_embeds is not None and prompt_attention_mask is None:
520
+ raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
521
+
522
+ if prompt_embeds_2 is not None and prompt_attention_mask_2 is None:
523
+ raise ValueError("Must provide `prompt_attention_mask_2` when specifying `prompt_embeds_2`.")
524
+
525
+ if negative_prompt is not None and negative_prompt_embeds is not None:
526
+ raise ValueError(
527
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
528
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
529
+ )
530
+
531
+ if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
532
+ raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
533
+
534
+ if negative_prompt_embeds_2 is not None and negative_prompt_attention_mask_2 is None:
535
+ raise ValueError(
536
+ "Must provide `negative_prompt_attention_mask_2` when specifying `negative_prompt_embeds_2`."
537
+ )
538
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
539
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
540
+ raise ValueError(
541
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
542
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
543
+ f" {negative_prompt_embeds.shape}."
544
+ )
545
+ if prompt_embeds_2 is not None and negative_prompt_embeds_2 is not None:
546
+ if prompt_embeds_2.shape != negative_prompt_embeds_2.shape:
547
+ raise ValueError(
548
+ "`prompt_embeds_2` and `negative_prompt_embeds_2` must have the same shape when passed directly, but"
549
+ f" got: `prompt_embeds_2` {prompt_embeds_2.shape} != `negative_prompt_embeds_2`"
550
+ f" {negative_prompt_embeds_2.shape}."
551
+ )
552
+
553
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
554
+ def get_timesteps(self, num_inference_steps, strength, device):
555
+ # get the original timestep using init_timestep
556
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
557
+
558
+ t_start = max(num_inference_steps - init_timestep, 0)
559
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
560
+
561
+ return timesteps, num_inference_steps - t_start
562
+
563
+ def prepare_mask_latents(
564
+ self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance, noise_aug_strength
565
+ ):
566
+ # resize the mask to latents shape as we concatenate the mask to the latents
567
+ # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
568
+ # and half precision
569
+ if mask is not None:
570
+ mask = mask.to(device=device, dtype=self.vae.dtype)
571
+ if self.vae.quant_conv is None or self.vae.quant_conv.weight.ndim==5:
572
+ bs = 1
573
+ new_mask = []
574
+ for i in range(0, mask.shape[0], bs):
575
+ mask_bs = mask[i : i + bs]
576
+ mask_bs = self.vae.encode(mask_bs)[0]
577
+ mask_bs = mask_bs.mode()
578
+ new_mask.append(mask_bs)
579
+ mask = torch.cat(new_mask, dim = 0)
580
+ mask = mask * self.vae.config.scaling_factor
581
+
582
+ else:
583
+ if mask.shape[1] == 4:
584
+ mask = mask
585
+ else:
586
+ video_length = mask.shape[2]
587
+ mask = rearrange(mask, "b c f h w -> (b f) c h w")
588
+ mask = self._encode_vae_image(mask, generator=generator)
589
+ mask = rearrange(mask, "(b f) c h w -> b c f h w", f=video_length)
590
+
591
+ if masked_image is not None:
592
+ masked_image = masked_image.to(device=device, dtype=self.vae.dtype)
593
+ if self.transformer.config.add_noise_in_inpaint_model:
594
+ masked_image = add_noise_to_reference_video(masked_image, ratio=noise_aug_strength)
595
+ if self.vae.quant_conv is None or self.vae.quant_conv.weight.ndim==5:
596
+ bs = 1
597
+ new_mask_pixel_values = []
598
+ for i in range(0, masked_image.shape[0], bs):
599
+ mask_pixel_values_bs = masked_image[i : i + bs]
600
+ mask_pixel_values_bs = self.vae.encode(mask_pixel_values_bs)[0]
601
+ mask_pixel_values_bs = mask_pixel_values_bs.mode()
602
+ new_mask_pixel_values.append(mask_pixel_values_bs)
603
+ masked_image_latents = torch.cat(new_mask_pixel_values, dim = 0)
604
+ masked_image_latents = masked_image_latents * self.vae.config.scaling_factor
605
+
606
+ else:
607
+ if masked_image.shape[1] == 4:
608
+ masked_image_latents = masked_image
609
+ else:
610
+ video_length = masked_image.shape[2]
611
+ masked_image = rearrange(masked_image, "b c f h w -> (b f) c h w")
612
+ masked_image_latents = self._encode_vae_image(masked_image, generator=generator)
613
+ masked_image_latents = rearrange(masked_image_latents, "(b f) c h w -> b c f h w", f=video_length)
614
+
615
+ # aligning device to prevent device errors when concating it with the latent model input
616
+ masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
617
+ else:
618
+ masked_image_latents = None
619
+
620
+ return mask, masked_image_latents
621
+
622
+ def prepare_latents(
623
+ self,
624
+ batch_size,
625
+ num_channels_latents,
626
+ height,
627
+ width,
628
+ video_length,
629
+ dtype,
630
+ device,
631
+ generator,
632
+ latents=None,
633
+ video=None,
634
+ timestep=None,
635
+ is_strength_max=True,
636
+ return_noise=False,
637
+ return_video_latents=False,
638
+ ):
639
+ if self.vae.quant_conv is None or self.vae.quant_conv.weight.ndim==5:
640
+ if self.vae.cache_mag_vae:
641
+ mini_batch_encoder = self.vae.mini_batch_encoder
642
+ mini_batch_decoder = self.vae.mini_batch_decoder
643
+ shape = (batch_size, num_channels_latents, int((video_length - 1) // mini_batch_encoder * mini_batch_decoder + 1) if video_length != 1 else 1, height // self.vae_scale_factor, width // self.vae_scale_factor)
644
+ else:
645
+ mini_batch_encoder = self.vae.mini_batch_encoder
646
+ mini_batch_decoder = self.vae.mini_batch_decoder
647
+ shape = (batch_size, num_channels_latents, int(video_length // mini_batch_encoder * mini_batch_decoder) if video_length != 1 else 1, height // self.vae_scale_factor, width // self.vae_scale_factor)
648
+ else:
649
+ shape = (batch_size, num_channels_latents, video_length, height // self.vae_scale_factor, width // self.vae_scale_factor)
650
+
651
+ if isinstance(generator, list) and len(generator) != batch_size:
652
+ raise ValueError(
653
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
654
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
655
+ )
656
+
657
+ if return_video_latents or (latents is None and not is_strength_max):
658
+ video = video.to(device=device, dtype=self.vae.dtype)
659
+ if self.vae.quant_conv is None or self.vae.quant_conv.weight.ndim==5:
660
+ bs = 1
661
+ new_video = []
662
+ for i in range(0, video.shape[0], bs):
663
+ video_bs = video[i : i + bs]
664
+ video_bs = self.vae.encode(video_bs)[0]
665
+ video_bs = video_bs.sample()
666
+ new_video.append(video_bs)
667
+ video = torch.cat(new_video, dim = 0)
668
+ video = video * self.vae.config.scaling_factor
669
+
670
+ else:
671
+ if video.shape[1] == 4:
672
+ video = video
673
+ else:
674
+ video_length = video.shape[2]
675
+ video = rearrange(video, "b c f h w -> (b f) c h w")
676
+ video = self._encode_vae_image(video, generator=generator)
677
+ video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
678
+ video_latents = video.repeat(batch_size // video.shape[0], 1, 1, 1, 1)
679
+ video_latents = video_latents.to(device=device, dtype=dtype)
680
+
681
+ if latents is None:
682
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
683
+ # if strength is 1. then initialise the latents to noise, else initial to image + noise
684
+ latents = noise if is_strength_max else self.scheduler.add_noise(video_latents, noise, timestep)
685
+ # if pure noise then scale the initial latents by the Scheduler's init sigma
686
+ latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents
687
+ else:
688
+ noise = latents.to(device)
689
+ latents = noise * self.scheduler.init_noise_sigma
690
+
691
+ # scale the initial noise by the standard deviation required by the scheduler
692
+ outputs = (latents,)
693
+
694
+ if return_noise:
695
+ outputs += (noise,)
696
+
697
+ if return_video_latents:
698
+ outputs += (video_latents,)
699
+
700
+ return outputs
701
+
702
+ def smooth_output(self, video, mini_batch_encoder, mini_batch_decoder):
703
+ if video.size()[2] <= mini_batch_encoder:
704
+ return video
705
+ prefix_index_before = mini_batch_encoder // 2
706
+ prefix_index_after = mini_batch_encoder - prefix_index_before
707
+ pixel_values = video[:, :, prefix_index_before:-prefix_index_after]
708
+
709
+ # Encode middle videos
710
+ latents = self.vae.encode(pixel_values)[0]
711
+ latents = latents.mode()
712
+ # Decode middle videos
713
+ middle_video = self.vae.decode(latents)[0]
714
+
715
+ video[:, :, prefix_index_before:-prefix_index_after] = (video[:, :, prefix_index_before:-prefix_index_after] + middle_video) / 2
716
+ return video
717
+
718
+ def decode_latents(self, latents):
719
+ video_length = latents.shape[2]
720
+ latents = 1 / self.vae.config.scaling_factor * latents
721
+ if self.vae.quant_conv is None or self.vae.quant_conv.weight.ndim==5:
722
+ mini_batch_encoder = self.vae.mini_batch_encoder
723
+ mini_batch_decoder = self.vae.mini_batch_decoder
724
+ video = self.vae.decode(latents)[0]
725
+ video = video.clamp(-1, 1)
726
+ if not self.vae.cache_compression_vae and not self.vae.cache_mag_vae:
727
+ video = self.smooth_output(video, mini_batch_encoder, mini_batch_decoder).cpu().clamp(-1, 1)
728
+ else:
729
+ latents = rearrange(latents, "b c f h w -> (b f) c h w")
730
+ video = []
731
+ for frame_idx in tqdm(range(latents.shape[0])):
732
+ video.append(self.vae.decode(latents[frame_idx:frame_idx+1]).sample)
733
+ video = torch.cat(video)
734
+ video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
735
+ video = (video / 2 + 0.5).clamp(0, 1)
736
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
737
+ video = video.cpu().float().numpy()
738
+ return video
739
+
740
+ @property
741
+ def guidance_scale(self):
742
+ return self._guidance_scale
743
+
744
+ @property
745
+ def guidance_rescale(self):
746
+ return self._guidance_rescale
747
+
748
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
749
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
750
+ # corresponds to doing no classifier free guidance.
751
+ @property
752
+ def do_classifier_free_guidance(self):
753
+ return self._guidance_scale > 1
754
+
755
+ @property
756
+ def num_timesteps(self):
757
+ return self._num_timesteps
758
+
759
+ @property
760
+ def interrupt(self):
761
+ return self._interrupt
762
+
763
+ def enable_autocast_float8_transformer(self):
764
+ self.enable_autocast_float8_transformer_flag = True
765
+
766
+ @torch.no_grad()
767
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
768
+ def __call__(
769
+ self,
770
+ prompt: Union[str, List[str]] = None,
771
+ video_length: Optional[int] = None,
772
+ video: Union[torch.FloatTensor] = None,
773
+ mask_video: Union[torch.FloatTensor] = None,
774
+ masked_video_latents: Union[torch.FloatTensor] = None,
775
+ height: Optional[int] = None,
776
+ width: Optional[int] = None,
777
+ num_inference_steps: Optional[int] = 50,
778
+ guidance_scale: Optional[float] = 5.0,
779
+ negative_prompt: Optional[Union[str, List[str]]] = None,
780
+ num_images_per_prompt: Optional[int] = 1,
781
+ eta: Optional[float] = 0.0,
782
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
783
+ latents: Optional[torch.Tensor] = None,
784
+ prompt_embeds: Optional[torch.Tensor] = None,
785
+ prompt_embeds_2: Optional[torch.Tensor] = None,
786
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
787
+ negative_prompt_embeds_2: Optional[torch.Tensor] = None,
788
+ prompt_attention_mask: Optional[torch.Tensor] = None,
789
+ prompt_attention_mask_2: Optional[torch.Tensor] = None,
790
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
791
+ negative_prompt_attention_mask_2: Optional[torch.Tensor] = None,
792
+ output_type: Optional[str] = "latent",
793
+ return_dict: bool = True,
794
+ callback_on_step_end: Optional[
795
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
796
+ ] = None,
797
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
798
+ guidance_rescale: float = 0.0,
799
+ original_size: Optional[Tuple[int, int]] = (1024, 1024),
800
+ target_size: Optional[Tuple[int, int]] = None,
801
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
802
+ clip_image: Image = None,
803
+ clip_apply_ratio: float = 0.40,
804
+ strength: float = 1.0,
805
+ noise_aug_strength: float = 0.0563,
806
+ comfyui_progressbar: bool = False,
807
+ ):
808
+ r"""
809
+ The call function to the pipeline for generation with HunyuanDiT.
810
+
811
+ Examples:
812
+ prompt (`str` or `List[str]`, *optional*):
813
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
814
+ video_length (`int`, *optional*):
815
+ Length of the video to be generated in seconds. This parameter influences the number of frames and
816
+ continuity of generated content.
817
+ video (`torch.FloatTensor`, *optional*):
818
+ A tensor representing an input video, which can be modified depending on the prompts provided.
819
+ mask_video (`torch.FloatTensor`, *optional*):
820
+ A tensor to specify areas of the video to be masked (omitted from generation).
821
+ masked_video_latents (`torch.FloatTensor`, *optional*):
822
+ Latents from masked portions of the video, utilized during image generation.
823
+ height (`int`, *optional*):
824
+ The height in pixels of the generated image or video frames.
825
+ width (`int`, *optional*):
826
+ The width in pixels of the generated image or video frames.
827
+ num_inference_steps (`int`, *optional*, defaults to 50):
828
+ The number of denoising steps. More denoising steps usually lead to a higher quality image but slower
829
+ inference time. This parameter is modulated by `strength`.
830
+ guidance_scale (`float`, *optional*, defaults to 5.0):
831
+ A higher guidance scale value encourages the model to generate images closely linked to the text
832
+ `prompt` at the expense of lower image quality. Guidance scale is effective when `guidance_scale > 1`.
833
+ negative_prompt (`str` or `List[str]`, *optional*):
834
+ The prompt or prompts to guide what to exclude in image generation. If not defined, you need to
835
+ provide `negative_prompt_embeds`. This parameter is ignored when not using guidance (`guidance_scale < 1`).
836
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
837
+ The number of images to generate per prompt.
838
+ eta (`float`, *optional*, defaults to 0.0):
839
+ A parameter defined in the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies to the
840
+ [`~schedulers.DDIMScheduler`] and is ignored in other schedulers. It adjusts noise level during the
841
+ inference process.
842
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
843
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) for setting
844
+ random seeds which helps in making generation deterministic.
845
+ latents (`torch.Tensor`, *optional*):
846
+ A pre-computed latent representation which can be used to guide the generation process.
847
+ prompt_embeds (`torch.Tensor`, *optional*):
848
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
849
+ provided, embeddings are generated from the `prompt` input argument.
850
+ prompt_embeds_2 (`torch.Tensor`, *optional*):
851
+ Secondary set of pre-generated text embeddings, useful for advanced prompt weighting.
852
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
853
+ Pre-generated negative text embeddings, aiding in fine-tuning what should not be represented in the outputs.
854
+ If not provided, embeddings are generated from the `negative_prompt` argument.
855
+ negative_prompt_embeds_2 (`torch.Tensor`, *optional*):
856
+ Secondary set of pre-generated negative text embeddings for further control.
857
+ prompt_attention_mask (`torch.Tensor`, *optional*):
858
+ Attention mask guiding the focus of the model on specific parts of the prompt text. Required when using
859
+ `prompt_embeds`.
860
+ prompt_attention_mask_2 (`torch.Tensor`, *optional*):
861
+ Attention mask for the secondary prompt embedding.
862
+ negative_prompt_attention_mask (`torch.Tensor`, *optional*):
863
+ Attention mask for the negative prompt, needed when `negative_prompt_embeds` are used.
864
+ negative_prompt_attention_mask_2 (`torch.Tensor`, *optional*):
865
+ Attention mask for the secondary negative prompt embedding.
866
+ output_type (`str`, *optional*, defaults to `"latent"`):
867
+ The output format of the generated image. Choose between `PIL.Image` and `np.array` to define
868
+ how you want the results to be formatted.
869
+ return_dict (`bool`, *optional*, defaults to `True`):
870
+ If set to `True`, a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] will be returned;
871
+ otherwise, a tuple containing the generated images and safety flags will be returned.
872
+ callback_on_step_end (`Callable[[int, int, Dict], None]`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
873
+ A callback function (or a list of them) that will be executed at the end of each denoising step,
874
+ allowing for custom processing during generation.
875
+ callback_on_step_end_tensor_inputs (`List[str]`, *optional*):
876
+ Specifies which tensor inputs should be included in the callback function. If not defined, all tensor
877
+ inputs will be passed, facilitating enhanced logging or monitoring of the generation process.
878
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
879
+ Rescale parameter for adjusting noise configuration based on guidance rescale. Based on findings from
880
+ [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
881
+ original_size (`Tuple[int, int]`, *optional*, defaults to `(1024, 1024)`):
882
+ The original dimensions of the image. Used to compute time ids during the generation process.
883
+ target_size (`Tuple[int, int]`, *optional*):
884
+ The targeted dimensions of the generated image, also utilized in the time id calculations.
885
+ crops_coords_top_left (`Tuple[int, int]`, *optional*, defaults to `(0, 0)`):
886
+ Coordinates defining the top left corner of any cropping, utilized while calculating the time ids.
887
+ clip_image (`Image`, *optional*):
888
+ An optional image to assist in the generation process. It may be used as an additional visual cue.
889
+ clip_apply_ratio (`float`, *optional*, defaults to 0.40):
890
+ Ratio indicating how much influence the clip image should exert over the generated content.
891
+ strength (`float`, *optional*, defaults to 1.0):
892
+ Affects the overall styling or quality of the generated output. Values closer to 1 usually provide direct
893
+ adherence to prompts.
894
+ comfyui_progressbar (`bool`, *optional*, defaults to `False`):
895
+ Enables a progress bar in ComfyUI, providing visual feedback during the generation process.
896
+
897
+ Examples:
898
+ # Example usage of the function for generating images based on prompts.
899
+
900
+ Returns:
901
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
902
+ Returns either a structured output containing generated images and their metadata when `return_dict` is
903
+ `True`, or a simpler tuple, where the first element is a list of generated images and the second
904
+ element indicates if any of them contain "not-safe-for-work" (NSFW) content.
905
+ """
906
+
907
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
908
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
909
+
910
+ # 0. default height and width
911
+ height = int(height // 16 * 16)
912
+ width = int(width // 16 * 16)
913
+
914
+ # 1. Check inputs. Raise error if not correct
915
+ self.check_inputs(
916
+ prompt,
917
+ height,
918
+ width,
919
+ negative_prompt,
920
+ prompt_embeds,
921
+ negative_prompt_embeds,
922
+ prompt_attention_mask,
923
+ negative_prompt_attention_mask,
924
+ prompt_embeds_2,
925
+ negative_prompt_embeds_2,
926
+ prompt_attention_mask_2,
927
+ negative_prompt_attention_mask_2,
928
+ callback_on_step_end_tensor_inputs,
929
+ )
930
+ self._guidance_scale = guidance_scale
931
+ self._guidance_rescale = guidance_rescale
932
+ self._interrupt = False
933
+
934
+ # 2. Define call parameters
935
+ if prompt is not None and isinstance(prompt, str):
936
+ batch_size = 1
937
+ elif prompt is not None and isinstance(prompt, list):
938
+ batch_size = len(prompt)
939
+ else:
940
+ batch_size = prompt_embeds.shape[0]
941
+
942
+ device = self._execution_device
943
+
944
+ # 3. Encode input prompt
945
+ (
946
+ prompt_embeds,
947
+ negative_prompt_embeds,
948
+ prompt_attention_mask,
949
+ negative_prompt_attention_mask,
950
+ ) = self.encode_prompt(
951
+ prompt=prompt,
952
+ device=device,
953
+ dtype=self.transformer.dtype,
954
+ num_images_per_prompt=num_images_per_prompt,
955
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
956
+ negative_prompt=negative_prompt,
957
+ prompt_embeds=prompt_embeds,
958
+ negative_prompt_embeds=negative_prompt_embeds,
959
+ prompt_attention_mask=prompt_attention_mask,
960
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
961
+ text_encoder_index=0,
962
+ )
963
+ (
964
+ prompt_embeds_2,
965
+ negative_prompt_embeds_2,
966
+ prompt_attention_mask_2,
967
+ negative_prompt_attention_mask_2,
968
+ ) = self.encode_prompt(
969
+ prompt=prompt,
970
+ device=device,
971
+ dtype=self.transformer.dtype,
972
+ num_images_per_prompt=num_images_per_prompt,
973
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
974
+ negative_prompt=negative_prompt,
975
+ prompt_embeds=prompt_embeds_2,
976
+ negative_prompt_embeds=negative_prompt_embeds_2,
977
+ prompt_attention_mask=prompt_attention_mask_2,
978
+ negative_prompt_attention_mask=negative_prompt_attention_mask_2,
979
+ text_encoder_index=1,
980
+ )
981
+ torch.cuda.empty_cache()
982
+
983
+ # 4. set timesteps
984
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
985
+ timesteps, num_inference_steps = self.get_timesteps(
986
+ num_inference_steps=num_inference_steps, strength=strength, device=device
987
+ )
988
+ if comfyui_progressbar:
989
+ from comfy.utils import ProgressBar
990
+ pbar = ProgressBar(num_inference_steps + 3)
991
+ # at which timestep to set the initial noise (n.b. 50% if strength is 0.5)
992
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
993
+ # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise
994
+ is_strength_max = strength == 1.0
995
+
996
+ if video is not None:
997
+ video_length = video.shape[2]
998
+ init_video = self.image_processor.preprocess(rearrange(video, "b c f h w -> (b f) c h w"), height=height, width=width)
999
+ init_video = init_video.to(dtype=torch.float32)
1000
+ init_video = rearrange(init_video, "(b f) c h w -> b c f h w", f=video_length)
1001
+ else:
1002
+ init_video = None
1003
+
1004
+ # Prepare latent variables
1005
+ num_channels_latents = self.vae.config.latent_channels
1006
+ num_channels_transformer = self.transformer.config.in_channels
1007
+ return_image_latents = num_channels_transformer == num_channels_latents
1008
+
1009
+ # 5. Prepare latents.
1010
+ latents_outputs = self.prepare_latents(
1011
+ batch_size * num_images_per_prompt,
1012
+ num_channels_latents,
1013
+ height,
1014
+ width,
1015
+ video_length,
1016
+ prompt_embeds.dtype,
1017
+ device,
1018
+ generator,
1019
+ latents,
1020
+ video=init_video,
1021
+ timestep=latent_timestep,
1022
+ is_strength_max=is_strength_max,
1023
+ return_noise=True,
1024
+ return_video_latents=return_image_latents,
1025
+ )
1026
+ if return_image_latents:
1027
+ latents, noise, image_latents = latents_outputs
1028
+ else:
1029
+ latents, noise = latents_outputs
1030
+
1031
+ if comfyui_progressbar:
1032
+ pbar.update(1)
1033
+
1034
+ # 6. Prepare clip latents if it needs.
1035
+ if clip_image is not None and self.transformer.enable_clip_in_inpaint:
1036
+ inputs = self.clip_image_processor(images=clip_image, return_tensors="pt")
1037
+ inputs["pixel_values"] = inputs["pixel_values"].to(latents.device, dtype=latents.dtype)
1038
+ clip_encoder_hidden_states = self.clip_image_encoder(**inputs).last_hidden_state[:, 1:]
1039
+ clip_encoder_hidden_states_neg = torch.zeros(
1040
+ [
1041
+ batch_size,
1042
+ int(self.clip_image_encoder.config.image_size / self.clip_image_encoder.config.patch_size) ** 2,
1043
+ int(self.clip_image_encoder.config.hidden_size)
1044
+ ]
1045
+ ).to(latents.device, dtype=latents.dtype)
1046
+
1047
+ clip_attention_mask = torch.ones([batch_size, self.transformer.n_query]).to(latents.device, dtype=latents.dtype)
1048
+ clip_attention_mask_neg = torch.zeros([batch_size, self.transformer.n_query]).to(latents.device, dtype=latents.dtype)
1049
+
1050
+ clip_encoder_hidden_states_input = torch.cat([clip_encoder_hidden_states_neg, clip_encoder_hidden_states]) if self.do_classifier_free_guidance else clip_encoder_hidden_states
1051
+ clip_attention_mask_input = torch.cat([clip_attention_mask_neg, clip_attention_mask]) if self.do_classifier_free_guidance else clip_attention_mask
1052
+
1053
+ elif clip_image is None and num_channels_transformer != num_channels_latents and self.transformer.enable_clip_in_inpaint:
1054
+ clip_encoder_hidden_states = torch.zeros(
1055
+ [
1056
+ batch_size,
1057
+ int(self.clip_image_encoder.config.image_size / self.clip_image_encoder.config.patch_size) ** 2,
1058
+ int(self.clip_image_encoder.config.hidden_size)
1059
+ ]
1060
+ ).to(latents.device, dtype=latents.dtype)
1061
+
1062
+ clip_attention_mask = torch.zeros([batch_size, self.transformer.n_query])
1063
+ clip_attention_mask = clip_attention_mask.to(latents.device, dtype=latents.dtype)
1064
+
1065
+ clip_encoder_hidden_states_input = torch.cat([clip_encoder_hidden_states] * 2) if self.do_classifier_free_guidance else clip_encoder_hidden_states
1066
+ clip_attention_mask_input = torch.cat([clip_attention_mask] * 2) if self.do_classifier_free_guidance else clip_attention_mask
1067
+
1068
+ else:
1069
+ clip_encoder_hidden_states_input = None
1070
+ clip_attention_mask_input = None
1071
+ if comfyui_progressbar:
1072
+ pbar.update(1)
1073
+
1074
+ # 7. Prepare inpaint latents if it needs.
1075
+ if mask_video is not None:
1076
+ if (mask_video == 255).all():
1077
+ # Use zero latents if we want to t2v.
1078
+ if self.transformer.resize_inpaint_mask_directly:
1079
+ mask_latents = torch.zeros_like(latents)[:, :1].to(latents.device, latents.dtype)
1080
+ else:
1081
+ mask_latents = torch.zeros_like(latents).to(latents.device, latents.dtype)
1082
+ masked_video_latents = torch.zeros_like(latents).to(latents.device, latents.dtype)
1083
+
1084
+ mask_input = torch.cat([mask_latents] * 2) if self.do_classifier_free_guidance else mask_latents
1085
+ masked_video_latents_input = (
1086
+ torch.cat([masked_video_latents] * 2) if self.do_classifier_free_guidance else masked_video_latents
1087
+ )
1088
+ inpaint_latents = torch.cat([mask_input, masked_video_latents_input], dim=1).to(latents.dtype)
1089
+ else:
1090
+ # Prepare mask latent variables
1091
+ video_length = video.shape[2]
1092
+ mask_condition = self.mask_processor.preprocess(rearrange(mask_video, "b c f h w -> (b f) c h w"), height=height, width=width)
1093
+ mask_condition = mask_condition.to(dtype=torch.float32)
1094
+ mask_condition = rearrange(mask_condition, "(b f) c h w -> b c f h w", f=video_length)
1095
+
1096
+ if num_channels_transformer != num_channels_latents:
1097
+ mask_condition_tile = torch.tile(mask_condition, [1, 3, 1, 1, 1])
1098
+ if masked_video_latents is None:
1099
+ masked_video = init_video * (mask_condition_tile < 0.5) + torch.ones_like(init_video) * (mask_condition_tile > 0.5) * -1
1100
+ else:
1101
+ masked_video = masked_video_latents
1102
+
1103
+ if self.transformer.resize_inpaint_mask_directly:
1104
+ _, masked_video_latents = self.prepare_mask_latents(
1105
+ None,
1106
+ masked_video,
1107
+ batch_size,
1108
+ height,
1109
+ width,
1110
+ prompt_embeds.dtype,
1111
+ device,
1112
+ generator,
1113
+ self.do_classifier_free_guidance,
1114
+ noise_aug_strength=noise_aug_strength,
1115
+ )
1116
+ mask_latents = resize_mask(1 - mask_condition, masked_video_latents, self.vae.cache_mag_vae)
1117
+ mask_latents = mask_latents.to(masked_video_latents.device) * self.vae.config.scaling_factor
1118
+ else:
1119
+ mask_latents, masked_video_latents = self.prepare_mask_latents(
1120
+ mask_condition_tile,
1121
+ masked_video,
1122
+ batch_size,
1123
+ height,
1124
+ width,
1125
+ prompt_embeds.dtype,
1126
+ device,
1127
+ generator,
1128
+ self.do_classifier_free_guidance,
1129
+ noise_aug_strength=noise_aug_strength,
1130
+ )
1131
+
1132
+ mask_input = torch.cat([mask_latents] * 2) if self.do_classifier_free_guidance else mask_latents
1133
+ masked_video_latents_input = (
1134
+ torch.cat([masked_video_latents] * 2) if self.do_classifier_free_guidance else masked_video_latents
1135
+ )
1136
+ inpaint_latents = torch.cat([mask_input, masked_video_latents_input], dim=1).to(latents.dtype)
1137
+ else:
1138
+ inpaint_latents = None
1139
+
1140
+ mask = torch.tile(mask_condition, [1, num_channels_latents, 1, 1, 1])
1141
+ mask = F.interpolate(mask, size=latents.size()[-3:], mode='trilinear', align_corners=True).to(latents.device, latents.dtype)
1142
+ else:
1143
+ if num_channels_transformer != num_channels_latents:
1144
+ mask = torch.zeros_like(latents).to(latents.device, latents.dtype)
1145
+ masked_video_latents = torch.zeros_like(latents).to(latents.device, latents.dtype)
1146
+
1147
+ mask_input = torch.cat([mask] * 2) if self.do_classifier_free_guidance else mask
1148
+ masked_video_latents_input = (
1149
+ torch.cat([masked_video_latents] * 2) if self.do_classifier_free_guidance else masked_video_latents
1150
+ )
1151
+ inpaint_latents = torch.cat([mask_input, masked_video_latents_input], dim=1).to(latents.dtype)
1152
+ else:
1153
+ mask = torch.zeros_like(init_video[:, :1])
1154
+ mask = torch.tile(mask, [1, num_channels_latents, 1, 1, 1])
1155
+ mask = F.interpolate(mask, size=latents.size()[-3:], mode='trilinear', align_corners=True).to(latents.device, latents.dtype)
1156
+
1157
+ inpaint_latents = None
1158
+ if comfyui_progressbar:
1159
+ pbar.update(1)
1160
+
1161
+ # Check that sizes of mask, masked image and latents match
1162
+ if num_channels_transformer != num_channels_latents:
1163
+ num_channels_mask = mask_latents.shape[1]
1164
+ num_channels_masked_image = masked_video_latents.shape[1]
1165
+ if num_channels_latents + num_channels_mask + num_channels_masked_image != self.transformer.config.in_channels:
1166
+ raise ValueError(
1167
+ f"Incorrect configuration settings! The config of `pipeline.transformer`: {self.transformer.config} expects"
1168
+ f" {self.transformer.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
1169
+ f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
1170
+ f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
1171
+ " `pipeline.transformer` or your `mask_image` or `image` input."
1172
+ )
1173
+
1174
+ # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
1175
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
1176
+
1177
+ # 9 create image_rotary_emb, style embedding & time ids
1178
+ grid_height = height // 8 // self.transformer.config.patch_size
1179
+ grid_width = width // 8 // self.transformer.config.patch_size
1180
+ if self.transformer.config.get("time_position_encoding_type", "2d_rope") == "3d_rope":
1181
+ base_size_width = 720 // 8 // self.transformer.config.patch_size
1182
+ base_size_height = 480 // 8 // self.transformer.config.patch_size
1183
+
1184
+ grid_crops_coords = get_resize_crop_region_for_grid(
1185
+ (grid_height, grid_width), base_size_width, base_size_height
1186
+ )
1187
+ image_rotary_emb = get_3d_rotary_pos_embed(
1188
+ self.transformer.config.attention_head_dim, grid_crops_coords, grid_size=(grid_height, grid_width),
1189
+ temporal_size=latents.size(2), use_real=True,
1190
+ )
1191
+ else:
1192
+ base_size = 512 // 8 // self.transformer.config.patch_size
1193
+ grid_crops_coords = get_resize_crop_region_for_grid(
1194
+ (grid_height, grid_width), base_size, base_size
1195
+ )
1196
+ image_rotary_emb = get_2d_rotary_pos_embed(
1197
+ self.transformer.config.attention_head_dim, grid_crops_coords, (grid_height, grid_width)
1198
+ )
1199
+
1200
+ # Get other hunyuan params
1201
+ style = torch.tensor([0], device=device)
1202
+
1203
+ target_size = target_size or (height, width)
1204
+ add_time_ids = list(original_size + target_size + crops_coords_top_left)
1205
+ add_time_ids = torch.tensor([add_time_ids], dtype=prompt_embeds.dtype)
1206
+
1207
+ if self.do_classifier_free_guidance:
1208
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
1209
+ prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask])
1210
+ prompt_embeds_2 = torch.cat([negative_prompt_embeds_2, prompt_embeds_2])
1211
+ prompt_attention_mask_2 = torch.cat([negative_prompt_attention_mask_2, prompt_attention_mask_2])
1212
+ add_time_ids = torch.cat([add_time_ids] * 2, dim=0)
1213
+ style = torch.cat([style] * 2, dim=0)
1214
+
1215
+ prompt_embeds = prompt_embeds.to(device=device)
1216
+ prompt_attention_mask = prompt_attention_mask.to(device=device)
1217
+ prompt_embeds_2 = prompt_embeds_2.to(device=device)
1218
+ prompt_attention_mask_2 = prompt_attention_mask_2.to(device=device)
1219
+ add_time_ids = add_time_ids.to(dtype=prompt_embeds.dtype, device=device).repeat(
1220
+ batch_size * num_images_per_prompt, 1
1221
+ )
1222
+ style = style.to(device=device).repeat(batch_size * num_images_per_prompt)
1223
+
1224
+ torch.cuda.empty_cache()
1225
+ if self.enable_autocast_float8_transformer_flag:
1226
+ origin_weight_dtype = self.transformer.dtype
1227
+ self.transformer = self.transformer.to(torch.float8_e4m3fn)
1228
+ # 10. Denoising loop
1229
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
1230
+ self._num_timesteps = len(timesteps)
1231
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1232
+ for i, t in enumerate(timesteps):
1233
+ if self.interrupt:
1234
+ continue
1235
+
1236
+ # expand the latents if we are doing classifier free guidance
1237
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
1238
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1239
+
1240
+ if i < len(timesteps) * (1 - clip_apply_ratio) and clip_encoder_hidden_states_input is not None:
1241
+ clip_encoder_hidden_states_actual_input = torch.zeros_like(clip_encoder_hidden_states_input)
1242
+ clip_attention_mask_actual_input = torch.zeros_like(clip_attention_mask_input)
1243
+ else:
1244
+ clip_encoder_hidden_states_actual_input = clip_encoder_hidden_states_input
1245
+ clip_attention_mask_actual_input = clip_attention_mask_input
1246
+
1247
+ # expand scalar t to 1-D tensor to match the 1st dim of latent_model_input
1248
+ t_expand = torch.tensor([t] * latent_model_input.shape[0], device=device).to(
1249
+ dtype=latent_model_input.dtype
1250
+ )
1251
+
1252
+ # predict the noise residual
1253
+ noise_pred = self.transformer(
1254
+ latent_model_input,
1255
+ t_expand,
1256
+ encoder_hidden_states=prompt_embeds,
1257
+ text_embedding_mask=prompt_attention_mask,
1258
+ encoder_hidden_states_t5=prompt_embeds_2,
1259
+ text_embedding_mask_t5=prompt_attention_mask_2,
1260
+ image_meta_size=add_time_ids,
1261
+ style=style,
1262
+ image_rotary_emb=image_rotary_emb,
1263
+ inpaint_latents=inpaint_latents,
1264
+ clip_encoder_hidden_states=clip_encoder_hidden_states_actual_input,
1265
+ clip_attention_mask=clip_attention_mask_actual_input,
1266
+ return_dict=False,
1267
+ )[0]
1268
+ if noise_pred.size()[1] != self.vae.config.latent_channels:
1269
+ noise_pred, _ = noise_pred.chunk(2, dim=1)
1270
+
1271
+ # perform guidance
1272
+ if self.do_classifier_free_guidance:
1273
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1274
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
1275
+
1276
+ if self.do_classifier_free_guidance and guidance_rescale > 0.0:
1277
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
1278
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
1279
+
1280
+ # compute the previous noisy sample x_t -> x_t-1
1281
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
1282
+
1283
+ if num_channels_transformer == 4:
1284
+ init_latents_proper = image_latents
1285
+ init_mask = mask
1286
+ if i < len(timesteps) - 1:
1287
+ noise_timestep = timesteps[i + 1]
1288
+ init_latents_proper = self.scheduler.add_noise(
1289
+ init_latents_proper, noise, torch.tensor([noise_timestep])
1290
+ )
1291
+
1292
+ latents = (1 - init_mask) * init_latents_proper + init_mask * latents
1293
+
1294
+ if callback_on_step_end is not None:
1295
+ callback_kwargs = {}
1296
+ for k in callback_on_step_end_tensor_inputs:
1297
+ callback_kwargs[k] = locals()[k]
1298
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1299
+
1300
+ latents = callback_outputs.pop("latents", latents)
1301
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1302
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
1303
+ prompt_embeds_2 = callback_outputs.pop("prompt_embeds_2", prompt_embeds_2)
1304
+ negative_prompt_embeds_2 = callback_outputs.pop(
1305
+ "negative_prompt_embeds_2", negative_prompt_embeds_2
1306
+ )
1307
+
1308
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1309
+ progress_bar.update()
1310
+
1311
+ if XLA_AVAILABLE:
1312
+ xm.mark_step()
1313
+
1314
+ if comfyui_progressbar:
1315
+ pbar.update(1)
1316
+
1317
+ if self.enable_autocast_float8_transformer_flag:
1318
+ self.transformer = self.transformer.to("cpu", origin_weight_dtype)
1319
+
1320
+ torch.cuda.empty_cache()
1321
+ # Post-processing
1322
+ video = self.decode_latents(latents)
1323
+
1324
+ # Convert to tensor
1325
+ if output_type == "latent":
1326
+ video = torch.from_numpy(video)
1327
+
1328
+ # Offload all models
1329
+ self.maybe_free_model_hooks()
1330
+
1331
+ if not return_dict:
1332
+ return video
1333
+
1334
+ return EasyAnimatePipelineOutput(videos=video)
easyanimate/ui/ui.py CHANGED
The diff for this file is too large to render. See raw diff
 
easyanimate/utils/discrete_sampler.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Modified from https://github.com/THUDM/CogVideo/blob/3710a612d8760f5cdb1741befeebb65b9e0f2fe0/sat/sgm/modules/diffusionmodules/sigma_sampling.py
2
+ """
3
+ import torch
4
+
5
+ class DiscreteSampling:
6
+ def __init__(self, num_idx, uniform_sampling=False):
7
+ self.num_idx = num_idx
8
+ self.uniform_sampling = uniform_sampling
9
+ self.is_distributed = torch.distributed.is_available() and torch.distributed.is_initialized()
10
+
11
+ if self.is_distributed and self.uniform_sampling:
12
+ world_size = torch.distributed.get_world_size()
13
+ self.rank = torch.distributed.get_rank()
14
+
15
+ i = 1
16
+ while True:
17
+ if world_size % i != 0 or num_idx % (world_size // i) != 0:
18
+ i += 1
19
+ else:
20
+ self.group_num = world_size // i
21
+ break
22
+ assert self.group_num > 0
23
+ assert world_size % self.group_num == 0
24
+ # the number of rank in one group
25
+ self.group_width = world_size // self.group_num
26
+ self.sigma_interval = self.num_idx // self.group_num
27
+ print('rank=%d world_size=%d group_num=%d group_width=%d sigma_interval=%s' % (
28
+ self.rank, world_size, self.group_num,
29
+ self.group_width, self.sigma_interval))
30
+
31
+ def __call__(self, n_samples, generator=None, device=None):
32
+ if self.is_distributed and self.uniform_sampling:
33
+ group_index = self.rank // self.group_width
34
+ idx = torch.randint(
35
+ group_index * self.sigma_interval,
36
+ (group_index + 1) * self.sigma_interval,
37
+ (n_samples,),
38
+ generator=generator, device=device,
39
+ )
40
+ print('proc[%d] idx=%s' % (self.rank, idx))
41
+ else:
42
+ idx = torch.randint(
43
+ 0, self.num_idx, (n_samples,),
44
+ generator=generator, device=device,
45
+ )
46
+ return idx
easyanimate/utils/fp8_optimization.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Modified from https://github.com/kijai/ComfyUI-MochiWrapper
2
+ """
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ def autocast_model_forward(cls, origin_dtype, *inputs, **kwargs):
7
+ weight_dtype = cls.weight.dtype
8
+ cls.to(origin_dtype)
9
+
10
+ # Convert all inputs to the original dtype
11
+ inputs = [input.to(origin_dtype) for input in inputs]
12
+ out = cls.original_forward(*inputs, **kwargs)
13
+
14
+ cls.to(weight_dtype)
15
+ return out
16
+
17
+ def convert_weight_dtype_wrapper(module, origin_dtype):
18
+ for name, module in module.named_modules():
19
+ if name == "":
20
+ continue
21
+ original_forward = module.forward
22
+ if hasattr(module, "weight"):
23
+ setattr(module, "original_forward", original_forward)
24
+ setattr(
25
+ module,
26
+ "forward",
27
+ lambda *inputs, m=module, **kwargs: autocast_model_forward(m, origin_dtype, *inputs, **kwargs)
28
+ )
easyanimate/utils/lora_utils.py CHANGED
@@ -156,8 +156,8 @@ def precalculate_safetensors_hashes(tensors, metadata):
156
 
157
 
158
  class LoRANetwork(torch.nn.Module):
159
- TRANSFORMER_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Transformer3DModel"]
160
- TEXT_ENCODER_TARGET_REPLACE_MODULE = ["T5LayerSelfAttention", "T5LayerFF"]
161
  LORA_PREFIX_TRANSFORMER = "lora_unet"
162
  LORA_PREFIX_TEXT_ENCODER = "lora_te"
163
  def __init__(
@@ -238,9 +238,10 @@ class LoRANetwork(torch.nn.Module):
238
  self.text_encoder_loras = []
239
  skipped_te = []
240
  for i, text_encoder in enumerate(text_encoders):
241
- text_encoder_loras, skipped = create_modules(False, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
242
- self.text_encoder_loras.extend(text_encoder_loras)
243
- skipped_te += skipped
 
244
  print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
245
 
246
  self.unet_loras, skipped_un = create_modules(True, unet, LoRANetwork.TRANSFORMER_TARGET_REPLACE_MODULE)
@@ -368,6 +369,7 @@ def create_network(
368
  def merge_lora(pipeline, lora_path, multiplier, device='cpu', dtype=torch.float32, state_dict=None, transformer_only=False):
369
  LORA_PREFIX_TRANSFORMER = "lora_unet"
370
  LORA_PREFIX_TEXT_ENCODER = "lora_te"
 
371
  if state_dict is None:
372
  state_dict = load_file(lora_path, device=device)
373
  else:
@@ -389,21 +391,24 @@ def merge_lora(pipeline, lora_path, multiplier, device='cpu', dtype=torch.float3
389
  layer_infos = layer.split(LORA_PREFIX_TRANSFORMER + "_")[-1].split("_")
390
  curr_layer = pipeline.transformer
391
 
392
- temp_name = layer_infos.pop(0)
393
- while len(layer_infos) > -1:
394
- try:
395
- curr_layer = curr_layer.__getattr__(temp_name)
396
- if len(layer_infos) > 0:
397
- temp_name = layer_infos.pop(0)
398
- elif len(layer_infos) == 0:
399
- break
400
- except Exception:
401
- if len(layer_infos) == 0:
402
- print('Error loading layer')
403
- if len(temp_name) > 0:
404
- temp_name += "_" + layer_infos.pop(0)
405
- else:
406
- temp_name = layer_infos.pop(0)
 
 
 
407
 
408
  weight_up = elems['lora_up.weight'].to(dtype)
409
  weight_down = elems['lora_down.weight'].to(dtype)
@@ -444,6 +449,7 @@ def unmerge_lora(pipeline, lora_path, multiplier=1, device="cpu", dtype=torch.fl
444
  curr_layer = pipeline.transformer
445
 
446
  temp_name = layer_infos.pop(0)
 
447
  while len(layer_infos) > -1:
448
  try:
449
  curr_layer = curr_layer.__getattr__(temp_name)
 
156
 
157
 
158
  class LoRANetwork(torch.nn.Module):
159
+ TRANSFORMER_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Transformer3DModel", "HunyuanTransformer3DModel", "EasyAnimateTransformer3DModel"]
160
+ TEXT_ENCODER_TARGET_REPLACE_MODULE = ["T5LayerSelfAttention", "T5LayerFF", "BertEncoder"]
161
  LORA_PREFIX_TRANSFORMER = "lora_unet"
162
  LORA_PREFIX_TEXT_ENCODER = "lora_te"
163
  def __init__(
 
238
  self.text_encoder_loras = []
239
  skipped_te = []
240
  for i, text_encoder in enumerate(text_encoders):
241
+ if text_encoder is not None:
242
+ text_encoder_loras, skipped = create_modules(False, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
243
+ self.text_encoder_loras.extend(text_encoder_loras)
244
+ skipped_te += skipped
245
  print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
246
 
247
  self.unet_loras, skipped_un = create_modules(True, unet, LoRANetwork.TRANSFORMER_TARGET_REPLACE_MODULE)
 
369
  def merge_lora(pipeline, lora_path, multiplier, device='cpu', dtype=torch.float32, state_dict=None, transformer_only=False):
370
  LORA_PREFIX_TRANSFORMER = "lora_unet"
371
  LORA_PREFIX_TEXT_ENCODER = "lora_te"
372
+ SPECIAL_LAYER_NAME = ["text_proj_t5"]
373
  if state_dict is None:
374
  state_dict = load_file(lora_path, device=device)
375
  else:
 
391
  layer_infos = layer.split(LORA_PREFIX_TRANSFORMER + "_")[-1].split("_")
392
  curr_layer = pipeline.transformer
393
 
394
+ try:
395
+ curr_layer = curr_layer.__getattr__("_".join(layer_infos[1:]))
396
+ except Exception:
397
+ temp_name = layer_infos.pop(0)
398
+ while len(layer_infos) > -1:
399
+ try:
400
+ curr_layer = curr_layer.__getattr__(temp_name)
401
+ if len(layer_infos) > 0:
402
+ temp_name = layer_infos.pop(0)
403
+ elif len(layer_infos) == 0:
404
+ break
405
+ except Exception:
406
+ if len(layer_infos) == 0:
407
+ print('Error loading layer')
408
+ if len(temp_name) > 0:
409
+ temp_name += "_" + layer_infos.pop(0)
410
+ else:
411
+ temp_name = layer_infos.pop(0)
412
 
413
  weight_up = elems['lora_up.weight'].to(dtype)
414
  weight_down = elems['lora_down.weight'].to(dtype)
 
449
  curr_layer = pipeline.transformer
450
 
451
  temp_name = layer_infos.pop(0)
452
+ print(layer, curr_layer)
453
  while len(layer_infos) > -1:
454
  try:
455
  curr_layer = curr_layer.__getattr__(temp_name)
easyanimate/utils/utils.py CHANGED
@@ -1,13 +1,15 @@
 
1
  import os
2
 
 
3
  import imageio
4
  import numpy as np
5
  import torch
6
  import torchvision
7
- import cv2
8
  from einops import rearrange
9
  from PIL import Image
10
 
 
11
  def get_width_and_height_from_image_and_base_resolution(image, base_resolution):
12
  target_pixels = int(base_resolution) * int(base_resolution)
13
  original_width, original_height = Image.open(image).size
@@ -73,13 +75,20 @@ def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, f
73
  def get_image_to_video_latent(validation_image_start, validation_image_end, video_length, sample_size):
74
  if validation_image_start is not None and validation_image_end is not None:
75
  if type(validation_image_start) is str and os.path.isfile(validation_image_start):
76
- image_start = clip_image = Image.open(validation_image_start)
 
 
77
  else:
78
  image_start = clip_image = validation_image_start
 
 
 
79
  if type(validation_image_end) is str and os.path.isfile(validation_image_end):
80
- image_end = Image.open(validation_image_end)
 
81
  else:
82
  image_end = validation_image_end
 
83
 
84
  if type(image_start) is list:
85
  clip_image = clip_image[0]
@@ -119,8 +128,13 @@ def get_image_to_video_latent(validation_image_start, validation_image_end, vide
119
  elif validation_image_start is not None:
120
  if type(validation_image_start) is str and os.path.isfile(validation_image_start):
121
  image_start = clip_image = Image.open(validation_image_start).convert("RGB")
 
 
122
  else:
123
  image_start = clip_image = validation_image_start
 
 
 
124
 
125
  if type(image_start) is list:
126
  clip_image = clip_image[0]
@@ -142,30 +156,60 @@ def get_image_to_video_latent(validation_image_start, validation_image_end, vide
142
  input_video_mask = torch.zeros_like(input_video[:, :1])
143
  input_video_mask[:, :, 1:, ] = 255
144
  else:
 
 
145
  input_video = torch.zeros([1, 3, video_length, sample_size[0], sample_size[1]])
146
  input_video_mask = torch.ones([1, 1, video_length, sample_size[0], sample_size[1]]) * 255
147
  clip_image = None
148
 
 
 
 
 
149
  return input_video, input_video_mask, clip_image
150
 
151
- def video_frames(input_video_path):
152
- cap = cv2.VideoCapture(input_video_path)
153
- frames = []
154
- while True:
155
- ret, frame = cap.read()
156
- if not ret:
157
- break
158
- frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
159
- cap.release()
160
- cv2.destroyAllWindows()
161
- return frames
162
-
163
- def get_video_to_video_latent(validation_videos, video_length):
164
- input_video = video_frames(validation_videos)
 
 
 
 
 
 
 
 
 
 
 
165
  input_video = torch.from_numpy(np.array(input_video))[:video_length]
166
  input_video = input_video.permute([3, 0, 1, 2]).unsqueeze(0) / 255
167
 
168
- input_video_mask = torch.zeros_like(input_video[:, :1])
169
- input_video_mask[:, :, :] = 255
 
 
 
 
 
 
 
 
 
 
 
 
 
170
 
171
- return input_video, input_video_mask, None
 
1
+ import gc
2
  import os
3
 
4
+ import cv2
5
  import imageio
6
  import numpy as np
7
  import torch
8
  import torchvision
 
9
  from einops import rearrange
10
  from PIL import Image
11
 
12
+
13
  def get_width_and_height_from_image_and_base_resolution(image, base_resolution):
14
  target_pixels = int(base_resolution) * int(base_resolution)
15
  original_width, original_height = Image.open(image).size
 
75
  def get_image_to_video_latent(validation_image_start, validation_image_end, video_length, sample_size):
76
  if validation_image_start is not None and validation_image_end is not None:
77
  if type(validation_image_start) is str and os.path.isfile(validation_image_start):
78
+ image_start = clip_image = Image.open(validation_image_start).convert("RGB")
79
+ image_start = image_start.resize([sample_size[1], sample_size[0]])
80
+ clip_image = clip_image.resize([sample_size[1], sample_size[0]])
81
  else:
82
  image_start = clip_image = validation_image_start
83
+ image_start = [_image_start.resize([sample_size[1], sample_size[0]]) for _image_start in image_start]
84
+ clip_image = [_clip_image.resize([sample_size[1], sample_size[0]]) for _clip_image in clip_image]
85
+
86
  if type(validation_image_end) is str and os.path.isfile(validation_image_end):
87
+ image_end = Image.open(validation_image_end).convert("RGB")
88
+ image_end = image_end.resize([sample_size[1], sample_size[0]])
89
  else:
90
  image_end = validation_image_end
91
+ image_end = [_image_end.resize([sample_size[1], sample_size[0]]) for _image_end in image_end]
92
 
93
  if type(image_start) is list:
94
  clip_image = clip_image[0]
 
128
  elif validation_image_start is not None:
129
  if type(validation_image_start) is str and os.path.isfile(validation_image_start):
130
  image_start = clip_image = Image.open(validation_image_start).convert("RGB")
131
+ image_start = image_start.resize([sample_size[1], sample_size[0]])
132
+ clip_image = clip_image.resize([sample_size[1], sample_size[0]])
133
  else:
134
  image_start = clip_image = validation_image_start
135
+ image_start = [_image_start.resize([sample_size[1], sample_size[0]]) for _image_start in image_start]
136
+ clip_image = [_clip_image.resize([sample_size[1], sample_size[0]]) for _clip_image in clip_image]
137
+ image_end = None
138
 
139
  if type(image_start) is list:
140
  clip_image = clip_image[0]
 
156
  input_video_mask = torch.zeros_like(input_video[:, :1])
157
  input_video_mask[:, :, 1:, ] = 255
158
  else:
159
+ image_start = None
160
+ image_end = None
161
  input_video = torch.zeros([1, 3, video_length, sample_size[0], sample_size[1]])
162
  input_video_mask = torch.ones([1, 1, video_length, sample_size[0], sample_size[1]]) * 255
163
  clip_image = None
164
 
165
+ del image_start
166
+ del image_end
167
+ gc.collect()
168
+
169
  return input_video, input_video_mask, clip_image
170
 
171
+ def get_video_to_video_latent(input_video_path, video_length, sample_size, fps=None, validation_video_mask=None, ref_image=None):
172
+ if isinstance(input_video_path, str):
173
+ cap = cv2.VideoCapture(input_video_path)
174
+ input_video = []
175
+
176
+ original_fps = cap.get(cv2.CAP_PROP_FPS)
177
+ frame_skip = 1 if fps is None else int(original_fps // fps)
178
+
179
+ frame_count = 0
180
+
181
+ while True:
182
+ ret, frame = cap.read()
183
+ if not ret:
184
+ break
185
+
186
+ if frame_count % frame_skip == 0:
187
+ frame = cv2.resize(frame, (sample_size[1], sample_size[0]))
188
+ input_video.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
189
+
190
+ frame_count += 1
191
+
192
+ cap.release()
193
+ else:
194
+ input_video = input_video_path
195
+
196
  input_video = torch.from_numpy(np.array(input_video))[:video_length]
197
  input_video = input_video.permute([3, 0, 1, 2]).unsqueeze(0) / 255
198
 
199
+ if ref_image is not None:
200
+ ref_image = Image.open(ref_image)
201
+ ref_image = torch.from_numpy(np.array(ref_image))
202
+ ref_image = ref_image.unsqueeze(0).permute([3, 0, 1, 2]).unsqueeze(0) / 255
203
+
204
+ if validation_video_mask is not None:
205
+ validation_video_mask = Image.open(validation_video_mask).convert('L').resize((sample_size[1], sample_size[0]))
206
+ input_video_mask = np.where(np.array(validation_video_mask) < 240, 0, 255)
207
+
208
+ input_video_mask = torch.from_numpy(np.array(input_video_mask)).unsqueeze(0).unsqueeze(-1).permute([3, 0, 1, 2]).unsqueeze(0)
209
+ input_video_mask = torch.tile(input_video_mask, [1, 1, input_video.size()[2], 1, 1])
210
+ input_video_mask = input_video_mask.to(input_video.device, input_video.dtype)
211
+ else:
212
+ input_video_mask = torch.zeros_like(input_video[:, :1])
213
+ input_video_mask[:, :, :] = 255
214
 
215
+ return input_video, input_video_mask, ref_image
easyanimate/vae/configs/autoencoder/autoencoder_kl_32x32x4_cogvideox.yaml ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 1.0e-04
3
+ target: easyanimate.vae.ldm.models.cogvideox_casual3dcnn.AutoencoderKLMagvit_CogVideoX
4
+ params:
5
+ latent_channels: 16
6
+ temporal_compression_ratio: 4
7
+ monitor: train/rec_loss
8
+ ckpt_path: vae/diffusion_pytorch_model.safetensors
9
+ down_block_types: ("CogVideoXDownBlock3D", "CogVideoXDownBlock3D", "CogVideoXDownBlock3D",
10
+ "CogVideoXDownBlock3D",)
11
+ up_block_types: ("CogVideoXUpBlock3D", "CogVideoXUpBlock3D", "CogVideoXUpBlock3D",
12
+ "CogVideoXUpBlock3D",)
13
+ lossconfig:
14
+ target: easyanimate.vae.ldm.modules.losses.LPIPSWithDiscriminator
15
+ params:
16
+ disc_start: 50001
17
+ kl_weight: 1.0e-06
18
+ disc_weight: 0.5
19
+ l2_loss_weight: 0.1
20
+ l1_loss_weight: 1.0
21
+ perceptual_weight: 1.0
22
+
23
+ data:
24
+ target: train_vae.DataModuleFromConfig
25
+
26
+ params:
27
+ batch_size: 1
28
+ wrap: true
29
+ num_workers: 8
30
+ train:
31
+ target: easyanimate.vae.ldm.data.dataset_image_video.CustomSRTrain
32
+ params:
33
+ data_json_path: pretrain.json
34
+ data_root: /your_data_root # This is used in relative path
35
+ size: 256
36
+ degradation: pil_nearest
37
+ video_size: 256
38
+ video_len: 49
39
+ slice_interval: 1
40
+ validation:
41
+ target: easyanimate.vae.ldm.data.dataset_image_video.CustomSRValidation
42
+ params:
43
+ data_json_path: pretrain.json
44
+ data_root: /your_data_root # This is used in relative path
45
+ size: 256
46
+ degradation: pil_nearest
47
+ video_size: 256
48
+ video_len: 49
49
+ slice_interval: 1
50
+
51
+ lightning:
52
+ callbacks:
53
+ image_logger:
54
+ target: train_vae.ImageLogger
55
+ params:
56
+ batch_frequency: 5000
57
+ max_images: 8
58
+ increase_log_steps: True
59
+
60
+ trainer:
61
+ benchmark: True
62
+ accumulate_grad_batches: 1
63
+ gpus: "0"
64
+ num_nodes: 1
easyanimate/vae/configs/autoencoder/autoencoder_kl_32x32x4_mag_v2.yaml ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 1.0e-04
3
+ target: easyanimate.vae.ldm.models.omnigen_casual3dcnn.AutoencoderKLMagvit_fromOmnigen
4
+ params:
5
+ spatial_group_norm: true
6
+ mid_block_attention_type: "spatial"
7
+ latent_channels: 16
8
+ monitor: train/rec_loss
9
+ ckpt_path: vae/diffusion_pytorch_model.safetensors
10
+ down_block_types: ("SpatialDownBlock3D", "SpatialTemporalDownBlock3D", "SpatialTemporalDownBlock3D",
11
+ "SpatialTemporalDownBlock3D",)
12
+ up_block_types: ("SpatialUpBlock3D", "SpatialTemporalUpBlock3D", "SpatialTemporalUpBlock3D",
13
+ "SpatialTemporalUpBlock3D",)
14
+ lossconfig:
15
+ target: easyanimate.vae.ldm.modules.losses.LPIPSWithDiscriminator
16
+ params:
17
+ disc_start: 50001
18
+ kl_weight: 1.0e-06
19
+ disc_weight: 0.5
20
+ l2_loss_weight: 0.1
21
+ l1_loss_weight: 1.0
22
+ perceptual_weight: 1.0
23
+
24
+ data:
25
+ target: train_vae.DataModuleFromConfig
26
+
27
+ params:
28
+ batch_size: 1
29
+ wrap: true
30
+ num_workers: 8
31
+ train:
32
+ target: easyanimate.vae.ldm.data.dataset_image_video.CustomSRTrain
33
+ params:
34
+ data_json_path: pretrain.json
35
+ data_root: /your_data_root # This is used in relative path
36
+ size: 256
37
+ degradation: pil_nearest
38
+ video_size: 256
39
+ video_len: 49
40
+ slice_interval: 1
41
+ validation:
42
+ target: easyanimate.vae.ldm.data.dataset_image_video.CustomSRValidation
43
+ params:
44
+ data_json_path: pretrain.json
45
+ data_root: /your_data_root # This is used in relative path
46
+ size: 256
47
+ degradation: pil_nearest
48
+ video_size: 256
49
+ video_len: 49
50
+ slice_interval: 1
51
+
52
+ lightning:
53
+ callbacks:
54
+ image_logger:
55
+ target: train_vae.ImageLogger
56
+ params:
57
+ batch_frequency: 5000
58
+ max_images: 8
59
+ increase_log_steps: True
60
+
61
+ trainer:
62
+ benchmark: True
63
+ accumulate_grad_batches: 1
64
+ gpus: "0"
65
+ num_nodes: 1
easyanimate/vae/ldm/data/dataset_callback.py CHANGED
@@ -1,6 +1,7 @@
1
  #-*- encoding:utf-8 -*-
2
  from pytorch_lightning.callbacks import Callback
3
 
 
4
  class DatasetCallback(Callback):
5
  def __init__(self):
6
  self.sampler_pos_start = 0
 
1
  #-*- encoding:utf-8 -*-
2
  from pytorch_lightning.callbacks import Callback
3
 
4
+
5
  class DatasetCallback(Callback):
6
  def __init__(self):
7
  self.sampler_pos_start = 0
easyanimate/vae/ldm/data/dataset_image_video.py CHANGED
@@ -17,7 +17,7 @@ from decord import VideoReader
17
  from func_timeout import FunctionTimedOut, func_set_timeout
18
  from omegaconf import OmegaConf
19
  from PIL import Image
20
- from torch.utils.data import (BatchSampler, Dataset, Sampler)
21
  from tqdm import tqdm
22
 
23
  from ..modules.image_degradation import (degradation_fn_bsr,
@@ -164,15 +164,18 @@ class ImageVideoDataset(Dataset):
164
  return self.base[index].get('type', 'image')
165
 
166
  def __getitem__(self, i):
167
- @func_set_timeout(3) # time wait 3 seconds
168
  def get_video_item(example):
169
  if self.data_root is not None:
170
  video_reader = VideoReader(os.path.join(self.data_root, example['file_path']))
171
  else:
172
  video_reader = VideoReader(example['file_path'])
173
  video_length = len(video_reader)
174
-
175
- clip_length = min(video_length, (self.video_len - 1) * self.slice_interval + 1)
 
 
 
176
  start_idx = random.randint(0, video_length - clip_length)
177
  batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.video_len, dtype=int)
178
 
 
17
  from func_timeout import FunctionTimedOut, func_set_timeout
18
  from omegaconf import OmegaConf
19
  from PIL import Image
20
+ from torch.utils.data import BatchSampler, Dataset, Sampler
21
  from tqdm import tqdm
22
 
23
  from ..modules.image_degradation import (degradation_fn_bsr,
 
164
  return self.base[index].get('type', 'image')
165
 
166
  def __getitem__(self, i):
167
+ @func_set_timeout(15) # time wait 3 seconds
168
  def get_video_item(example):
169
  if self.data_root is not None:
170
  video_reader = VideoReader(os.path.join(self.data_root, example['file_path']))
171
  else:
172
  video_reader = VideoReader(example['file_path'])
173
  video_length = len(video_reader)
174
+ if self.slice_interval == "rand":
175
+ slice_interval = np.random.choice([1, 2, 3])
176
+ else:
177
+ slice_interval = int(self.slice_interval)
178
+ clip_length = min(video_length, (self.video_len - 1) * slice_interval + 1)
179
  start_idx = random.randint(0, video_length - clip_length)
180
  batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.video_len, dtype=int)
181
 
easyanimate/vae/ldm/models/casual3dcnn.py ADDED
@@ -0,0 +1,337 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ from contextlib import contextmanager
3
+
4
+ import pytorch_lightning as pl
5
+ import torch
6
+ import torch.nn.functional as F
7
+
8
+ from ..modules.diffusionmodules.model import Decoder, Encoder
9
+ from ..modules.distributions.distributions import DiagonalGaussianDistribution
10
+ from ..util import instantiate_from_config
11
+ from .enc_dec import Decoder as Mag_Decoder
12
+ from .enc_dec import Encoder as Mag_Encoder
13
+
14
+
15
+ class AutoencoderKLMagvit(pl.LightningModule):
16
+ def __init__(self,
17
+ ddconfig,
18
+ lossconfig,
19
+ embed_dim,
20
+ ckpt_path=None,
21
+ ignore_keys=[],
22
+ image_key="image",
23
+ colorize_nlabels=None,
24
+ monitor=None,
25
+ ):
26
+ super().__init__()
27
+ self.image_key = image_key
28
+ self.encoder = Mag_Encoder()
29
+ self.decoder = Mag_Decoder()
30
+ self.loss = instantiate_from_config(lossconfig)
31
+ self.quant_conv = torch.nn.Conv3d(16, 16, 1)
32
+ self.post_quant_conv = torch.nn.Conv3d(8, 8, 1)
33
+ self.embed_dim = embed_dim
34
+ if colorize_nlabels is not None:
35
+ assert type(colorize_nlabels)==int
36
+ self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
37
+ if monitor is not None:
38
+ self.monitor = monitor
39
+ if ckpt_path is not None:
40
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
41
+
42
+ def init_from_ckpt(self, path, ignore_keys=list()):
43
+ sd = torch.load(path, map_location="cpu")["state_dict"]
44
+ keys = list(sd.keys())
45
+ for k in keys:
46
+ for ik in ignore_keys:
47
+ if k.startswith(ik):
48
+ print("Deleting key {} from state_dict.".format(k))
49
+ del sd[k]
50
+ self.load_state_dict(sd, strict=False)
51
+ print(f"Restored from {path}")
52
+
53
+ def encode(self, x):
54
+ h = self.encoder(x)
55
+ moments = self.quant_conv(h)
56
+ posterior = DiagonalGaussianDistribution(moments)
57
+ return posterior
58
+
59
+ def decode(self, z):
60
+ z = self.post_quant_conv(z)
61
+ dec = self.decoder(z)
62
+ return dec
63
+
64
+ def forward(self, input, sample_posterior=True):
65
+ if input.ndim==4:
66
+ input = input.unsqueeze(2)
67
+ posterior = self.encode(input)
68
+ if sample_posterior:
69
+ z = posterior.sample()
70
+ else:
71
+ z = posterior.mode()
72
+ dec = self.decode(z)
73
+ return dec, posterior
74
+
75
+ def get_input(self, batch, k):
76
+ x = batch[k]
77
+ if x.ndim==5:
78
+ x = x.permute(0, 4, 1, 2, 3).to(memory_format=torch.contiguous_format).float()
79
+ return x
80
+ if len(x.shape) == 3:
81
+ x = x[..., None]
82
+ x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
83
+ return x
84
+
85
+ def training_step(self, batch, batch_idx, optimizer_idx):
86
+ # tic = time.time()
87
+ inputs = self.get_input(batch, self.image_key)
88
+ # print(f"get_input time {time.time() - tic}")
89
+ # tic = time.time()
90
+ reconstructions, posterior = self(inputs)
91
+ # print(f"model forward time {time.time() - tic}")
92
+
93
+ if optimizer_idx == 0:
94
+ # train encoder+decoder+logvar
95
+ aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
96
+ last_layer=self.get_last_layer(), split="train")
97
+ self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
98
+ self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
99
+ # print(f"cal loss time {time.time() - tic}")
100
+ return aeloss
101
+
102
+ if optimizer_idx == 1:
103
+ # train the discriminator
104
+ discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
105
+ last_layer=self.get_last_layer(), split="train")
106
+
107
+ self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
108
+ self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
109
+ # print(f"cal loss time {time.time() - tic}")
110
+ return discloss
111
+
112
+ def validation_step(self, batch, batch_idx):
113
+ with torch.no_grad():
114
+ inputs = self.get_input(batch, self.image_key)
115
+ reconstructions, posterior = self(inputs)
116
+ aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
117
+ last_layer=self.get_last_layer(), split="val")
118
+
119
+ discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
120
+ last_layer=self.get_last_layer(), split="val")
121
+
122
+ self.log("val/rec_loss", log_dict_ae["val/rec_loss"])
123
+ self.log_dict(log_dict_ae)
124
+ self.log_dict(log_dict_disc)
125
+ return self.log_dict
126
+
127
+ def configure_optimizers(self):
128
+ lr = self.learning_rate
129
+ opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
130
+ list(self.decoder.parameters())+
131
+ list(self.quant_conv.parameters())+
132
+ list(self.post_quant_conv.parameters()),
133
+ lr=lr, betas=(0.5, 0.9))
134
+ opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
135
+ lr=lr, betas=(0.5, 0.9))
136
+ return [opt_ae, opt_disc], []
137
+
138
+ def get_last_layer(self):
139
+ return self.decoder.conv_out.weight
140
+
141
+ @torch.no_grad()
142
+ def log_images(self, batch, only_inputs=False, **kwargs):
143
+ log = dict()
144
+ x = self.get_input(batch, self.image_key)
145
+ x = x.to(self.device)
146
+ if not only_inputs:
147
+ xrec, posterior = self(x)
148
+ if x.shape[1] > 3:
149
+ # colorize with random projection
150
+ assert xrec.shape[1] > 3
151
+ x = self.to_rgb(x)
152
+ xrec = self.to_rgb(xrec)
153
+ log["samples"] = self.decode(torch.randn_like(posterior.sample()))
154
+ log["reconstructions"] = xrec
155
+ log["inputs"] = x
156
+ return log
157
+
158
+ def to_rgb(self, x):
159
+ assert self.image_key == "segmentation"
160
+ if not hasattr(self, "colorize"):
161
+ self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
162
+ x = F.conv2d(x, weight=self.colorize)
163
+ x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
164
+ return x
165
+
166
+ class AutoencoderKL(pl.LightningModule):
167
+ def __init__(self,
168
+ ddconfig,
169
+ lossconfig,
170
+ embed_dim,
171
+ ckpt_path=None,
172
+ ignore_keys=[],
173
+ image_key="image",
174
+ colorize_nlabels=None,
175
+ monitor=None,
176
+ ):
177
+ super().__init__()
178
+ self.image_key = image_key
179
+ self.encoder = Encoder(**ddconfig)
180
+ self.decoder = Decoder(**ddconfig)
181
+ self.loss = instantiate_from_config(lossconfig)
182
+ assert ddconfig["double_z"]
183
+ self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
184
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
185
+ self.embed_dim = embed_dim
186
+ if colorize_nlabels is not None:
187
+ assert type(colorize_nlabels)==int
188
+ self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
189
+ if monitor is not None:
190
+ self.monitor = monitor
191
+ if ckpt_path is not None:
192
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
193
+
194
+ def init_from_ckpt(self, path, ignore_keys=list()):
195
+ sd = torch.load(path, map_location="cpu")["state_dict"]
196
+ keys = list(sd.keys())
197
+ for k in keys:
198
+ for ik in ignore_keys:
199
+ if k.startswith(ik):
200
+ print("Deleting key {} from state_dict.".format(k))
201
+ del sd[k]
202
+ self.load_state_dict(sd, strict=False)
203
+ print(f"Restored from {path}")
204
+
205
+ def encode(self, x):
206
+ h = self.encoder(x)
207
+ moments = self.quant_conv(h)
208
+ posterior = DiagonalGaussianDistribution(moments)
209
+ return posterior
210
+
211
+ def decode(self, z):
212
+ z = self.post_quant_conv(z)
213
+ dec = self.decoder(z)
214
+ return dec
215
+
216
+ def forward(self, input, sample_posterior=True):
217
+ posterior = self.encode(input)
218
+ if sample_posterior:
219
+ z = posterior.sample()
220
+ else:
221
+ z = posterior.mode()
222
+ dec = self.decode(z)
223
+ return dec, posterior
224
+
225
+ def get_input(self, batch, k):
226
+ x = batch[k]
227
+ if len(x.shape) == 3:
228
+ x = x[..., None]
229
+ x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
230
+ return x
231
+
232
+ def training_step(self, batch, batch_idx, optimizer_idx):
233
+ # tic = time.time()
234
+ inputs = self.get_input(batch, self.image_key)
235
+ # print(f"get_input time {time.time() - tic}")
236
+ # tic = time.time()
237
+ reconstructions, posterior = self(inputs)
238
+ # print(f"model forward time {time.time() - tic}")
239
+ tic = time.time()
240
+
241
+ if optimizer_idx == 0:
242
+ # train encoder+decoder+logvar
243
+ aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
244
+ last_layer=self.get_last_layer(), split="train")
245
+ self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
246
+ self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
247
+ # print(f"cal loss time {time.time() - tic}")
248
+ return aeloss
249
+
250
+ if optimizer_idx == 1:
251
+ # train the discriminator
252
+ discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
253
+ last_layer=self.get_last_layer(), split="train")
254
+
255
+ self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
256
+ self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
257
+ # print(f"cal loss time {time.time() - tic}")
258
+ return discloss
259
+
260
+ def validation_step(self, batch, batch_idx):
261
+ tic = time.time()
262
+ inputs = self.get_input(batch, self.image_key)
263
+ print(f"get_input time {time.time() - tic}")
264
+ tic = time.time()
265
+ reconstructions, posterior = self(inputs)
266
+ print(f"val forward time {time.time() - tic}")
267
+ tic = time.time()
268
+ aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
269
+ last_layer=self.get_last_layer(), split="val")
270
+
271
+ discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
272
+ last_layer=self.get_last_layer(), split="val")
273
+
274
+ self.log("val/rec_loss", log_dict_ae["val/rec_loss"])
275
+ self.log_dict(log_dict_ae)
276
+ self.log_dict(log_dict_disc)
277
+ print(f"val end time {time.time() - tic}")
278
+ return self.log_dict
279
+
280
+ def configure_optimizers(self):
281
+ lr = self.learning_rate
282
+ opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
283
+ list(self.decoder.parameters())+
284
+ list(self.quant_conv.parameters())+
285
+ list(self.post_quant_conv.parameters()),
286
+ lr=lr, betas=(0.5, 0.9))
287
+ opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
288
+ lr=lr, betas=(0.5, 0.9))
289
+ return [opt_ae, opt_disc], []
290
+
291
+ def get_last_layer(self):
292
+ return self.decoder.conv_out.weight
293
+
294
+ @torch.no_grad()
295
+ def log_images(self, batch, only_inputs=False, **kwargs):
296
+ log = dict()
297
+ x = self.get_input(batch, self.image_key)
298
+ x = x.to(self.device)
299
+ if not only_inputs:
300
+ xrec, posterior = self(x)
301
+ if x.shape[1] > 3:
302
+ # colorize with random projection
303
+ assert xrec.shape[1] > 3
304
+ x = self.to_rgb(x)
305
+ xrec = self.to_rgb(xrec)
306
+ log["samples"] = self.decode(torch.randn_like(posterior.sample()))
307
+ log["reconstructions"] = xrec
308
+ log["inputs"] = x
309
+ return log
310
+
311
+ def to_rgb(self, x):
312
+ assert self.image_key == "segmentation"
313
+ if not hasattr(self, "colorize"):
314
+ self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
315
+ x = F.conv2d(x, weight=self.colorize)
316
+ x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
317
+ return x
318
+
319
+
320
+ class IdentityFirstStage(torch.nn.Module):
321
+ def __init__(self, *args, vq_interface=False, **kwargs):
322
+ self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff
323
+ super().__init__()
324
+
325
+ def encode(self, x, *args, **kwargs):
326
+ return x
327
+
328
+ def decode(self, x, *args, **kwargs):
329
+ return x
330
+
331
+ def quantize(self, x, *args, **kwargs):
332
+ if self.vq_interface:
333
+ return x, None, [None, None, None]
334
+ return x
335
+
336
+ def forward(self, x, *args, **kwargs):
337
+ return x
easyanimate/vae/ldm/models/cogvideox_casual3dcnn.py ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Dict, Optional, Tuple
3
+
4
+ import numpy as np
5
+ import pytorch_lightning as pl
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from einops import rearrange
10
+
11
+ from ..util import instantiate_from_config
12
+ from .cogvideox_enc_dec import (CogVideoXDecoder3D, CogVideoXEncoder3D,
13
+ CogVideoXSafeConv3d)
14
+
15
+
16
+ class DiagonalGaussianDistribution:
17
+ def __init__(
18
+ self,
19
+ mean: torch.Tensor,
20
+ logvar: torch.Tensor,
21
+ deterministic: bool = False,
22
+ ):
23
+ self.mean = mean
24
+ self.logvar = torch.clamp(logvar, -30.0, 20.0)
25
+ self.deterministic = deterministic
26
+
27
+ if deterministic:
28
+ self.var = self.std = torch.zeros_like(self.mean)
29
+ else:
30
+ self.std = torch.exp(0.5 * self.logvar)
31
+ self.var = torch.exp(self.logvar)
32
+
33
+ def sample(self, generator = None) -> torch.FloatTensor:
34
+ x = torch.randn(
35
+ self.mean.shape,
36
+ generator=generator,
37
+ device=self.mean.device,
38
+ dtype=self.mean.dtype,
39
+ )
40
+ return self.mean + self.std * x
41
+
42
+ def mode(self):
43
+ return self.mean
44
+
45
+ def kl(self, other: Optional["DiagonalGaussianDistribution"] = None) -> torch.Tensor:
46
+ dims = list(range(1, self.mean.ndim))
47
+
48
+ if self.deterministic:
49
+ return torch.Tensor([0.0])
50
+ else:
51
+ if other is None:
52
+ return 0.5 * torch.sum(
53
+ torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
54
+ dim=dims,
55
+ )
56
+ else:
57
+ return 0.5 * torch.sum(
58
+ torch.pow(self.mean - other.mean, 2) / other.var
59
+ + self.var / other.var
60
+ - 1.0
61
+ - self.logvar
62
+ + other.logvar,
63
+ dim=dims,
64
+ )
65
+
66
+ def nll(self, sample: torch.Tensor) -> torch.Tensor:
67
+ dims = list(range(1, self.mean.ndim))
68
+
69
+ if self.deterministic:
70
+ return torch.Tensor([0.0])
71
+
72
+ logtwopi = np.log(2.0 * np.pi)
73
+ return 0.5 * torch.sum(
74
+ logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
75
+ dim=dims,
76
+ )
77
+
78
+ @dataclass
79
+ class EncoderOutput:
80
+ latent_dist: DiagonalGaussianDistribution
81
+
82
+ @dataclass
83
+ class DecoderOutput:
84
+ sample: torch.Tensor
85
+
86
+ def str_eval(item):
87
+ if type(item) == str:
88
+ return eval(item)
89
+ else:
90
+ return item
91
+
92
+ class AutoencoderKLMagvit_CogVideoX(pl.LightningModule):
93
+ def __init__(
94
+ self,
95
+ in_channels: int = 3,
96
+ out_channels: int = 3,
97
+ down_block_types: Tuple[str] = (
98
+ "CogVideoXDownBlock3D",
99
+ "CogVideoXDownBlock3D",
100
+ "CogVideoXDownBlock3D",
101
+ "CogVideoXDownBlock3D",
102
+ ),
103
+ up_block_types: Tuple[str] = (
104
+ "CogVideoXUpBlock3D",
105
+ "CogVideoXUpBlock3D",
106
+ "CogVideoXUpBlock3D",
107
+ "CogVideoXUpBlock3D",
108
+ ),
109
+ block_out_channels: Tuple[int] = (128, 256, 256, 512),
110
+ latent_channels: int = 16,
111
+ layers_per_block: int = 3,
112
+ act_fn: str = "silu",
113
+ norm_eps: float = 1e-6,
114
+ norm_num_groups: int = 32,
115
+ temporal_compression_ratio: float = 4,
116
+ use_quant_conv: bool = False,
117
+ use_post_quant_conv: bool = False,
118
+
119
+ mini_batch_encoder=4,
120
+ mini_batch_decoder=1,
121
+
122
+ image_key="image",
123
+ train_decoder_only=False,
124
+ train_encoder_only=False,
125
+ monitor=None,
126
+ ckpt_path=None,
127
+ lossconfig=None,
128
+ ):
129
+ super().__init__()
130
+ self.image_key = image_key
131
+ down_block_types = str_eval(down_block_types)
132
+ up_block_types = str_eval(up_block_types)
133
+
134
+ self.encoder = CogVideoXEncoder3D(
135
+ in_channels=in_channels,
136
+ out_channels=latent_channels,
137
+ down_block_types=down_block_types,
138
+ block_out_channels=block_out_channels,
139
+ layers_per_block=layers_per_block,
140
+ act_fn=act_fn,
141
+ norm_eps=norm_eps,
142
+ norm_num_groups=norm_num_groups,
143
+ temporal_compression_ratio=temporal_compression_ratio,
144
+ )
145
+
146
+ self.decoder = CogVideoXDecoder3D(
147
+ in_channels=latent_channels,
148
+ out_channels=out_channels,
149
+ up_block_types=up_block_types,
150
+ block_out_channels=block_out_channels,
151
+ layers_per_block=layers_per_block,
152
+ act_fn=act_fn,
153
+ norm_eps=norm_eps,
154
+ norm_num_groups=norm_num_groups,
155
+ temporal_compression_ratio=temporal_compression_ratio,
156
+ )
157
+ self.quant_conv = CogVideoXSafeConv3d(2 * out_channels, 2 * out_channels, 1) if use_quant_conv else None
158
+ self.post_quant_conv = CogVideoXSafeConv3d(out_channels, out_channels, 1) if use_post_quant_conv else None
159
+
160
+ self.mini_batch_encoder = mini_batch_encoder
161
+ self.mini_batch_decoder = mini_batch_decoder
162
+ self.train_decoder_only = train_decoder_only
163
+ self.train_encoder_only = train_encoder_only
164
+ if train_decoder_only:
165
+ self.encoder.requires_grad_(False)
166
+ if self.quant_conv is not None:
167
+ self.quant_conv.requires_grad_(False)
168
+ if train_encoder_only:
169
+ self.decoder.requires_grad_(False)
170
+ if self.post_quant_conv is not None:
171
+ self.post_quant_conv.requires_grad_(False)
172
+ if monitor is not None:
173
+ self.monitor = monitor
174
+ if ckpt_path is not None:
175
+ self.init_from_ckpt(ckpt_path, ignore_keys="loss")
176
+ if lossconfig is not None:
177
+ self.loss = instantiate_from_config(lossconfig)
178
+
179
+ def init_from_ckpt(self, path, ignore_keys=list()):
180
+ if path.endswith("safetensors"):
181
+ from safetensors.torch import load_file, safe_open
182
+ sd = load_file(path)
183
+ else:
184
+ sd = torch.load(path, map_location="cpu")
185
+ if "state_dict" in list(sd.keys()):
186
+ sd = sd["state_dict"]
187
+ keys = list(sd.keys())
188
+ for k in keys:
189
+ for ik in ignore_keys:
190
+ if k.startswith(ik):
191
+ print("Deleting key {} from state_dict.".format(k))
192
+ del sd[k]
193
+ m, u = self.load_state_dict(sd, strict=False) # loss.item can be ignored successfully
194
+ print(f"Restored from {path}")
195
+ print(f"missing keys: {str(m)}, unexpected keys: {str(u)}")
196
+
197
+ def encode(self, x: torch.Tensor) -> EncoderOutput:
198
+ h = self.encoder(x)
199
+ self.encoder._clear_fake_context_parallel_cache()
200
+
201
+ if self.quant_conv is not None:
202
+ moments: torch.Tensor = self.quant_conv(h)
203
+ else:
204
+ moments: torch.Tensor = h
205
+ mean, logvar = moments.chunk(2, dim=1)
206
+ posterior = DiagonalGaussianDistribution(mean, logvar)
207
+
208
+ return posterior
209
+
210
+ def decode(self, z: torch.Tensor) -> DecoderOutput:
211
+ if self.post_quant_conv is not None:
212
+ z = self.post_quant_conv(z)
213
+ decoded = self.decoder(z)
214
+ self.decoder._clear_fake_context_parallel_cache()
215
+ return decoded
216
+
217
+ def forward(self, input, sample_posterior=True):
218
+ if input.ndim==4:
219
+ input = input.unsqueeze(2)
220
+ posterior = self.encode(input)
221
+ if sample_posterior:
222
+ z = posterior.sample()
223
+ else:
224
+ z = posterior.mode()
225
+ # print("stt latent shape", z.shape)
226
+ dec = self.decode(z)
227
+ return dec, posterior
228
+
229
+ def get_input(self, batch, k):
230
+ x = batch[k]
231
+ if x.ndim==5:
232
+ x = x.permute(0, 4, 1, 2, 3).to(memory_format=torch.contiguous_format).float()
233
+ return x
234
+ if len(x.shape) == 3:
235
+ x = x[..., None]
236
+ x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
237
+ return x
238
+
239
+ def training_step(self, batch, batch_idx, optimizer_idx):
240
+ inputs = self.get_input(batch, self.image_key)
241
+ reconstructions, posterior = self(inputs)
242
+
243
+ if optimizer_idx == 0:
244
+ aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
245
+ last_layer=self.get_last_layer(), split="train")
246
+ self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
247
+ self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
248
+ return aeloss
249
+
250
+ if optimizer_idx == 1:
251
+ discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
252
+ last_layer=self.get_last_layer(), split="train")
253
+
254
+ self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
255
+ self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
256
+ return discloss
257
+
258
+ def validation_step(self, batch, batch_idx):
259
+ with torch.no_grad():
260
+ inputs = self.get_input(batch, self.image_key)
261
+ reconstructions, posterior = self(inputs)
262
+ aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
263
+ last_layer=self.get_last_layer(), split="val")
264
+
265
+ discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
266
+ last_layer=self.get_last_layer(), split="val")
267
+
268
+ self.log("val/rec_loss", log_dict_ae["val/rec_loss"])
269
+ self.log_dict(log_dict_ae)
270
+ self.log_dict(log_dict_disc)
271
+ return self.log_dict
272
+
273
+ def configure_optimizers(self):
274
+ lr = self.learning_rate
275
+ if self.train_decoder_only:
276
+ if self.post_quant_conv is not None:
277
+ training_list = list(self.decoder.parameters()) + list(self.post_quant_conv.parameters())
278
+ else:
279
+ training_list = list(self.decoder.parameters())
280
+ opt_ae = torch.optim.Adam(training_list, lr=lr, betas=(0.5, 0.9))
281
+ elif self.train_encoder_only:
282
+ if self.quant_conv is not None:
283
+ training_list = list(self.encoder.parameters()) + list(self.quant_conv.parameters())
284
+ else:
285
+ training_list = list(self.encoder.parameters())
286
+ opt_ae = torch.optim.Adam(training_list, lr=lr, betas=(0.5, 0.9))
287
+ else:
288
+ training_list = list(self.encoder.parameters()) + list(self.decoder.parameters())
289
+ if self.quant_conv is not None:
290
+ training_list = training_list + list(self.quant_conv.parameters())
291
+ if self.post_quant_conv is not None:
292
+ training_list = training_list + list(self.post_quant_conv.parameters())
293
+ opt_ae = torch.optim.Adam(training_list, lr=lr, betas=(0.5, 0.9))
294
+ opt_disc = torch.optim.Adam(
295
+ list(self.loss.discriminator3d.parameters()) + list(self.loss.discriminator.parameters()),
296
+ lr=lr, betas=(0.5, 0.9)
297
+ )
298
+ return [opt_ae, opt_disc], []
299
+
300
+ def get_last_layer(self):
301
+ return self.decoder.conv_out.conv.weight
302
+
303
+ @torch.no_grad()
304
+ def log_images(self, batch, only_inputs=False, **kwargs):
305
+ log = dict()
306
+ x = self.get_input(batch, self.image_key)
307
+ x = x.to(self.device)
308
+ if not only_inputs:
309
+ xrec, posterior = self(x)
310
+ if x.shape[1] > 3:
311
+ # colorize with random projection
312
+ assert xrec.shape[1] > 3
313
+ x = self.to_rgb(x)
314
+ xrec = self.to_rgb(xrec)
315
+ log["samples"] = self.decode(torch.randn_like(posterior.sample()))
316
+ log["reconstructions"] = xrec
317
+ log["inputs"] = x
318
+ return log
319
+
320
+ def to_rgb(self, x):
321
+ assert self.image_key == "segmentation"
322
+ if not hasattr(self, "colorize"):
323
+ self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
324
+ x = F.conv2d(x, weight=self.colorize)
325
+ x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
326
+ return x
easyanimate/vae/ldm/models/cogvideox_enc_dec.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ from typing import Optional, Tuple
16
+
17
+ import numpy as np
18
+ import torch
19
+ import torch.nn as nn
20
+ from diffusers.models.autoencoders.autoencoder_kl_cogvideox import (
21
+ CogVideoXCausalConv3d, CogVideoXDownBlock3D, CogVideoXMidBlock3D,
22
+ CogVideoXSafeConv3d, CogVideoXSpatialNorm3D, CogVideoXUpBlock3D)
23
+ from diffusers.utils import logging
24
+
25
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
26
+
27
+
28
+ class CogVideoXEncoder3D(nn.Module):
29
+ r"""
30
+ The `CogVideoXEncoder3D` layer of a variational autoencoder that encodes its input into a latent representation.
31
+
32
+ Args:
33
+ in_channels (`int`, *optional*, defaults to 3):
34
+ The number of input channels.
35
+ out_channels (`int`, *optional*, defaults to 3):
36
+ The number of output channels.
37
+ down_block_types (`Tuple[str, ...]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
38
+ The types of down blocks to use. See `~diffusers.models.unet_2d_blocks.get_down_block` for available
39
+ options.
40
+ block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
41
+ The number of output channels for each block.
42
+ act_fn (`str`, *optional*, defaults to `"silu"`):
43
+ The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
44
+ layers_per_block (`int`, *optional*, defaults to 2):
45
+ The number of layers per block.
46
+ norm_num_groups (`int`, *optional*, defaults to 32):
47
+ The number of groups for normalization.
48
+ """
49
+
50
+ _supports_gradient_checkpointing = True
51
+
52
+ def __init__(
53
+ self,
54
+ in_channels: int = 3,
55
+ out_channels: int = 16,
56
+ down_block_types: Tuple[str, ...] = (
57
+ "CogVideoXDownBlock3D",
58
+ "CogVideoXDownBlock3D",
59
+ "CogVideoXDownBlock3D",
60
+ "CogVideoXDownBlock3D",
61
+ ),
62
+ block_out_channels: Tuple[int, ...] = (128, 256, 256, 512),
63
+ layers_per_block: int = 3,
64
+ act_fn: str = "silu",
65
+ norm_eps: float = 1e-6,
66
+ norm_num_groups: int = 32,
67
+ dropout: float = 0.0,
68
+ pad_mode: str = "first",
69
+ temporal_compression_ratio: float = 4,
70
+ ):
71
+ super().__init__()
72
+
73
+ # log2 of temporal_compress_times
74
+ temporal_compress_level = int(np.log2(temporal_compression_ratio))
75
+
76
+ self.conv_in = CogVideoXCausalConv3d(in_channels, block_out_channels[0], kernel_size=3, pad_mode=pad_mode)
77
+ self.down_blocks = nn.ModuleList([])
78
+
79
+ # down blocks
80
+ output_channel = block_out_channels[0]
81
+ for i, down_block_type in enumerate(down_block_types):
82
+ input_channel = output_channel
83
+ output_channel = block_out_channels[i]
84
+ is_final_block = i == len(block_out_channels) - 1
85
+ compress_time = i < temporal_compress_level
86
+
87
+ if down_block_type == "CogVideoXDownBlock3D":
88
+ down_block = CogVideoXDownBlock3D(
89
+ in_channels=input_channel,
90
+ out_channels=output_channel,
91
+ temb_channels=0,
92
+ dropout=dropout,
93
+ num_layers=layers_per_block,
94
+ resnet_eps=norm_eps,
95
+ resnet_act_fn=act_fn,
96
+ resnet_groups=norm_num_groups,
97
+ add_downsample=not is_final_block,
98
+ compress_time=compress_time,
99
+ )
100
+ else:
101
+ raise ValueError("Invalid `down_block_type` encountered. Must be `CogVideoXDownBlock3D`")
102
+
103
+ self.down_blocks.append(down_block)
104
+
105
+ # mid block
106
+ self.mid_block = CogVideoXMidBlock3D(
107
+ in_channels=block_out_channels[-1],
108
+ temb_channels=0,
109
+ dropout=dropout,
110
+ num_layers=2,
111
+ resnet_eps=norm_eps,
112
+ resnet_act_fn=act_fn,
113
+ resnet_groups=norm_num_groups,
114
+ pad_mode=pad_mode,
115
+ )
116
+
117
+ self.norm_out = nn.GroupNorm(norm_num_groups, block_out_channels[-1], eps=1e-6)
118
+ self.conv_act = nn.SiLU()
119
+ self.conv_out = CogVideoXCausalConv3d(
120
+ block_out_channels[-1], 2 * out_channels, kernel_size=3, pad_mode=pad_mode
121
+ )
122
+
123
+ self.gradient_checkpointing = False
124
+
125
+ def _clear_fake_context_parallel_cache(self):
126
+ for name, module in self.named_modules():
127
+ if isinstance(module, CogVideoXCausalConv3d):
128
+ logger.debug(f"Clearing fake Context Parallel cache for layer: {name}")
129
+ module._clear_fake_context_parallel_cache()
130
+
131
+ def forward(self, sample: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
132
+ r"""The forward method of the `CogVideoXEncoder3D` class."""
133
+ hidden_states = self.conv_in(sample)
134
+
135
+ if self.training and self.gradient_checkpointing:
136
+
137
+ def create_custom_forward(module):
138
+ def custom_forward(*inputs):
139
+ return module(*inputs)
140
+
141
+ return custom_forward
142
+
143
+ # 1. Down
144
+ for down_block in self.down_blocks:
145
+ hidden_states = torch.utils.checkpoint.checkpoint(
146
+ create_custom_forward(down_block), hidden_states, temb, None
147
+ )
148
+
149
+ # 2. Mid
150
+ hidden_states = torch.utils.checkpoint.checkpoint(
151
+ create_custom_forward(self.mid_block), hidden_states, temb, None
152
+ )
153
+ else:
154
+ # 1. Down
155
+ for down_block in self.down_blocks:
156
+ hidden_states = down_block(hidden_states, temb, None)
157
+
158
+ # 2. Mid
159
+ hidden_states = self.mid_block(hidden_states, temb, None)
160
+
161
+ # 3. Post-process
162
+ hidden_states = self.norm_out(hidden_states)
163
+ hidden_states = self.conv_act(hidden_states)
164
+ hidden_states = self.conv_out(hidden_states)
165
+ return hidden_states
166
+
167
+
168
+ class CogVideoXDecoder3D(nn.Module):
169
+ r"""
170
+ The `CogVideoXDecoder3D` layer of a variational autoencoder that decodes its latent representation into an output
171
+ sample.
172
+
173
+ Args:
174
+ in_channels (`int`, *optional*, defaults to 3):
175
+ The number of input channels.
176
+ out_channels (`int`, *optional*, defaults to 3):
177
+ The number of output channels.
178
+ up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
179
+ The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options.
180
+ block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
181
+ The number of output channels for each block.
182
+ act_fn (`str`, *optional*, defaults to `"silu"`):
183
+ The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
184
+ layers_per_block (`int`, *optional*, defaults to 2):
185
+ The number of layers per block.
186
+ norm_num_groups (`int`, *optional*, defaults to 32):
187
+ The number of groups for normalization.
188
+ """
189
+
190
+ _supports_gradient_checkpointing = True
191
+
192
+ def __init__(
193
+ self,
194
+ in_channels: int = 16,
195
+ out_channels: int = 3,
196
+ up_block_types: Tuple[str, ...] = (
197
+ "CogVideoXUpBlock3D",
198
+ "CogVideoXUpBlock3D",
199
+ "CogVideoXUpBlock3D",
200
+ "CogVideoXUpBlock3D",
201
+ ),
202
+ block_out_channels: Tuple[int, ...] = (128, 256, 256, 512),
203
+ layers_per_block: int = 3,
204
+ act_fn: str = "silu",
205
+ norm_eps: float = 1e-6,
206
+ norm_num_groups: int = 32,
207
+ dropout: float = 0.0,
208
+ pad_mode: str = "first",
209
+ temporal_compression_ratio: float = 4,
210
+ ):
211
+ super().__init__()
212
+
213
+ reversed_block_out_channels = list(reversed(block_out_channels))
214
+
215
+ self.conv_in = CogVideoXCausalConv3d(
216
+ in_channels, reversed_block_out_channels[0], kernel_size=3, pad_mode=pad_mode
217
+ )
218
+
219
+ # mid block
220
+ self.mid_block = CogVideoXMidBlock3D(
221
+ in_channels=reversed_block_out_channels[0],
222
+ temb_channels=0,
223
+ num_layers=2,
224
+ resnet_eps=norm_eps,
225
+ resnet_act_fn=act_fn,
226
+ resnet_groups=norm_num_groups,
227
+ spatial_norm_dim=in_channels,
228
+ pad_mode=pad_mode,
229
+ )
230
+
231
+ # up blocks
232
+ self.up_blocks = nn.ModuleList([])
233
+
234
+ output_channel = reversed_block_out_channels[0]
235
+ temporal_compress_level = int(np.log2(temporal_compression_ratio))
236
+
237
+ for i, up_block_type in enumerate(up_block_types):
238
+ prev_output_channel = output_channel
239
+ output_channel = reversed_block_out_channels[i]
240
+ is_final_block = i == len(block_out_channels) - 1
241
+ compress_time = i < temporal_compress_level
242
+
243
+ if up_block_type == "CogVideoXUpBlock3D":
244
+ up_block = CogVideoXUpBlock3D(
245
+ in_channels=prev_output_channel,
246
+ out_channels=output_channel,
247
+ temb_channels=0,
248
+ dropout=dropout,
249
+ num_layers=layers_per_block + 1,
250
+ resnet_eps=norm_eps,
251
+ resnet_act_fn=act_fn,
252
+ resnet_groups=norm_num_groups,
253
+ spatial_norm_dim=in_channels,
254
+ add_upsample=not is_final_block,
255
+ compress_time=compress_time,
256
+ pad_mode=pad_mode,
257
+ )
258
+ prev_output_channel = output_channel
259
+ else:
260
+ raise ValueError("Invalid `up_block_type` encountered. Must be `CogVideoXUpBlock3D`")
261
+
262
+ self.up_blocks.append(up_block)
263
+
264
+ self.norm_out = CogVideoXSpatialNorm3D(reversed_block_out_channels[-1], in_channels, groups=norm_num_groups)
265
+ self.conv_act = nn.SiLU()
266
+ self.conv_out = CogVideoXCausalConv3d(
267
+ reversed_block_out_channels[-1], out_channels, kernel_size=3, pad_mode=pad_mode
268
+ )
269
+
270
+ self.gradient_checkpointing = False
271
+
272
+ def _clear_fake_context_parallel_cache(self):
273
+ for name, module in self.named_modules():
274
+ if isinstance(module, CogVideoXCausalConv3d):
275
+ logger.debug(f"Clearing fake Context Parallel cache for layer: {name}")
276
+ module._clear_fake_context_parallel_cache()
277
+
278
+ def forward(self, sample: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
279
+ r"""The forward method of the `CogVideoXDecoder3D` class."""
280
+ hidden_states = self.conv_in(sample)
281
+
282
+ if self.training and self.gradient_checkpointing:
283
+
284
+ def create_custom_forward(module):
285
+ def custom_forward(*inputs):
286
+ return module(*inputs)
287
+
288
+ return custom_forward
289
+
290
+ # 1. Mid
291
+ hidden_states = torch.utils.checkpoint.checkpoint(
292
+ create_custom_forward(self.mid_block), hidden_states, temb, sample
293
+ )
294
+
295
+ # 2. Up
296
+ for up_block in self.up_blocks:
297
+ hidden_states = torch.utils.checkpoint.checkpoint(
298
+ create_custom_forward(up_block), hidden_states, temb, sample
299
+ )
300
+ else:
301
+ # 1. Mid
302
+ hidden_states = self.mid_block(hidden_states, temb, sample)
303
+
304
+ # 2. Up
305
+ for up_block in self.up_blocks:
306
+ hidden_states = up_block(hidden_states, temb, sample)
307
+
308
+ # 3. Post-process
309
+ hidden_states = self.norm_out(hidden_states, sample)
310
+ hidden_states = self.conv_act(hidden_states)
311
+ hidden_states = self.conv_out(hidden_states)
312
+ return hidden_states
easyanimate/vae/ldm/models/{enc_dec_pytorch.py → enc_dec.py} RENAMED
File without changes
easyanimate/vae/ldm/models/omnigen_casual3dcnn.py CHANGED
@@ -1,4 +1,3 @@
1
- import itertools
2
  from dataclasses import dataclass
3
  from typing import Optional
4
 
@@ -112,10 +111,15 @@ class AutoencoderKLMagvit_fromOmnigen(pl.LightningModule):
112
  monitor=None,
113
  ckpt_path=None,
114
  lossconfig=None,
 
115
  slice_compression_vae=False,
 
 
 
116
  mini_batch_encoder=9,
117
  mini_batch_decoder=3,
118
  train_decoder_only=False,
 
119
  ):
120
  super().__init__()
121
  self.image_key = image_key
@@ -137,7 +141,10 @@ class AutoencoderKLMagvit_fromOmnigen(pl.LightningModule):
137
  act_fn=act_fn,
138
  num_attention_heads=num_attention_heads,
139
  double_z=True,
 
140
  slice_compression_vae=slice_compression_vae,
 
 
141
  mini_batch_encoder=mini_batch_encoder,
142
  )
143
 
@@ -156,7 +163,11 @@ class AutoencoderKLMagvit_fromOmnigen(pl.LightningModule):
156
  norm_num_groups=norm_num_groups,
157
  act_fn=act_fn,
158
  num_attention_heads=num_attention_heads,
 
159
  slice_compression_vae=slice_compression_vae,
 
 
 
160
  mini_batch_decoder=mini_batch_decoder,
161
  )
162
 
@@ -166,9 +177,15 @@ class AutoencoderKLMagvit_fromOmnigen(pl.LightningModule):
166
  self.mini_batch_encoder = mini_batch_encoder
167
  self.mini_batch_decoder = mini_batch_decoder
168
  self.train_decoder_only = train_decoder_only
 
169
  if train_decoder_only:
170
  self.encoder.requires_grad_(False)
171
- self.quant_conv.requires_grad_(False)
 
 
 
 
 
172
  if monitor is not None:
173
  self.monitor = monitor
174
  if ckpt_path is not None:
@@ -190,28 +207,28 @@ class AutoencoderKLMagvit_fromOmnigen(pl.LightningModule):
190
  if k.startswith(ik):
191
  print("Deleting key {} from state_dict.".format(k))
192
  del sd[k]
193
- self.load_state_dict(sd, strict=False) # loss.item can be ignored successfully
194
  print(f"Restored from {path}")
 
195
 
196
  def encode(self, x: torch.Tensor) -> EncoderOutput:
197
  h = self.encoder(x)
198
 
199
- moments: torch.Tensor = self.quant_conv(h)
 
 
 
200
  mean, logvar = moments.chunk(2, dim=1)
201
  posterior = DiagonalGaussianDistribution(mean, logvar)
202
 
203
- # return EncoderOutput(latent_dist=posterior)
204
  return posterior
205
 
206
  def decode(self, z: torch.Tensor) -> DecoderOutput:
207
- z = self.post_quant_conv(z)
208
-
209
  decoded = self.decoder(z)
210
-
211
- # return DecoderOutput(sample=decoded)
212
  return decoded
213
 
214
-
215
  def forward(self, input, sample_posterior=True):
216
  if input.ndim==4:
217
  input = input.unsqueeze(2)
@@ -235,30 +252,22 @@ class AutoencoderKLMagvit_fromOmnigen(pl.LightningModule):
235
  return x
236
 
237
  def training_step(self, batch, batch_idx, optimizer_idx):
238
- # tic = time.time()
239
  inputs = self.get_input(batch, self.image_key)
240
- # print(f"get_input time {time.time() - tic}")
241
- # tic = time.time()
242
  reconstructions, posterior = self(inputs)
243
- # print(f"model forward time {time.time() - tic}")
244
 
245
  if optimizer_idx == 0:
246
- # train encoder+decoder+logvar
247
  aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
248
  last_layer=self.get_last_layer(), split="train")
249
  self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
250
  self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
251
- # print(f"cal loss time {time.time() - tic}")
252
  return aeloss
253
 
254
  if optimizer_idx == 1:
255
- # train the discriminator
256
  discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
257
  last_layer=self.get_last_layer(), split="train")
258
 
259
  self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
260
  self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
261
- # print(f"cal loss time {time.time() - tic}")
262
  return discloss
263
 
264
  def validation_step(self, batch, batch_idx):
@@ -279,17 +288,28 @@ class AutoencoderKLMagvit_fromOmnigen(pl.LightningModule):
279
  def configure_optimizers(self):
280
  lr = self.learning_rate
281
  if self.train_decoder_only:
282
- opt_ae = torch.optim.Adam(list(self.decoder.parameters())+
283
- list(self.post_quant_conv.parameters()),
284
- lr=lr, betas=(0.5, 0.9))
 
 
 
 
 
 
 
 
285
  else:
286
- opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
287
- list(self.decoder.parameters())+
288
- list(self.quant_conv.parameters())+
289
- list(self.post_quant_conv.parameters()),
290
- lr=lr, betas=(0.5, 0.9))
291
- opt_disc = torch.optim.Adam(list(self.loss.discriminator3d.parameters()) + list(self.loss.discriminator.parameters()),
292
- lr=lr, betas=(0.5, 0.9))
 
 
 
293
  return [opt_ae, opt_disc], []
294
 
295
  def get_last_layer(self):
 
 
1
  from dataclasses import dataclass
2
  from typing import Optional
3
 
 
111
  monitor=None,
112
  ckpt_path=None,
113
  lossconfig=None,
114
+ slice_mag_vae=False,
115
  slice_compression_vae=False,
116
+ cache_compression_vae=False,
117
+ cache_mag_vae=False,
118
+ spatial_group_norm=False,
119
  mini_batch_encoder=9,
120
  mini_batch_decoder=3,
121
  train_decoder_only=False,
122
+ train_encoder_only=False,
123
  ):
124
  super().__init__()
125
  self.image_key = image_key
 
141
  act_fn=act_fn,
142
  num_attention_heads=num_attention_heads,
143
  double_z=True,
144
+ slice_mag_vae=slice_mag_vae,
145
  slice_compression_vae=slice_compression_vae,
146
+ cache_compression_vae=cache_compression_vae,
147
+ spatial_group_norm=spatial_group_norm,
148
  mini_batch_encoder=mini_batch_encoder,
149
  )
150
 
 
163
  norm_num_groups=norm_num_groups,
164
  act_fn=act_fn,
165
  num_attention_heads=num_attention_heads,
166
+ slice_mag_vae=slice_mag_vae,
167
  slice_compression_vae=slice_compression_vae,
168
+ cache_compression_vae=cache_compression_vae,
169
+ cache_mag_vae=cache_mag_vae,
170
+ spatial_group_norm=spatial_group_norm,
171
  mini_batch_decoder=mini_batch_decoder,
172
  )
173
 
 
177
  self.mini_batch_encoder = mini_batch_encoder
178
  self.mini_batch_decoder = mini_batch_decoder
179
  self.train_decoder_only = train_decoder_only
180
+ self.train_encoder_only = train_encoder_only
181
  if train_decoder_only:
182
  self.encoder.requires_grad_(False)
183
+ if self.quant_conv is not None:
184
+ self.quant_conv.requires_grad_(False)
185
+ if train_encoder_only:
186
+ self.decoder.requires_grad_(False)
187
+ if self.post_quant_conv is not None:
188
+ self.post_quant_conv.requires_grad_(False)
189
  if monitor is not None:
190
  self.monitor = monitor
191
  if ckpt_path is not None:
 
207
  if k.startswith(ik):
208
  print("Deleting key {} from state_dict.".format(k))
209
  del sd[k]
210
+ m, u = self.load_state_dict(sd, strict=False) # loss.item can be ignored successfully
211
  print(f"Restored from {path}")
212
+ print(f"missing keys: {str(m)}, unexpected keys: {str(u)}")
213
 
214
  def encode(self, x: torch.Tensor) -> EncoderOutput:
215
  h = self.encoder(x)
216
 
217
+ if self.quant_conv is not None:
218
+ moments: torch.Tensor = self.quant_conv(h)
219
+ else:
220
+ moments: torch.Tensor = h
221
  mean, logvar = moments.chunk(2, dim=1)
222
  posterior = DiagonalGaussianDistribution(mean, logvar)
223
 
 
224
  return posterior
225
 
226
  def decode(self, z: torch.Tensor) -> DecoderOutput:
227
+ if self.post_quant_conv is not None:
228
+ z = self.post_quant_conv(z)
229
  decoded = self.decoder(z)
 
 
230
  return decoded
231
 
 
232
  def forward(self, input, sample_posterior=True):
233
  if input.ndim==4:
234
  input = input.unsqueeze(2)
 
252
  return x
253
 
254
  def training_step(self, batch, batch_idx, optimizer_idx):
 
255
  inputs = self.get_input(batch, self.image_key)
 
 
256
  reconstructions, posterior = self(inputs)
 
257
 
258
  if optimizer_idx == 0:
 
259
  aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
260
  last_layer=self.get_last_layer(), split="train")
261
  self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
262
  self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
 
263
  return aeloss
264
 
265
  if optimizer_idx == 1:
 
266
  discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
267
  last_layer=self.get_last_layer(), split="train")
268
 
269
  self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
270
  self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
 
271
  return discloss
272
 
273
  def validation_step(self, batch, batch_idx):
 
288
  def configure_optimizers(self):
289
  lr = self.learning_rate
290
  if self.train_decoder_only:
291
+ if self.post_quant_conv is not None:
292
+ training_list = list(self.decoder.parameters()) + list(self.post_quant_conv.parameters())
293
+ else:
294
+ training_list = list(self.decoder.parameters())
295
+ opt_ae = torch.optim.Adam(training_list, lr=lr, betas=(0.5, 0.9))
296
+ elif self.train_encoder_only:
297
+ if self.quant_conv is not None:
298
+ training_list = list(self.encoder.parameters()) + list(self.quant_conv.parameters())
299
+ else:
300
+ training_list = list(self.encoder.parameters())
301
+ opt_ae = torch.optim.Adam(training_list, lr=lr, betas=(0.5, 0.9))
302
  else:
303
+ training_list = list(self.encoder.parameters()) + list(self.decoder.parameters())
304
+ if self.quant_conv is not None:
305
+ training_list = training_list + list(self.quant_conv.parameters())
306
+ if self.post_quant_conv is not None:
307
+ training_list = training_list + list(self.post_quant_conv.parameters())
308
+ opt_ae = torch.optim.Adam(training_list, lr=lr, betas=(0.5, 0.9))
309
+ opt_disc = torch.optim.Adam(
310
+ list(self.loss.discriminator3d.parameters()) + list(self.loss.discriminator.parameters()),
311
+ lr=lr, betas=(0.5, 0.9)
312
+ )
313
  return [opt_ae, opt_disc], []
314
 
315
  def get_last_layer(self):
easyanimate/vae/ldm/models/omnigen_enc_dec.py CHANGED
@@ -1,6 +1,10 @@
 
 
1
  import torch
2
  import torch.nn as nn
3
- import numpy as np
 
 
4
  from ..modules.vaemodules.activations import get_activation
5
  from ..modules.vaemodules.common import CausalConv3d
6
  from ..modules.vaemodules.down_blocks import get_down_block
@@ -8,6 +12,16 @@ from ..modules.vaemodules.mid_blocks import get_mid_block
8
  from ..modules.vaemodules.up_blocks import get_up_block
9
 
10
 
 
 
 
 
 
 
 
 
 
 
11
  class Encoder(nn.Module):
12
  r"""
13
  The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation.
@@ -54,7 +68,11 @@ class Encoder(nn.Module):
54
  act_fn: str = "silu",
55
  num_attention_heads: int = 1,
56
  double_z: bool = True,
 
57
  slice_compression_vae: bool = False,
 
 
 
58
  mini_batch_encoder: int = 9,
59
  verbose = False,
60
  ):
@@ -118,9 +136,12 @@ class Encoder(nn.Module):
118
  conv_out_channels = 2 * out_channels if double_z else out_channels
119
  self.conv_out = CausalConv3d(block_out_channels[-1], conv_out_channels, kernel_size=3)
120
 
 
121
  self.slice_compression_vae = slice_compression_vae
 
 
122
  self.mini_batch_encoder = mini_batch_encoder
123
- self.features_share = False
124
  self.verbose = verbose
125
 
126
  def set_padding_one_frame(self):
@@ -145,36 +166,142 @@ class Encoder(nn.Module):
145
  for name, module in self.named_children():
146
  _set_padding_more_frame(name, module)
147
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
  def single_forward(self, x: torch.Tensor, previous_features: torch.Tensor, after_features: torch.Tensor) -> torch.Tensor:
149
  # x: (B, C, T, H, W)
150
- if self.features_share and previous_features is not None and after_features is None:
 
 
151
  x = torch.concat([previous_features, x], 2)
152
- elif self.features_share and previous_features is None and after_features is not None:
153
  x = torch.concat([x, after_features], 2)
154
- elif self.features_share and previous_features is not None and after_features is not None:
155
  x = torch.concat([previous_features, x, after_features], 2)
156
 
157
- x = self.conv_in(x)
158
-
 
 
 
 
 
 
159
  for down_block in self.down_blocks:
160
- x = down_block(x)
 
 
 
 
 
 
 
161
 
162
  x = self.mid_block(x)
163
 
164
- x = self.conv_norm_out(x)
 
 
 
 
 
 
165
  x = self.conv_act(x)
166
  x = self.conv_out(x)
167
 
168
- if self.features_share and previous_features is not None and after_features is None:
169
  x = x[:, :, 1:]
170
- elif self.features_share and previous_features is None and after_features is not None:
171
  x = x[:, :, :2]
172
- elif self.features_share and previous_features is not None and after_features is not None:
173
  x = x[:, :, 1:3]
174
  return x
175
 
176
  def forward(self, x: torch.Tensor) -> torch.Tensor:
177
- if self.slice_compression_vae:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
  _, _, f, _, _ = x.size()
179
  if f % 2 != 0:
180
  self.set_padding_one_frame()
@@ -188,11 +315,15 @@ class Encoder(nn.Module):
188
  new_pixel_values = []
189
  start_index = 0
190
 
191
- previous_features = None
192
  for i in range(start_index, x.shape[2], self.mini_batch_encoder):
193
- after_features = x[:, :, i + self.mini_batch_encoder: i + self.mini_batch_encoder + 4, :, :] if i + self.mini_batch_encoder < x.shape[2] else None
194
- next_frames = self.single_forward(x[:, :, i: i + self.mini_batch_encoder, :, :], previous_features, after_features)
195
- previous_features = x[:, :, i + self.mini_batch_encoder - 4: i + self.mini_batch_encoder, :, :]
 
 
 
 
 
196
  new_pixel_values.append(next_frames)
197
  new_pixel_values = torch.cat(new_pixel_values, dim=2)
198
  else:
@@ -242,7 +373,11 @@ class Decoder(nn.Module):
242
  norm_num_groups: int = 32,
243
  act_fn: str = "silu",
244
  num_attention_heads: int = 1,
 
245
  slice_compression_vae: bool = False,
 
 
 
246
  mini_batch_decoder: int = 3,
247
  verbose = False,
248
  ):
@@ -309,9 +444,12 @@ class Decoder(nn.Module):
309
 
310
  self.conv_out = CausalConv3d(block_out_channels[0], out_channels, kernel_size=3)
311
 
 
312
  self.slice_compression_vae = slice_compression_vae
 
 
313
  self.mini_batch_decoder = mini_batch_decoder
314
- self.features_share = True
315
  self.verbose = verbose
316
 
317
  def set_padding_one_frame(self):
@@ -335,22 +473,90 @@ class Decoder(nn.Module):
335
  _set_padding_more_frame(sub_name, sub_mod)
336
  for name, module in self.named_children():
337
  _set_padding_more_frame(name, module)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
338
 
339
  def single_forward(self, x: torch.Tensor, previous_features: torch.Tensor, after_features: torch.Tensor) -> torch.Tensor:
340
  # x: (B, C, T, H, W)
341
- if self.features_share and previous_features is not None and after_features is None:
 
 
342
  b, c, t, h, w = x.size()
343
  x = torch.concat([previous_features, x], 2)
344
  x = self.conv_in(x)
345
  x = self.mid_block(x)
346
  x = x[:, :, -t:]
347
- elif self.features_share and previous_features is None and after_features is not None:
348
  b, c, t, h, w = x.size()
349
  x = torch.concat([x, after_features], 2)
350
  x = self.conv_in(x)
351
  x = self.mid_block(x)
352
  x = x[:, :, :t]
353
- elif self.features_share and previous_features is not None and after_features is not None:
354
  _, _, t_1, _, _ = previous_features.size()
355
  _, _, t_2, _, _ = x.size()
356
  x = torch.concat([previous_features, x, after_features], 2)
@@ -358,20 +564,76 @@ class Decoder(nn.Module):
358
  x = self.mid_block(x)
359
  x = x[:, :, t_1:(t_1 + t_2)]
360
  else:
361
- x = self.conv_in(x)
362
- x = self.mid_block(x)
363
-
 
 
 
 
 
 
 
 
 
 
 
 
364
  for up_block in self.up_blocks:
365
- x = up_block(x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
366
 
367
- x = self.conv_norm_out(x)
368
  x = self.conv_act(x)
369
  x = self.conv_out(x)
370
 
371
  return x
372
 
373
  def forward(self, x: torch.Tensor) -> torch.Tensor:
374
- if self.slice_compression_vae:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
375
  _, _, f, _, _ = x.size()
376
  if f % 2 != 0:
377
  self.set_padding_one_frame()
@@ -391,6 +653,13 @@ class Decoder(nn.Module):
391
  previous_features = x[:, :, i: i + self.mini_batch_decoder, :, :]
392
  new_pixel_values.append(next_frames)
393
  new_pixel_values = torch.cat(new_pixel_values, dim=2)
 
 
 
 
 
 
 
394
  else:
395
  new_pixel_values = self.single_forward(x, None, None)
396
  return new_pixel_values
 
1
+ from typing import Any, Dict
2
+
3
  import torch
4
  import torch.nn as nn
5
+ from diffusers.utils import is_torch_version
6
+ from einops import rearrange
7
+
8
  from ..modules.vaemodules.activations import get_activation
9
  from ..modules.vaemodules.common import CausalConv3d
10
  from ..modules.vaemodules.down_blocks import get_down_block
 
12
  from ..modules.vaemodules.up_blocks import get_up_block
13
 
14
 
15
+ def create_custom_forward(module, return_dict=None):
16
+ def custom_forward(*inputs):
17
+ if return_dict is not None:
18
+ return module(*inputs, return_dict=return_dict)
19
+ else:
20
+ return module(*inputs)
21
+
22
+ return custom_forward
23
+
24
+
25
  class Encoder(nn.Module):
26
  r"""
27
  The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation.
 
68
  act_fn: str = "silu",
69
  num_attention_heads: int = 1,
70
  double_z: bool = True,
71
+ slice_mag_vae: bool = False,
72
  slice_compression_vae: bool = False,
73
+ cache_compression_vae: bool = False,
74
+ cache_mag_vae: bool = False,
75
+ spatial_group_norm: bool = False,
76
  mini_batch_encoder: int = 9,
77
  verbose = False,
78
  ):
 
136
  conv_out_channels = 2 * out_channels if double_z else out_channels
137
  self.conv_out = CausalConv3d(block_out_channels[-1], conv_out_channels, kernel_size=3)
138
 
139
+ self.slice_mag_vae = slice_mag_vae
140
  self.slice_compression_vae = slice_compression_vae
141
+ self.cache_compression_vae = cache_compression_vae
142
+ self.cache_mag_vae = cache_mag_vae
143
  self.mini_batch_encoder = mini_batch_encoder
144
+ self.spatial_group_norm = spatial_group_norm
145
  self.verbose = verbose
146
 
147
  def set_padding_one_frame(self):
 
166
  for name, module in self.named_children():
167
  _set_padding_more_frame(name, module)
168
 
169
+ def set_magvit_padding_one_frame(self):
170
+ def _set_magvit_padding_one_frame(name, module):
171
+ if hasattr(module, 'padding_flag'):
172
+ if self.verbose:
173
+ print('Set pad mode for module[%s] type=%s' % (name, str(type(module))))
174
+ module.padding_flag = 3
175
+ for sub_name, sub_mod in module.named_children():
176
+ _set_magvit_padding_one_frame(sub_name, sub_mod)
177
+ for name, module in self.named_children():
178
+ _set_magvit_padding_one_frame(name, module)
179
+
180
+ def set_magvit_padding_more_frame(self):
181
+ def _set_magvit_padding_more_frame(name, module):
182
+ if hasattr(module, 'padding_flag'):
183
+ if self.verbose:
184
+ print('Set pad mode for module[%s] type=%s' % (name, str(type(module))))
185
+ module.padding_flag = 4
186
+ for sub_name, sub_mod in module.named_children():
187
+ _set_magvit_padding_more_frame(sub_name, sub_mod)
188
+ for name, module in self.named_children():
189
+ _set_magvit_padding_more_frame(name, module)
190
+
191
+ def set_cache_slice_vae_padding_one_frame(self):
192
+ def _set_cache_slice_vae_padding_one_frame(name, module):
193
+ if hasattr(module, 'padding_flag'):
194
+ if self.verbose:
195
+ print('Set pad mode for module[%s] type=%s' % (name, str(type(module))))
196
+ module.padding_flag = 5
197
+ for sub_name, sub_mod in module.named_children():
198
+ _set_cache_slice_vae_padding_one_frame(sub_name, sub_mod)
199
+ for name, module in self.named_children():
200
+ _set_cache_slice_vae_padding_one_frame(name, module)
201
+
202
+ def set_cache_slice_vae_padding_more_frame(self):
203
+ def _set_cache_slice_vae_padding_more_frame(name, module):
204
+ if hasattr(module, 'padding_flag'):
205
+ if self.verbose:
206
+ print('Set pad mode for module[%s] type=%s' % (name, str(type(module))))
207
+ module.padding_flag = 6
208
+ for sub_name, sub_mod in module.named_children():
209
+ _set_cache_slice_vae_padding_more_frame(sub_name, sub_mod)
210
+ for name, module in self.named_children():
211
+ _set_cache_slice_vae_padding_more_frame(name, module)
212
+
213
+ def set_3dgroupnorm_for_submodule(self):
214
+ def _set_3dgroupnorm_for_submodule(name, module):
215
+ if hasattr(module, 'set_3dgroupnorm'):
216
+ if self.verbose:
217
+ print('Set pad mode for module[%s] type=%s' % (name, str(type(module))))
218
+ module.set_3dgroupnorm = True
219
+ for sub_name, sub_mod in module.named_children():
220
+ _set_3dgroupnorm_for_submodule(sub_name, sub_mod)
221
+ for name, module in self.named_children():
222
+ _set_3dgroupnorm_for_submodule(name, module)
223
+
224
  def single_forward(self, x: torch.Tensor, previous_features: torch.Tensor, after_features: torch.Tensor) -> torch.Tensor:
225
  # x: (B, C, T, H, W)
226
+ if self.training:
227
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
228
+ if previous_features is not None and after_features is None:
229
  x = torch.concat([previous_features, x], 2)
230
+ elif previous_features is None and after_features is not None:
231
  x = torch.concat([x, after_features], 2)
232
+ elif previous_features is not None and after_features is not None:
233
  x = torch.concat([previous_features, x, after_features], 2)
234
 
235
+ if self.training:
236
+ x = torch.utils.checkpoint.checkpoint(
237
+ create_custom_forward(self.conv_in),
238
+ x,
239
+ **ckpt_kwargs,
240
+ )
241
+ else:
242
+ x = self.conv_in(x)
243
  for down_block in self.down_blocks:
244
+ if self.training:
245
+ x = torch.utils.checkpoint.checkpoint(
246
+ create_custom_forward(down_block),
247
+ x,
248
+ **ckpt_kwargs,
249
+ )
250
+ else:
251
+ x = down_block(x)
252
 
253
  x = self.mid_block(x)
254
 
255
+ if self.spatial_group_norm:
256
+ batch_size = x.shape[0]
257
+ x = rearrange(x, "b c t h w -> (b t) c h w")
258
+ x = self.conv_norm_out(x)
259
+ x = rearrange(x, "(b t) c h w -> b c t h w", b=batch_size)
260
+ else:
261
+ x = self.conv_norm_out(x)
262
  x = self.conv_act(x)
263
  x = self.conv_out(x)
264
 
265
+ if previous_features is not None and after_features is None:
266
  x = x[:, :, 1:]
267
+ elif previous_features is None and after_features is not None:
268
  x = x[:, :, :2]
269
+ elif previous_features is not None and after_features is not None:
270
  x = x[:, :, 1:3]
271
  return x
272
 
273
  def forward(self, x: torch.Tensor) -> torch.Tensor:
274
+ if self.spatial_group_norm:
275
+ self.set_3dgroupnorm_for_submodule()
276
+
277
+ if self.cache_mag_vae:
278
+ self.set_magvit_padding_one_frame()
279
+ first_frames = self.single_forward(x[:, :, 0:1, :, :], None, None)
280
+ self.set_magvit_padding_more_frame()
281
+ new_pixel_values = [first_frames]
282
+ for i in range(1, x.shape[2], self.mini_batch_encoder):
283
+ next_frames = self.single_forward(x[:, :, i: i + self.mini_batch_encoder, :, :], None, None)
284
+ new_pixel_values.append(next_frames)
285
+ new_pixel_values = torch.cat(new_pixel_values, dim=2)
286
+ elif self.cache_compression_vae:
287
+ _, _, f, _, _ = x.size()
288
+ if f % 2 != 0:
289
+ self.set_padding_one_frame()
290
+ first_frames = self.single_forward(x[:, :, 0:1, :, :], None, None)
291
+ self.set_padding_more_frame()
292
+
293
+ new_pixel_values = [first_frames]
294
+ start_index = 1
295
+ else:
296
+ self.set_padding_more_frame()
297
+ new_pixel_values = []
298
+ start_index = 0
299
+
300
+ for i in range(start_index, x.shape[2], self.mini_batch_encoder):
301
+ next_frames = self.single_forward(x[:, :, i: i + self.mini_batch_encoder, :, :], None, None)
302
+ new_pixel_values.append(next_frames)
303
+ new_pixel_values = torch.cat(new_pixel_values, dim=2)
304
+ elif self.slice_compression_vae:
305
  _, _, f, _, _ = x.size()
306
  if f % 2 != 0:
307
  self.set_padding_one_frame()
 
315
  new_pixel_values = []
316
  start_index = 0
317
 
 
318
  for i in range(start_index, x.shape[2], self.mini_batch_encoder):
319
+ next_frames = self.single_forward(x[:, :, i: i + self.mini_batch_encoder, :, :], None, None)
320
+ new_pixel_values.append(next_frames)
321
+ new_pixel_values = torch.cat(new_pixel_values, dim=2)
322
+ elif self.slice_mag_vae:
323
+ _, _, f, _, _ = x.size()
324
+ new_pixel_values = []
325
+ for i in range(0, x.shape[2], self.mini_batch_encoder):
326
+ next_frames = self.single_forward(x[:, :, i: i + self.mini_batch_encoder, :, :], None, None)
327
  new_pixel_values.append(next_frames)
328
  new_pixel_values = torch.cat(new_pixel_values, dim=2)
329
  else:
 
373
  norm_num_groups: int = 32,
374
  act_fn: str = "silu",
375
  num_attention_heads: int = 1,
376
+ slice_mag_vae: bool = False,
377
  slice_compression_vae: bool = False,
378
+ cache_compression_vae: bool = False,
379
+ cache_mag_vae: bool = False,
380
+ spatial_group_norm: bool = False,
381
  mini_batch_decoder: int = 3,
382
  verbose = False,
383
  ):
 
444
 
445
  self.conv_out = CausalConv3d(block_out_channels[0], out_channels, kernel_size=3)
446
 
447
+ self.slice_mag_vae = slice_mag_vae
448
  self.slice_compression_vae = slice_compression_vae
449
+ self.cache_compression_vae = cache_compression_vae
450
+ self.cache_mag_vae = cache_mag_vae
451
  self.mini_batch_decoder = mini_batch_decoder
452
+ self.spatial_group_norm = spatial_group_norm
453
  self.verbose = verbose
454
 
455
  def set_padding_one_frame(self):
 
473
  _set_padding_more_frame(sub_name, sub_mod)
474
  for name, module in self.named_children():
475
  _set_padding_more_frame(name, module)
476
+
477
+ def set_magvit_padding_one_frame(self):
478
+ def _set_magvit_padding_one_frame(name, module):
479
+ if hasattr(module, 'padding_flag'):
480
+ if self.verbose:
481
+ print('Set pad mode for module[%s] type=%s' % (name, str(type(module))))
482
+ module.padding_flag = 3
483
+ for sub_name, sub_mod in module.named_children():
484
+ _set_magvit_padding_one_frame(sub_name, sub_mod)
485
+ for name, module in self.named_children():
486
+ _set_magvit_padding_one_frame(name, module)
487
+
488
+ def set_magvit_padding_more_frame(self):
489
+ def _set_magvit_padding_more_frame(name, module):
490
+ if hasattr(module, 'padding_flag'):
491
+ if self.verbose:
492
+ print('Set pad mode for module[%s] type=%s' % (name, str(type(module))))
493
+ module.padding_flag = 4
494
+ for sub_name, sub_mod in module.named_children():
495
+ _set_magvit_padding_more_frame(sub_name, sub_mod)
496
+ for name, module in self.named_children():
497
+ _set_magvit_padding_more_frame(name, module)
498
+
499
+ def set_cache_slice_vae_padding_one_frame(self):
500
+ def _set_cache_slice_vae_padding_one_frame(name, module):
501
+ if hasattr(module, 'padding_flag'):
502
+ if self.verbose:
503
+ print('Set pad mode for module[%s] type=%s' % (name, str(type(module))))
504
+ module.padding_flag = 5
505
+ for sub_name, sub_mod in module.named_children():
506
+ _set_cache_slice_vae_padding_one_frame(sub_name, sub_mod)
507
+ for name, module in self.named_children():
508
+ _set_cache_slice_vae_padding_one_frame(name, module)
509
+
510
+ def set_cache_slice_vae_padding_more_frame(self):
511
+ def _set_cache_slice_vae_padding_more_frame(name, module):
512
+ if hasattr(module, 'padding_flag'):
513
+ if self.verbose:
514
+ print('Set pad mode for module[%s] type=%s' % (name, str(type(module))))
515
+ module.padding_flag = 6
516
+ for sub_name, sub_mod in module.named_children():
517
+ _set_cache_slice_vae_padding_more_frame(sub_name, sub_mod)
518
+ for name, module in self.named_children():
519
+ _set_cache_slice_vae_padding_more_frame(name, module)
520
+
521
+ def set_3dgroupnorm_for_submodule(self):
522
+ def _set_3dgroupnorm_for_submodule(name, module):
523
+ if hasattr(module, 'set_3dgroupnorm'):
524
+ if self.verbose:
525
+ print('Set pad mode for module[%s] type=%s' % (name, str(type(module))))
526
+ module.set_3dgroupnorm = True
527
+ for sub_name, sub_mod in module.named_children():
528
+ _set_3dgroupnorm_for_submodule(sub_name, sub_mod)
529
+ for name, module in self.named_children():
530
+ _set_3dgroupnorm_for_submodule(name, module)
531
+
532
+ def clear_cache(self):
533
+ def _clear_cache(name, module):
534
+ if hasattr(module, 'prev_features'):
535
+ if self.verbose:
536
+ print('Set pad mode for module[%s] type=%s' % (name, str(type(module))))
537
+ module.prev_features = None
538
+ for sub_name, sub_mod in module.named_children():
539
+ _clear_cache(sub_name, sub_mod)
540
+ for name, module in self.named_children():
541
+ _clear_cache(name, module)
542
 
543
  def single_forward(self, x: torch.Tensor, previous_features: torch.Tensor, after_features: torch.Tensor) -> torch.Tensor:
544
  # x: (B, C, T, H, W)
545
+ if self.training:
546
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
547
+ if previous_features is not None and after_features is None:
548
  b, c, t, h, w = x.size()
549
  x = torch.concat([previous_features, x], 2)
550
  x = self.conv_in(x)
551
  x = self.mid_block(x)
552
  x = x[:, :, -t:]
553
+ elif previous_features is None and after_features is not None:
554
  b, c, t, h, w = x.size()
555
  x = torch.concat([x, after_features], 2)
556
  x = self.conv_in(x)
557
  x = self.mid_block(x)
558
  x = x[:, :, :t]
559
+ elif previous_features is not None and after_features is not None:
560
  _, _, t_1, _, _ = previous_features.size()
561
  _, _, t_2, _, _ = x.size()
562
  x = torch.concat([previous_features, x, after_features], 2)
 
564
  x = self.mid_block(x)
565
  x = x[:, :, t_1:(t_1 + t_2)]
566
  else:
567
+ if self.training:
568
+ x = torch.utils.checkpoint.checkpoint(
569
+ create_custom_forward(self.conv_in),
570
+ x,
571
+ **ckpt_kwargs,
572
+ )
573
+ x = torch.utils.checkpoint.checkpoint(
574
+ create_custom_forward(self.mid_block),
575
+ x,
576
+ **ckpt_kwargs,
577
+ )
578
+ else:
579
+ x = self.conv_in(x)
580
+ x = self.mid_block(x)
581
+
582
  for up_block in self.up_blocks:
583
+ if self.training:
584
+ x = torch.utils.checkpoint.checkpoint(
585
+ create_custom_forward(up_block),
586
+ x,
587
+ **ckpt_kwargs,
588
+ )
589
+ else:
590
+ x = up_block(x)
591
+
592
+ if self.spatial_group_norm:
593
+ batch_size = x.shape[0]
594
+ x = rearrange(x, "b c t h w -> (b t) c h w")
595
+ x = self.conv_norm_out(x)
596
+ x = rearrange(x, "(b t) c h w -> b c t h w", b=batch_size)
597
+ else:
598
+ x = self.conv_norm_out(x)
599
 
 
600
  x = self.conv_act(x)
601
  x = self.conv_out(x)
602
 
603
  return x
604
 
605
  def forward(self, x: torch.Tensor) -> torch.Tensor:
606
+ if self.spatial_group_norm:
607
+ self.set_3dgroupnorm_for_submodule()
608
+
609
+ if self.cache_mag_vae:
610
+ self.set_magvit_padding_one_frame()
611
+ first_frames = self.single_forward(x[:, :, 0:1, :, :], None, None)
612
+ self.set_magvit_padding_more_frame()
613
+ new_pixel_values = [first_frames]
614
+ for i in range(1, x.shape[2], self.mini_batch_decoder):
615
+ next_frames = self.single_forward(x[:, :, i: i + self.mini_batch_decoder, :, :], None, None)
616
+ new_pixel_values.append(next_frames)
617
+ new_pixel_values = torch.cat(new_pixel_values, dim=2)
618
+ elif self.cache_compression_vae:
619
+ _, _, f, _, _ = x.size()
620
+ if f == 1:
621
+ self.set_padding_one_frame()
622
+ first_frames = self.single_forward(x[:, :, :1, :, :], None, None)
623
+ new_pixel_values = [first_frames]
624
+ start_index = 1
625
+ else:
626
+ self.set_cache_slice_vae_padding_one_frame()
627
+ first_frames = self.single_forward(x[:, :, :self.mini_batch_decoder, :, :], None, None)
628
+ new_pixel_values = [first_frames]
629
+ start_index = self.mini_batch_decoder
630
+
631
+ for i in range(start_index, x.shape[2], self.mini_batch_decoder):
632
+ self.set_cache_slice_vae_padding_more_frame()
633
+ next_frames = self.single_forward(x[:, :, i: i + self.mini_batch_decoder, :, :], None, None)
634
+ new_pixel_values.append(next_frames)
635
+ new_pixel_values = torch.cat(new_pixel_values, dim=2)
636
+ elif self.slice_compression_vae:
637
  _, _, f, _, _ = x.size()
638
  if f % 2 != 0:
639
  self.set_padding_one_frame()
 
653
  previous_features = x[:, :, i: i + self.mini_batch_decoder, :, :]
654
  new_pixel_values.append(next_frames)
655
  new_pixel_values = torch.cat(new_pixel_values, dim=2)
656
+ elif self.slice_mag_vae:
657
+ _, _, f, _, _ = x.size()
658
+ new_pixel_values = []
659
+ for i in range(0, x.shape[2], self.mini_batch_decoder):
660
+ next_frames = self.single_forward(x[:, :, i: i + self.mini_batch_decoder, :, :], None, None)
661
+ new_pixel_values.append(next_frames)
662
+ new_pixel_values = torch.cat(new_pixel_values, dim=2)
663
  else:
664
  new_pixel_values = self.single_forward(x, None, None)
665
  return new_pixel_values
easyanimate/vae/ldm/modules/ema.py CHANGED
@@ -1,7 +1,8 @@
1
  #-*- encoding:utf-8 -*-
2
  import torch
3
- from torch import nn
4
  from pytorch_lightning.callbacks import Callback
 
 
5
 
6
  class LitEma(nn.Module):
7
  def __init__(self, model, decay=0.9999, use_num_upates=True):
 
1
  #-*- encoding:utf-8 -*-
2
  import torch
 
3
  from pytorch_lightning.callbacks import Callback
4
+ from torch import nn
5
+
6
 
7
  class LitEma(nn.Module):
8
  def __init__(self, model, decay=0.9999, use_num_upates=True):
easyanimate/vae/ldm/modules/losses/contperceptual.py CHANGED
@@ -2,8 +2,10 @@ import torch
2
  import torch.nn as nn
3
  import torch.nn.functional as F
4
  from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no?
 
5
  from ..vaemodules.discriminator import Discriminator3D
6
 
 
7
  class LPIPSWithDiscriminator(nn.Module):
8
  def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0,
9
  disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
@@ -62,15 +64,6 @@ class LPIPSWithDiscriminator(nn.Module):
62
 
63
  # get new loss_weight
64
  loss_weights = 1
65
- # b, _ ,f, _, _ = reconstructions.size()
66
- # loss_weights = torch.ones([b, f]).view(b, 1, f, 1, 1)
67
- # loss_weights[:, :, 0] = 3
68
- # for i in range(1, f, 8):
69
- # loss_weights[:, :, i - 1] = 3
70
- # loss_weights[:, :, i] = 3
71
- # loss_weights[:, :, -1] = 3
72
- # loss_weights = loss_weights.permute(0, 2, 1, 3, 4).flatten(0, 1).to(reconstructions.device)
73
-
74
  inputs = inputs.permute(0, 2, 1, 3, 4).flatten(0, 1)
75
  reconstructions = reconstructions.permute(0, 2, 1, 3, 4).flatten(0, 1)
76
 
 
2
  import torch.nn as nn
3
  import torch.nn.functional as F
4
  from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no?
5
+
6
  from ..vaemodules.discriminator import Discriminator3D
7
 
8
+
9
  class LPIPSWithDiscriminator(nn.Module):
10
  def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0,
11
  disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
 
64
 
65
  # get new loss_weight
66
  loss_weights = 1
 
 
 
 
 
 
 
 
 
67
  inputs = inputs.permute(0, 2, 1, 3, 4).flatten(0, 1)
68
  reconstructions = reconstructions.permute(0, 2, 1, 3, 4).flatten(0, 1)
69
 
easyanimate/vae/ldm/modules/vaemodules/common.py CHANGED
@@ -38,7 +38,7 @@ class CausalConv3d(nn.Conv3d):
38
  assert len(dilation) == 3, f"Dilation must be a 3-tuple, got {dilation} instead."
39
 
40
  t_ks, h_ks, w_ks = kernel_size
41
- _, h_stride, w_stride = stride
42
  t_dilation, h_dilation, w_dilation = dilation
43
 
44
  t_pad = (t_ks - 1) * t_dilation
@@ -54,6 +54,7 @@ class CausalConv3d(nn.Conv3d):
54
  self.temporal_padding = t_pad
55
  self.temporal_padding_origin = math.ceil(((t_ks - 1) * w_dilation + (1 - w_stride)) / 2)
56
  self.padding_flag = 0
 
57
 
58
  super().__init__(
59
  in_channels=in_channels,
@@ -67,38 +68,81 @@ class CausalConv3d(nn.Conv3d):
67
 
68
  def forward(self, x: torch.Tensor) -> torch.Tensor:
69
  # x: (B, C, T, H, W)
 
 
70
  if self.padding_flag == 0:
71
  x = F.pad(
72
  x,
73
  pad=(0, 0, 0, 0, self.temporal_padding, 0),
74
  mode="replicate", # TODO: check if this is necessary
75
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  else:
77
  x = F.pad(
78
  x,
79
  pad=(0, 0, 0, 0, self.temporal_padding_origin, self.temporal_padding_origin),
80
  )
81
- return super().forward(x)
82
-
83
- def set_padding_one_frame(self):
84
- def _set_padding_one_frame(name, module):
85
- if hasattr(module, 'padding_flag'):
86
- print('Set pad mode for module[%s] type=%s' % (name, str(type(module))))
87
- module.padding_flag = 1
88
- for sub_name, sub_mod in module.named_children():
89
- _set_padding_one_frame(sub_name, sub_mod)
90
- for name, module in self.named_children():
91
- _set_padding_one_frame(name, module)
92
-
93
- def set_padding_more_frame(self):
94
- def _set_padding_more_frame(name, module):
95
- if hasattr(module, 'padding_flag'):
96
- print('Set pad mode for module[%s] type=%s' % (name, str(type(module))))
97
- module.padding_flag = 2
98
- for sub_name, sub_mod in module.named_children():
99
- _set_padding_more_frame(sub_name, sub_mod)
100
- for name, module in self.named_children():
101
- _set_padding_more_frame(name, module)
102
 
103
  class ResidualBlock2D(nn.Module):
104
  def __init__(
@@ -142,15 +186,29 @@ class ResidualBlock2D(nn.Module):
142
  else:
143
  self.shortcut = nn.Identity()
144
 
 
 
145
  def forward(self, x: torch.Tensor) -> torch.Tensor:
146
  shortcut = self.shortcut(x)
147
 
148
- x = self.norm1(x)
 
 
 
 
 
 
149
  x = self.nonlinearity(x)
150
 
151
  x = self.conv1(x)
152
 
153
- x = self.norm2(x)
 
 
 
 
 
 
154
  x = self.nonlinearity(x)
155
 
156
  x = self.dropout(x)
@@ -201,15 +259,29 @@ class ResidualBlock3D(nn.Module):
201
  else:
202
  self.shortcut = nn.Identity()
203
 
 
 
204
  def forward(self, x: torch.Tensor) -> torch.Tensor:
205
  shortcut = self.shortcut(x)
206
 
207
- x = self.norm1(x)
 
 
 
 
 
 
208
  x = self.nonlinearity(x)
209
 
210
  x = self.conv1(x)
211
 
212
- x = self.norm2(x)
 
 
 
 
 
 
213
  x = self.nonlinearity(x)
214
 
215
  x = self.dropout(x)
@@ -238,11 +310,18 @@ class SpatialNorm2D(nn.Module):
238
  self.norm = nn.GroupNorm(num_channels=f_channels, num_groups=32, eps=1e-6, affine=True)
239
  self.conv_y = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
240
  self.conv_b = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
 
241
 
242
  def forward(self, f: torch.FloatTensor, zq: torch.FloatTensor) -> torch.FloatTensor:
243
  f_size = f.shape[-2:]
244
  zq = F.interpolate(zq, size=f_size, mode="nearest")
245
- norm_f = self.norm(f)
 
 
 
 
 
 
246
  new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
247
  return new_f
248
 
 
38
  assert len(dilation) == 3, f"Dilation must be a 3-tuple, got {dilation} instead."
39
 
40
  t_ks, h_ks, w_ks = kernel_size
41
+ self.t_stride, h_stride, w_stride = stride
42
  t_dilation, h_dilation, w_dilation = dilation
43
 
44
  t_pad = (t_ks - 1) * t_dilation
 
54
  self.temporal_padding = t_pad
55
  self.temporal_padding_origin = math.ceil(((t_ks - 1) * w_dilation + (1 - w_stride)) / 2)
56
  self.padding_flag = 0
57
+ self.prev_features = None
58
 
59
  super().__init__(
60
  in_channels=in_channels,
 
68
 
69
  def forward(self, x: torch.Tensor) -> torch.Tensor:
70
  # x: (B, C, T, H, W)
71
+ dtype = x.dtype
72
+ x = x.float()
73
  if self.padding_flag == 0:
74
  x = F.pad(
75
  x,
76
  pad=(0, 0, 0, 0, self.temporal_padding, 0),
77
  mode="replicate", # TODO: check if this is necessary
78
  )
79
+ x = x.to(dtype=dtype)
80
+ return super().forward(x)
81
+ elif self.padding_flag == 3:
82
+ x = F.pad(
83
+ x,
84
+ pad=(0, 0, 0, 0, self.temporal_padding, 0),
85
+ mode="replicate", # TODO: check if this is necessary
86
+ )
87
+ x = x.to(dtype=dtype)
88
+ self.prev_features = x[:, :, -self.temporal_padding:]
89
+
90
+ b, c, f, h, w = x.size()
91
+ outputs = []
92
+ i = 0
93
+ while i + self.temporal_padding + 1 <= f:
94
+ out = super().forward(x[:, :, i:i + self.temporal_padding + 1])
95
+ i += self.t_stride
96
+ outputs.append(out)
97
+ return torch.concat(outputs, 2)
98
+ elif self.padding_flag == 4:
99
+ if self.t_stride == 2:
100
+ x = torch.concat(
101
+ [self.prev_features[:, :, -(self.temporal_padding - 1):], x], dim = 2
102
+ )
103
+ else:
104
+ x = torch.concat(
105
+ [self.prev_features, x], dim = 2
106
+ )
107
+ x = x.to(dtype=dtype)
108
+ self.prev_features = x[:, :, -self.temporal_padding:]
109
+
110
+ b, c, f, h, w = x.size()
111
+ outputs = []
112
+ i = 0
113
+ while i + self.temporal_padding + 1 <= f:
114
+ out = super().forward(x[:, :, i:i + self.temporal_padding + 1])
115
+ i += self.t_stride
116
+ outputs.append(out)
117
+ return torch.concat(outputs, 2)
118
+ elif self.padding_flag == 5:
119
+ x = F.pad(
120
+ x,
121
+ pad=(0, 0, 0, 0, self.temporal_padding, 0),
122
+ mode="replicate", # TODO: check if this is necessary
123
+ )
124
+ x = x.to(dtype=dtype)
125
+ self.prev_features = x[:, :, -self.temporal_padding:]
126
+ return super().forward(x)
127
+ elif self.padding_flag == 6:
128
+ if self.t_stride == 2:
129
+ x = torch.concat(
130
+ [self.prev_features[:, :, -(self.temporal_padding - 1):], x], dim = 2
131
+ )
132
+ else:
133
+ x = torch.concat(
134
+ [self.prev_features, x], dim = 2
135
+ )
136
+ self.prev_features = x[:, :, -self.temporal_padding:]
137
+ x = x.to(dtype=dtype)
138
+ return super().forward(x)
139
  else:
140
  x = F.pad(
141
  x,
142
  pad=(0, 0, 0, 0, self.temporal_padding_origin, self.temporal_padding_origin),
143
  )
144
+ x = x.to(dtype=dtype)
145
+ return super().forward(x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
 
147
  class ResidualBlock2D(nn.Module):
148
  def __init__(
 
186
  else:
187
  self.shortcut = nn.Identity()
188
 
189
+ self.set_3dgroupnorm = False
190
+
191
  def forward(self, x: torch.Tensor) -> torch.Tensor:
192
  shortcut = self.shortcut(x)
193
 
194
+ if self.set_3dgroupnorm:
195
+ batch_size = x.shape[0]
196
+ x = rearrange(x, "b c t h w -> (b t) c h w")
197
+ x = self.norm1(x)
198
+ x = rearrange(x, "(b t) c h w -> b c t h w", b=batch_size)
199
+ else:
200
+ x = self.norm1(x)
201
  x = self.nonlinearity(x)
202
 
203
  x = self.conv1(x)
204
 
205
+ if self.set_3dgroupnorm:
206
+ batch_size = x.shape[0]
207
+ x = rearrange(x, "b c t h w -> (b t) c h w")
208
+ x = self.norm2(x)
209
+ x = rearrange(x, "(b t) c h w -> b c t h w", b=batch_size)
210
+ else:
211
+ x = self.norm2(x)
212
  x = self.nonlinearity(x)
213
 
214
  x = self.dropout(x)
 
259
  else:
260
  self.shortcut = nn.Identity()
261
 
262
+ self.set_3dgroupnorm = False
263
+
264
  def forward(self, x: torch.Tensor) -> torch.Tensor:
265
  shortcut = self.shortcut(x)
266
 
267
+ if self.set_3dgroupnorm:
268
+ batch_size = x.shape[0]
269
+ x = rearrange(x, "b c t h w -> (b t) c h w")
270
+ x = self.norm1(x)
271
+ x = rearrange(x, "(b t) c h w -> b c t h w", b=batch_size)
272
+ else:
273
+ x = self.norm1(x)
274
  x = self.nonlinearity(x)
275
 
276
  x = self.conv1(x)
277
 
278
+ if self.set_3dgroupnorm:
279
+ batch_size = x.shape[0]
280
+ x = rearrange(x, "b c t h w -> (b t) c h w")
281
+ x = self.norm2(x)
282
+ x = rearrange(x, "(b t) c h w -> b c t h w", b=batch_size)
283
+ else:
284
+ x = self.norm2(x)
285
  x = self.nonlinearity(x)
286
 
287
  x = self.dropout(x)
 
310
  self.norm = nn.GroupNorm(num_channels=f_channels, num_groups=32, eps=1e-6, affine=True)
311
  self.conv_y = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
312
  self.conv_b = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
313
+ self.set_3dgroupnorm = False
314
 
315
  def forward(self, f: torch.FloatTensor, zq: torch.FloatTensor) -> torch.FloatTensor:
316
  f_size = f.shape[-2:]
317
  zq = F.interpolate(zq, size=f_size, mode="nearest")
318
+ if self.set_3dgroupnorm:
319
+ batch_size = f.shape[0]
320
+ f = rearrange(f, "b c t h w -> (b t) c h w")
321
+ norm_f = self.norm(f)
322
+ norm_f = rearrange(norm_f, "(b t) c h w -> b c t h w", b=batch_size)
323
+ else:
324
+ norm_f = self.norm(f)
325
  new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
326
  return new_f
327
 
easyanimate/vae/ldm/modules/vaemodules/upsamplers.py CHANGED
@@ -137,6 +137,7 @@ class SpatialTemporalUpsampler3D(Upsampler):
137
  )
138
 
139
  self.padding_flag = 0
 
140
 
141
  def forward(self, x: torch.Tensor) -> torch.Tensor:
142
  x = F.interpolate(x, scale_factor=(1, 2, 2), mode="nearest")
@@ -145,32 +146,12 @@ class SpatialTemporalUpsampler3D(Upsampler):
145
  if self.padding_flag == 0:
146
  if x.shape[2] > 1:
147
  first_frame, x = x[:, :, :1], x[:, :, 1:]
148
- x = F.interpolate(x, scale_factor=(2, 1, 1), mode="trilinear")
149
  x = torch.cat([first_frame, x], dim=2)
150
- elif self.padding_flag == 2:
151
- x = F.interpolate(x, scale_factor=(2, 1, 1), mode="trilinear")
152
  return x
153
 
154
- def set_padding_one_frame(self):
155
- def _set_padding_one_frame(name, module):
156
- if hasattr(module, 'padding_flag'):
157
- print('Set pad mode for module[%s] type=%s' % (name, str(type(module))))
158
- module.padding_flag = 1
159
- for sub_name, sub_mod in module.named_children():
160
- _set_padding_one_frame(sub_name, sub_mod)
161
- for name, module in self.named_children():
162
- _set_padding_one_frame(name, module)
163
-
164
- def set_padding_more_frame(self):
165
- def _set_padding_more_frame(name, module):
166
- if hasattr(module, 'padding_flag'):
167
- print('Set pad mode for module[%s] type=%s' % (name, str(type(module))))
168
- module.padding_flag = 2
169
- for sub_name, sub_mod in module.named_children():
170
- _set_padding_more_frame(sub_name, sub_mod)
171
- for name, module in self.named_children():
172
- _set_padding_more_frame(name, module)
173
-
174
  class SpatialTemporalUpsamplerD2S3D(Upsampler):
175
  def __init__(self, in_channels: int, out_channels: int):
176
  super().__init__(
 
137
  )
138
 
139
  self.padding_flag = 0
140
+ self.set_3dgroupnorm = False
141
 
142
  def forward(self, x: torch.Tensor) -> torch.Tensor:
143
  x = F.interpolate(x, scale_factor=(1, 2, 2), mode="nearest")
 
146
  if self.padding_flag == 0:
147
  if x.shape[2] > 1:
148
  first_frame, x = x[:, :, :1], x[:, :, 1:]
149
+ x = F.interpolate(x, scale_factor=(2, 1, 1), mode="trilinear" if not self.set_3dgroupnorm else "nearest")
150
  x = torch.cat([first_frame, x], dim=2)
151
+ elif self.padding_flag == 2 or self.padding_flag == 4 or self.padding_flag == 5 or self.padding_flag == 6:
152
+ x = F.interpolate(x, scale_factor=(2, 1, 1), mode="trilinear" if not self.set_3dgroupnorm else "nearest")
153
  return x
154
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  class SpatialTemporalUpsamplerD2S3D(Upsampler):
156
  def __init__(self, in_channels: int, out_channels: int):
157
  super().__init__(
easyanimate/video_caption/README.md DELETED
@@ -1,90 +0,0 @@
1
- # Video Caption
2
- EasyAnimate uses multi-modal LLMs to generate captions for frames extracted from the video firstly, and then employs LLMs to summarize and refine the generated frame captions into the final video caption. By leveraging [sglang](https://github.com/sgl-project/sglang)/[vLLM](https://github.com/vllm-project/vllm) and [accelerate distributed inference](https://huggingface.co/docs/accelerate/en/usage_guides/distributed_inference), the entire processing could be very fast.
3
-
4
- English | [简体中文](./README_zh-CN.md)
5
-
6
- ## Quick Start
7
- 1. Cloud usage: AliyunDSW/Docker
8
-
9
- Check [README.md](../../README.md#quick-start) for details.
10
-
11
- 2. Local usage
12
-
13
- ```shell
14
- # Install EasyAnimate requirements firstly.
15
- cd EasyAnimate && pip install -r requirements.txt
16
-
17
- # Install additional requirements for video caption.
18
- cd easyanimate/video_caption && pip install -r requirements.txt --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
19
-
20
- # Use DDP instead of DP in EasyOCR detection.
21
- site_pkg_path=$(python -c 'import site; print(site.getsitepackages()[0])')
22
- cp -v easyocr_detection_patched.py $site_pkg_path/easyocr/detection.py
23
-
24
- # We strongly recommend using Docker unless you can properly handle the dependency between vllm with torch(cuda).
25
- ```
26
-
27
- ## Data preprocessing
28
- Data preprocessing can be divided into three parts:
29
-
30
- - Video cut.
31
- - Video cleaning.
32
- - Video caption.
33
-
34
- The input for data preprocessing can be a video folder or a metadata file (txt/csv/jsonl) containing the video path column. Please check `get_video_path_list` function in [utils/video_utils.py](utils/video_utils.py) for details.
35
-
36
- For easier understanding, we use one data from Panda70m as an example for data preprocessing, [Download here](https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/easyanimate/asset/v2/--C66yU3LjM_2.mp4). Please download the video and push it in "datasets/panda_70m/before_vcut/"
37
-
38
- ```
39
- 📦 datasets/
40
- ├── 📂 panda_70m/
41
- │ └── 📂 before_vcut/
42
- │ └── 📄 --C66yU3LjM_2.mp4
43
- ```
44
-
45
- 1. Video cut
46
-
47
- For long video cut, EasyAnimate utilizes PySceneDetect to identify scene changes within the video and performs scene cutting based on certain threshold values to ensure consistency in the themes of the video segments. After cutting, we only keep segments with lengths ranging from 3 to 10 seconds for model training.
48
-
49
- We have completed the parameters for ```stage_1_video_cut.sh```, so I can run it directly using the command sh ```stage_1_video_cut.sh```. After executing ```stage_1_video_cut.sh```, we obtained short videos in ```easyanimate/video_caption/datasets/panda_70m/train```.
50
-
51
- ```shell
52
- sh stage_1_video_cut.sh
53
- ```
54
- 2. Video cleaning
55
-
56
- Following SVD's data preparation process, EasyAnimate provides a simple yet effective data processing pipeline for high-quality data filtering and labeling. It also supports distributed processing to accelerate the speed of data preprocessing. The overall process is as follows:
57
-
58
- - Duration filtering: Analyze the basic information of the video to filter out low-quality videos that are short in duration or low in resolution. This filtering result is corresponding to the video cut (3s ~ 10s videos).
59
- - Aesthetic filtering: Filter out videos with poor content (blurry, dim, etc.) by calculating the average aesthetic score of uniformly distributed 4 frames.
60
- - Text filtering: Use easyocr to calculate the text proportion of middle frames to filter out videos with a large proportion of text.
61
- - Motion filtering: Calculate interframe optical flow differences to filter out videos that move too slowly or too quickly.
62
-
63
- The process file of **Aesthetic filtering** is ```compute_video_frame_quality.py```. After executing ```compute_video_frame_quality.py```, we obtained the file ```datasets/panda_70m/aesthetic_score.jsonl```, where each line corresponds to the aesthetic score of each video.
64
-
65
- The process file of **Text filtering** is ```compute_text_score.py```. After executing ```compute_text_score.py```, we obtained the file ```datasets/panda_70m/text_score.jsonl```, where each line corresponds to the text score of each video.
66
-
67
- The process file of **Motion filtering** is ```compute_motion_score.py```. Motion filtering is based on Aesthetic filtering and Text filtering; only samples that meet certain aesthetic scores and text scores will undergo calculation for the Motion score. After executing ```compute_motion_score.py```, we obtained the file ```datasets/panda_70m/motion_score.jsonl```, where each line corresponds to the motion score of each video.
68
-
69
- Then we need to filter videos by motion scores. After executing ```filter_videos_by_motion_score.py```, we get the file ```datasets/panda_70m/train.jsonl```, which includes the video we need to caption.
70
-
71
- We have completed the parameters for stage_2_filter_data.sh, so I can run it directly using the command sh stage_2_filter_data.sh.
72
-
73
- ```shell
74
- sh stage_2_filter_data.sh
75
- ```
76
- 3. Video caption
77
-
78
- Video captioning is carried out in two stages. The first stage involves extracting frames from a video and generating descriptions for them. Subsequently, a large language model is used to summarize these descriptions into a caption.
79
-
80
- We have conducted a detailed and manual comparison of open sourced multi-modal LLMs such as [Qwen-VL](https://huggingface.co/Qwen/Qwen-VL), [ShareGPT4V-7B](https://huggingface.co/Lin-Chen/ShareGPT4V-7B), [deepseek-vl-7b-chat](https://huggingface.co/deepseek-ai/deepseek-vl-7b-chat) and etc. And we found that [llava-v1.6-vicuna-7b](https://huggingface.co/liuhaotian/llava-v1.6-vicuna-7b) is capable of generating more detailed captions with fewer hallucinations. Additionally, it is supported by serving engines like [sglang](https://github.com/sgl-project/sglang) and [lmdepoly](https://github.com/InternLM/lmdeploy), enabling faster inference.
81
-
82
- Firstly, we use ```caption_video_frame.py``` to generate frame captions. Then, we use ```caption_summary.py``` to generate summary captions.
83
-
84
- We have completed the parameters for stage_3_video_caption.sh, so I can run it directly using the command sh stage_3_video_caption.sh. After executing ```stage_3_video_cut.sh```, we obtained last json ```train_panda_70m.json``` for easyanimate training.
85
-
86
- ```shell
87
- sh stage_3_video_caption.sh
88
- ```
89
-
90
- If you cannot access to Huggingface, you can run `export HF_ENDPOINT=https://hf-mirror.com` before the above command to download the summary caption model automatically.