Spaces:
Running
Running
bubbliiiing
commited on
Commit
•
f62c8b9
1
Parent(s):
ab9a89a
Update V5
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- app.py +31 -8
- config/easyanimate_image_magvit_v2.yaml +0 -8
- config/easyanimate_image_normal_v1.yaml +0 -8
- config/easyanimate_image_slicevae_v3.yaml +0 -9
- config/easyanimate_video_casual_motion_module_v1.yaml +0 -27
- config/easyanimate_video_long_sequence_v1.yaml +0 -14
- config/{easyanimate_video_motion_module_v1.yaml → easyanimate_video_v1_motion_module.yaml} +5 -7
- config/{easyanimate_video_slicevae_motion_module_v3.yaml → easyanimate_video_v2_magvit_motion_module.yaml} +11 -9
- config/{easyanimate_video_magvit_motion_module_v2.yaml → easyanimate_video_v3_slicevae_motion_module.yaml} +24 -11
- config/easyanimate_video_v4_slicevae_multi_text_encoder.yaml +20 -0
- config/easyanimate_video_v5_magvit_multi_text_encoder.yaml +19 -0
- config/zero_stage2_config.json +16 -0
- easyanimate/api/api.py +55 -9
- easyanimate/api/post_infer.py +0 -1
- easyanimate/data/dataset_image_video.py +311 -22
- easyanimate/models/__init__.py +16 -0
- easyanimate/models/attention.py +437 -659
- easyanimate/models/autoencoder_magvit.py +520 -4
- easyanimate/models/embeddings.py +107 -0
- easyanimate/models/norm.py +55 -2
- easyanimate/models/patch.py +0 -9
- easyanimate/models/processor.py +312 -0
- easyanimate/models/resampler.py +146 -0
- easyanimate/models/transformer2d.py +23 -58
- easyanimate/models/transformer3d.py +762 -70
- easyanimate/pipeline/pipeline_easyanimate.py +29 -39
- easyanimate/pipeline/pipeline_easyanimate_inpaint.py +90 -138
- easyanimate/pipeline/pipeline_easyanimate_multi_text_encoder.py +925 -0
- easyanimate/pipeline/pipeline_easyanimate_multi_text_encoder_control.py +996 -0
- easyanimate/pipeline/pipeline_easyanimate_multi_text_encoder_inpaint.py +1334 -0
- easyanimate/ui/ui.py +0 -0
- easyanimate/utils/discrete_sampler.py +46 -0
- easyanimate/utils/fp8_optimization.py +28 -0
- easyanimate/utils/lora_utils.py +26 -20
- easyanimate/utils/utils.py +64 -20
- easyanimate/vae/configs/autoencoder/autoencoder_kl_32x32x4_cogvideox.yaml +64 -0
- easyanimate/vae/configs/autoencoder/autoencoder_kl_32x32x4_mag_v2.yaml +65 -0
- easyanimate/vae/ldm/data/dataset_callback.py +1 -0
- easyanimate/vae/ldm/data/dataset_image_video.py +7 -4
- easyanimate/vae/ldm/models/casual3dcnn.py +337 -0
- easyanimate/vae/ldm/models/cogvideox_casual3dcnn.py +326 -0
- easyanimate/vae/ldm/models/cogvideox_enc_dec.py +312 -0
- easyanimate/vae/ldm/models/{enc_dec_pytorch.py → enc_dec.py} +0 -0
- easyanimate/vae/ldm/models/omnigen_casual3dcnn.py +48 -28
- easyanimate/vae/ldm/models/omnigen_enc_dec.py +296 -27
- easyanimate/vae/ldm/modules/ema.py +2 -1
- easyanimate/vae/ldm/modules/losses/contperceptual.py +2 -9
- easyanimate/vae/ldm/modules/vaemodules/common.py +106 -27
- easyanimate/vae/ldm/modules/vaemodules/upsamplers.py +4 -23
- easyanimate/video_caption/README.md +0 -90
app.py
CHANGED
@@ -1,27 +1,50 @@
|
|
1 |
-
import time
|
2 |
|
3 |
-
|
4 |
-
|
|
|
|
|
|
|
|
|
5 |
|
6 |
if __name__ == "__main__":
|
7 |
# Choose the ui mode
|
8 |
ui_mode = "eas"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
# Server ip
|
10 |
server_name = "0.0.0.0"
|
11 |
server_port = 7860
|
12 |
|
13 |
# Params below is used when ui_mode = "modelscope"
|
14 |
-
edition = "
|
15 |
-
|
16 |
-
|
|
|
|
|
|
|
|
|
|
|
17 |
savedir_sample = "samples"
|
18 |
|
19 |
if ui_mode == "modelscope":
|
20 |
-
demo, controller = ui_modelscope(edition, config_path, model_name, savedir_sample)
|
21 |
elif ui_mode == "eas":
|
22 |
demo, controller = ui_eas(edition, config_path, model_name, savedir_sample)
|
23 |
else:
|
24 |
-
demo, controller = ui()
|
25 |
|
26 |
# launch gradio
|
27 |
app, _, _ = demo.queue(status_update_rate=1).launch(
|
|
|
1 |
+
import time
|
2 |
|
3 |
+
import torch
|
4 |
+
|
5 |
+
from easyanimate.api.api import (infer_forward_api,
|
6 |
+
update_diffusion_transformer_api,
|
7 |
+
update_edition_api)
|
8 |
+
from easyanimate.ui.ui import ui, ui_eas, ui_modelscope
|
9 |
|
10 |
if __name__ == "__main__":
|
11 |
# Choose the ui mode
|
12 |
ui_mode = "eas"
|
13 |
+
|
14 |
+
# GPU memory mode, which can be choosen in ["model_cpu_offload", "model_cpu_offload_and_qfloat8", "sequential_cpu_offload"].
|
15 |
+
# "model_cpu_offload" means that the entire model will be moved to the CPU after use, which can save some GPU memory.
|
16 |
+
#
|
17 |
+
# "model_cpu_offload_and_qfloat8" indicates that the entire model will be moved to the CPU after use,
|
18 |
+
# and the transformer model has been quantized to float8, which can save more GPU memory.
|
19 |
+
#
|
20 |
+
# "sequential_cpu_offload" means that each layer of the model will be moved to the CPU after use,
|
21 |
+
# resulting in slower speeds but saving a large amount of GPU memory.
|
22 |
+
GPU_memory_mode = "model_cpu_offload_and_qfloat8"
|
23 |
+
# Use torch.float16 if GPU does not support torch.bfloat16
|
24 |
+
# ome graphics cards, such as v100, 2080ti, do not support torch.bfloat16
|
25 |
+
weight_dtype = torch.bfloat16
|
26 |
+
|
27 |
# Server ip
|
28 |
server_name = "0.0.0.0"
|
29 |
server_port = 7860
|
30 |
|
31 |
# Params below is used when ui_mode = "modelscope"
|
32 |
+
edition = "v5"
|
33 |
+
# Config
|
34 |
+
config_path = "config/easyanimate_video_v5_magvit_multi_text_encoder.yaml"
|
35 |
+
# Model path of the pretrained model
|
36 |
+
model_name = "models/Diffusion_Transformer/EasyAnimateV5-12b-zh-InP"
|
37 |
+
# "Inpaint" or "Control"
|
38 |
+
model_type = "Inpaint"
|
39 |
+
# Save dir
|
40 |
savedir_sample = "samples"
|
41 |
|
42 |
if ui_mode == "modelscope":
|
43 |
+
demo, controller = ui_modelscope(model_type, edition, config_path, model_name, savedir_sample, GPU_memory_mode, weight_dtype)
|
44 |
elif ui_mode == "eas":
|
45 |
demo, controller = ui_eas(edition, config_path, model_name, savedir_sample)
|
46 |
else:
|
47 |
+
demo, controller = ui(GPU_memory_mode, weight_dtype)
|
48 |
|
49 |
# launch gradio
|
50 |
app, _, _ = demo.queue(status_update_rate=1).launch(
|
config/easyanimate_image_magvit_v2.yaml
DELETED
@@ -1,8 +0,0 @@
|
|
1 |
-
noise_scheduler_kwargs:
|
2 |
-
beta_start: 0.0001
|
3 |
-
beta_end: 0.02
|
4 |
-
beta_schedule: "linear"
|
5 |
-
steps_offset: 1
|
6 |
-
|
7 |
-
vae_kwargs:
|
8 |
-
enable_magvit: true
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
config/easyanimate_image_normal_v1.yaml
DELETED
@@ -1,8 +0,0 @@
|
|
1 |
-
noise_scheduler_kwargs:
|
2 |
-
beta_start: 0.0001
|
3 |
-
beta_end: 0.02
|
4 |
-
beta_schedule: "linear"
|
5 |
-
steps_offset: 1
|
6 |
-
|
7 |
-
vae_kwargs:
|
8 |
-
enable_magvit: false
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
config/easyanimate_image_slicevae_v3.yaml
DELETED
@@ -1,9 +0,0 @@
|
|
1 |
-
noise_scheduler_kwargs:
|
2 |
-
beta_start: 0.0001
|
3 |
-
beta_end: 0.02
|
4 |
-
beta_schedule: "linear"
|
5 |
-
steps_offset: 1
|
6 |
-
|
7 |
-
vae_kwargs:
|
8 |
-
enable_magvit: true
|
9 |
-
slice_compression_vae: true
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
config/easyanimate_video_casual_motion_module_v1.yaml
DELETED
@@ -1,27 +0,0 @@
|
|
1 |
-
transformer_additional_kwargs:
|
2 |
-
patch_3d: false
|
3 |
-
fake_3d: false
|
4 |
-
casual_3d: true
|
5 |
-
casual_3d_upsampler_index: [16, 20]
|
6 |
-
time_patch_size: 4
|
7 |
-
basic_block_type: "motionmodule"
|
8 |
-
time_position_encoding_before_transformer: false
|
9 |
-
motion_module_type: "VanillaGrid"
|
10 |
-
|
11 |
-
motion_module_kwargs:
|
12 |
-
num_attention_heads: 8
|
13 |
-
num_transformer_block: 1
|
14 |
-
attention_block_types: [ "Temporal_Self", "Temporal_Self" ]
|
15 |
-
temporal_position_encoding: true
|
16 |
-
temporal_position_encoding_max_len: 4096
|
17 |
-
temporal_attention_dim_div: 1
|
18 |
-
block_size: 2
|
19 |
-
|
20 |
-
noise_scheduler_kwargs:
|
21 |
-
beta_start: 0.0001
|
22 |
-
beta_end: 0.02
|
23 |
-
beta_schedule: "linear"
|
24 |
-
steps_offset: 1
|
25 |
-
|
26 |
-
vae_kwargs:
|
27 |
-
enable_magvit: false
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
config/easyanimate_video_long_sequence_v1.yaml
DELETED
@@ -1,14 +0,0 @@
|
|
1 |
-
transformer_additional_kwargs:
|
2 |
-
patch_3d: false
|
3 |
-
fake_3d: false
|
4 |
-
basic_block_type: "selfattentiontemporal"
|
5 |
-
time_position_encoding_before_transformer: true
|
6 |
-
|
7 |
-
noise_scheduler_kwargs:
|
8 |
-
beta_start: 0.0001
|
9 |
-
beta_end: 0.02
|
10 |
-
beta_schedule: "linear"
|
11 |
-
steps_offset: 1
|
12 |
-
|
13 |
-
vae_kwargs:
|
14 |
-
enable_magvit: false
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
config/{easyanimate_video_motion_module_v1.yaml → easyanimate_video_v1_motion_module.yaml}
RENAMED
@@ -1,4 +1,5 @@
|
|
1 |
transformer_additional_kwargs:
|
|
|
2 |
patch_3d: false
|
3 |
fake_3d: false
|
4 |
basic_block_type: "motionmodule"
|
@@ -14,11 +15,8 @@ transformer_additional_kwargs:
|
|
14 |
temporal_attention_dim_div: 1
|
15 |
block_size: 2
|
16 |
|
17 |
-
noise_scheduler_kwargs:
|
18 |
-
beta_start: 0.0001
|
19 |
-
beta_end: 0.02
|
20 |
-
beta_schedule: "linear"
|
21 |
-
steps_offset: 1
|
22 |
-
|
23 |
vae_kwargs:
|
24 |
-
|
|
|
|
|
|
|
|
1 |
transformer_additional_kwargs:
|
2 |
+
transformer_type: "Transformer3DModel"
|
3 |
patch_3d: false
|
4 |
fake_3d: false
|
5 |
basic_block_type: "motionmodule"
|
|
|
15 |
temporal_attention_dim_div: 1
|
16 |
block_size: 2
|
17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
vae_kwargs:
|
19 |
+
vae_type: "AutoencoderKL"
|
20 |
+
|
21 |
+
text_encoder_kwargs:
|
22 |
+
enable_multi_text_encoder: false
|
config/{easyanimate_video_slicevae_motion_module_v3.yaml → easyanimate_video_v2_magvit_motion_module.yaml}
RENAMED
@@ -1,4 +1,5 @@
|
|
1 |
transformer_additional_kwargs:
|
|
|
2 |
patch_3d: false
|
3 |
fake_3d: false
|
4 |
basic_block_type: "motionmodule"
|
@@ -15,13 +16,14 @@ transformer_additional_kwargs:
|
|
15 |
temporal_attention_dim_div: 1
|
16 |
block_size: 1
|
17 |
|
18 |
-
noise_scheduler_kwargs:
|
19 |
-
beta_start: 0.0001
|
20 |
-
beta_end: 0.02
|
21 |
-
beta_schedule: "linear"
|
22 |
-
steps_offset: 1
|
23 |
-
|
24 |
vae_kwargs:
|
25 |
-
|
26 |
-
|
27 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
transformer_additional_kwargs:
|
2 |
+
transformer_type: "Transformer3DModel"
|
3 |
patch_3d: false
|
4 |
fake_3d: false
|
5 |
basic_block_type: "motionmodule"
|
|
|
16 |
temporal_attention_dim_div: 1
|
17 |
block_size: 1
|
18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
vae_kwargs:
|
20 |
+
vae_type: "AutoencoderKLMagvit"
|
21 |
+
mini_batch_encoder: 9
|
22 |
+
mini_batch_decoder: 3
|
23 |
+
slice_mag_vae: true
|
24 |
+
slice_compression_vae: false
|
25 |
+
cache_compression_vae: false
|
26 |
+
cache_mag_vae: false
|
27 |
+
|
28 |
+
text_encoder_kwargs:
|
29 |
+
enable_multi_text_encoder: false
|
config/{easyanimate_video_magvit_motion_module_v2.yaml → easyanimate_video_v3_slicevae_motion_module.yaml}
RENAMED
@@ -1,26 +1,39 @@
|
|
1 |
transformer_additional_kwargs:
|
|
|
2 |
patch_3d: false
|
3 |
fake_3d: false
|
4 |
-
basic_block_type: "
|
5 |
time_position_encoding_before_transformer: false
|
6 |
motion_module_type: "Vanilla"
|
7 |
enable_uvit: true
|
8 |
|
9 |
-
|
10 |
-
num_attention_heads:
|
11 |
num_transformer_block: 1
|
12 |
attention_block_types: [ "Temporal_Self", "Temporal_Self" ]
|
13 |
temporal_position_encoding: true
|
14 |
temporal_position_encoding_max_len: 4096
|
15 |
temporal_attention_dim_div: 1
|
16 |
block_size: 1
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
|
|
|
|
|
|
|
|
23 |
|
24 |
vae_kwargs:
|
25 |
-
|
26 |
-
mini_batch_encoder:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
transformer_additional_kwargs:
|
2 |
+
transformer_type: "Transformer3DModel"
|
3 |
patch_3d: false
|
4 |
fake_3d: false
|
5 |
+
basic_block_type: "global_motionmodule"
|
6 |
time_position_encoding_before_transformer: false
|
7 |
motion_module_type: "Vanilla"
|
8 |
enable_uvit: true
|
9 |
|
10 |
+
motion_module_kwargs_even:
|
11 |
+
num_attention_heads: 16
|
12 |
num_transformer_block: 1
|
13 |
attention_block_types: [ "Temporal_Self", "Temporal_Self" ]
|
14 |
temporal_position_encoding: true
|
15 |
temporal_position_encoding_max_len: 4096
|
16 |
temporal_attention_dim_div: 1
|
17 |
block_size: 1
|
18 |
+
remove_time_embedding_in_photo: false
|
19 |
+
motion_module_kwargs_odd:
|
20 |
+
num_attention_heads: 16
|
21 |
+
num_transformer_block: 1
|
22 |
+
attention_block_types: [ "Temporal_Self", "Global_Self" ]
|
23 |
+
temporal_position_encoding: true
|
24 |
+
temporal_position_encoding_max_len: 4096
|
25 |
+
temporal_attention_dim_div: 1
|
26 |
+
block_size: 1
|
27 |
+
remove_time_embedding_in_photo: false
|
28 |
|
29 |
vae_kwargs:
|
30 |
+
vae_type: "AutoencoderKLMagvit"
|
31 |
+
mini_batch_encoder: 8
|
32 |
+
mini_batch_decoder: 2
|
33 |
+
slice_mag_vae: false
|
34 |
+
slice_compression_vae: true
|
35 |
+
cache_compression_vae: false
|
36 |
+
cache_mag_vae: false
|
37 |
+
|
38 |
+
text_encoder_kwargs:
|
39 |
+
enable_multi_text_encoder: false
|
config/easyanimate_video_v4_slicevae_multi_text_encoder.yaml
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
transformer_additional_kwargs:
|
2 |
+
transformer_type: "HunyuanTransformer3DModel"
|
3 |
+
basic_block_type: "basic"
|
4 |
+
after_norm: false
|
5 |
+
time_position_encoding_type: "2d_rope"
|
6 |
+
time_position_encoding: true
|
7 |
+
resize_inpaint_mask_directly: false
|
8 |
+
enable_clip_in_inpaint: true
|
9 |
+
|
10 |
+
vae_kwargs:
|
11 |
+
vae_type: "AutoencoderKLMagvit"
|
12 |
+
mini_batch_encoder: 8
|
13 |
+
mini_batch_decoder: 2
|
14 |
+
slice_mag_vae: false
|
15 |
+
slice_compression_vae: false
|
16 |
+
cache_compression_vae: true
|
17 |
+
cache_mag_vae: false
|
18 |
+
|
19 |
+
text_encoder_kwargs:
|
20 |
+
enable_multi_text_encoder: true
|
config/easyanimate_video_v5_magvit_multi_text_encoder.yaml
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
transformer_additional_kwargs:
|
2 |
+
transformer_type: "EasyAnimateTransformer3DModel"
|
3 |
+
after_norm: false
|
4 |
+
time_position_encoding_type: "3d_rope"
|
5 |
+
resize_inpaint_mask_directly: true
|
6 |
+
enable_text_attention_mask: false
|
7 |
+
enable_clip_in_inpaint: false
|
8 |
+
|
9 |
+
vae_kwargs:
|
10 |
+
vae_type: "AutoencoderKLMagvit"
|
11 |
+
mini_batch_encoder: 4
|
12 |
+
mini_batch_decoder: 1
|
13 |
+
slice_mag_vae: false
|
14 |
+
slice_compression_vae: false
|
15 |
+
cache_compression_vae: false
|
16 |
+
cache_mag_vae: true
|
17 |
+
|
18 |
+
text_encoder_kwargs:
|
19 |
+
enable_multi_text_encoder: true
|
config/zero_stage2_config.json
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"bf16": {
|
3 |
+
"enabled": true
|
4 |
+
},
|
5 |
+
"train_micro_batch_size_per_gpu": 1,
|
6 |
+
"train_batch_size": "auto",
|
7 |
+
"gradient_accumulation_steps": "auto",
|
8 |
+
"dump_state": true,
|
9 |
+
"zero_optimization": {
|
10 |
+
"stage": 2,
|
11 |
+
"overlap_comm": true,
|
12 |
+
"contiguous_gradients": true,
|
13 |
+
"sub_group_size": 1e9,
|
14 |
+
"reduce_bucket_size": 5e8
|
15 |
+
}
|
16 |
+
}
|
easyanimate/api/api.py
CHANGED
@@ -1,15 +1,17 @@
|
|
1 |
-
import io
|
2 |
-
import gc
|
3 |
import base64
|
4 |
-
import
|
5 |
-
import gradio as gr
|
6 |
-
import tempfile
|
7 |
import hashlib
|
|
|
|
|
|
|
|
|
8 |
|
|
|
|
|
9 |
from fastapi import FastAPI
|
10 |
-
from io import BytesIO
|
11 |
from PIL import Image
|
12 |
|
|
|
13 |
# Function to encode a file to Base64
|
14 |
def encode_file_to_base64(file_path):
|
15 |
with open(file_path, "rb") as file:
|
@@ -53,6 +55,34 @@ def update_diffusion_transformer_api(_: gr.Blocks, app: FastAPI, controller):
|
|
53 |
|
54 |
return {"message": comment}
|
55 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
def infer_forward_api(_: gr.Blocks, app: FastAPI, controller):
|
57 |
@app.post("/easyanimate/infer_forward")
|
58 |
def _infer_forward_api(
|
@@ -63,7 +93,7 @@ def infer_forward_api(_: gr.Blocks, app: FastAPI, controller):
|
|
63 |
lora_model_path = datas.get('lora_model_path', 'none')
|
64 |
lora_alpha_slider = datas.get('lora_alpha_slider', 0.55)
|
65 |
prompt_textbox = datas.get('prompt_textbox', None)
|
66 |
-
negative_prompt_textbox = datas.get('negative_prompt_textbox', '
|
67 |
sampler_dropdown = datas.get('sampler_dropdown', 'Euler')
|
68 |
sample_step_slider = datas.get('sample_step_slider', 30)
|
69 |
resize_method = datas.get('resize_method', "Generate by")
|
@@ -72,17 +102,20 @@ def infer_forward_api(_: gr.Blocks, app: FastAPI, controller):
|
|
72 |
base_resolution = datas.get('base_resolution', 512)
|
73 |
is_image = datas.get('is_image', False)
|
74 |
generation_method = datas.get('generation_method', False)
|
75 |
-
length_slider = datas.get('length_slider',
|
76 |
overlap_video_length = datas.get('overlap_video_length', 4)
|
77 |
partial_video_length = datas.get('partial_video_length', 72)
|
78 |
cfg_scale_slider = datas.get('cfg_scale_slider', 6)
|
79 |
start_image = datas.get('start_image', None)
|
80 |
end_image = datas.get('end_image', None)
|
|
|
|
|
|
|
|
|
81 |
seed_textbox = datas.get("seed_textbox", 43)
|
82 |
|
83 |
generation_method = "Image Generation" if is_image else generation_method
|
84 |
|
85 |
-
temp_directory = tempfile.gettempdir()
|
86 |
if start_image is not None:
|
87 |
start_image = base64.b64decode(start_image)
|
88 |
start_image = [Image.open(BytesIO(start_image))]
|
@@ -91,6 +124,15 @@ def infer_forward_api(_: gr.Blocks, app: FastAPI, controller):
|
|
91 |
end_image = base64.b64decode(end_image)
|
92 |
end_image = [Image.open(BytesIO(end_image))]
|
93 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
try:
|
95 |
save_sample_path, comment = controller.generate(
|
96 |
"",
|
@@ -113,6 +155,10 @@ def infer_forward_api(_: gr.Blocks, app: FastAPI, controller):
|
|
113 |
cfg_scale_slider,
|
114 |
start_image,
|
115 |
end_image,
|
|
|
|
|
|
|
|
|
116 |
seed_textbox,
|
117 |
is_api = True,
|
118 |
)
|
|
|
|
|
|
|
1 |
import base64
|
2 |
+
import gc
|
|
|
|
|
3 |
import hashlib
|
4 |
+
import io
|
5 |
+
import os
|
6 |
+
import tempfile
|
7 |
+
from io import BytesIO
|
8 |
|
9 |
+
import gradio as gr
|
10 |
+
import torch
|
11 |
from fastapi import FastAPI
|
|
|
12 |
from PIL import Image
|
13 |
|
14 |
+
|
15 |
# Function to encode a file to Base64
|
16 |
def encode_file_to_base64(file_path):
|
17 |
with open(file_path, "rb") as file:
|
|
|
55 |
|
56 |
return {"message": comment}
|
57 |
|
58 |
+
def save_base64_video(base64_string):
|
59 |
+
video_data = base64.b64decode(base64_string)
|
60 |
+
|
61 |
+
md5_hash = hashlib.md5(video_data).hexdigest()
|
62 |
+
filename = f"{md5_hash}.mp4"
|
63 |
+
|
64 |
+
temp_dir = tempfile.gettempdir()
|
65 |
+
file_path = os.path.join(temp_dir, filename)
|
66 |
+
|
67 |
+
with open(file_path, 'wb') as video_file:
|
68 |
+
video_file.write(video_data)
|
69 |
+
|
70 |
+
return file_path
|
71 |
+
|
72 |
+
def save_base64_image(base64_string):
|
73 |
+
video_data = base64.b64decode(base64_string)
|
74 |
+
|
75 |
+
md5_hash = hashlib.md5(video_data).hexdigest()
|
76 |
+
filename = f"{md5_hash}.jpg"
|
77 |
+
|
78 |
+
temp_dir = tempfile.gettempdir()
|
79 |
+
file_path = os.path.join(temp_dir, filename)
|
80 |
+
|
81 |
+
with open(file_path, 'wb') as video_file:
|
82 |
+
video_file.write(video_data)
|
83 |
+
|
84 |
+
return file_path
|
85 |
+
|
86 |
def infer_forward_api(_: gr.Blocks, app: FastAPI, controller):
|
87 |
@app.post("/easyanimate/infer_forward")
|
88 |
def _infer_forward_api(
|
|
|
93 |
lora_model_path = datas.get('lora_model_path', 'none')
|
94 |
lora_alpha_slider = datas.get('lora_alpha_slider', 0.55)
|
95 |
prompt_textbox = datas.get('prompt_textbox', None)
|
96 |
+
negative_prompt_textbox = datas.get('negative_prompt_textbox', 'Unclear, mutated, deformed, distorted, dark frames, fixed frames, comic book, comic book, small and indistinguishable subject.')
|
97 |
sampler_dropdown = datas.get('sampler_dropdown', 'Euler')
|
98 |
sample_step_slider = datas.get('sample_step_slider', 30)
|
99 |
resize_method = datas.get('resize_method', "Generate by")
|
|
|
102 |
base_resolution = datas.get('base_resolution', 512)
|
103 |
is_image = datas.get('is_image', False)
|
104 |
generation_method = datas.get('generation_method', False)
|
105 |
+
length_slider = datas.get('length_slider', 49)
|
106 |
overlap_video_length = datas.get('overlap_video_length', 4)
|
107 |
partial_video_length = datas.get('partial_video_length', 72)
|
108 |
cfg_scale_slider = datas.get('cfg_scale_slider', 6)
|
109 |
start_image = datas.get('start_image', None)
|
110 |
end_image = datas.get('end_image', None)
|
111 |
+
validation_video = datas.get('validation_video', None)
|
112 |
+
validation_video_mask = datas.get('validation_video_mask', None)
|
113 |
+
control_video = datas.get('control_video', None)
|
114 |
+
denoise_strength = datas.get('denoise_strength', 0.70)
|
115 |
seed_textbox = datas.get("seed_textbox", 43)
|
116 |
|
117 |
generation_method = "Image Generation" if is_image else generation_method
|
118 |
|
|
|
119 |
if start_image is not None:
|
120 |
start_image = base64.b64decode(start_image)
|
121 |
start_image = [Image.open(BytesIO(start_image))]
|
|
|
124 |
end_image = base64.b64decode(end_image)
|
125 |
end_image = [Image.open(BytesIO(end_image))]
|
126 |
|
127 |
+
if validation_video is not None:
|
128 |
+
validation_video = save_base64_video(validation_video)
|
129 |
+
|
130 |
+
if validation_video_mask is not None:
|
131 |
+
validation_video_mask = save_base64_image(validation_video_mask)
|
132 |
+
|
133 |
+
if control_video is not None:
|
134 |
+
control_video = save_base64_video(control_video)
|
135 |
+
|
136 |
try:
|
137 |
save_sample_path, comment = controller.generate(
|
138 |
"",
|
|
|
155 |
cfg_scale_slider,
|
156 |
start_image,
|
157 |
end_image,
|
158 |
+
validation_video,
|
159 |
+
validation_video_mask,
|
160 |
+
control_video,
|
161 |
+
denoise_strength,
|
162 |
seed_textbox,
|
163 |
is_api = True,
|
164 |
)
|
easyanimate/api/post_infer.py
CHANGED
@@ -7,7 +7,6 @@ from io import BytesIO
|
|
7 |
|
8 |
import cv2
|
9 |
import requests
|
10 |
-
import base64
|
11 |
|
12 |
|
13 |
def post_diffusion_transformer(diffusion_transformer_path, url='http://127.0.0.1:7860'):
|
|
|
7 |
|
8 |
import cv2
|
9 |
import requests
|
|
|
10 |
|
11 |
|
12 |
def post_diffusion_transformer(diffusion_transformer_path, url='http://127.0.0.1:7860'):
|
easyanimate/data/dataset_image_video.py
CHANGED
@@ -1,24 +1,23 @@
|
|
1 |
import csv
|
|
|
2 |
import io
|
3 |
import json
|
4 |
import math
|
5 |
import os
|
6 |
import random
|
|
|
7 |
from threading import Thread
|
8 |
|
9 |
import albumentations
|
10 |
import cv2
|
11 |
-
import gc
|
12 |
import numpy as np
|
13 |
import torch
|
14 |
import torchvision.transforms as transforms
|
15 |
-
|
16 |
-
from func_timeout import func_timeout, FunctionTimedOut
|
17 |
from decord import VideoReader
|
|
|
18 |
from PIL import Image
|
19 |
from torch.utils.data import BatchSampler, Sampler
|
20 |
from torch.utils.data.dataset import Dataset
|
21 |
-
from contextlib import contextmanager
|
22 |
|
23 |
VIDEO_READER_TIMEOUT = 20
|
24 |
|
@@ -26,9 +25,9 @@ def get_random_mask(shape):
|
|
26 |
f, c, h, w = shape
|
27 |
|
28 |
if f != 1:
|
29 |
-
mask_index = np.random.
|
30 |
else:
|
31 |
-
mask_index = np.random.
|
32 |
mask = torch.zeros((f, 1, h, w), dtype=torch.uint8)
|
33 |
|
34 |
if mask_index == 0:
|
@@ -64,6 +63,40 @@ def get_random_mask(shape):
|
|
64 |
mask_frame_before = np.random.randint(0, f // 2)
|
65 |
mask_frame_after = np.random.randint(f // 2, f)
|
66 |
mask[mask_frame_before:mask_frame_after, :, start_y:end_y, start_x:end_x] = 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
else:
|
68 |
raise ValueError(f"The mask_index {mask_index} is not define")
|
69 |
return mask
|
@@ -128,19 +161,35 @@ def get_video_reader_batch(video_reader, batch_index):
|
|
128 |
frames = video_reader.get_batch(batch_index).asnumpy()
|
129 |
return frames
|
130 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
131 |
class ImageVideoDataset(Dataset):
|
132 |
def __init__(
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
# Loading annotations from files
|
145 |
print(f"loading annotations from {ann_path} ...")
|
146 |
if ann_path.endswith('.csv'):
|
@@ -176,11 +225,11 @@ class ImageVideoDataset(Dataset):
|
|
176 |
# Video params
|
177 |
self.video_sample_stride = video_sample_stride
|
178 |
self.video_sample_n_frames = video_sample_n_frames
|
179 |
-
video_sample_size = tuple(video_sample_size) if not isinstance(video_sample_size, int) else (video_sample_size, video_sample_size)
|
180 |
self.video_transforms = transforms.Compose(
|
181 |
[
|
182 |
-
transforms.Resize(video_sample_size
|
183 |
-
transforms.CenterCrop(video_sample_size),
|
184 |
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
|
185 |
]
|
186 |
)
|
@@ -193,7 +242,9 @@ class ImageVideoDataset(Dataset):
|
|
193 |
transforms.ToTensor(),
|
194 |
transforms.Normalize([0.5, 0.5, 0.5],[0.5, 0.5, 0.5])
|
195 |
])
|
196 |
-
|
|
|
|
|
197 |
def get_batch(self, idx):
|
198 |
data_info = self.dataset[idx % len(self.dataset)]
|
199 |
|
@@ -208,7 +259,7 @@ class ImageVideoDataset(Dataset):
|
|
208 |
with VideoReader_contextmanager(video_dir, num_threads=2) as video_reader:
|
209 |
min_sample_n_frames = min(
|
210 |
self.video_sample_n_frames,
|
211 |
-
int(len(video_reader) * (self.video_length_drop_end - self.video_length_drop_start))
|
212 |
)
|
213 |
if min_sample_n_frames == 0:
|
214 |
raise ValueError(f"No Frames in video.")
|
@@ -223,6 +274,12 @@ class ImageVideoDataset(Dataset):
|
|
223 |
pixel_values = func_timeout(
|
224 |
VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
|
225 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
226 |
except FunctionTimedOut:
|
227 |
raise ValueError(f"Read {idx} timeout.")
|
228 |
except Exception as e:
|
@@ -291,6 +348,238 @@ class ImageVideoDataset(Dataset):
|
|
291 |
clip_pixel_values = (clip_pixel_values * 0.5 + 0.5) * 255
|
292 |
sample["clip_pixel_values"] = clip_pixel_values
|
293 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
294 |
return sample
|
295 |
|
296 |
if __name__ == "__main__":
|
|
|
1 |
import csv
|
2 |
+
import gc
|
3 |
import io
|
4 |
import json
|
5 |
import math
|
6 |
import os
|
7 |
import random
|
8 |
+
from contextlib import contextmanager
|
9 |
from threading import Thread
|
10 |
|
11 |
import albumentations
|
12 |
import cv2
|
|
|
13 |
import numpy as np
|
14 |
import torch
|
15 |
import torchvision.transforms as transforms
|
|
|
|
|
16 |
from decord import VideoReader
|
17 |
+
from func_timeout import FunctionTimedOut, func_timeout
|
18 |
from PIL import Image
|
19 |
from torch.utils.data import BatchSampler, Sampler
|
20 |
from torch.utils.data.dataset import Dataset
|
|
|
21 |
|
22 |
VIDEO_READER_TIMEOUT = 20
|
23 |
|
|
|
25 |
f, c, h, w = shape
|
26 |
|
27 |
if f != 1:
|
28 |
+
mask_index = np.random.choice([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], p=[0.05, 0.2, 0.2, 0.2, 0.05, 0.05, 0.05, 0.1, 0.05, 0.05])
|
29 |
else:
|
30 |
+
mask_index = np.random.choice([0, 1], p = [0.2, 0.8])
|
31 |
mask = torch.zeros((f, 1, h, w), dtype=torch.uint8)
|
32 |
|
33 |
if mask_index == 0:
|
|
|
63 |
mask_frame_before = np.random.randint(0, f // 2)
|
64 |
mask_frame_after = np.random.randint(f // 2, f)
|
65 |
mask[mask_frame_before:mask_frame_after, :, start_y:end_y, start_x:end_x] = 1
|
66 |
+
elif mask_index == 5:
|
67 |
+
mask = torch.randint(0, 2, (f, 1, h, w), dtype=torch.uint8)
|
68 |
+
elif mask_index == 6:
|
69 |
+
num_frames_to_mask = random.randint(1, max(f // 2, 1))
|
70 |
+
frames_to_mask = random.sample(range(f), num_frames_to_mask)
|
71 |
+
|
72 |
+
for i in frames_to_mask:
|
73 |
+
block_height = random.randint(1, h // 4)
|
74 |
+
block_width = random.randint(1, w // 4)
|
75 |
+
top_left_y = random.randint(0, h - block_height)
|
76 |
+
top_left_x = random.randint(0, w - block_width)
|
77 |
+
mask[i, 0, top_left_y:top_left_y + block_height, top_left_x:top_left_x + block_width] = 1
|
78 |
+
elif mask_index == 7:
|
79 |
+
center_x = torch.randint(0, w, (1,)).item()
|
80 |
+
center_y = torch.randint(0, h, (1,)).item()
|
81 |
+
a = torch.randint(min(w, h) // 8, min(w, h) // 4, (1,)).item() # 长半轴
|
82 |
+
b = torch.randint(min(h, w) // 8, min(h, w) // 4, (1,)).item() # 短半轴
|
83 |
+
|
84 |
+
for i in range(h):
|
85 |
+
for j in range(w):
|
86 |
+
if ((i - center_y) ** 2) / (b ** 2) + ((j - center_x) ** 2) / (a ** 2) < 1:
|
87 |
+
mask[:, :, i, j] = 1
|
88 |
+
elif mask_index == 8:
|
89 |
+
center_x = torch.randint(0, w, (1,)).item()
|
90 |
+
center_y = torch.randint(0, h, (1,)).item()
|
91 |
+
radius = torch.randint(min(h, w) // 8, min(h, w) // 4, (1,)).item()
|
92 |
+
for i in range(h):
|
93 |
+
for j in range(w):
|
94 |
+
if (i - center_y) ** 2 + (j - center_x) ** 2 < radius ** 2:
|
95 |
+
mask[:, :, i, j] = 1
|
96 |
+
elif mask_index == 9:
|
97 |
+
for idx in range(f):
|
98 |
+
if np.random.rand() > 0.5:
|
99 |
+
mask[idx, :, :, :] = 1
|
100 |
else:
|
101 |
raise ValueError(f"The mask_index {mask_index} is not define")
|
102 |
return mask
|
|
|
161 |
frames = video_reader.get_batch(batch_index).asnumpy()
|
162 |
return frames
|
163 |
|
164 |
+
def resize_frame(frame, target_short_side):
|
165 |
+
h, w, _ = frame.shape
|
166 |
+
if h < w:
|
167 |
+
if target_short_side > h:
|
168 |
+
return frame
|
169 |
+
new_h = target_short_side
|
170 |
+
new_w = int(target_short_side * w / h)
|
171 |
+
else:
|
172 |
+
if target_short_side > w:
|
173 |
+
return frame
|
174 |
+
new_w = target_short_side
|
175 |
+
new_h = int(target_short_side * h / w)
|
176 |
+
|
177 |
+
resized_frame = cv2.resize(frame, (new_w, new_h))
|
178 |
+
return resized_frame
|
179 |
+
|
180 |
class ImageVideoDataset(Dataset):
|
181 |
def __init__(
|
182 |
+
self,
|
183 |
+
ann_path, data_root=None,
|
184 |
+
video_sample_size=512, video_sample_stride=4, video_sample_n_frames=16,
|
185 |
+
image_sample_size=512,
|
186 |
+
video_repeat=0,
|
187 |
+
text_drop_ratio=-1,
|
188 |
+
enable_bucket=False,
|
189 |
+
video_length_drop_start=0.1,
|
190 |
+
video_length_drop_end=0.9,
|
191 |
+
enable_inpaint=False,
|
192 |
+
):
|
193 |
# Loading annotations from files
|
194 |
print(f"loading annotations from {ann_path} ...")
|
195 |
if ann_path.endswith('.csv'):
|
|
|
225 |
# Video params
|
226 |
self.video_sample_stride = video_sample_stride
|
227 |
self.video_sample_n_frames = video_sample_n_frames
|
228 |
+
self.video_sample_size = tuple(video_sample_size) if not isinstance(video_sample_size, int) else (video_sample_size, video_sample_size)
|
229 |
self.video_transforms = transforms.Compose(
|
230 |
[
|
231 |
+
transforms.Resize(min(self.video_sample_size)),
|
232 |
+
transforms.CenterCrop(self.video_sample_size),
|
233 |
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
|
234 |
]
|
235 |
)
|
|
|
242 |
transforms.ToTensor(),
|
243 |
transforms.Normalize([0.5, 0.5, 0.5],[0.5, 0.5, 0.5])
|
244 |
])
|
245 |
+
|
246 |
+
self.larger_side_of_image_and_video = max(min(self.image_sample_size), min(self.video_sample_size))
|
247 |
+
|
248 |
def get_batch(self, idx):
|
249 |
data_info = self.dataset[idx % len(self.dataset)]
|
250 |
|
|
|
259 |
with VideoReader_contextmanager(video_dir, num_threads=2) as video_reader:
|
260 |
min_sample_n_frames = min(
|
261 |
self.video_sample_n_frames,
|
262 |
+
int(len(video_reader) * (self.video_length_drop_end - self.video_length_drop_start) // self.video_sample_stride)
|
263 |
)
|
264 |
if min_sample_n_frames == 0:
|
265 |
raise ValueError(f"No Frames in video.")
|
|
|
274 |
pixel_values = func_timeout(
|
275 |
VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
|
276 |
)
|
277 |
+
resized_frames = []
|
278 |
+
for i in range(len(pixel_values)):
|
279 |
+
frame = pixel_values[i]
|
280 |
+
resized_frame = resize_frame(frame, self.larger_side_of_image_and_video)
|
281 |
+
resized_frames.append(resized_frame)
|
282 |
+
pixel_values = np.array(resized_frames)
|
283 |
except FunctionTimedOut:
|
284 |
raise ValueError(f"Read {idx} timeout.")
|
285 |
except Exception as e:
|
|
|
348 |
clip_pixel_values = (clip_pixel_values * 0.5 + 0.5) * 255
|
349 |
sample["clip_pixel_values"] = clip_pixel_values
|
350 |
|
351 |
+
ref_pixel_values = sample["pixel_values"][0].unsqueeze(0)
|
352 |
+
if (mask == 1).all():
|
353 |
+
ref_pixel_values = torch.ones_like(ref_pixel_values) * -1
|
354 |
+
sample["ref_pixel_values"] = ref_pixel_values
|
355 |
+
|
356 |
+
return sample
|
357 |
+
|
358 |
+
|
359 |
+
class ImageVideoControlDataset(Dataset):
|
360 |
+
def __init__(
|
361 |
+
self,
|
362 |
+
ann_path, data_root=None,
|
363 |
+
video_sample_size=512, video_sample_stride=4, video_sample_n_frames=16,
|
364 |
+
image_sample_size=512,
|
365 |
+
video_repeat=0,
|
366 |
+
text_drop_ratio=-1,
|
367 |
+
enable_bucket=False,
|
368 |
+
video_length_drop_start=0.1,
|
369 |
+
video_length_drop_end=0.9,
|
370 |
+
enable_inpaint=False,
|
371 |
+
):
|
372 |
+
# Loading annotations from files
|
373 |
+
print(f"loading annotations from {ann_path} ...")
|
374 |
+
if ann_path.endswith('.csv'):
|
375 |
+
with open(ann_path, 'r') as csvfile:
|
376 |
+
dataset = list(csv.DictReader(csvfile))
|
377 |
+
elif ann_path.endswith('.json'):
|
378 |
+
dataset = json.load(open(ann_path))
|
379 |
+
|
380 |
+
self.data_root = data_root
|
381 |
+
|
382 |
+
# It's used to balance num of images and videos.
|
383 |
+
self.dataset = []
|
384 |
+
for data in dataset:
|
385 |
+
if data.get('type', 'image') != 'video':
|
386 |
+
self.dataset.append(data)
|
387 |
+
if video_repeat > 0:
|
388 |
+
for _ in range(video_repeat):
|
389 |
+
for data in dataset:
|
390 |
+
if data.get('type', 'image') == 'video':
|
391 |
+
self.dataset.append(data)
|
392 |
+
del dataset
|
393 |
+
|
394 |
+
self.length = len(self.dataset)
|
395 |
+
print(f"data scale: {self.length}")
|
396 |
+
# TODO: enable bucket training
|
397 |
+
self.enable_bucket = enable_bucket
|
398 |
+
self.text_drop_ratio = text_drop_ratio
|
399 |
+
self.enable_inpaint = enable_inpaint
|
400 |
+
|
401 |
+
self.video_length_drop_start = video_length_drop_start
|
402 |
+
self.video_length_drop_end = video_length_drop_end
|
403 |
+
|
404 |
+
# Video params
|
405 |
+
self.video_sample_stride = video_sample_stride
|
406 |
+
self.video_sample_n_frames = video_sample_n_frames
|
407 |
+
self.video_sample_size = tuple(video_sample_size) if not isinstance(video_sample_size, int) else (video_sample_size, video_sample_size)
|
408 |
+
self.video_transforms = transforms.Compose(
|
409 |
+
[
|
410 |
+
transforms.Resize(min(self.video_sample_size)),
|
411 |
+
transforms.CenterCrop(self.video_sample_size),
|
412 |
+
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
|
413 |
+
]
|
414 |
+
)
|
415 |
+
|
416 |
+
# Image params
|
417 |
+
self.image_sample_size = tuple(image_sample_size) if not isinstance(image_sample_size, int) else (image_sample_size, image_sample_size)
|
418 |
+
self.image_transforms = transforms.Compose([
|
419 |
+
transforms.Resize(min(self.image_sample_size)),
|
420 |
+
transforms.CenterCrop(self.image_sample_size),
|
421 |
+
transforms.ToTensor(),
|
422 |
+
transforms.Normalize([0.5, 0.5, 0.5],[0.5, 0.5, 0.5])
|
423 |
+
])
|
424 |
+
|
425 |
+
self.larger_side_of_image_and_video = max(min(self.image_sample_size), min(self.video_sample_size))
|
426 |
+
|
427 |
+
def get_batch(self, idx):
|
428 |
+
data_info = self.dataset[idx % len(self.dataset)]
|
429 |
+
video_id, text = data_info['file_path'], data_info['text']
|
430 |
+
|
431 |
+
if data_info.get('type', 'image')=='video':
|
432 |
+
if self.data_root is None:
|
433 |
+
video_dir = video_id
|
434 |
+
else:
|
435 |
+
video_dir = os.path.join(self.data_root, video_id)
|
436 |
+
|
437 |
+
with VideoReader_contextmanager(video_dir, num_threads=2) as video_reader:
|
438 |
+
min_sample_n_frames = min(
|
439 |
+
self.video_sample_n_frames,
|
440 |
+
int(len(video_reader) * (self.video_length_drop_end - self.video_length_drop_start) // self.video_sample_stride)
|
441 |
+
)
|
442 |
+
if min_sample_n_frames == 0:
|
443 |
+
raise ValueError(f"No Frames in video.")
|
444 |
+
|
445 |
+
video_length = int(self.video_length_drop_end * len(video_reader))
|
446 |
+
clip_length = min(video_length, (min_sample_n_frames - 1) * self.video_sample_stride + 1)
|
447 |
+
start_idx = random.randint(int(self.video_length_drop_start * video_length), video_length - clip_length) if video_length != clip_length else 0
|
448 |
+
batch_index = np.linspace(start_idx, start_idx + clip_length - 1, min_sample_n_frames, dtype=int)
|
449 |
+
|
450 |
+
try:
|
451 |
+
sample_args = (video_reader, batch_index)
|
452 |
+
pixel_values = func_timeout(
|
453 |
+
VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
|
454 |
+
)
|
455 |
+
resized_frames = []
|
456 |
+
for i in range(len(pixel_values)):
|
457 |
+
frame = pixel_values[i]
|
458 |
+
resized_frame = resize_frame(frame, self.larger_side_of_image_and_video)
|
459 |
+
resized_frames.append(resized_frame)
|
460 |
+
pixel_values = np.array(resized_frames)
|
461 |
+
except FunctionTimedOut:
|
462 |
+
raise ValueError(f"Read {idx} timeout.")
|
463 |
+
except Exception as e:
|
464 |
+
raise ValueError(f"Failed to extract frames from video. Error is {e}.")
|
465 |
+
|
466 |
+
if not self.enable_bucket:
|
467 |
+
pixel_values = torch.from_numpy(pixel_values).permute(0, 3, 1, 2).contiguous()
|
468 |
+
pixel_values = pixel_values / 255.
|
469 |
+
del video_reader
|
470 |
+
else:
|
471 |
+
pixel_values = pixel_values
|
472 |
+
|
473 |
+
if not self.enable_bucket:
|
474 |
+
pixel_values = self.video_transforms(pixel_values)
|
475 |
+
|
476 |
+
# Random use no text generation
|
477 |
+
if random.random() < self.text_drop_ratio:
|
478 |
+
text = ''
|
479 |
+
|
480 |
+
control_video_id = data_info['control_file_path']
|
481 |
+
|
482 |
+
if self.data_root is None:
|
483 |
+
control_video_id = control_video_id
|
484 |
+
else:
|
485 |
+
control_video_id = os.path.join(self.data_root, control_video_id)
|
486 |
+
|
487 |
+
with VideoReader_contextmanager(control_video_id, num_threads=2) as control_video_reader:
|
488 |
+
try:
|
489 |
+
sample_args = (control_video_reader, batch_index)
|
490 |
+
control_pixel_values = func_timeout(
|
491 |
+
VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
|
492 |
+
)
|
493 |
+
resized_frames = []
|
494 |
+
for i in range(len(control_pixel_values)):
|
495 |
+
frame = control_pixel_values[i]
|
496 |
+
resized_frame = resize_frame(frame, self.larger_side_of_image_and_video)
|
497 |
+
resized_frames.append(resized_frame)
|
498 |
+
control_pixel_values = np.array(resized_frames)
|
499 |
+
except FunctionTimedOut:
|
500 |
+
raise ValueError(f"Read {idx} timeout.")
|
501 |
+
except Exception as e:
|
502 |
+
raise ValueError(f"Failed to extract frames from video. Error is {e}.")
|
503 |
+
|
504 |
+
if not self.enable_bucket:
|
505 |
+
control_pixel_values = torch.from_numpy(control_pixel_values).permute(0, 3, 1, 2).contiguous()
|
506 |
+
control_pixel_values = control_pixel_values / 255.
|
507 |
+
del control_video_reader
|
508 |
+
else:
|
509 |
+
control_pixel_values = control_pixel_values
|
510 |
+
|
511 |
+
if not self.enable_bucket:
|
512 |
+
control_pixel_values = self.video_transforms(control_pixel_values)
|
513 |
+
return pixel_values, control_pixel_values, text, "video"
|
514 |
+
else:
|
515 |
+
image_path, text = data_info['file_path'], data_info['text']
|
516 |
+
if self.data_root is not None:
|
517 |
+
image_path = os.path.join(self.data_root, image_path)
|
518 |
+
image = Image.open(image_path).convert('RGB')
|
519 |
+
if not self.enable_bucket:
|
520 |
+
image = self.image_transforms(image).unsqueeze(0)
|
521 |
+
else:
|
522 |
+
image = np.expand_dims(np.array(image), 0)
|
523 |
+
|
524 |
+
if random.random() < self.text_drop_ratio:
|
525 |
+
text = ''
|
526 |
+
|
527 |
+
control_image_id = data_info['control_file_path']
|
528 |
+
|
529 |
+
if self.data_root is None:
|
530 |
+
control_image_id = control_image_id
|
531 |
+
else:
|
532 |
+
control_image_id = os.path.join(self.data_root, control_image_id)
|
533 |
+
|
534 |
+
control_image = Image.open(control_image_id).convert('RGB')
|
535 |
+
if not self.enable_bucket:
|
536 |
+
control_image = self.image_transforms(control_image).unsqueeze(0)
|
537 |
+
else:
|
538 |
+
control_image = np.expand_dims(np.array(control_image), 0)
|
539 |
+
return image, control_image, text, 'image'
|
540 |
+
|
541 |
+
def __len__(self):
|
542 |
+
return self.length
|
543 |
+
|
544 |
+
def __getitem__(self, idx):
|
545 |
+
data_info = self.dataset[idx % len(self.dataset)]
|
546 |
+
data_type = data_info.get('type', 'image')
|
547 |
+
while True:
|
548 |
+
sample = {}
|
549 |
+
try:
|
550 |
+
data_info_local = self.dataset[idx % len(self.dataset)]
|
551 |
+
data_type_local = data_info_local.get('type', 'image')
|
552 |
+
if data_type_local != data_type:
|
553 |
+
raise ValueError("data_type_local != data_type")
|
554 |
+
|
555 |
+
pixel_values, control_pixel_values, name, data_type = self.get_batch(idx)
|
556 |
+
sample["pixel_values"] = pixel_values
|
557 |
+
sample["control_pixel_values"] = control_pixel_values
|
558 |
+
sample["text"] = name
|
559 |
+
sample["data_type"] = data_type
|
560 |
+
sample["idx"] = idx
|
561 |
+
|
562 |
+
if len(sample) > 0:
|
563 |
+
break
|
564 |
+
except Exception as e:
|
565 |
+
print(e, self.dataset[idx % len(self.dataset)])
|
566 |
+
idx = random.randint(0, self.length-1)
|
567 |
+
|
568 |
+
if self.enable_inpaint and not self.enable_bucket:
|
569 |
+
mask = get_random_mask(pixel_values.size())
|
570 |
+
mask_pixel_values = pixel_values * (1 - mask) + torch.ones_like(pixel_values) * -1 * mask
|
571 |
+
sample["mask_pixel_values"] = mask_pixel_values
|
572 |
+
sample["mask"] = mask
|
573 |
+
|
574 |
+
clip_pixel_values = sample["pixel_values"][0].permute(1, 2, 0).contiguous()
|
575 |
+
clip_pixel_values = (clip_pixel_values * 0.5 + 0.5) * 255
|
576 |
+
sample["clip_pixel_values"] = clip_pixel_values
|
577 |
+
|
578 |
+
ref_pixel_values = sample["pixel_values"][0].unsqueeze(0)
|
579 |
+
if (mask == 1).all():
|
580 |
+
ref_pixel_values = torch.ones_like(ref_pixel_values) * -1
|
581 |
+
sample["ref_pixel_values"] = ref_pixel_values
|
582 |
+
|
583 |
return sample
|
584 |
|
585 |
if __name__ == "__main__":
|
easyanimate/models/__init__.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .autoencoder_magvit import (AutoencoderKLCogVideoX, AutoencoderKLMagvit, AutoencoderKL)
|
2 |
+
from .transformer3d import (EasyAnimateTransformer3DModel,
|
3 |
+
HunyuanTransformer3DModel,
|
4 |
+
Transformer3DModel)
|
5 |
+
|
6 |
+
|
7 |
+
name_to_transformer3d = {
|
8 |
+
"Transformer3DModel": Transformer3DModel,
|
9 |
+
"HunyuanTransformer3DModel": HunyuanTransformer3DModel,
|
10 |
+
"EasyAnimateTransformer3DModel": EasyAnimateTransformer3DModel,
|
11 |
+
}
|
12 |
+
name_to_autoencoder_magvit = {
|
13 |
+
"AutoencoderKL": AutoencoderKL,
|
14 |
+
"AutoencoderKLMagvit": AutoencoderKLMagvit,
|
15 |
+
"AutoencoderKLCogVideoX": AutoencoderKLCogVideoX,
|
16 |
+
}
|
easyanimate/models/attention.py
CHANGED
@@ -11,34 +11,38 @@
|
|
11 |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
# See the License for the specific language governing permissions and
|
13 |
# limitations under the License.
|
14 |
-
from typing import Any, Dict, Optional
|
15 |
|
16 |
import diffusers
|
17 |
import pkg_resources
|
18 |
import torch
|
19 |
import torch.nn.functional as F
|
20 |
import torch.nn.init as init
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
from diffusers.models.
|
32 |
-
from diffusers.models.
|
33 |
-
|
34 |
-
from diffusers.utils import USE_PEFT_BACKEND
|
35 |
from diffusers.utils.import_utils import is_xformers_available
|
36 |
from diffusers.utils.torch_utils import maybe_allow_in_graph
|
37 |
from einops import rearrange, repeat
|
38 |
from torch import nn
|
39 |
|
40 |
from .motion_module import PositionalEncoding, get_motion_module
|
41 |
-
from .norm import FP32LayerNorm
|
|
|
|
|
|
|
|
|
42 |
|
43 |
if is_xformers_available():
|
44 |
import xformers
|
@@ -53,7 +57,6 @@ def zero_module(module):
|
|
53 |
p.detach().zero_()
|
54 |
return module
|
55 |
|
56 |
-
|
57 |
@maybe_allow_in_graph
|
58 |
class GatedSelfAttentionDense(nn.Module):
|
59 |
r"""
|
@@ -95,267 +98,33 @@ class GatedSelfAttentionDense(nn.Module):
|
|
95 |
|
96 |
return x
|
97 |
|
98 |
-
|
99 |
-
class KVCompressionCrossAttention(nn.Module):
|
100 |
-
r"""
|
101 |
-
A cross attention layer.
|
102 |
-
|
103 |
-
Parameters:
|
104 |
-
query_dim (`int`): The number of channels in the query.
|
105 |
-
cross_attention_dim (`int`, *optional*):
|
106 |
-
The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
|
107 |
-
heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
|
108 |
-
dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
|
109 |
-
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
110 |
-
bias (`bool`, *optional*, defaults to False):
|
111 |
-
Set to `True` for the query, key, and value linear layers to contain a bias parameter.
|
112 |
-
"""
|
113 |
-
|
114 |
def __init__(
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
):
|
127 |
-
super().__init__()
|
128 |
-
inner_dim = dim_head * heads
|
129 |
-
cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
|
130 |
-
self.upcast_attention = upcast_attention
|
131 |
-
self.upcast_softmax = upcast_softmax
|
132 |
-
|
133 |
-
self.scale = dim_head**-0.5
|
134 |
-
|
135 |
-
self.heads = heads
|
136 |
-
# for slice_size > 0 the attention score computation
|
137 |
-
# is split across the batch axis to save memory
|
138 |
-
# You can set slice_size with `set_attention_slice`
|
139 |
-
self.sliceable_head_dim = heads
|
140 |
-
self._slice_size = None
|
141 |
-
self._use_memory_efficient_attention_xformers = True
|
142 |
-
self.added_kv_proj_dim = added_kv_proj_dim
|
143 |
-
|
144 |
-
if norm_num_groups is not None:
|
145 |
-
self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, eps=1e-5, affine=True)
|
146 |
-
else:
|
147 |
-
self.group_norm = None
|
148 |
-
|
149 |
-
self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
|
150 |
-
self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
|
151 |
-
self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
|
152 |
-
|
153 |
-
if self.added_kv_proj_dim is not None:
|
154 |
-
self.add_k_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
|
155 |
-
self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
|
156 |
-
|
157 |
-
self.kv_compression = nn.Conv2d(
|
158 |
-
query_dim,
|
159 |
-
query_dim,
|
160 |
-
groups=query_dim,
|
161 |
-
kernel_size=2,
|
162 |
-
stride=2,
|
163 |
bias=True
|
164 |
)
|
165 |
-
self.
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
self.to_out.append(nn.Dropout(dropout))
|
173 |
-
|
174 |
-
def reshape_heads_to_batch_dim(self, tensor):
|
175 |
-
batch_size, seq_len, dim = tensor.shape
|
176 |
-
head_size = self.heads
|
177 |
-
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
|
178 |
-
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
|
179 |
-
return tensor
|
180 |
-
|
181 |
-
def reshape_batch_dim_to_heads(self, tensor):
|
182 |
-
batch_size, seq_len, dim = tensor.shape
|
183 |
-
head_size = self.heads
|
184 |
-
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
|
185 |
-
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
|
186 |
-
return tensor
|
187 |
-
|
188 |
-
def set_attention_slice(self, slice_size):
|
189 |
-
if slice_size is not None and slice_size > self.sliceable_head_dim:
|
190 |
-
raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
|
191 |
-
|
192 |
-
self._slice_size = slice_size
|
193 |
-
|
194 |
-
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, num_frames: int = 16, height: int = 32, width: int = 32):
|
195 |
-
batch_size, sequence_length, _ = hidden_states.shape
|
196 |
-
|
197 |
-
encoder_hidden_states = encoder_hidden_states
|
198 |
-
|
199 |
-
if self.group_norm is not None:
|
200 |
-
hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
201 |
-
|
202 |
-
query = self.to_q(hidden_states)
|
203 |
-
dim = query.shape[-1]
|
204 |
-
query = self.reshape_heads_to_batch_dim(query)
|
205 |
-
|
206 |
-
if self.added_kv_proj_dim is not None:
|
207 |
-
key = self.to_k(hidden_states)
|
208 |
-
value = self.to_v(hidden_states)
|
209 |
-
|
210 |
-
encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states)
|
211 |
-
encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states)
|
212 |
-
|
213 |
-
key = rearrange(key, "b (f h w) c -> (b f) c h w", f=num_frames, h=height, w=width)
|
214 |
-
key = self.kv_compression(key)
|
215 |
-
key = rearrange(key, "(b f) c h w -> b (f h w) c", f=num_frames)
|
216 |
-
key = self.kv_compression_norm(key)
|
217 |
-
key = key.to(query.dtype)
|
218 |
-
|
219 |
-
value = rearrange(value, "b (f h w) c -> (b f) c h w", f=num_frames, h=height, w=width)
|
220 |
-
value = self.kv_compression(value)
|
221 |
-
value = rearrange(value, "(b f) c h w -> b (f h w) c", f=num_frames)
|
222 |
-
value = self.kv_compression_norm(value)
|
223 |
-
value = value.to(query.dtype)
|
224 |
-
|
225 |
-
key = self.reshape_heads_to_batch_dim(key)
|
226 |
-
value = self.reshape_heads_to_batch_dim(value)
|
227 |
-
encoder_hidden_states_key_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_key_proj)
|
228 |
-
encoder_hidden_states_value_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_value_proj)
|
229 |
-
|
230 |
-
key = torch.concat([encoder_hidden_states_key_proj, key], dim=1)
|
231 |
-
value = torch.concat([encoder_hidden_states_value_proj, value], dim=1)
|
232 |
-
else:
|
233 |
-
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
|
234 |
-
key = self.to_k(encoder_hidden_states)
|
235 |
-
value = self.to_v(encoder_hidden_states)
|
236 |
-
|
237 |
-
key = rearrange(key, "b (f h w) c -> (b f) c h w", f=num_frames, h=height, w=width)
|
238 |
-
key = self.kv_compression(key)
|
239 |
-
key = rearrange(key, "(b f) c h w -> b (f h w) c", f=num_frames)
|
240 |
-
key = self.kv_compression_norm(key)
|
241 |
-
key = key.to(query.dtype)
|
242 |
-
|
243 |
-
value = rearrange(value, "b (f h w) c -> (b f) c h w", f=num_frames, h=height, w=width)
|
244 |
-
value = self.kv_compression(value)
|
245 |
-
value = rearrange(value, "(b f) c h w -> b (f h w) c", f=num_frames)
|
246 |
-
value = self.kv_compression_norm(value)
|
247 |
-
value = value.to(query.dtype)
|
248 |
-
|
249 |
-
key = self.reshape_heads_to_batch_dim(key)
|
250 |
-
value = self.reshape_heads_to_batch_dim(value)
|
251 |
-
|
252 |
-
if attention_mask is not None:
|
253 |
-
if attention_mask.shape[-1] != query.shape[1]:
|
254 |
-
target_length = query.shape[1]
|
255 |
-
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
|
256 |
-
attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
|
257 |
-
|
258 |
-
# attention, what we cannot get enough of
|
259 |
-
if self._use_memory_efficient_attention_xformers:
|
260 |
-
hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
|
261 |
-
# Some versions of xformers return output in fp32, cast it back to the dtype of the input
|
262 |
-
hidden_states = hidden_states.to(query.dtype)
|
263 |
-
else:
|
264 |
-
if self._slice_size is None or query.shape[0] // self._slice_size == 1:
|
265 |
-
hidden_states = self._attention(query, key, value, attention_mask)
|
266 |
-
else:
|
267 |
-
hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
|
268 |
-
|
269 |
-
# linear proj
|
270 |
-
hidden_states = self.to_out[0](hidden_states)
|
271 |
-
|
272 |
-
# dropout
|
273 |
-
hidden_states = self.to_out[1](hidden_states)
|
274 |
-
return hidden_states
|
275 |
-
|
276 |
-
def _attention(self, query, key, value, attention_mask=None):
|
277 |
-
if self.upcast_attention:
|
278 |
-
query = query.float()
|
279 |
-
key = key.float()
|
280 |
-
|
281 |
-
attention_scores = torch.baddbmm(
|
282 |
-
torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
|
283 |
-
query,
|
284 |
-
key.transpose(-1, -2),
|
285 |
-
beta=0,
|
286 |
-
alpha=self.scale,
|
287 |
-
)
|
288 |
-
|
289 |
-
if attention_mask is not None:
|
290 |
-
attention_scores = attention_scores + attention_mask
|
291 |
-
|
292 |
-
if self.upcast_softmax:
|
293 |
-
attention_scores = attention_scores.float()
|
294 |
-
|
295 |
-
attention_probs = attention_scores.softmax(dim=-1)
|
296 |
-
|
297 |
-
# cast back to the original dtype
|
298 |
-
attention_probs = attention_probs.to(value.dtype)
|
299 |
-
|
300 |
-
# compute attention output
|
301 |
-
hidden_states = torch.bmm(attention_probs, value)
|
302 |
-
|
303 |
-
# reshape hidden_states
|
304 |
-
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
|
305 |
-
return hidden_states
|
306 |
-
|
307 |
-
def _sliced_attention(self, query, key, value, sequence_length, dim, attention_mask):
|
308 |
-
batch_size_attention = query.shape[0]
|
309 |
-
hidden_states = torch.zeros(
|
310 |
-
(batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype
|
311 |
)
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
query_slice = query[start_idx:end_idx]
|
318 |
-
key_slice = key[start_idx:end_idx]
|
319 |
-
|
320 |
-
if self.upcast_attention:
|
321 |
-
query_slice = query_slice.float()
|
322 |
-
key_slice = key_slice.float()
|
323 |
-
|
324 |
-
attn_slice = torch.baddbmm(
|
325 |
-
torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query_slice.dtype, device=query.device),
|
326 |
-
query_slice,
|
327 |
-
key_slice.transpose(-1, -2),
|
328 |
-
beta=0,
|
329 |
-
alpha=self.scale,
|
330 |
-
)
|
331 |
-
|
332 |
-
if attention_mask is not None:
|
333 |
-
attn_slice = attn_slice + attention_mask[start_idx:end_idx]
|
334 |
-
|
335 |
-
if self.upcast_softmax:
|
336 |
-
attn_slice = attn_slice.float()
|
337 |
-
|
338 |
-
attn_slice = attn_slice.softmax(dim=-1)
|
339 |
-
|
340 |
-
# cast back to the original dtype
|
341 |
-
attn_slice = attn_slice.to(value.dtype)
|
342 |
-
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
|
343 |
-
|
344 |
-
hidden_states[start_idx:end_idx] = attn_slice
|
345 |
-
|
346 |
-
# reshape hidden_states
|
347 |
-
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
|
348 |
-
return hidden_states
|
349 |
-
|
350 |
-
def _memory_efficient_attention_xformers(self, query, key, value, attention_mask):
|
351 |
-
# TODO attention_mask
|
352 |
-
query = query.contiguous()
|
353 |
-
key = key.contiguous()
|
354 |
-
value = value.contiguous()
|
355 |
-
hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
|
356 |
-
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
|
357 |
-
return hidden_states
|
358 |
-
|
359 |
|
360 |
@maybe_allow_in_graph
|
361 |
class TemporalTransformerBlock(nn.Module):
|
@@ -413,8 +182,6 @@ class TemporalTransformerBlock(nn.Module):
|
|
413 |
attention_type: str = "default",
|
414 |
positional_embeddings: Optional[str] = None,
|
415 |
num_positional_embeddings: Optional[int] = None,
|
416 |
-
# kv compression
|
417 |
-
kvcompression: Optional[bool] = False,
|
418 |
# motion module kwargs
|
419 |
motion_module_type = "VanillaGrid",
|
420 |
motion_module_kwargs = None,
|
@@ -454,40 +221,17 @@ class TemporalTransformerBlock(nn.Module):
|
|
454 |
else:
|
455 |
self.norm1 = FP32LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
|
456 |
|
457 |
-
self.
|
458 |
-
|
459 |
-
|
460 |
-
|
461 |
-
|
462 |
-
|
463 |
-
|
464 |
-
|
465 |
-
|
466 |
-
|
467 |
-
|
468 |
-
else:
|
469 |
-
if pkg_resources.parse_version(installed_version) >= pkg_resources.parse_version("0.28.2"):
|
470 |
-
self.attn1 = Attention(
|
471 |
-
query_dim=dim,
|
472 |
-
heads=num_attention_heads,
|
473 |
-
dim_head=attention_head_dim,
|
474 |
-
dropout=dropout,
|
475 |
-
bias=attention_bias,
|
476 |
-
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
|
477 |
-
upcast_attention=upcast_attention,
|
478 |
-
qk_norm="layer_norm" if qk_norm else None,
|
479 |
-
processor=HunyuanAttnProcessor2_0() if qk_norm else AttnProcessor2_0(),
|
480 |
-
)
|
481 |
-
else:
|
482 |
-
self.attn1 = Attention(
|
483 |
-
query_dim=dim,
|
484 |
-
heads=num_attention_heads,
|
485 |
-
dim_head=attention_head_dim,
|
486 |
-
dropout=dropout,
|
487 |
-
bias=attention_bias,
|
488 |
-
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
|
489 |
-
upcast_attention=upcast_attention,
|
490 |
-
)
|
491 |
|
492 |
self.attn_temporal = get_motion_module(
|
493 |
in_channels = dim,
|
@@ -505,28 +249,17 @@ class TemporalTransformerBlock(nn.Module):
|
|
505 |
if self.use_ada_layer_norm
|
506 |
else FP32LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
|
507 |
)
|
508 |
-
|
509 |
-
|
510 |
-
|
511 |
-
|
512 |
-
|
513 |
-
|
514 |
-
|
515 |
-
|
516 |
-
|
517 |
-
|
518 |
-
|
519 |
-
) # is self-attn if encoder_hidden_states is none
|
520 |
-
else:
|
521 |
-
self.attn2 = Attention(
|
522 |
-
query_dim=dim,
|
523 |
-
cross_attention_dim=cross_attention_dim if not double_self_attention else None,
|
524 |
-
heads=num_attention_heads,
|
525 |
-
dim_head=attention_head_dim,
|
526 |
-
dropout=dropout,
|
527 |
-
bias=attention_bias,
|
528 |
-
upcast_attention=upcast_attention,
|
529 |
-
) # is self-attn if encoder_hidden_states is none
|
530 |
else:
|
531 |
self.norm2 = None
|
532 |
self.attn2 = None
|
@@ -605,23 +338,12 @@ class TemporalTransformerBlock(nn.Module):
|
|
605 |
gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
|
606 |
|
607 |
norm_hidden_states = rearrange(norm_hidden_states, "b (f d) c -> (b f) d c", f=num_frames)
|
608 |
-
|
609 |
-
|
610 |
-
|
611 |
-
|
612 |
-
|
613 |
-
|
614 |
-
height=height,
|
615 |
-
width=width,
|
616 |
-
**cross_attention_kwargs,
|
617 |
-
)
|
618 |
-
else:
|
619 |
-
attn_output = self.attn1(
|
620 |
-
norm_hidden_states,
|
621 |
-
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
|
622 |
-
attention_mask=attention_mask,
|
623 |
-
**cross_attention_kwargs,
|
624 |
-
)
|
625 |
attn_output = rearrange(attn_output, "(b f) d c -> b (f d) c", f=num_frames)
|
626 |
if self.use_ada_layer_norm_zero:
|
627 |
attn_output = gate_msa.unsqueeze(1) * attn_output
|
@@ -658,6 +380,9 @@ class TemporalTransformerBlock(nn.Module):
|
|
658 |
if self.pos_embed is not None and self.use_ada_layer_norm_single is None:
|
659 |
norm_hidden_states = self.pos_embed(norm_hidden_states)
|
660 |
|
|
|
|
|
|
|
661 |
attn_output = self.attn2(
|
662 |
norm_hidden_states,
|
663 |
encoder_hidden_states=encoder_hidden_states,
|
@@ -760,7 +485,7 @@ class SelfAttentionTemporalTransformerBlock(nn.Module):
|
|
760 |
double_self_attention: bool = False,
|
761 |
upcast_attention: bool = False,
|
762 |
norm_elementwise_affine: bool = True,
|
763 |
-
norm_type: str = "layer_norm",
|
764 |
norm_eps: float = 1e-5,
|
765 |
final_dropout: bool = False,
|
766 |
attention_type: str = "default",
|
@@ -802,28 +527,17 @@ class SelfAttentionTemporalTransformerBlock(nn.Module):
|
|
802 |
else:
|
803 |
self.norm1 = FP32LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
|
804 |
|
805 |
-
|
806 |
-
|
807 |
-
|
808 |
-
|
809 |
-
|
810 |
-
|
811 |
-
|
812 |
-
|
813 |
-
|
814 |
-
|
815 |
-
|
816 |
-
)
|
817 |
-
else:
|
818 |
-
self.attn1 = Attention(
|
819 |
-
query_dim=dim,
|
820 |
-
heads=num_attention_heads,
|
821 |
-
dim_head=attention_head_dim,
|
822 |
-
dropout=dropout,
|
823 |
-
bias=attention_bias,
|
824 |
-
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
|
825 |
-
upcast_attention=upcast_attention,
|
826 |
-
)
|
827 |
|
828 |
# 2. Cross-Attn
|
829 |
if cross_attention_dim is not None or double_self_attention:
|
@@ -835,28 +549,17 @@ class SelfAttentionTemporalTransformerBlock(nn.Module):
|
|
835 |
if self.use_ada_layer_norm
|
836 |
else FP32LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
|
837 |
)
|
838 |
-
|
839 |
-
|
840 |
-
|
841 |
-
|
842 |
-
|
843 |
-
|
844 |
-
|
845 |
-
|
846 |
-
|
847 |
-
|
848 |
-
|
849 |
-
) # is self-attn if encoder_hidden_states is none
|
850 |
-
else:
|
851 |
-
self.attn2 = Attention(
|
852 |
-
query_dim=dim,
|
853 |
-
cross_attention_dim=cross_attention_dim if not double_self_attention else None,
|
854 |
-
heads=num_attention_heads,
|
855 |
-
dim_head=attention_head_dim,
|
856 |
-
dropout=dropout,
|
857 |
-
bias=attention_bias,
|
858 |
-
upcast_attention=upcast_attention,
|
859 |
-
) # is self-attn if encoder_hidden_states is none
|
860 |
else:
|
861 |
self.norm2 = None
|
862 |
self.attn2 = None
|
@@ -1017,340 +720,415 @@ class SelfAttentionTemporalTransformerBlock(nn.Module):
|
|
1017 |
hidden_states = hidden_states.squeeze(1)
|
1018 |
|
1019 |
return hidden_states
|
1020 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1021 |
|
1022 |
@maybe_allow_in_graph
|
1023 |
-
class
|
1024 |
r"""
|
1025 |
-
|
|
|
1026 |
|
1027 |
Parameters:
|
1028 |
-
dim (`int`):
|
1029 |
-
|
1030 |
-
|
1031 |
-
|
1032 |
-
cross_attention_dim (`int
|
1033 |
-
|
1034 |
-
|
1035 |
-
|
1036 |
-
|
1037 |
-
|
1038 |
-
only_cross_attention (`bool`, *optional*):
|
1039 |
-
Whether to use only cross-attention layers. In this case two cross attention layers are used.
|
1040 |
-
double_self_attention (`bool`, *optional*):
|
1041 |
-
Whether to use two self-attention layers. In this case no cross attention layers are used.
|
1042 |
-
upcast_attention (`bool`, *optional*):
|
1043 |
-
Whether to upcast the attention computation to float32. This is useful for mixed precision training.
|
1044 |
norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
|
1045 |
Whether to use learnable elementwise affine parameters for normalization.
|
1046 |
-
|
1047 |
-
|
1048 |
final_dropout (`bool` *optional*, defaults to False):
|
1049 |
Whether to apply a final dropout after the last feed-forward layer.
|
1050 |
-
|
1051 |
-
The
|
1052 |
-
|
1053 |
-
|
1054 |
-
|
1055 |
-
|
|
|
|
|
1056 |
"""
|
1057 |
|
1058 |
def __init__(
|
1059 |
self,
|
1060 |
dim: int,
|
1061 |
num_attention_heads: int,
|
1062 |
-
|
1063 |
dropout=0.0,
|
1064 |
-
cross_attention_dim: Optional[int] = None,
|
1065 |
activation_fn: str = "geglu",
|
1066 |
-
num_embeds_ada_norm: Optional[int] = None,
|
1067 |
-
attention_bias: bool = False,
|
1068 |
-
only_cross_attention: bool = False,
|
1069 |
-
double_self_attention: bool = False,
|
1070 |
-
upcast_attention: bool = False,
|
1071 |
norm_elementwise_affine: bool = True,
|
1072 |
-
|
1073 |
-
norm_eps: float = 1e-5,
|
1074 |
final_dropout: bool = False,
|
1075 |
-
|
1076 |
-
|
1077 |
-
|
1078 |
-
|
1079 |
-
|
1080 |
-
after_norm = False,
|
|
|
|
|
|
|
|
|
1081 |
):
|
1082 |
super().__init__()
|
1083 |
-
self.only_cross_attention = only_cross_attention
|
1084 |
-
|
1085 |
-
self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
|
1086 |
-
self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
|
1087 |
-
self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
|
1088 |
-
self.use_layer_norm = norm_type == "layer_norm"
|
1089 |
-
|
1090 |
-
if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
|
1091 |
-
raise ValueError(
|
1092 |
-
f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
|
1093 |
-
f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
|
1094 |
-
)
|
1095 |
-
|
1096 |
-
if positional_embeddings and (num_positional_embeddings is None):
|
1097 |
-
raise ValueError(
|
1098 |
-
"If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
|
1099 |
-
)
|
1100 |
-
|
1101 |
-
if positional_embeddings == "sinusoidal":
|
1102 |
-
self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
|
1103 |
-
else:
|
1104 |
-
self.pos_embed = None
|
1105 |
|
1106 |
# Define 3 blocks. Each block has its own normalization layer.
|
|
|
1107 |
# 1. Self-Attn
|
1108 |
-
|
1109 |
-
|
1110 |
-
|
1111 |
-
self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
|
1112 |
-
else:
|
1113 |
-
self.norm1 = FP32LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
|
1114 |
|
|
|
|
|
1115 |
self.kvcompression = kvcompression
|
1116 |
if kvcompression:
|
1117 |
-
self.attn1 =
|
1118 |
query_dim=dim,
|
|
|
|
|
1119 |
heads=num_attention_heads,
|
1120 |
-
|
1121 |
-
|
1122 |
-
bias=
|
1123 |
-
|
1124 |
-
upcast_attention=upcast_attention,
|
1125 |
)
|
1126 |
else:
|
1127 |
-
|
1128 |
-
|
1129 |
-
|
1130 |
-
|
1131 |
-
|
1132 |
-
|
1133 |
-
|
1134 |
-
|
1135 |
-
|
1136 |
-
|
1137 |
-
processor=HunyuanAttnProcessor2_0() if qk_norm else AttnProcessor2_0(),
|
1138 |
-
)
|
1139 |
-
else:
|
1140 |
-
self.attn1 = Attention(
|
1141 |
-
query_dim=dim,
|
1142 |
-
heads=num_attention_heads,
|
1143 |
-
dim_head=attention_head_dim,
|
1144 |
-
dropout=dropout,
|
1145 |
-
bias=attention_bias,
|
1146 |
-
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
|
1147 |
-
upcast_attention=upcast_attention,
|
1148 |
-
)
|
1149 |
|
1150 |
# 2. Cross-Attn
|
1151 |
-
|
1152 |
-
|
1153 |
-
|
1154 |
-
|
1155 |
-
self.
|
1156 |
-
|
1157 |
-
|
1158 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
1159 |
)
|
1160 |
-
|
1161 |
-
|
1162 |
-
|
1163 |
-
|
1164 |
-
|
1165 |
-
|
1166 |
-
|
1167 |
-
|
1168 |
-
|
1169 |
-
|
1170 |
-
|
1171 |
-
|
1172 |
-
|
1173 |
-
|
1174 |
-
|
1175 |
-
|
1176 |
-
|
1177 |
-
|
1178 |
-
|
1179 |
-
|
1180 |
-
|
1181 |
-
|
1182 |
-
else:
|
1183 |
-
self.norm2 = None
|
1184 |
-
self.attn2 = None
|
1185 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1186 |
# 3. Feed-forward
|
1187 |
-
|
1188 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1189 |
|
1190 |
-
|
|
|
|
|
|
|
|
|
|
|
1191 |
|
1192 |
if after_norm:
|
1193 |
self.norm4 = FP32LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
|
1194 |
else:
|
1195 |
self.norm4 = None
|
1196 |
|
1197 |
-
# 4. Fuser
|
1198 |
-
if attention_type == "gated" or attention_type == "gated-text-image":
|
1199 |
-
self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)
|
1200 |
-
|
1201 |
-
# 5. Scale-shift for PixArt-Alpha.
|
1202 |
-
if self.use_ada_layer_norm_single:
|
1203 |
-
self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
|
1204 |
-
|
1205 |
# let chunk size default to None
|
1206 |
self._chunk_size = None
|
1207 |
self._chunk_dim = 0
|
1208 |
|
1209 |
-
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
|
1210 |
# Sets chunk feed-forward
|
1211 |
self._chunk_size = chunk_size
|
1212 |
self._chunk_dim = dim
|
1213 |
|
1214 |
def forward(
|
1215 |
self,
|
1216 |
-
hidden_states: torch.
|
1217 |
-
|
1218 |
-
|
1219 |
-
|
1220 |
-
|
1221 |
-
|
1222 |
-
class_labels: Optional[torch.LongTensor] = None,
|
1223 |
-
num_frames: int = 16,
|
1224 |
height: int = 32,
|
1225 |
width: int = 32,
|
1226 |
-
|
1227 |
-
|
|
|
1228 |
# Notice that normalization is always applied before the real computation in the following blocks.
|
1229 |
-
# 0.
|
1230 |
-
|
1231 |
-
|
1232 |
-
|
1233 |
-
|
1234 |
-
|
1235 |
-
|
1236 |
-
|
1237 |
-
|
1238 |
-
|
1239 |
-
|
1240 |
-
|
1241 |
-
|
1242 |
-
|
1243 |
-
)
|
1244 |
-
|
1245 |
-
|
1246 |
-
|
1247 |
-
|
1248 |
-
|
1249 |
-
|
1250 |
-
|
1251 |
-
|
1252 |
-
|
1253 |
-
|
1254 |
-
|
1255 |
-
|
1256 |
-
# 2. Prepare GLIGEN inputs
|
1257 |
-
cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
|
1258 |
-
gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
|
1259 |
-
|
1260 |
-
if self.kvcompression:
|
1261 |
attn_output = self.attn1(
|
1262 |
-
|
1263 |
-
|
1264 |
-
attention_mask=attention_mask,
|
1265 |
-
num_frames=num_frames,
|
1266 |
-
height=height,
|
1267 |
-
width=width,
|
1268 |
-
**cross_attention_kwargs,
|
1269 |
)
|
1270 |
-
|
1271 |
-
|
1272 |
-
|
1273 |
-
|
1274 |
-
|
1275 |
-
|
|
|
|
|
1276 |
)
|
|
|
|
|
1277 |
|
1278 |
-
|
1279 |
-
|
1280 |
-
|
1281 |
-
|
1282 |
-
|
1283 |
-
|
1284 |
-
|
1285 |
-
|
1286 |
-
|
1287 |
-
# 2.5 GLIGEN Control
|
1288 |
-
if gligen_kwargs is not None:
|
1289 |
-
hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
|
1290 |
-
|
1291 |
-
# 3. Cross-Attention
|
1292 |
-
if self.attn2 is not None:
|
1293 |
-
if self.use_ada_layer_norm:
|
1294 |
-
norm_hidden_states = self.norm2(hidden_states, timestep)
|
1295 |
-
elif self.use_ada_layer_norm_zero or self.use_layer_norm:
|
1296 |
-
norm_hidden_states = self.norm2(hidden_states)
|
1297 |
-
elif self.use_ada_layer_norm_single:
|
1298 |
-
# For PixArt norm2 isn't applied here:
|
1299 |
-
# https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
|
1300 |
-
norm_hidden_states = hidden_states
|
1301 |
else:
|
1302 |
-
|
1303 |
-
|
1304 |
-
|
1305 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1306 |
|
1307 |
-
|
1308 |
-
|
1309 |
-
|
1310 |
-
|
1311 |
-
|
|
|
|
|
|
|
|
|
1312 |
)
|
1313 |
-
hidden_states = attn_output + hidden_states
|
1314 |
|
1315 |
-
#
|
1316 |
-
|
1317 |
-
|
|
|
|
|
|
|
1318 |
|
1319 |
-
|
1320 |
-
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
1321 |
|
1322 |
-
|
1323 |
-
|
1324 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1325 |
|
1326 |
-
|
1327 |
-
|
1328 |
-
|
1329 |
-
|
1330 |
-
f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
|
1331 |
-
)
|
1332 |
|
1333 |
-
|
1334 |
-
|
1335 |
-
|
1336 |
-
|
1337 |
-
|
1338 |
-
|
1339 |
-
|
1340 |
-
)
|
1341 |
-
|
1342 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1343 |
|
1344 |
-
|
1345 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1346 |
|
1347 |
-
|
1348 |
-
|
1349 |
-
|
1350 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1351 |
|
1352 |
-
|
1353 |
-
|
1354 |
-
hidden_states
|
|
|
|
|
|
|
|
|
|
|
|
|
1355 |
|
1356 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
# See the License for the specific language governing permissions and
|
13 |
# limitations under the License.
|
14 |
+
from typing import Any, Dict, Optional, Tuple, Union
|
15 |
|
16 |
import diffusers
|
17 |
import pkg_resources
|
18 |
import torch
|
19 |
import torch.nn.functional as F
|
20 |
import torch.nn.init as init
|
21 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
22 |
+
from diffusers.models.attention import Attention, FeedForward
|
23 |
+
from diffusers.models.attention_processor import (Attention,
|
24 |
+
AttentionProcessor,
|
25 |
+
AttnProcessor2_0,
|
26 |
+
HunyuanAttnProcessor2_0)
|
27 |
+
from diffusers.models.embeddings import (SinusoidalPositionalEmbedding,
|
28 |
+
TimestepEmbedding, Timesteps,
|
29 |
+
get_3d_sincos_pos_embed)
|
30 |
+
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
31 |
+
from diffusers.models.modeling_utils import ModelMixin
|
32 |
+
from diffusers.models.normalization import (AdaLayerNorm, AdaLayerNormZero,
|
33 |
+
CogVideoXLayerNormZero)
|
34 |
+
from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging
|
35 |
from diffusers.utils.import_utils import is_xformers_available
|
36 |
from diffusers.utils.torch_utils import maybe_allow_in_graph
|
37 |
from einops import rearrange, repeat
|
38 |
from torch import nn
|
39 |
|
40 |
from .motion_module import PositionalEncoding, get_motion_module
|
41 |
+
from .norm import AdaLayerNormShift, FP32LayerNorm, EasyAnimateLayerNormZero
|
42 |
+
from .processor import (EasyAnimateAttnProcessor2_0,
|
43 |
+
LazyKVCompressionProcessor2_0)
|
44 |
+
|
45 |
+
|
46 |
|
47 |
if is_xformers_available():
|
48 |
import xformers
|
|
|
57 |
p.detach().zero_()
|
58 |
return module
|
59 |
|
|
|
60 |
@maybe_allow_in_graph
|
61 |
class GatedSelfAttentionDense(nn.Module):
|
62 |
r"""
|
|
|
98 |
|
99 |
return x
|
100 |
|
101 |
+
class LazyKVCompressionAttention(Attention):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
def __init__(
|
103 |
+
self,
|
104 |
+
sr_ratio=2, *args, **kwargs
|
105 |
+
):
|
106 |
+
super().__init__(*args, **kwargs)
|
107 |
+
self.sr_ratio = sr_ratio
|
108 |
+
self.k_compression = nn.Conv2d(
|
109 |
+
kwargs["query_dim"],
|
110 |
+
kwargs["query_dim"],
|
111 |
+
groups=kwargs["query_dim"],
|
112 |
+
kernel_size=sr_ratio,
|
113 |
+
stride=sr_ratio,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
114 |
bias=True
|
115 |
)
|
116 |
+
self.v_compression = nn.Conv2d(
|
117 |
+
kwargs["query_dim"],
|
118 |
+
kwargs["query_dim"],
|
119 |
+
groups=kwargs["query_dim"],
|
120 |
+
kernel_size=sr_ratio,
|
121 |
+
stride=sr_ratio,
|
122 |
+
bias=True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
123 |
)
|
124 |
+
init.constant_(self.k_compression.weight, 1 / (sr_ratio * sr_ratio))
|
125 |
+
init.constant_(self.v_compression.weight, 1 / (sr_ratio * sr_ratio))
|
126 |
+
init.constant_(self.k_compression.bias, 0)
|
127 |
+
init.constant_(self.v_compression.bias, 0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
128 |
|
129 |
@maybe_allow_in_graph
|
130 |
class TemporalTransformerBlock(nn.Module):
|
|
|
182 |
attention_type: str = "default",
|
183 |
positional_embeddings: Optional[str] = None,
|
184 |
num_positional_embeddings: Optional[int] = None,
|
|
|
|
|
185 |
# motion module kwargs
|
186 |
motion_module_type = "VanillaGrid",
|
187 |
motion_module_kwargs = None,
|
|
|
221 |
else:
|
222 |
self.norm1 = FP32LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
|
223 |
|
224 |
+
self.attn1 = Attention(
|
225 |
+
query_dim=dim,
|
226 |
+
heads=num_attention_heads,
|
227 |
+
dim_head=attention_head_dim,
|
228 |
+
dropout=dropout,
|
229 |
+
bias=attention_bias,
|
230 |
+
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
|
231 |
+
upcast_attention=upcast_attention,
|
232 |
+
qk_norm="layer_norm" if qk_norm else None,
|
233 |
+
processor=HunyuanAttnProcessor2_0() if qk_norm else AttnProcessor2_0(),
|
234 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
235 |
|
236 |
self.attn_temporal = get_motion_module(
|
237 |
in_channels = dim,
|
|
|
249 |
if self.use_ada_layer_norm
|
250 |
else FP32LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
|
251 |
)
|
252 |
+
self.attn2 = Attention(
|
253 |
+
query_dim=dim,
|
254 |
+
cross_attention_dim=cross_attention_dim if not double_self_attention else None,
|
255 |
+
heads=num_attention_heads,
|
256 |
+
dim_head=attention_head_dim,
|
257 |
+
dropout=dropout,
|
258 |
+
bias=attention_bias,
|
259 |
+
upcast_attention=upcast_attention,
|
260 |
+
qk_norm="layer_norm" if qk_norm else None,
|
261 |
+
processor=HunyuanAttnProcessor2_0() if qk_norm else AttnProcessor2_0(),
|
262 |
+
) # is self-attn if encoder_hidden_states is none
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
263 |
else:
|
264 |
self.norm2 = None
|
265 |
self.attn2 = None
|
|
|
338 |
gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
|
339 |
|
340 |
norm_hidden_states = rearrange(norm_hidden_states, "b (f d) c -> (b f) d c", f=num_frames)
|
341 |
+
attn_output = self.attn1(
|
342 |
+
norm_hidden_states,
|
343 |
+
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
|
344 |
+
attention_mask=attention_mask,
|
345 |
+
**cross_attention_kwargs,
|
346 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
347 |
attn_output = rearrange(attn_output, "(b f) d c -> b (f d) c", f=num_frames)
|
348 |
if self.use_ada_layer_norm_zero:
|
349 |
attn_output = gate_msa.unsqueeze(1) * attn_output
|
|
|
380 |
if self.pos_embed is not None and self.use_ada_layer_norm_single is None:
|
381 |
norm_hidden_states = self.pos_embed(norm_hidden_states)
|
382 |
|
383 |
+
if norm_hidden_states.dtype != encoder_hidden_states.dtype or norm_hidden_states.dtype != encoder_attention_mask.dtype:
|
384 |
+
norm_hidden_states = norm_hidden_states.to(encoder_hidden_states.dtype)
|
385 |
+
|
386 |
attn_output = self.attn2(
|
387 |
norm_hidden_states,
|
388 |
encoder_hidden_states=encoder_hidden_states,
|
|
|
485 |
double_self_attention: bool = False,
|
486 |
upcast_attention: bool = False,
|
487 |
norm_elementwise_affine: bool = True,
|
488 |
+
norm_type: str = "layer_norm",
|
489 |
norm_eps: float = 1e-5,
|
490 |
final_dropout: bool = False,
|
491 |
attention_type: str = "default",
|
|
|
527 |
else:
|
528 |
self.norm1 = FP32LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
|
529 |
|
530 |
+
self.attn1 = Attention(
|
531 |
+
query_dim=dim,
|
532 |
+
heads=num_attention_heads,
|
533 |
+
dim_head=attention_head_dim,
|
534 |
+
dropout=dropout,
|
535 |
+
bias=attention_bias,
|
536 |
+
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
|
537 |
+
upcast_attention=upcast_attention,
|
538 |
+
qk_norm="layer_norm" if qk_norm else None,
|
539 |
+
processor=HunyuanAttnProcessor2_0() if qk_norm else AttnProcessor2_0(),
|
540 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
541 |
|
542 |
# 2. Cross-Attn
|
543 |
if cross_attention_dim is not None or double_self_attention:
|
|
|
549 |
if self.use_ada_layer_norm
|
550 |
else FP32LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
|
551 |
)
|
552 |
+
self.attn2 = Attention(
|
553 |
+
query_dim=dim,
|
554 |
+
cross_attention_dim=cross_attention_dim if not double_self_attention else None,
|
555 |
+
heads=num_attention_heads,
|
556 |
+
dim_head=attention_head_dim,
|
557 |
+
dropout=dropout,
|
558 |
+
bias=attention_bias,
|
559 |
+
upcast_attention=upcast_attention,
|
560 |
+
qk_norm="layer_norm" if qk_norm else None,
|
561 |
+
processor=HunyuanAttnProcessor2_0() if qk_norm else AttnProcessor2_0(),
|
562 |
+
) # is self-attn if encoder_hidden_states is none
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
563 |
else:
|
564 |
self.norm2 = None
|
565 |
self.attn2 = None
|
|
|
720 |
hidden_states = hidden_states.squeeze(1)
|
721 |
|
722 |
return hidden_states
|
723 |
+
|
724 |
+
class GEGLU(nn.Module):
|
725 |
+
def __init__(self, dim_in, dim_out, norm_elementwise_affine):
|
726 |
+
super().__init__()
|
727 |
+
self.norm = FP32LayerNorm(dim_in, dim_in, norm_elementwise_affine)
|
728 |
+
self.proj = nn.Linear(dim_in, dim_out * 2)
|
729 |
+
|
730 |
+
def forward(self, x):
|
731 |
+
x, gate = self.proj(self.norm(x)).chunk(2, dim=-1)
|
732 |
+
return x * F.gelu(gate)
|
733 |
|
734 |
@maybe_allow_in_graph
|
735 |
+
class HunyuanDiTBlock(nn.Module):
|
736 |
r"""
|
737 |
+
Transformer block used in Hunyuan-DiT model (https://github.com/Tencent/HunyuanDiT). Allow skip connection and
|
738 |
+
QKNorm
|
739 |
|
740 |
Parameters:
|
741 |
+
dim (`int`):
|
742 |
+
The number of channels in the input and output.
|
743 |
+
num_attention_heads (`int`):
|
744 |
+
The number of headsto use for multi-head attention.
|
745 |
+
cross_attention_dim (`int`,*optional*):
|
746 |
+
The size of the encoder_hidden_states vector for cross attention.
|
747 |
+
dropout(`float`, *optional*, defaults to 0.0):
|
748 |
+
The dropout probability to use.
|
749 |
+
activation_fn (`str`,*optional*, defaults to `"geglu"`):
|
750 |
+
Activation function to be used in feed-forward. .
|
|
|
|
|
|
|
|
|
|
|
|
|
751 |
norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
|
752 |
Whether to use learnable elementwise affine parameters for normalization.
|
753 |
+
norm_eps (`float`, *optional*, defaults to 1e-6):
|
754 |
+
A small constant added to the denominator in normalization layers to prevent division by zero.
|
755 |
final_dropout (`bool` *optional*, defaults to False):
|
756 |
Whether to apply a final dropout after the last feed-forward layer.
|
757 |
+
ff_inner_dim (`int`, *optional*):
|
758 |
+
The size of the hidden layer in the feed-forward block. Defaults to `None`.
|
759 |
+
ff_bias (`bool`, *optional*, defaults to `True`):
|
760 |
+
Whether to use bias in the feed-forward block.
|
761 |
+
skip (`bool`, *optional*, defaults to `False`):
|
762 |
+
Whether to use skip connection. Defaults to `False` for down-blocks and mid-blocks.
|
763 |
+
qk_norm (`bool`, *optional*, defaults to `True`):
|
764 |
+
Whether to use normalization in QK calculation. Defaults to `True`.
|
765 |
"""
|
766 |
|
767 |
def __init__(
|
768 |
self,
|
769 |
dim: int,
|
770 |
num_attention_heads: int,
|
771 |
+
cross_attention_dim: int = 1024,
|
772 |
dropout=0.0,
|
|
|
773 |
activation_fn: str = "geglu",
|
|
|
|
|
|
|
|
|
|
|
774 |
norm_elementwise_affine: bool = True,
|
775 |
+
norm_eps: float = 1e-6,
|
|
|
776 |
final_dropout: bool = False,
|
777 |
+
ff_inner_dim: Optional[int] = None,
|
778 |
+
ff_bias: bool = True,
|
779 |
+
skip: bool = False,
|
780 |
+
qk_norm: bool = True,
|
781 |
+
time_position_encoding: bool = False,
|
782 |
+
after_norm: bool = False,
|
783 |
+
is_local_attention: bool = False,
|
784 |
+
local_attention_frames: int = 2,
|
785 |
+
enable_inpaint: bool = False,
|
786 |
+
kvcompression = False,
|
787 |
):
|
788 |
super().__init__()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
789 |
|
790 |
# Define 3 blocks. Each block has its own normalization layer.
|
791 |
+
# NOTE: when new version comes, check norm2 and norm 3
|
792 |
# 1. Self-Attn
|
793 |
+
self.norm1 = AdaLayerNormShift(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
|
794 |
+
self.t_embed = PositionalEncoding(dim, dropout=0., max_len=512) \
|
795 |
+
if time_position_encoding else nn.Identity()
|
|
|
|
|
|
|
796 |
|
797 |
+
self.is_local_attention = is_local_attention
|
798 |
+
self.local_attention_frames = local_attention_frames
|
799 |
self.kvcompression = kvcompression
|
800 |
if kvcompression:
|
801 |
+
self.attn1 = LazyKVCompressionAttention(
|
802 |
query_dim=dim,
|
803 |
+
cross_attention_dim=None,
|
804 |
+
dim_head=dim // num_attention_heads,
|
805 |
heads=num_attention_heads,
|
806 |
+
qk_norm="layer_norm" if qk_norm else None,
|
807 |
+
eps=1e-6,
|
808 |
+
bias=True,
|
809 |
+
processor=LazyKVCompressionProcessor2_0(),
|
|
|
810 |
)
|
811 |
else:
|
812 |
+
self.attn1 = Attention(
|
813 |
+
query_dim=dim,
|
814 |
+
cross_attention_dim=None,
|
815 |
+
dim_head=dim // num_attention_heads,
|
816 |
+
heads=num_attention_heads,
|
817 |
+
qk_norm="layer_norm" if qk_norm else None,
|
818 |
+
eps=1e-6,
|
819 |
+
bias=True,
|
820 |
+
processor=HunyuanAttnProcessor2_0(),
|
821 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
822 |
|
823 |
# 2. Cross-Attn
|
824 |
+
self.norm2 = FP32LayerNorm(dim, norm_eps, norm_elementwise_affine)
|
825 |
+
|
826 |
+
if self.is_local_attention:
|
827 |
+
from mamba_ssm import Mamba2
|
828 |
+
self.mamba_norm_in = FP32LayerNorm(dim, norm_eps, norm_elementwise_affine)
|
829 |
+
self.in_linear = nn.Linear(dim, 1536)
|
830 |
+
self.mamba_norm_1 = FP32LayerNorm(1536, norm_eps, norm_elementwise_affine)
|
831 |
+
self.mamba_norm_2 = FP32LayerNorm(1536, norm_eps, norm_elementwise_affine)
|
832 |
+
|
833 |
+
self.mamba_block_1 = Mamba2(
|
834 |
+
d_model=1536,
|
835 |
+
d_state=64,
|
836 |
+
d_conv=4,
|
837 |
+
expand=2,
|
838 |
)
|
839 |
+
self.mamba_block_2 = Mamba2(
|
840 |
+
d_model=1536,
|
841 |
+
d_state=64,
|
842 |
+
d_conv=4,
|
843 |
+
expand=2,
|
844 |
+
)
|
845 |
+
self.mamba_norm_after_mamba_block = FP32LayerNorm(1536, norm_eps, norm_elementwise_affine)
|
846 |
+
|
847 |
+
self.out_linear = nn.Linear(1536, dim)
|
848 |
+
self.out_linear = zero_module(self.out_linear)
|
849 |
+
self.mamba_norm_out = FP32LayerNorm(dim, norm_eps, norm_elementwise_affine)
|
850 |
+
|
851 |
+
self.attn2 = Attention(
|
852 |
+
query_dim=dim,
|
853 |
+
cross_attention_dim=cross_attention_dim,
|
854 |
+
dim_head=dim // num_attention_heads,
|
855 |
+
heads=num_attention_heads,
|
856 |
+
qk_norm="layer_norm" if qk_norm else None,
|
857 |
+
eps=1e-6,
|
858 |
+
bias=True,
|
859 |
+
processor=HunyuanAttnProcessor2_0(),
|
860 |
+
)
|
|
|
|
|
|
|
861 |
|
862 |
+
if enable_inpaint:
|
863 |
+
self.norm_clip = FP32LayerNorm(dim, norm_eps, norm_elementwise_affine)
|
864 |
+
self.attn_clip = Attention(
|
865 |
+
query_dim=dim,
|
866 |
+
cross_attention_dim=cross_attention_dim,
|
867 |
+
dim_head=dim // num_attention_heads,
|
868 |
+
heads=num_attention_heads,
|
869 |
+
qk_norm="layer_norm" if qk_norm else None,
|
870 |
+
eps=1e-6,
|
871 |
+
bias=True,
|
872 |
+
processor=HunyuanAttnProcessor2_0(),
|
873 |
+
)
|
874 |
+
self.gate_clip = GEGLU(dim, dim, norm_elementwise_affine)
|
875 |
+
self.norm_clip_out = FP32LayerNorm(dim, norm_eps, norm_elementwise_affine)
|
876 |
+
else:
|
877 |
+
self.attn_clip = None
|
878 |
+
self.norm_clip = None
|
879 |
+
self.gate_clip = None
|
880 |
+
self.norm_clip_out = None
|
881 |
+
|
882 |
# 3. Feed-forward
|
883 |
+
self.norm3 = FP32LayerNorm(dim, norm_eps, norm_elementwise_affine)
|
884 |
+
|
885 |
+
self.ff = FeedForward(
|
886 |
+
dim,
|
887 |
+
dropout=dropout, ### 0.0
|
888 |
+
activation_fn=activation_fn, ### approx GeLU
|
889 |
+
final_dropout=final_dropout, ### 0.0
|
890 |
+
inner_dim=ff_inner_dim, ### int(dim * mlp_ratio)
|
891 |
+
bias=ff_bias,
|
892 |
+
)
|
893 |
|
894 |
+
# 4. Skip Connection
|
895 |
+
if skip:
|
896 |
+
self.skip_norm = FP32LayerNorm(2 * dim, norm_eps, elementwise_affine=True)
|
897 |
+
self.skip_linear = nn.Linear(2 * dim, dim)
|
898 |
+
else:
|
899 |
+
self.skip_linear = None
|
900 |
|
901 |
if after_norm:
|
902 |
self.norm4 = FP32LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
|
903 |
else:
|
904 |
self.norm4 = None
|
905 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
906 |
# let chunk size default to None
|
907 |
self._chunk_size = None
|
908 |
self._chunk_dim = 0
|
909 |
|
910 |
+
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
|
911 |
# Sets chunk feed-forward
|
912 |
self._chunk_size = chunk_size
|
913 |
self._chunk_dim = dim
|
914 |
|
915 |
def forward(
|
916 |
self,
|
917 |
+
hidden_states: torch.Tensor,
|
918 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
919 |
+
temb: Optional[torch.Tensor] = None,
|
920 |
+
image_rotary_emb=None,
|
921 |
+
skip=None,
|
922 |
+
num_frames: int = 1,
|
|
|
|
|
923 |
height: int = 32,
|
924 |
width: int = 32,
|
925 |
+
clip_encoder_hidden_states: Optional[torch.Tensor] = None,
|
926 |
+
disable_image_rotary_emb_in_attn1=False,
|
927 |
+
) -> torch.Tensor:
|
928 |
# Notice that normalization is always applied before the real computation in the following blocks.
|
929 |
+
# 0. Long Skip Connection
|
930 |
+
if self.skip_linear is not None:
|
931 |
+
cat = torch.cat([hidden_states, skip], dim=-1)
|
932 |
+
cat = self.skip_norm(cat)
|
933 |
+
hidden_states = self.skip_linear(cat)
|
934 |
+
|
935 |
+
if image_rotary_emb is not None:
|
936 |
+
image_rotary_emb = (torch.cat([image_rotary_emb[0] for i in range(num_frames)], dim=0), torch.cat([image_rotary_emb[1] for i in range(num_frames)], dim=0))
|
937 |
+
|
938 |
+
if num_frames != 1:
|
939 |
+
# add time embedding
|
940 |
+
hidden_states = rearrange(hidden_states, "b (f d) c -> (b d) f c", f=num_frames)
|
941 |
+
if self.t_embed is not None:
|
942 |
+
hidden_states = self.t_embed(hidden_states)
|
943 |
+
hidden_states = rearrange(hidden_states, "(b d) f c -> b (f d) c", d=height * width)
|
944 |
+
|
945 |
+
# 1. Self-Attention
|
946 |
+
norm_hidden_states = self.norm1(hidden_states, temb) ### checked: self.norm1 is correct
|
947 |
+
if num_frames > 2 and self.is_local_attention:
|
948 |
+
if image_rotary_emb is not None:
|
949 |
+
attn1_image_rotary_emb = (image_rotary_emb[0][:int(height * width * 2)], image_rotary_emb[1][:int(height * width * 2)])
|
950 |
+
else:
|
951 |
+
attn1_image_rotary_emb = image_rotary_emb
|
952 |
+
norm_hidden_states_1 = rearrange(norm_hidden_states, "b (f d) c -> b f d c", d=height * width)
|
953 |
+
norm_hidden_states_1 = rearrange(norm_hidden_states_1, "b (f p) d c -> (b f) (p d) c", p = 2)
|
954 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
955 |
attn_output = self.attn1(
|
956 |
+
norm_hidden_states_1,
|
957 |
+
image_rotary_emb=attn1_image_rotary_emb if not disable_image_rotary_emb_in_attn1 else None,
|
|
|
|
|
|
|
|
|
|
|
958 |
)
|
959 |
+
attn_output = rearrange(attn_output, "(b f) (p d) c -> b (f p) d c", p = 2, f = num_frames // 2)
|
960 |
+
|
961 |
+
norm_hidden_states_2 = rearrange(norm_hidden_states, "b (f d) c -> b f d c", d = height * width)[:, 1:-1]
|
962 |
+
local_attention_frames_num = norm_hidden_states_2.size()[1] // 2
|
963 |
+
norm_hidden_states_2 = rearrange(norm_hidden_states_2, "b (f p) d c -> (b f) (p d) c", p = 2)
|
964 |
+
attn_output_2 = self.attn1(
|
965 |
+
norm_hidden_states_2,
|
966 |
+
image_rotary_emb=attn1_image_rotary_emb if not disable_image_rotary_emb_in_attn1 else None,
|
967 |
)
|
968 |
+
attn_output_2 = rearrange(attn_output_2, "(b f) (p d) c -> b (f p) d c", p = 2, f = local_attention_frames_num)
|
969 |
+
attn_output[:, 1:-1] = (attn_output[:, 1:-1] + attn_output_2) / 2
|
970 |
|
971 |
+
attn_output = rearrange(attn_output, "b f d c -> b (f d) c")
|
972 |
+
else:
|
973 |
+
if self.kvcompression:
|
974 |
+
norm_hidden_states = rearrange(norm_hidden_states, "b (f h w) c -> b c f h w", f = num_frames, h = height, w = width)
|
975 |
+
attn_output = self.attn1(
|
976 |
+
norm_hidden_states,
|
977 |
+
image_rotary_emb=image_rotary_emb if not disable_image_rotary_emb_in_attn1 else None,
|
978 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
979 |
else:
|
980 |
+
attn_output = self.attn1(
|
981 |
+
norm_hidden_states,
|
982 |
+
image_rotary_emb=image_rotary_emb if not disable_image_rotary_emb_in_attn1 else None,
|
983 |
+
)
|
984 |
+
hidden_states = hidden_states + attn_output
|
985 |
+
|
986 |
+
if num_frames > 2 and self.is_local_attention:
|
987 |
+
hidden_states_in = self.in_linear(self.mamba_norm_in(hidden_states))
|
988 |
+
hidden_states = hidden_states + self.mamba_norm_out(
|
989 |
+
self.out_linear(
|
990 |
+
self.mamba_norm_after_mamba_block(
|
991 |
+
self.mamba_block_1(
|
992 |
+
self.mamba_norm_1(hidden_states_in)
|
993 |
+
) +
|
994 |
+
self.mamba_block_2(
|
995 |
+
self.mamba_norm_2(hidden_states_in.flip(1))
|
996 |
+
).flip(1)
|
997 |
+
)
|
998 |
+
)
|
999 |
+
)
|
1000 |
+
|
1001 |
+
# 2. Cross-Attention
|
1002 |
+
hidden_states = hidden_states + self.attn2(
|
1003 |
+
self.norm2(hidden_states),
|
1004 |
+
encoder_hidden_states=encoder_hidden_states,
|
1005 |
+
image_rotary_emb=image_rotary_emb,
|
1006 |
+
)
|
1007 |
|
1008 |
+
if self.attn_clip is not None:
|
1009 |
+
hidden_states = hidden_states + self.norm_clip_out(
|
1010 |
+
self.gate_clip(
|
1011 |
+
self.attn_clip(
|
1012 |
+
self.norm_clip(hidden_states),
|
1013 |
+
encoder_hidden_states=clip_encoder_hidden_states,
|
1014 |
+
image_rotary_emb=image_rotary_emb,
|
1015 |
+
)
|
1016 |
+
)
|
1017 |
)
|
|
|
1018 |
|
1019 |
+
# FFN Layer ### TODO: switch norm2 and norm3 in the state dict
|
1020 |
+
mlp_inputs = self.norm3(hidden_states)
|
1021 |
+
if self.norm4 is not None:
|
1022 |
+
hidden_states = hidden_states + self.norm4(self.ff(mlp_inputs))
|
1023 |
+
else:
|
1024 |
+
hidden_states = hidden_states + self.ff(mlp_inputs)
|
1025 |
|
1026 |
+
return hidden_states
|
|
|
1027 |
|
1028 |
+
@maybe_allow_in_graph
|
1029 |
+
class EasyAnimateDiTBlock(nn.Module):
|
1030 |
+
def __init__(
|
1031 |
+
self,
|
1032 |
+
dim: int,
|
1033 |
+
num_attention_heads: int,
|
1034 |
+
attention_head_dim: int,
|
1035 |
+
time_embed_dim: int,
|
1036 |
+
dropout: float = 0.0,
|
1037 |
+
activation_fn: str = "gelu-approximate",
|
1038 |
+
norm_elementwise_affine: bool = True,
|
1039 |
+
norm_eps: float = 1e-6,
|
1040 |
+
final_dropout: bool = True,
|
1041 |
+
ff_inner_dim: Optional[int] = None,
|
1042 |
+
ff_bias: bool = True,
|
1043 |
+
qk_norm: bool = True,
|
1044 |
+
after_norm: bool = False,
|
1045 |
+
norm_type: str="fp32_layer_norm"
|
1046 |
+
):
|
1047 |
+
super().__init__()
|
1048 |
|
1049 |
+
# Attention Part
|
1050 |
+
self.norm1 = EasyAnimateLayerNormZero(
|
1051 |
+
time_embed_dim, dim, norm_elementwise_affine, norm_eps, norm_type=norm_type, bias=True
|
1052 |
+
)
|
|
|
|
|
1053 |
|
1054 |
+
self.attn1 = Attention(
|
1055 |
+
query_dim=dim,
|
1056 |
+
dim_head=attention_head_dim,
|
1057 |
+
heads=num_attention_heads,
|
1058 |
+
qk_norm="layer_norm" if qk_norm else None,
|
1059 |
+
eps=1e-6,
|
1060 |
+
bias=True,
|
1061 |
+
processor=EasyAnimateAttnProcessor2_0(),
|
1062 |
+
)
|
1063 |
+
self.attn2 = Attention(
|
1064 |
+
query_dim=dim,
|
1065 |
+
dim_head=attention_head_dim,
|
1066 |
+
heads=num_attention_heads,
|
1067 |
+
qk_norm="layer_norm" if qk_norm else None,
|
1068 |
+
eps=1e-6,
|
1069 |
+
bias=True,
|
1070 |
+
processor=EasyAnimateAttnProcessor2_0(),
|
1071 |
+
)
|
1072 |
|
1073 |
+
# FFN Part
|
1074 |
+
self.norm2 = EasyAnimateLayerNormZero(
|
1075 |
+
time_embed_dim, dim, norm_elementwise_affine, norm_eps, norm_type=norm_type, bias=True
|
1076 |
+
)
|
1077 |
+
self.ff = FeedForward(
|
1078 |
+
dim,
|
1079 |
+
dropout=dropout,
|
1080 |
+
activation_fn=activation_fn,
|
1081 |
+
final_dropout=final_dropout,
|
1082 |
+
inner_dim=ff_inner_dim,
|
1083 |
+
bias=ff_bias,
|
1084 |
+
)
|
1085 |
+
self.txt_ff = FeedForward(
|
1086 |
+
dim,
|
1087 |
+
dropout=dropout,
|
1088 |
+
activation_fn=activation_fn,
|
1089 |
+
final_dropout=final_dropout,
|
1090 |
+
inner_dim=ff_inner_dim,
|
1091 |
+
bias=ff_bias,
|
1092 |
+
)
|
1093 |
+
if after_norm:
|
1094 |
+
self.norm3 = FP32LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
|
1095 |
+
else:
|
1096 |
+
self.norm3 = None
|
1097 |
|
1098 |
+
def forward(
|
1099 |
+
self,
|
1100 |
+
hidden_states: torch.Tensor,
|
1101 |
+
encoder_hidden_states: torch.Tensor,
|
1102 |
+
temb: torch.Tensor,
|
1103 |
+
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
1104 |
+
) -> torch.Tensor:
|
1105 |
+
# Norm
|
1106 |
+
norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
|
1107 |
+
hidden_states, encoder_hidden_states, temb
|
1108 |
+
)
|
1109 |
|
1110 |
+
# Attn
|
1111 |
+
attn_hidden_states, attn_encoder_hidden_states = self.attn1(
|
1112 |
+
hidden_states=norm_hidden_states,
|
1113 |
+
encoder_hidden_states=norm_encoder_hidden_states,
|
1114 |
+
image_rotary_emb=image_rotary_emb,
|
1115 |
+
attn2=self.attn2,
|
1116 |
+
)
|
1117 |
+
hidden_states = hidden_states + gate_msa * attn_hidden_states
|
1118 |
+
encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states
|
1119 |
|
1120 |
+
# Norm
|
1121 |
+
norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2(
|
1122 |
+
hidden_states, encoder_hidden_states, temb
|
1123 |
+
)
|
1124 |
+
|
1125 |
+
# FFN
|
1126 |
+
if self.norm3 is not None:
|
1127 |
+
norm_hidden_states = self.norm3(self.ff(norm_hidden_states))
|
1128 |
+
norm_encoder_hidden_states = self.norm3(self.txt_ff(norm_encoder_hidden_states))
|
1129 |
+
else:
|
1130 |
+
norm_hidden_states = self.ff(norm_hidden_states)
|
1131 |
+
norm_encoder_hidden_states = self.txt_ff(norm_encoder_hidden_states)
|
1132 |
+
hidden_states = hidden_states + gate_ff * norm_hidden_states
|
1133 |
+
encoder_hidden_states = encoder_hidden_states + enc_gate_ff * norm_encoder_hidden_states
|
1134 |
+
return hidden_states, encoder_hidden_states
|
easyanimate/models/autoencoder_magvit.py
CHANGED
@@ -15,8 +15,14 @@ from typing import Dict, Optional, Tuple, Union
|
|
15 |
|
16 |
import torch
|
17 |
import torch.nn as nn
|
18 |
-
import torch.nn.functional as F
|
19 |
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
try:
|
22 |
from diffusers.loaders import FromOriginalVAEMixin
|
@@ -32,10 +38,16 @@ from diffusers.models.modeling_outputs import AutoencoderKLOutput
|
|
32 |
from diffusers.models.modeling_utils import ModelMixin
|
33 |
from diffusers.utils.accelerate_utils import apply_forward_hook
|
34 |
from torch import nn
|
|
|
35 |
|
|
|
|
|
|
|
|
|
36 |
from ..vae.ldm.models.omnigen_enc_dec import Decoder as omnigen_Mag_Decoder
|
37 |
from ..vae.ldm.models.omnigen_enc_dec import Encoder as omnigen_Mag_Encoder
|
38 |
|
|
|
39 |
|
40 |
def str_eval(item):
|
41 |
if type(item) == str:
|
@@ -97,10 +109,19 @@ class AutoencoderKLMagvit(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
|
|
97 |
latent_channels: int = 4,
|
98 |
norm_num_groups: int = 32,
|
99 |
scaling_factor: float = 0.1825,
|
|
|
100 |
slice_compression_vae=False,
|
|
|
|
|
101 |
use_tiling=False,
|
|
|
|
|
102 |
mini_batch_encoder=9,
|
103 |
mini_batch_decoder=3,
|
|
|
|
|
|
|
|
|
104 |
):
|
105 |
super().__init__()
|
106 |
down_block_types = str_eval(down_block_types)
|
@@ -121,8 +142,12 @@ class AutoencoderKLMagvit(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
|
|
121 |
act_fn=act_fn,
|
122 |
num_attention_heads=num_attention_heads,
|
123 |
double_z=True,
|
|
|
124 |
slice_compression_vae=slice_compression_vae,
|
|
|
|
|
125 |
mini_batch_encoder=mini_batch_encoder,
|
|
|
126 |
)
|
127 |
|
128 |
self.decoder = omnigen_Mag_Decoder(
|
@@ -140,20 +165,30 @@ class AutoencoderKLMagvit(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
|
|
140 |
norm_num_groups=norm_num_groups,
|
141 |
act_fn=act_fn,
|
142 |
num_attention_heads=num_attention_heads,
|
|
|
143 |
slice_compression_vae=slice_compression_vae,
|
|
|
|
|
144 |
mini_batch_decoder=mini_batch_decoder,
|
|
|
145 |
)
|
146 |
|
147 |
self.quant_conv = nn.Conv3d(2 * latent_channels, 2 * latent_channels, kernel_size=1)
|
148 |
self.post_quant_conv = nn.Conv3d(latent_channels, latent_channels, kernel_size=1)
|
149 |
|
|
|
150 |
self.slice_compression_vae = slice_compression_vae
|
|
|
|
|
151 |
self.mini_batch_encoder = mini_batch_encoder
|
152 |
self.mini_batch_decoder = mini_batch_decoder
|
153 |
self.use_slicing = False
|
154 |
self.use_tiling = use_tiling
|
155 |
-
self.
|
156 |
-
self.
|
|
|
|
|
|
|
157 |
self.tile_latent_min_size = int(self.tile_sample_min_size / (2 ** (len(ch_mult) - 1)))
|
158 |
self.scaling_factor = scaling_factor
|
159 |
|
@@ -253,8 +288,16 @@ class AutoencoderKLMagvit(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
|
|
253 |
The latent representations of the encoded images. If `return_dict` is True, a
|
254 |
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
|
255 |
"""
|
|
|
|
|
|
|
|
|
256 |
if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
|
257 |
-
|
|
|
|
|
|
|
|
|
258 |
|
259 |
if self.use_slicing and x.shape[0] > 1:
|
260 |
encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
|
@@ -271,8 +314,15 @@ class AutoencoderKLMagvit(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
|
|
271 |
return AutoencoderKLOutput(latent_dist=posterior)
|
272 |
|
273 |
def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
|
|
|
|
|
|
|
|
|
274 |
if self.use_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size):
|
275 |
return self.tiled_decode(z, return_dict=return_dict)
|
|
|
|
|
|
|
276 |
z = self.post_quant_conv(z)
|
277 |
dec = self.decoder(z)
|
278 |
|
@@ -408,6 +458,34 @@ class AutoencoderKLMagvit(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
|
|
408 |
result_rows.append(torch.cat(result_row, dim=4))
|
409 |
|
410 |
dec = torch.cat(result_rows, dim=3)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
411 |
if not return_dict:
|
412 |
return (dec,)
|
413 |
|
@@ -507,3 +585,441 @@ class AutoencoderKLMagvit(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
|
|
507 |
print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
|
508 |
print(m, u)
|
509 |
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
|
16 |
import torch
|
17 |
import torch.nn as nn
|
|
|
18 |
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
19 |
+
from diffusers.loaders.single_file_model import FromOriginalModelMixin
|
20 |
+
from diffusers.models.autoencoders.vae import (DecoderOutput,
|
21 |
+
DiagonalGaussianDistribution)
|
22 |
+
from diffusers.models.modeling_outputs import AutoencoderKLOutput
|
23 |
+
from diffusers.models.modeling_utils import ModelMixin
|
24 |
+
from diffusers.utils import logging
|
25 |
+
from diffusers.utils.accelerate_utils import apply_forward_hook
|
26 |
|
27 |
try:
|
28 |
from diffusers.loaders import FromOriginalVAEMixin
|
|
|
38 |
from diffusers.models.modeling_utils import ModelMixin
|
39 |
from diffusers.utils.accelerate_utils import apply_forward_hook
|
40 |
from torch import nn
|
41 |
+
from diffusers import AutoencoderKL
|
42 |
|
43 |
+
from ..vae.ldm.models.cogvideox_enc_dec import (CogVideoXCausalConv3d,
|
44 |
+
CogVideoXDecoder3D,
|
45 |
+
CogVideoXEncoder3D,
|
46 |
+
CogVideoXSafeConv3d)
|
47 |
from ..vae.ldm.models.omnigen_enc_dec import Decoder as omnigen_Mag_Decoder
|
48 |
from ..vae.ldm.models.omnigen_enc_dec import Encoder as omnigen_Mag_Encoder
|
49 |
|
50 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
51 |
|
52 |
def str_eval(item):
|
53 |
if type(item) == str:
|
|
|
109 |
latent_channels: int = 4,
|
110 |
norm_num_groups: int = 32,
|
111 |
scaling_factor: float = 0.1825,
|
112 |
+
slice_mag_vae=True,
|
113 |
slice_compression_vae=False,
|
114 |
+
cache_compression_vae=False,
|
115 |
+
cache_mag_vae=False,
|
116 |
use_tiling=False,
|
117 |
+
use_tiling_encoder=False,
|
118 |
+
use_tiling_decoder=False,
|
119 |
mini_batch_encoder=9,
|
120 |
mini_batch_decoder=3,
|
121 |
+
upcast_vae=False,
|
122 |
+
spatial_group_norm=False,
|
123 |
+
tile_sample_min_size=384,
|
124 |
+
tile_overlap_factor=0.25,
|
125 |
):
|
126 |
super().__init__()
|
127 |
down_block_types = str_eval(down_block_types)
|
|
|
142 |
act_fn=act_fn,
|
143 |
num_attention_heads=num_attention_heads,
|
144 |
double_z=True,
|
145 |
+
slice_mag_vae=slice_mag_vae,
|
146 |
slice_compression_vae=slice_compression_vae,
|
147 |
+
cache_compression_vae=cache_compression_vae,
|
148 |
+
cache_mag_vae=cache_mag_vae,
|
149 |
mini_batch_encoder=mini_batch_encoder,
|
150 |
+
spatial_group_norm=spatial_group_norm,
|
151 |
)
|
152 |
|
153 |
self.decoder = omnigen_Mag_Decoder(
|
|
|
165 |
norm_num_groups=norm_num_groups,
|
166 |
act_fn=act_fn,
|
167 |
num_attention_heads=num_attention_heads,
|
168 |
+
slice_mag_vae=slice_mag_vae,
|
169 |
slice_compression_vae=slice_compression_vae,
|
170 |
+
cache_compression_vae=cache_compression_vae,
|
171 |
+
cache_mag_vae=cache_mag_vae,
|
172 |
mini_batch_decoder=mini_batch_decoder,
|
173 |
+
spatial_group_norm=spatial_group_norm,
|
174 |
)
|
175 |
|
176 |
self.quant_conv = nn.Conv3d(2 * latent_channels, 2 * latent_channels, kernel_size=1)
|
177 |
self.post_quant_conv = nn.Conv3d(latent_channels, latent_channels, kernel_size=1)
|
178 |
|
179 |
+
self.slice_mag_vae = slice_mag_vae
|
180 |
self.slice_compression_vae = slice_compression_vae
|
181 |
+
self.cache_compression_vae = cache_compression_vae
|
182 |
+
self.cache_mag_vae = cache_mag_vae
|
183 |
self.mini_batch_encoder = mini_batch_encoder
|
184 |
self.mini_batch_decoder = mini_batch_decoder
|
185 |
self.use_slicing = False
|
186 |
self.use_tiling = use_tiling
|
187 |
+
self.use_tiling_encoder = use_tiling_encoder
|
188 |
+
self.use_tiling_decoder = use_tiling_decoder
|
189 |
+
self.upcast_vae = upcast_vae
|
190 |
+
self.tile_sample_min_size = tile_sample_min_size
|
191 |
+
self.tile_overlap_factor = tile_overlap_factor
|
192 |
self.tile_latent_min_size = int(self.tile_sample_min_size / (2 ** (len(ch_mult) - 1)))
|
193 |
self.scaling_factor = scaling_factor
|
194 |
|
|
|
288 |
The latent representations of the encoded images. If `return_dict` is True, a
|
289 |
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
|
290 |
"""
|
291 |
+
if self.upcast_vae:
|
292 |
+
x = x.float()
|
293 |
+
self.encoder = self.encoder.float()
|
294 |
+
self.quant_conv = self.quant_conv.float()
|
295 |
if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
|
296 |
+
x = self.tiled_encode(x, return_dict=return_dict)
|
297 |
+
return x
|
298 |
+
if self.use_tiling_encoder and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
|
299 |
+
x = self.tiled_encode(x, return_dict=return_dict)
|
300 |
+
return x
|
301 |
|
302 |
if self.use_slicing and x.shape[0] > 1:
|
303 |
encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
|
|
|
314 |
return AutoencoderKLOutput(latent_dist=posterior)
|
315 |
|
316 |
def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
|
317 |
+
if self.upcast_vae:
|
318 |
+
z = z.float()
|
319 |
+
self.decoder = self.decoder.float()
|
320 |
+
self.post_quant_conv = self.post_quant_conv.float()
|
321 |
if self.use_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size):
|
322 |
return self.tiled_decode(z, return_dict=return_dict)
|
323 |
+
if self.use_tiling_decoder and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size):
|
324 |
+
return self.tiled_decode(z, return_dict=return_dict)
|
325 |
+
|
326 |
z = self.post_quant_conv(z)
|
327 |
dec = self.decoder(z)
|
328 |
|
|
|
458 |
result_rows.append(torch.cat(result_row, dim=4))
|
459 |
|
460 |
dec = torch.cat(result_rows, dim=3)
|
461 |
+
|
462 |
+
# Handle the lower right corner tile separately
|
463 |
+
lower_right_original = z[
|
464 |
+
:,
|
465 |
+
:,
|
466 |
+
:,
|
467 |
+
-self.tile_latent_min_size:,
|
468 |
+
-self.tile_latent_min_size:
|
469 |
+
]
|
470 |
+
quantized_lower_right = self.decoder(self.post_quant_conv(lower_right_original))
|
471 |
+
|
472 |
+
# Combine
|
473 |
+
H, W = quantized_lower_right.size(-2), quantized_lower_right.size(-1)
|
474 |
+
x_weights = torch.linspace(0, 1, W).unsqueeze(0).repeat(H, 1)
|
475 |
+
y_weights = torch.linspace(0, 1, H).unsqueeze(1).repeat(1, W)
|
476 |
+
weights = torch.min(x_weights, y_weights)
|
477 |
+
|
478 |
+
if len(dec.size()) == 4:
|
479 |
+
weights = weights.unsqueeze(0).unsqueeze(0)
|
480 |
+
elif len(dec.size()) == 5:
|
481 |
+
weights = weights.unsqueeze(0).unsqueeze(0).unsqueeze(0)
|
482 |
+
|
483 |
+
weights = weights.to(dec.device)
|
484 |
+
quantized_area = dec[:, :, :, -H:, -W:]
|
485 |
+
combined = weights * quantized_lower_right + (1 - weights) * quantized_area
|
486 |
+
|
487 |
+
dec[:, :, :, -H:, -W:] = combined
|
488 |
+
|
489 |
if not return_dict:
|
490 |
return (dec,)
|
491 |
|
|
|
585 |
print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
|
586 |
print(m, u)
|
587 |
return model
|
588 |
+
|
589 |
+
|
590 |
+
# Modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py
|
591 |
+
# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
|
592 |
+
# All rights reserved.
|
593 |
+
#
|
594 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
595 |
+
# you may not use this file except in compliance with the License.
|
596 |
+
# You may obtain a copy of the License at
|
597 |
+
#
|
598 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
599 |
+
#
|
600 |
+
# Unless required by applicable law or agreed to in writing, software
|
601 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
602 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
603 |
+
# See the License for the specific language governing permissions and
|
604 |
+
# limitations under the License.
|
605 |
+
|
606 |
+
|
607 |
+
class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
608 |
+
r"""
|
609 |
+
A VAE model with KL loss for encoding images into latents and decoding latent representations into images. Used in
|
610 |
+
[CogVideoX](https://github.com/THUDM/CogVideo).
|
611 |
+
|
612 |
+
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
|
613 |
+
for all models (such as downloading or saving).
|
614 |
+
|
615 |
+
Parameters:
|
616 |
+
in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
|
617 |
+
out_channels (int, *optional*, defaults to 3): Number of channels in the output.
|
618 |
+
down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
|
619 |
+
Tuple of downsample block types.
|
620 |
+
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
|
621 |
+
Tuple of upsample block types.
|
622 |
+
block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
|
623 |
+
Tuple of block output channels.
|
624 |
+
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
|
625 |
+
sample_size (`int`, *optional*, defaults to `32`): Sample input size.
|
626 |
+
scaling_factor (`float`, *optional*, defaults to `1.15258426`):
|
627 |
+
The component-wise standard deviation of the trained latent space computed using the first batch of the
|
628 |
+
training set. This is used to scale the latent space to have unit variance when training the diffusion
|
629 |
+
model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
|
630 |
+
diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
|
631 |
+
/ scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
|
632 |
+
Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
|
633 |
+
force_upcast (`bool`, *optional*, default to `True`):
|
634 |
+
If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE
|
635 |
+
can be fine-tuned / trained to a lower range without loosing too much precision in which case
|
636 |
+
`force_upcast` can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix
|
637 |
+
"""
|
638 |
+
|
639 |
+
_supports_gradient_checkpointing = True
|
640 |
+
_no_split_modules = ["CogVideoXResnetBlock3D"]
|
641 |
+
|
642 |
+
@register_to_config
|
643 |
+
def __init__(
|
644 |
+
self,
|
645 |
+
in_channels: int = 3,
|
646 |
+
out_channels: int = 3,
|
647 |
+
down_block_types: Tuple[str] = (
|
648 |
+
"CogVideoXDownBlock3D",
|
649 |
+
"CogVideoXDownBlock3D",
|
650 |
+
"CogVideoXDownBlock3D",
|
651 |
+
"CogVideoXDownBlock3D",
|
652 |
+
),
|
653 |
+
up_block_types: Tuple[str] = (
|
654 |
+
"CogVideoXUpBlock3D",
|
655 |
+
"CogVideoXUpBlock3D",
|
656 |
+
"CogVideoXUpBlock3D",
|
657 |
+
"CogVideoXUpBlock3D",
|
658 |
+
),
|
659 |
+
block_out_channels: Tuple[int] = (128, 256, 256, 512),
|
660 |
+
latent_channels: int = 16,
|
661 |
+
layers_per_block: int = 3,
|
662 |
+
act_fn: str = "silu",
|
663 |
+
norm_eps: float = 1e-6,
|
664 |
+
norm_num_groups: int = 32,
|
665 |
+
temporal_compression_ratio: float = 4,
|
666 |
+
sample_height: int = 480,
|
667 |
+
sample_width: int = 720,
|
668 |
+
scaling_factor: float = 1.15258426,
|
669 |
+
shift_factor: Optional[float] = None,
|
670 |
+
latents_mean: Optional[Tuple[float]] = None,
|
671 |
+
latents_std: Optional[Tuple[float]] = None,
|
672 |
+
force_upcast: float = True,
|
673 |
+
use_quant_conv: bool = False,
|
674 |
+
use_post_quant_conv: bool = False,
|
675 |
+
slice_mag_vae=False,
|
676 |
+
slice_compression_vae=False,
|
677 |
+
cache_compression_vae=False,
|
678 |
+
cache_mag_vae=True,
|
679 |
+
use_tiling=False,
|
680 |
+
mini_batch_encoder=4,
|
681 |
+
mini_batch_decoder=1,
|
682 |
+
):
|
683 |
+
super().__init__()
|
684 |
+
|
685 |
+
self.encoder = CogVideoXEncoder3D(
|
686 |
+
in_channels=in_channels,
|
687 |
+
out_channels=latent_channels,
|
688 |
+
down_block_types=down_block_types,
|
689 |
+
block_out_channels=block_out_channels,
|
690 |
+
layers_per_block=layers_per_block,
|
691 |
+
act_fn=act_fn,
|
692 |
+
norm_eps=norm_eps,
|
693 |
+
norm_num_groups=norm_num_groups,
|
694 |
+
temporal_compression_ratio=temporal_compression_ratio,
|
695 |
+
)
|
696 |
+
self.decoder = CogVideoXDecoder3D(
|
697 |
+
in_channels=latent_channels,
|
698 |
+
out_channels=out_channels,
|
699 |
+
up_block_types=up_block_types,
|
700 |
+
block_out_channels=block_out_channels,
|
701 |
+
layers_per_block=layers_per_block,
|
702 |
+
act_fn=act_fn,
|
703 |
+
norm_eps=norm_eps,
|
704 |
+
norm_num_groups=norm_num_groups,
|
705 |
+
temporal_compression_ratio=temporal_compression_ratio,
|
706 |
+
)
|
707 |
+
self.quant_conv = CogVideoXSafeConv3d(2 * out_channels, 2 * out_channels, 1) if use_quant_conv else None
|
708 |
+
self.post_quant_conv = CogVideoXSafeConv3d(out_channels, out_channels, 1) if use_post_quant_conv else None
|
709 |
+
|
710 |
+
self.use_slicing = False
|
711 |
+
self.use_tiling = use_tiling
|
712 |
+
|
713 |
+
# Can be increased to decode more latent frames at once, but comes at a reasonable memory cost and it is not
|
714 |
+
# recommended because the temporal parts of the VAE, here, are tricky to understand.
|
715 |
+
# If you decode X latent frames together, the number of output frames is:
|
716 |
+
# (X + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale)) => X + 6 frames
|
717 |
+
#
|
718 |
+
# Example with num_latent_frames_batch_size = 2:
|
719 |
+
# - 12 latent frames: (0, 1), (2, 3), (4, 5), (6, 7), (8, 9), (10, 11) are processed together
|
720 |
+
# => (12 // 2 frame slices) * ((2 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale))
|
721 |
+
# => 6 * 8 = 48 frames
|
722 |
+
# - 13 latent frames: (0, 1, 2) (special case), (3, 4), (5, 6), (7, 8), (9, 10), (11, 12) are processed together
|
723 |
+
# => (1 frame slice) * ((3 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale)) +
|
724 |
+
# ((13 - 3) // 2) * ((2 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale))
|
725 |
+
# => 1 * 9 + 5 * 8 = 49 frames
|
726 |
+
# It has been implemented this way so as to not have "magic values" in the code base that would be hard to explain. Note that
|
727 |
+
# setting it to anything other than 2 would give poor results because the VAE hasn't been trained to be adaptive with different
|
728 |
+
# number of temporal frames.
|
729 |
+
self.num_latent_frames_batch_size = 2
|
730 |
+
|
731 |
+
# We make the minimum height and width of sample for tiling half that of the generally supported
|
732 |
+
self.tile_sample_min_height = sample_height // 2
|
733 |
+
self.tile_sample_min_width = sample_width // 2
|
734 |
+
self.tile_latent_min_height = int(
|
735 |
+
self.tile_sample_min_height / (2 ** (len(self.config.block_out_channels) - 1))
|
736 |
+
)
|
737 |
+
self.tile_latent_min_width = int(self.tile_sample_min_width / (2 ** (len(self.config.block_out_channels) - 1)))
|
738 |
+
|
739 |
+
# These are experimental overlap factors that were chosen based on experimentation and seem to work best for
|
740 |
+
# 720x480 (WxH) resolution. The above resolution is the strongly recommended generation resolution in CogVideoX
|
741 |
+
# and so the tiling implementation has only been tested on those specific resolutions.
|
742 |
+
self.tile_overlap_factor_height = 1 / 6
|
743 |
+
self.tile_overlap_factor_width = 1 / 5
|
744 |
+
|
745 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
746 |
+
if isinstance(module, (CogVideoXEncoder3D, CogVideoXDecoder3D)):
|
747 |
+
module.gradient_checkpointing = value
|
748 |
+
|
749 |
+
def _clear_fake_context_parallel_cache(self):
|
750 |
+
for name, module in self.named_modules():
|
751 |
+
if isinstance(module, CogVideoXCausalConv3d):
|
752 |
+
logger.debug(f"Clearing fake Context Parallel cache for layer: {name}")
|
753 |
+
module._clear_fake_context_parallel_cache()
|
754 |
+
|
755 |
+
def enable_tiling(
|
756 |
+
self,
|
757 |
+
tile_sample_min_height: Optional[int] = None,
|
758 |
+
tile_sample_min_width: Optional[int] = None,
|
759 |
+
tile_overlap_factor_height: Optional[float] = None,
|
760 |
+
tile_overlap_factor_width: Optional[float] = None,
|
761 |
+
) -> None:
|
762 |
+
r"""
|
763 |
+
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
|
764 |
+
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
|
765 |
+
processing larger images.
|
766 |
+
|
767 |
+
Args:
|
768 |
+
tile_sample_min_height (`int`, *optional*):
|
769 |
+
The minimum height required for a sample to be separated into tiles across the height dimension.
|
770 |
+
tile_sample_min_width (`int`, *optional*):
|
771 |
+
The minimum width required for a sample to be separated into tiles across the width dimension.
|
772 |
+
tile_overlap_factor_height (`int`, *optional*):
|
773 |
+
The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
|
774 |
+
no tiling artifacts produced across the height dimension. Must be between 0 and 1. Setting a higher
|
775 |
+
value might cause more tiles to be processed leading to slow down of the decoding process.
|
776 |
+
tile_overlap_factor_width (`int`, *optional*):
|
777 |
+
The minimum amount of overlap between two consecutive horizontal tiles. This is to ensure that there
|
778 |
+
are no tiling artifacts produced across the width dimension. Must be between 0 and 1. Setting a higher
|
779 |
+
value might cause more tiles to be processed leading to slow down of the decoding process.
|
780 |
+
"""
|
781 |
+
self.use_tiling = True
|
782 |
+
self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
|
783 |
+
self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
|
784 |
+
self.tile_latent_min_height = int(
|
785 |
+
self.tile_sample_min_height / (2 ** (len(self.config.block_out_channels) - 1))
|
786 |
+
)
|
787 |
+
self.tile_latent_min_width = int(self.tile_sample_min_width / (2 ** (len(self.config.block_out_channels) - 1)))
|
788 |
+
self.tile_overlap_factor_height = tile_overlap_factor_height or self.tile_overlap_factor_height
|
789 |
+
self.tile_overlap_factor_width = tile_overlap_factor_width or self.tile_overlap_factor_width
|
790 |
+
|
791 |
+
def disable_tiling(self) -> None:
|
792 |
+
r"""
|
793 |
+
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
|
794 |
+
decoding in one step.
|
795 |
+
"""
|
796 |
+
self.use_tiling = False
|
797 |
+
|
798 |
+
def enable_slicing(self) -> None:
|
799 |
+
r"""
|
800 |
+
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
801 |
+
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
802 |
+
"""
|
803 |
+
self.use_slicing = True
|
804 |
+
|
805 |
+
def disable_slicing(self) -> None:
|
806 |
+
r"""
|
807 |
+
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
|
808 |
+
decoding in one step.
|
809 |
+
"""
|
810 |
+
self.use_slicing = False
|
811 |
+
|
812 |
+
@apply_forward_hook
|
813 |
+
def encode(
|
814 |
+
self, x: torch.Tensor, return_dict: bool = True
|
815 |
+
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
|
816 |
+
"""
|
817 |
+
Encode a batch of images into latents.
|
818 |
+
|
819 |
+
Args:
|
820 |
+
x (`torch.Tensor`): Input batch of images.
|
821 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
822 |
+
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
|
823 |
+
|
824 |
+
Returns:
|
825 |
+
The latent representations of the encoded images. If `return_dict` is True, a
|
826 |
+
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
|
827 |
+
"""
|
828 |
+
batch_size, num_channels, num_frames, height, width = x.shape
|
829 |
+
if num_frames == 1:
|
830 |
+
h = self.encoder(x)
|
831 |
+
if self.quant_conv is not None:
|
832 |
+
h = self.quant_conv(h)
|
833 |
+
posterior = DiagonalGaussianDistribution(h)
|
834 |
+
else:
|
835 |
+
frame_batch_size = 4
|
836 |
+
h = []
|
837 |
+
for i in range(num_frames // frame_batch_size):
|
838 |
+
remaining_frames = num_frames % frame_batch_size
|
839 |
+
start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames)
|
840 |
+
end_frame = frame_batch_size * (i + 1) + remaining_frames
|
841 |
+
z_intermediate = x[:, :, start_frame:end_frame]
|
842 |
+
z_intermediate = self.encoder(z_intermediate)
|
843 |
+
if self.quant_conv is not None:
|
844 |
+
z_intermediate = self.quant_conv(z_intermediate)
|
845 |
+
h.append(z_intermediate)
|
846 |
+
self._clear_fake_context_parallel_cache()
|
847 |
+
h = torch.cat(h, dim=2)
|
848 |
+
posterior = DiagonalGaussianDistribution(h)
|
849 |
+
self._clear_fake_context_parallel_cache()
|
850 |
+
if not return_dict:
|
851 |
+
return (posterior,)
|
852 |
+
return AutoencoderKLOutput(latent_dist=posterior)
|
853 |
+
|
854 |
+
def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
855 |
+
batch_size, num_channels, num_frames, height, width = z.shape
|
856 |
+
|
857 |
+
if self.use_tiling and (width > self.tile_latent_min_width or height > self.tile_latent_min_height):
|
858 |
+
return self.tiled_decode(z, return_dict=return_dict)
|
859 |
+
|
860 |
+
if num_frames == 1:
|
861 |
+
dec = []
|
862 |
+
z_intermediate = z
|
863 |
+
if self.post_quant_conv is not None:
|
864 |
+
z_intermediate = self.post_quant_conv(z_intermediate)
|
865 |
+
z_intermediate = self.decoder(z_intermediate)
|
866 |
+
dec.append(z_intermediate)
|
867 |
+
else:
|
868 |
+
frame_batch_size = self.num_latent_frames_batch_size
|
869 |
+
dec = []
|
870 |
+
for i in range(num_frames // frame_batch_size):
|
871 |
+
remaining_frames = num_frames % frame_batch_size
|
872 |
+
start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames)
|
873 |
+
end_frame = frame_batch_size * (i + 1) + remaining_frames
|
874 |
+
z_intermediate = z[:, :, start_frame:end_frame]
|
875 |
+
if self.post_quant_conv is not None:
|
876 |
+
z_intermediate = self.post_quant_conv(z_intermediate)
|
877 |
+
z_intermediate = self.decoder(z_intermediate)
|
878 |
+
dec.append(z_intermediate)
|
879 |
+
|
880 |
+
self._clear_fake_context_parallel_cache()
|
881 |
+
dec = torch.cat(dec, dim=2)
|
882 |
+
|
883 |
+
if not return_dict:
|
884 |
+
return (dec,)
|
885 |
+
|
886 |
+
return DecoderOutput(sample=dec)
|
887 |
+
|
888 |
+
@apply_forward_hook
|
889 |
+
def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
890 |
+
"""
|
891 |
+
Decode a batch of images.
|
892 |
+
|
893 |
+
Args:
|
894 |
+
z (`torch.Tensor`): Input batch of latent vectors.
|
895 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
896 |
+
Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
|
897 |
+
|
898 |
+
Returns:
|
899 |
+
[`~models.vae.DecoderOutput`] or `tuple`:
|
900 |
+
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
|
901 |
+
returned.
|
902 |
+
"""
|
903 |
+
if self.use_slicing and z.shape[0] > 1:
|
904 |
+
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
|
905 |
+
decoded = torch.cat(decoded_slices)
|
906 |
+
else:
|
907 |
+
decoded = self._decode(z).sample
|
908 |
+
|
909 |
+
if not return_dict:
|
910 |
+
return (decoded,)
|
911 |
+
return DecoderOutput(sample=decoded)
|
912 |
+
|
913 |
+
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
914 |
+
blend_extent = min(a.shape[3], b.shape[3], blend_extent)
|
915 |
+
for y in range(blend_extent):
|
916 |
+
b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (
|
917 |
+
y / blend_extent
|
918 |
+
)
|
919 |
+
return b
|
920 |
+
|
921 |
+
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
922 |
+
blend_extent = min(a.shape[4], b.shape[4], blend_extent)
|
923 |
+
for x in range(blend_extent):
|
924 |
+
b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (
|
925 |
+
x / blend_extent
|
926 |
+
)
|
927 |
+
return b
|
928 |
+
|
929 |
+
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
930 |
+
r"""
|
931 |
+
Decode a batch of images using a tiled decoder.
|
932 |
+
|
933 |
+
Args:
|
934 |
+
z (`torch.Tensor`): Input batch of latent vectors.
|
935 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
936 |
+
Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
|
937 |
+
|
938 |
+
Returns:
|
939 |
+
[`~models.vae.DecoderOutput`] or `tuple`:
|
940 |
+
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
|
941 |
+
returned.
|
942 |
+
"""
|
943 |
+
# Rough memory assessment:
|
944 |
+
# - In CogVideoX-2B, there are a total of 24 CausalConv3d layers.
|
945 |
+
# - The biggest intermediate dimensions are: [1, 128, 9, 480, 720].
|
946 |
+
# - Assume fp16 (2 bytes per value).
|
947 |
+
# Memory required: 1 * 128 * 9 * 480 * 720 * 24 * 2 / 1024**3 = 17.8 GB
|
948 |
+
#
|
949 |
+
# Memory assessment when using tiling:
|
950 |
+
# - Assume everything as above but now HxW is 240x360 by tiling in half
|
951 |
+
# Memory required: 1 * 128 * 9 * 240 * 360 * 24 * 2 / 1024**3 = 4.5 GB
|
952 |
+
|
953 |
+
batch_size, num_channels, num_frames, height, width = z.shape
|
954 |
+
|
955 |
+
overlap_height = int(self.tile_latent_min_height * (1 - self.tile_overlap_factor_height))
|
956 |
+
overlap_width = int(self.tile_latent_min_width * (1 - self.tile_overlap_factor_width))
|
957 |
+
blend_extent_height = int(self.tile_sample_min_height * self.tile_overlap_factor_height)
|
958 |
+
blend_extent_width = int(self.tile_sample_min_width * self.tile_overlap_factor_width)
|
959 |
+
row_limit_height = self.tile_sample_min_height - blend_extent_height
|
960 |
+
row_limit_width = self.tile_sample_min_width - blend_extent_width
|
961 |
+
frame_batch_size = self.num_latent_frames_batch_size
|
962 |
+
|
963 |
+
# Split z into overlapping tiles and decode them separately.
|
964 |
+
# The tiles have an overlap to avoid seams between tiles.
|
965 |
+
rows = []
|
966 |
+
for i in range(0, height, overlap_height):
|
967 |
+
row = []
|
968 |
+
for j in range(0, width, overlap_width):
|
969 |
+
time = []
|
970 |
+
for k in range(num_frames // frame_batch_size):
|
971 |
+
remaining_frames = num_frames % frame_batch_size
|
972 |
+
start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames)
|
973 |
+
end_frame = frame_batch_size * (k + 1) + remaining_frames
|
974 |
+
tile = z[
|
975 |
+
:,
|
976 |
+
:,
|
977 |
+
start_frame:end_frame,
|
978 |
+
i : i + self.tile_latent_min_height,
|
979 |
+
j : j + self.tile_latent_min_width,
|
980 |
+
]
|
981 |
+
if self.post_quant_conv is not None:
|
982 |
+
tile = self.post_quant_conv(tile)
|
983 |
+
tile = self.decoder(tile)
|
984 |
+
time.append(tile)
|
985 |
+
self._clear_fake_context_parallel_cache()
|
986 |
+
row.append(torch.cat(time, dim=2))
|
987 |
+
rows.append(row)
|
988 |
+
|
989 |
+
result_rows = []
|
990 |
+
for i, row in enumerate(rows):
|
991 |
+
result_row = []
|
992 |
+
for j, tile in enumerate(row):
|
993 |
+
# blend the above tile and the left tile
|
994 |
+
# to the current tile and add the current tile to the result row
|
995 |
+
if i > 0:
|
996 |
+
tile = self.blend_v(rows[i - 1][j], tile, blend_extent_height)
|
997 |
+
if j > 0:
|
998 |
+
tile = self.blend_h(row[j - 1], tile, blend_extent_width)
|
999 |
+
result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width])
|
1000 |
+
result_rows.append(torch.cat(result_row, dim=4))
|
1001 |
+
|
1002 |
+
dec = torch.cat(result_rows, dim=3)
|
1003 |
+
|
1004 |
+
if not return_dict:
|
1005 |
+
return (dec,)
|
1006 |
+
|
1007 |
+
return DecoderOutput(sample=dec)
|
1008 |
+
|
1009 |
+
def forward(
|
1010 |
+
self,
|
1011 |
+
sample: torch.Tensor,
|
1012 |
+
sample_posterior: bool = False,
|
1013 |
+
return_dict: bool = True,
|
1014 |
+
generator: Optional[torch.Generator] = None,
|
1015 |
+
) -> Union[torch.Tensor, torch.Tensor]:
|
1016 |
+
x = sample
|
1017 |
+
posterior = self.encode(x).latent_dist
|
1018 |
+
if sample_posterior:
|
1019 |
+
z = posterior.sample(generator=generator)
|
1020 |
+
else:
|
1021 |
+
z = posterior.mode()
|
1022 |
+
dec = self.decode(z)
|
1023 |
+
if not return_dict:
|
1024 |
+
return (dec,)
|
1025 |
+
return dec
|
easyanimate/models/embeddings.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from typing import Optional
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from diffusers.models.embeddings import (PixArtAlphaTextProjection, get_timestep_embedding,
|
8 |
+
TimestepEmbedding, Timesteps)
|
9 |
+
from einops import rearrange
|
10 |
+
from torch import nn
|
11 |
+
|
12 |
+
|
13 |
+
class HunyuanDiTAttentionPool(nn.Module):
|
14 |
+
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
|
15 |
+
super().__init__()
|
16 |
+
self.positional_embedding = nn.Parameter(torch.randn(spacial_dim + 1, embed_dim) / embed_dim**0.5)
|
17 |
+
self.k_proj = nn.Linear(embed_dim, embed_dim)
|
18 |
+
self.q_proj = nn.Linear(embed_dim, embed_dim)
|
19 |
+
self.v_proj = nn.Linear(embed_dim, embed_dim)
|
20 |
+
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
|
21 |
+
self.num_heads = num_heads
|
22 |
+
|
23 |
+
def forward(self, x):
|
24 |
+
x = torch.cat([x.mean(dim=1, keepdim=True), x], dim=1)
|
25 |
+
x = x + self.positional_embedding[None, :, :].to(x.dtype)
|
26 |
+
|
27 |
+
query = self.q_proj(x[:, :1])
|
28 |
+
key = self.k_proj(x)
|
29 |
+
value = self.v_proj(x)
|
30 |
+
batch_size, _, _ = query.size()
|
31 |
+
|
32 |
+
query = query.reshape(batch_size, -1, self.num_heads, query.size(-1) // self.num_heads).transpose(1, 2) # (1, H, N, E/H)
|
33 |
+
key = key.reshape(batch_size, -1, self.num_heads, key.size(-1) // self.num_heads).transpose(1, 2) # (L+1, H, N, E/H)
|
34 |
+
value = value.reshape(batch_size, -1, self.num_heads, value.size(-1) // self.num_heads).transpose(1, 2) # (L+1, H, N, E/H)
|
35 |
+
|
36 |
+
x = F.scaled_dot_product_attention(query=query, key=key, value=value, attn_mask=None, dropout_p=0.0, is_causal=False)
|
37 |
+
x = x.transpose(1, 2).reshape(batch_size, 1, -1)
|
38 |
+
x = x.to(query.dtype)
|
39 |
+
x = self.c_proj(x)
|
40 |
+
|
41 |
+
return x.squeeze(1)
|
42 |
+
|
43 |
+
|
44 |
+
class HunyuanCombinedTimestepTextSizeStyleEmbedding(nn.Module):
|
45 |
+
def __init__(self, embedding_dim, pooled_projection_dim=1024, seq_len=256, cross_attention_dim=2048):
|
46 |
+
super().__init__()
|
47 |
+
|
48 |
+
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
|
49 |
+
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
50 |
+
|
51 |
+
self.pooler = HunyuanDiTAttentionPool(
|
52 |
+
seq_len, cross_attention_dim, num_heads=8, output_dim=pooled_projection_dim
|
53 |
+
)
|
54 |
+
# Here we use a default learned embedder layer for future extension.
|
55 |
+
self.style_embedder = nn.Embedding(1, embedding_dim)
|
56 |
+
extra_in_dim = 256 * 6 + embedding_dim + pooled_projection_dim
|
57 |
+
self.extra_embedder = PixArtAlphaTextProjection(
|
58 |
+
in_features=extra_in_dim,
|
59 |
+
hidden_size=embedding_dim * 4,
|
60 |
+
out_features=embedding_dim,
|
61 |
+
act_fn="silu_fp32",
|
62 |
+
)
|
63 |
+
|
64 |
+
def forward(self, timestep, encoder_hidden_states, image_meta_size, style, hidden_dtype=None):
|
65 |
+
timesteps_proj = self.time_proj(timestep)
|
66 |
+
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, 256)
|
67 |
+
|
68 |
+
# extra condition1: text
|
69 |
+
pooled_projections = self.pooler(encoder_hidden_states) # (N, 1024)
|
70 |
+
|
71 |
+
# extra condition2: image meta size embdding
|
72 |
+
image_meta_size = get_timestep_embedding(image_meta_size.view(-1), 256, True, 0)
|
73 |
+
image_meta_size = image_meta_size.to(dtype=hidden_dtype)
|
74 |
+
image_meta_size = image_meta_size.view(-1, 6 * 256) # (N, 1536)
|
75 |
+
|
76 |
+
# extra condition3: style embedding
|
77 |
+
style_embedding = self.style_embedder(style) # (N, embedding_dim)
|
78 |
+
|
79 |
+
# Concatenate all extra vectors
|
80 |
+
extra_cond = torch.cat([pooled_projections, image_meta_size, style_embedding], dim=1)
|
81 |
+
conditioning = timesteps_emb + self.extra_embedder(extra_cond) # [B, D]
|
82 |
+
|
83 |
+
return conditioning
|
84 |
+
|
85 |
+
|
86 |
+
class TimePositionalEncoding(nn.Module):
|
87 |
+
def __init__(
|
88 |
+
self,
|
89 |
+
d_model,
|
90 |
+
dropout = 0.,
|
91 |
+
max_len = 24
|
92 |
+
):
|
93 |
+
super().__init__()
|
94 |
+
self.dropout = nn.Dropout(p=dropout)
|
95 |
+
position = torch.arange(max_len).unsqueeze(1)
|
96 |
+
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
|
97 |
+
pe = torch.zeros(1, max_len, d_model)
|
98 |
+
pe[0, :, 0::2] = torch.sin(position * div_term)
|
99 |
+
pe[0, :, 1::2] = torch.cos(position * div_term)
|
100 |
+
self.register_buffer('pe', pe)
|
101 |
+
|
102 |
+
def forward(self, x):
|
103 |
+
b, c, f, h, w = x.size()
|
104 |
+
x = rearrange(x, "b c f h w -> (b h w) f c")
|
105 |
+
x = x + self.pe[:, :x.size(1)]
|
106 |
+
x = rearrange(x, "(b h w) f c -> b c f h w", b=b, h=h, w=w)
|
107 |
+
return self.dropout(x)
|
easyanimate/models/norm.py
CHANGED
@@ -2,7 +2,8 @@ from typing import Any, Dict, Optional, Tuple
|
|
2 |
|
3 |
import torch
|
4 |
import torch.nn.functional as F
|
5 |
-
from diffusers.models.embeddings import
|
|
|
6 |
from torch import nn
|
7 |
|
8 |
|
@@ -12,7 +13,6 @@ def zero_module(module):
|
|
12 |
p.detach().zero_()
|
13 |
return module
|
14 |
|
15 |
-
|
16 |
class FP32LayerNorm(nn.LayerNorm):
|
17 |
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
18 |
origin_dtype = inputs.dtype
|
@@ -95,3 +95,56 @@ class AdaLayerNormSingle(nn.Module):
|
|
95 |
# No modulation happening here.
|
96 |
embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype)
|
97 |
return self.linear(self.silu(embedded_timestep)), embedded_timestep
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
import torch
|
4 |
import torch.nn.functional as F
|
5 |
+
from diffusers.models.embeddings import (CombinedTimestepLabelEmbeddings,
|
6 |
+
TimestepEmbedding, Timesteps)
|
7 |
from torch import nn
|
8 |
|
9 |
|
|
|
13 |
p.detach().zero_()
|
14 |
return module
|
15 |
|
|
|
16 |
class FP32LayerNorm(nn.LayerNorm):
|
17 |
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
18 |
origin_dtype = inputs.dtype
|
|
|
95 |
# No modulation happening here.
|
96 |
embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype)
|
97 |
return self.linear(self.silu(embedded_timestep)), embedded_timestep
|
98 |
+
|
99 |
+
class AdaLayerNormShift(nn.Module):
|
100 |
+
r"""
|
101 |
+
Norm layer modified to incorporate timestep embeddings.
|
102 |
+
|
103 |
+
Parameters:
|
104 |
+
embedding_dim (`int`): The size of each embedding vector.
|
105 |
+
num_embeddings (`int`): The size of the embeddings dictionary.
|
106 |
+
"""
|
107 |
+
|
108 |
+
def __init__(self, embedding_dim: int, elementwise_affine=True, eps=1e-6):
|
109 |
+
super().__init__()
|
110 |
+
self.silu = nn.SiLU()
|
111 |
+
self.linear = nn.Linear(embedding_dim, embedding_dim)
|
112 |
+
self.norm = FP32LayerNorm(embedding_dim, elementwise_affine=elementwise_affine, eps=eps)
|
113 |
+
|
114 |
+
def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
|
115 |
+
shift = self.linear(self.silu(emb.to(torch.float32)).to(emb.dtype))
|
116 |
+
x = self.norm(x) + shift.unsqueeze(dim=1)
|
117 |
+
return x
|
118 |
+
|
119 |
+
class EasyAnimateLayerNormZero(nn.Module):
|
120 |
+
# Modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/normalization.py
|
121 |
+
# Add fp32 layer norm
|
122 |
+
def __init__(
|
123 |
+
self,
|
124 |
+
conditioning_dim: int,
|
125 |
+
embedding_dim: int,
|
126 |
+
elementwise_affine: bool = True,
|
127 |
+
eps: float = 1e-5,
|
128 |
+
bias: bool = True,
|
129 |
+
norm_type: str = "fp32_layer_norm",
|
130 |
+
) -> None:
|
131 |
+
super().__init__()
|
132 |
+
|
133 |
+
self.silu = nn.SiLU()
|
134 |
+
self.linear = nn.Linear(conditioning_dim, 6 * embedding_dim, bias=bias)
|
135 |
+
if norm_type == "layer_norm":
|
136 |
+
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=elementwise_affine, eps=eps)
|
137 |
+
elif norm_type == "fp32_layer_norm":
|
138 |
+
self.norm = FP32LayerNorm(embedding_dim, elementwise_affine=elementwise_affine, eps=eps)
|
139 |
+
else:
|
140 |
+
raise ValueError(
|
141 |
+
f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'."
|
142 |
+
)
|
143 |
+
|
144 |
+
def forward(
|
145 |
+
self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor
|
146 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
147 |
+
shift, scale, gate, enc_shift, enc_scale, enc_gate = self.linear(self.silu(temb)).chunk(6, dim=1)
|
148 |
+
hidden_states = self.norm(hidden_states) * (1 + scale)[:, None, :] + shift[:, None, :]
|
149 |
+
encoder_hidden_states = self.norm(encoder_hidden_states) * (1 + enc_scale)[:, None, :] + enc_shift[:, None, :]
|
150 |
+
return hidden_states, encoder_hidden_states, gate[:, None, :], enc_gate[:, None, :]
|
easyanimate/models/patch.py
CHANGED
@@ -153,15 +153,6 @@ class TemporalUpsampler3D(Upsampler):
|
|
153 |
x = torch.cat([first_frame, x], dim=2)
|
154 |
return x
|
155 |
|
156 |
-
def cast_tuple(t, length = 1):
|
157 |
-
return t if isinstance(t, tuple) else ((t,) * length)
|
158 |
-
|
159 |
-
def divisible_by(num, den):
|
160 |
-
return (num % den) == 0
|
161 |
-
|
162 |
-
def is_odd(n):
|
163 |
-
return not divisible_by(n, 2)
|
164 |
-
|
165 |
class CausalConv3d(nn.Conv3d):
|
166 |
def __init__(
|
167 |
self,
|
|
|
153 |
x = torch.cat([first_frame, x], dim=2)
|
154 |
return x
|
155 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
156 |
class CausalConv3d(nn.Conv3d):
|
157 |
def __init__(
|
158 |
self,
|
easyanimate/models/processor.py
ADDED
@@ -0,0 +1,312 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from diffusers.models.attention import Attention
|
6 |
+
from diffusers.models.embeddings import apply_rotary_emb
|
7 |
+
from einops import rearrange, repeat
|
8 |
+
|
9 |
+
|
10 |
+
class HunyuanAttnProcessor2_0:
|
11 |
+
r"""
|
12 |
+
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
|
13 |
+
used in the HunyuanDiT model. It applies a s normalization layer and rotary embedding on query and key vector.
|
14 |
+
"""
|
15 |
+
|
16 |
+
def __init__(self):
|
17 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
18 |
+
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
19 |
+
|
20 |
+
def __call__(
|
21 |
+
self,
|
22 |
+
attn: Attention,
|
23 |
+
hidden_states: torch.Tensor,
|
24 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
25 |
+
attention_mask: Optional[torch.Tensor] = None,
|
26 |
+
temb: Optional[torch.Tensor] = None,
|
27 |
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
28 |
+
) -> torch.Tensor:
|
29 |
+
residual = hidden_states
|
30 |
+
if attn.spatial_norm is not None:
|
31 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
32 |
+
|
33 |
+
input_ndim = hidden_states.ndim
|
34 |
+
|
35 |
+
if input_ndim == 4:
|
36 |
+
batch_size, channel, height, width = hidden_states.shape
|
37 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
38 |
+
|
39 |
+
batch_size, sequence_length, _ = (
|
40 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
41 |
+
)
|
42 |
+
|
43 |
+
if attention_mask is not None:
|
44 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
45 |
+
# scaled_dot_product_attention expects attention_mask shape to be
|
46 |
+
# (batch, heads, source_length, target_length)
|
47 |
+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
48 |
+
|
49 |
+
if attn.group_norm is not None:
|
50 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
51 |
+
|
52 |
+
query = attn.to_q(hidden_states)
|
53 |
+
|
54 |
+
if encoder_hidden_states is None:
|
55 |
+
encoder_hidden_states = hidden_states
|
56 |
+
elif attn.norm_cross:
|
57 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
58 |
+
|
59 |
+
key = attn.to_k(encoder_hidden_states)
|
60 |
+
value = attn.to_v(encoder_hidden_states)
|
61 |
+
|
62 |
+
inner_dim = key.shape[-1]
|
63 |
+
head_dim = inner_dim // attn.heads
|
64 |
+
|
65 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
66 |
+
|
67 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
68 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
69 |
+
|
70 |
+
if attn.norm_q is not None:
|
71 |
+
query = attn.norm_q(query)
|
72 |
+
if attn.norm_k is not None:
|
73 |
+
key = attn.norm_k(key)
|
74 |
+
|
75 |
+
# Apply RoPE if needed
|
76 |
+
if image_rotary_emb is not None:
|
77 |
+
query = apply_rotary_emb(query, image_rotary_emb)
|
78 |
+
if not attn.is_cross_attention:
|
79 |
+
key = apply_rotary_emb(key, image_rotary_emb)
|
80 |
+
|
81 |
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
82 |
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
83 |
+
hidden_states = F.scaled_dot_product_attention(
|
84 |
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
85 |
+
)
|
86 |
+
|
87 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
88 |
+
hidden_states = hidden_states.to(query.dtype)
|
89 |
+
|
90 |
+
# linear proj
|
91 |
+
hidden_states = attn.to_out[0](hidden_states)
|
92 |
+
# dropout
|
93 |
+
hidden_states = attn.to_out[1](hidden_states)
|
94 |
+
|
95 |
+
if input_ndim == 4:
|
96 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
97 |
+
|
98 |
+
if attn.residual_connection:
|
99 |
+
hidden_states = hidden_states + residual
|
100 |
+
|
101 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
102 |
+
|
103 |
+
return hidden_states
|
104 |
+
|
105 |
+
class LazyKVCompressionProcessor2_0:
|
106 |
+
r"""
|
107 |
+
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
|
108 |
+
used in the KVCompression model. It applies a s normalization layer and rotary embedding on query and key vector.
|
109 |
+
"""
|
110 |
+
|
111 |
+
def __init__(self):
|
112 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
113 |
+
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
114 |
+
|
115 |
+
def __call__(
|
116 |
+
self,
|
117 |
+
attn: Attention,
|
118 |
+
hidden_states: torch.Tensor,
|
119 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
120 |
+
attention_mask: Optional[torch.Tensor] = None,
|
121 |
+
temb: Optional[torch.Tensor] = None,
|
122 |
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
123 |
+
) -> torch.Tensor:
|
124 |
+
residual = hidden_states
|
125 |
+
if attn.spatial_norm is not None:
|
126 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
127 |
+
|
128 |
+
input_ndim = hidden_states.ndim
|
129 |
+
|
130 |
+
batch_size, channel, num_frames, height, width = hidden_states.shape
|
131 |
+
hidden_states = rearrange(hidden_states, "b c f h w -> b (f h w) c", f=num_frames, h=height, w=width)
|
132 |
+
|
133 |
+
batch_size, sequence_length, _ = (
|
134 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
135 |
+
)
|
136 |
+
|
137 |
+
if attention_mask is not None:
|
138 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
139 |
+
# scaled_dot_product_attention expects attention_mask shape to be
|
140 |
+
# (batch, heads, source_length, target_length)
|
141 |
+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
142 |
+
|
143 |
+
if attn.group_norm is not None:
|
144 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
145 |
+
|
146 |
+
query = attn.to_q(hidden_states)
|
147 |
+
|
148 |
+
if encoder_hidden_states is None:
|
149 |
+
encoder_hidden_states = hidden_states
|
150 |
+
elif attn.norm_cross:
|
151 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
152 |
+
|
153 |
+
key = attn.to_k(encoder_hidden_states)
|
154 |
+
value = attn.to_v(encoder_hidden_states)
|
155 |
+
|
156 |
+
key = rearrange(key, "b (f h w) c -> (b f) c h w", f=num_frames, h=height, w=width)
|
157 |
+
key = attn.k_compression(key)
|
158 |
+
key_shape = key.size()
|
159 |
+
key = rearrange(key, "(b f) c h w -> b (f h w) c", f=num_frames)
|
160 |
+
|
161 |
+
value = rearrange(value, "b (f h w) c -> (b f) c h w", f=num_frames, h=height, w=width)
|
162 |
+
value = attn.v_compression(value)
|
163 |
+
value = rearrange(value, "(b f) c h w -> b (f h w) c", f=num_frames)
|
164 |
+
|
165 |
+
inner_dim = key.shape[-1]
|
166 |
+
head_dim = inner_dim // attn.heads
|
167 |
+
|
168 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
169 |
+
|
170 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
171 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
172 |
+
|
173 |
+
if attn.norm_q is not None:
|
174 |
+
query = attn.norm_q(query)
|
175 |
+
if attn.norm_k is not None:
|
176 |
+
key = attn.norm_k(key)
|
177 |
+
|
178 |
+
# Apply RoPE if needed
|
179 |
+
if image_rotary_emb is not None:
|
180 |
+
compression_image_rotary_emb = (
|
181 |
+
rearrange(image_rotary_emb[0], "(f h w) c -> f c h w", f=num_frames, h=height, w=width),
|
182 |
+
rearrange(image_rotary_emb[1], "(f h w) c -> f c h w", f=num_frames, h=height, w=width),
|
183 |
+
)
|
184 |
+
compression_image_rotary_emb = (
|
185 |
+
F.interpolate(compression_image_rotary_emb[0], size=key_shape[-2:], mode='bilinear'),
|
186 |
+
F.interpolate(compression_image_rotary_emb[1], size=key_shape[-2:], mode='bilinear')
|
187 |
+
)
|
188 |
+
compression_image_rotary_emb = (
|
189 |
+
rearrange(compression_image_rotary_emb[0], "f c h w -> (f h w) c"),
|
190 |
+
rearrange(compression_image_rotary_emb[1], "f c h w -> (f h w) c"),
|
191 |
+
)
|
192 |
+
|
193 |
+
query = apply_rotary_emb(query, image_rotary_emb)
|
194 |
+
if not attn.is_cross_attention:
|
195 |
+
key = apply_rotary_emb(key, compression_image_rotary_emb)
|
196 |
+
|
197 |
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
198 |
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
199 |
+
hidden_states = F.scaled_dot_product_attention(
|
200 |
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
201 |
+
)
|
202 |
+
|
203 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
204 |
+
hidden_states = hidden_states.to(query.dtype)
|
205 |
+
|
206 |
+
# linear proj
|
207 |
+
hidden_states = attn.to_out[0](hidden_states)
|
208 |
+
# dropout
|
209 |
+
hidden_states = attn.to_out[1](hidden_states)
|
210 |
+
|
211 |
+
if attn.residual_connection:
|
212 |
+
hidden_states = hidden_states + residual
|
213 |
+
|
214 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
215 |
+
|
216 |
+
return hidden_states
|
217 |
+
|
218 |
+
class EasyAnimateAttnProcessor2_0:
|
219 |
+
def __init__(self):
|
220 |
+
pass
|
221 |
+
|
222 |
+
def __call__(
|
223 |
+
self,
|
224 |
+
attn: Attention,
|
225 |
+
hidden_states: torch.Tensor,
|
226 |
+
encoder_hidden_states: torch.Tensor,
|
227 |
+
attention_mask: Optional[torch.Tensor] = None,
|
228 |
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
229 |
+
attn2: Attention = None,
|
230 |
+
) -> torch.Tensor:
|
231 |
+
text_seq_length = encoder_hidden_states.size(1)
|
232 |
+
|
233 |
+
batch_size, sequence_length, _ = (
|
234 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
235 |
+
)
|
236 |
+
|
237 |
+
if attention_mask is not None:
|
238 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
239 |
+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
240 |
+
|
241 |
+
if attn2 is None:
|
242 |
+
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
243 |
+
|
244 |
+
query = attn.to_q(hidden_states)
|
245 |
+
key = attn.to_k(hidden_states)
|
246 |
+
value = attn.to_v(hidden_states)
|
247 |
+
|
248 |
+
inner_dim = key.shape[-1]
|
249 |
+
head_dim = inner_dim // attn.heads
|
250 |
+
|
251 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
252 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
253 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
254 |
+
|
255 |
+
if attn.norm_q is not None:
|
256 |
+
query = attn.norm_q(query)
|
257 |
+
if attn.norm_k is not None:
|
258 |
+
key = attn.norm_k(key)
|
259 |
+
|
260 |
+
if attn2 is not None:
|
261 |
+
query_txt = attn2.to_q(encoder_hidden_states)
|
262 |
+
key_txt = attn2.to_k(encoder_hidden_states)
|
263 |
+
value_txt = attn2.to_v(encoder_hidden_states)
|
264 |
+
|
265 |
+
inner_dim = key_txt.shape[-1]
|
266 |
+
head_dim = inner_dim // attn.heads
|
267 |
+
|
268 |
+
query_txt = query_txt.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
269 |
+
key_txt = key_txt.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
270 |
+
value_txt = value_txt.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
271 |
+
|
272 |
+
if attn2.norm_q is not None:
|
273 |
+
query_txt = attn2.norm_q(query_txt)
|
274 |
+
if attn2.norm_k is not None:
|
275 |
+
key_txt = attn2.norm_k(key_txt)
|
276 |
+
|
277 |
+
query = torch.cat([query_txt, query], dim=2)
|
278 |
+
key = torch.cat([key_txt, key], dim=2)
|
279 |
+
value = torch.cat([value_txt, value], dim=2)
|
280 |
+
|
281 |
+
# Apply RoPE if needed
|
282 |
+
if image_rotary_emb is not None:
|
283 |
+
query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
|
284 |
+
if not attn.is_cross_attention:
|
285 |
+
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
|
286 |
+
|
287 |
+
hidden_states = F.scaled_dot_product_attention(
|
288 |
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
289 |
+
)
|
290 |
+
|
291 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
292 |
+
|
293 |
+
if attn2 is None:
|
294 |
+
# linear proj
|
295 |
+
hidden_states = attn.to_out[0](hidden_states)
|
296 |
+
# dropout
|
297 |
+
hidden_states = attn.to_out[1](hidden_states)
|
298 |
+
|
299 |
+
encoder_hidden_states, hidden_states = hidden_states.split(
|
300 |
+
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
|
301 |
+
)
|
302 |
+
else:
|
303 |
+
encoder_hidden_states, hidden_states = hidden_states.split(
|
304 |
+
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
|
305 |
+
)
|
306 |
+
# linear proj
|
307 |
+
hidden_states = attn.to_out[0](hidden_states)
|
308 |
+
encoder_hidden_states = attn2.to_out[0](encoder_hidden_states)
|
309 |
+
# dropout
|
310 |
+
hidden_states = attn.to_out[1](hidden_states)
|
311 |
+
encoder_hidden_states = attn2.to_out[1](encoder_hidden_states)
|
312 |
+
return hidden_states, encoder_hidden_states
|
easyanimate/models/resampler.py
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Alibaba Cloud.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import math
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
from torch import nn
|
11 |
+
from torch.nn import functional as F
|
12 |
+
from torch.nn.init import normal_
|
13 |
+
|
14 |
+
|
15 |
+
def get_abs_pos(abs_pos, tgt_size):
|
16 |
+
# abs_pos: L, C
|
17 |
+
# tgt_size: M
|
18 |
+
# return: M, C
|
19 |
+
src_size = int(math.sqrt(abs_pos.size(0)))
|
20 |
+
tgt_size = int(math.sqrt(tgt_size))
|
21 |
+
dtype = abs_pos.dtype
|
22 |
+
|
23 |
+
if src_size != tgt_size:
|
24 |
+
return F.interpolate(
|
25 |
+
abs_pos.float().reshape(1, src_size, src_size, -1).permute(0, 3, 1, 2),
|
26 |
+
size=(tgt_size, tgt_size),
|
27 |
+
mode="bicubic",
|
28 |
+
align_corners=False,
|
29 |
+
).permute(0, 2, 3, 1).flatten(0, 2).to(dtype=dtype)
|
30 |
+
else:
|
31 |
+
return abs_pos
|
32 |
+
|
33 |
+
# https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20
|
34 |
+
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
|
35 |
+
"""
|
36 |
+
grid_size: int of the grid height and width
|
37 |
+
return:
|
38 |
+
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
39 |
+
"""
|
40 |
+
grid_h = np.arange(grid_size, dtype=np.float32)
|
41 |
+
grid_w = np.arange(grid_size, dtype=np.float32)
|
42 |
+
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
43 |
+
grid = np.stack(grid, axis=0)
|
44 |
+
|
45 |
+
grid = grid.reshape([2, 1, grid_size, grid_size])
|
46 |
+
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
47 |
+
if cls_token:
|
48 |
+
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
|
49 |
+
return pos_embed
|
50 |
+
|
51 |
+
|
52 |
+
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
53 |
+
assert embed_dim % 2 == 0
|
54 |
+
|
55 |
+
# use half of dimensions to encode grid_h
|
56 |
+
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
57 |
+
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
58 |
+
|
59 |
+
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
60 |
+
return emb
|
61 |
+
|
62 |
+
|
63 |
+
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
64 |
+
"""
|
65 |
+
embed_dim: output dimension for each position
|
66 |
+
pos: a list of positions to be encoded: size (M,)
|
67 |
+
out: (M, D)
|
68 |
+
"""
|
69 |
+
assert embed_dim % 2 == 0
|
70 |
+
omega = np.arange(embed_dim // 2, dtype=np.float32)
|
71 |
+
omega /= embed_dim / 2.
|
72 |
+
omega = 1. / 10000**omega # (D/2,)
|
73 |
+
|
74 |
+
pos = pos.reshape(-1) # (M,)
|
75 |
+
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
|
76 |
+
|
77 |
+
emb_sin = np.sin(out) # (M, D/2)
|
78 |
+
emb_cos = np.cos(out) # (M, D/2)
|
79 |
+
|
80 |
+
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
81 |
+
return emb
|
82 |
+
|
83 |
+
class Resampler(nn.Module):
|
84 |
+
"""
|
85 |
+
A 2D perceiver-resampler network with one cross attention layers by
|
86 |
+
(grid_size**2) learnable queries and 2d sincos pos_emb
|
87 |
+
Outputs:
|
88 |
+
A tensor with the shape of (grid_size**2, embed_dim)
|
89 |
+
"""
|
90 |
+
def __init__(
|
91 |
+
self,
|
92 |
+
grid_size,
|
93 |
+
embed_dim,
|
94 |
+
num_heads,
|
95 |
+
kv_dim=None,
|
96 |
+
norm_layer=nn.LayerNorm
|
97 |
+
):
|
98 |
+
super().__init__()
|
99 |
+
self.num_queries = grid_size ** 2
|
100 |
+
self.embed_dim = embed_dim
|
101 |
+
self.num_heads = num_heads
|
102 |
+
|
103 |
+
self.pos_embed = nn.Parameter(
|
104 |
+
torch.from_numpy(get_2d_sincos_pos_embed(embed_dim, grid_size)).float()
|
105 |
+
).requires_grad_(False)
|
106 |
+
|
107 |
+
self.query = nn.Parameter(torch.zeros(self.num_queries, embed_dim))
|
108 |
+
normal_(self.query, std=.02)
|
109 |
+
|
110 |
+
if kv_dim is not None and kv_dim != embed_dim:
|
111 |
+
self.kv_proj = nn.Linear(kv_dim, embed_dim, bias=False)
|
112 |
+
else:
|
113 |
+
self.kv_proj = nn.Identity()
|
114 |
+
|
115 |
+
self.attn = nn.MultiheadAttention(embed_dim, num_heads)
|
116 |
+
self.ln_q = norm_layer(embed_dim)
|
117 |
+
self.ln_kv = norm_layer(embed_dim)
|
118 |
+
|
119 |
+
self.apply(self._init_weights)
|
120 |
+
|
121 |
+
def _init_weights(self, m):
|
122 |
+
if isinstance(m, nn.Linear):
|
123 |
+
normal_(m.weight, std=.02)
|
124 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
125 |
+
nn.init.constant_(m.bias, 0)
|
126 |
+
elif isinstance(m, nn.LayerNorm):
|
127 |
+
nn.init.constant_(m.bias, 0)
|
128 |
+
nn.init.constant_(m.weight, 1.0)
|
129 |
+
|
130 |
+
def forward(self, x, key_padding_mask=None):
|
131 |
+
pos_embed = get_abs_pos(self.pos_embed, x.size(1))
|
132 |
+
|
133 |
+
x = self.kv_proj(x)
|
134 |
+
x = self.ln_kv(x).permute(1, 0, 2)
|
135 |
+
|
136 |
+
N = x.shape[1]
|
137 |
+
q = self.ln_q(self.query)
|
138 |
+
out = self.attn(
|
139 |
+
self._repeat(q, N) + self.pos_embed.unsqueeze(1),
|
140 |
+
x + pos_embed.unsqueeze(1),
|
141 |
+
x,
|
142 |
+
key_padding_mask=key_padding_mask)[0]
|
143 |
+
return out.permute(1, 0, 2)
|
144 |
+
|
145 |
+
def _repeat(self, query, N: int):
|
146 |
+
return query.unsqueeze(1).repeat(1, N, 1)
|
easyanimate/models/transformer2d.py
CHANGED
@@ -37,10 +37,6 @@ except:
|
|
37 |
from diffusers.models.embeddings import \
|
38 |
CaptionProjection as PixArtAlphaTextProjection
|
39 |
|
40 |
-
from .attention import (KVCompressionTransformerBlock,
|
41 |
-
SelfAttentionTemporalTransformerBlock,
|
42 |
-
TemporalTransformerBlock)
|
43 |
-
|
44 |
|
45 |
@dataclass
|
46 |
class Transformer2DModelOutput(BaseOutput):
|
@@ -196,58 +192,29 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
|
196 |
interpolation_scale=interpolation_scale,
|
197 |
)
|
198 |
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
kvcompression=False if d < 14 else True,
|
223 |
-
)
|
224 |
-
for d in range(num_layers)
|
225 |
-
]
|
226 |
-
)
|
227 |
-
else:
|
228 |
-
# 3. Define transformers blocks
|
229 |
-
self.transformer_blocks = nn.ModuleList(
|
230 |
-
[
|
231 |
-
BasicTransformerBlock(
|
232 |
-
inner_dim,
|
233 |
-
num_attention_heads,
|
234 |
-
attention_head_dim,
|
235 |
-
dropout=dropout,
|
236 |
-
cross_attention_dim=cross_attention_dim,
|
237 |
-
activation_fn=activation_fn,
|
238 |
-
num_embeds_ada_norm=num_embeds_ada_norm,
|
239 |
-
attention_bias=attention_bias,
|
240 |
-
only_cross_attention=only_cross_attention,
|
241 |
-
double_self_attention=double_self_attention,
|
242 |
-
upcast_attention=upcast_attention,
|
243 |
-
norm_type=norm_type,
|
244 |
-
norm_elementwise_affine=norm_elementwise_affine,
|
245 |
-
norm_eps=norm_eps,
|
246 |
-
attention_type=attention_type,
|
247 |
-
)
|
248 |
-
for d in range(num_layers)
|
249 |
-
]
|
250 |
-
)
|
251 |
|
252 |
# 4. Define output layers
|
253 |
self.out_channels = in_channels if out_channels is None else out_channels
|
@@ -413,7 +380,6 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
|
413 |
if self.training and self.gradient_checkpointing:
|
414 |
args = {
|
415 |
"basic": [],
|
416 |
-
"kvcompression": [1, height, width],
|
417 |
}[self.basic_block_type]
|
418 |
hidden_states = torch.utils.checkpoint.checkpoint(
|
419 |
block,
|
@@ -430,7 +396,6 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
|
430 |
else:
|
431 |
kwargs = {
|
432 |
"basic": {},
|
433 |
-
"kvcompression": {"num_frames":1, "height":height, "width":width},
|
434 |
}[self.basic_block_type]
|
435 |
hidden_states = block(
|
436 |
hidden_states,
|
|
|
37 |
from diffusers.models.embeddings import \
|
38 |
CaptionProjection as PixArtAlphaTextProjection
|
39 |
|
|
|
|
|
|
|
|
|
40 |
|
41 |
@dataclass
|
42 |
class Transformer2DModelOutput(BaseOutput):
|
|
|
192 |
interpolation_scale=interpolation_scale,
|
193 |
)
|
194 |
|
195 |
+
# 3. Define transformers blocks
|
196 |
+
self.transformer_blocks = nn.ModuleList(
|
197 |
+
[
|
198 |
+
BasicTransformerBlock(
|
199 |
+
inner_dim,
|
200 |
+
num_attention_heads,
|
201 |
+
attention_head_dim,
|
202 |
+
dropout=dropout,
|
203 |
+
cross_attention_dim=cross_attention_dim,
|
204 |
+
activation_fn=activation_fn,
|
205 |
+
num_embeds_ada_norm=num_embeds_ada_norm,
|
206 |
+
attention_bias=attention_bias,
|
207 |
+
only_cross_attention=only_cross_attention,
|
208 |
+
double_self_attention=double_self_attention,
|
209 |
+
upcast_attention=upcast_attention,
|
210 |
+
norm_type=norm_type,
|
211 |
+
norm_elementwise_affine=norm_elementwise_affine,
|
212 |
+
norm_eps=norm_eps,
|
213 |
+
attention_type=attention_type,
|
214 |
+
)
|
215 |
+
for d in range(num_layers)
|
216 |
+
]
|
217 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
218 |
|
219 |
# 4. Define output layers
|
220 |
self.out_channels = in_channels if out_channels is None else out_channels
|
|
|
380 |
if self.training and self.gradient_checkpointing:
|
381 |
args = {
|
382 |
"basic": [],
|
|
|
383 |
}[self.basic_block_type]
|
384 |
hidden_states = torch.utils.checkpoint.checkpoint(
|
385 |
block,
|
|
|
396 |
else:
|
397 |
kwargs = {
|
398 |
"basic": {},
|
|
|
399 |
}[self.basic_block_type]
|
400 |
hidden_states = block(
|
401 |
hidden_states,
|
easyanimate/models/transformer3d.py
CHANGED
@@ -11,34 +11,39 @@
|
|
11 |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
# See the License for the specific language governing permissions and
|
13 |
# limitations under the License.
|
|
|
|
|
14 |
import json
|
15 |
import math
|
16 |
import os
|
17 |
from dataclasses import dataclass
|
18 |
-
from typing import Any, Dict, Optional
|
19 |
|
20 |
import numpy as np
|
21 |
import torch
|
22 |
import torch.nn.functional as F
|
23 |
-
import torch.nn.init as init
|
24 |
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
25 |
-
from diffusers.models.attention import BasicTransformerBlock
|
26 |
from diffusers.models.embeddings import (PatchEmbed, PixArtAlphaTextProjection,
|
27 |
-
TimestepEmbedding, Timesteps
|
28 |
-
|
|
|
29 |
from diffusers.models.modeling_utils import ModelMixin
|
30 |
-
from diffusers.models.normalization import AdaLayerNormContinuous
|
31 |
from diffusers.utils import (USE_PEFT_BACKEND, BaseOutput, is_torch_version,
|
32 |
logging)
|
33 |
from diffusers.utils.torch_utils import maybe_allow_in_graph
|
34 |
from einops import rearrange
|
35 |
from torch import nn
|
36 |
|
37 |
-
from .attention import (
|
38 |
-
|
|
|
|
|
39 |
from .norm import AdaLayerNormSingle
|
40 |
-
from .patch import (CasualPatchEmbed3D,
|
41 |
TemporalUpsampler3D, UnPatch1D)
|
|
|
42 |
|
43 |
try:
|
44 |
from diffusers.models.embeddings import PixArtAlphaTextProjection
|
@@ -46,12 +51,6 @@ except:
|
|
46 |
from diffusers.models.embeddings import \
|
47 |
CaptionProjection as PixArtAlphaTextProjection
|
48 |
|
49 |
-
def zero_module(module):
|
50 |
-
# Zero out the parameters of a module and return it.
|
51 |
-
for p in module.parameters():
|
52 |
-
p.detach().zero_()
|
53 |
-
return module
|
54 |
-
|
55 |
|
56 |
class CLIPProjection(nn.Module):
|
57 |
"""
|
@@ -72,28 +71,6 @@ class CLIPProjection(nn.Module):
|
|
72 |
hidden_states = self.linear_2(hidden_states)
|
73 |
return hidden_states
|
74 |
|
75 |
-
class TimePositionalEncoding(nn.Module):
|
76 |
-
def __init__(
|
77 |
-
self,
|
78 |
-
d_model,
|
79 |
-
dropout = 0.,
|
80 |
-
max_len = 24
|
81 |
-
):
|
82 |
-
super().__init__()
|
83 |
-
self.dropout = nn.Dropout(p=dropout)
|
84 |
-
position = torch.arange(max_len).unsqueeze(1)
|
85 |
-
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
|
86 |
-
pe = torch.zeros(1, max_len, d_model)
|
87 |
-
pe[0, :, 0::2] = torch.sin(position * div_term)
|
88 |
-
pe[0, :, 1::2] = torch.cos(position * div_term)
|
89 |
-
self.register_buffer('pe', pe)
|
90 |
-
|
91 |
-
def forward(self, x):
|
92 |
-
b, c, f, h, w = x.size()
|
93 |
-
x = rearrange(x, "b c f h w -> (b h w) f c")
|
94 |
-
x = x + self.pe[:, :x.size(1)]
|
95 |
-
x = rearrange(x, "(b h w) f c -> b c f h w", b=b, h=h, w=w)
|
96 |
-
return self.dropout(x)
|
97 |
|
98 |
@dataclass
|
99 |
class Transformer3DModelOutput(BaseOutput):
|
@@ -189,6 +166,10 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
|
|
189 |
|
190 |
qk_norm = False,
|
191 |
after_norm = False,
|
|
|
|
|
|
|
|
|
192 |
):
|
193 |
super().__init__()
|
194 |
self.use_linear_projection = use_linear_projection
|
@@ -202,9 +183,6 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
|
|
202 |
self.casual_3d = casual_3d
|
203 |
self.casual_3d_upsampler_index = casual_3d_upsampler_index
|
204 |
|
205 |
-
conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
|
206 |
-
linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
|
207 |
-
|
208 |
assert sample_size is not None, "Transformer3DModel over patched input must provide sample_size"
|
209 |
|
210 |
self.height = sample_size
|
@@ -310,34 +288,6 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
|
|
310 |
for d in range(num_layers)
|
311 |
]
|
312 |
)
|
313 |
-
elif self.basic_block_type == "kvcompression_motionmodule":
|
314 |
-
self.transformer_blocks = nn.ModuleList(
|
315 |
-
[
|
316 |
-
TemporalTransformerBlock(
|
317 |
-
inner_dim,
|
318 |
-
num_attention_heads,
|
319 |
-
attention_head_dim,
|
320 |
-
dropout=dropout,
|
321 |
-
cross_attention_dim=cross_attention_dim,
|
322 |
-
activation_fn=activation_fn,
|
323 |
-
num_embeds_ada_norm=num_embeds_ada_norm,
|
324 |
-
attention_bias=attention_bias,
|
325 |
-
only_cross_attention=only_cross_attention,
|
326 |
-
double_self_attention=double_self_attention,
|
327 |
-
upcast_attention=upcast_attention,
|
328 |
-
norm_type=norm_type,
|
329 |
-
norm_elementwise_affine=norm_elementwise_affine,
|
330 |
-
norm_eps=norm_eps,
|
331 |
-
attention_type=attention_type,
|
332 |
-
kvcompression=False if d < 14 else True,
|
333 |
-
motion_module_type=motion_module_type,
|
334 |
-
motion_module_kwargs=motion_module_kwargs,
|
335 |
-
qk_norm=qk_norm,
|
336 |
-
after_norm=after_norm,
|
337 |
-
)
|
338 |
-
for d in range(num_layers)
|
339 |
-
]
|
340 |
-
)
|
341 |
elif self.basic_block_type == "selfattentiontemporal":
|
342 |
self.transformer_blocks = nn.ModuleList(
|
343 |
[
|
@@ -448,6 +398,7 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
|
|
448 |
self,
|
449 |
hidden_states: torch.Tensor,
|
450 |
inpaint_latents: torch.Tensor = None,
|
|
|
451 |
encoder_hidden_states: Optional[torch.Tensor] = None,
|
452 |
clip_encoder_hidden_states: Optional[torch.Tensor] = None,
|
453 |
timestep: Optional[torch.LongTensor] = None,
|
@@ -524,6 +475,8 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
|
|
524 |
|
525 |
if inpaint_latents is not None:
|
526 |
hidden_states = torch.concat([hidden_states, inpaint_latents], 1)
|
|
|
|
|
527 |
# 1. Input
|
528 |
if self.casual_3d:
|
529 |
video_length, height, width = (hidden_states.shape[-3] - 1) // self.time_patch_size + 1, hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size
|
@@ -596,7 +549,6 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
|
|
596 |
"motionmodule": [video_length, height, width],
|
597 |
"global_motionmodule": [video_length, height, width],
|
598 |
"selfattentiontemporal": [],
|
599 |
-
"kvcompression_motionmodule": [video_length, height, width],
|
600 |
}[self.basic_block_type]
|
601 |
hidden_states = torch.utils.checkpoint.checkpoint(
|
602 |
create_custom_forward(block),
|
@@ -616,7 +568,6 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
|
|
616 |
"motionmodule": {"num_frames":video_length, "height":height, "width":width},
|
617 |
"global_motionmodule": {"num_frames":video_length, "height":height, "width":width},
|
618 |
"selfattentiontemporal": {},
|
619 |
-
"kvcompression_motionmodule": {"num_frames":video_length, "height":height, "width":width},
|
620 |
}[self.basic_block_type]
|
621 |
hidden_states = block(
|
622 |
hidden_states,
|
@@ -741,4 +692,745 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
|
|
741 |
params = [p.numel() if "attn_temporal." in n else 0 for n, p in model.named_parameters()]
|
742 |
print(f"### Attn temporal Parameters: {sum(params) / 1e6} M")
|
743 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
744 |
return model
|
|
|
11 |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
# See the License for the specific language governing permissions and
|
13 |
# limitations under the License.
|
14 |
+
|
15 |
+
import glob
|
16 |
import json
|
17 |
import math
|
18 |
import os
|
19 |
from dataclasses import dataclass
|
20 |
+
from typing import Any, Dict, Optional
|
21 |
|
22 |
import numpy as np
|
23 |
import torch
|
24 |
import torch.nn.functional as F
|
|
|
25 |
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
26 |
+
from diffusers.models.attention import BasicTransformerBlock
|
27 |
from diffusers.models.embeddings import (PatchEmbed, PixArtAlphaTextProjection,
|
28 |
+
TimestepEmbedding, Timesteps,
|
29 |
+
get_2d_sincos_pos_embed)
|
30 |
+
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
31 |
from diffusers.models.modeling_utils import ModelMixin
|
32 |
+
from diffusers.models.normalization import AdaLayerNorm, AdaLayerNormContinuous
|
33 |
from diffusers.utils import (USE_PEFT_BACKEND, BaseOutput, is_torch_version,
|
34 |
logging)
|
35 |
from diffusers.utils.torch_utils import maybe_allow_in_graph
|
36 |
from einops import rearrange
|
37 |
from torch import nn
|
38 |
|
39 |
+
from .attention import (EasyAnimateDiTBlock, HunyuanDiTBlock,
|
40 |
+
SelfAttentionTemporalTransformerBlock,
|
41 |
+
TemporalTransformerBlock, zero_module)
|
42 |
+
from .embeddings import HunyuanCombinedTimestepTextSizeStyleEmbedding, TimePositionalEncoding
|
43 |
from .norm import AdaLayerNormSingle
|
44 |
+
from .patch import (CasualPatchEmbed3D, PatchEmbed3D, PatchEmbedF3D,
|
45 |
TemporalUpsampler3D, UnPatch1D)
|
46 |
+
from .resampler import Resampler
|
47 |
|
48 |
try:
|
49 |
from diffusers.models.embeddings import PixArtAlphaTextProjection
|
|
|
51 |
from diffusers.models.embeddings import \
|
52 |
CaptionProjection as PixArtAlphaTextProjection
|
53 |
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
|
55 |
class CLIPProjection(nn.Module):
|
56 |
"""
|
|
|
71 |
hidden_states = self.linear_2(hidden_states)
|
72 |
return hidden_states
|
73 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
74 |
|
75 |
@dataclass
|
76 |
class Transformer3DModelOutput(BaseOutput):
|
|
|
166 |
|
167 |
qk_norm = False,
|
168 |
after_norm = False,
|
169 |
+
resize_inpaint_mask_directly: bool = False,
|
170 |
+
enable_clip_in_inpaint: bool = True,
|
171 |
+
enable_text_attention_mask: bool = True,
|
172 |
+
add_noise_in_inpaint_model: bool = False,
|
173 |
):
|
174 |
super().__init__()
|
175 |
self.use_linear_projection = use_linear_projection
|
|
|
183 |
self.casual_3d = casual_3d
|
184 |
self.casual_3d_upsampler_index = casual_3d_upsampler_index
|
185 |
|
|
|
|
|
|
|
186 |
assert sample_size is not None, "Transformer3DModel over patched input must provide sample_size"
|
187 |
|
188 |
self.height = sample_size
|
|
|
288 |
for d in range(num_layers)
|
289 |
]
|
290 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
291 |
elif self.basic_block_type == "selfattentiontemporal":
|
292 |
self.transformer_blocks = nn.ModuleList(
|
293 |
[
|
|
|
398 |
self,
|
399 |
hidden_states: torch.Tensor,
|
400 |
inpaint_latents: torch.Tensor = None,
|
401 |
+
control_latents: torch.Tensor = None,
|
402 |
encoder_hidden_states: Optional[torch.Tensor] = None,
|
403 |
clip_encoder_hidden_states: Optional[torch.Tensor] = None,
|
404 |
timestep: Optional[torch.LongTensor] = None,
|
|
|
475 |
|
476 |
if inpaint_latents is not None:
|
477 |
hidden_states = torch.concat([hidden_states, inpaint_latents], 1)
|
478 |
+
if control_latents is not None:
|
479 |
+
hidden_states = torch.concat([hidden_states, control_latents], 1)
|
480 |
# 1. Input
|
481 |
if self.casual_3d:
|
482 |
video_length, height, width = (hidden_states.shape[-3] - 1) // self.time_patch_size + 1, hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size
|
|
|
549 |
"motionmodule": [video_length, height, width],
|
550 |
"global_motionmodule": [video_length, height, width],
|
551 |
"selfattentiontemporal": [],
|
|
|
552 |
}[self.basic_block_type]
|
553 |
hidden_states = torch.utils.checkpoint.checkpoint(
|
554 |
create_custom_forward(block),
|
|
|
568 |
"motionmodule": {"num_frames":video_length, "height":height, "width":width},
|
569 |
"global_motionmodule": {"num_frames":video_length, "height":height, "width":width},
|
570 |
"selfattentiontemporal": {},
|
|
|
571 |
}[self.basic_block_type]
|
572 |
hidden_states = block(
|
573 |
hidden_states,
|
|
|
692 |
params = [p.numel() if "attn_temporal." in n else 0 for n, p in model.named_parameters()]
|
693 |
print(f"### Attn temporal Parameters: {sum(params) / 1e6} M")
|
694 |
|
695 |
+
return model
|
696 |
+
|
697 |
+
class HunyuanTransformer3DModel(ModelMixin, ConfigMixin):
|
698 |
+
"""
|
699 |
+
HunYuanDiT: Diffusion model with a Transformer backbone.
|
700 |
+
|
701 |
+
Inherit ModelMixin and ConfigMixin to be compatible with the sampler StableDiffusionPipeline of diffusers.
|
702 |
+
|
703 |
+
Parameters:
|
704 |
+
num_attention_heads (`int`, *optional*, defaults to 16):
|
705 |
+
The number of heads to use for multi-head attention.
|
706 |
+
attention_head_dim (`int`, *optional*, defaults to 88):
|
707 |
+
The number of channels in each head.
|
708 |
+
in_channels (`int`, *optional*):
|
709 |
+
The number of channels in the input and output (specify if the input is **continuous**).
|
710 |
+
patch_size (`int`, *optional*):
|
711 |
+
The size of the patch to use for the input.
|
712 |
+
activation_fn (`str`, *optional*, defaults to `"geglu"`):
|
713 |
+
Activation function to use in feed-forward.
|
714 |
+
sample_size (`int`, *optional*):
|
715 |
+
The width of the latent images. This is fixed during training since it is used to learn a number of
|
716 |
+
position embeddings.
|
717 |
+
dropout (`float`, *optional*, defaults to 0.0):
|
718 |
+
The dropout probability to use.
|
719 |
+
cross_attention_dim (`int`, *optional*):
|
720 |
+
The number of dimension in the clip text embedding.
|
721 |
+
hidden_size (`int`, *optional*):
|
722 |
+
The size of hidden layer in the conditioning embedding layers.
|
723 |
+
num_layers (`int`, *optional*, defaults to 1):
|
724 |
+
The number of layers of Transformer blocks to use.
|
725 |
+
mlp_ratio (`float`, *optional*, defaults to 4.0):
|
726 |
+
The ratio of the hidden layer size to the input size.
|
727 |
+
learn_sigma (`bool`, *optional*, defaults to `True`):
|
728 |
+
Whether to predict variance.
|
729 |
+
cross_attention_dim_t5 (`int`, *optional*):
|
730 |
+
The number dimensions in t5 text embedding.
|
731 |
+
pooled_projection_dim (`int`, *optional*):
|
732 |
+
The size of the pooled projection.
|
733 |
+
text_len (`int`, *optional*):
|
734 |
+
The length of the clip text embedding.
|
735 |
+
text_len_t5 (`int`, *optional*):
|
736 |
+
The length of the T5 text embedding.
|
737 |
+
"""
|
738 |
+
_supports_gradient_checkpointing = True
|
739 |
+
|
740 |
+
@register_to_config
|
741 |
+
def __init__(
|
742 |
+
self,
|
743 |
+
num_attention_heads: int = 16,
|
744 |
+
attention_head_dim: int = 88,
|
745 |
+
in_channels: Optional[int] = None,
|
746 |
+
out_channels: Optional[int] = None,
|
747 |
+
patch_size: Optional[int] = None,
|
748 |
+
|
749 |
+
n_query=16,
|
750 |
+
projection_dim=768,
|
751 |
+
activation_fn: str = "gelu-approximate",
|
752 |
+
sample_size=32,
|
753 |
+
hidden_size=1152,
|
754 |
+
num_layers: int = 28,
|
755 |
+
mlp_ratio: float = 4.0,
|
756 |
+
learn_sigma: bool = True,
|
757 |
+
cross_attention_dim: int = 1024,
|
758 |
+
norm_type: str = "layer_norm",
|
759 |
+
cross_attention_dim_t5: int = 2048,
|
760 |
+
pooled_projection_dim: int = 1024,
|
761 |
+
text_len: int = 77,
|
762 |
+
text_len_t5: int = 256,
|
763 |
+
|
764 |
+
# block type
|
765 |
+
basic_block_type: str = "basic",
|
766 |
+
|
767 |
+
time_position_encoding = False,
|
768 |
+
time_position_encoding_type: str = "2d_rope",
|
769 |
+
after_norm = False,
|
770 |
+
resize_inpaint_mask_directly: bool = False,
|
771 |
+
enable_clip_in_inpaint: bool = True,
|
772 |
+
enable_text_attention_mask: bool = True,
|
773 |
+
add_noise_in_inpaint_model: bool = False,
|
774 |
+
):
|
775 |
+
super().__init__()
|
776 |
+
# 4. Define output layers
|
777 |
+
if learn_sigma:
|
778 |
+
self.out_channels = in_channels * 2 if out_channels is None else out_channels
|
779 |
+
else:
|
780 |
+
self.out_channels = in_channels if out_channels is None else out_channels
|
781 |
+
self.enable_inpaint = in_channels * 2 != self.out_channels if learn_sigma else in_channels != self.out_channels
|
782 |
+
self.num_heads = num_attention_heads
|
783 |
+
self.inner_dim = num_attention_heads * attention_head_dim
|
784 |
+
self.basic_block_type = basic_block_type
|
785 |
+
self.resize_inpaint_mask_directly = resize_inpaint_mask_directly
|
786 |
+
self.text_embedder = PixArtAlphaTextProjection(
|
787 |
+
in_features=cross_attention_dim_t5,
|
788 |
+
hidden_size=cross_attention_dim_t5 * 4,
|
789 |
+
out_features=cross_attention_dim,
|
790 |
+
act_fn="silu_fp32",
|
791 |
+
)
|
792 |
+
|
793 |
+
self.text_embedding_padding = nn.Parameter(
|
794 |
+
torch.randn(text_len + text_len_t5, cross_attention_dim, dtype=torch.float32)
|
795 |
+
)
|
796 |
+
|
797 |
+
self.pos_embed = PatchEmbed(
|
798 |
+
height=sample_size,
|
799 |
+
width=sample_size,
|
800 |
+
in_channels=in_channels,
|
801 |
+
embed_dim=hidden_size,
|
802 |
+
patch_size=patch_size,
|
803 |
+
pos_embed_type=None,
|
804 |
+
)
|
805 |
+
|
806 |
+
self.time_extra_emb = HunyuanCombinedTimestepTextSizeStyleEmbedding(
|
807 |
+
hidden_size,
|
808 |
+
pooled_projection_dim=pooled_projection_dim,
|
809 |
+
seq_len=text_len_t5,
|
810 |
+
cross_attention_dim=cross_attention_dim_t5,
|
811 |
+
)
|
812 |
+
|
813 |
+
# 3. Define transformers blocks
|
814 |
+
if self.basic_block_type == "hybrid_attention":
|
815 |
+
self.blocks = nn.ModuleList(
|
816 |
+
[
|
817 |
+
HunyuanDiTBlock(
|
818 |
+
dim=self.inner_dim,
|
819 |
+
num_attention_heads=self.config.num_attention_heads,
|
820 |
+
activation_fn=activation_fn,
|
821 |
+
ff_inner_dim=int(self.inner_dim * mlp_ratio),
|
822 |
+
cross_attention_dim=cross_attention_dim,
|
823 |
+
qk_norm=True, # See http://arxiv.org/abs/2302.05442 for details.
|
824 |
+
skip=layer > num_layers // 2,
|
825 |
+
after_norm=after_norm,
|
826 |
+
time_position_encoding=time_position_encoding,
|
827 |
+
is_local_attention=False if layer % 2 == 0 else True,
|
828 |
+
local_attention_frames=2,
|
829 |
+
enable_inpaint=self.enable_inpaint and enable_clip_in_inpaint,
|
830 |
+
)
|
831 |
+
for layer in range(num_layers)
|
832 |
+
]
|
833 |
+
)
|
834 |
+
elif self.basic_block_type == "kvcompression_basic":
|
835 |
+
self.blocks = nn.ModuleList(
|
836 |
+
[
|
837 |
+
HunyuanDiTBlock(
|
838 |
+
dim=self.inner_dim,
|
839 |
+
num_attention_heads=self.config.num_attention_heads,
|
840 |
+
activation_fn=activation_fn,
|
841 |
+
ff_inner_dim=int(self.inner_dim * mlp_ratio),
|
842 |
+
cross_attention_dim=cross_attention_dim,
|
843 |
+
qk_norm=True, # See http://arxiv.org/abs/2302.05442 for details.
|
844 |
+
skip=layer > num_layers // 2,
|
845 |
+
after_norm=after_norm,
|
846 |
+
time_position_encoding=time_position_encoding,
|
847 |
+
kvcompression=False if layer < num_layers // 2 else True,
|
848 |
+
enable_inpaint=self.enable_inpaint and enable_clip_in_inpaint,
|
849 |
+
)
|
850 |
+
for layer in range(num_layers)
|
851 |
+
]
|
852 |
+
)
|
853 |
+
else:
|
854 |
+
self.blocks = nn.ModuleList(
|
855 |
+
[
|
856 |
+
HunyuanDiTBlock(
|
857 |
+
dim=self.inner_dim,
|
858 |
+
num_attention_heads=self.config.num_attention_heads,
|
859 |
+
activation_fn=activation_fn,
|
860 |
+
ff_inner_dim=int(self.inner_dim * mlp_ratio),
|
861 |
+
cross_attention_dim=cross_attention_dim,
|
862 |
+
qk_norm=True, # See http://arxiv.org/abs/2302.05442 for details.
|
863 |
+
skip=layer > num_layers // 2,
|
864 |
+
after_norm=after_norm,
|
865 |
+
time_position_encoding=time_position_encoding,
|
866 |
+
enable_inpaint=self.enable_inpaint and enable_clip_in_inpaint,
|
867 |
+
)
|
868 |
+
for layer in range(num_layers)
|
869 |
+
]
|
870 |
+
)
|
871 |
+
|
872 |
+
self.n_query = n_query
|
873 |
+
if self.enable_inpaint and enable_clip_in_inpaint:
|
874 |
+
self.clip_padding = nn.Parameter(
|
875 |
+
torch.randn((self.n_query, cross_attention_dim)) * 0.02
|
876 |
+
)
|
877 |
+
self.clip_projection = Resampler(
|
878 |
+
int(math.sqrt(n_query)),
|
879 |
+
embed_dim=cross_attention_dim,
|
880 |
+
num_heads=self.config.num_attention_heads,
|
881 |
+
kv_dim=projection_dim,
|
882 |
+
norm_layer=nn.LayerNorm,
|
883 |
+
)
|
884 |
+
else:
|
885 |
+
self.clip_padding = None
|
886 |
+
self.clip_projection = None
|
887 |
+
|
888 |
+
self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
|
889 |
+
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
|
890 |
+
|
891 |
+
self.gradient_checkpointing = False
|
892 |
+
|
893 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
894 |
+
if hasattr(module, "gradient_checkpointing"):
|
895 |
+
module.gradient_checkpointing = value
|
896 |
+
|
897 |
+
def forward(
|
898 |
+
self,
|
899 |
+
hidden_states,
|
900 |
+
timestep,
|
901 |
+
encoder_hidden_states=None,
|
902 |
+
text_embedding_mask=None,
|
903 |
+
encoder_hidden_states_t5=None,
|
904 |
+
text_embedding_mask_t5=None,
|
905 |
+
image_meta_size=None,
|
906 |
+
style=None,
|
907 |
+
image_rotary_emb=None,
|
908 |
+
inpaint_latents=None,
|
909 |
+
control_latents: torch.Tensor = None,
|
910 |
+
clip_encoder_hidden_states: Optional[torch.Tensor]=None,
|
911 |
+
clip_attention_mask: Optional[torch.Tensor]=None,
|
912 |
+
return_dict=True,
|
913 |
+
):
|
914 |
+
"""
|
915 |
+
The [`HunyuanDiT2DModel`] forward method.
|
916 |
+
|
917 |
+
Args:
|
918 |
+
hidden_states (`torch.Tensor` of shape `(batch size, dim, height, width)`):
|
919 |
+
The input tensor.
|
920 |
+
timestep ( `torch.LongTensor`, *optional*):
|
921 |
+
Used to indicate denoising step.
|
922 |
+
encoder_hidden_states ( `torch.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
|
923 |
+
Conditional embeddings for cross attention layer. This is the output of `BertModel`.
|
924 |
+
text_embedding_mask: torch.Tensor
|
925 |
+
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. This is the output
|
926 |
+
of `BertModel`.
|
927 |
+
encoder_hidden_states_t5 ( `torch.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
|
928 |
+
Conditional embeddings for cross attention layer. This is the output of T5 Text Encoder.
|
929 |
+
text_embedding_mask_t5: torch.Tensor
|
930 |
+
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. This is the output
|
931 |
+
of T5 Text Encoder.
|
932 |
+
image_meta_size (torch.Tensor):
|
933 |
+
Conditional embedding indicate the image sizes
|
934 |
+
style: torch.Tensor:
|
935 |
+
Conditional embedding indicate the style
|
936 |
+
image_rotary_emb (`torch.Tensor`):
|
937 |
+
The image rotary embeddings to apply on query and key tensors during attention calculation.
|
938 |
+
return_dict: bool
|
939 |
+
Whether to return a dictionary.
|
940 |
+
"""
|
941 |
+
if inpaint_latents is not None:
|
942 |
+
hidden_states = torch.concat([hidden_states, inpaint_latents], 1)
|
943 |
+
if control_latents is not None:
|
944 |
+
hidden_states = torch.concat([hidden_states, control_latents], 1)
|
945 |
+
|
946 |
+
# unpatchify: (N, out_channels, H, W)
|
947 |
+
patch_size = self.pos_embed.patch_size
|
948 |
+
video_length, height, width = hidden_states.shape[-3], hidden_states.shape[-2] // patch_size, hidden_states.shape[-1] // patch_size
|
949 |
+
hidden_states = rearrange(hidden_states, "b c f h w ->(b f) c h w")
|
950 |
+
hidden_states = self.pos_embed(hidden_states)
|
951 |
+
hidden_states = rearrange(hidden_states, "(b f) (h w) c -> b c f h w", f=video_length, h=height, w=width)
|
952 |
+
hidden_states = hidden_states.flatten(2).transpose(1, 2)
|
953 |
+
|
954 |
+
temb = self.time_extra_emb(
|
955 |
+
timestep, encoder_hidden_states_t5, image_meta_size, style, hidden_dtype=timestep.dtype
|
956 |
+
) # [B, D]
|
957 |
+
|
958 |
+
# text projection
|
959 |
+
batch_size, sequence_length, _ = encoder_hidden_states_t5.shape
|
960 |
+
encoder_hidden_states_t5 = self.text_embedder(
|
961 |
+
encoder_hidden_states_t5.view(-1, encoder_hidden_states_t5.shape[-1])
|
962 |
+
)
|
963 |
+
encoder_hidden_states_t5 = encoder_hidden_states_t5.view(batch_size, sequence_length, -1)
|
964 |
+
|
965 |
+
encoder_hidden_states = torch.cat([encoder_hidden_states, encoder_hidden_states_t5], dim=1)
|
966 |
+
text_embedding_mask = torch.cat([text_embedding_mask, text_embedding_mask_t5], dim=-1)
|
967 |
+
text_embedding_mask = text_embedding_mask.unsqueeze(2).bool()
|
968 |
+
|
969 |
+
encoder_hidden_states = torch.where(text_embedding_mask, encoder_hidden_states, self.text_embedding_padding)
|
970 |
+
|
971 |
+
if clip_encoder_hidden_states is not None:
|
972 |
+
batch_size = encoder_hidden_states.shape[0]
|
973 |
+
|
974 |
+
clip_encoder_hidden_states = self.clip_projection(clip_encoder_hidden_states)
|
975 |
+
clip_encoder_hidden_states = clip_encoder_hidden_states.view(batch_size, -1, encoder_hidden_states.shape[-1])
|
976 |
+
|
977 |
+
clip_attention_mask = clip_attention_mask.unsqueeze(2).bool()
|
978 |
+
clip_encoder_hidden_states = torch.where(clip_attention_mask, clip_encoder_hidden_states, self.clip_padding)
|
979 |
+
|
980 |
+
skips = []
|
981 |
+
for layer, block in enumerate(self.blocks):
|
982 |
+
if layer > self.config.num_layers // 2:
|
983 |
+
skip = skips.pop()
|
984 |
+
if self.training and self.gradient_checkpointing:
|
985 |
+
|
986 |
+
def create_custom_forward(module, return_dict=None):
|
987 |
+
def custom_forward(*inputs):
|
988 |
+
if return_dict is not None:
|
989 |
+
return module(*inputs, return_dict=return_dict)
|
990 |
+
else:
|
991 |
+
return module(*inputs)
|
992 |
+
|
993 |
+
return custom_forward
|
994 |
+
|
995 |
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
996 |
+
args = {
|
997 |
+
"kvcompression_basic": [video_length, height, width, clip_encoder_hidden_states],
|
998 |
+
"basic": [video_length, height, width, clip_encoder_hidden_states],
|
999 |
+
"hybrid_attention": [video_length, height, width, clip_encoder_hidden_states],
|
1000 |
+
}[self.basic_block_type]
|
1001 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
1002 |
+
create_custom_forward(block),
|
1003 |
+
hidden_states,
|
1004 |
+
encoder_hidden_states,
|
1005 |
+
temb,
|
1006 |
+
image_rotary_emb,
|
1007 |
+
skip,
|
1008 |
+
*args,
|
1009 |
+
**ckpt_kwargs,
|
1010 |
+
)
|
1011 |
+
else:
|
1012 |
+
kwargs = {
|
1013 |
+
"kvcompression_basic": {"num_frames":video_length, "height":height, "width":width, "clip_encoder_hidden_states":clip_encoder_hidden_states},
|
1014 |
+
"basic": {"num_frames":video_length, "height":height, "width":width, "clip_encoder_hidden_states":clip_encoder_hidden_states},
|
1015 |
+
"hybrid_attention": {"num_frames":video_length, "height":height, "width":width, "clip_encoder_hidden_states":clip_encoder_hidden_states},
|
1016 |
+
}[self.basic_block_type]
|
1017 |
+
hidden_states = block(
|
1018 |
+
hidden_states,
|
1019 |
+
temb=temb,
|
1020 |
+
encoder_hidden_states=encoder_hidden_states,
|
1021 |
+
image_rotary_emb=image_rotary_emb,
|
1022 |
+
skip=skip,
|
1023 |
+
**kwargs
|
1024 |
+
) # (N, L, D)
|
1025 |
+
else:
|
1026 |
+
if self.training and self.gradient_checkpointing:
|
1027 |
+
|
1028 |
+
def create_custom_forward(module, return_dict=None):
|
1029 |
+
def custom_forward(*inputs):
|
1030 |
+
if return_dict is not None:
|
1031 |
+
return module(*inputs, return_dict=return_dict)
|
1032 |
+
else:
|
1033 |
+
return module(*inputs)
|
1034 |
+
|
1035 |
+
return custom_forward
|
1036 |
+
|
1037 |
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
1038 |
+
args = {
|
1039 |
+
"kvcompression_basic": [None, video_length, height, width, clip_encoder_hidden_states, True if layer==0 else False],
|
1040 |
+
"basic": [None, video_length, height, width, clip_encoder_hidden_states, True if layer==0 else False],
|
1041 |
+
"hybrid_attention": [None, video_length, height, width, clip_encoder_hidden_states, True if layer==0 else False],
|
1042 |
+
}[self.basic_block_type]
|
1043 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
1044 |
+
create_custom_forward(block),
|
1045 |
+
hidden_states,
|
1046 |
+
encoder_hidden_states,
|
1047 |
+
temb,
|
1048 |
+
image_rotary_emb,
|
1049 |
+
*args,
|
1050 |
+
**ckpt_kwargs,
|
1051 |
+
)
|
1052 |
+
else:
|
1053 |
+
kwargs = {
|
1054 |
+
"kvcompression_basic": {"num_frames":video_length, "height":height, "width":width, "clip_encoder_hidden_states":clip_encoder_hidden_states},
|
1055 |
+
"basic": {"num_frames":video_length, "height":height, "width":width, "clip_encoder_hidden_states":clip_encoder_hidden_states},
|
1056 |
+
"hybrid_attention": {"num_frames":video_length, "height":height, "width":width, "clip_encoder_hidden_states":clip_encoder_hidden_states},
|
1057 |
+
}[self.basic_block_type]
|
1058 |
+
hidden_states = block(
|
1059 |
+
hidden_states,
|
1060 |
+
temb=temb,
|
1061 |
+
encoder_hidden_states=encoder_hidden_states,
|
1062 |
+
image_rotary_emb=image_rotary_emb,
|
1063 |
+
disable_image_rotary_emb_in_attn1=True if layer==0 else False,
|
1064 |
+
**kwargs
|
1065 |
+
) # (N, L, D)
|
1066 |
+
|
1067 |
+
if layer < (self.config.num_layers // 2 - 1):
|
1068 |
+
skips.append(hidden_states)
|
1069 |
+
|
1070 |
+
# final layer
|
1071 |
+
hidden_states = self.norm_out(hidden_states, temb.to(torch.float32))
|
1072 |
+
hidden_states = self.proj_out(hidden_states)
|
1073 |
+
# (N, L, patch_size ** 2 * out_channels)
|
1074 |
+
|
1075 |
+
hidden_states = hidden_states.reshape(
|
1076 |
+
shape=(hidden_states.shape[0], video_length, height, width, patch_size, patch_size, self.out_channels)
|
1077 |
+
)
|
1078 |
+
hidden_states = torch.einsum("nfhwpqc->ncfhpwq", hidden_states)
|
1079 |
+
output = hidden_states.reshape(
|
1080 |
+
shape=(hidden_states.shape[0], self.out_channels, video_length, height * patch_size, width * patch_size)
|
1081 |
+
)
|
1082 |
+
|
1083 |
+
if not return_dict:
|
1084 |
+
return (output,)
|
1085 |
+
return Transformer2DModelOutput(sample=output)
|
1086 |
+
|
1087 |
+
@classmethod
|
1088 |
+
def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, patch_size=2, transformer_additional_kwargs={}):
|
1089 |
+
if subfolder is not None:
|
1090 |
+
pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
|
1091 |
+
print(f"loaded 3D transformer's pretrained weights from {pretrained_model_path} ...")
|
1092 |
+
|
1093 |
+
config_file = os.path.join(pretrained_model_path, 'config.json')
|
1094 |
+
if not os.path.isfile(config_file):
|
1095 |
+
raise RuntimeError(f"{config_file} does not exist")
|
1096 |
+
with open(config_file, "r") as f:
|
1097 |
+
config = json.load(f)
|
1098 |
+
|
1099 |
+
from diffusers.utils import WEIGHTS_NAME
|
1100 |
+
model = cls.from_config(config, **transformer_additional_kwargs)
|
1101 |
+
model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
|
1102 |
+
model_file_safetensors = model_file.replace(".bin", ".safetensors")
|
1103 |
+
if os.path.exists(model_file_safetensors):
|
1104 |
+
from safetensors.torch import load_file, safe_open
|
1105 |
+
state_dict = load_file(model_file_safetensors)
|
1106 |
+
else:
|
1107 |
+
if not os.path.isfile(model_file):
|
1108 |
+
raise RuntimeError(f"{model_file} does not exist")
|
1109 |
+
state_dict = torch.load(model_file, map_location="cpu")
|
1110 |
+
|
1111 |
+
if model.state_dict()['pos_embed.proj.weight'].size() != state_dict['pos_embed.proj.weight'].size():
|
1112 |
+
new_shape = model.state_dict()['pos_embed.proj.weight'].size()
|
1113 |
+
if len(new_shape) == 5:
|
1114 |
+
state_dict['pos_embed.proj.weight'] = state_dict['pos_embed.proj.weight'].unsqueeze(2).expand(new_shape).clone()
|
1115 |
+
state_dict['pos_embed.proj.weight'][:, :, :-1] = 0
|
1116 |
+
else:
|
1117 |
+
if model.state_dict()['pos_embed.proj.weight'].size()[1] > state_dict['pos_embed.proj.weight'].size()[1]:
|
1118 |
+
model.state_dict()['pos_embed.proj.weight'][:, :state_dict['pos_embed.proj.weight'].size()[1], :, :] = state_dict['pos_embed.proj.weight']
|
1119 |
+
model.state_dict()['pos_embed.proj.weight'][:, state_dict['pos_embed.proj.weight'].size()[1]:, :, :] = 0
|
1120 |
+
state_dict['pos_embed.proj.weight'] = model.state_dict()['pos_embed.proj.weight']
|
1121 |
+
else:
|
1122 |
+
model.state_dict()['pos_embed.proj.weight'][:, :, :, :] = state_dict['pos_embed.proj.weight'][:, :model.state_dict()['pos_embed.proj.weight'].size()[1], :, :]
|
1123 |
+
state_dict['pos_embed.proj.weight'] = model.state_dict()['pos_embed.proj.weight']
|
1124 |
+
|
1125 |
+
if model.state_dict()['proj_out.weight'].size() != state_dict['proj_out.weight'].size():
|
1126 |
+
if model.state_dict()['proj_out.weight'].size()[0] > state_dict['proj_out.weight'].size()[0]:
|
1127 |
+
model.state_dict()['proj_out.weight'][:state_dict['proj_out.weight'].size()[0], :] = state_dict['proj_out.weight']
|
1128 |
+
state_dict['proj_out.weight'] = model.state_dict()['proj_out.weight']
|
1129 |
+
else:
|
1130 |
+
model.state_dict()['proj_out.weight'][:, :] = state_dict['proj_out.weight'][:model.state_dict()['proj_out.weight'].size()[0], :]
|
1131 |
+
state_dict['proj_out.weight'] = model.state_dict()['proj_out.weight']
|
1132 |
+
|
1133 |
+
if model.state_dict()['proj_out.bias'].size() != state_dict['proj_out.bias'].size():
|
1134 |
+
if model.state_dict()['proj_out.bias'].size()[0] > state_dict['proj_out.bias'].size()[0]:
|
1135 |
+
model.state_dict()['proj_out.bias'][:state_dict['proj_out.bias'].size()[0]] = state_dict['proj_out.bias']
|
1136 |
+
state_dict['proj_out.bias'] = model.state_dict()['proj_out.bias']
|
1137 |
+
else:
|
1138 |
+
model.state_dict()['proj_out.bias'][:, :] = state_dict['proj_out.bias'][:model.state_dict()['proj_out.bias'].size()[0], :]
|
1139 |
+
state_dict['proj_out.bias'] = model.state_dict()['proj_out.bias']
|
1140 |
+
|
1141 |
+
tmp_state_dict = {}
|
1142 |
+
for key in state_dict:
|
1143 |
+
if key in model.state_dict().keys() and model.state_dict()[key].size() == state_dict[key].size():
|
1144 |
+
tmp_state_dict[key] = state_dict[key]
|
1145 |
+
else:
|
1146 |
+
print(key, "Size don't match, skip")
|
1147 |
+
state_dict = tmp_state_dict
|
1148 |
+
|
1149 |
+
m, u = model.load_state_dict(state_dict, strict=False)
|
1150 |
+
print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
|
1151 |
+
print(m)
|
1152 |
+
|
1153 |
+
params = [p.numel() if "mamba" in n else 0 for n, p in model.named_parameters()]
|
1154 |
+
print(f"### Mamba Parameters: {sum(params) / 1e6} M")
|
1155 |
+
|
1156 |
+
params = [p.numel() if "attn1." in n else 0 for n, p in model.named_parameters()]
|
1157 |
+
print(f"### attn1 Parameters: {sum(params) / 1e6} M")
|
1158 |
+
|
1159 |
+
return model
|
1160 |
+
|
1161 |
+
class EasyAnimateTransformer3DModel(ModelMixin, ConfigMixin):
|
1162 |
+
_supports_gradient_checkpointing = True
|
1163 |
+
|
1164 |
+
@register_to_config
|
1165 |
+
def __init__(
|
1166 |
+
self,
|
1167 |
+
num_attention_heads: int = 30,
|
1168 |
+
attention_head_dim: int = 64,
|
1169 |
+
in_channels: Optional[int] = None,
|
1170 |
+
out_channels: Optional[int] = None,
|
1171 |
+
patch_size: Optional[int] = None,
|
1172 |
+
sample_width: int = 90,
|
1173 |
+
sample_height: int = 60,
|
1174 |
+
ref_channels: int = None,
|
1175 |
+
clip_channels: int = None,
|
1176 |
+
|
1177 |
+
activation_fn: str = "gelu-approximate",
|
1178 |
+
timestep_activation_fn: str = "silu",
|
1179 |
+
freq_shift: int = 0,
|
1180 |
+
num_layers: int = 30,
|
1181 |
+
dropout: float = 0.0,
|
1182 |
+
time_embed_dim: int = 512,
|
1183 |
+
text_embed_dim: int = 4096,
|
1184 |
+
text_embed_dim_t5: int = 4096,
|
1185 |
+
norm_eps: float = 1e-5,
|
1186 |
+
|
1187 |
+
norm_elementwise_affine: bool = True,
|
1188 |
+
flip_sin_to_cos: bool = True,
|
1189 |
+
|
1190 |
+
time_position_encoding_type: str = "3d_rope",
|
1191 |
+
after_norm = False,
|
1192 |
+
resize_inpaint_mask_directly: bool = False,
|
1193 |
+
enable_clip_in_inpaint: bool = True,
|
1194 |
+
enable_text_attention_mask: bool = True,
|
1195 |
+
add_noise_in_inpaint_model: bool = False,
|
1196 |
+
):
|
1197 |
+
super().__init__()
|
1198 |
+
self.num_heads = num_attention_heads
|
1199 |
+
self.inner_dim = num_attention_heads * attention_head_dim
|
1200 |
+
self.resize_inpaint_mask_directly = resize_inpaint_mask_directly
|
1201 |
+
self.patch_size = patch_size
|
1202 |
+
|
1203 |
+
post_patch_height = sample_height // patch_size
|
1204 |
+
post_patch_width = sample_width // patch_size
|
1205 |
+
self.post_patch_height = post_patch_height
|
1206 |
+
self.post_patch_width = post_patch_width
|
1207 |
+
|
1208 |
+
self.time_proj = Timesteps(self.inner_dim, flip_sin_to_cos, freq_shift)
|
1209 |
+
self.time_embedding = TimestepEmbedding(self.inner_dim, time_embed_dim, timestep_activation_fn)
|
1210 |
+
|
1211 |
+
self.proj = nn.Conv2d(
|
1212 |
+
in_channels, self.inner_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=True
|
1213 |
+
)
|
1214 |
+
self.text_proj = nn.Linear(text_embed_dim, self.inner_dim)
|
1215 |
+
self.text_proj_t5 = nn.Linear(text_embed_dim_t5, self.inner_dim)
|
1216 |
+
|
1217 |
+
if ref_channels is not None:
|
1218 |
+
self.ref_proj = nn.Conv2d(
|
1219 |
+
ref_channels, self.inner_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=True
|
1220 |
+
)
|
1221 |
+
ref_pos_embedding = get_2d_sincos_pos_embed(self.inner_dim, (post_patch_height, post_patch_width))
|
1222 |
+
ref_pos_embedding = torch.from_numpy(ref_pos_embedding)
|
1223 |
+
self.register_buffer("ref_pos_embedding", ref_pos_embedding, persistent=False)
|
1224 |
+
|
1225 |
+
if clip_channels is not None:
|
1226 |
+
self.clip_proj = nn.Linear(clip_channels, self.inner_dim)
|
1227 |
+
|
1228 |
+
self.transformer_blocks = nn.ModuleList(
|
1229 |
+
[
|
1230 |
+
EasyAnimateDiTBlock(
|
1231 |
+
dim=self.inner_dim,
|
1232 |
+
num_attention_heads=num_attention_heads,
|
1233 |
+
attention_head_dim=attention_head_dim,
|
1234 |
+
time_embed_dim=time_embed_dim,
|
1235 |
+
dropout=dropout,
|
1236 |
+
activation_fn=activation_fn,
|
1237 |
+
norm_elementwise_affine=norm_elementwise_affine,
|
1238 |
+
norm_eps=norm_eps,
|
1239 |
+
after_norm=after_norm
|
1240 |
+
)
|
1241 |
+
for _ in range(num_layers)
|
1242 |
+
]
|
1243 |
+
)
|
1244 |
+
self.norm_final = nn.LayerNorm(self.inner_dim, norm_eps, norm_elementwise_affine)
|
1245 |
+
|
1246 |
+
# 5. Output blocks
|
1247 |
+
self.norm_out = AdaLayerNorm(
|
1248 |
+
embedding_dim=time_embed_dim,
|
1249 |
+
output_dim=2 * self.inner_dim,
|
1250 |
+
norm_elementwise_affine=norm_elementwise_affine,
|
1251 |
+
norm_eps=norm_eps,
|
1252 |
+
chunk_dim=1,
|
1253 |
+
)
|
1254 |
+
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * out_channels)
|
1255 |
+
|
1256 |
+
self.gradient_checkpointing = False
|
1257 |
+
|
1258 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
1259 |
+
self.gradient_checkpointing = value
|
1260 |
+
|
1261 |
+
def forward(
|
1262 |
+
self,
|
1263 |
+
hidden_states,
|
1264 |
+
timestep,
|
1265 |
+
timestep_cond = None,
|
1266 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
1267 |
+
text_embedding_mask: Optional[torch.Tensor] = None,
|
1268 |
+
encoder_hidden_states_t5: Optional[torch.Tensor] = None,
|
1269 |
+
text_embedding_mask_t5: Optional[torch.Tensor] = None,
|
1270 |
+
image_meta_size = None,
|
1271 |
+
style = None,
|
1272 |
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
1273 |
+
inpaint_latents: Optional[torch.Tensor] = None,
|
1274 |
+
control_latents: Optional[torch.Tensor] = None,
|
1275 |
+
ref_latents: Optional[torch.Tensor] = None,
|
1276 |
+
clip_encoder_hidden_states: Optional[torch.Tensor] = None,
|
1277 |
+
clip_attention_mask: Optional[torch.Tensor] = None,
|
1278 |
+
return_dict=True,
|
1279 |
+
):
|
1280 |
+
batch_size, channels, video_length, height, width = hidden_states.size()
|
1281 |
+
|
1282 |
+
# 1. Time embedding
|
1283 |
+
temb = self.time_proj(timestep).to(dtype=hidden_states.dtype)
|
1284 |
+
temb = self.time_embedding(temb, timestep_cond)
|
1285 |
+
|
1286 |
+
# 2. Patch embedding
|
1287 |
+
if inpaint_latents is not None:
|
1288 |
+
hidden_states = torch.concat([hidden_states, inpaint_latents], 1)
|
1289 |
+
if control_latents is not None:
|
1290 |
+
hidden_states = torch.concat([hidden_states, control_latents], 1)
|
1291 |
+
|
1292 |
+
hidden_states = rearrange(hidden_states, "b c f h w ->(b f) c h w")
|
1293 |
+
hidden_states = self.proj(hidden_states)
|
1294 |
+
hidden_states = rearrange(hidden_states, "(b f) c h w -> b c f h w", f=video_length, h=height // self.patch_size, w=width // self.patch_size)
|
1295 |
+
hidden_states = hidden_states.flatten(2).transpose(1, 2)
|
1296 |
+
|
1297 |
+
encoder_hidden_states = self.text_proj(encoder_hidden_states)
|
1298 |
+
if encoder_hidden_states_t5 is not None:
|
1299 |
+
encoder_hidden_states_t5 = self.text_proj_t5(encoder_hidden_states_t5)
|
1300 |
+
encoder_hidden_states = torch.cat([encoder_hidden_states, encoder_hidden_states_t5], dim=1).contiguous()
|
1301 |
+
|
1302 |
+
if ref_latents is not None:
|
1303 |
+
ref_batch, ref_channels, ref_video_length, ref_height, ref_width = ref_latents.shape
|
1304 |
+
ref_latents = rearrange(ref_latents, "b c f h w ->(b f) c h w")
|
1305 |
+
ref_latents = self.ref_proj(ref_latents)
|
1306 |
+
ref_latents = rearrange(ref_latents, "(b f) c h w -> b c f h w", f=ref_video_length, h=ref_height // self.patch_size, w=ref_width // self.patch_size)
|
1307 |
+
ref_latents = ref_latents.flatten(2).transpose(1, 2)
|
1308 |
+
|
1309 |
+
emb_size = hidden_states.size()[-1]
|
1310 |
+
ref_pos_embedding = self.ref_pos_embedding
|
1311 |
+
ref_pos_embedding_interpolate = ref_pos_embedding.view(1, 1, self.post_patch_height, self.post_patch_width, emb_size).permute([0, 4, 1, 2, 3])
|
1312 |
+
ref_pos_embedding_interpolate = F.interpolate(
|
1313 |
+
ref_pos_embedding_interpolate,
|
1314 |
+
size=[1, height // self.config.patch_size, width // self.config.patch_size],
|
1315 |
+
mode='trilinear', align_corners=False
|
1316 |
+
)
|
1317 |
+
ref_pos_embedding_interpolate = ref_pos_embedding_interpolate.permute([0, 2, 3, 4, 1]).view(1, -1, emb_size)
|
1318 |
+
ref_latents = ref_latents + ref_pos_embedding_interpolate
|
1319 |
+
|
1320 |
+
encoder_hidden_states = ref_latents
|
1321 |
+
|
1322 |
+
if clip_encoder_hidden_states is not None:
|
1323 |
+
clip_encoder_hidden_states = self.clip_proj(clip_encoder_hidden_states)
|
1324 |
+
|
1325 |
+
encoder_hidden_states = torch.concat([clip_encoder_hidden_states, ref_latents], dim=1)
|
1326 |
+
|
1327 |
+
# 4. Transformer blocks
|
1328 |
+
for i, block in enumerate(self.transformer_blocks):
|
1329 |
+
if self.training and self.gradient_checkpointing:
|
1330 |
+
def create_custom_forward(module, return_dict=None):
|
1331 |
+
def custom_forward(*inputs):
|
1332 |
+
if return_dict is not None:
|
1333 |
+
return module(*inputs, return_dict=return_dict)
|
1334 |
+
else:
|
1335 |
+
return module(*inputs)
|
1336 |
+
|
1337 |
+
return custom_forward
|
1338 |
+
|
1339 |
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
1340 |
+
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
|
1341 |
+
create_custom_forward(block),
|
1342 |
+
hidden_states,
|
1343 |
+
encoder_hidden_states,
|
1344 |
+
temb,
|
1345 |
+
image_rotary_emb,
|
1346 |
+
**ckpt_kwargs,
|
1347 |
+
)
|
1348 |
+
else:
|
1349 |
+
hidden_states, encoder_hidden_states = block(
|
1350 |
+
hidden_states=hidden_states,
|
1351 |
+
encoder_hidden_states=encoder_hidden_states,
|
1352 |
+
temb=temb,
|
1353 |
+
image_rotary_emb=image_rotary_emb,
|
1354 |
+
)
|
1355 |
+
|
1356 |
+
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
1357 |
+
hidden_states = self.norm_final(hidden_states)
|
1358 |
+
hidden_states = hidden_states[:, encoder_hidden_states.size()[1]:]
|
1359 |
+
|
1360 |
+
# 5. Final block
|
1361 |
+
hidden_states = self.norm_out(hidden_states, temb=temb)
|
1362 |
+
hidden_states = self.proj_out(hidden_states)
|
1363 |
+
|
1364 |
+
# 6. Unpatchify
|
1365 |
+
p = self.config.patch_size
|
1366 |
+
output = hidden_states.reshape(batch_size, video_length, height // p, width // p, channels, p, p)
|
1367 |
+
output = output.permute(0, 4, 1, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
|
1368 |
+
|
1369 |
+
if not return_dict:
|
1370 |
+
return (output,)
|
1371 |
+
return Transformer2DModelOutput(sample=output)
|
1372 |
+
|
1373 |
+
@classmethod
|
1374 |
+
def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, transformer_additional_kwargs={}):
|
1375 |
+
if subfolder is not None:
|
1376 |
+
pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
|
1377 |
+
print(f"loaded 3D transformer's pretrained weights from {pretrained_model_path} ...")
|
1378 |
+
|
1379 |
+
config_file = os.path.join(pretrained_model_path, 'config.json')
|
1380 |
+
if not os.path.isfile(config_file):
|
1381 |
+
raise RuntimeError(f"{config_file} does not exist")
|
1382 |
+
with open(config_file, "r") as f:
|
1383 |
+
config = json.load(f)
|
1384 |
+
|
1385 |
+
from diffusers.utils import WEIGHTS_NAME
|
1386 |
+
model = cls.from_config(config, **transformer_additional_kwargs)
|
1387 |
+
model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
|
1388 |
+
model_file_safetensors = model_file.replace(".bin", ".safetensors")
|
1389 |
+
if os.path.exists(model_file):
|
1390 |
+
state_dict = torch.load(model_file, map_location="cpu")
|
1391 |
+
elif os.path.exists(model_file_safetensors):
|
1392 |
+
from safetensors.torch import load_file, safe_open
|
1393 |
+
state_dict = load_file(model_file_safetensors)
|
1394 |
+
else:
|
1395 |
+
from safetensors.torch import load_file, safe_open
|
1396 |
+
model_files_safetensors = glob.glob(os.path.join(pretrained_model_path, "*.safetensors"))
|
1397 |
+
state_dict = {}
|
1398 |
+
for model_file_safetensors in model_files_safetensors:
|
1399 |
+
_state_dict = load_file(model_file_safetensors)
|
1400 |
+
for key in _state_dict:
|
1401 |
+
state_dict[key] = _state_dict[key]
|
1402 |
+
|
1403 |
+
if model.state_dict()['proj.weight'].size() != state_dict['proj.weight'].size():
|
1404 |
+
new_shape = model.state_dict()['proj.weight'].size()
|
1405 |
+
if len(new_shape) == 5:
|
1406 |
+
state_dict['proj.weight'] = state_dict['proj.weight'].unsqueeze(2).expand(new_shape).clone()
|
1407 |
+
state_dict['proj.weight'][:, :, :-1] = 0
|
1408 |
+
else:
|
1409 |
+
if model.state_dict()['proj.weight'].size()[1] > state_dict['proj.weight'].size()[1]:
|
1410 |
+
model.state_dict()['proj.weight'][:, :state_dict['proj.weight'].size()[1], :, :] = state_dict['proj.weight']
|
1411 |
+
model.state_dict()['proj.weight'][:, state_dict['proj.weight'].size()[1]:, :, :] = 0
|
1412 |
+
state_dict['proj.weight'] = model.state_dict()['proj.weight']
|
1413 |
+
else:
|
1414 |
+
model.state_dict()['proj.weight'][:, :, :, :] = state_dict['proj.weight'][:, :model.state_dict()['proj.weight'].size()[1], :, :]
|
1415 |
+
state_dict['proj.weight'] = model.state_dict()['proj.weight']
|
1416 |
+
|
1417 |
+
tmp_state_dict = {}
|
1418 |
+
for key in state_dict:
|
1419 |
+
if key in model.state_dict().keys() and model.state_dict()[key].size() == state_dict[key].size():
|
1420 |
+
tmp_state_dict[key] = state_dict[key]
|
1421 |
+
else:
|
1422 |
+
print(key, "Size don't match, skip")
|
1423 |
+
|
1424 |
+
state_dict = tmp_state_dict
|
1425 |
+
|
1426 |
+
m, u = model.load_state_dict(state_dict, strict=False)
|
1427 |
+
print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
|
1428 |
+
print(m)
|
1429 |
+
|
1430 |
+
params = [p.numel() if "." in n else 0 for n, p in model.named_parameters()]
|
1431 |
+
print(f"### All Parameters: {sum(params) / 1e6} M")
|
1432 |
+
|
1433 |
+
params = [p.numel() if "attn1." in n else 0 for n, p in model.named_parameters()]
|
1434 |
+
print(f"### attn1 Parameters: {sum(params) / 1e6} M")
|
1435 |
+
|
1436 |
return model
|
easyanimate/pipeline/pipeline_easyanimate.py
CHANGED
@@ -12,9 +12,9 @@
|
|
12 |
# See the License for the specific language governing permissions and
|
13 |
# limitations under the License.
|
14 |
|
|
|
15 |
import html
|
16 |
import inspect
|
17 |
-
import copy
|
18 |
import re
|
19 |
import urllib.parse as ul
|
20 |
from dataclasses import dataclass
|
@@ -154,7 +154,8 @@ class EasyAnimatePipeline(DiffusionPipeline):
|
|
154 |
|
155 |
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
156 |
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
157 |
-
|
|
|
158 |
# Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/utils.py
|
159 |
def mask_text_embeddings(self, emb, mask):
|
160 |
if emb.shape[0] == 1:
|
@@ -548,31 +549,13 @@ class EasyAnimatePipeline(DiffusionPipeline):
|
|
548 |
prefix_index_before = mini_batch_encoder // 2
|
549 |
prefix_index_after = mini_batch_encoder - prefix_index_before
|
550 |
pixel_values = video[:, :, prefix_index_before:-prefix_index_after]
|
551 |
-
|
552 |
-
|
553 |
-
|
554 |
-
|
555 |
-
|
556 |
-
|
557 |
-
|
558 |
-
with torch.no_grad():
|
559 |
-
pixel_values_bs = pixel_values[:, :, i: i + mini_batch_encoder, :, :]
|
560 |
-
pixel_values_bs = self.vae.encode(pixel_values_bs)[0]
|
561 |
-
pixel_values_bs = pixel_values_bs.sample()
|
562 |
-
new_pixel_values.append(pixel_values_bs)
|
563 |
-
latents = torch.cat(new_pixel_values, dim = 2)
|
564 |
-
|
565 |
-
if self.vae.slice_compression_vae:
|
566 |
-
middle_video = self.vae.decode(latents)[0]
|
567 |
-
else:
|
568 |
-
middle_video = []
|
569 |
-
for i in range(0, latents.shape[2], mini_batch_decoder):
|
570 |
-
with torch.no_grad():
|
571 |
-
start_index = i
|
572 |
-
end_index = i + mini_batch_decoder
|
573 |
-
latents_bs = self.vae.decode(latents[:, :, start_index:end_index, :, :])[0]
|
574 |
-
middle_video.append(latents_bs)
|
575 |
-
middle_video = torch.cat(middle_video, 2)
|
576 |
video[:, :, prefix_index_before:-prefix_index_after] = (video[:, :, prefix_index_before:-prefix_index_after] + middle_video) / 2
|
577 |
return video
|
578 |
|
@@ -582,17 +565,7 @@ class EasyAnimatePipeline(DiffusionPipeline):
|
|
582 |
if self.vae.quant_conv.weight.ndim==5:
|
583 |
mini_batch_encoder = self.vae.mini_batch_encoder
|
584 |
mini_batch_decoder = self.vae.mini_batch_decoder
|
585 |
-
|
586 |
-
video = self.vae.decode(latents)[0]
|
587 |
-
else:
|
588 |
-
video = []
|
589 |
-
for i in range(0, latents.shape[2], mini_batch_decoder):
|
590 |
-
with torch.no_grad():
|
591 |
-
start_index = i
|
592 |
-
end_index = i + mini_batch_decoder
|
593 |
-
latents_bs = self.vae.decode(latents[:, :, start_index:end_index, :, :])[0]
|
594 |
-
video.append(latents_bs)
|
595 |
-
video = torch.cat(video, 2)
|
596 |
video = video.clamp(-1, 1)
|
597 |
video = self.smooth_output(video, mini_batch_encoder, mini_batch_decoder).cpu().clamp(-1, 1)
|
598 |
else:
|
@@ -607,6 +580,9 @@ class EasyAnimatePipeline(DiffusionPipeline):
|
|
607 |
video = video.cpu().float().numpy()
|
608 |
return video
|
609 |
|
|
|
|
|
|
|
610 |
@torch.no_grad()
|
611 |
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
612 |
def __call__(
|
@@ -633,6 +609,7 @@ class EasyAnimatePipeline(DiffusionPipeline):
|
|
633 |
callback_steps: int = 1,
|
634 |
clean_caption: bool = True,
|
635 |
max_sequence_length: int = 120,
|
|
|
636 |
**kwargs,
|
637 |
) -> Union[EasyAnimatePipelineOutput, Tuple]:
|
638 |
"""
|
@@ -780,9 +757,16 @@ class EasyAnimatePipeline(DiffusionPipeline):
|
|
780 |
|
781 |
added_cond_kwargs = {"resolution": resolution, "aspect_ratio": aspect_ratio}
|
782 |
|
|
|
|
|
|
|
|
|
|
|
783 |
# 7. Denoising loop
|
784 |
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
785 |
-
|
|
|
|
|
786 |
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
787 |
for i, t in enumerate(timesteps):
|
788 |
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
@@ -834,6 +818,12 @@ class EasyAnimatePipeline(DiffusionPipeline):
|
|
834 |
step_idx = i // getattr(self.scheduler, "order", 1)
|
835 |
callback(step_idx, t, latents)
|
836 |
|
|
|
|
|
|
|
|
|
|
|
|
|
837 |
# Post-processing
|
838 |
video = self.decode_latents(latents)
|
839 |
|
|
|
12 |
# See the License for the specific language governing permissions and
|
13 |
# limitations under the License.
|
14 |
|
15 |
+
import copy
|
16 |
import html
|
17 |
import inspect
|
|
|
18 |
import re
|
19 |
import urllib.parse as ul
|
20 |
from dataclasses import dataclass
|
|
|
154 |
|
155 |
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
156 |
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
157 |
+
self.enable_autocast_float8_transformer_flag = False
|
158 |
+
|
159 |
# Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/utils.py
|
160 |
def mask_text_embeddings(self, emb, mask):
|
161 |
if emb.shape[0] == 1:
|
|
|
549 |
prefix_index_before = mini_batch_encoder // 2
|
550 |
prefix_index_after = mini_batch_encoder - prefix_index_before
|
551 |
pixel_values = video[:, :, prefix_index_before:-prefix_index_after]
|
552 |
+
|
553 |
+
# Encode middle videos
|
554 |
+
latents = self.vae.encode(pixel_values)[0]
|
555 |
+
latents = latents.mode()
|
556 |
+
# Decode middle videos
|
557 |
+
middle_video = self.vae.decode(latents)[0]
|
558 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
559 |
video[:, :, prefix_index_before:-prefix_index_after] = (video[:, :, prefix_index_before:-prefix_index_after] + middle_video) / 2
|
560 |
return video
|
561 |
|
|
|
565 |
if self.vae.quant_conv.weight.ndim==5:
|
566 |
mini_batch_encoder = self.vae.mini_batch_encoder
|
567 |
mini_batch_decoder = self.vae.mini_batch_decoder
|
568 |
+
video = self.vae.decode(latents)[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
569 |
video = video.clamp(-1, 1)
|
570 |
video = self.smooth_output(video, mini_batch_encoder, mini_batch_decoder).cpu().clamp(-1, 1)
|
571 |
else:
|
|
|
580 |
video = video.cpu().float().numpy()
|
581 |
return video
|
582 |
|
583 |
+
def enable_autocast_float8_transformer(self):
|
584 |
+
self.enable_autocast_float8_transformer_flag = True
|
585 |
+
|
586 |
@torch.no_grad()
|
587 |
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
588 |
def __call__(
|
|
|
609 |
callback_steps: int = 1,
|
610 |
clean_caption: bool = True,
|
611 |
max_sequence_length: int = 120,
|
612 |
+
comfyui_progressbar: bool = False,
|
613 |
**kwargs,
|
614 |
) -> Union[EasyAnimatePipelineOutput, Tuple]:
|
615 |
"""
|
|
|
757 |
|
758 |
added_cond_kwargs = {"resolution": resolution, "aspect_ratio": aspect_ratio}
|
759 |
|
760 |
+
torch.cuda.empty_cache()
|
761 |
+
if self.enable_autocast_float8_transformer_flag:
|
762 |
+
origin_weight_dtype = self.transformer.dtype
|
763 |
+
self.transformer = self.transformer.to(torch.float8_e4m3fn)
|
764 |
+
|
765 |
# 7. Denoising loop
|
766 |
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
767 |
+
if comfyui_progressbar:
|
768 |
+
from comfy.utils import ProgressBar
|
769 |
+
pbar = ProgressBar(num_inference_steps)
|
770 |
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
771 |
for i, t in enumerate(timesteps):
|
772 |
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
|
|
818 |
step_idx = i // getattr(self.scheduler, "order", 1)
|
819 |
callback(step_idx, t, latents)
|
820 |
|
821 |
+
if comfyui_progressbar:
|
822 |
+
pbar.update(1)
|
823 |
+
|
824 |
+
if self.enable_autocast_float8_transformer_flag:
|
825 |
+
self.transformer = self.transformer.to("cpu", origin_weight_dtype)
|
826 |
+
|
827 |
# Post-processing
|
828 |
video = self.decode_latents(latents)
|
829 |
|
easyanimate/pipeline/pipeline_easyanimate_inpaint.py
CHANGED
@@ -12,14 +12,13 @@
|
|
12 |
# See the License for the specific language governing permissions and
|
13 |
# limitations under the License.
|
14 |
|
|
|
|
|
15 |
import html
|
16 |
import inspect
|
17 |
import re
|
18 |
-
import gc
|
19 |
-
import copy
|
20 |
import urllib.parse as ul
|
21 |
from dataclasses import dataclass
|
22 |
-
from PIL import Image
|
23 |
from typing import Callable, List, Optional, Tuple, Union
|
24 |
|
25 |
import numpy as np
|
@@ -34,9 +33,10 @@ from diffusers.utils import (BACKENDS_MAPPING, BaseOutput, deprecate,
|
|
34 |
replace_example_docstring)
|
35 |
from diffusers.utils.torch_utils import randn_tensor
|
36 |
from einops import rearrange
|
|
|
37 |
from tqdm import tqdm
|
38 |
-
from transformers import
|
39 |
-
|
40 |
|
41 |
from ..models.transformer3d import Transformer3DModel
|
42 |
|
@@ -129,6 +129,7 @@ class EasyAnimateInpaintPipeline(DiffusionPipeline):
|
|
129 |
self.mask_processor = VaeImageProcessor(
|
130 |
vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
|
131 |
)
|
|
|
132 |
|
133 |
# Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/utils.py
|
134 |
def mask_text_embeddings(self, emb, mask):
|
@@ -493,6 +494,60 @@ class EasyAnimateInpaintPipeline(DiffusionPipeline):
|
|
493 |
|
494 |
return caption.strip()
|
495 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
496 |
def prepare_latents(
|
497 |
self,
|
498 |
batch_size,
|
@@ -529,22 +584,11 @@ class EasyAnimateInpaintPipeline(DiffusionPipeline):
|
|
529 |
bs = 1
|
530 |
mini_batch_encoder = self.vae.mini_batch_encoder
|
531 |
new_video = []
|
532 |
-
|
533 |
-
|
534 |
-
|
535 |
-
|
536 |
-
|
537 |
-
new_video.append(video_bs)
|
538 |
-
else:
|
539 |
-
for i in range(0, video.shape[0], bs):
|
540 |
-
new_video_mini_batch = []
|
541 |
-
for j in range(0, video.shape[2], mini_batch_encoder):
|
542 |
-
video_bs = video[i : i + bs, :, j: j + mini_batch_encoder, :, :]
|
543 |
-
video_bs = self.vae.encode(video_bs)[0]
|
544 |
-
video_bs = video_bs.sample()
|
545 |
-
new_video_mini_batch.append(video_bs)
|
546 |
-
new_video_mini_batch = torch.cat(new_video_mini_batch, dim = 2)
|
547 |
-
new_video.append(new_video_mini_batch)
|
548 |
video = torch.cat(new_video, dim = 0)
|
549 |
video = video * self.vae.config.scaling_factor
|
550 |
|
@@ -585,31 +629,13 @@ class EasyAnimateInpaintPipeline(DiffusionPipeline):
|
|
585 |
prefix_index_before = mini_batch_encoder // 2
|
586 |
prefix_index_after = mini_batch_encoder - prefix_index_before
|
587 |
pixel_values = video[:, :, prefix_index_before:-prefix_index_after]
|
588 |
-
|
589 |
-
|
590 |
-
|
591 |
-
|
592 |
-
|
593 |
-
|
594 |
-
|
595 |
-
with torch.no_grad():
|
596 |
-
pixel_values_bs = pixel_values[:, :, i: i + mini_batch_encoder, :, :]
|
597 |
-
pixel_values_bs = self.vae.encode(pixel_values_bs)[0]
|
598 |
-
pixel_values_bs = pixel_values_bs.sample()
|
599 |
-
new_pixel_values.append(pixel_values_bs)
|
600 |
-
latents = torch.cat(new_pixel_values, dim = 2)
|
601 |
-
|
602 |
-
if self.vae.slice_compression_vae:
|
603 |
-
middle_video = self.vae.decode(latents)[0]
|
604 |
-
else:
|
605 |
-
middle_video = []
|
606 |
-
for i in range(0, latents.shape[2], mini_batch_decoder):
|
607 |
-
with torch.no_grad():
|
608 |
-
start_index = i
|
609 |
-
end_index = i + mini_batch_decoder
|
610 |
-
latents_bs = self.vae.decode(latents[:, :, start_index:end_index, :, :])[0]
|
611 |
-
middle_video.append(latents_bs)
|
612 |
-
middle_video = torch.cat(middle_video, 2)
|
613 |
video[:, :, prefix_index_before:-prefix_index_after] = (video[:, :, prefix_index_before:-prefix_index_after] + middle_video) / 2
|
614 |
return video
|
615 |
|
@@ -619,17 +645,7 @@ class EasyAnimateInpaintPipeline(DiffusionPipeline):
|
|
619 |
if self.vae.quant_conv.weight.ndim==5:
|
620 |
mini_batch_encoder = self.vae.mini_batch_encoder
|
621 |
mini_batch_decoder = self.vae.mini_batch_decoder
|
622 |
-
|
623 |
-
video = self.vae.decode(latents)[0]
|
624 |
-
else:
|
625 |
-
video = []
|
626 |
-
for i in range(0, latents.shape[2], mini_batch_decoder):
|
627 |
-
with torch.no_grad():
|
628 |
-
start_index = i
|
629 |
-
end_index = i + mini_batch_decoder
|
630 |
-
latents_bs = self.vae.decode(latents[:, :, start_index:end_index, :, :])[0]
|
631 |
-
video.append(latents_bs)
|
632 |
-
video = torch.cat(video, 2)
|
633 |
video = video.clamp(-1, 1)
|
634 |
video = self.smooth_output(video, mini_batch_encoder, mini_batch_decoder).cpu().clamp(-1, 1)
|
635 |
else:
|
@@ -668,84 +684,9 @@ class EasyAnimateInpaintPipeline(DiffusionPipeline):
|
|
668 |
|
669 |
return timesteps, num_inference_steps - t_start
|
670 |
|
671 |
-
def
|
672 |
-
self
|
673 |
-
):
|
674 |
-
# resize the mask to latents shape as we concatenate the mask to the latents
|
675 |
-
# we do that before converting to dtype to avoid breaking in case we're using cpu_offload
|
676 |
-
# and half precision
|
677 |
-
video_length = mask.shape[2]
|
678 |
-
|
679 |
-
mask = mask.to(device=device, dtype=self.vae.dtype)
|
680 |
-
if self.vae.quant_conv.weight.ndim==5:
|
681 |
-
bs = 1
|
682 |
-
mini_batch_encoder = self.vae.mini_batch_encoder
|
683 |
-
new_mask = []
|
684 |
-
if self.vae.slice_compression_vae:
|
685 |
-
for i in range(0, mask.shape[0], bs):
|
686 |
-
mask_bs = mask[i : i + bs]
|
687 |
-
mask_bs = self.vae.encode(mask_bs)[0]
|
688 |
-
mask_bs = mask_bs.sample()
|
689 |
-
new_mask.append(mask_bs)
|
690 |
-
else:
|
691 |
-
for i in range(0, mask.shape[0], bs):
|
692 |
-
new_mask_mini_batch = []
|
693 |
-
for j in range(0, mask.shape[2], mini_batch_encoder):
|
694 |
-
mask_bs = mask[i : i + bs, :, j: j + mini_batch_encoder, :, :]
|
695 |
-
mask_bs = self.vae.encode(mask_bs)[0]
|
696 |
-
mask_bs = mask_bs.sample()
|
697 |
-
new_mask_mini_batch.append(mask_bs)
|
698 |
-
new_mask_mini_batch = torch.cat(new_mask_mini_batch, dim = 2)
|
699 |
-
new_mask.append(new_mask_mini_batch)
|
700 |
-
mask = torch.cat(new_mask, dim = 0)
|
701 |
-
mask = mask * self.vae.config.scaling_factor
|
702 |
-
|
703 |
-
else:
|
704 |
-
if mask.shape[1] == 4:
|
705 |
-
mask = mask
|
706 |
-
else:
|
707 |
-
video_length = mask.shape[2]
|
708 |
-
mask = rearrange(mask, "b c f h w -> (b f) c h w")
|
709 |
-
mask = self._encode_vae_image(mask, generator=generator)
|
710 |
-
mask = rearrange(mask, "(b f) c h w -> b c f h w", f=video_length)
|
711 |
-
|
712 |
-
masked_image = masked_image.to(device=device, dtype=self.vae.dtype)
|
713 |
-
if self.vae.quant_conv.weight.ndim==5:
|
714 |
-
bs = 1
|
715 |
-
mini_batch_encoder = self.vae.mini_batch_encoder
|
716 |
-
new_mask_pixel_values = []
|
717 |
-
if self.vae.slice_compression_vae:
|
718 |
-
for i in range(0, masked_image.shape[0], bs):
|
719 |
-
mask_pixel_values_bs = masked_image[i : i + bs]
|
720 |
-
mask_pixel_values_bs = self.vae.encode(mask_pixel_values_bs)[0]
|
721 |
-
mask_pixel_values_bs = mask_pixel_values_bs.sample()
|
722 |
-
new_mask_pixel_values.append(mask_pixel_values_bs)
|
723 |
-
else:
|
724 |
-
for i in range(0, masked_image.shape[0], bs):
|
725 |
-
new_mask_pixel_values_mini_batch = []
|
726 |
-
for j in range(0, masked_image.shape[2], mini_batch_encoder):
|
727 |
-
mask_pixel_values_bs = masked_image[i : i + bs, :, j: j + mini_batch_encoder, :, :]
|
728 |
-
mask_pixel_values_bs = self.vae.encode(mask_pixel_values_bs)[0]
|
729 |
-
mask_pixel_values_bs = mask_pixel_values_bs.sample()
|
730 |
-
new_mask_pixel_values_mini_batch.append(mask_pixel_values_bs)
|
731 |
-
new_mask_pixel_values_mini_batch = torch.cat(new_mask_pixel_values_mini_batch, dim = 2)
|
732 |
-
new_mask_pixel_values.append(new_mask_pixel_values_mini_batch)
|
733 |
-
masked_image_latents = torch.cat(new_mask_pixel_values, dim = 0)
|
734 |
-
masked_image_latents = masked_image_latents * self.vae.config.scaling_factor
|
735 |
-
|
736 |
-
else:
|
737 |
-
if masked_image.shape[1] == 4:
|
738 |
-
masked_image_latents = masked_image
|
739 |
-
else:
|
740 |
-
video_length = mask.shape[2]
|
741 |
-
masked_image = rearrange(masked_image, "b c f h w -> (b f) c h w")
|
742 |
-
masked_image_latents = self._encode_vae_image(masked_image, generator=generator)
|
743 |
-
masked_image_latents = rearrange(masked_image_latents, "(b f) c h w -> b c f h w", f=video_length)
|
744 |
|
745 |
-
# aligning device to prevent device errors when concating it with the latent model input
|
746 |
-
masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
|
747 |
-
return mask, masked_image_latents
|
748 |
-
|
749 |
@torch.no_grad()
|
750 |
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
751 |
def __call__(
|
@@ -779,6 +720,8 @@ class EasyAnimateInpaintPipeline(DiffusionPipeline):
|
|
779 |
max_sequence_length: int = 120,
|
780 |
clip_image: Image = None,
|
781 |
clip_apply_ratio: float = 0.50,
|
|
|
|
|
782 |
) -> Union[EasyAnimatePipelineOutput, Tuple]:
|
783 |
"""
|
784 |
Function invoked when calling the pipeline for generation.
|
@@ -1057,10 +1000,16 @@ class EasyAnimateInpaintPipeline(DiffusionPipeline):
|
|
1057 |
gc.collect()
|
1058 |
torch.cuda.empty_cache()
|
1059 |
torch.cuda.ipc_collect()
|
|
|
|
|
|
|
1060 |
|
1061 |
# 10. Denoising loop
|
1062 |
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
1063 |
self._num_timesteps = len(timesteps)
|
|
|
|
|
|
|
1064 |
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
1065 |
for i, t in enumerate(timesteps):
|
1066 |
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
@@ -1130,16 +1079,19 @@ class EasyAnimateInpaintPipeline(DiffusionPipeline):
|
|
1130 |
step_idx = i // getattr(self.scheduler, "order", 1)
|
1131 |
callback(step_idx, t, latents)
|
1132 |
|
|
|
|
|
|
|
|
|
|
|
|
|
1133 |
gc.collect()
|
1134 |
torch.cuda.empty_cache()
|
1135 |
torch.cuda.ipc_collect()
|
1136 |
|
1137 |
# Post-processing
|
1138 |
video = self.decode_latents(latents)
|
1139 |
-
|
1140 |
-
gc.collect()
|
1141 |
-
torch.cuda.empty_cache()
|
1142 |
-
torch.cuda.ipc_collect()
|
1143 |
# Convert to tensor
|
1144 |
if output_type == "latent":
|
1145 |
video = torch.from_numpy(video)
|
|
|
12 |
# See the License for the specific language governing permissions and
|
13 |
# limitations under the License.
|
14 |
|
15 |
+
import copy
|
16 |
+
import gc
|
17 |
import html
|
18 |
import inspect
|
19 |
import re
|
|
|
|
|
20 |
import urllib.parse as ul
|
21 |
from dataclasses import dataclass
|
|
|
22 |
from typing import Callable, List, Optional, Tuple, Union
|
23 |
|
24 |
import numpy as np
|
|
|
33 |
replace_example_docstring)
|
34 |
from diffusers.utils.torch_utils import randn_tensor
|
35 |
from einops import rearrange
|
36 |
+
from PIL import Image
|
37 |
from tqdm import tqdm
|
38 |
+
from transformers import (CLIPImageProcessor, CLIPVisionModelWithProjection,
|
39 |
+
T5EncoderModel, T5Tokenizer)
|
40 |
|
41 |
from ..models.transformer3d import Transformer3DModel
|
42 |
|
|
|
129 |
self.mask_processor = VaeImageProcessor(
|
130 |
vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
|
131 |
)
|
132 |
+
self.enable_autocast_float8_transformer_flag = False
|
133 |
|
134 |
# Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/utils.py
|
135 |
def mask_text_embeddings(self, emb, mask):
|
|
|
494 |
|
495 |
return caption.strip()
|
496 |
|
497 |
+
def prepare_mask_latents(
|
498 |
+
self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
|
499 |
+
):
|
500 |
+
# resize the mask to latents shape as we concatenate the mask to the latents
|
501 |
+
# we do that before converting to dtype to avoid breaking in case we're using cpu_offload
|
502 |
+
# and half precision
|
503 |
+
video_length = mask.shape[2]
|
504 |
+
|
505 |
+
mask = mask.to(device=device, dtype=self.vae.dtype)
|
506 |
+
if self.vae.quant_conv.weight.ndim==5:
|
507 |
+
bs = 1
|
508 |
+
new_mask = []
|
509 |
+
for i in range(0, mask.shape[0], bs):
|
510 |
+
mask_bs = mask[i : i + bs]
|
511 |
+
mask_bs = self.vae.encode(mask_bs)[0]
|
512 |
+
mask_bs = mask_bs.sample()
|
513 |
+
new_mask.append(mask_bs)
|
514 |
+
mask = torch.cat(new_mask, dim = 0)
|
515 |
+
mask = mask * self.vae.config.scaling_factor
|
516 |
+
|
517 |
+
else:
|
518 |
+
if mask.shape[1] == 4:
|
519 |
+
mask = mask
|
520 |
+
else:
|
521 |
+
video_length = mask.shape[2]
|
522 |
+
mask = rearrange(mask, "b c f h w -> (b f) c h w")
|
523 |
+
mask = self._encode_vae_image(mask, generator=generator)
|
524 |
+
mask = rearrange(mask, "(b f) c h w -> b c f h w", f=video_length)
|
525 |
+
|
526 |
+
masked_image = masked_image.to(device=device, dtype=self.vae.dtype)
|
527 |
+
if self.vae.quant_conv.weight.ndim==5:
|
528 |
+
bs = 1
|
529 |
+
new_mask_pixel_values = []
|
530 |
+
for i in range(0, masked_image.shape[0], bs):
|
531 |
+
mask_pixel_values_bs = masked_image[i : i + bs]
|
532 |
+
mask_pixel_values_bs = self.vae.encode(mask_pixel_values_bs)[0]
|
533 |
+
mask_pixel_values_bs = mask_pixel_values_bs.sample()
|
534 |
+
new_mask_pixel_values.append(mask_pixel_values_bs)
|
535 |
+
masked_image_latents = torch.cat(new_mask_pixel_values, dim = 0)
|
536 |
+
masked_image_latents = masked_image_latents * self.vae.config.scaling_factor
|
537 |
+
|
538 |
+
else:
|
539 |
+
if masked_image.shape[1] == 4:
|
540 |
+
masked_image_latents = masked_image
|
541 |
+
else:
|
542 |
+
video_length = mask.shape[2]
|
543 |
+
masked_image = rearrange(masked_image, "b c f h w -> (b f) c h w")
|
544 |
+
masked_image_latents = self._encode_vae_image(masked_image, generator=generator)
|
545 |
+
masked_image_latents = rearrange(masked_image_latents, "(b f) c h w -> b c f h w", f=video_length)
|
546 |
+
|
547 |
+
# aligning device to prevent device errors when concating it with the latent model input
|
548 |
+
masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
|
549 |
+
return mask, masked_image_latents
|
550 |
+
|
551 |
def prepare_latents(
|
552 |
self,
|
553 |
batch_size,
|
|
|
584 |
bs = 1
|
585 |
mini_batch_encoder = self.vae.mini_batch_encoder
|
586 |
new_video = []
|
587 |
+
for i in range(0, video.shape[0], bs):
|
588 |
+
video_bs = video[i : i + bs]
|
589 |
+
video_bs = self.vae.encode(video_bs)[0]
|
590 |
+
video_bs = video_bs.sample()
|
591 |
+
new_video.append(video_bs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
592 |
video = torch.cat(new_video, dim = 0)
|
593 |
video = video * self.vae.config.scaling_factor
|
594 |
|
|
|
629 |
prefix_index_before = mini_batch_encoder // 2
|
630 |
prefix_index_after = mini_batch_encoder - prefix_index_before
|
631 |
pixel_values = video[:, :, prefix_index_before:-prefix_index_after]
|
632 |
+
|
633 |
+
# Encode middle videos
|
634 |
+
latents = self.vae.encode(pixel_values)[0]
|
635 |
+
latents = latents.sample()
|
636 |
+
# Decode middle videos
|
637 |
+
middle_video = self.vae.decode(latents)[0]
|
638 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
639 |
video[:, :, prefix_index_before:-prefix_index_after] = (video[:, :, prefix_index_before:-prefix_index_after] + middle_video) / 2
|
640 |
return video
|
641 |
|
|
|
645 |
if self.vae.quant_conv.weight.ndim==5:
|
646 |
mini_batch_encoder = self.vae.mini_batch_encoder
|
647 |
mini_batch_decoder = self.vae.mini_batch_decoder
|
648 |
+
video = self.vae.decode(latents)[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
649 |
video = video.clamp(-1, 1)
|
650 |
video = self.smooth_output(video, mini_batch_encoder, mini_batch_decoder).cpu().clamp(-1, 1)
|
651 |
else:
|
|
|
684 |
|
685 |
return timesteps, num_inference_steps - t_start
|
686 |
|
687 |
+
def enable_autocast_float8_transformer(self):
|
688 |
+
self.enable_autocast_float8_transformer_flag = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
689 |
|
|
|
|
|
|
|
|
|
690 |
@torch.no_grad()
|
691 |
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
692 |
def __call__(
|
|
|
720 |
max_sequence_length: int = 120,
|
721 |
clip_image: Image = None,
|
722 |
clip_apply_ratio: float = 0.50,
|
723 |
+
comfyui_progressbar: bool = False,
|
724 |
+
**kwargs,
|
725 |
) -> Union[EasyAnimatePipelineOutput, Tuple]:
|
726 |
"""
|
727 |
Function invoked when calling the pipeline for generation.
|
|
|
1000 |
gc.collect()
|
1001 |
torch.cuda.empty_cache()
|
1002 |
torch.cuda.ipc_collect()
|
1003 |
+
if self.enable_autocast_float8_transformer_flag:
|
1004 |
+
origin_weight_dtype = self.transformer.dtype
|
1005 |
+
self.transformer = self.transformer.to(torch.float8_e4m3fn)
|
1006 |
|
1007 |
# 10. Denoising loop
|
1008 |
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
1009 |
self._num_timesteps = len(timesteps)
|
1010 |
+
if comfyui_progressbar:
|
1011 |
+
from comfy.utils import ProgressBar
|
1012 |
+
pbar = ProgressBar(num_inference_steps)
|
1013 |
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
1014 |
for i, t in enumerate(timesteps):
|
1015 |
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
|
|
1079 |
step_idx = i // getattr(self.scheduler, "order", 1)
|
1080 |
callback(step_idx, t, latents)
|
1081 |
|
1082 |
+
if comfyui_progressbar:
|
1083 |
+
pbar.update(1)
|
1084 |
+
|
1085 |
+
if self.enable_autocast_float8_transformer_flag:
|
1086 |
+
self.transformer = self.transformer.to("cpu", origin_weight_dtype)
|
1087 |
+
|
1088 |
gc.collect()
|
1089 |
torch.cuda.empty_cache()
|
1090 |
torch.cuda.ipc_collect()
|
1091 |
|
1092 |
# Post-processing
|
1093 |
video = self.decode_latents(latents)
|
1094 |
+
|
|
|
|
|
|
|
1095 |
# Convert to tensor
|
1096 |
if output_type == "latent":
|
1097 |
video = torch.from_numpy(video)
|
easyanimate/pipeline/pipeline_easyanimate_multi_text_encoder.py
ADDED
@@ -0,0 +1,925 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 EasyAnimate Authors and The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import inspect
|
16 |
+
from typing import Callable, Dict, List, Optional, Tuple, Union
|
17 |
+
|
18 |
+
import numpy as np
|
19 |
+
import torch
|
20 |
+
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
|
21 |
+
from diffusers.image_processor import VaeImageProcessor
|
22 |
+
from diffusers.models.embeddings import (get_2d_rotary_pos_embed,
|
23 |
+
get_3d_rotary_pos_embed)
|
24 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
25 |
+
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
26 |
+
from diffusers.pipelines.stable_diffusion.safety_checker import \
|
27 |
+
StableDiffusionSafetyChecker
|
28 |
+
from diffusers.schedulers import DDIMScheduler
|
29 |
+
from diffusers.utils import (is_torch_xla_available, logging,
|
30 |
+
replace_example_docstring)
|
31 |
+
from diffusers.utils.torch_utils import randn_tensor
|
32 |
+
from einops import rearrange
|
33 |
+
from tqdm import tqdm
|
34 |
+
from transformers import (BertModel, BertTokenizer, CLIPImageProcessor,
|
35 |
+
T5Tokenizer, T5EncoderModel)
|
36 |
+
|
37 |
+
from .pipeline_easyanimate import EasyAnimatePipelineOutput
|
38 |
+
from ..models import AutoencoderKLMagvit, EasyAnimateTransformer3DModel
|
39 |
+
|
40 |
+
if is_torch_xla_available():
|
41 |
+
import torch_xla.core.xla_model as xm
|
42 |
+
|
43 |
+
XLA_AVAILABLE = True
|
44 |
+
else:
|
45 |
+
XLA_AVAILABLE = False
|
46 |
+
|
47 |
+
|
48 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
49 |
+
|
50 |
+
EXAMPLE_DOC_STRING = """
|
51 |
+
Examples:
|
52 |
+
```py
|
53 |
+
>>> pass
|
54 |
+
```
|
55 |
+
"""
|
56 |
+
|
57 |
+
|
58 |
+
def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
|
59 |
+
tw = tgt_width
|
60 |
+
th = tgt_height
|
61 |
+
h, w = src
|
62 |
+
r = h / w
|
63 |
+
if r > (th / tw):
|
64 |
+
resize_height = th
|
65 |
+
resize_width = int(round(th / h * w))
|
66 |
+
else:
|
67 |
+
resize_width = tw
|
68 |
+
resize_height = int(round(tw / w * h))
|
69 |
+
|
70 |
+
crop_top = int(round((th - resize_height) / 2.0))
|
71 |
+
crop_left = int(round((tw - resize_width) / 2.0))
|
72 |
+
|
73 |
+
return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
|
74 |
+
|
75 |
+
|
76 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
|
77 |
+
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
78 |
+
"""
|
79 |
+
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
|
80 |
+
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
|
81 |
+
"""
|
82 |
+
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
|
83 |
+
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
|
84 |
+
# rescale the results from guidance (fixes overexposure)
|
85 |
+
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
|
86 |
+
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
|
87 |
+
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
|
88 |
+
return noise_cfg
|
89 |
+
|
90 |
+
|
91 |
+
class EasyAnimatePipeline_Multi_Text_Encoder(DiffusionPipeline):
|
92 |
+
r"""
|
93 |
+
Pipeline for text-to-video generation using EasyAnimate.
|
94 |
+
|
95 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
96 |
+
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
97 |
+
|
98 |
+
EasyAnimate uses two text encoders: [mT5](https://huggingface.co/google/mt5-base) and [bilingual CLIP](fine-tuned by
|
99 |
+
HunyuanDiT team)
|
100 |
+
|
101 |
+
Args:
|
102 |
+
vae ([`AutoencoderKLMagvit`]):
|
103 |
+
Variational Auto-Encoder (VAE) Model to encode and decode video to and from latent representations.
|
104 |
+
text_encoder (Optional[`~transformers.BertModel`, `~transformers.CLIPTextModel`]):
|
105 |
+
Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
|
106 |
+
EasyAnimate uses a fine-tuned [bilingual CLIP].
|
107 |
+
tokenizer (Optional[`~transformers.BertTokenizer`, `~transformers.CLIPTokenizer`]):
|
108 |
+
A `BertTokenizer` or `CLIPTokenizer` to tokenize text.
|
109 |
+
transformer ([`EasyAnimateTransformer3DModel`]):
|
110 |
+
The EasyAnimate model designed by Tencent Hunyuan.
|
111 |
+
text_encoder_2 (`T5EncoderModel`):
|
112 |
+
The mT5 embedder.
|
113 |
+
tokenizer_2 (`T5Tokenizer`):
|
114 |
+
The tokenizer for the mT5 embedder.
|
115 |
+
scheduler ([`DDIMScheduler`]):
|
116 |
+
A scheduler to be used in combination with EasyAnimate to denoise the encoded image latents.
|
117 |
+
"""
|
118 |
+
|
119 |
+
model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
|
120 |
+
_optional_components = [
|
121 |
+
"safety_checker",
|
122 |
+
"feature_extractor",
|
123 |
+
"text_encoder_2",
|
124 |
+
"tokenizer_2",
|
125 |
+
"text_encoder",
|
126 |
+
"tokenizer",
|
127 |
+
]
|
128 |
+
_exclude_from_cpu_offload = ["safety_checker"]
|
129 |
+
_callback_tensor_inputs = [
|
130 |
+
"latents",
|
131 |
+
"prompt_embeds",
|
132 |
+
"negative_prompt_embeds",
|
133 |
+
"prompt_embeds_2",
|
134 |
+
"negative_prompt_embeds_2",
|
135 |
+
]
|
136 |
+
|
137 |
+
def __init__(
|
138 |
+
self,
|
139 |
+
vae: AutoencoderKLMagvit,
|
140 |
+
text_encoder: BertModel,
|
141 |
+
tokenizer: BertTokenizer,
|
142 |
+
text_encoder_2: T5EncoderModel,
|
143 |
+
tokenizer_2: T5Tokenizer,
|
144 |
+
transformer: EasyAnimateTransformer3DModel,
|
145 |
+
scheduler: DDIMScheduler,
|
146 |
+
safety_checker: StableDiffusionSafetyChecker,
|
147 |
+
feature_extractor: CLIPImageProcessor,
|
148 |
+
requires_safety_checker: bool = True,
|
149 |
+
):
|
150 |
+
super().__init__()
|
151 |
+
|
152 |
+
self.register_modules(
|
153 |
+
vae=vae,
|
154 |
+
text_encoder=text_encoder,
|
155 |
+
tokenizer=tokenizer,
|
156 |
+
tokenizer_2=tokenizer_2,
|
157 |
+
transformer=transformer,
|
158 |
+
scheduler=scheduler,
|
159 |
+
safety_checker=safety_checker,
|
160 |
+
feature_extractor=feature_extractor,
|
161 |
+
text_encoder_2=text_encoder_2,
|
162 |
+
)
|
163 |
+
|
164 |
+
if safety_checker is None and requires_safety_checker:
|
165 |
+
logger.warning(
|
166 |
+
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
167 |
+
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
|
168 |
+
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
|
169 |
+
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
|
170 |
+
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
|
171 |
+
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
|
172 |
+
)
|
173 |
+
|
174 |
+
if safety_checker is not None and feature_extractor is None:
|
175 |
+
raise ValueError(
|
176 |
+
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
|
177 |
+
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
178 |
+
)
|
179 |
+
|
180 |
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
181 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
182 |
+
self.enable_autocast_float8_transformer_flag = False
|
183 |
+
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
184 |
+
|
185 |
+
def enable_sequential_cpu_offload(self, *args, **kwargs):
|
186 |
+
super().enable_sequential_cpu_offload(*args, **kwargs)
|
187 |
+
if hasattr(self.transformer, "clip_projection") and self.transformer.clip_projection is not None:
|
188 |
+
import accelerate
|
189 |
+
accelerate.hooks.remove_hook_from_module(self.transformer.clip_projection, recurse=True)
|
190 |
+
self.transformer.clip_projection = self.transformer.clip_projection.to("cuda")
|
191 |
+
|
192 |
+
def encode_prompt(
|
193 |
+
self,
|
194 |
+
prompt: str,
|
195 |
+
device: torch.device,
|
196 |
+
dtype: torch.dtype,
|
197 |
+
num_images_per_prompt: int = 1,
|
198 |
+
do_classifier_free_guidance: bool = True,
|
199 |
+
negative_prompt: Optional[str] = None,
|
200 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
201 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
202 |
+
prompt_attention_mask: Optional[torch.Tensor] = None,
|
203 |
+
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
204 |
+
max_sequence_length: Optional[int] = None,
|
205 |
+
text_encoder_index: int = 0,
|
206 |
+
actual_max_sequence_length: int = 256
|
207 |
+
):
|
208 |
+
r"""
|
209 |
+
Encodes the prompt into text encoder hidden states.
|
210 |
+
|
211 |
+
Args:
|
212 |
+
prompt (`str` or `List[str]`, *optional*):
|
213 |
+
prompt to be encoded
|
214 |
+
device: (`torch.device`):
|
215 |
+
torch device
|
216 |
+
dtype (`torch.dtype`):
|
217 |
+
torch dtype
|
218 |
+
num_images_per_prompt (`int`):
|
219 |
+
number of images that should be generated per prompt
|
220 |
+
do_classifier_free_guidance (`bool`):
|
221 |
+
whether to use classifier free guidance or not
|
222 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
223 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
224 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
225 |
+
less than `1`).
|
226 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
227 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
228 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
229 |
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
230 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
231 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
232 |
+
argument.
|
233 |
+
prompt_attention_mask (`torch.Tensor`, *optional*):
|
234 |
+
Attention mask for the prompt. Required when `prompt_embeds` is passed directly.
|
235 |
+
negative_prompt_attention_mask (`torch.Tensor`, *optional*):
|
236 |
+
Attention mask for the negative prompt. Required when `negative_prompt_embeds` is passed directly.
|
237 |
+
max_sequence_length (`int`, *optional*): maximum sequence length to use for the prompt.
|
238 |
+
text_encoder_index (`int`, *optional*):
|
239 |
+
Index of the text encoder to use. `0` for clip and `1` for T5.
|
240 |
+
"""
|
241 |
+
tokenizers = [self.tokenizer, self.tokenizer_2]
|
242 |
+
text_encoders = [self.text_encoder, self.text_encoder_2]
|
243 |
+
|
244 |
+
tokenizer = tokenizers[text_encoder_index]
|
245 |
+
text_encoder = text_encoders[text_encoder_index]
|
246 |
+
|
247 |
+
if max_sequence_length is None:
|
248 |
+
if text_encoder_index == 0:
|
249 |
+
max_length = min(self.tokenizer.model_max_length, actual_max_sequence_length)
|
250 |
+
if text_encoder_index == 1:
|
251 |
+
max_length = min(self.tokenizer_2.model_max_length, actual_max_sequence_length)
|
252 |
+
else:
|
253 |
+
max_length = max_sequence_length
|
254 |
+
|
255 |
+
if prompt is not None and isinstance(prompt, str):
|
256 |
+
batch_size = 1
|
257 |
+
elif prompt is not None and isinstance(prompt, list):
|
258 |
+
batch_size = len(prompt)
|
259 |
+
else:
|
260 |
+
batch_size = prompt_embeds.shape[0]
|
261 |
+
|
262 |
+
if prompt_embeds is None:
|
263 |
+
text_inputs = tokenizer(
|
264 |
+
prompt,
|
265 |
+
padding="max_length",
|
266 |
+
max_length=max_length,
|
267 |
+
truncation=True,
|
268 |
+
return_attention_mask=True,
|
269 |
+
return_tensors="pt",
|
270 |
+
)
|
271 |
+
text_input_ids = text_inputs.input_ids
|
272 |
+
if text_input_ids.shape[-1] > actual_max_sequence_length:
|
273 |
+
reprompt = tokenizer.batch_decode(text_input_ids[:, :actual_max_sequence_length], skip_special_tokens=True)
|
274 |
+
text_inputs = tokenizer(
|
275 |
+
reprompt,
|
276 |
+
padding="max_length",
|
277 |
+
max_length=max_length,
|
278 |
+
truncation=True,
|
279 |
+
return_attention_mask=True,
|
280 |
+
return_tensors="pt",
|
281 |
+
)
|
282 |
+
text_input_ids = text_inputs.input_ids
|
283 |
+
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
284 |
+
|
285 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
286 |
+
text_input_ids, untruncated_ids
|
287 |
+
):
|
288 |
+
_actual_max_sequence_length = min(tokenizer.model_max_length, actual_max_sequence_length)
|
289 |
+
removed_text = tokenizer.batch_decode(untruncated_ids[:, _actual_max_sequence_length - 1 : -1])
|
290 |
+
logger.warning(
|
291 |
+
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
292 |
+
f" {_actual_max_sequence_length} tokens: {removed_text}"
|
293 |
+
)
|
294 |
+
prompt_attention_mask = text_inputs.attention_mask.to(device)
|
295 |
+
|
296 |
+
if self.transformer.config.enable_text_attention_mask:
|
297 |
+
prompt_embeds = text_encoder(
|
298 |
+
text_input_ids.to(device),
|
299 |
+
attention_mask=prompt_attention_mask,
|
300 |
+
)
|
301 |
+
else:
|
302 |
+
prompt_embeds = text_encoder(
|
303 |
+
text_input_ids.to(device)
|
304 |
+
)
|
305 |
+
prompt_embeds = prompt_embeds[0]
|
306 |
+
prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
|
307 |
+
|
308 |
+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
309 |
+
|
310 |
+
bs_embed, seq_len, _ = prompt_embeds.shape
|
311 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
312 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
313 |
+
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
314 |
+
|
315 |
+
# get unconditional embeddings for classifier free guidance
|
316 |
+
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
317 |
+
uncond_tokens: List[str]
|
318 |
+
if negative_prompt is None:
|
319 |
+
uncond_tokens = [""] * batch_size
|
320 |
+
elif prompt is not None and type(prompt) is not type(negative_prompt):
|
321 |
+
raise TypeError(
|
322 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
323 |
+
f" {type(prompt)}."
|
324 |
+
)
|
325 |
+
elif isinstance(negative_prompt, str):
|
326 |
+
uncond_tokens = [negative_prompt]
|
327 |
+
elif batch_size != len(negative_prompt):
|
328 |
+
raise ValueError(
|
329 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
330 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
331 |
+
" the batch size of `prompt`."
|
332 |
+
)
|
333 |
+
else:
|
334 |
+
uncond_tokens = negative_prompt
|
335 |
+
|
336 |
+
max_length = prompt_embeds.shape[1]
|
337 |
+
uncond_input = tokenizer(
|
338 |
+
uncond_tokens,
|
339 |
+
padding="max_length",
|
340 |
+
max_length=max_length,
|
341 |
+
truncation=True,
|
342 |
+
return_tensors="pt",
|
343 |
+
)
|
344 |
+
uncond_input_ids = uncond_input.input_ids
|
345 |
+
if uncond_input_ids.shape[-1] > actual_max_sequence_length:
|
346 |
+
reuncond_tokens = tokenizer.batch_decode(uncond_input_ids[:, :actual_max_sequence_length], skip_special_tokens=True)
|
347 |
+
uncond_input = tokenizer(
|
348 |
+
reuncond_tokens,
|
349 |
+
padding="max_length",
|
350 |
+
max_length=max_length,
|
351 |
+
truncation=True,
|
352 |
+
return_attention_mask=True,
|
353 |
+
return_tensors="pt",
|
354 |
+
)
|
355 |
+
uncond_input_ids = uncond_input.input_ids
|
356 |
+
|
357 |
+
negative_prompt_attention_mask = uncond_input.attention_mask.to(device)
|
358 |
+
if self.transformer.config.enable_text_attention_mask:
|
359 |
+
negative_prompt_embeds = text_encoder(
|
360 |
+
uncond_input.input_ids.to(device),
|
361 |
+
attention_mask=negative_prompt_attention_mask,
|
362 |
+
)
|
363 |
+
else:
|
364 |
+
negative_prompt_embeds = text_encoder(
|
365 |
+
uncond_input.input_ids.to(device)
|
366 |
+
)
|
367 |
+
negative_prompt_embeds = negative_prompt_embeds[0]
|
368 |
+
negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
|
369 |
+
|
370 |
+
if do_classifier_free_guidance:
|
371 |
+
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
372 |
+
seq_len = negative_prompt_embeds.shape[1]
|
373 |
+
|
374 |
+
negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
|
375 |
+
|
376 |
+
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
377 |
+
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
378 |
+
|
379 |
+
return prompt_embeds, negative_prompt_embeds, prompt_attention_mask, negative_prompt_attention_mask
|
380 |
+
|
381 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
|
382 |
+
def run_safety_checker(self, image, device, dtype):
|
383 |
+
if self.safety_checker is None:
|
384 |
+
has_nsfw_concept = None
|
385 |
+
else:
|
386 |
+
if torch.is_tensor(image):
|
387 |
+
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
|
388 |
+
else:
|
389 |
+
feature_extractor_input = self.image_processor.numpy_to_pil(image)
|
390 |
+
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
|
391 |
+
image, has_nsfw_concept = self.safety_checker(
|
392 |
+
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
|
393 |
+
)
|
394 |
+
return image, has_nsfw_concept
|
395 |
+
|
396 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
397 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
398 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
399 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
400 |
+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
401 |
+
# and should be between [0, 1]
|
402 |
+
|
403 |
+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
404 |
+
extra_step_kwargs = {}
|
405 |
+
if accepts_eta:
|
406 |
+
extra_step_kwargs["eta"] = eta
|
407 |
+
|
408 |
+
# check if the scheduler accepts generator
|
409 |
+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
410 |
+
if accepts_generator:
|
411 |
+
extra_step_kwargs["generator"] = generator
|
412 |
+
return extra_step_kwargs
|
413 |
+
|
414 |
+
def check_inputs(
|
415 |
+
self,
|
416 |
+
prompt,
|
417 |
+
height,
|
418 |
+
width,
|
419 |
+
negative_prompt=None,
|
420 |
+
prompt_embeds=None,
|
421 |
+
negative_prompt_embeds=None,
|
422 |
+
prompt_attention_mask=None,
|
423 |
+
negative_prompt_attention_mask=None,
|
424 |
+
prompt_embeds_2=None,
|
425 |
+
negative_prompt_embeds_2=None,
|
426 |
+
prompt_attention_mask_2=None,
|
427 |
+
negative_prompt_attention_mask_2=None,
|
428 |
+
callback_on_step_end_tensor_inputs=None,
|
429 |
+
):
|
430 |
+
if height % 8 != 0 or width % 8 != 0:
|
431 |
+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
432 |
+
|
433 |
+
if callback_on_step_end_tensor_inputs is not None and not all(
|
434 |
+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
435 |
+
):
|
436 |
+
raise ValueError(
|
437 |
+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
438 |
+
)
|
439 |
+
|
440 |
+
if prompt is not None and prompt_embeds is not None:
|
441 |
+
raise ValueError(
|
442 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
443 |
+
" only forward one of the two."
|
444 |
+
)
|
445 |
+
elif prompt is None and prompt_embeds is None:
|
446 |
+
raise ValueError(
|
447 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
448 |
+
)
|
449 |
+
elif prompt is None and prompt_embeds_2 is None:
|
450 |
+
raise ValueError(
|
451 |
+
"Provide either `prompt` or `prompt_embeds_2`. Cannot leave both `prompt` and `prompt_embeds_2` undefined."
|
452 |
+
)
|
453 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
454 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
455 |
+
|
456 |
+
if prompt_embeds is not None and prompt_attention_mask is None:
|
457 |
+
raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
|
458 |
+
|
459 |
+
if prompt_embeds_2 is not None and prompt_attention_mask_2 is None:
|
460 |
+
raise ValueError("Must provide `prompt_attention_mask_2` when specifying `prompt_embeds_2`.")
|
461 |
+
|
462 |
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
463 |
+
raise ValueError(
|
464 |
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
465 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
466 |
+
)
|
467 |
+
|
468 |
+
if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
|
469 |
+
raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
|
470 |
+
|
471 |
+
if negative_prompt_embeds_2 is not None and negative_prompt_attention_mask_2 is None:
|
472 |
+
raise ValueError(
|
473 |
+
"Must provide `negative_prompt_attention_mask_2` when specifying `negative_prompt_embeds_2`."
|
474 |
+
)
|
475 |
+
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
476 |
+
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
477 |
+
raise ValueError(
|
478 |
+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
479 |
+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
480 |
+
f" {negative_prompt_embeds.shape}."
|
481 |
+
)
|
482 |
+
if prompt_embeds_2 is not None and negative_prompt_embeds_2 is not None:
|
483 |
+
if prompt_embeds_2.shape != negative_prompt_embeds_2.shape:
|
484 |
+
raise ValueError(
|
485 |
+
"`prompt_embeds_2` and `negative_prompt_embeds_2` must have the same shape when passed directly, but"
|
486 |
+
f" got: `prompt_embeds_2` {prompt_embeds_2.shape} != `negative_prompt_embeds_2`"
|
487 |
+
f" {negative_prompt_embeds_2.shape}."
|
488 |
+
)
|
489 |
+
|
490 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
|
491 |
+
def prepare_latents(self, batch_size, num_channels_latents, video_length, height, width, dtype, device, generator, latents=None):
|
492 |
+
if self.vae.quant_conv is None or self.vae.quant_conv.weight.ndim==5:
|
493 |
+
if self.vae.cache_mag_vae:
|
494 |
+
mini_batch_encoder = self.vae.mini_batch_encoder
|
495 |
+
mini_batch_decoder = self.vae.mini_batch_decoder
|
496 |
+
shape = (batch_size, num_channels_latents, int((video_length - 1) // mini_batch_encoder * mini_batch_decoder + 1) if video_length != 1 else 1, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
497 |
+
else:
|
498 |
+
mini_batch_encoder = self.vae.mini_batch_encoder
|
499 |
+
mini_batch_decoder = self.vae.mini_batch_decoder
|
500 |
+
shape = (batch_size, num_channels_latents, int(video_length // mini_batch_encoder * mini_batch_decoder) if video_length != 1 else 1, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
501 |
+
else:
|
502 |
+
shape = (batch_size, num_channels_latents, video_length, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
503 |
+
|
504 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
505 |
+
raise ValueError(
|
506 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
507 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
508 |
+
)
|
509 |
+
|
510 |
+
if latents is None:
|
511 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
512 |
+
else:
|
513 |
+
latents = latents.to(device)
|
514 |
+
|
515 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
516 |
+
latents = latents * self.scheduler.init_noise_sigma
|
517 |
+
return latents
|
518 |
+
|
519 |
+
def smooth_output(self, video, mini_batch_encoder, mini_batch_decoder):
|
520 |
+
if video.size()[2] <= mini_batch_encoder:
|
521 |
+
return video
|
522 |
+
prefix_index_before = mini_batch_encoder // 2
|
523 |
+
prefix_index_after = mini_batch_encoder - prefix_index_before
|
524 |
+
pixel_values = video[:, :, prefix_index_before:-prefix_index_after]
|
525 |
+
|
526 |
+
# Encode middle videos
|
527 |
+
latents = self.vae.encode(pixel_values)[0]
|
528 |
+
latents = latents.mode()
|
529 |
+
# Decode middle videos
|
530 |
+
middle_video = self.vae.decode(latents)[0]
|
531 |
+
|
532 |
+
video[:, :, prefix_index_before:-prefix_index_after] = (video[:, :, prefix_index_before:-prefix_index_after] + middle_video) / 2
|
533 |
+
return video
|
534 |
+
|
535 |
+
def decode_latents(self, latents):
|
536 |
+
video_length = latents.shape[2]
|
537 |
+
latents = 1 / self.vae.config.scaling_factor * latents
|
538 |
+
if self.vae.quant_conv is None or self.vae.quant_conv.weight.ndim==5:
|
539 |
+
mini_batch_encoder = self.vae.mini_batch_encoder
|
540 |
+
mini_batch_decoder = self.vae.mini_batch_decoder
|
541 |
+
video = self.vae.decode(latents)[0]
|
542 |
+
video = video.clamp(-1, 1)
|
543 |
+
if not self.vae.cache_compression_vae and not self.vae.cache_mag_vae:
|
544 |
+
video = self.smooth_output(video, mini_batch_encoder, mini_batch_decoder).cpu().clamp(-1, 1)
|
545 |
+
else:
|
546 |
+
latents = rearrange(latents, "b c f h w -> (b f) c h w")
|
547 |
+
video = []
|
548 |
+
for frame_idx in tqdm(range(latents.shape[0])):
|
549 |
+
video.append(self.vae.decode(latents[frame_idx:frame_idx+1]).sample)
|
550 |
+
video = torch.cat(video)
|
551 |
+
video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
|
552 |
+
video = (video / 2 + 0.5).clamp(0, 1)
|
553 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
|
554 |
+
video = video.cpu().float().numpy()
|
555 |
+
return video
|
556 |
+
|
557 |
+
@property
|
558 |
+
def guidance_scale(self):
|
559 |
+
return self._guidance_scale
|
560 |
+
|
561 |
+
@property
|
562 |
+
def guidance_rescale(self):
|
563 |
+
return self._guidance_rescale
|
564 |
+
|
565 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
566 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
567 |
+
# corresponds to doing no classifier free guidance.
|
568 |
+
@property
|
569 |
+
def do_classifier_free_guidance(self):
|
570 |
+
return self._guidance_scale > 1
|
571 |
+
|
572 |
+
@property
|
573 |
+
def num_timesteps(self):
|
574 |
+
return self._num_timesteps
|
575 |
+
|
576 |
+
@property
|
577 |
+
def interrupt(self):
|
578 |
+
return self._interrupt
|
579 |
+
|
580 |
+
def enable_autocast_float8_transformer(self):
|
581 |
+
self.enable_autocast_float8_transformer_flag = True
|
582 |
+
|
583 |
+
@torch.no_grad()
|
584 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
585 |
+
def __call__(
|
586 |
+
self,
|
587 |
+
prompt: Union[str, List[str]] = None,
|
588 |
+
video_length: Optional[int] = None,
|
589 |
+
height: Optional[int] = None,
|
590 |
+
width: Optional[int] = None,
|
591 |
+
num_inference_steps: Optional[int] = 50,
|
592 |
+
guidance_scale: Optional[float] = 5.0,
|
593 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
594 |
+
num_images_per_prompt: Optional[int] = 1,
|
595 |
+
eta: Optional[float] = 0.0,
|
596 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
597 |
+
latents: Optional[torch.Tensor] = None,
|
598 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
599 |
+
prompt_embeds_2: Optional[torch.Tensor] = None,
|
600 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
601 |
+
negative_prompt_embeds_2: Optional[torch.Tensor] = None,
|
602 |
+
prompt_attention_mask: Optional[torch.Tensor] = None,
|
603 |
+
prompt_attention_mask_2: Optional[torch.Tensor] = None,
|
604 |
+
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
605 |
+
negative_prompt_attention_mask_2: Optional[torch.Tensor] = None,
|
606 |
+
output_type: Optional[str] = "latent",
|
607 |
+
return_dict: bool = True,
|
608 |
+
callback_on_step_end: Optional[
|
609 |
+
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
610 |
+
] = None,
|
611 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
612 |
+
guidance_rescale: float = 0.0,
|
613 |
+
original_size: Optional[Tuple[int, int]] = (1024, 1024),
|
614 |
+
target_size: Optional[Tuple[int, int]] = None,
|
615 |
+
crops_coords_top_left: Tuple[int, int] = (0, 0),
|
616 |
+
comfyui_progressbar: bool = False,
|
617 |
+
):
|
618 |
+
r"""
|
619 |
+
Generates images or video using the EasyAnimate pipeline based on the provided prompts.
|
620 |
+
|
621 |
+
Examples:
|
622 |
+
prompt (`str` or `List[str]`, *optional*):
|
623 |
+
Text prompts to guide the image or video generation. If not provided, use `prompt_embeds` instead.
|
624 |
+
video_length (`int`, *optional*):
|
625 |
+
Length of the generated video (in frames).
|
626 |
+
height (`int`, *optional*):
|
627 |
+
Height of the generated image in pixels.
|
628 |
+
width (`int`, *optional*):
|
629 |
+
Width of the generated image in pixels.
|
630 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
631 |
+
Number of denoising steps during generation. More steps generally yield higher quality images but slow down inference.
|
632 |
+
guidance_scale (`float`, *optional*, defaults to 5.0):
|
633 |
+
Encourages the model to align outputs with prompts. A higher value may decrease image quality.
|
634 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
635 |
+
Prompts indicating what to exclude in generation. If not specified, use `negative_prompt_embeds`.
|
636 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
637 |
+
Number of images to generate for each prompt.
|
638 |
+
eta (`float`, *optional*, defaults to 0.0):
|
639 |
+
Applies to DDIM scheduling. Controlled by the eta parameter from the related literature.
|
640 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
641 |
+
A generator to ensure reproducibility in image generation.
|
642 |
+
latents (`torch.Tensor`, *optional*):
|
643 |
+
Predefined latent tensors to condition generation.
|
644 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
645 |
+
Text embeddings for the prompts. Overrides prompt string inputs for more flexibility.
|
646 |
+
prompt_embeds_2 (`torch.Tensor`, *optional*):
|
647 |
+
Secondary text embeddings to supplement or replace the initial prompt embeddings.
|
648 |
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
649 |
+
Embeddings for negative prompts. Overrides string inputs if defined.
|
650 |
+
negative_prompt_embeds_2 (`torch.Tensor`, *optional*):
|
651 |
+
Secondary embeddings for negative prompts, similar to `negative_prompt_embeds`.
|
652 |
+
prompt_attention_mask (`torch.Tensor`, *optional*):
|
653 |
+
Attention mask for the primary prompt embeddings.
|
654 |
+
prompt_attention_mask_2 (`torch.Tensor`, *optional*):
|
655 |
+
Attention mask for the secondary prompt embeddings.
|
656 |
+
negative_prompt_attention_mask (`torch.Tensor`, *optional*):
|
657 |
+
Attention mask for negative prompt embeddings.
|
658 |
+
negative_prompt_attention_mask_2 (`torch.Tensor`, *optional*):
|
659 |
+
Attention mask for secondary negative prompt embeddings.
|
660 |
+
output_type (`str`, *optional*, defaults to "latent"):
|
661 |
+
Format of the generated output, either as a PIL image or as a NumPy array.
|
662 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
663 |
+
If `True`, returns a structured output. Otherwise returns a simple tuple.
|
664 |
+
callback_on_step_end (`Callable`, *optional*):
|
665 |
+
Functions called at the end of each denoising step.
|
666 |
+
callback_on_step_end_tensor_inputs (`List[str]`, *optional*):
|
667 |
+
Tensor names to be included in callback function calls.
|
668 |
+
guidance_rescale (`float`, *optional*, defaults to 0.0):
|
669 |
+
Adjusts noise levels based on guidance scale.
|
670 |
+
original_size (`Tuple[int, int]`, *optional*, defaults to `(1024, 1024)`):
|
671 |
+
Original dimensions of the output.
|
672 |
+
target_size (`Tuple[int, int]`, *optional*):
|
673 |
+
Desired output dimensions for calculations.
|
674 |
+
crops_coords_top_left (`Tuple[int, int]`, *optional*, defaults to `(0, 0)`):
|
675 |
+
Coordinates for cropping.
|
676 |
+
|
677 |
+
Returns:
|
678 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
679 |
+
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
|
680 |
+
otherwise a `tuple` is returned where the first element is a list with the generated images and the
|
681 |
+
second element is a list of `bool`s indicating whether the corresponding generated image contains
|
682 |
+
"not-safe-for-work" (nsfw) content.
|
683 |
+
"""
|
684 |
+
|
685 |
+
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
686 |
+
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
687 |
+
|
688 |
+
# 0. default height and width
|
689 |
+
height = int((height // 16) * 16)
|
690 |
+
width = int((width // 16) * 16)
|
691 |
+
|
692 |
+
# 1. Check inputs. Raise error if not correct
|
693 |
+
self.check_inputs(
|
694 |
+
prompt,
|
695 |
+
height,
|
696 |
+
width,
|
697 |
+
negative_prompt,
|
698 |
+
prompt_embeds,
|
699 |
+
negative_prompt_embeds,
|
700 |
+
prompt_attention_mask,
|
701 |
+
negative_prompt_attention_mask,
|
702 |
+
prompt_embeds_2,
|
703 |
+
negative_prompt_embeds_2,
|
704 |
+
prompt_attention_mask_2,
|
705 |
+
negative_prompt_attention_mask_2,
|
706 |
+
callback_on_step_end_tensor_inputs,
|
707 |
+
)
|
708 |
+
self._guidance_scale = guidance_scale
|
709 |
+
self._guidance_rescale = guidance_rescale
|
710 |
+
self._interrupt = False
|
711 |
+
|
712 |
+
# 2. Define call parameters
|
713 |
+
if prompt is not None and isinstance(prompt, str):
|
714 |
+
batch_size = 1
|
715 |
+
elif prompt is not None and isinstance(prompt, list):
|
716 |
+
batch_size = len(prompt)
|
717 |
+
else:
|
718 |
+
batch_size = prompt_embeds.shape[0]
|
719 |
+
|
720 |
+
device = self._execution_device
|
721 |
+
|
722 |
+
# 3. Encode input prompt
|
723 |
+
(
|
724 |
+
prompt_embeds,
|
725 |
+
negative_prompt_embeds,
|
726 |
+
prompt_attention_mask,
|
727 |
+
negative_prompt_attention_mask,
|
728 |
+
) = self.encode_prompt(
|
729 |
+
prompt=prompt,
|
730 |
+
device=device,
|
731 |
+
dtype=self.transformer.dtype,
|
732 |
+
num_images_per_prompt=num_images_per_prompt,
|
733 |
+
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
734 |
+
negative_prompt=negative_prompt,
|
735 |
+
prompt_embeds=prompt_embeds,
|
736 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
737 |
+
prompt_attention_mask=prompt_attention_mask,
|
738 |
+
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
739 |
+
text_encoder_index=0,
|
740 |
+
)
|
741 |
+
(
|
742 |
+
prompt_embeds_2,
|
743 |
+
negative_prompt_embeds_2,
|
744 |
+
prompt_attention_mask_2,
|
745 |
+
negative_prompt_attention_mask_2,
|
746 |
+
) = self.encode_prompt(
|
747 |
+
prompt=prompt,
|
748 |
+
device=device,
|
749 |
+
dtype=self.transformer.dtype,
|
750 |
+
num_images_per_prompt=num_images_per_prompt,
|
751 |
+
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
752 |
+
negative_prompt=negative_prompt,
|
753 |
+
prompt_embeds=prompt_embeds_2,
|
754 |
+
negative_prompt_embeds=negative_prompt_embeds_2,
|
755 |
+
prompt_attention_mask=prompt_attention_mask_2,
|
756 |
+
negative_prompt_attention_mask=negative_prompt_attention_mask_2,
|
757 |
+
text_encoder_index=1,
|
758 |
+
)
|
759 |
+
torch.cuda.empty_cache()
|
760 |
+
|
761 |
+
# 4. Prepare timesteps
|
762 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
763 |
+
timesteps = self.scheduler.timesteps
|
764 |
+
if comfyui_progressbar:
|
765 |
+
from comfy.utils import ProgressBar
|
766 |
+
pbar = ProgressBar(num_inference_steps + 1)
|
767 |
+
|
768 |
+
# 5. Prepare latent variables
|
769 |
+
num_channels_latents = self.transformer.config.in_channels
|
770 |
+
latents = self.prepare_latents(
|
771 |
+
batch_size * num_images_per_prompt,
|
772 |
+
num_channels_latents,
|
773 |
+
video_length,
|
774 |
+
height,
|
775 |
+
width,
|
776 |
+
prompt_embeds.dtype,
|
777 |
+
device,
|
778 |
+
generator,
|
779 |
+
latents,
|
780 |
+
)
|
781 |
+
if comfyui_progressbar:
|
782 |
+
pbar.update(1)
|
783 |
+
|
784 |
+
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
785 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
786 |
+
|
787 |
+
# 7 create image_rotary_emb, style embedding & time ids
|
788 |
+
grid_height = height // 8 // self.transformer.config.patch_size
|
789 |
+
grid_width = width // 8 // self.transformer.config.patch_size
|
790 |
+
if self.transformer.config.get("time_position_encoding_type", "2d_rope") == "3d_rope":
|
791 |
+
base_size_width = 720 // 8 // self.transformer.config.patch_size
|
792 |
+
base_size_height = 480 // 8 // self.transformer.config.patch_size
|
793 |
+
|
794 |
+
grid_crops_coords = get_resize_crop_region_for_grid(
|
795 |
+
(grid_height, grid_width), base_size_width, base_size_height
|
796 |
+
)
|
797 |
+
image_rotary_emb = get_3d_rotary_pos_embed(
|
798 |
+
self.transformer.config.attention_head_dim, grid_crops_coords, grid_size=(grid_height, grid_width),
|
799 |
+
temporal_size=latents.size(2), use_real=True,
|
800 |
+
)
|
801 |
+
else:
|
802 |
+
base_size = 512 // 8 // self.transformer.config.patch_size
|
803 |
+
grid_crops_coords = get_resize_crop_region_for_grid(
|
804 |
+
(grid_height, grid_width), base_size, base_size
|
805 |
+
)
|
806 |
+
image_rotary_emb = get_2d_rotary_pos_embed(
|
807 |
+
self.transformer.config.attention_head_dim, grid_crops_coords, (grid_height, grid_width)
|
808 |
+
)
|
809 |
+
|
810 |
+
# Get other hunyuan params
|
811 |
+
style = torch.tensor([0], device=device)
|
812 |
+
|
813 |
+
target_size = target_size or (height, width)
|
814 |
+
add_time_ids = list(original_size + target_size + crops_coords_top_left)
|
815 |
+
add_time_ids = torch.tensor([add_time_ids], dtype=prompt_embeds.dtype)
|
816 |
+
|
817 |
+
if self.do_classifier_free_guidance:
|
818 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
819 |
+
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask])
|
820 |
+
prompt_embeds_2 = torch.cat([negative_prompt_embeds_2, prompt_embeds_2])
|
821 |
+
prompt_attention_mask_2 = torch.cat([negative_prompt_attention_mask_2, prompt_attention_mask_2])
|
822 |
+
add_time_ids = torch.cat([add_time_ids] * 2, dim=0)
|
823 |
+
style = torch.cat([style] * 2, dim=0)
|
824 |
+
|
825 |
+
# To latents.device
|
826 |
+
prompt_embeds = prompt_embeds.to(device=device)
|
827 |
+
prompt_attention_mask = prompt_attention_mask.to(device=device)
|
828 |
+
prompt_embeds_2 = prompt_embeds_2.to(device=device)
|
829 |
+
prompt_attention_mask_2 = prompt_attention_mask_2.to(device=device)
|
830 |
+
add_time_ids = add_time_ids.to(dtype=prompt_embeds.dtype, device=device).repeat(
|
831 |
+
batch_size * num_images_per_prompt, 1
|
832 |
+
)
|
833 |
+
style = style.to(device=device).repeat(batch_size * num_images_per_prompt)
|
834 |
+
|
835 |
+
torch.cuda.empty_cache()
|
836 |
+
if self.enable_autocast_float8_transformer_flag:
|
837 |
+
origin_weight_dtype = self.transformer.dtype
|
838 |
+
self.transformer = self.transformer.to(torch.float8_e4m3fn)
|
839 |
+
# 8. Denoising loop
|
840 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
841 |
+
self._num_timesteps = len(timesteps)
|
842 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
843 |
+
for i, t in enumerate(timesteps):
|
844 |
+
if self.interrupt:
|
845 |
+
continue
|
846 |
+
|
847 |
+
# expand the latents if we are doing classifier free guidance
|
848 |
+
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
849 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
850 |
+
|
851 |
+
# expand scalar t to 1-D tensor to match the 1st dim of latent_model_input
|
852 |
+
t_expand = torch.tensor([t] * latent_model_input.shape[0], device=device).to(
|
853 |
+
dtype=latent_model_input.dtype
|
854 |
+
)
|
855 |
+
|
856 |
+
# predict the noise residual
|
857 |
+
noise_pred = self.transformer(
|
858 |
+
latent_model_input,
|
859 |
+
t_expand,
|
860 |
+
encoder_hidden_states=prompt_embeds,
|
861 |
+
text_embedding_mask=prompt_attention_mask,
|
862 |
+
encoder_hidden_states_t5=prompt_embeds_2,
|
863 |
+
text_embedding_mask_t5=prompt_attention_mask_2,
|
864 |
+
image_meta_size=add_time_ids,
|
865 |
+
style=style,
|
866 |
+
image_rotary_emb=image_rotary_emb,
|
867 |
+
return_dict=False,
|
868 |
+
)[0]
|
869 |
+
|
870 |
+
if noise_pred.size()[1] != self.vae.config.latent_channels:
|
871 |
+
noise_pred, _ = noise_pred.chunk(2, dim=1)
|
872 |
+
|
873 |
+
# perform guidance
|
874 |
+
if self.do_classifier_free_guidance:
|
875 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
876 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
877 |
+
|
878 |
+
if self.do_classifier_free_guidance and guidance_rescale > 0.0:
|
879 |
+
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
880 |
+
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
|
881 |
+
|
882 |
+
# compute the previous noisy sample x_t -> x_t-1
|
883 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
884 |
+
|
885 |
+
if callback_on_step_end is not None:
|
886 |
+
callback_kwargs = {}
|
887 |
+
for k in callback_on_step_end_tensor_inputs:
|
888 |
+
callback_kwargs[k] = locals()[k]
|
889 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
890 |
+
|
891 |
+
latents = callback_outputs.pop("latents", latents)
|
892 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
893 |
+
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
894 |
+
prompt_embeds_2 = callback_outputs.pop("prompt_embeds_2", prompt_embeds_2)
|
895 |
+
negative_prompt_embeds_2 = callback_outputs.pop(
|
896 |
+
"negative_prompt_embeds_2", negative_prompt_embeds_2
|
897 |
+
)
|
898 |
+
|
899 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
900 |
+
progress_bar.update()
|
901 |
+
|
902 |
+
if XLA_AVAILABLE:
|
903 |
+
xm.mark_step()
|
904 |
+
|
905 |
+
if comfyui_progressbar:
|
906 |
+
pbar.update(1)
|
907 |
+
|
908 |
+
if self.enable_autocast_float8_transformer_flag:
|
909 |
+
self.transformer = self.transformer.to("cpu", origin_weight_dtype)
|
910 |
+
|
911 |
+
torch.cuda.empty_cache()
|
912 |
+
# Post-processing
|
913 |
+
video = self.decode_latents(latents)
|
914 |
+
|
915 |
+
# Convert to tensor
|
916 |
+
if output_type == "latent":
|
917 |
+
video = torch.from_numpy(video)
|
918 |
+
|
919 |
+
# Offload all models
|
920 |
+
self.maybe_free_model_hooks()
|
921 |
+
|
922 |
+
if not return_dict:
|
923 |
+
return video
|
924 |
+
|
925 |
+
return EasyAnimatePipelineOutput(videos=video)
|
easyanimate/pipeline/pipeline_easyanimate_multi_text_encoder_control.py
ADDED
@@ -0,0 +1,996 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 EasyAnimate Authors and The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import inspect
|
16 |
+
import re
|
17 |
+
import urllib.parse as ul
|
18 |
+
from dataclasses import dataclass
|
19 |
+
from typing import Callable, Dict, List, Optional, Tuple, Union
|
20 |
+
|
21 |
+
import numpy as np
|
22 |
+
import torch
|
23 |
+
import torch.nn.functional as F
|
24 |
+
from diffusers import DiffusionPipeline, ImagePipelineOutput
|
25 |
+
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
|
26 |
+
from diffusers.image_processor import VaeImageProcessor
|
27 |
+
from diffusers.models import AutoencoderKL, HunyuanDiT2DModel
|
28 |
+
from diffusers.models.embeddings import (get_2d_rotary_pos_embed,
|
29 |
+
get_3d_rotary_pos_embed)
|
30 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
31 |
+
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
32 |
+
from diffusers.pipelines.stable_diffusion.safety_checker import \
|
33 |
+
StableDiffusionSafetyChecker
|
34 |
+
from diffusers.schedulers import DDIMScheduler, DPMSolverMultistepScheduler
|
35 |
+
from diffusers.utils import (BACKENDS_MAPPING, BaseOutput, deprecate,
|
36 |
+
is_bs4_available, is_ftfy_available,
|
37 |
+
is_torch_xla_available, logging,
|
38 |
+
replace_example_docstring)
|
39 |
+
from diffusers.utils.torch_utils import randn_tensor
|
40 |
+
from einops import rearrange
|
41 |
+
from PIL import Image
|
42 |
+
from tqdm import tqdm
|
43 |
+
from transformers import (BertModel, BertTokenizer, CLIPImageProcessor,
|
44 |
+
CLIPVisionModelWithProjection,
|
45 |
+
T5EncoderModel, T5Tokenizer)
|
46 |
+
|
47 |
+
from ..models import AutoencoderKLMagvit, EasyAnimateTransformer3DModel
|
48 |
+
from .pipeline_easyanimate import EasyAnimatePipelineOutput
|
49 |
+
|
50 |
+
if is_torch_xla_available():
|
51 |
+
import torch_xla.core.xla_model as xm
|
52 |
+
|
53 |
+
XLA_AVAILABLE = True
|
54 |
+
else:
|
55 |
+
XLA_AVAILABLE = False
|
56 |
+
|
57 |
+
|
58 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
59 |
+
|
60 |
+
EXAMPLE_DOC_STRING = """
|
61 |
+
Examples:
|
62 |
+
```py
|
63 |
+
>>> pass
|
64 |
+
```
|
65 |
+
"""
|
66 |
+
|
67 |
+
def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
|
68 |
+
tw = tgt_width
|
69 |
+
th = tgt_height
|
70 |
+
h, w = src
|
71 |
+
r = h / w
|
72 |
+
if r > (th / tw):
|
73 |
+
resize_height = th
|
74 |
+
resize_width = int(round(th / h * w))
|
75 |
+
else:
|
76 |
+
resize_width = tw
|
77 |
+
resize_height = int(round(tw / w * h))
|
78 |
+
|
79 |
+
crop_top = int(round((th - resize_height) / 2.0))
|
80 |
+
crop_left = int(round((tw - resize_width) / 2.0))
|
81 |
+
|
82 |
+
return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
|
83 |
+
|
84 |
+
|
85 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
|
86 |
+
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
87 |
+
"""
|
88 |
+
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
|
89 |
+
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
|
90 |
+
"""
|
91 |
+
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
|
92 |
+
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
|
93 |
+
# rescale the results from guidance (fixes overexposure)
|
94 |
+
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
|
95 |
+
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
|
96 |
+
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
|
97 |
+
return noise_cfg
|
98 |
+
|
99 |
+
|
100 |
+
class EasyAnimatePipeline_Multi_Text_Encoder_Control(DiffusionPipeline):
|
101 |
+
r"""
|
102 |
+
Pipeline for text-to-video generation using EasyAnimate.
|
103 |
+
|
104 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
105 |
+
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
106 |
+
|
107 |
+
EasyAnimate uses two text encoders: [mT5](https://huggingface.co/google/mt5-base) and [bilingual CLIP](fine-tuned by
|
108 |
+
HunyuanDiT team)
|
109 |
+
|
110 |
+
Args:
|
111 |
+
vae ([`AutoencoderKLMagvit`]):
|
112 |
+
Variational Auto-Encoder (VAE) Model to encode and decode video to and from latent representations.
|
113 |
+
text_encoder (Optional[`~transformers.BertModel`, `~transformers.CLIPTextModel`]):
|
114 |
+
Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
|
115 |
+
EasyAnimate uses a fine-tuned [bilingual CLIP].
|
116 |
+
tokenizer (Optional[`~transformers.BertTokenizer`, `~transformers.CLIPTokenizer`]):
|
117 |
+
A `BertTokenizer` or `CLIPTokenizer` to tokenize text.
|
118 |
+
transformer ([`EasyAnimateTransformer3DModel`]):
|
119 |
+
The EasyAnimate model designed by Tencent Hunyuan.
|
120 |
+
text_encoder_2 (`T5EncoderModel`):
|
121 |
+
The mT5 embedder.
|
122 |
+
tokenizer_2 (`T5Tokenizer`):
|
123 |
+
The tokenizer for the mT5 embedder.
|
124 |
+
scheduler ([`DDIMScheduler`]):
|
125 |
+
A scheduler to be used in combination with EasyAnimate to denoise the encoded image latents.
|
126 |
+
"""
|
127 |
+
|
128 |
+
model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
|
129 |
+
_optional_components = [
|
130 |
+
"safety_checker",
|
131 |
+
"feature_extractor",
|
132 |
+
"text_encoder_2",
|
133 |
+
"tokenizer_2",
|
134 |
+
"text_encoder",
|
135 |
+
"tokenizer",
|
136 |
+
]
|
137 |
+
_exclude_from_cpu_offload = ["safety_checker"]
|
138 |
+
_callback_tensor_inputs = [
|
139 |
+
"latents",
|
140 |
+
"prompt_embeds",
|
141 |
+
"negative_prompt_embeds",
|
142 |
+
"prompt_embeds_2",
|
143 |
+
"negative_prompt_embeds_2",
|
144 |
+
]
|
145 |
+
|
146 |
+
def __init__(
|
147 |
+
self,
|
148 |
+
vae: AutoencoderKLMagvit,
|
149 |
+
text_encoder: BertModel,
|
150 |
+
tokenizer: BertTokenizer,
|
151 |
+
text_encoder_2: T5EncoderModel,
|
152 |
+
tokenizer_2: T5Tokenizer,
|
153 |
+
transformer: EasyAnimateTransformer3DModel,
|
154 |
+
scheduler: DDIMScheduler,
|
155 |
+
safety_checker: StableDiffusionSafetyChecker,
|
156 |
+
feature_extractor: CLIPImageProcessor,
|
157 |
+
requires_safety_checker: bool = True
|
158 |
+
):
|
159 |
+
super().__init__()
|
160 |
+
|
161 |
+
self.register_modules(
|
162 |
+
vae=vae,
|
163 |
+
text_encoder=text_encoder,
|
164 |
+
tokenizer=tokenizer,
|
165 |
+
tokenizer_2=tokenizer_2,
|
166 |
+
transformer=transformer,
|
167 |
+
scheduler=scheduler,
|
168 |
+
safety_checker=safety_checker,
|
169 |
+
feature_extractor=feature_extractor,
|
170 |
+
text_encoder_2=text_encoder_2
|
171 |
+
)
|
172 |
+
|
173 |
+
if safety_checker is None and requires_safety_checker:
|
174 |
+
logger.warning(
|
175 |
+
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
176 |
+
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
|
177 |
+
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
|
178 |
+
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
|
179 |
+
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
|
180 |
+
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
|
181 |
+
)
|
182 |
+
|
183 |
+
if safety_checker is not None and feature_extractor is None:
|
184 |
+
raise ValueError(
|
185 |
+
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
|
186 |
+
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
187 |
+
)
|
188 |
+
|
189 |
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
190 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
191 |
+
self.mask_processor = VaeImageProcessor(
|
192 |
+
vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
|
193 |
+
)
|
194 |
+
self.enable_autocast_float8_transformer_flag = False
|
195 |
+
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
196 |
+
|
197 |
+
def enable_sequential_cpu_offload(self, *args, **kwargs):
|
198 |
+
super().enable_sequential_cpu_offload(*args, **kwargs)
|
199 |
+
if hasattr(self.transformer, "clip_projection") and self.transformer.clip_projection is not None:
|
200 |
+
import accelerate
|
201 |
+
accelerate.hooks.remove_hook_from_module(self.transformer.clip_projection, recurse=True)
|
202 |
+
self.transformer.clip_projection = self.transformer.clip_projection.to("cuda")
|
203 |
+
|
204 |
+
def encode_prompt(
|
205 |
+
self,
|
206 |
+
prompt: str,
|
207 |
+
device: torch.device,
|
208 |
+
dtype: torch.dtype,
|
209 |
+
num_images_per_prompt: int = 1,
|
210 |
+
do_classifier_free_guidance: bool = True,
|
211 |
+
negative_prompt: Optional[str] = None,
|
212 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
213 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
214 |
+
prompt_attention_mask: Optional[torch.Tensor] = None,
|
215 |
+
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
216 |
+
max_sequence_length: Optional[int] = None,
|
217 |
+
text_encoder_index: int = 0,
|
218 |
+
actual_max_sequence_length: int = 256
|
219 |
+
):
|
220 |
+
r"""
|
221 |
+
Encodes the prompt into text encoder hidden states.
|
222 |
+
|
223 |
+
Args:
|
224 |
+
prompt (`str` or `List[str]`, *optional*):
|
225 |
+
prompt to be encoded
|
226 |
+
device: (`torch.device`):
|
227 |
+
torch device
|
228 |
+
dtype (`torch.dtype`):
|
229 |
+
torch dtype
|
230 |
+
num_images_per_prompt (`int`):
|
231 |
+
number of images that should be generated per prompt
|
232 |
+
do_classifier_free_guidance (`bool`):
|
233 |
+
whether to use classifier free guidance or not
|
234 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
235 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
236 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
237 |
+
less than `1`).
|
238 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
239 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
240 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
241 |
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
242 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
243 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
244 |
+
argument.
|
245 |
+
prompt_attention_mask (`torch.Tensor`, *optional*):
|
246 |
+
Attention mask for the prompt. Required when `prompt_embeds` is passed directly.
|
247 |
+
negative_prompt_attention_mask (`torch.Tensor`, *optional*):
|
248 |
+
Attention mask for the negative prompt. Required when `negative_prompt_embeds` is passed directly.
|
249 |
+
max_sequence_length (`int`, *optional*): maximum sequence length to use for the prompt.
|
250 |
+
text_encoder_index (`int`, *optional*):
|
251 |
+
Index of the text encoder to use. `0` for clip and `1` for T5.
|
252 |
+
"""
|
253 |
+
tokenizers = [self.tokenizer, self.tokenizer_2]
|
254 |
+
text_encoders = [self.text_encoder, self.text_encoder_2]
|
255 |
+
|
256 |
+
tokenizer = tokenizers[text_encoder_index]
|
257 |
+
text_encoder = text_encoders[text_encoder_index]
|
258 |
+
|
259 |
+
if max_sequence_length is None:
|
260 |
+
if text_encoder_index == 0:
|
261 |
+
max_length = min(self.tokenizer.model_max_length, actual_max_sequence_length)
|
262 |
+
if text_encoder_index == 1:
|
263 |
+
max_length = min(self.tokenizer_2.model_max_length, actual_max_sequence_length)
|
264 |
+
else:
|
265 |
+
max_length = max_sequence_length
|
266 |
+
|
267 |
+
if prompt is not None and isinstance(prompt, str):
|
268 |
+
batch_size = 1
|
269 |
+
elif prompt is not None and isinstance(prompt, list):
|
270 |
+
batch_size = len(prompt)
|
271 |
+
else:
|
272 |
+
batch_size = prompt_embeds.shape[0]
|
273 |
+
|
274 |
+
if prompt_embeds is None:
|
275 |
+
text_inputs = tokenizer(
|
276 |
+
prompt,
|
277 |
+
padding="max_length",
|
278 |
+
max_length=max_length,
|
279 |
+
truncation=True,
|
280 |
+
return_attention_mask=True,
|
281 |
+
return_tensors="pt",
|
282 |
+
)
|
283 |
+
text_input_ids = text_inputs.input_ids
|
284 |
+
if text_input_ids.shape[-1] > actual_max_sequence_length:
|
285 |
+
reprompt = tokenizer.batch_decode(text_input_ids[:, :actual_max_sequence_length], skip_special_tokens=True)
|
286 |
+
text_inputs = tokenizer(
|
287 |
+
reprompt,
|
288 |
+
padding="max_length",
|
289 |
+
max_length=max_length,
|
290 |
+
truncation=True,
|
291 |
+
return_attention_mask=True,
|
292 |
+
return_tensors="pt",
|
293 |
+
)
|
294 |
+
text_input_ids = text_inputs.input_ids
|
295 |
+
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
296 |
+
|
297 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
298 |
+
text_input_ids, untruncated_ids
|
299 |
+
):
|
300 |
+
_actual_max_sequence_length = min(tokenizer.model_max_length, actual_max_sequence_length)
|
301 |
+
removed_text = tokenizer.batch_decode(untruncated_ids[:, _actual_max_sequence_length - 1 : -1])
|
302 |
+
logger.warning(
|
303 |
+
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
304 |
+
f" {_actual_max_sequence_length} tokens: {removed_text}"
|
305 |
+
)
|
306 |
+
prompt_attention_mask = text_inputs.attention_mask.to(device)
|
307 |
+
if self.transformer.config.enable_text_attention_mask:
|
308 |
+
prompt_embeds = text_encoder(
|
309 |
+
text_input_ids.to(device),
|
310 |
+
attention_mask=prompt_attention_mask,
|
311 |
+
)
|
312 |
+
else:
|
313 |
+
prompt_embeds = text_encoder(
|
314 |
+
text_input_ids.to(device)
|
315 |
+
)
|
316 |
+
prompt_embeds = prompt_embeds[0]
|
317 |
+
prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
|
318 |
+
|
319 |
+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
320 |
+
|
321 |
+
bs_embed, seq_len, _ = prompt_embeds.shape
|
322 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
323 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
324 |
+
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
325 |
+
|
326 |
+
# get unconditional embeddings for classifier free guidance
|
327 |
+
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
328 |
+
uncond_tokens: List[str]
|
329 |
+
if negative_prompt is None:
|
330 |
+
uncond_tokens = [""] * batch_size
|
331 |
+
elif prompt is not None and type(prompt) is not type(negative_prompt):
|
332 |
+
raise TypeError(
|
333 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
334 |
+
f" {type(prompt)}."
|
335 |
+
)
|
336 |
+
elif isinstance(negative_prompt, str):
|
337 |
+
uncond_tokens = [negative_prompt]
|
338 |
+
elif batch_size != len(negative_prompt):
|
339 |
+
raise ValueError(
|
340 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
341 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
342 |
+
" the batch size of `prompt`."
|
343 |
+
)
|
344 |
+
else:
|
345 |
+
uncond_tokens = negative_prompt
|
346 |
+
|
347 |
+
max_length = prompt_embeds.shape[1]
|
348 |
+
uncond_input = tokenizer(
|
349 |
+
uncond_tokens,
|
350 |
+
padding="max_length",
|
351 |
+
max_length=max_length,
|
352 |
+
truncation=True,
|
353 |
+
return_tensors="pt",
|
354 |
+
)
|
355 |
+
uncond_input_ids = uncond_input.input_ids
|
356 |
+
if uncond_input_ids.shape[-1] > actual_max_sequence_length:
|
357 |
+
reuncond_tokens = tokenizer.batch_decode(uncond_input_ids[:, :actual_max_sequence_length], skip_special_tokens=True)
|
358 |
+
uncond_input = tokenizer(
|
359 |
+
reuncond_tokens,
|
360 |
+
padding="max_length",
|
361 |
+
max_length=max_length,
|
362 |
+
truncation=True,
|
363 |
+
return_attention_mask=True,
|
364 |
+
return_tensors="pt",
|
365 |
+
)
|
366 |
+
uncond_input_ids = uncond_input.input_ids
|
367 |
+
|
368 |
+
negative_prompt_attention_mask = uncond_input.attention_mask.to(device)
|
369 |
+
if self.transformer.config.enable_text_attention_mask:
|
370 |
+
negative_prompt_embeds = text_encoder(
|
371 |
+
uncond_input.input_ids.to(device),
|
372 |
+
attention_mask=negative_prompt_attention_mask,
|
373 |
+
)
|
374 |
+
else:
|
375 |
+
negative_prompt_embeds = text_encoder(
|
376 |
+
uncond_input.input_ids.to(device)
|
377 |
+
)
|
378 |
+
negative_prompt_embeds = negative_prompt_embeds[0]
|
379 |
+
negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
|
380 |
+
|
381 |
+
if do_classifier_free_guidance:
|
382 |
+
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
383 |
+
seq_len = negative_prompt_embeds.shape[1]
|
384 |
+
|
385 |
+
negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
|
386 |
+
|
387 |
+
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
388 |
+
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
389 |
+
|
390 |
+
return prompt_embeds, negative_prompt_embeds, prompt_attention_mask, negative_prompt_attention_mask
|
391 |
+
|
392 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
|
393 |
+
def run_safety_checker(self, image, device, dtype):
|
394 |
+
if self.safety_checker is None:
|
395 |
+
has_nsfw_concept = None
|
396 |
+
else:
|
397 |
+
if torch.is_tensor(image):
|
398 |
+
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
|
399 |
+
else:
|
400 |
+
feature_extractor_input = self.image_processor.numpy_to_pil(image)
|
401 |
+
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
|
402 |
+
image, has_nsfw_concept = self.safety_checker(
|
403 |
+
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
|
404 |
+
)
|
405 |
+
return image, has_nsfw_concept
|
406 |
+
|
407 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
408 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
409 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
410 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
411 |
+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
412 |
+
# and should be between [0, 1]
|
413 |
+
|
414 |
+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
415 |
+
extra_step_kwargs = {}
|
416 |
+
if accepts_eta:
|
417 |
+
extra_step_kwargs["eta"] = eta
|
418 |
+
|
419 |
+
# check if the scheduler accepts generator
|
420 |
+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
421 |
+
if accepts_generator:
|
422 |
+
extra_step_kwargs["generator"] = generator
|
423 |
+
return extra_step_kwargs
|
424 |
+
|
425 |
+
def check_inputs(
|
426 |
+
self,
|
427 |
+
prompt,
|
428 |
+
height,
|
429 |
+
width,
|
430 |
+
negative_prompt=None,
|
431 |
+
prompt_embeds=None,
|
432 |
+
negative_prompt_embeds=None,
|
433 |
+
prompt_attention_mask=None,
|
434 |
+
negative_prompt_attention_mask=None,
|
435 |
+
prompt_embeds_2=None,
|
436 |
+
negative_prompt_embeds_2=None,
|
437 |
+
prompt_attention_mask_2=None,
|
438 |
+
negative_prompt_attention_mask_2=None,
|
439 |
+
callback_on_step_end_tensor_inputs=None,
|
440 |
+
):
|
441 |
+
if height % 8 != 0 or width % 8 != 0:
|
442 |
+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
443 |
+
|
444 |
+
if callback_on_step_end_tensor_inputs is not None and not all(
|
445 |
+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
446 |
+
):
|
447 |
+
raise ValueError(
|
448 |
+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
449 |
+
)
|
450 |
+
|
451 |
+
if prompt is not None and prompt_embeds is not None:
|
452 |
+
raise ValueError(
|
453 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
454 |
+
" only forward one of the two."
|
455 |
+
)
|
456 |
+
elif prompt is None and prompt_embeds is None:
|
457 |
+
raise ValueError(
|
458 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
459 |
+
)
|
460 |
+
elif prompt is None and prompt_embeds_2 is None:
|
461 |
+
raise ValueError(
|
462 |
+
"Provide either `prompt` or `prompt_embeds_2`. Cannot leave both `prompt` and `prompt_embeds_2` undefined."
|
463 |
+
)
|
464 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
465 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
466 |
+
|
467 |
+
if prompt_embeds is not None and prompt_attention_mask is None:
|
468 |
+
raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
|
469 |
+
|
470 |
+
if prompt_embeds_2 is not None and prompt_attention_mask_2 is None:
|
471 |
+
raise ValueError("Must provide `prompt_attention_mask_2` when specifying `prompt_embeds_2`.")
|
472 |
+
|
473 |
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
474 |
+
raise ValueError(
|
475 |
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
476 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
477 |
+
)
|
478 |
+
|
479 |
+
if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
|
480 |
+
raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
|
481 |
+
|
482 |
+
if negative_prompt_embeds_2 is not None and negative_prompt_attention_mask_2 is None:
|
483 |
+
raise ValueError(
|
484 |
+
"Must provide `negative_prompt_attention_mask_2` when specifying `negative_prompt_embeds_2`."
|
485 |
+
)
|
486 |
+
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
487 |
+
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
488 |
+
raise ValueError(
|
489 |
+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
490 |
+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
491 |
+
f" {negative_prompt_embeds.shape}."
|
492 |
+
)
|
493 |
+
if prompt_embeds_2 is not None and negative_prompt_embeds_2 is not None:
|
494 |
+
if prompt_embeds_2.shape != negative_prompt_embeds_2.shape:
|
495 |
+
raise ValueError(
|
496 |
+
"`prompt_embeds_2` and `negative_prompt_embeds_2` must have the same shape when passed directly, but"
|
497 |
+
f" got: `prompt_embeds_2` {prompt_embeds_2.shape} != `negative_prompt_embeds_2`"
|
498 |
+
f" {negative_prompt_embeds_2.shape}."
|
499 |
+
)
|
500 |
+
|
501 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
|
502 |
+
def prepare_latents(self, batch_size, num_channels_latents, video_length, height, width, dtype, device, generator, latents=None):
|
503 |
+
if self.vae.quant_conv is None or self.vae.quant_conv.weight.ndim==5:
|
504 |
+
if self.vae.cache_mag_vae:
|
505 |
+
mini_batch_encoder = self.vae.mini_batch_encoder
|
506 |
+
mini_batch_decoder = self.vae.mini_batch_decoder
|
507 |
+
shape = (batch_size, num_channels_latents, int((video_length - 1) // mini_batch_encoder * mini_batch_decoder + 1) if video_length != 1 else 1, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
508 |
+
else:
|
509 |
+
mini_batch_encoder = self.vae.mini_batch_encoder
|
510 |
+
mini_batch_decoder = self.vae.mini_batch_decoder
|
511 |
+
shape = (batch_size, num_channels_latents, int(video_length // mini_batch_encoder * mini_batch_decoder) if video_length != 1 else 1, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
512 |
+
else:
|
513 |
+
shape = (batch_size, num_channels_latents, video_length, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
514 |
+
|
515 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
516 |
+
raise ValueError(
|
517 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
518 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
519 |
+
)
|
520 |
+
|
521 |
+
if latents is None:
|
522 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
523 |
+
else:
|
524 |
+
latents = latents.to(device)
|
525 |
+
|
526 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
527 |
+
latents = latents * self.scheduler.init_noise_sigma
|
528 |
+
return latents
|
529 |
+
|
530 |
+
def prepare_control_latents(
|
531 |
+
self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
|
532 |
+
):
|
533 |
+
# resize the mask to latents shape as we concatenate the mask to the latents
|
534 |
+
# we do that before converting to dtype to avoid breaking in case we're using cpu_offload
|
535 |
+
# and half precision
|
536 |
+
|
537 |
+
if mask is not None:
|
538 |
+
mask = mask.to(device=device, dtype=self.vae.dtype)
|
539 |
+
bs = 1
|
540 |
+
new_mask = []
|
541 |
+
for i in range(0, mask.shape[0], bs):
|
542 |
+
mask_bs = mask[i : i + bs]
|
543 |
+
mask_bs = self.vae.encode(mask_bs)[0]
|
544 |
+
mask_bs = mask_bs.mode()
|
545 |
+
new_mask.append(mask_bs)
|
546 |
+
mask = torch.cat(new_mask, dim = 0)
|
547 |
+
mask = mask * self.vae.config.scaling_factor
|
548 |
+
|
549 |
+
if masked_image is not None:
|
550 |
+
masked_image = masked_image.to(device=device, dtype=self.vae.dtype)
|
551 |
+
bs = 1
|
552 |
+
new_mask_pixel_values = []
|
553 |
+
for i in range(0, masked_image.shape[0], bs):
|
554 |
+
mask_pixel_values_bs = masked_image[i : i + bs]
|
555 |
+
mask_pixel_values_bs = self.vae.encode(mask_pixel_values_bs)[0]
|
556 |
+
mask_pixel_values_bs = mask_pixel_values_bs.mode()
|
557 |
+
new_mask_pixel_values.append(mask_pixel_values_bs)
|
558 |
+
masked_image_latents = torch.cat(new_mask_pixel_values, dim = 0)
|
559 |
+
masked_image_latents = masked_image_latents * self.vae.config.scaling_factor
|
560 |
+
else:
|
561 |
+
masked_image_latents = None
|
562 |
+
|
563 |
+
return mask, masked_image_latents
|
564 |
+
|
565 |
+
def smooth_output(self, video, mini_batch_encoder, mini_batch_decoder):
|
566 |
+
if video.size()[2] <= mini_batch_encoder:
|
567 |
+
return video
|
568 |
+
prefix_index_before = mini_batch_encoder // 2
|
569 |
+
prefix_index_after = mini_batch_encoder - prefix_index_before
|
570 |
+
pixel_values = video[:, :, prefix_index_before:-prefix_index_after]
|
571 |
+
|
572 |
+
# Encode middle videos
|
573 |
+
latents = self.vae.encode(pixel_values)[0]
|
574 |
+
latents = latents.mode()
|
575 |
+
# Decode middle videos
|
576 |
+
middle_video = self.vae.decode(latents)[0]
|
577 |
+
|
578 |
+
video[:, :, prefix_index_before:-prefix_index_after] = (video[:, :, prefix_index_before:-prefix_index_after] + middle_video) / 2
|
579 |
+
return video
|
580 |
+
|
581 |
+
def decode_latents(self, latents):
|
582 |
+
video_length = latents.shape[2]
|
583 |
+
latents = 1 / self.vae.config.scaling_factor * latents
|
584 |
+
if self.vae.quant_conv is None or self.vae.quant_conv.weight.ndim==5:
|
585 |
+
mini_batch_encoder = self.vae.mini_batch_encoder
|
586 |
+
mini_batch_decoder = self.vae.mini_batch_decoder
|
587 |
+
video = self.vae.decode(latents)[0]
|
588 |
+
video = video.clamp(-1, 1)
|
589 |
+
if not self.vae.cache_compression_vae and not self.vae.cache_mag_vae:
|
590 |
+
video = self.smooth_output(video, mini_batch_encoder, mini_batch_decoder).cpu().clamp(-1, 1)
|
591 |
+
else:
|
592 |
+
latents = rearrange(latents, "b c f h w -> (b f) c h w")
|
593 |
+
video = []
|
594 |
+
for frame_idx in tqdm(range(latents.shape[0])):
|
595 |
+
video.append(self.vae.decode(latents[frame_idx:frame_idx+1]).sample)
|
596 |
+
video = torch.cat(video)
|
597 |
+
video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
|
598 |
+
video = (video / 2 + 0.5).clamp(0, 1)
|
599 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
|
600 |
+
video = video.cpu().float().numpy()
|
601 |
+
return video
|
602 |
+
|
603 |
+
@property
|
604 |
+
def guidance_scale(self):
|
605 |
+
return self._guidance_scale
|
606 |
+
|
607 |
+
@property
|
608 |
+
def guidance_rescale(self):
|
609 |
+
return self._guidance_rescale
|
610 |
+
|
611 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
612 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
613 |
+
# corresponds to doing no classifier free guidance.
|
614 |
+
@property
|
615 |
+
def do_classifier_free_guidance(self):
|
616 |
+
return self._guidance_scale > 1
|
617 |
+
|
618 |
+
@property
|
619 |
+
def num_timesteps(self):
|
620 |
+
return self._num_timesteps
|
621 |
+
|
622 |
+
@property
|
623 |
+
def interrupt(self):
|
624 |
+
return self._interrupt
|
625 |
+
|
626 |
+
def enable_autocast_float8_transformer(self):
|
627 |
+
self.enable_autocast_float8_transformer_flag = True
|
628 |
+
|
629 |
+
@torch.no_grad()
|
630 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
631 |
+
def __call__(
|
632 |
+
self,
|
633 |
+
prompt: Union[str, List[str]] = None,
|
634 |
+
video_length: Optional[int] = None,
|
635 |
+
height: Optional[int] = None,
|
636 |
+
width: Optional[int] = None,
|
637 |
+
control_video: Union[torch.FloatTensor] = None,
|
638 |
+
num_inference_steps: Optional[int] = 50,
|
639 |
+
guidance_scale: Optional[float] = 5.0,
|
640 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
641 |
+
num_images_per_prompt: Optional[int] = 1,
|
642 |
+
eta: Optional[float] = 0.0,
|
643 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
644 |
+
latents: Optional[torch.Tensor] = None,
|
645 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
646 |
+
prompt_embeds_2: Optional[torch.Tensor] = None,
|
647 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
648 |
+
negative_prompt_embeds_2: Optional[torch.Tensor] = None,
|
649 |
+
prompt_attention_mask: Optional[torch.Tensor] = None,
|
650 |
+
prompt_attention_mask_2: Optional[torch.Tensor] = None,
|
651 |
+
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
652 |
+
negative_prompt_attention_mask_2: Optional[torch.Tensor] = None,
|
653 |
+
output_type: Optional[str] = "latent",
|
654 |
+
return_dict: bool = True,
|
655 |
+
callback_on_step_end: Optional[
|
656 |
+
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
657 |
+
] = None,
|
658 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
659 |
+
guidance_rescale: float = 0.0,
|
660 |
+
original_size: Optional[Tuple[int, int]] = (1024, 1024),
|
661 |
+
target_size: Optional[Tuple[int, int]] = None,
|
662 |
+
crops_coords_top_left: Tuple[int, int] = (0, 0),
|
663 |
+
comfyui_progressbar: bool = False,
|
664 |
+
):
|
665 |
+
r"""
|
666 |
+
Generates images or video using the EasyAnimate pipeline based on the provided prompts.
|
667 |
+
|
668 |
+
Examples:
|
669 |
+
prompt (`str` or `List[str]`, *optional*):
|
670 |
+
Text prompts to guide the image or video generation. If not provided, use `prompt_embeds` instead.
|
671 |
+
video_length (`int`, *optional*):
|
672 |
+
Length of the generated video (in frames).
|
673 |
+
height (`int`, *optional*):
|
674 |
+
Height of the generated image in pixels.
|
675 |
+
width (`int`, *optional*):
|
676 |
+
Width of the generated image in pixels.
|
677 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
678 |
+
Number of denoising steps during generation. More steps generally yield higher quality images but slow down inference.
|
679 |
+
guidance_scale (`float`, *optional*, defaults to 5.0):
|
680 |
+
Encourages the model to align outputs with prompts. A higher value may decrease image quality.
|
681 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
682 |
+
Prompts indicating what to exclude in generation. If not specified, use `negative_prompt_embeds`.
|
683 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
684 |
+
Number of images to generate for each prompt.
|
685 |
+
eta (`float`, *optional*, defaults to 0.0):
|
686 |
+
Applies to DDIM scheduling. Controlled by the eta parameter from the related literature.
|
687 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
688 |
+
A generator to ensure reproducibility in image generation.
|
689 |
+
latents (`torch.Tensor`, *optional*):
|
690 |
+
Predefined latent tensors to condition generation.
|
691 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
692 |
+
Text embeddings for the prompts. Overrides prompt string inputs for more flexibility.
|
693 |
+
prompt_embeds_2 (`torch.Tensor`, *optional*):
|
694 |
+
Secondary text embeddings to supplement or replace the initial prompt embeddings.
|
695 |
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
696 |
+
Embeddings for negative prompts. Overrides string inputs if defined.
|
697 |
+
negative_prompt_embeds_2 (`torch.Tensor`, *optional*):
|
698 |
+
Secondary embeddings for negative prompts, similar to `negative_prompt_embeds`.
|
699 |
+
prompt_attention_mask (`torch.Tensor`, *optional*):
|
700 |
+
Attention mask for the primary prompt embeddings.
|
701 |
+
prompt_attention_mask_2 (`torch.Tensor`, *optional*):
|
702 |
+
Attention mask for the secondary prompt embeddings.
|
703 |
+
negative_prompt_attention_mask (`torch.Tensor`, *optional*):
|
704 |
+
Attention mask for negative prompt embeddings.
|
705 |
+
negative_prompt_attention_mask_2 (`torch.Tensor`, *optional*):
|
706 |
+
Attention mask for secondary negative prompt embeddings.
|
707 |
+
output_type (`str`, *optional*, defaults to "latent"):
|
708 |
+
Format of the generated output, either as a PIL image or as a NumPy array.
|
709 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
710 |
+
If `True`, returns a structured output. Otherwise returns a simple tuple.
|
711 |
+
callback_on_step_end (`Callable`, *optional*):
|
712 |
+
Functions called at the end of each denoising step.
|
713 |
+
callback_on_step_end_tensor_inputs (`List[str]`, *optional*):
|
714 |
+
Tensor names to be included in callback function calls.
|
715 |
+
guidance_rescale (`float`, *optional*, defaults to 0.0):
|
716 |
+
Adjusts noise levels based on guidance scale.
|
717 |
+
original_size (`Tuple[int, int]`, *optional*, defaults to `(1024, 1024)`):
|
718 |
+
Original dimensions of the output.
|
719 |
+
target_size (`Tuple[int, int]`, *optional*):
|
720 |
+
Desired output dimensions for calculations.
|
721 |
+
crops_coords_top_left (`Tuple[int, int]`, *optional*, defaults to `(0, 0)`):
|
722 |
+
Coordinates for cropping.
|
723 |
+
|
724 |
+
Returns:
|
725 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
726 |
+
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
|
727 |
+
otherwise a `tuple` is returned where the first element is a list with the generated images and the
|
728 |
+
second element is a list of `bool`s indicating whether the corresponding generated image contains
|
729 |
+
"not-safe-for-work" (nsfw) content.
|
730 |
+
"""
|
731 |
+
|
732 |
+
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
733 |
+
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
734 |
+
|
735 |
+
# 0. default height and width
|
736 |
+
height = int((height // 16) * 16)
|
737 |
+
width = int((width // 16) * 16)
|
738 |
+
|
739 |
+
# 1. Check inputs. Raise error if not correct
|
740 |
+
self.check_inputs(
|
741 |
+
prompt,
|
742 |
+
height,
|
743 |
+
width,
|
744 |
+
negative_prompt,
|
745 |
+
prompt_embeds,
|
746 |
+
negative_prompt_embeds,
|
747 |
+
prompt_attention_mask,
|
748 |
+
negative_prompt_attention_mask,
|
749 |
+
prompt_embeds_2,
|
750 |
+
negative_prompt_embeds_2,
|
751 |
+
prompt_attention_mask_2,
|
752 |
+
negative_prompt_attention_mask_2,
|
753 |
+
callback_on_step_end_tensor_inputs,
|
754 |
+
)
|
755 |
+
self._guidance_scale = guidance_scale
|
756 |
+
self._guidance_rescale = guidance_rescale
|
757 |
+
self._interrupt = False
|
758 |
+
|
759 |
+
# 2. Define call parameters
|
760 |
+
if prompt is not None and isinstance(prompt, str):
|
761 |
+
batch_size = 1
|
762 |
+
elif prompt is not None and isinstance(prompt, list):
|
763 |
+
batch_size = len(prompt)
|
764 |
+
else:
|
765 |
+
batch_size = prompt_embeds.shape[0]
|
766 |
+
|
767 |
+
device = self._execution_device
|
768 |
+
|
769 |
+
# 3. Encode input prompt
|
770 |
+
(
|
771 |
+
prompt_embeds,
|
772 |
+
negative_prompt_embeds,
|
773 |
+
prompt_attention_mask,
|
774 |
+
negative_prompt_attention_mask,
|
775 |
+
) = self.encode_prompt(
|
776 |
+
prompt=prompt,
|
777 |
+
device=device,
|
778 |
+
dtype=self.transformer.dtype,
|
779 |
+
num_images_per_prompt=num_images_per_prompt,
|
780 |
+
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
781 |
+
negative_prompt=negative_prompt,
|
782 |
+
prompt_embeds=prompt_embeds,
|
783 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
784 |
+
prompt_attention_mask=prompt_attention_mask,
|
785 |
+
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
786 |
+
text_encoder_index=0,
|
787 |
+
)
|
788 |
+
(
|
789 |
+
prompt_embeds_2,
|
790 |
+
negative_prompt_embeds_2,
|
791 |
+
prompt_attention_mask_2,
|
792 |
+
negative_prompt_attention_mask_2,
|
793 |
+
) = self.encode_prompt(
|
794 |
+
prompt=prompt,
|
795 |
+
device=device,
|
796 |
+
dtype=self.transformer.dtype,
|
797 |
+
num_images_per_prompt=num_images_per_prompt,
|
798 |
+
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
799 |
+
negative_prompt=negative_prompt,
|
800 |
+
prompt_embeds=prompt_embeds_2,
|
801 |
+
negative_prompt_embeds=negative_prompt_embeds_2,
|
802 |
+
prompt_attention_mask=prompt_attention_mask_2,
|
803 |
+
negative_prompt_attention_mask=negative_prompt_attention_mask_2,
|
804 |
+
text_encoder_index=1,
|
805 |
+
)
|
806 |
+
torch.cuda.empty_cache()
|
807 |
+
|
808 |
+
# 4. Prepare timesteps
|
809 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
810 |
+
timesteps = self.scheduler.timesteps
|
811 |
+
if comfyui_progressbar:
|
812 |
+
from comfy.utils import ProgressBar
|
813 |
+
pbar = ProgressBar(num_inference_steps + 2)
|
814 |
+
|
815 |
+
# 5. Prepare latent variables
|
816 |
+
num_channels_latents = self.vae.config.latent_channels
|
817 |
+
latents = self.prepare_latents(
|
818 |
+
batch_size * num_images_per_prompt,
|
819 |
+
num_channels_latents,
|
820 |
+
video_length,
|
821 |
+
height,
|
822 |
+
width,
|
823 |
+
prompt_embeds.dtype,
|
824 |
+
device,
|
825 |
+
generator,
|
826 |
+
latents,
|
827 |
+
)
|
828 |
+
if comfyui_progressbar:
|
829 |
+
pbar.update(1)
|
830 |
+
|
831 |
+
if control_video is not None:
|
832 |
+
video_length = control_video.shape[2]
|
833 |
+
control_video = self.image_processor.preprocess(rearrange(control_video, "b c f h w -> (b f) c h w"), height=height, width=width)
|
834 |
+
control_video = control_video.to(dtype=torch.float32)
|
835 |
+
control_video = rearrange(control_video, "(b f) c h w -> b c f h w", f=video_length)
|
836 |
+
else:
|
837 |
+
control_video = None
|
838 |
+
control_video_latents = self.prepare_control_latents(
|
839 |
+
None,
|
840 |
+
control_video,
|
841 |
+
batch_size,
|
842 |
+
height,
|
843 |
+
width,
|
844 |
+
prompt_embeds.dtype,
|
845 |
+
device,
|
846 |
+
generator,
|
847 |
+
self.do_classifier_free_guidance
|
848 |
+
)[1]
|
849 |
+
control_latents = (
|
850 |
+
torch.cat([control_video_latents] * 2) if self.do_classifier_free_guidance else control_video_latents
|
851 |
+
)
|
852 |
+
|
853 |
+
if comfyui_progressbar:
|
854 |
+
pbar.update(1)
|
855 |
+
|
856 |
+
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
857 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
858 |
+
|
859 |
+
# 7 create image_rotary_emb, style embedding & time ids
|
860 |
+
grid_height = height // 8 // self.transformer.config.patch_size
|
861 |
+
grid_width = width // 8 // self.transformer.config.patch_size
|
862 |
+
if self.transformer.config.get("time_position_encoding_type", "2d_rope") == "3d_rope":
|
863 |
+
base_size_width = 720 // 8 // self.transformer.config.patch_size
|
864 |
+
base_size_height = 480 // 8 // self.transformer.config.patch_size
|
865 |
+
|
866 |
+
grid_crops_coords = get_resize_crop_region_for_grid(
|
867 |
+
(grid_height, grid_width), base_size_width, base_size_height
|
868 |
+
)
|
869 |
+
image_rotary_emb = get_3d_rotary_pos_embed(
|
870 |
+
self.transformer.config.attention_head_dim, grid_crops_coords, grid_size=(grid_height, grid_width),
|
871 |
+
temporal_size=latents.size(2), use_real=True,
|
872 |
+
)
|
873 |
+
else:
|
874 |
+
base_size = 512 // 8 // self.transformer.config.patch_size
|
875 |
+
grid_crops_coords = get_resize_crop_region_for_grid(
|
876 |
+
(grid_height, grid_width), base_size, base_size
|
877 |
+
)
|
878 |
+
image_rotary_emb = get_2d_rotary_pos_embed(
|
879 |
+
self.transformer.config.attention_head_dim, grid_crops_coords, (grid_height, grid_width)
|
880 |
+
)
|
881 |
+
|
882 |
+
# Get other hunyuan params
|
883 |
+
style = torch.tensor([0], device=device)
|
884 |
+
|
885 |
+
target_size = target_size or (height, width)
|
886 |
+
add_time_ids = list(original_size + target_size + crops_coords_top_left)
|
887 |
+
add_time_ids = torch.tensor([add_time_ids], dtype=prompt_embeds.dtype)
|
888 |
+
|
889 |
+
if self.do_classifier_free_guidance:
|
890 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
891 |
+
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask])
|
892 |
+
prompt_embeds_2 = torch.cat([negative_prompt_embeds_2, prompt_embeds_2])
|
893 |
+
prompt_attention_mask_2 = torch.cat([negative_prompt_attention_mask_2, prompt_attention_mask_2])
|
894 |
+
add_time_ids = torch.cat([add_time_ids] * 2, dim=0)
|
895 |
+
style = torch.cat([style] * 2, dim=0)
|
896 |
+
|
897 |
+
# To latents.device
|
898 |
+
prompt_embeds = prompt_embeds.to(device=device)
|
899 |
+
prompt_attention_mask = prompt_attention_mask.to(device=device)
|
900 |
+
prompt_embeds_2 = prompt_embeds_2.to(device=device)
|
901 |
+
prompt_attention_mask_2 = prompt_attention_mask_2.to(device=device)
|
902 |
+
add_time_ids = add_time_ids.to(dtype=prompt_embeds.dtype, device=device).repeat(
|
903 |
+
batch_size * num_images_per_prompt, 1
|
904 |
+
)
|
905 |
+
style = style.to(device=device).repeat(batch_size * num_images_per_prompt)
|
906 |
+
|
907 |
+
torch.cuda.empty_cache()
|
908 |
+
if self.enable_autocast_float8_transformer_flag:
|
909 |
+
origin_weight_dtype = self.transformer.dtype
|
910 |
+
self.transformer = self.transformer.to(torch.float8_e4m3fn)
|
911 |
+
# 8. Denoising loop
|
912 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
913 |
+
self._num_timesteps = len(timesteps)
|
914 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
915 |
+
for i, t in enumerate(timesteps):
|
916 |
+
if self.interrupt:
|
917 |
+
continue
|
918 |
+
|
919 |
+
# expand the latents if we are doing classifier free guidance
|
920 |
+
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
921 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
922 |
+
|
923 |
+
# expand scalar t to 1-D tensor to match the 1st dim of latent_model_input
|
924 |
+
t_expand = torch.tensor([t] * latent_model_input.shape[0], device=device).to(
|
925 |
+
dtype=latent_model_input.dtype
|
926 |
+
)
|
927 |
+
# predict the noise residual
|
928 |
+
noise_pred = self.transformer(
|
929 |
+
latent_model_input,
|
930 |
+
t_expand,
|
931 |
+
encoder_hidden_states=prompt_embeds,
|
932 |
+
text_embedding_mask=prompt_attention_mask,
|
933 |
+
encoder_hidden_states_t5=prompt_embeds_2,
|
934 |
+
text_embedding_mask_t5=prompt_attention_mask_2,
|
935 |
+
image_meta_size=add_time_ids,
|
936 |
+
style=style,
|
937 |
+
image_rotary_emb=image_rotary_emb,
|
938 |
+
return_dict=False,
|
939 |
+
control_latents=control_latents,
|
940 |
+
)[0]
|
941 |
+
if noise_pred.size()[1] != self.vae.config.latent_channels:
|
942 |
+
noise_pred, _ = noise_pred.chunk(2, dim=1)
|
943 |
+
|
944 |
+
# perform guidance
|
945 |
+
if self.do_classifier_free_guidance:
|
946 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
947 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
948 |
+
|
949 |
+
if self.do_classifier_free_guidance and guidance_rescale > 0.0:
|
950 |
+
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
951 |
+
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
|
952 |
+
|
953 |
+
# compute the previous noisy sample x_t -> x_t-1
|
954 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
955 |
+
|
956 |
+
if callback_on_step_end is not None:
|
957 |
+
callback_kwargs = {}
|
958 |
+
for k in callback_on_step_end_tensor_inputs:
|
959 |
+
callback_kwargs[k] = locals()[k]
|
960 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
961 |
+
|
962 |
+
latents = callback_outputs.pop("latents", latents)
|
963 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
964 |
+
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
965 |
+
prompt_embeds_2 = callback_outputs.pop("prompt_embeds_2", prompt_embeds_2)
|
966 |
+
negative_prompt_embeds_2 = callback_outputs.pop(
|
967 |
+
"negative_prompt_embeds_2", negative_prompt_embeds_2
|
968 |
+
)
|
969 |
+
|
970 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
971 |
+
progress_bar.update()
|
972 |
+
|
973 |
+
if XLA_AVAILABLE:
|
974 |
+
xm.mark_step()
|
975 |
+
|
976 |
+
if comfyui_progressbar:
|
977 |
+
pbar.update(1)
|
978 |
+
|
979 |
+
if self.enable_autocast_float8_transformer_flag:
|
980 |
+
self.transformer = self.transformer.to("cpu", origin_weight_dtype)
|
981 |
+
|
982 |
+
torch.cuda.empty_cache()
|
983 |
+
# Post-processing
|
984 |
+
video = self.decode_latents(latents)
|
985 |
+
|
986 |
+
# Convert to tensor
|
987 |
+
if output_type == "latent":
|
988 |
+
video = torch.from_numpy(video)
|
989 |
+
|
990 |
+
# Offload all models
|
991 |
+
self.maybe_free_model_hooks()
|
992 |
+
|
993 |
+
if not return_dict:
|
994 |
+
return video
|
995 |
+
|
996 |
+
return EasyAnimatePipelineOutput(videos=video)
|
easyanimate/pipeline/pipeline_easyanimate_multi_text_encoder_inpaint.py
ADDED
@@ -0,0 +1,1334 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 EasyAnimate Authors and The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import inspect
|
16 |
+
from typing import Callable, Dict, List, Optional, Tuple, Union
|
17 |
+
|
18 |
+
import torch
|
19 |
+
import torch.nn.functional as F
|
20 |
+
from diffusers import DiffusionPipeline
|
21 |
+
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
|
22 |
+
from diffusers.image_processor import VaeImageProcessor
|
23 |
+
from diffusers.models import AutoencoderKL, HunyuanDiT2DModel
|
24 |
+
from diffusers.models.embeddings import (get_2d_rotary_pos_embed,
|
25 |
+
get_3d_rotary_pos_embed)
|
26 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
27 |
+
from diffusers.pipelines.stable_diffusion.safety_checker import \
|
28 |
+
StableDiffusionSafetyChecker
|
29 |
+
from diffusers.schedulers import DDIMScheduler
|
30 |
+
from diffusers.utils import (is_torch_xla_available, logging,
|
31 |
+
replace_example_docstring)
|
32 |
+
from diffusers.utils.torch_utils import randn_tensor
|
33 |
+
from einops import rearrange
|
34 |
+
from PIL import Image
|
35 |
+
from tqdm import tqdm
|
36 |
+
from transformers import (BertModel, BertTokenizer, CLIPImageProcessor,
|
37 |
+
CLIPVisionModelWithProjection, T5Tokenizer,
|
38 |
+
T5EncoderModel)
|
39 |
+
|
40 |
+
from .pipeline_easyanimate import EasyAnimatePipelineOutput
|
41 |
+
from ..models import AutoencoderKLMagvit, EasyAnimateTransformer3DModel
|
42 |
+
|
43 |
+
if is_torch_xla_available():
|
44 |
+
import torch_xla.core.xla_model as xm
|
45 |
+
|
46 |
+
XLA_AVAILABLE = True
|
47 |
+
else:
|
48 |
+
XLA_AVAILABLE = False
|
49 |
+
|
50 |
+
|
51 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
52 |
+
|
53 |
+
EXAMPLE_DOC_STRING = """
|
54 |
+
Examples:
|
55 |
+
```py
|
56 |
+
>>> pass
|
57 |
+
```
|
58 |
+
"""
|
59 |
+
|
60 |
+
|
61 |
+
def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
|
62 |
+
tw = tgt_width
|
63 |
+
th = tgt_height
|
64 |
+
h, w = src
|
65 |
+
r = h / w
|
66 |
+
if r > (th / tw):
|
67 |
+
resize_height = th
|
68 |
+
resize_width = int(round(th / h * w))
|
69 |
+
else:
|
70 |
+
resize_width = tw
|
71 |
+
resize_height = int(round(tw / w * h))
|
72 |
+
|
73 |
+
crop_top = int(round((th - resize_height) / 2.0))
|
74 |
+
crop_left = int(round((tw - resize_width) / 2.0))
|
75 |
+
|
76 |
+
return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
|
77 |
+
|
78 |
+
|
79 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
|
80 |
+
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
81 |
+
"""
|
82 |
+
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
|
83 |
+
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
|
84 |
+
"""
|
85 |
+
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
|
86 |
+
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
|
87 |
+
# rescale the results from guidance (fixes overexposure)
|
88 |
+
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
|
89 |
+
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
|
90 |
+
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
|
91 |
+
return noise_cfg
|
92 |
+
|
93 |
+
|
94 |
+
def resize_mask(mask, latent, process_first_frame_only=True):
|
95 |
+
latent_size = latent.size()
|
96 |
+
|
97 |
+
if process_first_frame_only:
|
98 |
+
target_size = list(latent_size[2:])
|
99 |
+
target_size[0] = 1
|
100 |
+
first_frame_resized = F.interpolate(
|
101 |
+
mask[:, :, 0:1, :, :],
|
102 |
+
size=target_size,
|
103 |
+
mode='trilinear',
|
104 |
+
align_corners=False
|
105 |
+
)
|
106 |
+
|
107 |
+
target_size = list(latent_size[2:])
|
108 |
+
target_size[0] = target_size[0] - 1
|
109 |
+
if target_size[0] != 0:
|
110 |
+
remaining_frames_resized = F.interpolate(
|
111 |
+
mask[:, :, 1:, :, :],
|
112 |
+
size=target_size,
|
113 |
+
mode='trilinear',
|
114 |
+
align_corners=False
|
115 |
+
)
|
116 |
+
resized_mask = torch.cat([first_frame_resized, remaining_frames_resized], dim=2)
|
117 |
+
else:
|
118 |
+
resized_mask = first_frame_resized
|
119 |
+
else:
|
120 |
+
target_size = list(latent_size[2:])
|
121 |
+
resized_mask = F.interpolate(
|
122 |
+
mask,
|
123 |
+
size=target_size,
|
124 |
+
mode='trilinear',
|
125 |
+
align_corners=False
|
126 |
+
)
|
127 |
+
return resized_mask
|
128 |
+
|
129 |
+
|
130 |
+
def add_noise_to_reference_video(image, ratio=None):
|
131 |
+
if ratio is None:
|
132 |
+
sigma = torch.normal(mean=-3.0, std=0.5, size=(image.shape[0],)).to(image.device)
|
133 |
+
sigma = torch.exp(sigma).to(image.dtype)
|
134 |
+
else:
|
135 |
+
sigma = torch.ones((image.shape[0],)).to(image.device, image.dtype) * ratio
|
136 |
+
|
137 |
+
image_noise = torch.randn_like(image) * sigma[:, None, None, None, None]
|
138 |
+
image_noise = torch.where(image==-1, torch.zeros_like(image), image_noise)
|
139 |
+
image = image + image_noise
|
140 |
+
return image
|
141 |
+
|
142 |
+
|
143 |
+
class EasyAnimatePipeline_Multi_Text_Encoder_Inpaint(DiffusionPipeline):
|
144 |
+
r"""
|
145 |
+
Pipeline for text-to-video generation using EasyAnimate.
|
146 |
+
|
147 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
148 |
+
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
149 |
+
|
150 |
+
EasyAnimate uses two text encoders: [mT5](https://huggingface.co/google/mt5-base) and [bilingual CLIP](fine-tuned by
|
151 |
+
HunyuanDiT team)
|
152 |
+
|
153 |
+
Args:
|
154 |
+
vae ([`AutoencoderKLMagvit`]):
|
155 |
+
Variational Auto-Encoder (VAE) Model to encode and decode video to and from latent representations.
|
156 |
+
text_encoder (Optional[`~transformers.BertModel`, `~transformers.CLIPTextModel`]):
|
157 |
+
Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
|
158 |
+
EasyAnimate uses a fine-tuned [bilingual CLIP].
|
159 |
+
tokenizer (Optional[`~transformers.BertTokenizer`, `~transformers.CLIPTokenizer`]):
|
160 |
+
A `BertTokenizer` or `CLIPTokenizer` to tokenize text.
|
161 |
+
transformer ([`EasyAnimateTransformer3DModel`]):
|
162 |
+
The EasyAnimate model designed by Tencent Hunyuan.
|
163 |
+
text_encoder_2 (`T5EncoderModel`):
|
164 |
+
The mT5 embedder.
|
165 |
+
tokenizer_2 (`T5Tokenizer`):
|
166 |
+
The tokenizer for the mT5 embedder.
|
167 |
+
scheduler ([`DDIMScheduler`]):
|
168 |
+
A scheduler to be used in combination with EasyAnimate to denoise the encoded image latents.
|
169 |
+
clip_image_processor (`CLIPImageProcessor`):
|
170 |
+
The CLIP image embedder.
|
171 |
+
clip_image_encoder (`CLIPVisionModelWithProjection`):
|
172 |
+
The image processor for the CLIP image embedder.
|
173 |
+
"""
|
174 |
+
|
175 |
+
model_cpu_offload_seq = "text_encoder->text_encoder_2->clip_image_encoder->transformer->vae"
|
176 |
+
_optional_components = [
|
177 |
+
"safety_checker",
|
178 |
+
"feature_extractor",
|
179 |
+
"text_encoder_2",
|
180 |
+
"tokenizer_2",
|
181 |
+
"text_encoder",
|
182 |
+
"tokenizer",
|
183 |
+
"clip_image_encoder",
|
184 |
+
]
|
185 |
+
_exclude_from_cpu_offload = ["safety_checker"]
|
186 |
+
_callback_tensor_inputs = [
|
187 |
+
"latents",
|
188 |
+
"prompt_embeds",
|
189 |
+
"negative_prompt_embeds",
|
190 |
+
"prompt_embeds_2",
|
191 |
+
"negative_prompt_embeds_2",
|
192 |
+
]
|
193 |
+
|
194 |
+
def __init__(
|
195 |
+
self,
|
196 |
+
vae: AutoencoderKLMagvit,
|
197 |
+
text_encoder: BertModel,
|
198 |
+
tokenizer: BertTokenizer,
|
199 |
+
text_encoder_2: T5EncoderModel,
|
200 |
+
tokenizer_2: T5Tokenizer,
|
201 |
+
transformer: EasyAnimateTransformer3DModel,
|
202 |
+
scheduler: DDIMScheduler,
|
203 |
+
safety_checker: StableDiffusionSafetyChecker,
|
204 |
+
feature_extractor: CLIPImageProcessor,
|
205 |
+
requires_safety_checker: bool = True,
|
206 |
+
clip_image_processor: CLIPImageProcessor = None,
|
207 |
+
clip_image_encoder: CLIPVisionModelWithProjection = None,
|
208 |
+
):
|
209 |
+
super().__init__()
|
210 |
+
|
211 |
+
self.register_modules(
|
212 |
+
vae=vae,
|
213 |
+
text_encoder=text_encoder,
|
214 |
+
tokenizer=tokenizer,
|
215 |
+
tokenizer_2=tokenizer_2,
|
216 |
+
transformer=transformer,
|
217 |
+
scheduler=scheduler,
|
218 |
+
safety_checker=safety_checker,
|
219 |
+
feature_extractor=feature_extractor,
|
220 |
+
text_encoder_2=text_encoder_2,
|
221 |
+
clip_image_processor=clip_image_processor,
|
222 |
+
clip_image_encoder=clip_image_encoder,
|
223 |
+
)
|
224 |
+
|
225 |
+
if safety_checker is None and requires_safety_checker:
|
226 |
+
logger.warning(
|
227 |
+
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
228 |
+
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
|
229 |
+
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
|
230 |
+
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
|
231 |
+
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
|
232 |
+
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
|
233 |
+
)
|
234 |
+
|
235 |
+
if safety_checker is not None and feature_extractor is None:
|
236 |
+
raise ValueError(
|
237 |
+
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
|
238 |
+
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
239 |
+
)
|
240 |
+
|
241 |
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
242 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
243 |
+
self.mask_processor = VaeImageProcessor(
|
244 |
+
vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
|
245 |
+
)
|
246 |
+
self.enable_autocast_float8_transformer_flag = False
|
247 |
+
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
248 |
+
|
249 |
+
def enable_sequential_cpu_offload(self, *args, **kwargs):
|
250 |
+
super().enable_sequential_cpu_offload(*args, **kwargs)
|
251 |
+
if hasattr(self.transformer, "clip_projection") and self.transformer.clip_projection is not None:
|
252 |
+
import accelerate
|
253 |
+
accelerate.hooks.remove_hook_from_module(self.transformer.clip_projection, recurse=True)
|
254 |
+
self.transformer.clip_projection = self.transformer.clip_projection.to("cuda")
|
255 |
+
|
256 |
+
def encode_prompt(
|
257 |
+
self,
|
258 |
+
prompt: str,
|
259 |
+
device: torch.device,
|
260 |
+
dtype: torch.dtype,
|
261 |
+
num_images_per_prompt: int = 1,
|
262 |
+
do_classifier_free_guidance: bool = True,
|
263 |
+
negative_prompt: Optional[str] = None,
|
264 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
265 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
266 |
+
prompt_attention_mask: Optional[torch.Tensor] = None,
|
267 |
+
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
268 |
+
max_sequence_length: Optional[int] = None,
|
269 |
+
text_encoder_index: int = 0,
|
270 |
+
actual_max_sequence_length: int = 256
|
271 |
+
):
|
272 |
+
r"""
|
273 |
+
Encodes the prompt into text encoder hidden states.
|
274 |
+
|
275 |
+
Args:
|
276 |
+
prompt (`str` or `List[str]`, *optional*):
|
277 |
+
prompt to be encoded
|
278 |
+
device: (`torch.device`):
|
279 |
+
torch device
|
280 |
+
dtype (`torch.dtype`):
|
281 |
+
torch dtype
|
282 |
+
num_images_per_prompt (`int`):
|
283 |
+
number of images that should be generated per prompt
|
284 |
+
do_classifier_free_guidance (`bool`):
|
285 |
+
whether to use classifier free guidance or not
|
286 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
287 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
288 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
289 |
+
less than `1`).
|
290 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
291 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
292 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
293 |
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
294 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
295 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
296 |
+
argument.
|
297 |
+
prompt_attention_mask (`torch.Tensor`, *optional*):
|
298 |
+
Attention mask for the prompt. Required when `prompt_embeds` is passed directly.
|
299 |
+
negative_prompt_attention_mask (`torch.Tensor`, *optional*):
|
300 |
+
Attention mask for the negative prompt. Required when `negative_prompt_embeds` is passed directly.
|
301 |
+
max_sequence_length (`int`, *optional*): maximum sequence length to use for the prompt.
|
302 |
+
text_encoder_index (`int`, *optional*):
|
303 |
+
Index of the text encoder to use. `0` for clip and `1` for T5.
|
304 |
+
"""
|
305 |
+
tokenizers = [self.tokenizer, self.tokenizer_2]
|
306 |
+
text_encoders = [self.text_encoder, self.text_encoder_2]
|
307 |
+
|
308 |
+
tokenizer = tokenizers[text_encoder_index]
|
309 |
+
text_encoder = text_encoders[text_encoder_index]
|
310 |
+
|
311 |
+
if max_sequence_length is None:
|
312 |
+
if text_encoder_index == 0:
|
313 |
+
max_length = min(self.tokenizer.model_max_length, actual_max_sequence_length)
|
314 |
+
if text_encoder_index == 1:
|
315 |
+
max_length = min(self.tokenizer_2.model_max_length, actual_max_sequence_length)
|
316 |
+
else:
|
317 |
+
max_length = max_sequence_length
|
318 |
+
|
319 |
+
if prompt is not None and isinstance(prompt, str):
|
320 |
+
batch_size = 1
|
321 |
+
elif prompt is not None and isinstance(prompt, list):
|
322 |
+
batch_size = len(prompt)
|
323 |
+
else:
|
324 |
+
batch_size = prompt_embeds.shape[0]
|
325 |
+
|
326 |
+
if prompt_embeds is None:
|
327 |
+
text_inputs = tokenizer(
|
328 |
+
prompt,
|
329 |
+
padding="max_length",
|
330 |
+
max_length=max_length,
|
331 |
+
truncation=True,
|
332 |
+
return_attention_mask=True,
|
333 |
+
return_tensors="pt",
|
334 |
+
)
|
335 |
+
text_input_ids = text_inputs.input_ids
|
336 |
+
if text_input_ids.shape[-1] > actual_max_sequence_length:
|
337 |
+
reprompt = tokenizer.batch_decode(text_input_ids[:, :actual_max_sequence_length], skip_special_tokens=True)
|
338 |
+
text_inputs = tokenizer(
|
339 |
+
reprompt,
|
340 |
+
padding="max_length",
|
341 |
+
max_length=max_length,
|
342 |
+
truncation=True,
|
343 |
+
return_attention_mask=True,
|
344 |
+
return_tensors="pt",
|
345 |
+
)
|
346 |
+
text_input_ids = text_inputs.input_ids
|
347 |
+
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
348 |
+
|
349 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
350 |
+
text_input_ids, untruncated_ids
|
351 |
+
):
|
352 |
+
_actual_max_sequence_length = min(tokenizer.model_max_length, actual_max_sequence_length)
|
353 |
+
removed_text = tokenizer.batch_decode(untruncated_ids[:, _actual_max_sequence_length - 1 : -1])
|
354 |
+
logger.warning(
|
355 |
+
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
356 |
+
f" {_actual_max_sequence_length} tokens: {removed_text}"
|
357 |
+
)
|
358 |
+
prompt_attention_mask = text_inputs.attention_mask.to(device)
|
359 |
+
if self.transformer.config.enable_text_attention_mask:
|
360 |
+
prompt_embeds = text_encoder(
|
361 |
+
text_input_ids.to(device),
|
362 |
+
attention_mask=prompt_attention_mask,
|
363 |
+
)
|
364 |
+
else:
|
365 |
+
prompt_embeds = text_encoder(
|
366 |
+
text_input_ids.to(device)
|
367 |
+
)
|
368 |
+
prompt_embeds = prompt_embeds[0]
|
369 |
+
prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
|
370 |
+
|
371 |
+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
372 |
+
|
373 |
+
bs_embed, seq_len, _ = prompt_embeds.shape
|
374 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
375 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
376 |
+
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
377 |
+
|
378 |
+
# get unconditional embeddings for classifier free guidance
|
379 |
+
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
380 |
+
uncond_tokens: List[str]
|
381 |
+
if negative_prompt is None:
|
382 |
+
uncond_tokens = [""] * batch_size
|
383 |
+
elif prompt is not None and type(prompt) is not type(negative_prompt):
|
384 |
+
raise TypeError(
|
385 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
386 |
+
f" {type(prompt)}."
|
387 |
+
)
|
388 |
+
elif isinstance(negative_prompt, str):
|
389 |
+
uncond_tokens = [negative_prompt]
|
390 |
+
elif batch_size != len(negative_prompt):
|
391 |
+
raise ValueError(
|
392 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
393 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
394 |
+
" the batch size of `prompt`."
|
395 |
+
)
|
396 |
+
else:
|
397 |
+
uncond_tokens = negative_prompt
|
398 |
+
|
399 |
+
max_length = prompt_embeds.shape[1]
|
400 |
+
uncond_input = tokenizer(
|
401 |
+
uncond_tokens,
|
402 |
+
padding="max_length",
|
403 |
+
max_length=max_length,
|
404 |
+
truncation=True,
|
405 |
+
return_tensors="pt",
|
406 |
+
)
|
407 |
+
uncond_input_ids = uncond_input.input_ids
|
408 |
+
if uncond_input_ids.shape[-1] > actual_max_sequence_length:
|
409 |
+
reuncond_tokens = tokenizer.batch_decode(uncond_input_ids[:, :actual_max_sequence_length], skip_special_tokens=True)
|
410 |
+
uncond_input = tokenizer(
|
411 |
+
reuncond_tokens,
|
412 |
+
padding="max_length",
|
413 |
+
max_length=max_length,
|
414 |
+
truncation=True,
|
415 |
+
return_attention_mask=True,
|
416 |
+
return_tensors="pt",
|
417 |
+
)
|
418 |
+
uncond_input_ids = uncond_input.input_ids
|
419 |
+
|
420 |
+
negative_prompt_attention_mask = uncond_input.attention_mask.to(device)
|
421 |
+
if self.transformer.config.enable_text_attention_mask:
|
422 |
+
negative_prompt_embeds = text_encoder(
|
423 |
+
uncond_input.input_ids.to(device),
|
424 |
+
attention_mask=negative_prompt_attention_mask,
|
425 |
+
)
|
426 |
+
else:
|
427 |
+
negative_prompt_embeds = text_encoder(
|
428 |
+
uncond_input.input_ids.to(device)
|
429 |
+
)
|
430 |
+
negative_prompt_embeds = negative_prompt_embeds[0]
|
431 |
+
negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
|
432 |
+
|
433 |
+
if do_classifier_free_guidance:
|
434 |
+
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
435 |
+
seq_len = negative_prompt_embeds.shape[1]
|
436 |
+
|
437 |
+
negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
|
438 |
+
|
439 |
+
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
440 |
+
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
441 |
+
|
442 |
+
return prompt_embeds, negative_prompt_embeds, prompt_attention_mask, negative_prompt_attention_mask
|
443 |
+
|
444 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
|
445 |
+
def run_safety_checker(self, image, device, dtype):
|
446 |
+
if self.safety_checker is None:
|
447 |
+
has_nsfw_concept = None
|
448 |
+
else:
|
449 |
+
if torch.is_tensor(image):
|
450 |
+
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
|
451 |
+
else:
|
452 |
+
feature_extractor_input = self.image_processor.numpy_to_pil(image)
|
453 |
+
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
|
454 |
+
image, has_nsfw_concept = self.safety_checker(
|
455 |
+
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
|
456 |
+
)
|
457 |
+
return image, has_nsfw_concept
|
458 |
+
|
459 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
460 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
461 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
462 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
463 |
+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
464 |
+
# and should be between [0, 1]
|
465 |
+
|
466 |
+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
467 |
+
extra_step_kwargs = {}
|
468 |
+
if accepts_eta:
|
469 |
+
extra_step_kwargs["eta"] = eta
|
470 |
+
|
471 |
+
# check if the scheduler accepts generator
|
472 |
+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
473 |
+
if accepts_generator:
|
474 |
+
extra_step_kwargs["generator"] = generator
|
475 |
+
return extra_step_kwargs
|
476 |
+
|
477 |
+
def check_inputs(
|
478 |
+
self,
|
479 |
+
prompt,
|
480 |
+
height,
|
481 |
+
width,
|
482 |
+
negative_prompt=None,
|
483 |
+
prompt_embeds=None,
|
484 |
+
negative_prompt_embeds=None,
|
485 |
+
prompt_attention_mask=None,
|
486 |
+
negative_prompt_attention_mask=None,
|
487 |
+
prompt_embeds_2=None,
|
488 |
+
negative_prompt_embeds_2=None,
|
489 |
+
prompt_attention_mask_2=None,
|
490 |
+
negative_prompt_attention_mask_2=None,
|
491 |
+
callback_on_step_end_tensor_inputs=None,
|
492 |
+
):
|
493 |
+
if height % 8 != 0 or width % 8 != 0:
|
494 |
+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
495 |
+
|
496 |
+
if callback_on_step_end_tensor_inputs is not None and not all(
|
497 |
+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
498 |
+
):
|
499 |
+
raise ValueError(
|
500 |
+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
501 |
+
)
|
502 |
+
|
503 |
+
if prompt is not None and prompt_embeds is not None:
|
504 |
+
raise ValueError(
|
505 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
506 |
+
" only forward one of the two."
|
507 |
+
)
|
508 |
+
elif prompt is None and prompt_embeds is None:
|
509 |
+
raise ValueError(
|
510 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
511 |
+
)
|
512 |
+
elif prompt is None and prompt_embeds_2 is None:
|
513 |
+
raise ValueError(
|
514 |
+
"Provide either `prompt` or `prompt_embeds_2`. Cannot leave both `prompt` and `prompt_embeds_2` undefined."
|
515 |
+
)
|
516 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
517 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
518 |
+
|
519 |
+
if prompt_embeds is not None and prompt_attention_mask is None:
|
520 |
+
raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
|
521 |
+
|
522 |
+
if prompt_embeds_2 is not None and prompt_attention_mask_2 is None:
|
523 |
+
raise ValueError("Must provide `prompt_attention_mask_2` when specifying `prompt_embeds_2`.")
|
524 |
+
|
525 |
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
526 |
+
raise ValueError(
|
527 |
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
528 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
529 |
+
)
|
530 |
+
|
531 |
+
if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
|
532 |
+
raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
|
533 |
+
|
534 |
+
if negative_prompt_embeds_2 is not None and negative_prompt_attention_mask_2 is None:
|
535 |
+
raise ValueError(
|
536 |
+
"Must provide `negative_prompt_attention_mask_2` when specifying `negative_prompt_embeds_2`."
|
537 |
+
)
|
538 |
+
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
539 |
+
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
540 |
+
raise ValueError(
|
541 |
+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
542 |
+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
543 |
+
f" {negative_prompt_embeds.shape}."
|
544 |
+
)
|
545 |
+
if prompt_embeds_2 is not None and negative_prompt_embeds_2 is not None:
|
546 |
+
if prompt_embeds_2.shape != negative_prompt_embeds_2.shape:
|
547 |
+
raise ValueError(
|
548 |
+
"`prompt_embeds_2` and `negative_prompt_embeds_2` must have the same shape when passed directly, but"
|
549 |
+
f" got: `prompt_embeds_2` {prompt_embeds_2.shape} != `negative_prompt_embeds_2`"
|
550 |
+
f" {negative_prompt_embeds_2.shape}."
|
551 |
+
)
|
552 |
+
|
553 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
|
554 |
+
def get_timesteps(self, num_inference_steps, strength, device):
|
555 |
+
# get the original timestep using init_timestep
|
556 |
+
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
|
557 |
+
|
558 |
+
t_start = max(num_inference_steps - init_timestep, 0)
|
559 |
+
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
|
560 |
+
|
561 |
+
return timesteps, num_inference_steps - t_start
|
562 |
+
|
563 |
+
def prepare_mask_latents(
|
564 |
+
self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance, noise_aug_strength
|
565 |
+
):
|
566 |
+
# resize the mask to latents shape as we concatenate the mask to the latents
|
567 |
+
# we do that before converting to dtype to avoid breaking in case we're using cpu_offload
|
568 |
+
# and half precision
|
569 |
+
if mask is not None:
|
570 |
+
mask = mask.to(device=device, dtype=self.vae.dtype)
|
571 |
+
if self.vae.quant_conv is None or self.vae.quant_conv.weight.ndim==5:
|
572 |
+
bs = 1
|
573 |
+
new_mask = []
|
574 |
+
for i in range(0, mask.shape[0], bs):
|
575 |
+
mask_bs = mask[i : i + bs]
|
576 |
+
mask_bs = self.vae.encode(mask_bs)[0]
|
577 |
+
mask_bs = mask_bs.mode()
|
578 |
+
new_mask.append(mask_bs)
|
579 |
+
mask = torch.cat(new_mask, dim = 0)
|
580 |
+
mask = mask * self.vae.config.scaling_factor
|
581 |
+
|
582 |
+
else:
|
583 |
+
if mask.shape[1] == 4:
|
584 |
+
mask = mask
|
585 |
+
else:
|
586 |
+
video_length = mask.shape[2]
|
587 |
+
mask = rearrange(mask, "b c f h w -> (b f) c h w")
|
588 |
+
mask = self._encode_vae_image(mask, generator=generator)
|
589 |
+
mask = rearrange(mask, "(b f) c h w -> b c f h w", f=video_length)
|
590 |
+
|
591 |
+
if masked_image is not None:
|
592 |
+
masked_image = masked_image.to(device=device, dtype=self.vae.dtype)
|
593 |
+
if self.transformer.config.add_noise_in_inpaint_model:
|
594 |
+
masked_image = add_noise_to_reference_video(masked_image, ratio=noise_aug_strength)
|
595 |
+
if self.vae.quant_conv is None or self.vae.quant_conv.weight.ndim==5:
|
596 |
+
bs = 1
|
597 |
+
new_mask_pixel_values = []
|
598 |
+
for i in range(0, masked_image.shape[0], bs):
|
599 |
+
mask_pixel_values_bs = masked_image[i : i + bs]
|
600 |
+
mask_pixel_values_bs = self.vae.encode(mask_pixel_values_bs)[0]
|
601 |
+
mask_pixel_values_bs = mask_pixel_values_bs.mode()
|
602 |
+
new_mask_pixel_values.append(mask_pixel_values_bs)
|
603 |
+
masked_image_latents = torch.cat(new_mask_pixel_values, dim = 0)
|
604 |
+
masked_image_latents = masked_image_latents * self.vae.config.scaling_factor
|
605 |
+
|
606 |
+
else:
|
607 |
+
if masked_image.shape[1] == 4:
|
608 |
+
masked_image_latents = masked_image
|
609 |
+
else:
|
610 |
+
video_length = masked_image.shape[2]
|
611 |
+
masked_image = rearrange(masked_image, "b c f h w -> (b f) c h w")
|
612 |
+
masked_image_latents = self._encode_vae_image(masked_image, generator=generator)
|
613 |
+
masked_image_latents = rearrange(masked_image_latents, "(b f) c h w -> b c f h w", f=video_length)
|
614 |
+
|
615 |
+
# aligning device to prevent device errors when concating it with the latent model input
|
616 |
+
masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
|
617 |
+
else:
|
618 |
+
masked_image_latents = None
|
619 |
+
|
620 |
+
return mask, masked_image_latents
|
621 |
+
|
622 |
+
def prepare_latents(
|
623 |
+
self,
|
624 |
+
batch_size,
|
625 |
+
num_channels_latents,
|
626 |
+
height,
|
627 |
+
width,
|
628 |
+
video_length,
|
629 |
+
dtype,
|
630 |
+
device,
|
631 |
+
generator,
|
632 |
+
latents=None,
|
633 |
+
video=None,
|
634 |
+
timestep=None,
|
635 |
+
is_strength_max=True,
|
636 |
+
return_noise=False,
|
637 |
+
return_video_latents=False,
|
638 |
+
):
|
639 |
+
if self.vae.quant_conv is None or self.vae.quant_conv.weight.ndim==5:
|
640 |
+
if self.vae.cache_mag_vae:
|
641 |
+
mini_batch_encoder = self.vae.mini_batch_encoder
|
642 |
+
mini_batch_decoder = self.vae.mini_batch_decoder
|
643 |
+
shape = (batch_size, num_channels_latents, int((video_length - 1) // mini_batch_encoder * mini_batch_decoder + 1) if video_length != 1 else 1, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
644 |
+
else:
|
645 |
+
mini_batch_encoder = self.vae.mini_batch_encoder
|
646 |
+
mini_batch_decoder = self.vae.mini_batch_decoder
|
647 |
+
shape = (batch_size, num_channels_latents, int(video_length // mini_batch_encoder * mini_batch_decoder) if video_length != 1 else 1, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
648 |
+
else:
|
649 |
+
shape = (batch_size, num_channels_latents, video_length, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
650 |
+
|
651 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
652 |
+
raise ValueError(
|
653 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
654 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
655 |
+
)
|
656 |
+
|
657 |
+
if return_video_latents or (latents is None and not is_strength_max):
|
658 |
+
video = video.to(device=device, dtype=self.vae.dtype)
|
659 |
+
if self.vae.quant_conv is None or self.vae.quant_conv.weight.ndim==5:
|
660 |
+
bs = 1
|
661 |
+
new_video = []
|
662 |
+
for i in range(0, video.shape[0], bs):
|
663 |
+
video_bs = video[i : i + bs]
|
664 |
+
video_bs = self.vae.encode(video_bs)[0]
|
665 |
+
video_bs = video_bs.sample()
|
666 |
+
new_video.append(video_bs)
|
667 |
+
video = torch.cat(new_video, dim = 0)
|
668 |
+
video = video * self.vae.config.scaling_factor
|
669 |
+
|
670 |
+
else:
|
671 |
+
if video.shape[1] == 4:
|
672 |
+
video = video
|
673 |
+
else:
|
674 |
+
video_length = video.shape[2]
|
675 |
+
video = rearrange(video, "b c f h w -> (b f) c h w")
|
676 |
+
video = self._encode_vae_image(video, generator=generator)
|
677 |
+
video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
|
678 |
+
video_latents = video.repeat(batch_size // video.shape[0], 1, 1, 1, 1)
|
679 |
+
video_latents = video_latents.to(device=device, dtype=dtype)
|
680 |
+
|
681 |
+
if latents is None:
|
682 |
+
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
683 |
+
# if strength is 1. then initialise the latents to noise, else initial to image + noise
|
684 |
+
latents = noise if is_strength_max else self.scheduler.add_noise(video_latents, noise, timestep)
|
685 |
+
# if pure noise then scale the initial latents by the Scheduler's init sigma
|
686 |
+
latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents
|
687 |
+
else:
|
688 |
+
noise = latents.to(device)
|
689 |
+
latents = noise * self.scheduler.init_noise_sigma
|
690 |
+
|
691 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
692 |
+
outputs = (latents,)
|
693 |
+
|
694 |
+
if return_noise:
|
695 |
+
outputs += (noise,)
|
696 |
+
|
697 |
+
if return_video_latents:
|
698 |
+
outputs += (video_latents,)
|
699 |
+
|
700 |
+
return outputs
|
701 |
+
|
702 |
+
def smooth_output(self, video, mini_batch_encoder, mini_batch_decoder):
|
703 |
+
if video.size()[2] <= mini_batch_encoder:
|
704 |
+
return video
|
705 |
+
prefix_index_before = mini_batch_encoder // 2
|
706 |
+
prefix_index_after = mini_batch_encoder - prefix_index_before
|
707 |
+
pixel_values = video[:, :, prefix_index_before:-prefix_index_after]
|
708 |
+
|
709 |
+
# Encode middle videos
|
710 |
+
latents = self.vae.encode(pixel_values)[0]
|
711 |
+
latents = latents.mode()
|
712 |
+
# Decode middle videos
|
713 |
+
middle_video = self.vae.decode(latents)[0]
|
714 |
+
|
715 |
+
video[:, :, prefix_index_before:-prefix_index_after] = (video[:, :, prefix_index_before:-prefix_index_after] + middle_video) / 2
|
716 |
+
return video
|
717 |
+
|
718 |
+
def decode_latents(self, latents):
|
719 |
+
video_length = latents.shape[2]
|
720 |
+
latents = 1 / self.vae.config.scaling_factor * latents
|
721 |
+
if self.vae.quant_conv is None or self.vae.quant_conv.weight.ndim==5:
|
722 |
+
mini_batch_encoder = self.vae.mini_batch_encoder
|
723 |
+
mini_batch_decoder = self.vae.mini_batch_decoder
|
724 |
+
video = self.vae.decode(latents)[0]
|
725 |
+
video = video.clamp(-1, 1)
|
726 |
+
if not self.vae.cache_compression_vae and not self.vae.cache_mag_vae:
|
727 |
+
video = self.smooth_output(video, mini_batch_encoder, mini_batch_decoder).cpu().clamp(-1, 1)
|
728 |
+
else:
|
729 |
+
latents = rearrange(latents, "b c f h w -> (b f) c h w")
|
730 |
+
video = []
|
731 |
+
for frame_idx in tqdm(range(latents.shape[0])):
|
732 |
+
video.append(self.vae.decode(latents[frame_idx:frame_idx+1]).sample)
|
733 |
+
video = torch.cat(video)
|
734 |
+
video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
|
735 |
+
video = (video / 2 + 0.5).clamp(0, 1)
|
736 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
|
737 |
+
video = video.cpu().float().numpy()
|
738 |
+
return video
|
739 |
+
|
740 |
+
@property
|
741 |
+
def guidance_scale(self):
|
742 |
+
return self._guidance_scale
|
743 |
+
|
744 |
+
@property
|
745 |
+
def guidance_rescale(self):
|
746 |
+
return self._guidance_rescale
|
747 |
+
|
748 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
749 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
750 |
+
# corresponds to doing no classifier free guidance.
|
751 |
+
@property
|
752 |
+
def do_classifier_free_guidance(self):
|
753 |
+
return self._guidance_scale > 1
|
754 |
+
|
755 |
+
@property
|
756 |
+
def num_timesteps(self):
|
757 |
+
return self._num_timesteps
|
758 |
+
|
759 |
+
@property
|
760 |
+
def interrupt(self):
|
761 |
+
return self._interrupt
|
762 |
+
|
763 |
+
def enable_autocast_float8_transformer(self):
|
764 |
+
self.enable_autocast_float8_transformer_flag = True
|
765 |
+
|
766 |
+
@torch.no_grad()
|
767 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
768 |
+
def __call__(
|
769 |
+
self,
|
770 |
+
prompt: Union[str, List[str]] = None,
|
771 |
+
video_length: Optional[int] = None,
|
772 |
+
video: Union[torch.FloatTensor] = None,
|
773 |
+
mask_video: Union[torch.FloatTensor] = None,
|
774 |
+
masked_video_latents: Union[torch.FloatTensor] = None,
|
775 |
+
height: Optional[int] = None,
|
776 |
+
width: Optional[int] = None,
|
777 |
+
num_inference_steps: Optional[int] = 50,
|
778 |
+
guidance_scale: Optional[float] = 5.0,
|
779 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
780 |
+
num_images_per_prompt: Optional[int] = 1,
|
781 |
+
eta: Optional[float] = 0.0,
|
782 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
783 |
+
latents: Optional[torch.Tensor] = None,
|
784 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
785 |
+
prompt_embeds_2: Optional[torch.Tensor] = None,
|
786 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
787 |
+
negative_prompt_embeds_2: Optional[torch.Tensor] = None,
|
788 |
+
prompt_attention_mask: Optional[torch.Tensor] = None,
|
789 |
+
prompt_attention_mask_2: Optional[torch.Tensor] = None,
|
790 |
+
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
791 |
+
negative_prompt_attention_mask_2: Optional[torch.Tensor] = None,
|
792 |
+
output_type: Optional[str] = "latent",
|
793 |
+
return_dict: bool = True,
|
794 |
+
callback_on_step_end: Optional[
|
795 |
+
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
796 |
+
] = None,
|
797 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
798 |
+
guidance_rescale: float = 0.0,
|
799 |
+
original_size: Optional[Tuple[int, int]] = (1024, 1024),
|
800 |
+
target_size: Optional[Tuple[int, int]] = None,
|
801 |
+
crops_coords_top_left: Tuple[int, int] = (0, 0),
|
802 |
+
clip_image: Image = None,
|
803 |
+
clip_apply_ratio: float = 0.40,
|
804 |
+
strength: float = 1.0,
|
805 |
+
noise_aug_strength: float = 0.0563,
|
806 |
+
comfyui_progressbar: bool = False,
|
807 |
+
):
|
808 |
+
r"""
|
809 |
+
The call function to the pipeline for generation with HunyuanDiT.
|
810 |
+
|
811 |
+
Examples:
|
812 |
+
prompt (`str` or `List[str]`, *optional*):
|
813 |
+
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
|
814 |
+
video_length (`int`, *optional*):
|
815 |
+
Length of the video to be generated in seconds. This parameter influences the number of frames and
|
816 |
+
continuity of generated content.
|
817 |
+
video (`torch.FloatTensor`, *optional*):
|
818 |
+
A tensor representing an input video, which can be modified depending on the prompts provided.
|
819 |
+
mask_video (`torch.FloatTensor`, *optional*):
|
820 |
+
A tensor to specify areas of the video to be masked (omitted from generation).
|
821 |
+
masked_video_latents (`torch.FloatTensor`, *optional*):
|
822 |
+
Latents from masked portions of the video, utilized during image generation.
|
823 |
+
height (`int`, *optional*):
|
824 |
+
The height in pixels of the generated image or video frames.
|
825 |
+
width (`int`, *optional*):
|
826 |
+
The width in pixels of the generated image or video frames.
|
827 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
828 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image but slower
|
829 |
+
inference time. This parameter is modulated by `strength`.
|
830 |
+
guidance_scale (`float`, *optional*, defaults to 5.0):
|
831 |
+
A higher guidance scale value encourages the model to generate images closely linked to the text
|
832 |
+
`prompt` at the expense of lower image quality. Guidance scale is effective when `guidance_scale > 1`.
|
833 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
834 |
+
The prompt or prompts to guide what to exclude in image generation. If not defined, you need to
|
835 |
+
provide `negative_prompt_embeds`. This parameter is ignored when not using guidance (`guidance_scale < 1`).
|
836 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
837 |
+
The number of images to generate per prompt.
|
838 |
+
eta (`float`, *optional*, defaults to 0.0):
|
839 |
+
A parameter defined in the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies to the
|
840 |
+
[`~schedulers.DDIMScheduler`] and is ignored in other schedulers. It adjusts noise level during the
|
841 |
+
inference process.
|
842 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
843 |
+
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) for setting
|
844 |
+
random seeds which helps in making generation deterministic.
|
845 |
+
latents (`torch.Tensor`, *optional*):
|
846 |
+
A pre-computed latent representation which can be used to guide the generation process.
|
847 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
848 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
|
849 |
+
provided, embeddings are generated from the `prompt` input argument.
|
850 |
+
prompt_embeds_2 (`torch.Tensor`, *optional*):
|
851 |
+
Secondary set of pre-generated text embeddings, useful for advanced prompt weighting.
|
852 |
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
853 |
+
Pre-generated negative text embeddings, aiding in fine-tuning what should not be represented in the outputs.
|
854 |
+
If not provided, embeddings are generated from the `negative_prompt` argument.
|
855 |
+
negative_prompt_embeds_2 (`torch.Tensor`, *optional*):
|
856 |
+
Secondary set of pre-generated negative text embeddings for further control.
|
857 |
+
prompt_attention_mask (`torch.Tensor`, *optional*):
|
858 |
+
Attention mask guiding the focus of the model on specific parts of the prompt text. Required when using
|
859 |
+
`prompt_embeds`.
|
860 |
+
prompt_attention_mask_2 (`torch.Tensor`, *optional*):
|
861 |
+
Attention mask for the secondary prompt embedding.
|
862 |
+
negative_prompt_attention_mask (`torch.Tensor`, *optional*):
|
863 |
+
Attention mask for the negative prompt, needed when `negative_prompt_embeds` are used.
|
864 |
+
negative_prompt_attention_mask_2 (`torch.Tensor`, *optional*):
|
865 |
+
Attention mask for the secondary negative prompt embedding.
|
866 |
+
output_type (`str`, *optional*, defaults to `"latent"`):
|
867 |
+
The output format of the generated image. Choose between `PIL.Image` and `np.array` to define
|
868 |
+
how you want the results to be formatted.
|
869 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
870 |
+
If set to `True`, a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] will be returned;
|
871 |
+
otherwise, a tuple containing the generated images and safety flags will be returned.
|
872 |
+
callback_on_step_end (`Callable[[int, int, Dict], None]`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
|
873 |
+
A callback function (or a list of them) that will be executed at the end of each denoising step,
|
874 |
+
allowing for custom processing during generation.
|
875 |
+
callback_on_step_end_tensor_inputs (`List[str]`, *optional*):
|
876 |
+
Specifies which tensor inputs should be included in the callback function. If not defined, all tensor
|
877 |
+
inputs will be passed, facilitating enhanced logging or monitoring of the generation process.
|
878 |
+
guidance_rescale (`float`, *optional*, defaults to 0.0):
|
879 |
+
Rescale parameter for adjusting noise configuration based on guidance rescale. Based on findings from
|
880 |
+
[Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
|
881 |
+
original_size (`Tuple[int, int]`, *optional*, defaults to `(1024, 1024)`):
|
882 |
+
The original dimensions of the image. Used to compute time ids during the generation process.
|
883 |
+
target_size (`Tuple[int, int]`, *optional*):
|
884 |
+
The targeted dimensions of the generated image, also utilized in the time id calculations.
|
885 |
+
crops_coords_top_left (`Tuple[int, int]`, *optional*, defaults to `(0, 0)`):
|
886 |
+
Coordinates defining the top left corner of any cropping, utilized while calculating the time ids.
|
887 |
+
clip_image (`Image`, *optional*):
|
888 |
+
An optional image to assist in the generation process. It may be used as an additional visual cue.
|
889 |
+
clip_apply_ratio (`float`, *optional*, defaults to 0.40):
|
890 |
+
Ratio indicating how much influence the clip image should exert over the generated content.
|
891 |
+
strength (`float`, *optional*, defaults to 1.0):
|
892 |
+
Affects the overall styling or quality of the generated output. Values closer to 1 usually provide direct
|
893 |
+
adherence to prompts.
|
894 |
+
comfyui_progressbar (`bool`, *optional*, defaults to `False`):
|
895 |
+
Enables a progress bar in ComfyUI, providing visual feedback during the generation process.
|
896 |
+
|
897 |
+
Examples:
|
898 |
+
# Example usage of the function for generating images based on prompts.
|
899 |
+
|
900 |
+
Returns:
|
901 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
902 |
+
Returns either a structured output containing generated images and their metadata when `return_dict` is
|
903 |
+
`True`, or a simpler tuple, where the first element is a list of generated images and the second
|
904 |
+
element indicates if any of them contain "not-safe-for-work" (NSFW) content.
|
905 |
+
"""
|
906 |
+
|
907 |
+
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
908 |
+
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
909 |
+
|
910 |
+
# 0. default height and width
|
911 |
+
height = int(height // 16 * 16)
|
912 |
+
width = int(width // 16 * 16)
|
913 |
+
|
914 |
+
# 1. Check inputs. Raise error if not correct
|
915 |
+
self.check_inputs(
|
916 |
+
prompt,
|
917 |
+
height,
|
918 |
+
width,
|
919 |
+
negative_prompt,
|
920 |
+
prompt_embeds,
|
921 |
+
negative_prompt_embeds,
|
922 |
+
prompt_attention_mask,
|
923 |
+
negative_prompt_attention_mask,
|
924 |
+
prompt_embeds_2,
|
925 |
+
negative_prompt_embeds_2,
|
926 |
+
prompt_attention_mask_2,
|
927 |
+
negative_prompt_attention_mask_2,
|
928 |
+
callback_on_step_end_tensor_inputs,
|
929 |
+
)
|
930 |
+
self._guidance_scale = guidance_scale
|
931 |
+
self._guidance_rescale = guidance_rescale
|
932 |
+
self._interrupt = False
|
933 |
+
|
934 |
+
# 2. Define call parameters
|
935 |
+
if prompt is not None and isinstance(prompt, str):
|
936 |
+
batch_size = 1
|
937 |
+
elif prompt is not None and isinstance(prompt, list):
|
938 |
+
batch_size = len(prompt)
|
939 |
+
else:
|
940 |
+
batch_size = prompt_embeds.shape[0]
|
941 |
+
|
942 |
+
device = self._execution_device
|
943 |
+
|
944 |
+
# 3. Encode input prompt
|
945 |
+
(
|
946 |
+
prompt_embeds,
|
947 |
+
negative_prompt_embeds,
|
948 |
+
prompt_attention_mask,
|
949 |
+
negative_prompt_attention_mask,
|
950 |
+
) = self.encode_prompt(
|
951 |
+
prompt=prompt,
|
952 |
+
device=device,
|
953 |
+
dtype=self.transformer.dtype,
|
954 |
+
num_images_per_prompt=num_images_per_prompt,
|
955 |
+
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
956 |
+
negative_prompt=negative_prompt,
|
957 |
+
prompt_embeds=prompt_embeds,
|
958 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
959 |
+
prompt_attention_mask=prompt_attention_mask,
|
960 |
+
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
961 |
+
text_encoder_index=0,
|
962 |
+
)
|
963 |
+
(
|
964 |
+
prompt_embeds_2,
|
965 |
+
negative_prompt_embeds_2,
|
966 |
+
prompt_attention_mask_2,
|
967 |
+
negative_prompt_attention_mask_2,
|
968 |
+
) = self.encode_prompt(
|
969 |
+
prompt=prompt,
|
970 |
+
device=device,
|
971 |
+
dtype=self.transformer.dtype,
|
972 |
+
num_images_per_prompt=num_images_per_prompt,
|
973 |
+
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
974 |
+
negative_prompt=negative_prompt,
|
975 |
+
prompt_embeds=prompt_embeds_2,
|
976 |
+
negative_prompt_embeds=negative_prompt_embeds_2,
|
977 |
+
prompt_attention_mask=prompt_attention_mask_2,
|
978 |
+
negative_prompt_attention_mask=negative_prompt_attention_mask_2,
|
979 |
+
text_encoder_index=1,
|
980 |
+
)
|
981 |
+
torch.cuda.empty_cache()
|
982 |
+
|
983 |
+
# 4. set timesteps
|
984 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
985 |
+
timesteps, num_inference_steps = self.get_timesteps(
|
986 |
+
num_inference_steps=num_inference_steps, strength=strength, device=device
|
987 |
+
)
|
988 |
+
if comfyui_progressbar:
|
989 |
+
from comfy.utils import ProgressBar
|
990 |
+
pbar = ProgressBar(num_inference_steps + 3)
|
991 |
+
# at which timestep to set the initial noise (n.b. 50% if strength is 0.5)
|
992 |
+
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
|
993 |
+
# create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise
|
994 |
+
is_strength_max = strength == 1.0
|
995 |
+
|
996 |
+
if video is not None:
|
997 |
+
video_length = video.shape[2]
|
998 |
+
init_video = self.image_processor.preprocess(rearrange(video, "b c f h w -> (b f) c h w"), height=height, width=width)
|
999 |
+
init_video = init_video.to(dtype=torch.float32)
|
1000 |
+
init_video = rearrange(init_video, "(b f) c h w -> b c f h w", f=video_length)
|
1001 |
+
else:
|
1002 |
+
init_video = None
|
1003 |
+
|
1004 |
+
# Prepare latent variables
|
1005 |
+
num_channels_latents = self.vae.config.latent_channels
|
1006 |
+
num_channels_transformer = self.transformer.config.in_channels
|
1007 |
+
return_image_latents = num_channels_transformer == num_channels_latents
|
1008 |
+
|
1009 |
+
# 5. Prepare latents.
|
1010 |
+
latents_outputs = self.prepare_latents(
|
1011 |
+
batch_size * num_images_per_prompt,
|
1012 |
+
num_channels_latents,
|
1013 |
+
height,
|
1014 |
+
width,
|
1015 |
+
video_length,
|
1016 |
+
prompt_embeds.dtype,
|
1017 |
+
device,
|
1018 |
+
generator,
|
1019 |
+
latents,
|
1020 |
+
video=init_video,
|
1021 |
+
timestep=latent_timestep,
|
1022 |
+
is_strength_max=is_strength_max,
|
1023 |
+
return_noise=True,
|
1024 |
+
return_video_latents=return_image_latents,
|
1025 |
+
)
|
1026 |
+
if return_image_latents:
|
1027 |
+
latents, noise, image_latents = latents_outputs
|
1028 |
+
else:
|
1029 |
+
latents, noise = latents_outputs
|
1030 |
+
|
1031 |
+
if comfyui_progressbar:
|
1032 |
+
pbar.update(1)
|
1033 |
+
|
1034 |
+
# 6. Prepare clip latents if it needs.
|
1035 |
+
if clip_image is not None and self.transformer.enable_clip_in_inpaint:
|
1036 |
+
inputs = self.clip_image_processor(images=clip_image, return_tensors="pt")
|
1037 |
+
inputs["pixel_values"] = inputs["pixel_values"].to(latents.device, dtype=latents.dtype)
|
1038 |
+
clip_encoder_hidden_states = self.clip_image_encoder(**inputs).last_hidden_state[:, 1:]
|
1039 |
+
clip_encoder_hidden_states_neg = torch.zeros(
|
1040 |
+
[
|
1041 |
+
batch_size,
|
1042 |
+
int(self.clip_image_encoder.config.image_size / self.clip_image_encoder.config.patch_size) ** 2,
|
1043 |
+
int(self.clip_image_encoder.config.hidden_size)
|
1044 |
+
]
|
1045 |
+
).to(latents.device, dtype=latents.dtype)
|
1046 |
+
|
1047 |
+
clip_attention_mask = torch.ones([batch_size, self.transformer.n_query]).to(latents.device, dtype=latents.dtype)
|
1048 |
+
clip_attention_mask_neg = torch.zeros([batch_size, self.transformer.n_query]).to(latents.device, dtype=latents.dtype)
|
1049 |
+
|
1050 |
+
clip_encoder_hidden_states_input = torch.cat([clip_encoder_hidden_states_neg, clip_encoder_hidden_states]) if self.do_classifier_free_guidance else clip_encoder_hidden_states
|
1051 |
+
clip_attention_mask_input = torch.cat([clip_attention_mask_neg, clip_attention_mask]) if self.do_classifier_free_guidance else clip_attention_mask
|
1052 |
+
|
1053 |
+
elif clip_image is None and num_channels_transformer != num_channels_latents and self.transformer.enable_clip_in_inpaint:
|
1054 |
+
clip_encoder_hidden_states = torch.zeros(
|
1055 |
+
[
|
1056 |
+
batch_size,
|
1057 |
+
int(self.clip_image_encoder.config.image_size / self.clip_image_encoder.config.patch_size) ** 2,
|
1058 |
+
int(self.clip_image_encoder.config.hidden_size)
|
1059 |
+
]
|
1060 |
+
).to(latents.device, dtype=latents.dtype)
|
1061 |
+
|
1062 |
+
clip_attention_mask = torch.zeros([batch_size, self.transformer.n_query])
|
1063 |
+
clip_attention_mask = clip_attention_mask.to(latents.device, dtype=latents.dtype)
|
1064 |
+
|
1065 |
+
clip_encoder_hidden_states_input = torch.cat([clip_encoder_hidden_states] * 2) if self.do_classifier_free_guidance else clip_encoder_hidden_states
|
1066 |
+
clip_attention_mask_input = torch.cat([clip_attention_mask] * 2) if self.do_classifier_free_guidance else clip_attention_mask
|
1067 |
+
|
1068 |
+
else:
|
1069 |
+
clip_encoder_hidden_states_input = None
|
1070 |
+
clip_attention_mask_input = None
|
1071 |
+
if comfyui_progressbar:
|
1072 |
+
pbar.update(1)
|
1073 |
+
|
1074 |
+
# 7. Prepare inpaint latents if it needs.
|
1075 |
+
if mask_video is not None:
|
1076 |
+
if (mask_video == 255).all():
|
1077 |
+
# Use zero latents if we want to t2v.
|
1078 |
+
if self.transformer.resize_inpaint_mask_directly:
|
1079 |
+
mask_latents = torch.zeros_like(latents)[:, :1].to(latents.device, latents.dtype)
|
1080 |
+
else:
|
1081 |
+
mask_latents = torch.zeros_like(latents).to(latents.device, latents.dtype)
|
1082 |
+
masked_video_latents = torch.zeros_like(latents).to(latents.device, latents.dtype)
|
1083 |
+
|
1084 |
+
mask_input = torch.cat([mask_latents] * 2) if self.do_classifier_free_guidance else mask_latents
|
1085 |
+
masked_video_latents_input = (
|
1086 |
+
torch.cat([masked_video_latents] * 2) if self.do_classifier_free_guidance else masked_video_latents
|
1087 |
+
)
|
1088 |
+
inpaint_latents = torch.cat([mask_input, masked_video_latents_input], dim=1).to(latents.dtype)
|
1089 |
+
else:
|
1090 |
+
# Prepare mask latent variables
|
1091 |
+
video_length = video.shape[2]
|
1092 |
+
mask_condition = self.mask_processor.preprocess(rearrange(mask_video, "b c f h w -> (b f) c h w"), height=height, width=width)
|
1093 |
+
mask_condition = mask_condition.to(dtype=torch.float32)
|
1094 |
+
mask_condition = rearrange(mask_condition, "(b f) c h w -> b c f h w", f=video_length)
|
1095 |
+
|
1096 |
+
if num_channels_transformer != num_channels_latents:
|
1097 |
+
mask_condition_tile = torch.tile(mask_condition, [1, 3, 1, 1, 1])
|
1098 |
+
if masked_video_latents is None:
|
1099 |
+
masked_video = init_video * (mask_condition_tile < 0.5) + torch.ones_like(init_video) * (mask_condition_tile > 0.5) * -1
|
1100 |
+
else:
|
1101 |
+
masked_video = masked_video_latents
|
1102 |
+
|
1103 |
+
if self.transformer.resize_inpaint_mask_directly:
|
1104 |
+
_, masked_video_latents = self.prepare_mask_latents(
|
1105 |
+
None,
|
1106 |
+
masked_video,
|
1107 |
+
batch_size,
|
1108 |
+
height,
|
1109 |
+
width,
|
1110 |
+
prompt_embeds.dtype,
|
1111 |
+
device,
|
1112 |
+
generator,
|
1113 |
+
self.do_classifier_free_guidance,
|
1114 |
+
noise_aug_strength=noise_aug_strength,
|
1115 |
+
)
|
1116 |
+
mask_latents = resize_mask(1 - mask_condition, masked_video_latents, self.vae.cache_mag_vae)
|
1117 |
+
mask_latents = mask_latents.to(masked_video_latents.device) * self.vae.config.scaling_factor
|
1118 |
+
else:
|
1119 |
+
mask_latents, masked_video_latents = self.prepare_mask_latents(
|
1120 |
+
mask_condition_tile,
|
1121 |
+
masked_video,
|
1122 |
+
batch_size,
|
1123 |
+
height,
|
1124 |
+
width,
|
1125 |
+
prompt_embeds.dtype,
|
1126 |
+
device,
|
1127 |
+
generator,
|
1128 |
+
self.do_classifier_free_guidance,
|
1129 |
+
noise_aug_strength=noise_aug_strength,
|
1130 |
+
)
|
1131 |
+
|
1132 |
+
mask_input = torch.cat([mask_latents] * 2) if self.do_classifier_free_guidance else mask_latents
|
1133 |
+
masked_video_latents_input = (
|
1134 |
+
torch.cat([masked_video_latents] * 2) if self.do_classifier_free_guidance else masked_video_latents
|
1135 |
+
)
|
1136 |
+
inpaint_latents = torch.cat([mask_input, masked_video_latents_input], dim=1).to(latents.dtype)
|
1137 |
+
else:
|
1138 |
+
inpaint_latents = None
|
1139 |
+
|
1140 |
+
mask = torch.tile(mask_condition, [1, num_channels_latents, 1, 1, 1])
|
1141 |
+
mask = F.interpolate(mask, size=latents.size()[-3:], mode='trilinear', align_corners=True).to(latents.device, latents.dtype)
|
1142 |
+
else:
|
1143 |
+
if num_channels_transformer != num_channels_latents:
|
1144 |
+
mask = torch.zeros_like(latents).to(latents.device, latents.dtype)
|
1145 |
+
masked_video_latents = torch.zeros_like(latents).to(latents.device, latents.dtype)
|
1146 |
+
|
1147 |
+
mask_input = torch.cat([mask] * 2) if self.do_classifier_free_guidance else mask
|
1148 |
+
masked_video_latents_input = (
|
1149 |
+
torch.cat([masked_video_latents] * 2) if self.do_classifier_free_guidance else masked_video_latents
|
1150 |
+
)
|
1151 |
+
inpaint_latents = torch.cat([mask_input, masked_video_latents_input], dim=1).to(latents.dtype)
|
1152 |
+
else:
|
1153 |
+
mask = torch.zeros_like(init_video[:, :1])
|
1154 |
+
mask = torch.tile(mask, [1, num_channels_latents, 1, 1, 1])
|
1155 |
+
mask = F.interpolate(mask, size=latents.size()[-3:], mode='trilinear', align_corners=True).to(latents.device, latents.dtype)
|
1156 |
+
|
1157 |
+
inpaint_latents = None
|
1158 |
+
if comfyui_progressbar:
|
1159 |
+
pbar.update(1)
|
1160 |
+
|
1161 |
+
# Check that sizes of mask, masked image and latents match
|
1162 |
+
if num_channels_transformer != num_channels_latents:
|
1163 |
+
num_channels_mask = mask_latents.shape[1]
|
1164 |
+
num_channels_masked_image = masked_video_latents.shape[1]
|
1165 |
+
if num_channels_latents + num_channels_mask + num_channels_masked_image != self.transformer.config.in_channels:
|
1166 |
+
raise ValueError(
|
1167 |
+
f"Incorrect configuration settings! The config of `pipeline.transformer`: {self.transformer.config} expects"
|
1168 |
+
f" {self.transformer.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
|
1169 |
+
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
|
1170 |
+
f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
|
1171 |
+
" `pipeline.transformer` or your `mask_image` or `image` input."
|
1172 |
+
)
|
1173 |
+
|
1174 |
+
# 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
1175 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
1176 |
+
|
1177 |
+
# 9 create image_rotary_emb, style embedding & time ids
|
1178 |
+
grid_height = height // 8 // self.transformer.config.patch_size
|
1179 |
+
grid_width = width // 8 // self.transformer.config.patch_size
|
1180 |
+
if self.transformer.config.get("time_position_encoding_type", "2d_rope") == "3d_rope":
|
1181 |
+
base_size_width = 720 // 8 // self.transformer.config.patch_size
|
1182 |
+
base_size_height = 480 // 8 // self.transformer.config.patch_size
|
1183 |
+
|
1184 |
+
grid_crops_coords = get_resize_crop_region_for_grid(
|
1185 |
+
(grid_height, grid_width), base_size_width, base_size_height
|
1186 |
+
)
|
1187 |
+
image_rotary_emb = get_3d_rotary_pos_embed(
|
1188 |
+
self.transformer.config.attention_head_dim, grid_crops_coords, grid_size=(grid_height, grid_width),
|
1189 |
+
temporal_size=latents.size(2), use_real=True,
|
1190 |
+
)
|
1191 |
+
else:
|
1192 |
+
base_size = 512 // 8 // self.transformer.config.patch_size
|
1193 |
+
grid_crops_coords = get_resize_crop_region_for_grid(
|
1194 |
+
(grid_height, grid_width), base_size, base_size
|
1195 |
+
)
|
1196 |
+
image_rotary_emb = get_2d_rotary_pos_embed(
|
1197 |
+
self.transformer.config.attention_head_dim, grid_crops_coords, (grid_height, grid_width)
|
1198 |
+
)
|
1199 |
+
|
1200 |
+
# Get other hunyuan params
|
1201 |
+
style = torch.tensor([0], device=device)
|
1202 |
+
|
1203 |
+
target_size = target_size or (height, width)
|
1204 |
+
add_time_ids = list(original_size + target_size + crops_coords_top_left)
|
1205 |
+
add_time_ids = torch.tensor([add_time_ids], dtype=prompt_embeds.dtype)
|
1206 |
+
|
1207 |
+
if self.do_classifier_free_guidance:
|
1208 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
1209 |
+
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask])
|
1210 |
+
prompt_embeds_2 = torch.cat([negative_prompt_embeds_2, prompt_embeds_2])
|
1211 |
+
prompt_attention_mask_2 = torch.cat([negative_prompt_attention_mask_2, prompt_attention_mask_2])
|
1212 |
+
add_time_ids = torch.cat([add_time_ids] * 2, dim=0)
|
1213 |
+
style = torch.cat([style] * 2, dim=0)
|
1214 |
+
|
1215 |
+
prompt_embeds = prompt_embeds.to(device=device)
|
1216 |
+
prompt_attention_mask = prompt_attention_mask.to(device=device)
|
1217 |
+
prompt_embeds_2 = prompt_embeds_2.to(device=device)
|
1218 |
+
prompt_attention_mask_2 = prompt_attention_mask_2.to(device=device)
|
1219 |
+
add_time_ids = add_time_ids.to(dtype=prompt_embeds.dtype, device=device).repeat(
|
1220 |
+
batch_size * num_images_per_prompt, 1
|
1221 |
+
)
|
1222 |
+
style = style.to(device=device).repeat(batch_size * num_images_per_prompt)
|
1223 |
+
|
1224 |
+
torch.cuda.empty_cache()
|
1225 |
+
if self.enable_autocast_float8_transformer_flag:
|
1226 |
+
origin_weight_dtype = self.transformer.dtype
|
1227 |
+
self.transformer = self.transformer.to(torch.float8_e4m3fn)
|
1228 |
+
# 10. Denoising loop
|
1229 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
1230 |
+
self._num_timesteps = len(timesteps)
|
1231 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
1232 |
+
for i, t in enumerate(timesteps):
|
1233 |
+
if self.interrupt:
|
1234 |
+
continue
|
1235 |
+
|
1236 |
+
# expand the latents if we are doing classifier free guidance
|
1237 |
+
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
1238 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
1239 |
+
|
1240 |
+
if i < len(timesteps) * (1 - clip_apply_ratio) and clip_encoder_hidden_states_input is not None:
|
1241 |
+
clip_encoder_hidden_states_actual_input = torch.zeros_like(clip_encoder_hidden_states_input)
|
1242 |
+
clip_attention_mask_actual_input = torch.zeros_like(clip_attention_mask_input)
|
1243 |
+
else:
|
1244 |
+
clip_encoder_hidden_states_actual_input = clip_encoder_hidden_states_input
|
1245 |
+
clip_attention_mask_actual_input = clip_attention_mask_input
|
1246 |
+
|
1247 |
+
# expand scalar t to 1-D tensor to match the 1st dim of latent_model_input
|
1248 |
+
t_expand = torch.tensor([t] * latent_model_input.shape[0], device=device).to(
|
1249 |
+
dtype=latent_model_input.dtype
|
1250 |
+
)
|
1251 |
+
|
1252 |
+
# predict the noise residual
|
1253 |
+
noise_pred = self.transformer(
|
1254 |
+
latent_model_input,
|
1255 |
+
t_expand,
|
1256 |
+
encoder_hidden_states=prompt_embeds,
|
1257 |
+
text_embedding_mask=prompt_attention_mask,
|
1258 |
+
encoder_hidden_states_t5=prompt_embeds_2,
|
1259 |
+
text_embedding_mask_t5=prompt_attention_mask_2,
|
1260 |
+
image_meta_size=add_time_ids,
|
1261 |
+
style=style,
|
1262 |
+
image_rotary_emb=image_rotary_emb,
|
1263 |
+
inpaint_latents=inpaint_latents,
|
1264 |
+
clip_encoder_hidden_states=clip_encoder_hidden_states_actual_input,
|
1265 |
+
clip_attention_mask=clip_attention_mask_actual_input,
|
1266 |
+
return_dict=False,
|
1267 |
+
)[0]
|
1268 |
+
if noise_pred.size()[1] != self.vae.config.latent_channels:
|
1269 |
+
noise_pred, _ = noise_pred.chunk(2, dim=1)
|
1270 |
+
|
1271 |
+
# perform guidance
|
1272 |
+
if self.do_classifier_free_guidance:
|
1273 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
1274 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
1275 |
+
|
1276 |
+
if self.do_classifier_free_guidance and guidance_rescale > 0.0:
|
1277 |
+
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
1278 |
+
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
|
1279 |
+
|
1280 |
+
# compute the previous noisy sample x_t -> x_t-1
|
1281 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
1282 |
+
|
1283 |
+
if num_channels_transformer == 4:
|
1284 |
+
init_latents_proper = image_latents
|
1285 |
+
init_mask = mask
|
1286 |
+
if i < len(timesteps) - 1:
|
1287 |
+
noise_timestep = timesteps[i + 1]
|
1288 |
+
init_latents_proper = self.scheduler.add_noise(
|
1289 |
+
init_latents_proper, noise, torch.tensor([noise_timestep])
|
1290 |
+
)
|
1291 |
+
|
1292 |
+
latents = (1 - init_mask) * init_latents_proper + init_mask * latents
|
1293 |
+
|
1294 |
+
if callback_on_step_end is not None:
|
1295 |
+
callback_kwargs = {}
|
1296 |
+
for k in callback_on_step_end_tensor_inputs:
|
1297 |
+
callback_kwargs[k] = locals()[k]
|
1298 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
1299 |
+
|
1300 |
+
latents = callback_outputs.pop("latents", latents)
|
1301 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
1302 |
+
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
1303 |
+
prompt_embeds_2 = callback_outputs.pop("prompt_embeds_2", prompt_embeds_2)
|
1304 |
+
negative_prompt_embeds_2 = callback_outputs.pop(
|
1305 |
+
"negative_prompt_embeds_2", negative_prompt_embeds_2
|
1306 |
+
)
|
1307 |
+
|
1308 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
1309 |
+
progress_bar.update()
|
1310 |
+
|
1311 |
+
if XLA_AVAILABLE:
|
1312 |
+
xm.mark_step()
|
1313 |
+
|
1314 |
+
if comfyui_progressbar:
|
1315 |
+
pbar.update(1)
|
1316 |
+
|
1317 |
+
if self.enable_autocast_float8_transformer_flag:
|
1318 |
+
self.transformer = self.transformer.to("cpu", origin_weight_dtype)
|
1319 |
+
|
1320 |
+
torch.cuda.empty_cache()
|
1321 |
+
# Post-processing
|
1322 |
+
video = self.decode_latents(latents)
|
1323 |
+
|
1324 |
+
# Convert to tensor
|
1325 |
+
if output_type == "latent":
|
1326 |
+
video = torch.from_numpy(video)
|
1327 |
+
|
1328 |
+
# Offload all models
|
1329 |
+
self.maybe_free_model_hooks()
|
1330 |
+
|
1331 |
+
if not return_dict:
|
1332 |
+
return video
|
1333 |
+
|
1334 |
+
return EasyAnimatePipelineOutput(videos=video)
|
easyanimate/ui/ui.py
CHANGED
The diff for this file is too large to render.
See raw diff
|
|
easyanimate/utils/discrete_sampler.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Modified from https://github.com/THUDM/CogVideo/blob/3710a612d8760f5cdb1741befeebb65b9e0f2fe0/sat/sgm/modules/diffusionmodules/sigma_sampling.py
|
2 |
+
"""
|
3 |
+
import torch
|
4 |
+
|
5 |
+
class DiscreteSampling:
|
6 |
+
def __init__(self, num_idx, uniform_sampling=False):
|
7 |
+
self.num_idx = num_idx
|
8 |
+
self.uniform_sampling = uniform_sampling
|
9 |
+
self.is_distributed = torch.distributed.is_available() and torch.distributed.is_initialized()
|
10 |
+
|
11 |
+
if self.is_distributed and self.uniform_sampling:
|
12 |
+
world_size = torch.distributed.get_world_size()
|
13 |
+
self.rank = torch.distributed.get_rank()
|
14 |
+
|
15 |
+
i = 1
|
16 |
+
while True:
|
17 |
+
if world_size % i != 0 or num_idx % (world_size // i) != 0:
|
18 |
+
i += 1
|
19 |
+
else:
|
20 |
+
self.group_num = world_size // i
|
21 |
+
break
|
22 |
+
assert self.group_num > 0
|
23 |
+
assert world_size % self.group_num == 0
|
24 |
+
# the number of rank in one group
|
25 |
+
self.group_width = world_size // self.group_num
|
26 |
+
self.sigma_interval = self.num_idx // self.group_num
|
27 |
+
print('rank=%d world_size=%d group_num=%d group_width=%d sigma_interval=%s' % (
|
28 |
+
self.rank, world_size, self.group_num,
|
29 |
+
self.group_width, self.sigma_interval))
|
30 |
+
|
31 |
+
def __call__(self, n_samples, generator=None, device=None):
|
32 |
+
if self.is_distributed and self.uniform_sampling:
|
33 |
+
group_index = self.rank // self.group_width
|
34 |
+
idx = torch.randint(
|
35 |
+
group_index * self.sigma_interval,
|
36 |
+
(group_index + 1) * self.sigma_interval,
|
37 |
+
(n_samples,),
|
38 |
+
generator=generator, device=device,
|
39 |
+
)
|
40 |
+
print('proc[%d] idx=%s' % (self.rank, idx))
|
41 |
+
else:
|
42 |
+
idx = torch.randint(
|
43 |
+
0, self.num_idx, (n_samples,),
|
44 |
+
generator=generator, device=device,
|
45 |
+
)
|
46 |
+
return idx
|
easyanimate/utils/fp8_optimization.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Modified from https://github.com/kijai/ComfyUI-MochiWrapper
|
2 |
+
"""
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
|
6 |
+
def autocast_model_forward(cls, origin_dtype, *inputs, **kwargs):
|
7 |
+
weight_dtype = cls.weight.dtype
|
8 |
+
cls.to(origin_dtype)
|
9 |
+
|
10 |
+
# Convert all inputs to the original dtype
|
11 |
+
inputs = [input.to(origin_dtype) for input in inputs]
|
12 |
+
out = cls.original_forward(*inputs, **kwargs)
|
13 |
+
|
14 |
+
cls.to(weight_dtype)
|
15 |
+
return out
|
16 |
+
|
17 |
+
def convert_weight_dtype_wrapper(module, origin_dtype):
|
18 |
+
for name, module in module.named_modules():
|
19 |
+
if name == "":
|
20 |
+
continue
|
21 |
+
original_forward = module.forward
|
22 |
+
if hasattr(module, "weight"):
|
23 |
+
setattr(module, "original_forward", original_forward)
|
24 |
+
setattr(
|
25 |
+
module,
|
26 |
+
"forward",
|
27 |
+
lambda *inputs, m=module, **kwargs: autocast_model_forward(m, origin_dtype, *inputs, **kwargs)
|
28 |
+
)
|
easyanimate/utils/lora_utils.py
CHANGED
@@ -156,8 +156,8 @@ def precalculate_safetensors_hashes(tensors, metadata):
|
|
156 |
|
157 |
|
158 |
class LoRANetwork(torch.nn.Module):
|
159 |
-
TRANSFORMER_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Transformer3DModel"]
|
160 |
-
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["T5LayerSelfAttention", "T5LayerFF"]
|
161 |
LORA_PREFIX_TRANSFORMER = "lora_unet"
|
162 |
LORA_PREFIX_TEXT_ENCODER = "lora_te"
|
163 |
def __init__(
|
@@ -238,9 +238,10 @@ class LoRANetwork(torch.nn.Module):
|
|
238 |
self.text_encoder_loras = []
|
239 |
skipped_te = []
|
240 |
for i, text_encoder in enumerate(text_encoders):
|
241 |
-
|
242 |
-
|
243 |
-
|
|
|
244 |
print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
|
245 |
|
246 |
self.unet_loras, skipped_un = create_modules(True, unet, LoRANetwork.TRANSFORMER_TARGET_REPLACE_MODULE)
|
@@ -368,6 +369,7 @@ def create_network(
|
|
368 |
def merge_lora(pipeline, lora_path, multiplier, device='cpu', dtype=torch.float32, state_dict=None, transformer_only=False):
|
369 |
LORA_PREFIX_TRANSFORMER = "lora_unet"
|
370 |
LORA_PREFIX_TEXT_ENCODER = "lora_te"
|
|
|
371 |
if state_dict is None:
|
372 |
state_dict = load_file(lora_path, device=device)
|
373 |
else:
|
@@ -389,21 +391,24 @@ def merge_lora(pipeline, lora_path, multiplier, device='cpu', dtype=torch.float3
|
|
389 |
layer_infos = layer.split(LORA_PREFIX_TRANSFORMER + "_")[-1].split("_")
|
390 |
curr_layer = pipeline.transformer
|
391 |
|
392 |
-
|
393 |
-
|
394 |
-
|
395 |
-
|
396 |
-
|
397 |
-
|
398 |
-
|
399 |
-
|
400 |
-
|
401 |
-
|
402 |
-
|
403 |
-
|
404 |
-
|
405 |
-
|
406 |
-
temp_name
|
|
|
|
|
|
|
407 |
|
408 |
weight_up = elems['lora_up.weight'].to(dtype)
|
409 |
weight_down = elems['lora_down.weight'].to(dtype)
|
@@ -444,6 +449,7 @@ def unmerge_lora(pipeline, lora_path, multiplier=1, device="cpu", dtype=torch.fl
|
|
444 |
curr_layer = pipeline.transformer
|
445 |
|
446 |
temp_name = layer_infos.pop(0)
|
|
|
447 |
while len(layer_infos) > -1:
|
448 |
try:
|
449 |
curr_layer = curr_layer.__getattr__(temp_name)
|
|
|
156 |
|
157 |
|
158 |
class LoRANetwork(torch.nn.Module):
|
159 |
+
TRANSFORMER_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Transformer3DModel", "HunyuanTransformer3DModel", "EasyAnimateTransformer3DModel"]
|
160 |
+
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["T5LayerSelfAttention", "T5LayerFF", "BertEncoder"]
|
161 |
LORA_PREFIX_TRANSFORMER = "lora_unet"
|
162 |
LORA_PREFIX_TEXT_ENCODER = "lora_te"
|
163 |
def __init__(
|
|
|
238 |
self.text_encoder_loras = []
|
239 |
skipped_te = []
|
240 |
for i, text_encoder in enumerate(text_encoders):
|
241 |
+
if text_encoder is not None:
|
242 |
+
text_encoder_loras, skipped = create_modules(False, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
|
243 |
+
self.text_encoder_loras.extend(text_encoder_loras)
|
244 |
+
skipped_te += skipped
|
245 |
print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
|
246 |
|
247 |
self.unet_loras, skipped_un = create_modules(True, unet, LoRANetwork.TRANSFORMER_TARGET_REPLACE_MODULE)
|
|
|
369 |
def merge_lora(pipeline, lora_path, multiplier, device='cpu', dtype=torch.float32, state_dict=None, transformer_only=False):
|
370 |
LORA_PREFIX_TRANSFORMER = "lora_unet"
|
371 |
LORA_PREFIX_TEXT_ENCODER = "lora_te"
|
372 |
+
SPECIAL_LAYER_NAME = ["text_proj_t5"]
|
373 |
if state_dict is None:
|
374 |
state_dict = load_file(lora_path, device=device)
|
375 |
else:
|
|
|
391 |
layer_infos = layer.split(LORA_PREFIX_TRANSFORMER + "_")[-1].split("_")
|
392 |
curr_layer = pipeline.transformer
|
393 |
|
394 |
+
try:
|
395 |
+
curr_layer = curr_layer.__getattr__("_".join(layer_infos[1:]))
|
396 |
+
except Exception:
|
397 |
+
temp_name = layer_infos.pop(0)
|
398 |
+
while len(layer_infos) > -1:
|
399 |
+
try:
|
400 |
+
curr_layer = curr_layer.__getattr__(temp_name)
|
401 |
+
if len(layer_infos) > 0:
|
402 |
+
temp_name = layer_infos.pop(0)
|
403 |
+
elif len(layer_infos) == 0:
|
404 |
+
break
|
405 |
+
except Exception:
|
406 |
+
if len(layer_infos) == 0:
|
407 |
+
print('Error loading layer')
|
408 |
+
if len(temp_name) > 0:
|
409 |
+
temp_name += "_" + layer_infos.pop(0)
|
410 |
+
else:
|
411 |
+
temp_name = layer_infos.pop(0)
|
412 |
|
413 |
weight_up = elems['lora_up.weight'].to(dtype)
|
414 |
weight_down = elems['lora_down.weight'].to(dtype)
|
|
|
449 |
curr_layer = pipeline.transformer
|
450 |
|
451 |
temp_name = layer_infos.pop(0)
|
452 |
+
print(layer, curr_layer)
|
453 |
while len(layer_infos) > -1:
|
454 |
try:
|
455 |
curr_layer = curr_layer.__getattr__(temp_name)
|
easyanimate/utils/utils.py
CHANGED
@@ -1,13 +1,15 @@
|
|
|
|
1 |
import os
|
2 |
|
|
|
3 |
import imageio
|
4 |
import numpy as np
|
5 |
import torch
|
6 |
import torchvision
|
7 |
-
import cv2
|
8 |
from einops import rearrange
|
9 |
from PIL import Image
|
10 |
|
|
|
11 |
def get_width_and_height_from_image_and_base_resolution(image, base_resolution):
|
12 |
target_pixels = int(base_resolution) * int(base_resolution)
|
13 |
original_width, original_height = Image.open(image).size
|
@@ -73,13 +75,20 @@ def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, f
|
|
73 |
def get_image_to_video_latent(validation_image_start, validation_image_end, video_length, sample_size):
|
74 |
if validation_image_start is not None and validation_image_end is not None:
|
75 |
if type(validation_image_start) is str and os.path.isfile(validation_image_start):
|
76 |
-
image_start = clip_image = Image.open(validation_image_start)
|
|
|
|
|
77 |
else:
|
78 |
image_start = clip_image = validation_image_start
|
|
|
|
|
|
|
79 |
if type(validation_image_end) is str and os.path.isfile(validation_image_end):
|
80 |
-
image_end = Image.open(validation_image_end)
|
|
|
81 |
else:
|
82 |
image_end = validation_image_end
|
|
|
83 |
|
84 |
if type(image_start) is list:
|
85 |
clip_image = clip_image[0]
|
@@ -119,8 +128,13 @@ def get_image_to_video_latent(validation_image_start, validation_image_end, vide
|
|
119 |
elif validation_image_start is not None:
|
120 |
if type(validation_image_start) is str and os.path.isfile(validation_image_start):
|
121 |
image_start = clip_image = Image.open(validation_image_start).convert("RGB")
|
|
|
|
|
122 |
else:
|
123 |
image_start = clip_image = validation_image_start
|
|
|
|
|
|
|
124 |
|
125 |
if type(image_start) is list:
|
126 |
clip_image = clip_image[0]
|
@@ -142,30 +156,60 @@ def get_image_to_video_latent(validation_image_start, validation_image_end, vide
|
|
142 |
input_video_mask = torch.zeros_like(input_video[:, :1])
|
143 |
input_video_mask[:, :, 1:, ] = 255
|
144 |
else:
|
|
|
|
|
145 |
input_video = torch.zeros([1, 3, video_length, sample_size[0], sample_size[1]])
|
146 |
input_video_mask = torch.ones([1, 1, video_length, sample_size[0], sample_size[1]]) * 255
|
147 |
clip_image = None
|
148 |
|
|
|
|
|
|
|
|
|
149 |
return input_video, input_video_mask, clip_image
|
150 |
|
151 |
-
def
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
165 |
input_video = torch.from_numpy(np.array(input_video))[:video_length]
|
166 |
input_video = input_video.permute([3, 0, 1, 2]).unsqueeze(0) / 255
|
167 |
|
168 |
-
|
169 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
170 |
|
171 |
-
return input_video, input_video_mask,
|
|
|
1 |
+
import gc
|
2 |
import os
|
3 |
|
4 |
+
import cv2
|
5 |
import imageio
|
6 |
import numpy as np
|
7 |
import torch
|
8 |
import torchvision
|
|
|
9 |
from einops import rearrange
|
10 |
from PIL import Image
|
11 |
|
12 |
+
|
13 |
def get_width_and_height_from_image_and_base_resolution(image, base_resolution):
|
14 |
target_pixels = int(base_resolution) * int(base_resolution)
|
15 |
original_width, original_height = Image.open(image).size
|
|
|
75 |
def get_image_to_video_latent(validation_image_start, validation_image_end, video_length, sample_size):
|
76 |
if validation_image_start is not None and validation_image_end is not None:
|
77 |
if type(validation_image_start) is str and os.path.isfile(validation_image_start):
|
78 |
+
image_start = clip_image = Image.open(validation_image_start).convert("RGB")
|
79 |
+
image_start = image_start.resize([sample_size[1], sample_size[0]])
|
80 |
+
clip_image = clip_image.resize([sample_size[1], sample_size[0]])
|
81 |
else:
|
82 |
image_start = clip_image = validation_image_start
|
83 |
+
image_start = [_image_start.resize([sample_size[1], sample_size[0]]) for _image_start in image_start]
|
84 |
+
clip_image = [_clip_image.resize([sample_size[1], sample_size[0]]) for _clip_image in clip_image]
|
85 |
+
|
86 |
if type(validation_image_end) is str and os.path.isfile(validation_image_end):
|
87 |
+
image_end = Image.open(validation_image_end).convert("RGB")
|
88 |
+
image_end = image_end.resize([sample_size[1], sample_size[0]])
|
89 |
else:
|
90 |
image_end = validation_image_end
|
91 |
+
image_end = [_image_end.resize([sample_size[1], sample_size[0]]) for _image_end in image_end]
|
92 |
|
93 |
if type(image_start) is list:
|
94 |
clip_image = clip_image[0]
|
|
|
128 |
elif validation_image_start is not None:
|
129 |
if type(validation_image_start) is str and os.path.isfile(validation_image_start):
|
130 |
image_start = clip_image = Image.open(validation_image_start).convert("RGB")
|
131 |
+
image_start = image_start.resize([sample_size[1], sample_size[0]])
|
132 |
+
clip_image = clip_image.resize([sample_size[1], sample_size[0]])
|
133 |
else:
|
134 |
image_start = clip_image = validation_image_start
|
135 |
+
image_start = [_image_start.resize([sample_size[1], sample_size[0]]) for _image_start in image_start]
|
136 |
+
clip_image = [_clip_image.resize([sample_size[1], sample_size[0]]) for _clip_image in clip_image]
|
137 |
+
image_end = None
|
138 |
|
139 |
if type(image_start) is list:
|
140 |
clip_image = clip_image[0]
|
|
|
156 |
input_video_mask = torch.zeros_like(input_video[:, :1])
|
157 |
input_video_mask[:, :, 1:, ] = 255
|
158 |
else:
|
159 |
+
image_start = None
|
160 |
+
image_end = None
|
161 |
input_video = torch.zeros([1, 3, video_length, sample_size[0], sample_size[1]])
|
162 |
input_video_mask = torch.ones([1, 1, video_length, sample_size[0], sample_size[1]]) * 255
|
163 |
clip_image = None
|
164 |
|
165 |
+
del image_start
|
166 |
+
del image_end
|
167 |
+
gc.collect()
|
168 |
+
|
169 |
return input_video, input_video_mask, clip_image
|
170 |
|
171 |
+
def get_video_to_video_latent(input_video_path, video_length, sample_size, fps=None, validation_video_mask=None, ref_image=None):
|
172 |
+
if isinstance(input_video_path, str):
|
173 |
+
cap = cv2.VideoCapture(input_video_path)
|
174 |
+
input_video = []
|
175 |
+
|
176 |
+
original_fps = cap.get(cv2.CAP_PROP_FPS)
|
177 |
+
frame_skip = 1 if fps is None else int(original_fps // fps)
|
178 |
+
|
179 |
+
frame_count = 0
|
180 |
+
|
181 |
+
while True:
|
182 |
+
ret, frame = cap.read()
|
183 |
+
if not ret:
|
184 |
+
break
|
185 |
+
|
186 |
+
if frame_count % frame_skip == 0:
|
187 |
+
frame = cv2.resize(frame, (sample_size[1], sample_size[0]))
|
188 |
+
input_video.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
|
189 |
+
|
190 |
+
frame_count += 1
|
191 |
+
|
192 |
+
cap.release()
|
193 |
+
else:
|
194 |
+
input_video = input_video_path
|
195 |
+
|
196 |
input_video = torch.from_numpy(np.array(input_video))[:video_length]
|
197 |
input_video = input_video.permute([3, 0, 1, 2]).unsqueeze(0) / 255
|
198 |
|
199 |
+
if ref_image is not None:
|
200 |
+
ref_image = Image.open(ref_image)
|
201 |
+
ref_image = torch.from_numpy(np.array(ref_image))
|
202 |
+
ref_image = ref_image.unsqueeze(0).permute([3, 0, 1, 2]).unsqueeze(0) / 255
|
203 |
+
|
204 |
+
if validation_video_mask is not None:
|
205 |
+
validation_video_mask = Image.open(validation_video_mask).convert('L').resize((sample_size[1], sample_size[0]))
|
206 |
+
input_video_mask = np.where(np.array(validation_video_mask) < 240, 0, 255)
|
207 |
+
|
208 |
+
input_video_mask = torch.from_numpy(np.array(input_video_mask)).unsqueeze(0).unsqueeze(-1).permute([3, 0, 1, 2]).unsqueeze(0)
|
209 |
+
input_video_mask = torch.tile(input_video_mask, [1, 1, input_video.size()[2], 1, 1])
|
210 |
+
input_video_mask = input_video_mask.to(input_video.device, input_video.dtype)
|
211 |
+
else:
|
212 |
+
input_video_mask = torch.zeros_like(input_video[:, :1])
|
213 |
+
input_video_mask[:, :, :] = 255
|
214 |
|
215 |
+
return input_video, input_video_mask, ref_image
|
easyanimate/vae/configs/autoencoder/autoencoder_kl_32x32x4_cogvideox.yaml
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
base_learning_rate: 1.0e-04
|
3 |
+
target: easyanimate.vae.ldm.models.cogvideox_casual3dcnn.AutoencoderKLMagvit_CogVideoX
|
4 |
+
params:
|
5 |
+
latent_channels: 16
|
6 |
+
temporal_compression_ratio: 4
|
7 |
+
monitor: train/rec_loss
|
8 |
+
ckpt_path: vae/diffusion_pytorch_model.safetensors
|
9 |
+
down_block_types: ("CogVideoXDownBlock3D", "CogVideoXDownBlock3D", "CogVideoXDownBlock3D",
|
10 |
+
"CogVideoXDownBlock3D",)
|
11 |
+
up_block_types: ("CogVideoXUpBlock3D", "CogVideoXUpBlock3D", "CogVideoXUpBlock3D",
|
12 |
+
"CogVideoXUpBlock3D",)
|
13 |
+
lossconfig:
|
14 |
+
target: easyanimate.vae.ldm.modules.losses.LPIPSWithDiscriminator
|
15 |
+
params:
|
16 |
+
disc_start: 50001
|
17 |
+
kl_weight: 1.0e-06
|
18 |
+
disc_weight: 0.5
|
19 |
+
l2_loss_weight: 0.1
|
20 |
+
l1_loss_weight: 1.0
|
21 |
+
perceptual_weight: 1.0
|
22 |
+
|
23 |
+
data:
|
24 |
+
target: train_vae.DataModuleFromConfig
|
25 |
+
|
26 |
+
params:
|
27 |
+
batch_size: 1
|
28 |
+
wrap: true
|
29 |
+
num_workers: 8
|
30 |
+
train:
|
31 |
+
target: easyanimate.vae.ldm.data.dataset_image_video.CustomSRTrain
|
32 |
+
params:
|
33 |
+
data_json_path: pretrain.json
|
34 |
+
data_root: /your_data_root # This is used in relative path
|
35 |
+
size: 256
|
36 |
+
degradation: pil_nearest
|
37 |
+
video_size: 256
|
38 |
+
video_len: 49
|
39 |
+
slice_interval: 1
|
40 |
+
validation:
|
41 |
+
target: easyanimate.vae.ldm.data.dataset_image_video.CustomSRValidation
|
42 |
+
params:
|
43 |
+
data_json_path: pretrain.json
|
44 |
+
data_root: /your_data_root # This is used in relative path
|
45 |
+
size: 256
|
46 |
+
degradation: pil_nearest
|
47 |
+
video_size: 256
|
48 |
+
video_len: 49
|
49 |
+
slice_interval: 1
|
50 |
+
|
51 |
+
lightning:
|
52 |
+
callbacks:
|
53 |
+
image_logger:
|
54 |
+
target: train_vae.ImageLogger
|
55 |
+
params:
|
56 |
+
batch_frequency: 5000
|
57 |
+
max_images: 8
|
58 |
+
increase_log_steps: True
|
59 |
+
|
60 |
+
trainer:
|
61 |
+
benchmark: True
|
62 |
+
accumulate_grad_batches: 1
|
63 |
+
gpus: "0"
|
64 |
+
num_nodes: 1
|
easyanimate/vae/configs/autoencoder/autoencoder_kl_32x32x4_mag_v2.yaml
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
base_learning_rate: 1.0e-04
|
3 |
+
target: easyanimate.vae.ldm.models.omnigen_casual3dcnn.AutoencoderKLMagvit_fromOmnigen
|
4 |
+
params:
|
5 |
+
spatial_group_norm: true
|
6 |
+
mid_block_attention_type: "spatial"
|
7 |
+
latent_channels: 16
|
8 |
+
monitor: train/rec_loss
|
9 |
+
ckpt_path: vae/diffusion_pytorch_model.safetensors
|
10 |
+
down_block_types: ("SpatialDownBlock3D", "SpatialTemporalDownBlock3D", "SpatialTemporalDownBlock3D",
|
11 |
+
"SpatialTemporalDownBlock3D",)
|
12 |
+
up_block_types: ("SpatialUpBlock3D", "SpatialTemporalUpBlock3D", "SpatialTemporalUpBlock3D",
|
13 |
+
"SpatialTemporalUpBlock3D",)
|
14 |
+
lossconfig:
|
15 |
+
target: easyanimate.vae.ldm.modules.losses.LPIPSWithDiscriminator
|
16 |
+
params:
|
17 |
+
disc_start: 50001
|
18 |
+
kl_weight: 1.0e-06
|
19 |
+
disc_weight: 0.5
|
20 |
+
l2_loss_weight: 0.1
|
21 |
+
l1_loss_weight: 1.0
|
22 |
+
perceptual_weight: 1.0
|
23 |
+
|
24 |
+
data:
|
25 |
+
target: train_vae.DataModuleFromConfig
|
26 |
+
|
27 |
+
params:
|
28 |
+
batch_size: 1
|
29 |
+
wrap: true
|
30 |
+
num_workers: 8
|
31 |
+
train:
|
32 |
+
target: easyanimate.vae.ldm.data.dataset_image_video.CustomSRTrain
|
33 |
+
params:
|
34 |
+
data_json_path: pretrain.json
|
35 |
+
data_root: /your_data_root # This is used in relative path
|
36 |
+
size: 256
|
37 |
+
degradation: pil_nearest
|
38 |
+
video_size: 256
|
39 |
+
video_len: 49
|
40 |
+
slice_interval: 1
|
41 |
+
validation:
|
42 |
+
target: easyanimate.vae.ldm.data.dataset_image_video.CustomSRValidation
|
43 |
+
params:
|
44 |
+
data_json_path: pretrain.json
|
45 |
+
data_root: /your_data_root # This is used in relative path
|
46 |
+
size: 256
|
47 |
+
degradation: pil_nearest
|
48 |
+
video_size: 256
|
49 |
+
video_len: 49
|
50 |
+
slice_interval: 1
|
51 |
+
|
52 |
+
lightning:
|
53 |
+
callbacks:
|
54 |
+
image_logger:
|
55 |
+
target: train_vae.ImageLogger
|
56 |
+
params:
|
57 |
+
batch_frequency: 5000
|
58 |
+
max_images: 8
|
59 |
+
increase_log_steps: True
|
60 |
+
|
61 |
+
trainer:
|
62 |
+
benchmark: True
|
63 |
+
accumulate_grad_batches: 1
|
64 |
+
gpus: "0"
|
65 |
+
num_nodes: 1
|
easyanimate/vae/ldm/data/dataset_callback.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
#-*- encoding:utf-8 -*-
|
2 |
from pytorch_lightning.callbacks import Callback
|
3 |
|
|
|
4 |
class DatasetCallback(Callback):
|
5 |
def __init__(self):
|
6 |
self.sampler_pos_start = 0
|
|
|
1 |
#-*- encoding:utf-8 -*-
|
2 |
from pytorch_lightning.callbacks import Callback
|
3 |
|
4 |
+
|
5 |
class DatasetCallback(Callback):
|
6 |
def __init__(self):
|
7 |
self.sampler_pos_start = 0
|
easyanimate/vae/ldm/data/dataset_image_video.py
CHANGED
@@ -17,7 +17,7 @@ from decord import VideoReader
|
|
17 |
from func_timeout import FunctionTimedOut, func_set_timeout
|
18 |
from omegaconf import OmegaConf
|
19 |
from PIL import Image
|
20 |
-
from torch.utils.data import
|
21 |
from tqdm import tqdm
|
22 |
|
23 |
from ..modules.image_degradation import (degradation_fn_bsr,
|
@@ -164,15 +164,18 @@ class ImageVideoDataset(Dataset):
|
|
164 |
return self.base[index].get('type', 'image')
|
165 |
|
166 |
def __getitem__(self, i):
|
167 |
-
@func_set_timeout(
|
168 |
def get_video_item(example):
|
169 |
if self.data_root is not None:
|
170 |
video_reader = VideoReader(os.path.join(self.data_root, example['file_path']))
|
171 |
else:
|
172 |
video_reader = VideoReader(example['file_path'])
|
173 |
video_length = len(video_reader)
|
174 |
-
|
175 |
-
|
|
|
|
|
|
|
176 |
start_idx = random.randint(0, video_length - clip_length)
|
177 |
batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.video_len, dtype=int)
|
178 |
|
|
|
17 |
from func_timeout import FunctionTimedOut, func_set_timeout
|
18 |
from omegaconf import OmegaConf
|
19 |
from PIL import Image
|
20 |
+
from torch.utils.data import BatchSampler, Dataset, Sampler
|
21 |
from tqdm import tqdm
|
22 |
|
23 |
from ..modules.image_degradation import (degradation_fn_bsr,
|
|
|
164 |
return self.base[index].get('type', 'image')
|
165 |
|
166 |
def __getitem__(self, i):
|
167 |
+
@func_set_timeout(15) # time wait 3 seconds
|
168 |
def get_video_item(example):
|
169 |
if self.data_root is not None:
|
170 |
video_reader = VideoReader(os.path.join(self.data_root, example['file_path']))
|
171 |
else:
|
172 |
video_reader = VideoReader(example['file_path'])
|
173 |
video_length = len(video_reader)
|
174 |
+
if self.slice_interval == "rand":
|
175 |
+
slice_interval = np.random.choice([1, 2, 3])
|
176 |
+
else:
|
177 |
+
slice_interval = int(self.slice_interval)
|
178 |
+
clip_length = min(video_length, (self.video_len - 1) * slice_interval + 1)
|
179 |
start_idx = random.randint(0, video_length - clip_length)
|
180 |
batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.video_len, dtype=int)
|
181 |
|
easyanimate/vae/ldm/models/casual3dcnn.py
ADDED
@@ -0,0 +1,337 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
from contextlib import contextmanager
|
3 |
+
|
4 |
+
import pytorch_lightning as pl
|
5 |
+
import torch
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
from ..modules.diffusionmodules.model import Decoder, Encoder
|
9 |
+
from ..modules.distributions.distributions import DiagonalGaussianDistribution
|
10 |
+
from ..util import instantiate_from_config
|
11 |
+
from .enc_dec import Decoder as Mag_Decoder
|
12 |
+
from .enc_dec import Encoder as Mag_Encoder
|
13 |
+
|
14 |
+
|
15 |
+
class AutoencoderKLMagvit(pl.LightningModule):
|
16 |
+
def __init__(self,
|
17 |
+
ddconfig,
|
18 |
+
lossconfig,
|
19 |
+
embed_dim,
|
20 |
+
ckpt_path=None,
|
21 |
+
ignore_keys=[],
|
22 |
+
image_key="image",
|
23 |
+
colorize_nlabels=None,
|
24 |
+
monitor=None,
|
25 |
+
):
|
26 |
+
super().__init__()
|
27 |
+
self.image_key = image_key
|
28 |
+
self.encoder = Mag_Encoder()
|
29 |
+
self.decoder = Mag_Decoder()
|
30 |
+
self.loss = instantiate_from_config(lossconfig)
|
31 |
+
self.quant_conv = torch.nn.Conv3d(16, 16, 1)
|
32 |
+
self.post_quant_conv = torch.nn.Conv3d(8, 8, 1)
|
33 |
+
self.embed_dim = embed_dim
|
34 |
+
if colorize_nlabels is not None:
|
35 |
+
assert type(colorize_nlabels)==int
|
36 |
+
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
|
37 |
+
if monitor is not None:
|
38 |
+
self.monitor = monitor
|
39 |
+
if ckpt_path is not None:
|
40 |
+
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
41 |
+
|
42 |
+
def init_from_ckpt(self, path, ignore_keys=list()):
|
43 |
+
sd = torch.load(path, map_location="cpu")["state_dict"]
|
44 |
+
keys = list(sd.keys())
|
45 |
+
for k in keys:
|
46 |
+
for ik in ignore_keys:
|
47 |
+
if k.startswith(ik):
|
48 |
+
print("Deleting key {} from state_dict.".format(k))
|
49 |
+
del sd[k]
|
50 |
+
self.load_state_dict(sd, strict=False)
|
51 |
+
print(f"Restored from {path}")
|
52 |
+
|
53 |
+
def encode(self, x):
|
54 |
+
h = self.encoder(x)
|
55 |
+
moments = self.quant_conv(h)
|
56 |
+
posterior = DiagonalGaussianDistribution(moments)
|
57 |
+
return posterior
|
58 |
+
|
59 |
+
def decode(self, z):
|
60 |
+
z = self.post_quant_conv(z)
|
61 |
+
dec = self.decoder(z)
|
62 |
+
return dec
|
63 |
+
|
64 |
+
def forward(self, input, sample_posterior=True):
|
65 |
+
if input.ndim==4:
|
66 |
+
input = input.unsqueeze(2)
|
67 |
+
posterior = self.encode(input)
|
68 |
+
if sample_posterior:
|
69 |
+
z = posterior.sample()
|
70 |
+
else:
|
71 |
+
z = posterior.mode()
|
72 |
+
dec = self.decode(z)
|
73 |
+
return dec, posterior
|
74 |
+
|
75 |
+
def get_input(self, batch, k):
|
76 |
+
x = batch[k]
|
77 |
+
if x.ndim==5:
|
78 |
+
x = x.permute(0, 4, 1, 2, 3).to(memory_format=torch.contiguous_format).float()
|
79 |
+
return x
|
80 |
+
if len(x.shape) == 3:
|
81 |
+
x = x[..., None]
|
82 |
+
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
|
83 |
+
return x
|
84 |
+
|
85 |
+
def training_step(self, batch, batch_idx, optimizer_idx):
|
86 |
+
# tic = time.time()
|
87 |
+
inputs = self.get_input(batch, self.image_key)
|
88 |
+
# print(f"get_input time {time.time() - tic}")
|
89 |
+
# tic = time.time()
|
90 |
+
reconstructions, posterior = self(inputs)
|
91 |
+
# print(f"model forward time {time.time() - tic}")
|
92 |
+
|
93 |
+
if optimizer_idx == 0:
|
94 |
+
# train encoder+decoder+logvar
|
95 |
+
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
|
96 |
+
last_layer=self.get_last_layer(), split="train")
|
97 |
+
self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
98 |
+
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
|
99 |
+
# print(f"cal loss time {time.time() - tic}")
|
100 |
+
return aeloss
|
101 |
+
|
102 |
+
if optimizer_idx == 1:
|
103 |
+
# train the discriminator
|
104 |
+
discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
|
105 |
+
last_layer=self.get_last_layer(), split="train")
|
106 |
+
|
107 |
+
self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
108 |
+
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
|
109 |
+
# print(f"cal loss time {time.time() - tic}")
|
110 |
+
return discloss
|
111 |
+
|
112 |
+
def validation_step(self, batch, batch_idx):
|
113 |
+
with torch.no_grad():
|
114 |
+
inputs = self.get_input(batch, self.image_key)
|
115 |
+
reconstructions, posterior = self(inputs)
|
116 |
+
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
|
117 |
+
last_layer=self.get_last_layer(), split="val")
|
118 |
+
|
119 |
+
discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
|
120 |
+
last_layer=self.get_last_layer(), split="val")
|
121 |
+
|
122 |
+
self.log("val/rec_loss", log_dict_ae["val/rec_loss"])
|
123 |
+
self.log_dict(log_dict_ae)
|
124 |
+
self.log_dict(log_dict_disc)
|
125 |
+
return self.log_dict
|
126 |
+
|
127 |
+
def configure_optimizers(self):
|
128 |
+
lr = self.learning_rate
|
129 |
+
opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
|
130 |
+
list(self.decoder.parameters())+
|
131 |
+
list(self.quant_conv.parameters())+
|
132 |
+
list(self.post_quant_conv.parameters()),
|
133 |
+
lr=lr, betas=(0.5, 0.9))
|
134 |
+
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
|
135 |
+
lr=lr, betas=(0.5, 0.9))
|
136 |
+
return [opt_ae, opt_disc], []
|
137 |
+
|
138 |
+
def get_last_layer(self):
|
139 |
+
return self.decoder.conv_out.weight
|
140 |
+
|
141 |
+
@torch.no_grad()
|
142 |
+
def log_images(self, batch, only_inputs=False, **kwargs):
|
143 |
+
log = dict()
|
144 |
+
x = self.get_input(batch, self.image_key)
|
145 |
+
x = x.to(self.device)
|
146 |
+
if not only_inputs:
|
147 |
+
xrec, posterior = self(x)
|
148 |
+
if x.shape[1] > 3:
|
149 |
+
# colorize with random projection
|
150 |
+
assert xrec.shape[1] > 3
|
151 |
+
x = self.to_rgb(x)
|
152 |
+
xrec = self.to_rgb(xrec)
|
153 |
+
log["samples"] = self.decode(torch.randn_like(posterior.sample()))
|
154 |
+
log["reconstructions"] = xrec
|
155 |
+
log["inputs"] = x
|
156 |
+
return log
|
157 |
+
|
158 |
+
def to_rgb(self, x):
|
159 |
+
assert self.image_key == "segmentation"
|
160 |
+
if not hasattr(self, "colorize"):
|
161 |
+
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
|
162 |
+
x = F.conv2d(x, weight=self.colorize)
|
163 |
+
x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
|
164 |
+
return x
|
165 |
+
|
166 |
+
class AutoencoderKL(pl.LightningModule):
|
167 |
+
def __init__(self,
|
168 |
+
ddconfig,
|
169 |
+
lossconfig,
|
170 |
+
embed_dim,
|
171 |
+
ckpt_path=None,
|
172 |
+
ignore_keys=[],
|
173 |
+
image_key="image",
|
174 |
+
colorize_nlabels=None,
|
175 |
+
monitor=None,
|
176 |
+
):
|
177 |
+
super().__init__()
|
178 |
+
self.image_key = image_key
|
179 |
+
self.encoder = Encoder(**ddconfig)
|
180 |
+
self.decoder = Decoder(**ddconfig)
|
181 |
+
self.loss = instantiate_from_config(lossconfig)
|
182 |
+
assert ddconfig["double_z"]
|
183 |
+
self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
|
184 |
+
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
185 |
+
self.embed_dim = embed_dim
|
186 |
+
if colorize_nlabels is not None:
|
187 |
+
assert type(colorize_nlabels)==int
|
188 |
+
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
|
189 |
+
if monitor is not None:
|
190 |
+
self.monitor = monitor
|
191 |
+
if ckpt_path is not None:
|
192 |
+
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
193 |
+
|
194 |
+
def init_from_ckpt(self, path, ignore_keys=list()):
|
195 |
+
sd = torch.load(path, map_location="cpu")["state_dict"]
|
196 |
+
keys = list(sd.keys())
|
197 |
+
for k in keys:
|
198 |
+
for ik in ignore_keys:
|
199 |
+
if k.startswith(ik):
|
200 |
+
print("Deleting key {} from state_dict.".format(k))
|
201 |
+
del sd[k]
|
202 |
+
self.load_state_dict(sd, strict=False)
|
203 |
+
print(f"Restored from {path}")
|
204 |
+
|
205 |
+
def encode(self, x):
|
206 |
+
h = self.encoder(x)
|
207 |
+
moments = self.quant_conv(h)
|
208 |
+
posterior = DiagonalGaussianDistribution(moments)
|
209 |
+
return posterior
|
210 |
+
|
211 |
+
def decode(self, z):
|
212 |
+
z = self.post_quant_conv(z)
|
213 |
+
dec = self.decoder(z)
|
214 |
+
return dec
|
215 |
+
|
216 |
+
def forward(self, input, sample_posterior=True):
|
217 |
+
posterior = self.encode(input)
|
218 |
+
if sample_posterior:
|
219 |
+
z = posterior.sample()
|
220 |
+
else:
|
221 |
+
z = posterior.mode()
|
222 |
+
dec = self.decode(z)
|
223 |
+
return dec, posterior
|
224 |
+
|
225 |
+
def get_input(self, batch, k):
|
226 |
+
x = batch[k]
|
227 |
+
if len(x.shape) == 3:
|
228 |
+
x = x[..., None]
|
229 |
+
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
|
230 |
+
return x
|
231 |
+
|
232 |
+
def training_step(self, batch, batch_idx, optimizer_idx):
|
233 |
+
# tic = time.time()
|
234 |
+
inputs = self.get_input(batch, self.image_key)
|
235 |
+
# print(f"get_input time {time.time() - tic}")
|
236 |
+
# tic = time.time()
|
237 |
+
reconstructions, posterior = self(inputs)
|
238 |
+
# print(f"model forward time {time.time() - tic}")
|
239 |
+
tic = time.time()
|
240 |
+
|
241 |
+
if optimizer_idx == 0:
|
242 |
+
# train encoder+decoder+logvar
|
243 |
+
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
|
244 |
+
last_layer=self.get_last_layer(), split="train")
|
245 |
+
self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
246 |
+
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
|
247 |
+
# print(f"cal loss time {time.time() - tic}")
|
248 |
+
return aeloss
|
249 |
+
|
250 |
+
if optimizer_idx == 1:
|
251 |
+
# train the discriminator
|
252 |
+
discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
|
253 |
+
last_layer=self.get_last_layer(), split="train")
|
254 |
+
|
255 |
+
self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
256 |
+
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
|
257 |
+
# print(f"cal loss time {time.time() - tic}")
|
258 |
+
return discloss
|
259 |
+
|
260 |
+
def validation_step(self, batch, batch_idx):
|
261 |
+
tic = time.time()
|
262 |
+
inputs = self.get_input(batch, self.image_key)
|
263 |
+
print(f"get_input time {time.time() - tic}")
|
264 |
+
tic = time.time()
|
265 |
+
reconstructions, posterior = self(inputs)
|
266 |
+
print(f"val forward time {time.time() - tic}")
|
267 |
+
tic = time.time()
|
268 |
+
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
|
269 |
+
last_layer=self.get_last_layer(), split="val")
|
270 |
+
|
271 |
+
discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
|
272 |
+
last_layer=self.get_last_layer(), split="val")
|
273 |
+
|
274 |
+
self.log("val/rec_loss", log_dict_ae["val/rec_loss"])
|
275 |
+
self.log_dict(log_dict_ae)
|
276 |
+
self.log_dict(log_dict_disc)
|
277 |
+
print(f"val end time {time.time() - tic}")
|
278 |
+
return self.log_dict
|
279 |
+
|
280 |
+
def configure_optimizers(self):
|
281 |
+
lr = self.learning_rate
|
282 |
+
opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
|
283 |
+
list(self.decoder.parameters())+
|
284 |
+
list(self.quant_conv.parameters())+
|
285 |
+
list(self.post_quant_conv.parameters()),
|
286 |
+
lr=lr, betas=(0.5, 0.9))
|
287 |
+
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
|
288 |
+
lr=lr, betas=(0.5, 0.9))
|
289 |
+
return [opt_ae, opt_disc], []
|
290 |
+
|
291 |
+
def get_last_layer(self):
|
292 |
+
return self.decoder.conv_out.weight
|
293 |
+
|
294 |
+
@torch.no_grad()
|
295 |
+
def log_images(self, batch, only_inputs=False, **kwargs):
|
296 |
+
log = dict()
|
297 |
+
x = self.get_input(batch, self.image_key)
|
298 |
+
x = x.to(self.device)
|
299 |
+
if not only_inputs:
|
300 |
+
xrec, posterior = self(x)
|
301 |
+
if x.shape[1] > 3:
|
302 |
+
# colorize with random projection
|
303 |
+
assert xrec.shape[1] > 3
|
304 |
+
x = self.to_rgb(x)
|
305 |
+
xrec = self.to_rgb(xrec)
|
306 |
+
log["samples"] = self.decode(torch.randn_like(posterior.sample()))
|
307 |
+
log["reconstructions"] = xrec
|
308 |
+
log["inputs"] = x
|
309 |
+
return log
|
310 |
+
|
311 |
+
def to_rgb(self, x):
|
312 |
+
assert self.image_key == "segmentation"
|
313 |
+
if not hasattr(self, "colorize"):
|
314 |
+
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
|
315 |
+
x = F.conv2d(x, weight=self.colorize)
|
316 |
+
x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
|
317 |
+
return x
|
318 |
+
|
319 |
+
|
320 |
+
class IdentityFirstStage(torch.nn.Module):
|
321 |
+
def __init__(self, *args, vq_interface=False, **kwargs):
|
322 |
+
self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff
|
323 |
+
super().__init__()
|
324 |
+
|
325 |
+
def encode(self, x, *args, **kwargs):
|
326 |
+
return x
|
327 |
+
|
328 |
+
def decode(self, x, *args, **kwargs):
|
329 |
+
return x
|
330 |
+
|
331 |
+
def quantize(self, x, *args, **kwargs):
|
332 |
+
if self.vq_interface:
|
333 |
+
return x, None, [None, None, None]
|
334 |
+
return x
|
335 |
+
|
336 |
+
def forward(self, x, *args, **kwargs):
|
337 |
+
return x
|
easyanimate/vae/ldm/models/cogvideox_casual3dcnn.py
ADDED
@@ -0,0 +1,326 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from typing import Dict, Optional, Tuple
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import pytorch_lightning as pl
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from einops import rearrange
|
10 |
+
|
11 |
+
from ..util import instantiate_from_config
|
12 |
+
from .cogvideox_enc_dec import (CogVideoXDecoder3D, CogVideoXEncoder3D,
|
13 |
+
CogVideoXSafeConv3d)
|
14 |
+
|
15 |
+
|
16 |
+
class DiagonalGaussianDistribution:
|
17 |
+
def __init__(
|
18 |
+
self,
|
19 |
+
mean: torch.Tensor,
|
20 |
+
logvar: torch.Tensor,
|
21 |
+
deterministic: bool = False,
|
22 |
+
):
|
23 |
+
self.mean = mean
|
24 |
+
self.logvar = torch.clamp(logvar, -30.0, 20.0)
|
25 |
+
self.deterministic = deterministic
|
26 |
+
|
27 |
+
if deterministic:
|
28 |
+
self.var = self.std = torch.zeros_like(self.mean)
|
29 |
+
else:
|
30 |
+
self.std = torch.exp(0.5 * self.logvar)
|
31 |
+
self.var = torch.exp(self.logvar)
|
32 |
+
|
33 |
+
def sample(self, generator = None) -> torch.FloatTensor:
|
34 |
+
x = torch.randn(
|
35 |
+
self.mean.shape,
|
36 |
+
generator=generator,
|
37 |
+
device=self.mean.device,
|
38 |
+
dtype=self.mean.dtype,
|
39 |
+
)
|
40 |
+
return self.mean + self.std * x
|
41 |
+
|
42 |
+
def mode(self):
|
43 |
+
return self.mean
|
44 |
+
|
45 |
+
def kl(self, other: Optional["DiagonalGaussianDistribution"] = None) -> torch.Tensor:
|
46 |
+
dims = list(range(1, self.mean.ndim))
|
47 |
+
|
48 |
+
if self.deterministic:
|
49 |
+
return torch.Tensor([0.0])
|
50 |
+
else:
|
51 |
+
if other is None:
|
52 |
+
return 0.5 * torch.sum(
|
53 |
+
torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
|
54 |
+
dim=dims,
|
55 |
+
)
|
56 |
+
else:
|
57 |
+
return 0.5 * torch.sum(
|
58 |
+
torch.pow(self.mean - other.mean, 2) / other.var
|
59 |
+
+ self.var / other.var
|
60 |
+
- 1.0
|
61 |
+
- self.logvar
|
62 |
+
+ other.logvar,
|
63 |
+
dim=dims,
|
64 |
+
)
|
65 |
+
|
66 |
+
def nll(self, sample: torch.Tensor) -> torch.Tensor:
|
67 |
+
dims = list(range(1, self.mean.ndim))
|
68 |
+
|
69 |
+
if self.deterministic:
|
70 |
+
return torch.Tensor([0.0])
|
71 |
+
|
72 |
+
logtwopi = np.log(2.0 * np.pi)
|
73 |
+
return 0.5 * torch.sum(
|
74 |
+
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
|
75 |
+
dim=dims,
|
76 |
+
)
|
77 |
+
|
78 |
+
@dataclass
|
79 |
+
class EncoderOutput:
|
80 |
+
latent_dist: DiagonalGaussianDistribution
|
81 |
+
|
82 |
+
@dataclass
|
83 |
+
class DecoderOutput:
|
84 |
+
sample: torch.Tensor
|
85 |
+
|
86 |
+
def str_eval(item):
|
87 |
+
if type(item) == str:
|
88 |
+
return eval(item)
|
89 |
+
else:
|
90 |
+
return item
|
91 |
+
|
92 |
+
class AutoencoderKLMagvit_CogVideoX(pl.LightningModule):
|
93 |
+
def __init__(
|
94 |
+
self,
|
95 |
+
in_channels: int = 3,
|
96 |
+
out_channels: int = 3,
|
97 |
+
down_block_types: Tuple[str] = (
|
98 |
+
"CogVideoXDownBlock3D",
|
99 |
+
"CogVideoXDownBlock3D",
|
100 |
+
"CogVideoXDownBlock3D",
|
101 |
+
"CogVideoXDownBlock3D",
|
102 |
+
),
|
103 |
+
up_block_types: Tuple[str] = (
|
104 |
+
"CogVideoXUpBlock3D",
|
105 |
+
"CogVideoXUpBlock3D",
|
106 |
+
"CogVideoXUpBlock3D",
|
107 |
+
"CogVideoXUpBlock3D",
|
108 |
+
),
|
109 |
+
block_out_channels: Tuple[int] = (128, 256, 256, 512),
|
110 |
+
latent_channels: int = 16,
|
111 |
+
layers_per_block: int = 3,
|
112 |
+
act_fn: str = "silu",
|
113 |
+
norm_eps: float = 1e-6,
|
114 |
+
norm_num_groups: int = 32,
|
115 |
+
temporal_compression_ratio: float = 4,
|
116 |
+
use_quant_conv: bool = False,
|
117 |
+
use_post_quant_conv: bool = False,
|
118 |
+
|
119 |
+
mini_batch_encoder=4,
|
120 |
+
mini_batch_decoder=1,
|
121 |
+
|
122 |
+
image_key="image",
|
123 |
+
train_decoder_only=False,
|
124 |
+
train_encoder_only=False,
|
125 |
+
monitor=None,
|
126 |
+
ckpt_path=None,
|
127 |
+
lossconfig=None,
|
128 |
+
):
|
129 |
+
super().__init__()
|
130 |
+
self.image_key = image_key
|
131 |
+
down_block_types = str_eval(down_block_types)
|
132 |
+
up_block_types = str_eval(up_block_types)
|
133 |
+
|
134 |
+
self.encoder = CogVideoXEncoder3D(
|
135 |
+
in_channels=in_channels,
|
136 |
+
out_channels=latent_channels,
|
137 |
+
down_block_types=down_block_types,
|
138 |
+
block_out_channels=block_out_channels,
|
139 |
+
layers_per_block=layers_per_block,
|
140 |
+
act_fn=act_fn,
|
141 |
+
norm_eps=norm_eps,
|
142 |
+
norm_num_groups=norm_num_groups,
|
143 |
+
temporal_compression_ratio=temporal_compression_ratio,
|
144 |
+
)
|
145 |
+
|
146 |
+
self.decoder = CogVideoXDecoder3D(
|
147 |
+
in_channels=latent_channels,
|
148 |
+
out_channels=out_channels,
|
149 |
+
up_block_types=up_block_types,
|
150 |
+
block_out_channels=block_out_channels,
|
151 |
+
layers_per_block=layers_per_block,
|
152 |
+
act_fn=act_fn,
|
153 |
+
norm_eps=norm_eps,
|
154 |
+
norm_num_groups=norm_num_groups,
|
155 |
+
temporal_compression_ratio=temporal_compression_ratio,
|
156 |
+
)
|
157 |
+
self.quant_conv = CogVideoXSafeConv3d(2 * out_channels, 2 * out_channels, 1) if use_quant_conv else None
|
158 |
+
self.post_quant_conv = CogVideoXSafeConv3d(out_channels, out_channels, 1) if use_post_quant_conv else None
|
159 |
+
|
160 |
+
self.mini_batch_encoder = mini_batch_encoder
|
161 |
+
self.mini_batch_decoder = mini_batch_decoder
|
162 |
+
self.train_decoder_only = train_decoder_only
|
163 |
+
self.train_encoder_only = train_encoder_only
|
164 |
+
if train_decoder_only:
|
165 |
+
self.encoder.requires_grad_(False)
|
166 |
+
if self.quant_conv is not None:
|
167 |
+
self.quant_conv.requires_grad_(False)
|
168 |
+
if train_encoder_only:
|
169 |
+
self.decoder.requires_grad_(False)
|
170 |
+
if self.post_quant_conv is not None:
|
171 |
+
self.post_quant_conv.requires_grad_(False)
|
172 |
+
if monitor is not None:
|
173 |
+
self.monitor = monitor
|
174 |
+
if ckpt_path is not None:
|
175 |
+
self.init_from_ckpt(ckpt_path, ignore_keys="loss")
|
176 |
+
if lossconfig is not None:
|
177 |
+
self.loss = instantiate_from_config(lossconfig)
|
178 |
+
|
179 |
+
def init_from_ckpt(self, path, ignore_keys=list()):
|
180 |
+
if path.endswith("safetensors"):
|
181 |
+
from safetensors.torch import load_file, safe_open
|
182 |
+
sd = load_file(path)
|
183 |
+
else:
|
184 |
+
sd = torch.load(path, map_location="cpu")
|
185 |
+
if "state_dict" in list(sd.keys()):
|
186 |
+
sd = sd["state_dict"]
|
187 |
+
keys = list(sd.keys())
|
188 |
+
for k in keys:
|
189 |
+
for ik in ignore_keys:
|
190 |
+
if k.startswith(ik):
|
191 |
+
print("Deleting key {} from state_dict.".format(k))
|
192 |
+
del sd[k]
|
193 |
+
m, u = self.load_state_dict(sd, strict=False) # loss.item can be ignored successfully
|
194 |
+
print(f"Restored from {path}")
|
195 |
+
print(f"missing keys: {str(m)}, unexpected keys: {str(u)}")
|
196 |
+
|
197 |
+
def encode(self, x: torch.Tensor) -> EncoderOutput:
|
198 |
+
h = self.encoder(x)
|
199 |
+
self.encoder._clear_fake_context_parallel_cache()
|
200 |
+
|
201 |
+
if self.quant_conv is not None:
|
202 |
+
moments: torch.Tensor = self.quant_conv(h)
|
203 |
+
else:
|
204 |
+
moments: torch.Tensor = h
|
205 |
+
mean, logvar = moments.chunk(2, dim=1)
|
206 |
+
posterior = DiagonalGaussianDistribution(mean, logvar)
|
207 |
+
|
208 |
+
return posterior
|
209 |
+
|
210 |
+
def decode(self, z: torch.Tensor) -> DecoderOutput:
|
211 |
+
if self.post_quant_conv is not None:
|
212 |
+
z = self.post_quant_conv(z)
|
213 |
+
decoded = self.decoder(z)
|
214 |
+
self.decoder._clear_fake_context_parallel_cache()
|
215 |
+
return decoded
|
216 |
+
|
217 |
+
def forward(self, input, sample_posterior=True):
|
218 |
+
if input.ndim==4:
|
219 |
+
input = input.unsqueeze(2)
|
220 |
+
posterior = self.encode(input)
|
221 |
+
if sample_posterior:
|
222 |
+
z = posterior.sample()
|
223 |
+
else:
|
224 |
+
z = posterior.mode()
|
225 |
+
# print("stt latent shape", z.shape)
|
226 |
+
dec = self.decode(z)
|
227 |
+
return dec, posterior
|
228 |
+
|
229 |
+
def get_input(self, batch, k):
|
230 |
+
x = batch[k]
|
231 |
+
if x.ndim==5:
|
232 |
+
x = x.permute(0, 4, 1, 2, 3).to(memory_format=torch.contiguous_format).float()
|
233 |
+
return x
|
234 |
+
if len(x.shape) == 3:
|
235 |
+
x = x[..., None]
|
236 |
+
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
|
237 |
+
return x
|
238 |
+
|
239 |
+
def training_step(self, batch, batch_idx, optimizer_idx):
|
240 |
+
inputs = self.get_input(batch, self.image_key)
|
241 |
+
reconstructions, posterior = self(inputs)
|
242 |
+
|
243 |
+
if optimizer_idx == 0:
|
244 |
+
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
|
245 |
+
last_layer=self.get_last_layer(), split="train")
|
246 |
+
self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
247 |
+
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
|
248 |
+
return aeloss
|
249 |
+
|
250 |
+
if optimizer_idx == 1:
|
251 |
+
discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
|
252 |
+
last_layer=self.get_last_layer(), split="train")
|
253 |
+
|
254 |
+
self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
255 |
+
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
|
256 |
+
return discloss
|
257 |
+
|
258 |
+
def validation_step(self, batch, batch_idx):
|
259 |
+
with torch.no_grad():
|
260 |
+
inputs = self.get_input(batch, self.image_key)
|
261 |
+
reconstructions, posterior = self(inputs)
|
262 |
+
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
|
263 |
+
last_layer=self.get_last_layer(), split="val")
|
264 |
+
|
265 |
+
discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
|
266 |
+
last_layer=self.get_last_layer(), split="val")
|
267 |
+
|
268 |
+
self.log("val/rec_loss", log_dict_ae["val/rec_loss"])
|
269 |
+
self.log_dict(log_dict_ae)
|
270 |
+
self.log_dict(log_dict_disc)
|
271 |
+
return self.log_dict
|
272 |
+
|
273 |
+
def configure_optimizers(self):
|
274 |
+
lr = self.learning_rate
|
275 |
+
if self.train_decoder_only:
|
276 |
+
if self.post_quant_conv is not None:
|
277 |
+
training_list = list(self.decoder.parameters()) + list(self.post_quant_conv.parameters())
|
278 |
+
else:
|
279 |
+
training_list = list(self.decoder.parameters())
|
280 |
+
opt_ae = torch.optim.Adam(training_list, lr=lr, betas=(0.5, 0.9))
|
281 |
+
elif self.train_encoder_only:
|
282 |
+
if self.quant_conv is not None:
|
283 |
+
training_list = list(self.encoder.parameters()) + list(self.quant_conv.parameters())
|
284 |
+
else:
|
285 |
+
training_list = list(self.encoder.parameters())
|
286 |
+
opt_ae = torch.optim.Adam(training_list, lr=lr, betas=(0.5, 0.9))
|
287 |
+
else:
|
288 |
+
training_list = list(self.encoder.parameters()) + list(self.decoder.parameters())
|
289 |
+
if self.quant_conv is not None:
|
290 |
+
training_list = training_list + list(self.quant_conv.parameters())
|
291 |
+
if self.post_quant_conv is not None:
|
292 |
+
training_list = training_list + list(self.post_quant_conv.parameters())
|
293 |
+
opt_ae = torch.optim.Adam(training_list, lr=lr, betas=(0.5, 0.9))
|
294 |
+
opt_disc = torch.optim.Adam(
|
295 |
+
list(self.loss.discriminator3d.parameters()) + list(self.loss.discriminator.parameters()),
|
296 |
+
lr=lr, betas=(0.5, 0.9)
|
297 |
+
)
|
298 |
+
return [opt_ae, opt_disc], []
|
299 |
+
|
300 |
+
def get_last_layer(self):
|
301 |
+
return self.decoder.conv_out.conv.weight
|
302 |
+
|
303 |
+
@torch.no_grad()
|
304 |
+
def log_images(self, batch, only_inputs=False, **kwargs):
|
305 |
+
log = dict()
|
306 |
+
x = self.get_input(batch, self.image_key)
|
307 |
+
x = x.to(self.device)
|
308 |
+
if not only_inputs:
|
309 |
+
xrec, posterior = self(x)
|
310 |
+
if x.shape[1] > 3:
|
311 |
+
# colorize with random projection
|
312 |
+
assert xrec.shape[1] > 3
|
313 |
+
x = self.to_rgb(x)
|
314 |
+
xrec = self.to_rgb(xrec)
|
315 |
+
log["samples"] = self.decode(torch.randn_like(posterior.sample()))
|
316 |
+
log["reconstructions"] = xrec
|
317 |
+
log["inputs"] = x
|
318 |
+
return log
|
319 |
+
|
320 |
+
def to_rgb(self, x):
|
321 |
+
assert self.image_key == "segmentation"
|
322 |
+
if not hasattr(self, "colorize"):
|
323 |
+
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
|
324 |
+
x = F.conv2d(x, weight=self.colorize)
|
325 |
+
x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
|
326 |
+
return x
|
easyanimate/vae/ldm/models/cogvideox_enc_dec.py
ADDED
@@ -0,0 +1,312 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
from typing import Optional, Tuple
|
16 |
+
|
17 |
+
import numpy as np
|
18 |
+
import torch
|
19 |
+
import torch.nn as nn
|
20 |
+
from diffusers.models.autoencoders.autoencoder_kl_cogvideox import (
|
21 |
+
CogVideoXCausalConv3d, CogVideoXDownBlock3D, CogVideoXMidBlock3D,
|
22 |
+
CogVideoXSafeConv3d, CogVideoXSpatialNorm3D, CogVideoXUpBlock3D)
|
23 |
+
from diffusers.utils import logging
|
24 |
+
|
25 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
26 |
+
|
27 |
+
|
28 |
+
class CogVideoXEncoder3D(nn.Module):
|
29 |
+
r"""
|
30 |
+
The `CogVideoXEncoder3D` layer of a variational autoencoder that encodes its input into a latent representation.
|
31 |
+
|
32 |
+
Args:
|
33 |
+
in_channels (`int`, *optional*, defaults to 3):
|
34 |
+
The number of input channels.
|
35 |
+
out_channels (`int`, *optional*, defaults to 3):
|
36 |
+
The number of output channels.
|
37 |
+
down_block_types (`Tuple[str, ...]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
|
38 |
+
The types of down blocks to use. See `~diffusers.models.unet_2d_blocks.get_down_block` for available
|
39 |
+
options.
|
40 |
+
block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
|
41 |
+
The number of output channels for each block.
|
42 |
+
act_fn (`str`, *optional*, defaults to `"silu"`):
|
43 |
+
The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
|
44 |
+
layers_per_block (`int`, *optional*, defaults to 2):
|
45 |
+
The number of layers per block.
|
46 |
+
norm_num_groups (`int`, *optional*, defaults to 32):
|
47 |
+
The number of groups for normalization.
|
48 |
+
"""
|
49 |
+
|
50 |
+
_supports_gradient_checkpointing = True
|
51 |
+
|
52 |
+
def __init__(
|
53 |
+
self,
|
54 |
+
in_channels: int = 3,
|
55 |
+
out_channels: int = 16,
|
56 |
+
down_block_types: Tuple[str, ...] = (
|
57 |
+
"CogVideoXDownBlock3D",
|
58 |
+
"CogVideoXDownBlock3D",
|
59 |
+
"CogVideoXDownBlock3D",
|
60 |
+
"CogVideoXDownBlock3D",
|
61 |
+
),
|
62 |
+
block_out_channels: Tuple[int, ...] = (128, 256, 256, 512),
|
63 |
+
layers_per_block: int = 3,
|
64 |
+
act_fn: str = "silu",
|
65 |
+
norm_eps: float = 1e-6,
|
66 |
+
norm_num_groups: int = 32,
|
67 |
+
dropout: float = 0.0,
|
68 |
+
pad_mode: str = "first",
|
69 |
+
temporal_compression_ratio: float = 4,
|
70 |
+
):
|
71 |
+
super().__init__()
|
72 |
+
|
73 |
+
# log2 of temporal_compress_times
|
74 |
+
temporal_compress_level = int(np.log2(temporal_compression_ratio))
|
75 |
+
|
76 |
+
self.conv_in = CogVideoXCausalConv3d(in_channels, block_out_channels[0], kernel_size=3, pad_mode=pad_mode)
|
77 |
+
self.down_blocks = nn.ModuleList([])
|
78 |
+
|
79 |
+
# down blocks
|
80 |
+
output_channel = block_out_channels[0]
|
81 |
+
for i, down_block_type in enumerate(down_block_types):
|
82 |
+
input_channel = output_channel
|
83 |
+
output_channel = block_out_channels[i]
|
84 |
+
is_final_block = i == len(block_out_channels) - 1
|
85 |
+
compress_time = i < temporal_compress_level
|
86 |
+
|
87 |
+
if down_block_type == "CogVideoXDownBlock3D":
|
88 |
+
down_block = CogVideoXDownBlock3D(
|
89 |
+
in_channels=input_channel,
|
90 |
+
out_channels=output_channel,
|
91 |
+
temb_channels=0,
|
92 |
+
dropout=dropout,
|
93 |
+
num_layers=layers_per_block,
|
94 |
+
resnet_eps=norm_eps,
|
95 |
+
resnet_act_fn=act_fn,
|
96 |
+
resnet_groups=norm_num_groups,
|
97 |
+
add_downsample=not is_final_block,
|
98 |
+
compress_time=compress_time,
|
99 |
+
)
|
100 |
+
else:
|
101 |
+
raise ValueError("Invalid `down_block_type` encountered. Must be `CogVideoXDownBlock3D`")
|
102 |
+
|
103 |
+
self.down_blocks.append(down_block)
|
104 |
+
|
105 |
+
# mid block
|
106 |
+
self.mid_block = CogVideoXMidBlock3D(
|
107 |
+
in_channels=block_out_channels[-1],
|
108 |
+
temb_channels=0,
|
109 |
+
dropout=dropout,
|
110 |
+
num_layers=2,
|
111 |
+
resnet_eps=norm_eps,
|
112 |
+
resnet_act_fn=act_fn,
|
113 |
+
resnet_groups=norm_num_groups,
|
114 |
+
pad_mode=pad_mode,
|
115 |
+
)
|
116 |
+
|
117 |
+
self.norm_out = nn.GroupNorm(norm_num_groups, block_out_channels[-1], eps=1e-6)
|
118 |
+
self.conv_act = nn.SiLU()
|
119 |
+
self.conv_out = CogVideoXCausalConv3d(
|
120 |
+
block_out_channels[-1], 2 * out_channels, kernel_size=3, pad_mode=pad_mode
|
121 |
+
)
|
122 |
+
|
123 |
+
self.gradient_checkpointing = False
|
124 |
+
|
125 |
+
def _clear_fake_context_parallel_cache(self):
|
126 |
+
for name, module in self.named_modules():
|
127 |
+
if isinstance(module, CogVideoXCausalConv3d):
|
128 |
+
logger.debug(f"Clearing fake Context Parallel cache for layer: {name}")
|
129 |
+
module._clear_fake_context_parallel_cache()
|
130 |
+
|
131 |
+
def forward(self, sample: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
|
132 |
+
r"""The forward method of the `CogVideoXEncoder3D` class."""
|
133 |
+
hidden_states = self.conv_in(sample)
|
134 |
+
|
135 |
+
if self.training and self.gradient_checkpointing:
|
136 |
+
|
137 |
+
def create_custom_forward(module):
|
138 |
+
def custom_forward(*inputs):
|
139 |
+
return module(*inputs)
|
140 |
+
|
141 |
+
return custom_forward
|
142 |
+
|
143 |
+
# 1. Down
|
144 |
+
for down_block in self.down_blocks:
|
145 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
146 |
+
create_custom_forward(down_block), hidden_states, temb, None
|
147 |
+
)
|
148 |
+
|
149 |
+
# 2. Mid
|
150 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
151 |
+
create_custom_forward(self.mid_block), hidden_states, temb, None
|
152 |
+
)
|
153 |
+
else:
|
154 |
+
# 1. Down
|
155 |
+
for down_block in self.down_blocks:
|
156 |
+
hidden_states = down_block(hidden_states, temb, None)
|
157 |
+
|
158 |
+
# 2. Mid
|
159 |
+
hidden_states = self.mid_block(hidden_states, temb, None)
|
160 |
+
|
161 |
+
# 3. Post-process
|
162 |
+
hidden_states = self.norm_out(hidden_states)
|
163 |
+
hidden_states = self.conv_act(hidden_states)
|
164 |
+
hidden_states = self.conv_out(hidden_states)
|
165 |
+
return hidden_states
|
166 |
+
|
167 |
+
|
168 |
+
class CogVideoXDecoder3D(nn.Module):
|
169 |
+
r"""
|
170 |
+
The `CogVideoXDecoder3D` layer of a variational autoencoder that decodes its latent representation into an output
|
171 |
+
sample.
|
172 |
+
|
173 |
+
Args:
|
174 |
+
in_channels (`int`, *optional*, defaults to 3):
|
175 |
+
The number of input channels.
|
176 |
+
out_channels (`int`, *optional*, defaults to 3):
|
177 |
+
The number of output channels.
|
178 |
+
up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
|
179 |
+
The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options.
|
180 |
+
block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
|
181 |
+
The number of output channels for each block.
|
182 |
+
act_fn (`str`, *optional*, defaults to `"silu"`):
|
183 |
+
The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
|
184 |
+
layers_per_block (`int`, *optional*, defaults to 2):
|
185 |
+
The number of layers per block.
|
186 |
+
norm_num_groups (`int`, *optional*, defaults to 32):
|
187 |
+
The number of groups for normalization.
|
188 |
+
"""
|
189 |
+
|
190 |
+
_supports_gradient_checkpointing = True
|
191 |
+
|
192 |
+
def __init__(
|
193 |
+
self,
|
194 |
+
in_channels: int = 16,
|
195 |
+
out_channels: int = 3,
|
196 |
+
up_block_types: Tuple[str, ...] = (
|
197 |
+
"CogVideoXUpBlock3D",
|
198 |
+
"CogVideoXUpBlock3D",
|
199 |
+
"CogVideoXUpBlock3D",
|
200 |
+
"CogVideoXUpBlock3D",
|
201 |
+
),
|
202 |
+
block_out_channels: Tuple[int, ...] = (128, 256, 256, 512),
|
203 |
+
layers_per_block: int = 3,
|
204 |
+
act_fn: str = "silu",
|
205 |
+
norm_eps: float = 1e-6,
|
206 |
+
norm_num_groups: int = 32,
|
207 |
+
dropout: float = 0.0,
|
208 |
+
pad_mode: str = "first",
|
209 |
+
temporal_compression_ratio: float = 4,
|
210 |
+
):
|
211 |
+
super().__init__()
|
212 |
+
|
213 |
+
reversed_block_out_channels = list(reversed(block_out_channels))
|
214 |
+
|
215 |
+
self.conv_in = CogVideoXCausalConv3d(
|
216 |
+
in_channels, reversed_block_out_channels[0], kernel_size=3, pad_mode=pad_mode
|
217 |
+
)
|
218 |
+
|
219 |
+
# mid block
|
220 |
+
self.mid_block = CogVideoXMidBlock3D(
|
221 |
+
in_channels=reversed_block_out_channels[0],
|
222 |
+
temb_channels=0,
|
223 |
+
num_layers=2,
|
224 |
+
resnet_eps=norm_eps,
|
225 |
+
resnet_act_fn=act_fn,
|
226 |
+
resnet_groups=norm_num_groups,
|
227 |
+
spatial_norm_dim=in_channels,
|
228 |
+
pad_mode=pad_mode,
|
229 |
+
)
|
230 |
+
|
231 |
+
# up blocks
|
232 |
+
self.up_blocks = nn.ModuleList([])
|
233 |
+
|
234 |
+
output_channel = reversed_block_out_channels[0]
|
235 |
+
temporal_compress_level = int(np.log2(temporal_compression_ratio))
|
236 |
+
|
237 |
+
for i, up_block_type in enumerate(up_block_types):
|
238 |
+
prev_output_channel = output_channel
|
239 |
+
output_channel = reversed_block_out_channels[i]
|
240 |
+
is_final_block = i == len(block_out_channels) - 1
|
241 |
+
compress_time = i < temporal_compress_level
|
242 |
+
|
243 |
+
if up_block_type == "CogVideoXUpBlock3D":
|
244 |
+
up_block = CogVideoXUpBlock3D(
|
245 |
+
in_channels=prev_output_channel,
|
246 |
+
out_channels=output_channel,
|
247 |
+
temb_channels=0,
|
248 |
+
dropout=dropout,
|
249 |
+
num_layers=layers_per_block + 1,
|
250 |
+
resnet_eps=norm_eps,
|
251 |
+
resnet_act_fn=act_fn,
|
252 |
+
resnet_groups=norm_num_groups,
|
253 |
+
spatial_norm_dim=in_channels,
|
254 |
+
add_upsample=not is_final_block,
|
255 |
+
compress_time=compress_time,
|
256 |
+
pad_mode=pad_mode,
|
257 |
+
)
|
258 |
+
prev_output_channel = output_channel
|
259 |
+
else:
|
260 |
+
raise ValueError("Invalid `up_block_type` encountered. Must be `CogVideoXUpBlock3D`")
|
261 |
+
|
262 |
+
self.up_blocks.append(up_block)
|
263 |
+
|
264 |
+
self.norm_out = CogVideoXSpatialNorm3D(reversed_block_out_channels[-1], in_channels, groups=norm_num_groups)
|
265 |
+
self.conv_act = nn.SiLU()
|
266 |
+
self.conv_out = CogVideoXCausalConv3d(
|
267 |
+
reversed_block_out_channels[-1], out_channels, kernel_size=3, pad_mode=pad_mode
|
268 |
+
)
|
269 |
+
|
270 |
+
self.gradient_checkpointing = False
|
271 |
+
|
272 |
+
def _clear_fake_context_parallel_cache(self):
|
273 |
+
for name, module in self.named_modules():
|
274 |
+
if isinstance(module, CogVideoXCausalConv3d):
|
275 |
+
logger.debug(f"Clearing fake Context Parallel cache for layer: {name}")
|
276 |
+
module._clear_fake_context_parallel_cache()
|
277 |
+
|
278 |
+
def forward(self, sample: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
|
279 |
+
r"""The forward method of the `CogVideoXDecoder3D` class."""
|
280 |
+
hidden_states = self.conv_in(sample)
|
281 |
+
|
282 |
+
if self.training and self.gradient_checkpointing:
|
283 |
+
|
284 |
+
def create_custom_forward(module):
|
285 |
+
def custom_forward(*inputs):
|
286 |
+
return module(*inputs)
|
287 |
+
|
288 |
+
return custom_forward
|
289 |
+
|
290 |
+
# 1. Mid
|
291 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
292 |
+
create_custom_forward(self.mid_block), hidden_states, temb, sample
|
293 |
+
)
|
294 |
+
|
295 |
+
# 2. Up
|
296 |
+
for up_block in self.up_blocks:
|
297 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
298 |
+
create_custom_forward(up_block), hidden_states, temb, sample
|
299 |
+
)
|
300 |
+
else:
|
301 |
+
# 1. Mid
|
302 |
+
hidden_states = self.mid_block(hidden_states, temb, sample)
|
303 |
+
|
304 |
+
# 2. Up
|
305 |
+
for up_block in self.up_blocks:
|
306 |
+
hidden_states = up_block(hidden_states, temb, sample)
|
307 |
+
|
308 |
+
# 3. Post-process
|
309 |
+
hidden_states = self.norm_out(hidden_states, sample)
|
310 |
+
hidden_states = self.conv_act(hidden_states)
|
311 |
+
hidden_states = self.conv_out(hidden_states)
|
312 |
+
return hidden_states
|
easyanimate/vae/ldm/models/{enc_dec_pytorch.py → enc_dec.py}
RENAMED
File without changes
|
easyanimate/vae/ldm/models/omnigen_casual3dcnn.py
CHANGED
@@ -1,4 +1,3 @@
|
|
1 |
-
import itertools
|
2 |
from dataclasses import dataclass
|
3 |
from typing import Optional
|
4 |
|
@@ -112,10 +111,15 @@ class AutoencoderKLMagvit_fromOmnigen(pl.LightningModule):
|
|
112 |
monitor=None,
|
113 |
ckpt_path=None,
|
114 |
lossconfig=None,
|
|
|
115 |
slice_compression_vae=False,
|
|
|
|
|
|
|
116 |
mini_batch_encoder=9,
|
117 |
mini_batch_decoder=3,
|
118 |
train_decoder_only=False,
|
|
|
119 |
):
|
120 |
super().__init__()
|
121 |
self.image_key = image_key
|
@@ -137,7 +141,10 @@ class AutoencoderKLMagvit_fromOmnigen(pl.LightningModule):
|
|
137 |
act_fn=act_fn,
|
138 |
num_attention_heads=num_attention_heads,
|
139 |
double_z=True,
|
|
|
140 |
slice_compression_vae=slice_compression_vae,
|
|
|
|
|
141 |
mini_batch_encoder=mini_batch_encoder,
|
142 |
)
|
143 |
|
@@ -156,7 +163,11 @@ class AutoencoderKLMagvit_fromOmnigen(pl.LightningModule):
|
|
156 |
norm_num_groups=norm_num_groups,
|
157 |
act_fn=act_fn,
|
158 |
num_attention_heads=num_attention_heads,
|
|
|
159 |
slice_compression_vae=slice_compression_vae,
|
|
|
|
|
|
|
160 |
mini_batch_decoder=mini_batch_decoder,
|
161 |
)
|
162 |
|
@@ -166,9 +177,15 @@ class AutoencoderKLMagvit_fromOmnigen(pl.LightningModule):
|
|
166 |
self.mini_batch_encoder = mini_batch_encoder
|
167 |
self.mini_batch_decoder = mini_batch_decoder
|
168 |
self.train_decoder_only = train_decoder_only
|
|
|
169 |
if train_decoder_only:
|
170 |
self.encoder.requires_grad_(False)
|
171 |
-
self.quant_conv
|
|
|
|
|
|
|
|
|
|
|
172 |
if monitor is not None:
|
173 |
self.monitor = monitor
|
174 |
if ckpt_path is not None:
|
@@ -190,28 +207,28 @@ class AutoencoderKLMagvit_fromOmnigen(pl.LightningModule):
|
|
190 |
if k.startswith(ik):
|
191 |
print("Deleting key {} from state_dict.".format(k))
|
192 |
del sd[k]
|
193 |
-
self.load_state_dict(sd, strict=False) # loss.item can be ignored successfully
|
194 |
print(f"Restored from {path}")
|
|
|
195 |
|
196 |
def encode(self, x: torch.Tensor) -> EncoderOutput:
|
197 |
h = self.encoder(x)
|
198 |
|
199 |
-
|
|
|
|
|
|
|
200 |
mean, logvar = moments.chunk(2, dim=1)
|
201 |
posterior = DiagonalGaussianDistribution(mean, logvar)
|
202 |
|
203 |
-
# return EncoderOutput(latent_dist=posterior)
|
204 |
return posterior
|
205 |
|
206 |
def decode(self, z: torch.Tensor) -> DecoderOutput:
|
207 |
-
|
208 |
-
|
209 |
decoded = self.decoder(z)
|
210 |
-
|
211 |
-
# return DecoderOutput(sample=decoded)
|
212 |
return decoded
|
213 |
|
214 |
-
|
215 |
def forward(self, input, sample_posterior=True):
|
216 |
if input.ndim==4:
|
217 |
input = input.unsqueeze(2)
|
@@ -235,30 +252,22 @@ class AutoencoderKLMagvit_fromOmnigen(pl.LightningModule):
|
|
235 |
return x
|
236 |
|
237 |
def training_step(self, batch, batch_idx, optimizer_idx):
|
238 |
-
# tic = time.time()
|
239 |
inputs = self.get_input(batch, self.image_key)
|
240 |
-
# print(f"get_input time {time.time() - tic}")
|
241 |
-
# tic = time.time()
|
242 |
reconstructions, posterior = self(inputs)
|
243 |
-
# print(f"model forward time {time.time() - tic}")
|
244 |
|
245 |
if optimizer_idx == 0:
|
246 |
-
# train encoder+decoder+logvar
|
247 |
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
|
248 |
last_layer=self.get_last_layer(), split="train")
|
249 |
self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
250 |
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
|
251 |
-
# print(f"cal loss time {time.time() - tic}")
|
252 |
return aeloss
|
253 |
|
254 |
if optimizer_idx == 1:
|
255 |
-
# train the discriminator
|
256 |
discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
|
257 |
last_layer=self.get_last_layer(), split="train")
|
258 |
|
259 |
self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
260 |
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
|
261 |
-
# print(f"cal loss time {time.time() - tic}")
|
262 |
return discloss
|
263 |
|
264 |
def validation_step(self, batch, batch_idx):
|
@@ -279,17 +288,28 @@ class AutoencoderKLMagvit_fromOmnigen(pl.LightningModule):
|
|
279 |
def configure_optimizers(self):
|
280 |
lr = self.learning_rate
|
281 |
if self.train_decoder_only:
|
282 |
-
|
283 |
-
|
284 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
285 |
else:
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
|
|
|
|
|
|
|
293 |
return [opt_ae, opt_disc], []
|
294 |
|
295 |
def get_last_layer(self):
|
|
|
|
|
1 |
from dataclasses import dataclass
|
2 |
from typing import Optional
|
3 |
|
|
|
111 |
monitor=None,
|
112 |
ckpt_path=None,
|
113 |
lossconfig=None,
|
114 |
+
slice_mag_vae=False,
|
115 |
slice_compression_vae=False,
|
116 |
+
cache_compression_vae=False,
|
117 |
+
cache_mag_vae=False,
|
118 |
+
spatial_group_norm=False,
|
119 |
mini_batch_encoder=9,
|
120 |
mini_batch_decoder=3,
|
121 |
train_decoder_only=False,
|
122 |
+
train_encoder_only=False,
|
123 |
):
|
124 |
super().__init__()
|
125 |
self.image_key = image_key
|
|
|
141 |
act_fn=act_fn,
|
142 |
num_attention_heads=num_attention_heads,
|
143 |
double_z=True,
|
144 |
+
slice_mag_vae=slice_mag_vae,
|
145 |
slice_compression_vae=slice_compression_vae,
|
146 |
+
cache_compression_vae=cache_compression_vae,
|
147 |
+
spatial_group_norm=spatial_group_norm,
|
148 |
mini_batch_encoder=mini_batch_encoder,
|
149 |
)
|
150 |
|
|
|
163 |
norm_num_groups=norm_num_groups,
|
164 |
act_fn=act_fn,
|
165 |
num_attention_heads=num_attention_heads,
|
166 |
+
slice_mag_vae=slice_mag_vae,
|
167 |
slice_compression_vae=slice_compression_vae,
|
168 |
+
cache_compression_vae=cache_compression_vae,
|
169 |
+
cache_mag_vae=cache_mag_vae,
|
170 |
+
spatial_group_norm=spatial_group_norm,
|
171 |
mini_batch_decoder=mini_batch_decoder,
|
172 |
)
|
173 |
|
|
|
177 |
self.mini_batch_encoder = mini_batch_encoder
|
178 |
self.mini_batch_decoder = mini_batch_decoder
|
179 |
self.train_decoder_only = train_decoder_only
|
180 |
+
self.train_encoder_only = train_encoder_only
|
181 |
if train_decoder_only:
|
182 |
self.encoder.requires_grad_(False)
|
183 |
+
if self.quant_conv is not None:
|
184 |
+
self.quant_conv.requires_grad_(False)
|
185 |
+
if train_encoder_only:
|
186 |
+
self.decoder.requires_grad_(False)
|
187 |
+
if self.post_quant_conv is not None:
|
188 |
+
self.post_quant_conv.requires_grad_(False)
|
189 |
if monitor is not None:
|
190 |
self.monitor = monitor
|
191 |
if ckpt_path is not None:
|
|
|
207 |
if k.startswith(ik):
|
208 |
print("Deleting key {} from state_dict.".format(k))
|
209 |
del sd[k]
|
210 |
+
m, u = self.load_state_dict(sd, strict=False) # loss.item can be ignored successfully
|
211 |
print(f"Restored from {path}")
|
212 |
+
print(f"missing keys: {str(m)}, unexpected keys: {str(u)}")
|
213 |
|
214 |
def encode(self, x: torch.Tensor) -> EncoderOutput:
|
215 |
h = self.encoder(x)
|
216 |
|
217 |
+
if self.quant_conv is not None:
|
218 |
+
moments: torch.Tensor = self.quant_conv(h)
|
219 |
+
else:
|
220 |
+
moments: torch.Tensor = h
|
221 |
mean, logvar = moments.chunk(2, dim=1)
|
222 |
posterior = DiagonalGaussianDistribution(mean, logvar)
|
223 |
|
|
|
224 |
return posterior
|
225 |
|
226 |
def decode(self, z: torch.Tensor) -> DecoderOutput:
|
227 |
+
if self.post_quant_conv is not None:
|
228 |
+
z = self.post_quant_conv(z)
|
229 |
decoded = self.decoder(z)
|
|
|
|
|
230 |
return decoded
|
231 |
|
|
|
232 |
def forward(self, input, sample_posterior=True):
|
233 |
if input.ndim==4:
|
234 |
input = input.unsqueeze(2)
|
|
|
252 |
return x
|
253 |
|
254 |
def training_step(self, batch, batch_idx, optimizer_idx):
|
|
|
255 |
inputs = self.get_input(batch, self.image_key)
|
|
|
|
|
256 |
reconstructions, posterior = self(inputs)
|
|
|
257 |
|
258 |
if optimizer_idx == 0:
|
|
|
259 |
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
|
260 |
last_layer=self.get_last_layer(), split="train")
|
261 |
self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
262 |
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
|
|
|
263 |
return aeloss
|
264 |
|
265 |
if optimizer_idx == 1:
|
|
|
266 |
discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
|
267 |
last_layer=self.get_last_layer(), split="train")
|
268 |
|
269 |
self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
270 |
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
|
|
|
271 |
return discloss
|
272 |
|
273 |
def validation_step(self, batch, batch_idx):
|
|
|
288 |
def configure_optimizers(self):
|
289 |
lr = self.learning_rate
|
290 |
if self.train_decoder_only:
|
291 |
+
if self.post_quant_conv is not None:
|
292 |
+
training_list = list(self.decoder.parameters()) + list(self.post_quant_conv.parameters())
|
293 |
+
else:
|
294 |
+
training_list = list(self.decoder.parameters())
|
295 |
+
opt_ae = torch.optim.Adam(training_list, lr=lr, betas=(0.5, 0.9))
|
296 |
+
elif self.train_encoder_only:
|
297 |
+
if self.quant_conv is not None:
|
298 |
+
training_list = list(self.encoder.parameters()) + list(self.quant_conv.parameters())
|
299 |
+
else:
|
300 |
+
training_list = list(self.encoder.parameters())
|
301 |
+
opt_ae = torch.optim.Adam(training_list, lr=lr, betas=(0.5, 0.9))
|
302 |
else:
|
303 |
+
training_list = list(self.encoder.parameters()) + list(self.decoder.parameters())
|
304 |
+
if self.quant_conv is not None:
|
305 |
+
training_list = training_list + list(self.quant_conv.parameters())
|
306 |
+
if self.post_quant_conv is not None:
|
307 |
+
training_list = training_list + list(self.post_quant_conv.parameters())
|
308 |
+
opt_ae = torch.optim.Adam(training_list, lr=lr, betas=(0.5, 0.9))
|
309 |
+
opt_disc = torch.optim.Adam(
|
310 |
+
list(self.loss.discriminator3d.parameters()) + list(self.loss.discriminator.parameters()),
|
311 |
+
lr=lr, betas=(0.5, 0.9)
|
312 |
+
)
|
313 |
return [opt_ae, opt_disc], []
|
314 |
|
315 |
def get_last_layer(self):
|
easyanimate/vae/ldm/models/omnigen_enc_dec.py
CHANGED
@@ -1,6 +1,10 @@
|
|
|
|
|
|
1 |
import torch
|
2 |
import torch.nn as nn
|
3 |
-
|
|
|
|
|
4 |
from ..modules.vaemodules.activations import get_activation
|
5 |
from ..modules.vaemodules.common import CausalConv3d
|
6 |
from ..modules.vaemodules.down_blocks import get_down_block
|
@@ -8,6 +12,16 @@ from ..modules.vaemodules.mid_blocks import get_mid_block
|
|
8 |
from ..modules.vaemodules.up_blocks import get_up_block
|
9 |
|
10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
class Encoder(nn.Module):
|
12 |
r"""
|
13 |
The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation.
|
@@ -54,7 +68,11 @@ class Encoder(nn.Module):
|
|
54 |
act_fn: str = "silu",
|
55 |
num_attention_heads: int = 1,
|
56 |
double_z: bool = True,
|
|
|
57 |
slice_compression_vae: bool = False,
|
|
|
|
|
|
|
58 |
mini_batch_encoder: int = 9,
|
59 |
verbose = False,
|
60 |
):
|
@@ -118,9 +136,12 @@ class Encoder(nn.Module):
|
|
118 |
conv_out_channels = 2 * out_channels if double_z else out_channels
|
119 |
self.conv_out = CausalConv3d(block_out_channels[-1], conv_out_channels, kernel_size=3)
|
120 |
|
|
|
121 |
self.slice_compression_vae = slice_compression_vae
|
|
|
|
|
122 |
self.mini_batch_encoder = mini_batch_encoder
|
123 |
-
self.
|
124 |
self.verbose = verbose
|
125 |
|
126 |
def set_padding_one_frame(self):
|
@@ -145,36 +166,142 @@ class Encoder(nn.Module):
|
|
145 |
for name, module in self.named_children():
|
146 |
_set_padding_more_frame(name, module)
|
147 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
148 |
def single_forward(self, x: torch.Tensor, previous_features: torch.Tensor, after_features: torch.Tensor) -> torch.Tensor:
|
149 |
# x: (B, C, T, H, W)
|
150 |
-
if self.
|
|
|
|
|
151 |
x = torch.concat([previous_features, x], 2)
|
152 |
-
elif
|
153 |
x = torch.concat([x, after_features], 2)
|
154 |
-
elif
|
155 |
x = torch.concat([previous_features, x, after_features], 2)
|
156 |
|
157 |
-
|
158 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
159 |
for down_block in self.down_blocks:
|
160 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
161 |
|
162 |
x = self.mid_block(x)
|
163 |
|
164 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
165 |
x = self.conv_act(x)
|
166 |
x = self.conv_out(x)
|
167 |
|
168 |
-
if
|
169 |
x = x[:, :, 1:]
|
170 |
-
elif
|
171 |
x = x[:, :, :2]
|
172 |
-
elif
|
173 |
x = x[:, :, 1:3]
|
174 |
return x
|
175 |
|
176 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
177 |
-
if self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
178 |
_, _, f, _, _ = x.size()
|
179 |
if f % 2 != 0:
|
180 |
self.set_padding_one_frame()
|
@@ -188,11 +315,15 @@ class Encoder(nn.Module):
|
|
188 |
new_pixel_values = []
|
189 |
start_index = 0
|
190 |
|
191 |
-
previous_features = None
|
192 |
for i in range(start_index, x.shape[2], self.mini_batch_encoder):
|
193 |
-
|
194 |
-
|
195 |
-
|
|
|
|
|
|
|
|
|
|
|
196 |
new_pixel_values.append(next_frames)
|
197 |
new_pixel_values = torch.cat(new_pixel_values, dim=2)
|
198 |
else:
|
@@ -242,7 +373,11 @@ class Decoder(nn.Module):
|
|
242 |
norm_num_groups: int = 32,
|
243 |
act_fn: str = "silu",
|
244 |
num_attention_heads: int = 1,
|
|
|
245 |
slice_compression_vae: bool = False,
|
|
|
|
|
|
|
246 |
mini_batch_decoder: int = 3,
|
247 |
verbose = False,
|
248 |
):
|
@@ -309,9 +444,12 @@ class Decoder(nn.Module):
|
|
309 |
|
310 |
self.conv_out = CausalConv3d(block_out_channels[0], out_channels, kernel_size=3)
|
311 |
|
|
|
312 |
self.slice_compression_vae = slice_compression_vae
|
|
|
|
|
313 |
self.mini_batch_decoder = mini_batch_decoder
|
314 |
-
self.
|
315 |
self.verbose = verbose
|
316 |
|
317 |
def set_padding_one_frame(self):
|
@@ -335,22 +473,90 @@ class Decoder(nn.Module):
|
|
335 |
_set_padding_more_frame(sub_name, sub_mod)
|
336 |
for name, module in self.named_children():
|
337 |
_set_padding_more_frame(name, module)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
338 |
|
339 |
def single_forward(self, x: torch.Tensor, previous_features: torch.Tensor, after_features: torch.Tensor) -> torch.Tensor:
|
340 |
# x: (B, C, T, H, W)
|
341 |
-
if self.
|
|
|
|
|
342 |
b, c, t, h, w = x.size()
|
343 |
x = torch.concat([previous_features, x], 2)
|
344 |
x = self.conv_in(x)
|
345 |
x = self.mid_block(x)
|
346 |
x = x[:, :, -t:]
|
347 |
-
elif
|
348 |
b, c, t, h, w = x.size()
|
349 |
x = torch.concat([x, after_features], 2)
|
350 |
x = self.conv_in(x)
|
351 |
x = self.mid_block(x)
|
352 |
x = x[:, :, :t]
|
353 |
-
elif
|
354 |
_, _, t_1, _, _ = previous_features.size()
|
355 |
_, _, t_2, _, _ = x.size()
|
356 |
x = torch.concat([previous_features, x, after_features], 2)
|
@@ -358,20 +564,76 @@ class Decoder(nn.Module):
|
|
358 |
x = self.mid_block(x)
|
359 |
x = x[:, :, t_1:(t_1 + t_2)]
|
360 |
else:
|
361 |
-
|
362 |
-
|
363 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
364 |
for up_block in self.up_blocks:
|
365 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
366 |
|
367 |
-
x = self.conv_norm_out(x)
|
368 |
x = self.conv_act(x)
|
369 |
x = self.conv_out(x)
|
370 |
|
371 |
return x
|
372 |
|
373 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
374 |
-
if self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
375 |
_, _, f, _, _ = x.size()
|
376 |
if f % 2 != 0:
|
377 |
self.set_padding_one_frame()
|
@@ -391,6 +653,13 @@ class Decoder(nn.Module):
|
|
391 |
previous_features = x[:, :, i: i + self.mini_batch_decoder, :, :]
|
392 |
new_pixel_values.append(next_frames)
|
393 |
new_pixel_values = torch.cat(new_pixel_values, dim=2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
394 |
else:
|
395 |
new_pixel_values = self.single_forward(x, None, None)
|
396 |
return new_pixel_values
|
|
|
1 |
+
from typing import Any, Dict
|
2 |
+
|
3 |
import torch
|
4 |
import torch.nn as nn
|
5 |
+
from diffusers.utils import is_torch_version
|
6 |
+
from einops import rearrange
|
7 |
+
|
8 |
from ..modules.vaemodules.activations import get_activation
|
9 |
from ..modules.vaemodules.common import CausalConv3d
|
10 |
from ..modules.vaemodules.down_blocks import get_down_block
|
|
|
12 |
from ..modules.vaemodules.up_blocks import get_up_block
|
13 |
|
14 |
|
15 |
+
def create_custom_forward(module, return_dict=None):
|
16 |
+
def custom_forward(*inputs):
|
17 |
+
if return_dict is not None:
|
18 |
+
return module(*inputs, return_dict=return_dict)
|
19 |
+
else:
|
20 |
+
return module(*inputs)
|
21 |
+
|
22 |
+
return custom_forward
|
23 |
+
|
24 |
+
|
25 |
class Encoder(nn.Module):
|
26 |
r"""
|
27 |
The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation.
|
|
|
68 |
act_fn: str = "silu",
|
69 |
num_attention_heads: int = 1,
|
70 |
double_z: bool = True,
|
71 |
+
slice_mag_vae: bool = False,
|
72 |
slice_compression_vae: bool = False,
|
73 |
+
cache_compression_vae: bool = False,
|
74 |
+
cache_mag_vae: bool = False,
|
75 |
+
spatial_group_norm: bool = False,
|
76 |
mini_batch_encoder: int = 9,
|
77 |
verbose = False,
|
78 |
):
|
|
|
136 |
conv_out_channels = 2 * out_channels if double_z else out_channels
|
137 |
self.conv_out = CausalConv3d(block_out_channels[-1], conv_out_channels, kernel_size=3)
|
138 |
|
139 |
+
self.slice_mag_vae = slice_mag_vae
|
140 |
self.slice_compression_vae = slice_compression_vae
|
141 |
+
self.cache_compression_vae = cache_compression_vae
|
142 |
+
self.cache_mag_vae = cache_mag_vae
|
143 |
self.mini_batch_encoder = mini_batch_encoder
|
144 |
+
self.spatial_group_norm = spatial_group_norm
|
145 |
self.verbose = verbose
|
146 |
|
147 |
def set_padding_one_frame(self):
|
|
|
166 |
for name, module in self.named_children():
|
167 |
_set_padding_more_frame(name, module)
|
168 |
|
169 |
+
def set_magvit_padding_one_frame(self):
|
170 |
+
def _set_magvit_padding_one_frame(name, module):
|
171 |
+
if hasattr(module, 'padding_flag'):
|
172 |
+
if self.verbose:
|
173 |
+
print('Set pad mode for module[%s] type=%s' % (name, str(type(module))))
|
174 |
+
module.padding_flag = 3
|
175 |
+
for sub_name, sub_mod in module.named_children():
|
176 |
+
_set_magvit_padding_one_frame(sub_name, sub_mod)
|
177 |
+
for name, module in self.named_children():
|
178 |
+
_set_magvit_padding_one_frame(name, module)
|
179 |
+
|
180 |
+
def set_magvit_padding_more_frame(self):
|
181 |
+
def _set_magvit_padding_more_frame(name, module):
|
182 |
+
if hasattr(module, 'padding_flag'):
|
183 |
+
if self.verbose:
|
184 |
+
print('Set pad mode for module[%s] type=%s' % (name, str(type(module))))
|
185 |
+
module.padding_flag = 4
|
186 |
+
for sub_name, sub_mod in module.named_children():
|
187 |
+
_set_magvit_padding_more_frame(sub_name, sub_mod)
|
188 |
+
for name, module in self.named_children():
|
189 |
+
_set_magvit_padding_more_frame(name, module)
|
190 |
+
|
191 |
+
def set_cache_slice_vae_padding_one_frame(self):
|
192 |
+
def _set_cache_slice_vae_padding_one_frame(name, module):
|
193 |
+
if hasattr(module, 'padding_flag'):
|
194 |
+
if self.verbose:
|
195 |
+
print('Set pad mode for module[%s] type=%s' % (name, str(type(module))))
|
196 |
+
module.padding_flag = 5
|
197 |
+
for sub_name, sub_mod in module.named_children():
|
198 |
+
_set_cache_slice_vae_padding_one_frame(sub_name, sub_mod)
|
199 |
+
for name, module in self.named_children():
|
200 |
+
_set_cache_slice_vae_padding_one_frame(name, module)
|
201 |
+
|
202 |
+
def set_cache_slice_vae_padding_more_frame(self):
|
203 |
+
def _set_cache_slice_vae_padding_more_frame(name, module):
|
204 |
+
if hasattr(module, 'padding_flag'):
|
205 |
+
if self.verbose:
|
206 |
+
print('Set pad mode for module[%s] type=%s' % (name, str(type(module))))
|
207 |
+
module.padding_flag = 6
|
208 |
+
for sub_name, sub_mod in module.named_children():
|
209 |
+
_set_cache_slice_vae_padding_more_frame(sub_name, sub_mod)
|
210 |
+
for name, module in self.named_children():
|
211 |
+
_set_cache_slice_vae_padding_more_frame(name, module)
|
212 |
+
|
213 |
+
def set_3dgroupnorm_for_submodule(self):
|
214 |
+
def _set_3dgroupnorm_for_submodule(name, module):
|
215 |
+
if hasattr(module, 'set_3dgroupnorm'):
|
216 |
+
if self.verbose:
|
217 |
+
print('Set pad mode for module[%s] type=%s' % (name, str(type(module))))
|
218 |
+
module.set_3dgroupnorm = True
|
219 |
+
for sub_name, sub_mod in module.named_children():
|
220 |
+
_set_3dgroupnorm_for_submodule(sub_name, sub_mod)
|
221 |
+
for name, module in self.named_children():
|
222 |
+
_set_3dgroupnorm_for_submodule(name, module)
|
223 |
+
|
224 |
def single_forward(self, x: torch.Tensor, previous_features: torch.Tensor, after_features: torch.Tensor) -> torch.Tensor:
|
225 |
# x: (B, C, T, H, W)
|
226 |
+
if self.training:
|
227 |
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
228 |
+
if previous_features is not None and after_features is None:
|
229 |
x = torch.concat([previous_features, x], 2)
|
230 |
+
elif previous_features is None and after_features is not None:
|
231 |
x = torch.concat([x, after_features], 2)
|
232 |
+
elif previous_features is not None and after_features is not None:
|
233 |
x = torch.concat([previous_features, x, after_features], 2)
|
234 |
|
235 |
+
if self.training:
|
236 |
+
x = torch.utils.checkpoint.checkpoint(
|
237 |
+
create_custom_forward(self.conv_in),
|
238 |
+
x,
|
239 |
+
**ckpt_kwargs,
|
240 |
+
)
|
241 |
+
else:
|
242 |
+
x = self.conv_in(x)
|
243 |
for down_block in self.down_blocks:
|
244 |
+
if self.training:
|
245 |
+
x = torch.utils.checkpoint.checkpoint(
|
246 |
+
create_custom_forward(down_block),
|
247 |
+
x,
|
248 |
+
**ckpt_kwargs,
|
249 |
+
)
|
250 |
+
else:
|
251 |
+
x = down_block(x)
|
252 |
|
253 |
x = self.mid_block(x)
|
254 |
|
255 |
+
if self.spatial_group_norm:
|
256 |
+
batch_size = x.shape[0]
|
257 |
+
x = rearrange(x, "b c t h w -> (b t) c h w")
|
258 |
+
x = self.conv_norm_out(x)
|
259 |
+
x = rearrange(x, "(b t) c h w -> b c t h w", b=batch_size)
|
260 |
+
else:
|
261 |
+
x = self.conv_norm_out(x)
|
262 |
x = self.conv_act(x)
|
263 |
x = self.conv_out(x)
|
264 |
|
265 |
+
if previous_features is not None and after_features is None:
|
266 |
x = x[:, :, 1:]
|
267 |
+
elif previous_features is None and after_features is not None:
|
268 |
x = x[:, :, :2]
|
269 |
+
elif previous_features is not None and after_features is not None:
|
270 |
x = x[:, :, 1:3]
|
271 |
return x
|
272 |
|
273 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
274 |
+
if self.spatial_group_norm:
|
275 |
+
self.set_3dgroupnorm_for_submodule()
|
276 |
+
|
277 |
+
if self.cache_mag_vae:
|
278 |
+
self.set_magvit_padding_one_frame()
|
279 |
+
first_frames = self.single_forward(x[:, :, 0:1, :, :], None, None)
|
280 |
+
self.set_magvit_padding_more_frame()
|
281 |
+
new_pixel_values = [first_frames]
|
282 |
+
for i in range(1, x.shape[2], self.mini_batch_encoder):
|
283 |
+
next_frames = self.single_forward(x[:, :, i: i + self.mini_batch_encoder, :, :], None, None)
|
284 |
+
new_pixel_values.append(next_frames)
|
285 |
+
new_pixel_values = torch.cat(new_pixel_values, dim=2)
|
286 |
+
elif self.cache_compression_vae:
|
287 |
+
_, _, f, _, _ = x.size()
|
288 |
+
if f % 2 != 0:
|
289 |
+
self.set_padding_one_frame()
|
290 |
+
first_frames = self.single_forward(x[:, :, 0:1, :, :], None, None)
|
291 |
+
self.set_padding_more_frame()
|
292 |
+
|
293 |
+
new_pixel_values = [first_frames]
|
294 |
+
start_index = 1
|
295 |
+
else:
|
296 |
+
self.set_padding_more_frame()
|
297 |
+
new_pixel_values = []
|
298 |
+
start_index = 0
|
299 |
+
|
300 |
+
for i in range(start_index, x.shape[2], self.mini_batch_encoder):
|
301 |
+
next_frames = self.single_forward(x[:, :, i: i + self.mini_batch_encoder, :, :], None, None)
|
302 |
+
new_pixel_values.append(next_frames)
|
303 |
+
new_pixel_values = torch.cat(new_pixel_values, dim=2)
|
304 |
+
elif self.slice_compression_vae:
|
305 |
_, _, f, _, _ = x.size()
|
306 |
if f % 2 != 0:
|
307 |
self.set_padding_one_frame()
|
|
|
315 |
new_pixel_values = []
|
316 |
start_index = 0
|
317 |
|
|
|
318 |
for i in range(start_index, x.shape[2], self.mini_batch_encoder):
|
319 |
+
next_frames = self.single_forward(x[:, :, i: i + self.mini_batch_encoder, :, :], None, None)
|
320 |
+
new_pixel_values.append(next_frames)
|
321 |
+
new_pixel_values = torch.cat(new_pixel_values, dim=2)
|
322 |
+
elif self.slice_mag_vae:
|
323 |
+
_, _, f, _, _ = x.size()
|
324 |
+
new_pixel_values = []
|
325 |
+
for i in range(0, x.shape[2], self.mini_batch_encoder):
|
326 |
+
next_frames = self.single_forward(x[:, :, i: i + self.mini_batch_encoder, :, :], None, None)
|
327 |
new_pixel_values.append(next_frames)
|
328 |
new_pixel_values = torch.cat(new_pixel_values, dim=2)
|
329 |
else:
|
|
|
373 |
norm_num_groups: int = 32,
|
374 |
act_fn: str = "silu",
|
375 |
num_attention_heads: int = 1,
|
376 |
+
slice_mag_vae: bool = False,
|
377 |
slice_compression_vae: bool = False,
|
378 |
+
cache_compression_vae: bool = False,
|
379 |
+
cache_mag_vae: bool = False,
|
380 |
+
spatial_group_norm: bool = False,
|
381 |
mini_batch_decoder: int = 3,
|
382 |
verbose = False,
|
383 |
):
|
|
|
444 |
|
445 |
self.conv_out = CausalConv3d(block_out_channels[0], out_channels, kernel_size=3)
|
446 |
|
447 |
+
self.slice_mag_vae = slice_mag_vae
|
448 |
self.slice_compression_vae = slice_compression_vae
|
449 |
+
self.cache_compression_vae = cache_compression_vae
|
450 |
+
self.cache_mag_vae = cache_mag_vae
|
451 |
self.mini_batch_decoder = mini_batch_decoder
|
452 |
+
self.spatial_group_norm = spatial_group_norm
|
453 |
self.verbose = verbose
|
454 |
|
455 |
def set_padding_one_frame(self):
|
|
|
473 |
_set_padding_more_frame(sub_name, sub_mod)
|
474 |
for name, module in self.named_children():
|
475 |
_set_padding_more_frame(name, module)
|
476 |
+
|
477 |
+
def set_magvit_padding_one_frame(self):
|
478 |
+
def _set_magvit_padding_one_frame(name, module):
|
479 |
+
if hasattr(module, 'padding_flag'):
|
480 |
+
if self.verbose:
|
481 |
+
print('Set pad mode for module[%s] type=%s' % (name, str(type(module))))
|
482 |
+
module.padding_flag = 3
|
483 |
+
for sub_name, sub_mod in module.named_children():
|
484 |
+
_set_magvit_padding_one_frame(sub_name, sub_mod)
|
485 |
+
for name, module in self.named_children():
|
486 |
+
_set_magvit_padding_one_frame(name, module)
|
487 |
+
|
488 |
+
def set_magvit_padding_more_frame(self):
|
489 |
+
def _set_magvit_padding_more_frame(name, module):
|
490 |
+
if hasattr(module, 'padding_flag'):
|
491 |
+
if self.verbose:
|
492 |
+
print('Set pad mode for module[%s] type=%s' % (name, str(type(module))))
|
493 |
+
module.padding_flag = 4
|
494 |
+
for sub_name, sub_mod in module.named_children():
|
495 |
+
_set_magvit_padding_more_frame(sub_name, sub_mod)
|
496 |
+
for name, module in self.named_children():
|
497 |
+
_set_magvit_padding_more_frame(name, module)
|
498 |
+
|
499 |
+
def set_cache_slice_vae_padding_one_frame(self):
|
500 |
+
def _set_cache_slice_vae_padding_one_frame(name, module):
|
501 |
+
if hasattr(module, 'padding_flag'):
|
502 |
+
if self.verbose:
|
503 |
+
print('Set pad mode for module[%s] type=%s' % (name, str(type(module))))
|
504 |
+
module.padding_flag = 5
|
505 |
+
for sub_name, sub_mod in module.named_children():
|
506 |
+
_set_cache_slice_vae_padding_one_frame(sub_name, sub_mod)
|
507 |
+
for name, module in self.named_children():
|
508 |
+
_set_cache_slice_vae_padding_one_frame(name, module)
|
509 |
+
|
510 |
+
def set_cache_slice_vae_padding_more_frame(self):
|
511 |
+
def _set_cache_slice_vae_padding_more_frame(name, module):
|
512 |
+
if hasattr(module, 'padding_flag'):
|
513 |
+
if self.verbose:
|
514 |
+
print('Set pad mode for module[%s] type=%s' % (name, str(type(module))))
|
515 |
+
module.padding_flag = 6
|
516 |
+
for sub_name, sub_mod in module.named_children():
|
517 |
+
_set_cache_slice_vae_padding_more_frame(sub_name, sub_mod)
|
518 |
+
for name, module in self.named_children():
|
519 |
+
_set_cache_slice_vae_padding_more_frame(name, module)
|
520 |
+
|
521 |
+
def set_3dgroupnorm_for_submodule(self):
|
522 |
+
def _set_3dgroupnorm_for_submodule(name, module):
|
523 |
+
if hasattr(module, 'set_3dgroupnorm'):
|
524 |
+
if self.verbose:
|
525 |
+
print('Set pad mode for module[%s] type=%s' % (name, str(type(module))))
|
526 |
+
module.set_3dgroupnorm = True
|
527 |
+
for sub_name, sub_mod in module.named_children():
|
528 |
+
_set_3dgroupnorm_for_submodule(sub_name, sub_mod)
|
529 |
+
for name, module in self.named_children():
|
530 |
+
_set_3dgroupnorm_for_submodule(name, module)
|
531 |
+
|
532 |
+
def clear_cache(self):
|
533 |
+
def _clear_cache(name, module):
|
534 |
+
if hasattr(module, 'prev_features'):
|
535 |
+
if self.verbose:
|
536 |
+
print('Set pad mode for module[%s] type=%s' % (name, str(type(module))))
|
537 |
+
module.prev_features = None
|
538 |
+
for sub_name, sub_mod in module.named_children():
|
539 |
+
_clear_cache(sub_name, sub_mod)
|
540 |
+
for name, module in self.named_children():
|
541 |
+
_clear_cache(name, module)
|
542 |
|
543 |
def single_forward(self, x: torch.Tensor, previous_features: torch.Tensor, after_features: torch.Tensor) -> torch.Tensor:
|
544 |
# x: (B, C, T, H, W)
|
545 |
+
if self.training:
|
546 |
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
547 |
+
if previous_features is not None and after_features is None:
|
548 |
b, c, t, h, w = x.size()
|
549 |
x = torch.concat([previous_features, x], 2)
|
550 |
x = self.conv_in(x)
|
551 |
x = self.mid_block(x)
|
552 |
x = x[:, :, -t:]
|
553 |
+
elif previous_features is None and after_features is not None:
|
554 |
b, c, t, h, w = x.size()
|
555 |
x = torch.concat([x, after_features], 2)
|
556 |
x = self.conv_in(x)
|
557 |
x = self.mid_block(x)
|
558 |
x = x[:, :, :t]
|
559 |
+
elif previous_features is not None and after_features is not None:
|
560 |
_, _, t_1, _, _ = previous_features.size()
|
561 |
_, _, t_2, _, _ = x.size()
|
562 |
x = torch.concat([previous_features, x, after_features], 2)
|
|
|
564 |
x = self.mid_block(x)
|
565 |
x = x[:, :, t_1:(t_1 + t_2)]
|
566 |
else:
|
567 |
+
if self.training:
|
568 |
+
x = torch.utils.checkpoint.checkpoint(
|
569 |
+
create_custom_forward(self.conv_in),
|
570 |
+
x,
|
571 |
+
**ckpt_kwargs,
|
572 |
+
)
|
573 |
+
x = torch.utils.checkpoint.checkpoint(
|
574 |
+
create_custom_forward(self.mid_block),
|
575 |
+
x,
|
576 |
+
**ckpt_kwargs,
|
577 |
+
)
|
578 |
+
else:
|
579 |
+
x = self.conv_in(x)
|
580 |
+
x = self.mid_block(x)
|
581 |
+
|
582 |
for up_block in self.up_blocks:
|
583 |
+
if self.training:
|
584 |
+
x = torch.utils.checkpoint.checkpoint(
|
585 |
+
create_custom_forward(up_block),
|
586 |
+
x,
|
587 |
+
**ckpt_kwargs,
|
588 |
+
)
|
589 |
+
else:
|
590 |
+
x = up_block(x)
|
591 |
+
|
592 |
+
if self.spatial_group_norm:
|
593 |
+
batch_size = x.shape[0]
|
594 |
+
x = rearrange(x, "b c t h w -> (b t) c h w")
|
595 |
+
x = self.conv_norm_out(x)
|
596 |
+
x = rearrange(x, "(b t) c h w -> b c t h w", b=batch_size)
|
597 |
+
else:
|
598 |
+
x = self.conv_norm_out(x)
|
599 |
|
|
|
600 |
x = self.conv_act(x)
|
601 |
x = self.conv_out(x)
|
602 |
|
603 |
return x
|
604 |
|
605 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
606 |
+
if self.spatial_group_norm:
|
607 |
+
self.set_3dgroupnorm_for_submodule()
|
608 |
+
|
609 |
+
if self.cache_mag_vae:
|
610 |
+
self.set_magvit_padding_one_frame()
|
611 |
+
first_frames = self.single_forward(x[:, :, 0:1, :, :], None, None)
|
612 |
+
self.set_magvit_padding_more_frame()
|
613 |
+
new_pixel_values = [first_frames]
|
614 |
+
for i in range(1, x.shape[2], self.mini_batch_decoder):
|
615 |
+
next_frames = self.single_forward(x[:, :, i: i + self.mini_batch_decoder, :, :], None, None)
|
616 |
+
new_pixel_values.append(next_frames)
|
617 |
+
new_pixel_values = torch.cat(new_pixel_values, dim=2)
|
618 |
+
elif self.cache_compression_vae:
|
619 |
+
_, _, f, _, _ = x.size()
|
620 |
+
if f == 1:
|
621 |
+
self.set_padding_one_frame()
|
622 |
+
first_frames = self.single_forward(x[:, :, :1, :, :], None, None)
|
623 |
+
new_pixel_values = [first_frames]
|
624 |
+
start_index = 1
|
625 |
+
else:
|
626 |
+
self.set_cache_slice_vae_padding_one_frame()
|
627 |
+
first_frames = self.single_forward(x[:, :, :self.mini_batch_decoder, :, :], None, None)
|
628 |
+
new_pixel_values = [first_frames]
|
629 |
+
start_index = self.mini_batch_decoder
|
630 |
+
|
631 |
+
for i in range(start_index, x.shape[2], self.mini_batch_decoder):
|
632 |
+
self.set_cache_slice_vae_padding_more_frame()
|
633 |
+
next_frames = self.single_forward(x[:, :, i: i + self.mini_batch_decoder, :, :], None, None)
|
634 |
+
new_pixel_values.append(next_frames)
|
635 |
+
new_pixel_values = torch.cat(new_pixel_values, dim=2)
|
636 |
+
elif self.slice_compression_vae:
|
637 |
_, _, f, _, _ = x.size()
|
638 |
if f % 2 != 0:
|
639 |
self.set_padding_one_frame()
|
|
|
653 |
previous_features = x[:, :, i: i + self.mini_batch_decoder, :, :]
|
654 |
new_pixel_values.append(next_frames)
|
655 |
new_pixel_values = torch.cat(new_pixel_values, dim=2)
|
656 |
+
elif self.slice_mag_vae:
|
657 |
+
_, _, f, _, _ = x.size()
|
658 |
+
new_pixel_values = []
|
659 |
+
for i in range(0, x.shape[2], self.mini_batch_decoder):
|
660 |
+
next_frames = self.single_forward(x[:, :, i: i + self.mini_batch_decoder, :, :], None, None)
|
661 |
+
new_pixel_values.append(next_frames)
|
662 |
+
new_pixel_values = torch.cat(new_pixel_values, dim=2)
|
663 |
else:
|
664 |
new_pixel_values = self.single_forward(x, None, None)
|
665 |
return new_pixel_values
|
easyanimate/vae/ldm/modules/ema.py
CHANGED
@@ -1,7 +1,8 @@
|
|
1 |
#-*- encoding:utf-8 -*-
|
2 |
import torch
|
3 |
-
from torch import nn
|
4 |
from pytorch_lightning.callbacks import Callback
|
|
|
|
|
5 |
|
6 |
class LitEma(nn.Module):
|
7 |
def __init__(self, model, decay=0.9999, use_num_upates=True):
|
|
|
1 |
#-*- encoding:utf-8 -*-
|
2 |
import torch
|
|
|
3 |
from pytorch_lightning.callbacks import Callback
|
4 |
+
from torch import nn
|
5 |
+
|
6 |
|
7 |
class LitEma(nn.Module):
|
8 |
def __init__(self, model, decay=0.9999, use_num_upates=True):
|
easyanimate/vae/ldm/modules/losses/contperceptual.py
CHANGED
@@ -2,8 +2,10 @@ import torch
|
|
2 |
import torch.nn as nn
|
3 |
import torch.nn.functional as F
|
4 |
from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no?
|
|
|
5 |
from ..vaemodules.discriminator import Discriminator3D
|
6 |
|
|
|
7 |
class LPIPSWithDiscriminator(nn.Module):
|
8 |
def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0,
|
9 |
disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
|
@@ -62,15 +64,6 @@ class LPIPSWithDiscriminator(nn.Module):
|
|
62 |
|
63 |
# get new loss_weight
|
64 |
loss_weights = 1
|
65 |
-
# b, _ ,f, _, _ = reconstructions.size()
|
66 |
-
# loss_weights = torch.ones([b, f]).view(b, 1, f, 1, 1)
|
67 |
-
# loss_weights[:, :, 0] = 3
|
68 |
-
# for i in range(1, f, 8):
|
69 |
-
# loss_weights[:, :, i - 1] = 3
|
70 |
-
# loss_weights[:, :, i] = 3
|
71 |
-
# loss_weights[:, :, -1] = 3
|
72 |
-
# loss_weights = loss_weights.permute(0, 2, 1, 3, 4).flatten(0, 1).to(reconstructions.device)
|
73 |
-
|
74 |
inputs = inputs.permute(0, 2, 1, 3, 4).flatten(0, 1)
|
75 |
reconstructions = reconstructions.permute(0, 2, 1, 3, 4).flatten(0, 1)
|
76 |
|
|
|
2 |
import torch.nn as nn
|
3 |
import torch.nn.functional as F
|
4 |
from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no?
|
5 |
+
|
6 |
from ..vaemodules.discriminator import Discriminator3D
|
7 |
|
8 |
+
|
9 |
class LPIPSWithDiscriminator(nn.Module):
|
10 |
def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0,
|
11 |
disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
|
|
|
64 |
|
65 |
# get new loss_weight
|
66 |
loss_weights = 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
inputs = inputs.permute(0, 2, 1, 3, 4).flatten(0, 1)
|
68 |
reconstructions = reconstructions.permute(0, 2, 1, 3, 4).flatten(0, 1)
|
69 |
|
easyanimate/vae/ldm/modules/vaemodules/common.py
CHANGED
@@ -38,7 +38,7 @@ class CausalConv3d(nn.Conv3d):
|
|
38 |
assert len(dilation) == 3, f"Dilation must be a 3-tuple, got {dilation} instead."
|
39 |
|
40 |
t_ks, h_ks, w_ks = kernel_size
|
41 |
-
|
42 |
t_dilation, h_dilation, w_dilation = dilation
|
43 |
|
44 |
t_pad = (t_ks - 1) * t_dilation
|
@@ -54,6 +54,7 @@ class CausalConv3d(nn.Conv3d):
|
|
54 |
self.temporal_padding = t_pad
|
55 |
self.temporal_padding_origin = math.ceil(((t_ks - 1) * w_dilation + (1 - w_stride)) / 2)
|
56 |
self.padding_flag = 0
|
|
|
57 |
|
58 |
super().__init__(
|
59 |
in_channels=in_channels,
|
@@ -67,38 +68,81 @@ class CausalConv3d(nn.Conv3d):
|
|
67 |
|
68 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
69 |
# x: (B, C, T, H, W)
|
|
|
|
|
70 |
if self.padding_flag == 0:
|
71 |
x = F.pad(
|
72 |
x,
|
73 |
pad=(0, 0, 0, 0, self.temporal_padding, 0),
|
74 |
mode="replicate", # TODO: check if this is necessary
|
75 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
else:
|
77 |
x = F.pad(
|
78 |
x,
|
79 |
pad=(0, 0, 0, 0, self.temporal_padding_origin, self.temporal_padding_origin),
|
80 |
)
|
81 |
-
|
82 |
-
|
83 |
-
def set_padding_one_frame(self):
|
84 |
-
def _set_padding_one_frame(name, module):
|
85 |
-
if hasattr(module, 'padding_flag'):
|
86 |
-
print('Set pad mode for module[%s] type=%s' % (name, str(type(module))))
|
87 |
-
module.padding_flag = 1
|
88 |
-
for sub_name, sub_mod in module.named_children():
|
89 |
-
_set_padding_one_frame(sub_name, sub_mod)
|
90 |
-
for name, module in self.named_children():
|
91 |
-
_set_padding_one_frame(name, module)
|
92 |
-
|
93 |
-
def set_padding_more_frame(self):
|
94 |
-
def _set_padding_more_frame(name, module):
|
95 |
-
if hasattr(module, 'padding_flag'):
|
96 |
-
print('Set pad mode for module[%s] type=%s' % (name, str(type(module))))
|
97 |
-
module.padding_flag = 2
|
98 |
-
for sub_name, sub_mod in module.named_children():
|
99 |
-
_set_padding_more_frame(sub_name, sub_mod)
|
100 |
-
for name, module in self.named_children():
|
101 |
-
_set_padding_more_frame(name, module)
|
102 |
|
103 |
class ResidualBlock2D(nn.Module):
|
104 |
def __init__(
|
@@ -142,15 +186,29 @@ class ResidualBlock2D(nn.Module):
|
|
142 |
else:
|
143 |
self.shortcut = nn.Identity()
|
144 |
|
|
|
|
|
145 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
146 |
shortcut = self.shortcut(x)
|
147 |
|
148 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
149 |
x = self.nonlinearity(x)
|
150 |
|
151 |
x = self.conv1(x)
|
152 |
|
153 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
154 |
x = self.nonlinearity(x)
|
155 |
|
156 |
x = self.dropout(x)
|
@@ -201,15 +259,29 @@ class ResidualBlock3D(nn.Module):
|
|
201 |
else:
|
202 |
self.shortcut = nn.Identity()
|
203 |
|
|
|
|
|
204 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
205 |
shortcut = self.shortcut(x)
|
206 |
|
207 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
208 |
x = self.nonlinearity(x)
|
209 |
|
210 |
x = self.conv1(x)
|
211 |
|
212 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
213 |
x = self.nonlinearity(x)
|
214 |
|
215 |
x = self.dropout(x)
|
@@ -238,11 +310,18 @@ class SpatialNorm2D(nn.Module):
|
|
238 |
self.norm = nn.GroupNorm(num_channels=f_channels, num_groups=32, eps=1e-6, affine=True)
|
239 |
self.conv_y = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
|
240 |
self.conv_b = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
|
|
|
241 |
|
242 |
def forward(self, f: torch.FloatTensor, zq: torch.FloatTensor) -> torch.FloatTensor:
|
243 |
f_size = f.shape[-2:]
|
244 |
zq = F.interpolate(zq, size=f_size, mode="nearest")
|
245 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
246 |
new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
|
247 |
return new_f
|
248 |
|
|
|
38 |
assert len(dilation) == 3, f"Dilation must be a 3-tuple, got {dilation} instead."
|
39 |
|
40 |
t_ks, h_ks, w_ks = kernel_size
|
41 |
+
self.t_stride, h_stride, w_stride = stride
|
42 |
t_dilation, h_dilation, w_dilation = dilation
|
43 |
|
44 |
t_pad = (t_ks - 1) * t_dilation
|
|
|
54 |
self.temporal_padding = t_pad
|
55 |
self.temporal_padding_origin = math.ceil(((t_ks - 1) * w_dilation + (1 - w_stride)) / 2)
|
56 |
self.padding_flag = 0
|
57 |
+
self.prev_features = None
|
58 |
|
59 |
super().__init__(
|
60 |
in_channels=in_channels,
|
|
|
68 |
|
69 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
70 |
# x: (B, C, T, H, W)
|
71 |
+
dtype = x.dtype
|
72 |
+
x = x.float()
|
73 |
if self.padding_flag == 0:
|
74 |
x = F.pad(
|
75 |
x,
|
76 |
pad=(0, 0, 0, 0, self.temporal_padding, 0),
|
77 |
mode="replicate", # TODO: check if this is necessary
|
78 |
)
|
79 |
+
x = x.to(dtype=dtype)
|
80 |
+
return super().forward(x)
|
81 |
+
elif self.padding_flag == 3:
|
82 |
+
x = F.pad(
|
83 |
+
x,
|
84 |
+
pad=(0, 0, 0, 0, self.temporal_padding, 0),
|
85 |
+
mode="replicate", # TODO: check if this is necessary
|
86 |
+
)
|
87 |
+
x = x.to(dtype=dtype)
|
88 |
+
self.prev_features = x[:, :, -self.temporal_padding:]
|
89 |
+
|
90 |
+
b, c, f, h, w = x.size()
|
91 |
+
outputs = []
|
92 |
+
i = 0
|
93 |
+
while i + self.temporal_padding + 1 <= f:
|
94 |
+
out = super().forward(x[:, :, i:i + self.temporal_padding + 1])
|
95 |
+
i += self.t_stride
|
96 |
+
outputs.append(out)
|
97 |
+
return torch.concat(outputs, 2)
|
98 |
+
elif self.padding_flag == 4:
|
99 |
+
if self.t_stride == 2:
|
100 |
+
x = torch.concat(
|
101 |
+
[self.prev_features[:, :, -(self.temporal_padding - 1):], x], dim = 2
|
102 |
+
)
|
103 |
+
else:
|
104 |
+
x = torch.concat(
|
105 |
+
[self.prev_features, x], dim = 2
|
106 |
+
)
|
107 |
+
x = x.to(dtype=dtype)
|
108 |
+
self.prev_features = x[:, :, -self.temporal_padding:]
|
109 |
+
|
110 |
+
b, c, f, h, w = x.size()
|
111 |
+
outputs = []
|
112 |
+
i = 0
|
113 |
+
while i + self.temporal_padding + 1 <= f:
|
114 |
+
out = super().forward(x[:, :, i:i + self.temporal_padding + 1])
|
115 |
+
i += self.t_stride
|
116 |
+
outputs.append(out)
|
117 |
+
return torch.concat(outputs, 2)
|
118 |
+
elif self.padding_flag == 5:
|
119 |
+
x = F.pad(
|
120 |
+
x,
|
121 |
+
pad=(0, 0, 0, 0, self.temporal_padding, 0),
|
122 |
+
mode="replicate", # TODO: check if this is necessary
|
123 |
+
)
|
124 |
+
x = x.to(dtype=dtype)
|
125 |
+
self.prev_features = x[:, :, -self.temporal_padding:]
|
126 |
+
return super().forward(x)
|
127 |
+
elif self.padding_flag == 6:
|
128 |
+
if self.t_stride == 2:
|
129 |
+
x = torch.concat(
|
130 |
+
[self.prev_features[:, :, -(self.temporal_padding - 1):], x], dim = 2
|
131 |
+
)
|
132 |
+
else:
|
133 |
+
x = torch.concat(
|
134 |
+
[self.prev_features, x], dim = 2
|
135 |
+
)
|
136 |
+
self.prev_features = x[:, :, -self.temporal_padding:]
|
137 |
+
x = x.to(dtype=dtype)
|
138 |
+
return super().forward(x)
|
139 |
else:
|
140 |
x = F.pad(
|
141 |
x,
|
142 |
pad=(0, 0, 0, 0, self.temporal_padding_origin, self.temporal_padding_origin),
|
143 |
)
|
144 |
+
x = x.to(dtype=dtype)
|
145 |
+
return super().forward(x)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
146 |
|
147 |
class ResidualBlock2D(nn.Module):
|
148 |
def __init__(
|
|
|
186 |
else:
|
187 |
self.shortcut = nn.Identity()
|
188 |
|
189 |
+
self.set_3dgroupnorm = False
|
190 |
+
|
191 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
192 |
shortcut = self.shortcut(x)
|
193 |
|
194 |
+
if self.set_3dgroupnorm:
|
195 |
+
batch_size = x.shape[0]
|
196 |
+
x = rearrange(x, "b c t h w -> (b t) c h w")
|
197 |
+
x = self.norm1(x)
|
198 |
+
x = rearrange(x, "(b t) c h w -> b c t h w", b=batch_size)
|
199 |
+
else:
|
200 |
+
x = self.norm1(x)
|
201 |
x = self.nonlinearity(x)
|
202 |
|
203 |
x = self.conv1(x)
|
204 |
|
205 |
+
if self.set_3dgroupnorm:
|
206 |
+
batch_size = x.shape[0]
|
207 |
+
x = rearrange(x, "b c t h w -> (b t) c h w")
|
208 |
+
x = self.norm2(x)
|
209 |
+
x = rearrange(x, "(b t) c h w -> b c t h w", b=batch_size)
|
210 |
+
else:
|
211 |
+
x = self.norm2(x)
|
212 |
x = self.nonlinearity(x)
|
213 |
|
214 |
x = self.dropout(x)
|
|
|
259 |
else:
|
260 |
self.shortcut = nn.Identity()
|
261 |
|
262 |
+
self.set_3dgroupnorm = False
|
263 |
+
|
264 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
265 |
shortcut = self.shortcut(x)
|
266 |
|
267 |
+
if self.set_3dgroupnorm:
|
268 |
+
batch_size = x.shape[0]
|
269 |
+
x = rearrange(x, "b c t h w -> (b t) c h w")
|
270 |
+
x = self.norm1(x)
|
271 |
+
x = rearrange(x, "(b t) c h w -> b c t h w", b=batch_size)
|
272 |
+
else:
|
273 |
+
x = self.norm1(x)
|
274 |
x = self.nonlinearity(x)
|
275 |
|
276 |
x = self.conv1(x)
|
277 |
|
278 |
+
if self.set_3dgroupnorm:
|
279 |
+
batch_size = x.shape[0]
|
280 |
+
x = rearrange(x, "b c t h w -> (b t) c h w")
|
281 |
+
x = self.norm2(x)
|
282 |
+
x = rearrange(x, "(b t) c h w -> b c t h w", b=batch_size)
|
283 |
+
else:
|
284 |
+
x = self.norm2(x)
|
285 |
x = self.nonlinearity(x)
|
286 |
|
287 |
x = self.dropout(x)
|
|
|
310 |
self.norm = nn.GroupNorm(num_channels=f_channels, num_groups=32, eps=1e-6, affine=True)
|
311 |
self.conv_y = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
|
312 |
self.conv_b = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
|
313 |
+
self.set_3dgroupnorm = False
|
314 |
|
315 |
def forward(self, f: torch.FloatTensor, zq: torch.FloatTensor) -> torch.FloatTensor:
|
316 |
f_size = f.shape[-2:]
|
317 |
zq = F.interpolate(zq, size=f_size, mode="nearest")
|
318 |
+
if self.set_3dgroupnorm:
|
319 |
+
batch_size = f.shape[0]
|
320 |
+
f = rearrange(f, "b c t h w -> (b t) c h w")
|
321 |
+
norm_f = self.norm(f)
|
322 |
+
norm_f = rearrange(norm_f, "(b t) c h w -> b c t h w", b=batch_size)
|
323 |
+
else:
|
324 |
+
norm_f = self.norm(f)
|
325 |
new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
|
326 |
return new_f
|
327 |
|
easyanimate/vae/ldm/modules/vaemodules/upsamplers.py
CHANGED
@@ -137,6 +137,7 @@ class SpatialTemporalUpsampler3D(Upsampler):
|
|
137 |
)
|
138 |
|
139 |
self.padding_flag = 0
|
|
|
140 |
|
141 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
142 |
x = F.interpolate(x, scale_factor=(1, 2, 2), mode="nearest")
|
@@ -145,32 +146,12 @@ class SpatialTemporalUpsampler3D(Upsampler):
|
|
145 |
if self.padding_flag == 0:
|
146 |
if x.shape[2] > 1:
|
147 |
first_frame, x = x[:, :, :1], x[:, :, 1:]
|
148 |
-
x = F.interpolate(x, scale_factor=(2, 1, 1), mode="trilinear")
|
149 |
x = torch.cat([first_frame, x], dim=2)
|
150 |
-
elif self.padding_flag == 2:
|
151 |
-
x = F.interpolate(x, scale_factor=(2, 1, 1), mode="trilinear")
|
152 |
return x
|
153 |
|
154 |
-
def set_padding_one_frame(self):
|
155 |
-
def _set_padding_one_frame(name, module):
|
156 |
-
if hasattr(module, 'padding_flag'):
|
157 |
-
print('Set pad mode for module[%s] type=%s' % (name, str(type(module))))
|
158 |
-
module.padding_flag = 1
|
159 |
-
for sub_name, sub_mod in module.named_children():
|
160 |
-
_set_padding_one_frame(sub_name, sub_mod)
|
161 |
-
for name, module in self.named_children():
|
162 |
-
_set_padding_one_frame(name, module)
|
163 |
-
|
164 |
-
def set_padding_more_frame(self):
|
165 |
-
def _set_padding_more_frame(name, module):
|
166 |
-
if hasattr(module, 'padding_flag'):
|
167 |
-
print('Set pad mode for module[%s] type=%s' % (name, str(type(module))))
|
168 |
-
module.padding_flag = 2
|
169 |
-
for sub_name, sub_mod in module.named_children():
|
170 |
-
_set_padding_more_frame(sub_name, sub_mod)
|
171 |
-
for name, module in self.named_children():
|
172 |
-
_set_padding_more_frame(name, module)
|
173 |
-
|
174 |
class SpatialTemporalUpsamplerD2S3D(Upsampler):
|
175 |
def __init__(self, in_channels: int, out_channels: int):
|
176 |
super().__init__(
|
|
|
137 |
)
|
138 |
|
139 |
self.padding_flag = 0
|
140 |
+
self.set_3dgroupnorm = False
|
141 |
|
142 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
143 |
x = F.interpolate(x, scale_factor=(1, 2, 2), mode="nearest")
|
|
|
146 |
if self.padding_flag == 0:
|
147 |
if x.shape[2] > 1:
|
148 |
first_frame, x = x[:, :, :1], x[:, :, 1:]
|
149 |
+
x = F.interpolate(x, scale_factor=(2, 1, 1), mode="trilinear" if not self.set_3dgroupnorm else "nearest")
|
150 |
x = torch.cat([first_frame, x], dim=2)
|
151 |
+
elif self.padding_flag == 2 or self.padding_flag == 4 or self.padding_flag == 5 or self.padding_flag == 6:
|
152 |
+
x = F.interpolate(x, scale_factor=(2, 1, 1), mode="trilinear" if not self.set_3dgroupnorm else "nearest")
|
153 |
return x
|
154 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
155 |
class SpatialTemporalUpsamplerD2S3D(Upsampler):
|
156 |
def __init__(self, in_channels: int, out_channels: int):
|
157 |
super().__init__(
|
easyanimate/video_caption/README.md
DELETED
@@ -1,90 +0,0 @@
|
|
1 |
-
# Video Caption
|
2 |
-
EasyAnimate uses multi-modal LLMs to generate captions for frames extracted from the video firstly, and then employs LLMs to summarize and refine the generated frame captions into the final video caption. By leveraging [sglang](https://github.com/sgl-project/sglang)/[vLLM](https://github.com/vllm-project/vllm) and [accelerate distributed inference](https://huggingface.co/docs/accelerate/en/usage_guides/distributed_inference), the entire processing could be very fast.
|
3 |
-
|
4 |
-
English | [简体中文](./README_zh-CN.md)
|
5 |
-
|
6 |
-
## Quick Start
|
7 |
-
1. Cloud usage: AliyunDSW/Docker
|
8 |
-
|
9 |
-
Check [README.md](../../README.md#quick-start) for details.
|
10 |
-
|
11 |
-
2. Local usage
|
12 |
-
|
13 |
-
```shell
|
14 |
-
# Install EasyAnimate requirements firstly.
|
15 |
-
cd EasyAnimate && pip install -r requirements.txt
|
16 |
-
|
17 |
-
# Install additional requirements for video caption.
|
18 |
-
cd easyanimate/video_caption && pip install -r requirements.txt --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
19 |
-
|
20 |
-
# Use DDP instead of DP in EasyOCR detection.
|
21 |
-
site_pkg_path=$(python -c 'import site; print(site.getsitepackages()[0])')
|
22 |
-
cp -v easyocr_detection_patched.py $site_pkg_path/easyocr/detection.py
|
23 |
-
|
24 |
-
# We strongly recommend using Docker unless you can properly handle the dependency between vllm with torch(cuda).
|
25 |
-
```
|
26 |
-
|
27 |
-
## Data preprocessing
|
28 |
-
Data preprocessing can be divided into three parts:
|
29 |
-
|
30 |
-
- Video cut.
|
31 |
-
- Video cleaning.
|
32 |
-
- Video caption.
|
33 |
-
|
34 |
-
The input for data preprocessing can be a video folder or a metadata file (txt/csv/jsonl) containing the video path column. Please check `get_video_path_list` function in [utils/video_utils.py](utils/video_utils.py) for details.
|
35 |
-
|
36 |
-
For easier understanding, we use one data from Panda70m as an example for data preprocessing, [Download here](https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/easyanimate/asset/v2/--C66yU3LjM_2.mp4). Please download the video and push it in "datasets/panda_70m/before_vcut/"
|
37 |
-
|
38 |
-
```
|
39 |
-
📦 datasets/
|
40 |
-
├── 📂 panda_70m/
|
41 |
-
│ └── 📂 before_vcut/
|
42 |
-
│ └── 📄 --C66yU3LjM_2.mp4
|
43 |
-
```
|
44 |
-
|
45 |
-
1. Video cut
|
46 |
-
|
47 |
-
For long video cut, EasyAnimate utilizes PySceneDetect to identify scene changes within the video and performs scene cutting based on certain threshold values to ensure consistency in the themes of the video segments. After cutting, we only keep segments with lengths ranging from 3 to 10 seconds for model training.
|
48 |
-
|
49 |
-
We have completed the parameters for ```stage_1_video_cut.sh```, so I can run it directly using the command sh ```stage_1_video_cut.sh```. After executing ```stage_1_video_cut.sh```, we obtained short videos in ```easyanimate/video_caption/datasets/panda_70m/train```.
|
50 |
-
|
51 |
-
```shell
|
52 |
-
sh stage_1_video_cut.sh
|
53 |
-
```
|
54 |
-
2. Video cleaning
|
55 |
-
|
56 |
-
Following SVD's data preparation process, EasyAnimate provides a simple yet effective data processing pipeline for high-quality data filtering and labeling. It also supports distributed processing to accelerate the speed of data preprocessing. The overall process is as follows:
|
57 |
-
|
58 |
-
- Duration filtering: Analyze the basic information of the video to filter out low-quality videos that are short in duration or low in resolution. This filtering result is corresponding to the video cut (3s ~ 10s videos).
|
59 |
-
- Aesthetic filtering: Filter out videos with poor content (blurry, dim, etc.) by calculating the average aesthetic score of uniformly distributed 4 frames.
|
60 |
-
- Text filtering: Use easyocr to calculate the text proportion of middle frames to filter out videos with a large proportion of text.
|
61 |
-
- Motion filtering: Calculate interframe optical flow differences to filter out videos that move too slowly or too quickly.
|
62 |
-
|
63 |
-
The process file of **Aesthetic filtering** is ```compute_video_frame_quality.py```. After executing ```compute_video_frame_quality.py```, we obtained the file ```datasets/panda_70m/aesthetic_score.jsonl```, where each line corresponds to the aesthetic score of each video.
|
64 |
-
|
65 |
-
The process file of **Text filtering** is ```compute_text_score.py```. After executing ```compute_text_score.py```, we obtained the file ```datasets/panda_70m/text_score.jsonl```, where each line corresponds to the text score of each video.
|
66 |
-
|
67 |
-
The process file of **Motion filtering** is ```compute_motion_score.py```. Motion filtering is based on Aesthetic filtering and Text filtering; only samples that meet certain aesthetic scores and text scores will undergo calculation for the Motion score. After executing ```compute_motion_score.py```, we obtained the file ```datasets/panda_70m/motion_score.jsonl```, where each line corresponds to the motion score of each video.
|
68 |
-
|
69 |
-
Then we need to filter videos by motion scores. After executing ```filter_videos_by_motion_score.py```, we get the file ```datasets/panda_70m/train.jsonl```, which includes the video we need to caption.
|
70 |
-
|
71 |
-
We have completed the parameters for stage_2_filter_data.sh, so I can run it directly using the command sh stage_2_filter_data.sh.
|
72 |
-
|
73 |
-
```shell
|
74 |
-
sh stage_2_filter_data.sh
|
75 |
-
```
|
76 |
-
3. Video caption
|
77 |
-
|
78 |
-
Video captioning is carried out in two stages. The first stage involves extracting frames from a video and generating descriptions for them. Subsequently, a large language model is used to summarize these descriptions into a caption.
|
79 |
-
|
80 |
-
We have conducted a detailed and manual comparison of open sourced multi-modal LLMs such as [Qwen-VL](https://huggingface.co/Qwen/Qwen-VL), [ShareGPT4V-7B](https://huggingface.co/Lin-Chen/ShareGPT4V-7B), [deepseek-vl-7b-chat](https://huggingface.co/deepseek-ai/deepseek-vl-7b-chat) and etc. And we found that [llava-v1.6-vicuna-7b](https://huggingface.co/liuhaotian/llava-v1.6-vicuna-7b) is capable of generating more detailed captions with fewer hallucinations. Additionally, it is supported by serving engines like [sglang](https://github.com/sgl-project/sglang) and [lmdepoly](https://github.com/InternLM/lmdeploy), enabling faster inference.
|
81 |
-
|
82 |
-
Firstly, we use ```caption_video_frame.py``` to generate frame captions. Then, we use ```caption_summary.py``` to generate summary captions.
|
83 |
-
|
84 |
-
We have completed the parameters for stage_3_video_caption.sh, so I can run it directly using the command sh stage_3_video_caption.sh. After executing ```stage_3_video_cut.sh```, we obtained last json ```train_panda_70m.json``` for easyanimate training.
|
85 |
-
|
86 |
-
```shell
|
87 |
-
sh stage_3_video_caption.sh
|
88 |
-
```
|
89 |
-
|
90 |
-
If you cannot access to Huggingface, you can run `export HF_ENDPOINT=https://hf-mirror.com` before the above command to download the summary caption model automatically.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|