Alexander Bagus commited on
Commit
d2c9b66
·
1 Parent(s): df103d9

initial commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +19 -0
  2. README.md +8 -5
  3. app.py +224 -4
  4. examples/depth.jpg +0 -0
  5. examples/hed.jpg +0 -0
  6. examples/pose.jpg +0 -0
  7. examples/pose2.jpg +0 -0
  8. image_utils.py +70 -0
  9. predict_t2i_control.py +228 -0
  10. requirements.txt +7 -0
  11. static/data.json +8 -0
  12. static/footer.html +16 -0
  13. static/header.html +11 -0
  14. videox_fun/__init__.py +0 -0
  15. videox_fun/api/api.py +226 -0
  16. videox_fun/api/api_multi_nodes.py +320 -0
  17. videox_fun/data/__init__.py +9 -0
  18. videox_fun/data/bucket_sampler.py +379 -0
  19. videox_fun/data/dataset_image.py +191 -0
  20. videox_fun/data/dataset_image_video.py +657 -0
  21. videox_fun/data/dataset_video.py +901 -0
  22. videox_fun/data/utils.py +347 -0
  23. videox_fun/pipeline/__init__.py +62 -0
  24. videox_fun/pipeline/pipeline_cogvideox_fun.py +862 -0
  25. videox_fun/pipeline/pipeline_cogvideox_fun_control.py +956 -0
  26. videox_fun/pipeline/pipeline_cogvideox_fun_inpaint.py +1136 -0
  27. videox_fun/pipeline/pipeline_fantasy_talking.py +754 -0
  28. videox_fun/pipeline/pipeline_flux.py +978 -0
  29. videox_fun/pipeline/pipeline_flux2.py +900 -0
  30. videox_fun/pipeline/pipeline_flux2_control.py +973 -0
  31. videox_fun/pipeline/pipeline_hunyuanvideo.py +805 -0
  32. videox_fun/pipeline/pipeline_hunyuanvideo_i2v.py +972 -0
  33. videox_fun/pipeline/pipeline_qwenimage.py +767 -0
  34. videox_fun/pipeline/pipeline_qwenimage_edit.py +952 -0
  35. videox_fun/pipeline/pipeline_qwenimage_edit_plus.py +937 -0
  36. videox_fun/pipeline/pipeline_wan.py +576 -0
  37. videox_fun/pipeline/pipeline_wan2_2.py +591 -0
  38. videox_fun/pipeline/pipeline_wan2_2_animate.py +929 -0
  39. videox_fun/pipeline/pipeline_wan2_2_fun_control.py +903 -0
  40. videox_fun/pipeline/pipeline_wan2_2_fun_inpaint.py +752 -0
  41. videox_fun/pipeline/pipeline_wan2_2_s2v.py +815 -0
  42. videox_fun/pipeline/pipeline_wan2_2_ti2v.py +732 -0
  43. videox_fun/pipeline/pipeline_wan2_2_vace_fun.py +801 -0
  44. videox_fun/pipeline/pipeline_wan_fun_control.py +799 -0
  45. videox_fun/pipeline/pipeline_wan_fun_inpaint.py +734 -0
  46. videox_fun/pipeline/pipeline_wan_phantom.py +695 -0
  47. videox_fun/pipeline/pipeline_wan_vace.py +787 -0
  48. videox_fun/pipeline/pipeline_z_image.py +613 -0
  49. videox_fun/pipeline/pipeline_z_image_control.py +633 -0
  50. videox_fun/reward/MPS/README.md +1 -0
.gitignore ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+
4
+ models/
5
+
6
+ # Packages
7
+ *.egg
8
+ *.egg-info
9
+ dist
10
+ build
11
+ eggs
12
+ parts
13
+ bin
14
+ var
15
+ sdist
16
+ develop-eggs
17
+ .installed.cfg
18
+ lib64
19
+ __pycache__
README.md CHANGED
@@ -1,14 +1,17 @@
1
  ---
2
  title: ZIT Controlnet
3
- emoji: 📊
4
- colorFrom: yellow
5
- colorTo: purple
6
  sdk: gradio
7
- sdk_version: 6.0.2
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
11
  short_description: Supports Canny, HED, Depth, Pose and MLSD
 
 
 
12
  ---
13
 
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
  title: ZIT Controlnet
3
+ emoji: 🖼
4
+ colorFrom: purple
5
+ colorTo: red
6
  sdk: gradio
7
+ sdk_version: 5.44.0
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
11
  short_description: Supports Canny, HED, Depth, Pose and MLSD
12
+ models:
13
+ - Tongyi-MAI/Z-Image-Turbo
14
+ - alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union
15
  ---
16
 
17
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py CHANGED
@@ -1,7 +1,227 @@
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import numpy as np
3
+ import random
4
+ import json
5
+ import spaces
6
+ import torch
7
+ from diffusers import DiffusionPipeline
8
+ from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler
9
+ from videox_fun.pipeline import ZImageControlPipeline
10
+ from videox_fun.models import ZImageControlTransformer2DModel
11
+ from transformers import AutoTokenizer, Qwen3ForCausalLM
12
+ from diffusers import AutoencoderKL
13
+ from image_utils import get_image_latent, scale_image
14
+ # from videox_fun.utils.utils import get_image_latent
15
 
 
 
16
 
17
+ MODEL_REPO = "Tongyi-MAI/Z-Image-Turbo"
18
+ MAX_SEED = np.iinfo(np.int32).max
19
+ MAX_IMAGE_SIZE = 1280
20
+
21
+ MODEL_LOCAL = "models/Z-Image-Turbo/"
22
+ TRANSFORMER_LOCAL = "models/Z-Image-Turbo-Fun-Controlnet-Union.safetensors"
23
+
24
+
25
+ weight_dtype = torch.bfloat16
26
+
27
+ # load transformer
28
+ transformer = ZImageControlTransformer2DModel.from_pretrained(
29
+ MODEL_LOCAL,
30
+ subfolder="transformer",
31
+ low_cpu_mem_usage=True,
32
+ torch_dtype=torch.bfloat16,
33
+ transformer_additional_kwargs={
34
+ "control_layers_places": [0, 5, 10, 15, 20, 25],
35
+ "control_in_dim": 16
36
+ },
37
+ ).to(torch.bfloat16)
38
+
39
+ if TRANSFORMER_LOCAL is not None:
40
+ print(f"From checkpoint: {TRANSFORMER_LOCAL}")
41
+ if TRANSFORMER_LOCAL.endswith("safetensors"):
42
+ from safetensors.torch import load_file, safe_open
43
+ state_dict = load_file(TRANSFORMER_LOCAL)
44
+ else:
45
+ state_dict = torch.load(TRANSFORMER_LOCAL, map_location="cpu")
46
+ state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict
47
+
48
+ m, u = transformer.load_state_dict(state_dict, strict=False)
49
+ print(f"missing keys: {len(m)}, unexpected keys: {len(u)}")
50
+
51
+ # Load MODEL_REPO
52
+ # Get Vae
53
+ vae = AutoencoderKL.from_pretrained(
54
+ MODEL_LOCAL,
55
+ subfolder="vae"
56
+ ).to(weight_dtype)
57
+
58
+ tokenizer = AutoTokenizer.from_pretrained(
59
+ MODEL_LOCAL, subfolder="tokenizer"
60
+ )
61
+ text_encoder = Qwen3ForCausalLM.from_pretrained(
62
+ MODEL_LOCAL, subfolder="text_encoder", torch_dtype=weight_dtype,
63
+ low_cpu_mem_usage=True,
64
+ )
65
+ scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=3)
66
+ pipe = ZImageControlPipeline(
67
+ vae=vae,
68
+ tokenizer=tokenizer,
69
+ text_encoder=text_encoder,
70
+ transformer=transformer,
71
+ scheduler=scheduler,
72
+ )
73
+ pipe.transformer = transformer
74
+ pipe.to("cuda")
75
+
76
+ # ======== AoTI compilation + FA3 ========
77
+ pipe.transformer.layers._repeated_blocks = ["ZImageTransformerBlock"]
78
+ spaces.aoti_blocks_load(pipe.transformer.layers,
79
+ "zerogpu-aoti/Z-Image", variant="fa3")
80
+
81
+
82
+ @spaces.GPU
83
+ def inference(
84
+ prompt,
85
+ input_image,
86
+ image_scale=1.0,
87
+ control_context_scale = 0.75,
88
+ seed=42,
89
+ randomize_seed=True,
90
+ guidance_scale=1.5,
91
+ num_inference_steps=8,
92
+ progress=gr.Progress(track_tqdm=True),
93
+ ):
94
+ # process image
95
+ if input_image is None:
96
+ print("Error: input_image is empty.")
97
+ return None
98
+
99
+ input_image, width, height = scale_image(input_image, image_scale)
100
+
101
+ control_image = get_image_latent(input_image, sample_size=[height, width])[:, :, 0]
102
+
103
+ # generation
104
+ if randomize_seed:
105
+ seed = random.randint(0, MAX_SEED)
106
+
107
+ generator = torch.Generator().manual_seed(seed)
108
+
109
+ image = pipe(
110
+ prompt=prompt,
111
+ height=height,
112
+ width=width,
113
+ generator=generator,
114
+ guidance_scale=guidance_scale,
115
+ control_image=control_image,
116
+ num_inference_steps=num_inference_steps,
117
+ control_context_scale=control_context_scale,
118
+ ).images[0]
119
+
120
+ return image, seed
121
+
122
+
123
+ def read_file(path: str) -> str:
124
+ with open(path, 'r', encoding='utf-8') as f:
125
+ content = f.read()
126
+ return content
127
+
128
+
129
+ css = """
130
+ #col-container {
131
+ margin: 0 auto;
132
+ max-width: 960px;
133
+ }
134
+ """
135
+
136
+ with open('static/data.json', 'r') as file:
137
+ data = json.load(file)
138
+ examples = data['examples']
139
+
140
+ with gr.Blocks() as demo:
141
+ with gr.Column(elem_id="col-container"):
142
+ with gr.Column():
143
+ gr.HTML(read_file("static/header.html"))
144
+ with gr.Row(equal_height=True):
145
+ with gr.Column():
146
+ input_image = gr.Image(
147
+ height=290, sources=['upload', 'clipboard'],
148
+ image_mode='RGB',
149
+ # elem_id="image_upload",
150
+ type="pil", label="Upload")
151
+
152
+ prompt = gr.Textbox(
153
+ label="Prompt",
154
+ show_label=False,
155
+ lines=2,
156
+ placeholder="Enter your prompt",
157
+ container=False,
158
+ )
159
+
160
+ run_button = gr.Button("Run", variant="primary")
161
+ with gr.Column():
162
+ output_image = gr.Image(label="Result", show_label=False)
163
+
164
+ with gr.Accordion("Advanced Settings", open=False):
165
+ seed = gr.Slider(
166
+ label="Seed",
167
+ minimum=0,
168
+ maximum=MAX_SEED,
169
+ step=1,
170
+ value=0,
171
+ )
172
+
173
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
174
+
175
+ with gr.Row():
176
+ image_scale = gr.Slider(
177
+ label="Image scale",
178
+ minimum=0.5,
179
+ maximum=2.0,
180
+ step=0.1,
181
+ value=1.0,
182
+ )
183
+ control_context_scale = gr.Slider(
184
+ label="Control context scale",
185
+ minimum=0.0,
186
+ maximum=1.0,
187
+ step=0.1,
188
+ value=0.75,
189
+ )
190
+
191
+ with gr.Row():
192
+ guidance_scale = gr.Slider(
193
+ label="Guidance scale",
194
+ minimum=0.0,
195
+ maximum=10.0,
196
+ step=0.1,
197
+ value=2.5,
198
+ )
199
+
200
+ num_inference_steps = gr.Slider(
201
+ label="Number of inference steps",
202
+ minimum=1,
203
+ maximum=30,
204
+ step=1,
205
+ value=8,
206
+ )
207
+ gr.Examples(examples=examples, inputs=[input_image, prompt])
208
+
209
+ gr.HTML(read_file("static/footer.html"))
210
+ gr.on(
211
+ triggers=[run_button.click, prompt.submit],
212
+ fn=inference,
213
+ inputs=[
214
+ prompt,
215
+ input_image,
216
+ image_scale,
217
+ control_context_scale,
218
+ seed,
219
+ randomize_seed,
220
+ guidance_scale,
221
+ num_inference_steps,
222
+ ],
223
+ outputs=[output_image, seed],
224
+ )
225
+
226
+ if __name__ == "__main__":
227
+ demo.launch(mcp_server=True)
examples/depth.jpg ADDED
examples/hed.jpg ADDED
examples/pose.jpg ADDED
examples/pose2.jpg ADDED
image_utils.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import Image
3
+ import numpy as np
4
+
5
+ def scale_image(img, scale):
6
+ w, h = img.size
7
+ new_w = int(w * scale)
8
+ new_h = int(h * scale)
9
+
10
+ # Adjust to nearest multiple of 32
11
+ new_w = (new_w // 32) * 32
12
+ new_h = (new_h // 32) * 32
13
+
14
+ return img.resize((new_w, new_h), Image.LANCZOS), new_w, new_h
15
+
16
+ def padding_image(images, new_width, new_height):
17
+ new_image = Image.new('RGB', (new_width, new_height), (255, 255, 255))
18
+
19
+ aspect_ratio = images.width / images.height
20
+ if new_width / new_height > 1:
21
+ if aspect_ratio > new_width / new_height:
22
+ new_img_width = new_width
23
+ new_img_height = int(new_img_width / aspect_ratio)
24
+ else:
25
+ new_img_height = new_height
26
+ new_img_width = int(new_img_height * aspect_ratio)
27
+ else:
28
+ if aspect_ratio > new_width / new_height:
29
+ new_img_width = new_width
30
+ new_img_height = int(new_img_width / aspect_ratio)
31
+ else:
32
+ new_img_height = new_height
33
+ new_img_width = int(new_img_height * aspect_ratio)
34
+
35
+ resized_img = images.resize((new_img_width, new_img_height))
36
+
37
+ paste_x = (new_width - new_img_width) // 2
38
+ paste_y = (new_height - new_img_height) // 2
39
+
40
+ new_image.paste(resized_img, (paste_x, paste_y))
41
+
42
+ return new_image
43
+
44
+
45
+ def get_image_latent(ref_image=None, sample_size=None, padding=False):
46
+ if ref_image is not None:
47
+ if isinstance(ref_image, str):
48
+ ref_image = Image.open(ref_image).convert("RGB")
49
+ if padding:
50
+ ref_image = padding_image(
51
+ ref_image, sample_size[1], sample_size[0])
52
+ ref_image = ref_image.resize((sample_size[1], sample_size[0]))
53
+ ref_image = torch.from_numpy(np.array(ref_image))
54
+ ref_image = ref_image.unsqueeze(0).permute(
55
+ [3, 0, 1, 2]).unsqueeze(0) / 255
56
+ elif isinstance(ref_image, Image.Image):
57
+ ref_image = ref_image.convert("RGB")
58
+ if padding:
59
+ ref_image = padding_image(
60
+ ref_image, sample_size[1], sample_size[0])
61
+ ref_image = ref_image.resize((sample_size[1], sample_size[0]))
62
+ ref_image = torch.from_numpy(np.array(ref_image))
63
+ ref_image = ref_image.unsqueeze(0).permute(
64
+ [3, 0, 1, 2]).unsqueeze(0) / 255
65
+ else:
66
+ ref_image = torch.from_numpy(np.array(ref_image))
67
+ ref_image = ref_image.unsqueeze(0).permute(
68
+ [3, 0, 1, 2]).unsqueeze(0) / 255
69
+
70
+ return ref_image
predict_t2i_control.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ import numpy as np
5
+ import torch
6
+ from diffusers import FlowMatchEulerDiscreteScheduler
7
+ from omegaconf import OmegaConf
8
+ from PIL import Image
9
+
10
+ current_file_path = os.path.abspath(__file__)
11
+ project_roots = [os.path.dirname(current_file_path), os.path.dirname(os.path.dirname(current_file_path)), os.path.dirname(os.path.dirname(os.path.dirname(current_file_path)))]
12
+ for project_root in project_roots:
13
+ sys.path.insert(0, project_root) if project_root not in sys.path else None
14
+
15
+ from videox_fun.dist import set_multi_gpus_devices, shard_model
16
+ from videox_fun.models import (AutoencoderKL, AutoTokenizer,
17
+ Qwen3ForCausalLM, ZImageControlTransformer2DModel)
18
+ from videox_fun.models.cache_utils import get_teacache_coefficients
19
+ from videox_fun.pipeline import ZImageControlPipeline
20
+ from videox_fun.utils.fm_solvers import FlowDPMSolverMultistepScheduler
21
+ from videox_fun.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
22
+ from videox_fun.utils.fp8_optimization import (convert_model_weight_to_float8,
23
+ convert_weight_dtype_wrapper)
24
+ from videox_fun.utils.lora_utils import merge_lora, unmerge_lora
25
+ from videox_fun.utils.utils import (filter_kwargs, get_image_to_video_latent, get_image_latent, get_image,
26
+ get_video_to_video_latent,
27
+ save_videos_grid)
28
+
29
+ # GPU memory mode, which can be chosen in [model_full_load, model_full_load_and_qfloat8, model_cpu_offload, model_cpu_offload_and_qfloat8, sequential_cpu_offload].
30
+ # model_full_load means that the entire model will be moved to the GPU.
31
+ #
32
+ # model_full_load_and_qfloat8 means that the entire model will be moved to the GPU,
33
+ # and the transformer model has been quantized to float8, which can save more GPU memory.
34
+ #
35
+ # model_cpu_offload means that the entire model will be moved to the CPU after use, which can save some GPU memory.
36
+ #
37
+ # model_cpu_offload_and_qfloat8 indicates that the entire model will be moved to the CPU after use,
38
+ # and the transformer model has been quantized to float8, which can save more GPU memory.
39
+ #
40
+ # sequential_cpu_offload means that each layer of the model will be moved to the CPU after use,
41
+ # resulting in slower speeds but saving a large amount of GPU memory.
42
+ GPU_memory_mode = "model_cpu_offload"
43
+ # Multi GPUs config
44
+ # Please ensure that the product of ulysses_degree and ring_degree equals the number of GPUs used.
45
+ # For example, if you are using 8 GPUs, you can set ulysses_degree = 2 and ring_degree = 4.
46
+ # If you are using 1 GPU, you can set ulysses_degree = 1 and ring_degree = 1.
47
+ ulysses_degree = 1
48
+ ring_degree = 1
49
+ # Use FSDP to save more GPU memory in multi gpus.
50
+ fsdp_dit = False
51
+ fsdp_text_encoder = False
52
+ # Compile will give a speedup in fixed resolution and need a little GPU memory.
53
+ # The compile_dit is not compatible with the fsdp_dit and sequential_cpu_offload.
54
+ compile_dit = False
55
+
56
+ # Config and model path
57
+ config_path = "config/z_image/z_image_control.yaml"
58
+ # model path
59
+ model_name = "models/Diffusion_Transformer/Z-Image-Turbo/"
60
+
61
+ # Choose the sampler in "Flow", "Flow_Unipc", "Flow_DPM++"
62
+ sampler_name = "Flow"
63
+
64
+ # Load pretrained model if need
65
+ transformer_path = "models/Personalized_Model/Z-Image-Turbo-Fun-Controlnet-Union.safetensors"
66
+ vae_path = None
67
+ lora_path = None
68
+
69
+ # Other params
70
+ sample_size = [1728, 992]
71
+
72
+ # Use torch.float16 if GPU does not support torch.bfloat16
73
+ # ome graphics cards, such as v100, 2080ti, do not support torch.bfloat16
74
+ weight_dtype = torch.bfloat16
75
+ control_image = "asset/pose.jpg"
76
+ control_context_scale = 0.75
77
+
78
+ # 使用更长的neg prompt如"模糊,突变,变形,失真,画面暗,文本字幕,画面固定,连环画,漫画,线稿,没有主体。",可以增加稳定性
79
+ # 在neg prompt中添加"安静,固定"等词语可以增加动态性。
80
+ prompt = "一位年轻女子站在阳光明媚的海岸线上,白裙在轻拂的海风中微微飘动。她拥有一头鲜艳的紫色长发,在风中轻盈舞动,发间系着一个精致的黑色蝴蝶结,与身后柔和的蔚蓝天空形成鲜明对比。她面容清秀,眉目精致,透着一股甜美的青春气息;神情柔和,略带羞涩,目光静静地凝望着远方的地平线,双手自然交叠于身前,仿佛沉浸在思绪之中。在她身后,是辽阔无垠、波光粼粼的大海,阳光洒在海面上,映出温暖的金色光晕。"
81
+ negative_prompt = " "
82
+ guidance_scale = 0.00
83
+ seed = 43
84
+ num_inference_steps = 9
85
+ lora_weight = 0.55
86
+ save_path = "samples/z-image-t2i-control"
87
+
88
+ device = set_multi_gpus_devices(ulysses_degree, ring_degree)
89
+ config = OmegaConf.load(config_path)
90
+
91
+ transformer = ZImageControlTransformer2DModel.from_pretrained(
92
+ model_name,
93
+ subfolder="transformer",
94
+ low_cpu_mem_usage=True,
95
+ torch_dtype=weight_dtype,
96
+ transformer_additional_kwargs=OmegaConf.to_container(config['transformer_additional_kwargs']),
97
+ ).to(weight_dtype)
98
+
99
+ if transformer_path is not None:
100
+ print(f"From checkpoint: {transformer_path}")
101
+ if transformer_path.endswith("safetensors"):
102
+ from safetensors.torch import load_file, safe_open
103
+ state_dict = load_file(transformer_path)
104
+ else:
105
+ state_dict = torch.load(transformer_path, map_location="cpu")
106
+ state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict
107
+
108
+ m, u = transformer.load_state_dict(state_dict, strict=False)
109
+ print(f"missing keys: {len(m)}, unexpected keys: {len(u)}")
110
+
111
+ # Get Vae
112
+ vae = AutoencoderKL.from_pretrained(
113
+ model_name,
114
+ subfolder="vae"
115
+ ).to(weight_dtype)
116
+
117
+ if vae_path is not None:
118
+ print(f"From checkpoint: {vae_path}")
119
+ if vae_path.endswith("safetensors"):
120
+ from safetensors.torch import load_file, safe_open
121
+ state_dict = load_file(vae_path)
122
+ else:
123
+ state_dict = torch.load(vae_path, map_location="cpu")
124
+ state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict
125
+
126
+ m, u = vae.load_state_dict(state_dict, strict=False)
127
+ print(f"missing keys: {len(m)}, unexpected keys: {len(u)}")
128
+
129
+ # Get tokenizer and text_encoder
130
+ tokenizer = AutoTokenizer.from_pretrained(
131
+ model_name, subfolder="tokenizer"
132
+ )
133
+ text_encoder = Qwen3ForCausalLM.from_pretrained(
134
+ model_name, subfolder="text_encoder", torch_dtype=weight_dtype,
135
+ low_cpu_mem_usage=True,
136
+ )
137
+
138
+ # Get Scheduler
139
+ Chosen_Scheduler = scheduler_dict = {
140
+ "Flow": FlowMatchEulerDiscreteScheduler,
141
+ "Flow_Unipc": FlowUniPCMultistepScheduler,
142
+ "Flow_DPM++": FlowDPMSolverMultistepScheduler,
143
+ }[sampler_name]
144
+ scheduler = Chosen_Scheduler.from_pretrained(
145
+ model_name,
146
+ subfolder="scheduler"
147
+ )
148
+
149
+ pipeline = ZImageControlPipeline(
150
+ vae=vae,
151
+ tokenizer=tokenizer,
152
+ text_encoder=text_encoder,
153
+ transformer=transformer,
154
+ scheduler=scheduler,
155
+ )
156
+
157
+ if ulysses_degree > 1 or ring_degree > 1:
158
+ from functools import partial
159
+ transformer.enable_multi_gpus_inference()
160
+ if fsdp_dit:
161
+ shard_fn = partial(shard_model, device_id=device, param_dtype=weight_dtype, module_to_wrapper=list(transformer.transformer_blocks) + list(transformer.single_transformer_blocks))
162
+ pipeline.transformer = shard_fn(pipeline.transformer)
163
+ print("Add FSDP DIT")
164
+ if fsdp_text_encoder:
165
+ shard_fn = partial(shard_model, device_id=device, param_dtype=weight_dtype, module_to_wrapper=text_encoder.language_model.layers, ignored_modules=[text_encoder.language_model.embed_tokens], transformer_layer_cls_to_wrap=["MistralDecoderLayer", "PixtralTransformer"])
166
+ text_encoder = shard_fn(text_encoder)
167
+ print("Add FSDP TEXT ENCODER")
168
+
169
+ if compile_dit:
170
+ for i in range(len(pipeline.transformer.transformer_blocks)):
171
+ pipeline.transformer.transformer_blocks[i] = torch.compile(pipeline.transformer.transformer_blocks[i])
172
+ print("Add Compile")
173
+
174
+ if GPU_memory_mode == "sequential_cpu_offload":
175
+ pipeline.enable_sequential_cpu_offload(device=device)
176
+ elif GPU_memory_mode == "model_cpu_offload_and_qfloat8":
177
+ convert_model_weight_to_float8(transformer, exclude_module_name=["img_in", "txt_in", "timestep"], device=device)
178
+ convert_weight_dtype_wrapper(transformer, weight_dtype)
179
+ pipeline.enable_model_cpu_offload(device=device)
180
+ elif GPU_memory_mode == "model_cpu_offload":
181
+ pipeline.enable_model_cpu_offload(device=device)
182
+ elif GPU_memory_mode == "model_full_load_and_qfloat8":
183
+ convert_model_weight_to_float8(transformer, exclude_module_name=["img_in", "txt_in", "timestep"], device=device)
184
+ convert_weight_dtype_wrapper(transformer, weight_dtype)
185
+ pipeline.to(device=device)
186
+ else:
187
+ pipeline.to(device=device)
188
+
189
+ generator = torch.Generator(device=device).manual_seed(seed)
190
+
191
+ if lora_path is not None:
192
+ pipeline = merge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype)
193
+
194
+ with torch.no_grad():
195
+ if control_image is not None:
196
+ control_image = get_image_latent(control_image, sample_size=sample_size)[:, :, 0]
197
+
198
+ sample = pipeline(
199
+ prompt = prompt,
200
+ negative_prompt = negative_prompt,
201
+ height = sample_size[0],
202
+ width = sample_size[1],
203
+ generator = generator,
204
+ guidance_scale = guidance_scale,
205
+ control_image = control_image,
206
+ num_inference_steps = num_inference_steps,
207
+ control_context_scale = control_context_scale,
208
+ ).images
209
+
210
+ if lora_path is not None:
211
+ pipeline = unmerge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype)
212
+
213
+ def save_results():
214
+ if not os.path.exists(save_path):
215
+ os.makedirs(save_path, exist_ok=True)
216
+
217
+ index = len([path for path in os.listdir(save_path)]) + 1
218
+ prefix = str(index).zfill(8)
219
+ video_path = os.path.join(save_path, prefix + ".png")
220
+ image = sample[0]
221
+ image.save(video_path)
222
+
223
+ if ulysses_degree * ring_degree > 1:
224
+ import torch.distributed as dist
225
+ if dist.get_rank() == 0:
226
+ save_results()
227
+ else:
228
+ save_results()
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ gradio
2
+ torch
3
+ transformers
4
+ accelerate
5
+ spaces
6
+ git+https://github.com/huggingface/diffusers.git
7
+ kernels
static/data.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "examples": [
3
+ ["examples/hed.jpg", "A middle-aged man with a short beard, wearing a casual button-down shirt, sitting at a polished dark wooden table, holding a tumbler of whiskey with ice and taking a thoughtful sip. The background is a softly lit."],
4
+ ["examples/depth.jpg", "Modern minimalist, clean lines, open plan, natural light, spacious, serene, contemporary, elegant, architectural, inviting, sophisticated, light-filled, harmonious, texture, shadows, high ceilings."],
5
+ ["examples/pose.jpg", "A fit, athletic young woman, squatting low, glancing confidently at the camera. She's on a picturesque tropical beach with gentle waves lapping the shore. The image has the crisp, high-contrast look of a fashion magazine cover. Dynamic pose, bright and inviting."],
6
+ ["examples/pose2.jpg", "A majestic female paladin in gleaming plate armor, standing tall and proud, bathed in a celestial glow, with a determined expression, holding a radiant sword aloft against a backdrop of a sun-drenched, ancient castle."]
7
+ ]
8
+ }
static/footer.html ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div>
2
+ I made this space after seeing a Reddit post about using ControlNet editing with Z-Image from Alibaba.
3
+ The code looks solid and serves as a great example. I believe there’s a lot of potential to build on top of this, add new features, and explore even more creative ideas using this technique.
4
+
5
+ <h2>Usage</h2>
6
+ You can change control_context_scale for more control and better detail.
7
+ For best results, use a detailed prompt.
8
+ The recommended control_context_scale range is 0.65 to 0.80.
9
+
10
+ <h2>Reference</h2>
11
+ <ul>
12
+ <li>Tongyi-MAI/Z-Image-Turbo: <a href="https://huggingface.co/Tongyi-MAI/Z-Image-Turbo">https://huggingface.co/Tongyi-MAI/Z-Image-Turbo</a></li>
13
+ <li>alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union: <a href="https://huggingface.co/alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union">https://huggingface.co/alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union</a></li>
14
+ <li>VideoX-Fun: <a href="https://github.com/aigc-apps/VideoX-Fun">https://github.com/aigc-apps/VideoX-Fun</a></li>
15
+ </ul>
16
+ </div>
static/header.html ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div style="text-align: center; max-width: 600px; margin: 0 auto;">
2
+ <h1>
3
+ Z Image Turbo (ZIT) - Controlnet
4
+ </h1>
5
+ <div class="grid-container" >
6
+ <p>
7
+ Supports multiple control conditions - including Canny, HED, Depth, Pose and MLSD.
8
+ <br>
9
+ If you like my spaces, please support me by visiting <a href="https://aisudo.com/" target="_blank">AiSudo</a> for more image generation 😊
10
+ </div>
11
+ </div>
videox_fun/__init__.py ADDED
File without changes
videox_fun/api/api.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import gc
3
+ import hashlib
4
+ import io
5
+ import os
6
+ import tempfile
7
+ from io import BytesIO
8
+
9
+ import gradio as gr
10
+ import requests
11
+ import torch
12
+ from fastapi import FastAPI
13
+ from PIL import Image
14
+
15
+
16
+ # Function to encode a file to Base64
17
+ def encode_file_to_base64(file_path):
18
+ with open(file_path, "rb") as file:
19
+ # Encode the data to Base64
20
+ file_base64 = base64.b64encode(file.read())
21
+ return file_base64
22
+
23
+ def update_diffusion_transformer_api(_: gr.Blocks, app: FastAPI, controller):
24
+ @app.post("/videox_fun/update_diffusion_transformer")
25
+ def _update_diffusion_transformer_api(
26
+ datas: dict,
27
+ ):
28
+ diffusion_transformer_path = datas.get('diffusion_transformer_path', 'none')
29
+
30
+ try:
31
+ controller.update_diffusion_transformer(
32
+ diffusion_transformer_path
33
+ )
34
+ comment = "Success"
35
+ except Exception as e:
36
+ torch.cuda.empty_cache()
37
+ comment = f"Error. error information is {str(e)}"
38
+
39
+ return {"message": comment}
40
+
41
+ def download_from_url(url, timeout=10):
42
+ try:
43
+ response = requests.get(url, timeout=timeout)
44
+ response.raise_for_status() # 检查请求是否成功
45
+ return response.content
46
+ except requests.exceptions.RequestException as e:
47
+ print(f"Error downloading from {url}: {e}")
48
+ return None
49
+
50
+ def save_base64_video(base64_string):
51
+ video_data = base64.b64decode(base64_string)
52
+
53
+ md5_hash = hashlib.md5(video_data).hexdigest()
54
+ filename = f"{md5_hash}.mp4"
55
+
56
+ temp_dir = tempfile.gettempdir()
57
+ file_path = os.path.join(temp_dir, filename)
58
+
59
+ with open(file_path, 'wb') as video_file:
60
+ video_file.write(video_data)
61
+
62
+ return file_path
63
+
64
+ def save_base64_image(base64_string):
65
+ video_data = base64.b64decode(base64_string)
66
+
67
+ md5_hash = hashlib.md5(video_data).hexdigest()
68
+ filename = f"{md5_hash}.jpg"
69
+
70
+ temp_dir = tempfile.gettempdir()
71
+ file_path = os.path.join(temp_dir, filename)
72
+
73
+ with open(file_path, 'wb') as video_file:
74
+ video_file.write(video_data)
75
+
76
+ return file_path
77
+
78
+ def save_url_video(url):
79
+ video_data = download_from_url(url)
80
+ if video_data:
81
+ return save_base64_video(base64.b64encode(video_data))
82
+ return None
83
+
84
+ def save_url_image(url):
85
+ image_data = download_from_url(url)
86
+ if image_data:
87
+ return save_base64_image(base64.b64encode(image_data))
88
+ return None
89
+
90
+ def infer_forward_api(_: gr.Blocks, app: FastAPI, controller):
91
+ @app.post("/videox_fun/infer_forward")
92
+ def _infer_forward_api(
93
+ datas: dict,
94
+ ):
95
+ base_model_path = datas.get('base_model_path', 'none')
96
+ base_model_2_path = datas.get('base_model_2_path', 'none')
97
+ lora_model_path = datas.get('lora_model_path', 'none')
98
+ lora_model_2_path = datas.get('lora_model_2_path', 'none')
99
+ lora_alpha_slider = datas.get('lora_alpha_slider', 0.55)
100
+ prompt_textbox = datas.get('prompt_textbox', None)
101
+ negative_prompt_textbox = datas.get('negative_prompt_textbox', 'The video is not of a high quality, it has a low resolution. Watermark present in each frame. The background is solid. Strange body and strange trajectory. Distortion. ')
102
+ sampler_dropdown = datas.get('sampler_dropdown', 'Euler')
103
+ sample_step_slider = datas.get('sample_step_slider', 30)
104
+ resize_method = datas.get('resize_method', "Generate by")
105
+ width_slider = datas.get('width_slider', 672)
106
+ height_slider = datas.get('height_slider', 384)
107
+ base_resolution = datas.get('base_resolution', 512)
108
+ is_image = datas.get('is_image', False)
109
+ generation_method = datas.get('generation_method', False)
110
+ length_slider = datas.get('length_slider', 49)
111
+ overlap_video_length = datas.get('overlap_video_length', 4)
112
+ partial_video_length = datas.get('partial_video_length', 72)
113
+ cfg_scale_slider = datas.get('cfg_scale_slider', 6)
114
+ start_image = datas.get('start_image', None)
115
+ end_image = datas.get('end_image', None)
116
+ validation_video = datas.get('validation_video', None)
117
+ validation_video_mask = datas.get('validation_video_mask', None)
118
+ control_video = datas.get('control_video', None)
119
+ denoise_strength = datas.get('denoise_strength', 0.70)
120
+ seed_textbox = datas.get("seed_textbox", 43)
121
+
122
+ ref_image = datas.get('ref_image', None)
123
+ enable_teacache = datas.get('enable_teacache', True)
124
+ teacache_threshold = datas.get('teacache_threshold', 0.10)
125
+ num_skip_start_steps = datas.get('num_skip_start_steps', 1)
126
+ teacache_offload = datas.get('teacache_offload', False)
127
+ cfg_skip_ratio = datas.get('cfg_skip_ratio', 0)
128
+ enable_riflex = datas.get('enable_riflex', False)
129
+ riflex_k = datas.get('riflex_k', 6)
130
+ fps = datas.get('fps', None)
131
+
132
+ generation_method = "Image Generation" if is_image else generation_method
133
+
134
+ if start_image is not None:
135
+ if start_image.startswith('http'):
136
+ start_image = save_url_image(start_image)
137
+ start_image = [Image.open(start_image).convert("RGB")]
138
+ else:
139
+ start_image = base64.b64decode(start_image)
140
+ start_image = [Image.open(BytesIO(start_image)).convert("RGB")]
141
+
142
+ if end_image is not None:
143
+ if end_image.startswith('http'):
144
+ end_image = save_url_image(end_image)
145
+ end_image = [Image.open(end_image).convert("RGB")]
146
+ else:
147
+ end_image = base64.b64decode(end_image)
148
+ end_image = [Image.open(BytesIO(end_image)).convert("RGB")]
149
+
150
+ if validation_video is not None:
151
+ if validation_video.startswith('http'):
152
+ validation_video = save_url_video(validation_video)
153
+ else:
154
+ validation_video = save_base64_video(validation_video)
155
+
156
+ if validation_video_mask is not None:
157
+ if validation_video_mask.startswith('http'):
158
+ validation_video_mask = save_url_image(validation_video_mask)
159
+ else:
160
+ validation_video_mask = save_base64_image(validation_video_mask)
161
+
162
+ if control_video is not None:
163
+ if control_video.startswith('http'):
164
+ control_video = save_url_video(control_video)
165
+ else:
166
+ control_video = save_base64_video(control_video)
167
+
168
+ if ref_image is not None:
169
+ if ref_image.startswith('http'):
170
+ ref_image = save_url_image(ref_image)
171
+ ref_image = [Image.open(ref_image).convert("RGB")]
172
+ else:
173
+ ref_image = base64.b64decode(ref_image)
174
+ ref_image = [Image.open(BytesIO(ref_image)).convert("RGB")]
175
+
176
+ try:
177
+ save_sample_path, comment = controller.generate(
178
+ "",
179
+ base_model_path,
180
+ lora_model_path,
181
+ lora_alpha_slider,
182
+ prompt_textbox,
183
+ negative_prompt_textbox,
184
+ sampler_dropdown,
185
+ sample_step_slider,
186
+ resize_method,
187
+ width_slider,
188
+ height_slider,
189
+ base_resolution,
190
+ generation_method,
191
+ length_slider,
192
+ overlap_video_length,
193
+ partial_video_length,
194
+ cfg_scale_slider,
195
+ start_image,
196
+ end_image,
197
+ validation_video,
198
+ validation_video_mask,
199
+ control_video,
200
+ denoise_strength,
201
+ seed_textbox,
202
+ ref_image = ref_image,
203
+ enable_teacache = enable_teacache,
204
+ teacache_threshold = teacache_threshold,
205
+ num_skip_start_steps = num_skip_start_steps,
206
+ teacache_offload = teacache_offload,
207
+ cfg_skip_ratio = cfg_skip_ratio,
208
+ enable_riflex = enable_riflex,
209
+ riflex_k = riflex_k,
210
+ base_model_2_dropdown = base_model_2_path,
211
+ lora_model_2_dropdown = lora_model_2_path,
212
+ fps = fps,
213
+ is_api = True,
214
+ )
215
+ except Exception as e:
216
+ gc.collect()
217
+ torch.cuda.empty_cache()
218
+ torch.cuda.ipc_collect()
219
+ save_sample_path = ""
220
+ comment = f"Error. error information is {str(e)}"
221
+ return {"message": comment, "save_sample_path": None, "base64_encoding": None}
222
+
223
+ if save_sample_path != "":
224
+ return {"message": comment, "save_sample_path": save_sample_path, "base64_encoding": encode_file_to_base64(save_sample_path)}
225
+ else:
226
+ return {"message": comment, "save_sample_path": save_sample_path, "base64_encoding": None}
videox_fun/api/api_multi_nodes.py ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file is modified from https://github.com/xdit-project/xDiT/blob/main/entrypoints/launch.py
2
+ import base64
3
+ import gc
4
+ import hashlib
5
+ import io
6
+ import os
7
+ import tempfile
8
+ from io import BytesIO
9
+
10
+ import gradio as gr
11
+ import requests
12
+ import torch
13
+ import torch.distributed as dist
14
+ from fastapi import FastAPI, HTTPException
15
+ from PIL import Image
16
+
17
+ from .api import download_from_url, encode_file_to_base64
18
+
19
+ try:
20
+ import ray
21
+ except:
22
+ print("Ray is not installed. If you want to use multi gpus api. Please install it by running 'pip install ray'.")
23
+ ray = None
24
+
25
+ def save_base64_video_dist(base64_string):
26
+ video_data = base64.b64decode(base64_string)
27
+
28
+ md5_hash = hashlib.md5(video_data).hexdigest()
29
+ filename = f"{md5_hash}.mp4"
30
+
31
+ temp_dir = tempfile.gettempdir()
32
+ file_path = os.path.join(temp_dir, filename)
33
+
34
+ if dist.is_initialized():
35
+ if dist.get_rank() == 0:
36
+ with open(file_path, 'wb') as video_file:
37
+ video_file.write(video_data)
38
+ dist.barrier()
39
+ else:
40
+ with open(file_path, 'wb') as video_file:
41
+ video_file.write(video_data)
42
+ return file_path
43
+
44
+ def save_base64_image_dist(base64_string):
45
+ video_data = base64.b64decode(base64_string)
46
+
47
+ md5_hash = hashlib.md5(video_data).hexdigest()
48
+ filename = f"{md5_hash}.jpg"
49
+
50
+ temp_dir = tempfile.gettempdir()
51
+ file_path = os.path.join(temp_dir, filename)
52
+
53
+ if dist.is_initialized():
54
+ if dist.get_rank() == 0:
55
+ with open(file_path, 'wb') as video_file:
56
+ video_file.write(video_data)
57
+ dist.barrier()
58
+ else:
59
+ with open(file_path, 'wb') as video_file:
60
+ video_file.write(video_data)
61
+ return file_path
62
+
63
+ def save_url_video_dist(url):
64
+ video_data = download_from_url(url)
65
+ if video_data:
66
+ return save_base64_video_dist(base64.b64encode(video_data))
67
+ return None
68
+
69
+ def save_url_image_dist(url):
70
+ image_data = download_from_url(url)
71
+ if image_data:
72
+ return save_base64_image_dist(base64.b64encode(image_data))
73
+ return None
74
+
75
+ if ray is not None:
76
+ @ray.remote(num_gpus=1)
77
+ class MultiNodesGenerator:
78
+ def __init__(
79
+ self, rank: int, world_size: int, Controller,
80
+ GPU_memory_mode, scheduler_dict, model_name=None, model_type="Inpaint",
81
+ config_path=None, ulysses_degree=1, ring_degree=1,
82
+ fsdp_dit=False, fsdp_text_encoder=False, compile_dit=False,
83
+ weight_dtype=None, savedir_sample=None,
84
+ ):
85
+ # Set PyTorch distributed environment variables
86
+ os.environ["RANK"] = str(rank)
87
+ os.environ["WORLD_SIZE"] = str(world_size)
88
+ os.environ["MASTER_ADDR"] = "127.0.0.1"
89
+ os.environ["MASTER_PORT"] = "29500"
90
+
91
+ self.rank = rank
92
+ self.controller = Controller(
93
+ GPU_memory_mode, scheduler_dict, model_name=model_name, model_type=model_type, config_path=config_path,
94
+ ulysses_degree=ulysses_degree, ring_degree=ring_degree,
95
+ fsdp_dit=fsdp_dit, fsdp_text_encoder=fsdp_text_encoder, compile_dit=compile_dit,
96
+ weight_dtype=weight_dtype, savedir_sample=savedir_sample,
97
+ )
98
+
99
+ def generate(self, datas):
100
+ try:
101
+ base_model_path = datas.get('base_model_path', 'none')
102
+ base_model_2_path = datas.get('base_model_2_path', 'none')
103
+ lora_model_path = datas.get('lora_model_path', 'none')
104
+ lora_model_2_path = datas.get('lora_model_2_path', 'none')
105
+ lora_alpha_slider = datas.get('lora_alpha_slider', 0.55)
106
+ prompt_textbox = datas.get('prompt_textbox', None)
107
+ negative_prompt_textbox = datas.get('negative_prompt_textbox', 'The video is not of a high quality, it has a low resolution. Watermark present in each frame. The background is solid. Strange body and strange trajectory. Distortion. ')
108
+ sampler_dropdown = datas.get('sampler_dropdown', 'Euler')
109
+ sample_step_slider = datas.get('sample_step_slider', 30)
110
+ resize_method = datas.get('resize_method', "Generate by")
111
+ width_slider = datas.get('width_slider', 672)
112
+ height_slider = datas.get('height_slider', 384)
113
+ base_resolution = datas.get('base_resolution', 512)
114
+ is_image = datas.get('is_image', False)
115
+ generation_method = datas.get('generation_method', False)
116
+ length_slider = datas.get('length_slider', 49)
117
+ overlap_video_length = datas.get('overlap_video_length', 4)
118
+ partial_video_length = datas.get('partial_video_length', 72)
119
+ cfg_scale_slider = datas.get('cfg_scale_slider', 6)
120
+ start_image = datas.get('start_image', None)
121
+ end_image = datas.get('end_image', None)
122
+ validation_video = datas.get('validation_video', None)
123
+ validation_video_mask = datas.get('validation_video_mask', None)
124
+ control_video = datas.get('control_video', None)
125
+ denoise_strength = datas.get('denoise_strength', 0.70)
126
+ seed_textbox = datas.get("seed_textbox", 43)
127
+
128
+ ref_image = datas.get('ref_image', None)
129
+ enable_teacache = datas.get('enable_teacache', True)
130
+ teacache_threshold = datas.get('teacache_threshold', 0.10)
131
+ num_skip_start_steps = datas.get('num_skip_start_steps', 1)
132
+ teacache_offload = datas.get('teacache_offload', False)
133
+ cfg_skip_ratio = datas.get('cfg_skip_ratio', 0)
134
+ enable_riflex = datas.get('enable_riflex', False)
135
+ riflex_k = datas.get('riflex_k', 6)
136
+ fps = datas.get('fps', None)
137
+
138
+ generation_method = "Image Generation" if is_image else generation_method
139
+
140
+ if start_image is not None:
141
+ if start_image.startswith('http'):
142
+ start_image = save_url_image_dist(start_image)
143
+ start_image = [Image.open(start_image).convert("RGB")]
144
+ else:
145
+ start_image = base64.b64decode(start_image)
146
+ start_image = [Image.open(BytesIO(start_image)).convert("RGB")]
147
+
148
+ if end_image is not None:
149
+ if end_image.startswith('http'):
150
+ end_image = save_url_image_dist(end_image)
151
+ end_image = [Image.open(end_image).convert("RGB")]
152
+ else:
153
+ end_image = base64.b64decode(end_image)
154
+ end_image = [Image.open(BytesIO(end_image)).convert("RGB")]
155
+
156
+ if validation_video is not None:
157
+ if validation_video.startswith('http'):
158
+ validation_video = save_url_video_dist(validation_video)
159
+ else:
160
+ validation_video = save_base64_video_dist(validation_video)
161
+
162
+ if validation_video_mask is not None:
163
+ if validation_video_mask.startswith('http'):
164
+ validation_video_mask = save_url_image_dist(validation_video_mask)
165
+ else:
166
+ validation_video_mask = save_base64_image_dist(validation_video_mask)
167
+
168
+ if control_video is not None:
169
+ if control_video.startswith('http'):
170
+ control_video = save_url_video_dist(control_video)
171
+ else:
172
+ control_video = save_base64_video_dist(control_video)
173
+
174
+ if ref_image is not None:
175
+ if ref_image.startswith('http'):
176
+ ref_image = save_url_image_dist(ref_image)
177
+ ref_image = [Image.open(ref_image).convert("RGB")]
178
+ else:
179
+ ref_image = base64.b64decode(ref_image)
180
+ ref_image = [Image.open(BytesIO(ref_image)).convert("RGB")]
181
+
182
+ try:
183
+ save_sample_path, comment = self.controller.generate(
184
+ "",
185
+ base_model_path,
186
+ lora_model_path,
187
+ lora_alpha_slider,
188
+ prompt_textbox,
189
+ negative_prompt_textbox,
190
+ sampler_dropdown,
191
+ sample_step_slider,
192
+ resize_method,
193
+ width_slider,
194
+ height_slider,
195
+ base_resolution,
196
+ generation_method,
197
+ length_slider,
198
+ overlap_video_length,
199
+ partial_video_length,
200
+ cfg_scale_slider,
201
+ start_image,
202
+ end_image,
203
+ validation_video,
204
+ validation_video_mask,
205
+ control_video,
206
+ denoise_strength,
207
+ seed_textbox,
208
+ ref_image = ref_image,
209
+ enable_teacache = enable_teacache,
210
+ teacache_threshold = teacache_threshold,
211
+ num_skip_start_steps = num_skip_start_steps,
212
+ teacache_offload = teacache_offload,
213
+ cfg_skip_ratio = cfg_skip_ratio,
214
+ enable_riflex = enable_riflex,
215
+ riflex_k = riflex_k,
216
+ base_model_2_dropdown = base_model_2_path,
217
+ lora_model_2_dropdown = lora_model_2_path,
218
+ fps = fps,
219
+ is_api = True,
220
+ )
221
+ except Exception as e:
222
+ gc.collect()
223
+ torch.cuda.empty_cache()
224
+ torch.cuda.ipc_collect()
225
+ save_sample_path = ""
226
+ comment = f"Error. error information is {str(e)}"
227
+ if dist.is_initialized():
228
+ if dist.get_rank() == 0:
229
+ return {"message": comment, "save_sample_path": None, "base64_encoding": None}
230
+ else:
231
+ return None
232
+ else:
233
+ return {"message": comment, "save_sample_path": None, "base64_encoding": None}
234
+
235
+
236
+ if dist.is_initialized():
237
+ if dist.get_rank() == 0:
238
+ if save_sample_path != "":
239
+ return {"message": comment, "save_sample_path": save_sample_path, "base64_encoding": encode_file_to_base64(save_sample_path)}
240
+ else:
241
+ return {"message": comment, "save_sample_path": None, "base64_encoding": None}
242
+ else:
243
+ return None
244
+ else:
245
+ if save_sample_path != "":
246
+ return {"message": comment, "save_sample_path": save_sample_path, "base64_encoding": encode_file_to_base64(save_sample_path)}
247
+ else:
248
+ return {"message": comment, "save_sample_path": None, "base64_encoding": None}
249
+
250
+ except Exception as e:
251
+ print(f"Error generating: {str(e)}")
252
+ comment = f"Error generating: {str(e)}"
253
+ if dist.is_initialized():
254
+ if dist.get_rank() == 0:
255
+ return {"message": comment, "save_sample_path": None, "base64_encoding": None}
256
+ else:
257
+ return None
258
+ else:
259
+ return {"message": comment, "save_sample_path": None, "base64_encoding": None}
260
+
261
+ class MultiNodesEngine:
262
+ def __init__(
263
+ self,
264
+ world_size,
265
+ Controller,
266
+ GPU_memory_mode,
267
+ scheduler_dict,
268
+ model_name,
269
+ model_type,
270
+ config_path,
271
+ ulysses_degree=1,
272
+ ring_degree=1,
273
+ fsdp_dit=False,
274
+ fsdp_text_encoder=False,
275
+ compile_dit=False,
276
+ weight_dtype=torch.bfloat16,
277
+ savedir_sample="samples"
278
+ ):
279
+ # Ensure Ray is initialized
280
+ if not ray.is_initialized():
281
+ ray.init()
282
+
283
+ num_workers = world_size
284
+ self.workers = [
285
+ MultiNodesGenerator.remote(
286
+ rank, world_size, Controller,
287
+ GPU_memory_mode, scheduler_dict, model_name=model_name, model_type=model_type, config_path=config_path,
288
+ ulysses_degree=ulysses_degree, ring_degree=ring_degree,
289
+ fsdp_dit=fsdp_dit, fsdp_text_encoder=fsdp_text_encoder, compile_dit=compile_dit,
290
+ weight_dtype=weight_dtype, savedir_sample=savedir_sample,
291
+ )
292
+ for rank in range(num_workers)
293
+ ]
294
+ print("Update workers done")
295
+
296
+ async def generate(self, data):
297
+ results = ray.get([
298
+ worker.generate.remote(data)
299
+ for worker in self.workers
300
+ ])
301
+
302
+ return next(path for path in results if path is not None)
303
+
304
+ def multi_nodes_infer_forward_api(_: gr.Blocks, app: FastAPI, engine):
305
+
306
+ @app.post("/videox_fun/infer_forward")
307
+ async def _multi_nodes_infer_forward_api(
308
+ datas: dict,
309
+ ):
310
+ try:
311
+ result = await engine.generate(datas)
312
+ return result
313
+ except Exception as e:
314
+ if isinstance(e, HTTPException):
315
+ raise e
316
+ raise HTTPException(status_code=500, detail=str(e))
317
+ else:
318
+ MultiNodesEngine = None
319
+ MultiNodesGenerator = None
320
+ multi_nodes_infer_forward_api = None
videox_fun/data/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from .dataset_image import CC15M, ImageEditDataset
2
+ from .dataset_image_video import (ImageVideoControlDataset, ImageVideoDataset, TextDataset,
3
+ ImageVideoSampler)
4
+ from .dataset_video import VideoDataset, VideoSpeechDataset, VideoAnimateDataset, WebVid10M
5
+ from .utils import (VIDEO_READER_TIMEOUT, Camera, VideoReader_contextmanager,
6
+ custom_meshgrid, get_random_mask, get_relative_pose,
7
+ get_video_reader_batch, padding_image, process_pose_file,
8
+ process_pose_params, ray_condition, resize_frame,
9
+ resize_image_with_target_area)
videox_fun/data/bucket_sampler.py ADDED
@@ -0,0 +1,379 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import os
3
+ from typing import (Generic, Iterable, Iterator, List, Optional, Sequence,
4
+ Sized, TypeVar, Union)
5
+
6
+ import cv2
7
+ import numpy as np
8
+ import torch
9
+ from PIL import Image
10
+ from torch.utils.data import BatchSampler, Dataset, Sampler
11
+
12
+ ASPECT_RATIO_512 = {
13
+ '0.25': [256.0, 1024.0], '0.26': [256.0, 992.0], '0.27': [256.0, 960.0], '0.28': [256.0, 928.0],
14
+ '0.32': [288.0, 896.0], '0.33': [288.0, 864.0], '0.35': [288.0, 832.0], '0.4': [320.0, 800.0],
15
+ '0.42': [320.0, 768.0], '0.48': [352.0, 736.0], '0.5': [352.0, 704.0], '0.52': [352.0, 672.0],
16
+ '0.57': [384.0, 672.0], '0.6': [384.0, 640.0], '0.68': [416.0, 608.0], '0.72': [416.0, 576.0],
17
+ '0.78': [448.0, 576.0], '0.82': [448.0, 544.0], '0.88': [480.0, 544.0], '0.94': [480.0, 512.0],
18
+ '1.0': [512.0, 512.0], '1.07': [512.0, 480.0], '1.13': [544.0, 480.0], '1.21': [544.0, 448.0],
19
+ '1.29': [576.0, 448.0], '1.38': [576.0, 416.0], '1.46': [608.0, 416.0], '1.67': [640.0, 384.0],
20
+ '1.75': [672.0, 384.0], '2.0': [704.0, 352.0], '2.09': [736.0, 352.0], '2.4': [768.0, 320.0],
21
+ '2.5': [800.0, 320.0], '2.89': [832.0, 288.0], '3.0': [864.0, 288.0], '3.11': [896.0, 288.0],
22
+ '3.62': [928.0, 256.0], '3.75': [960.0, 256.0], '3.88': [992.0, 256.0], '4.0': [1024.0, 256.0]
23
+ }
24
+ ASPECT_RATIO_RANDOM_CROP_512 = {
25
+ '0.42': [320.0, 768.0], '0.5': [352.0, 704.0],
26
+ '0.57': [384.0, 672.0], '0.68': [416.0, 608.0], '0.78': [448.0, 576.0], '0.88': [480.0, 544.0],
27
+ '0.94': [480.0, 512.0], '1.0': [512.0, 512.0], '1.07': [512.0, 480.0],
28
+ '1.13': [544.0, 480.0], '1.29': [576.0, 448.0], '1.46': [608.0, 416.0], '1.75': [672.0, 384.0],
29
+ '2.0': [704.0, 352.0], '2.4': [768.0, 320.0]
30
+ }
31
+ ASPECT_RATIO_RANDOM_CROP_PROB = [
32
+ 1, 2,
33
+ 4, 4, 4, 4,
34
+ 8, 8, 8,
35
+ 4, 4, 4, 4,
36
+ 2, 1
37
+ ]
38
+ ASPECT_RATIO_RANDOM_CROP_PROB = np.array(ASPECT_RATIO_RANDOM_CROP_PROB) / sum(ASPECT_RATIO_RANDOM_CROP_PROB)
39
+
40
+ def get_closest_ratio(height: float, width: float, ratios: dict = ASPECT_RATIO_512):
41
+ aspect_ratio = height / width
42
+ closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - aspect_ratio))
43
+ return ratios[closest_ratio], float(closest_ratio)
44
+
45
+ def get_image_size_without_loading(path):
46
+ with Image.open(path) as img:
47
+ return img.size # (width, height)
48
+
49
+ class RandomSampler(Sampler[int]):
50
+ r"""Samples elements randomly. If without replacement, then sample from a shuffled dataset.
51
+
52
+ If with replacement, then user can specify :attr:`num_samples` to draw.
53
+
54
+ Args:
55
+ data_source (Dataset): dataset to sample from
56
+ replacement (bool): samples are drawn on-demand with replacement if ``True``, default=``False``
57
+ num_samples (int): number of samples to draw, default=`len(dataset)`.
58
+ generator (Generator): Generator used in sampling.
59
+ """
60
+
61
+ data_source: Sized
62
+ replacement: bool
63
+
64
+ def __init__(self, data_source: Sized, replacement: bool = False,
65
+ num_samples: Optional[int] = None, generator=None) -> None:
66
+ self.data_source = data_source
67
+ self.replacement = replacement
68
+ self._num_samples = num_samples
69
+ self.generator = generator
70
+ self._pos_start = 0
71
+
72
+ if not isinstance(self.replacement, bool):
73
+ raise TypeError(f"replacement should be a boolean value, but got replacement={self.replacement}")
74
+
75
+ if not isinstance(self.num_samples, int) or self.num_samples <= 0:
76
+ raise ValueError(f"num_samples should be a positive integer value, but got num_samples={self.num_samples}")
77
+
78
+ @property
79
+ def num_samples(self) -> int:
80
+ # dataset size might change at runtime
81
+ if self._num_samples is None:
82
+ return len(self.data_source)
83
+ return self._num_samples
84
+
85
+ def __iter__(self) -> Iterator[int]:
86
+ n = len(self.data_source)
87
+ if self.generator is None:
88
+ seed = int(torch.empty((), dtype=torch.int64).random_().item())
89
+ generator = torch.Generator()
90
+ generator.manual_seed(seed)
91
+ else:
92
+ generator = self.generator
93
+
94
+ if self.replacement:
95
+ for _ in range(self.num_samples // 32):
96
+ yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=generator).tolist()
97
+ yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=generator).tolist()
98
+ else:
99
+ for _ in range(self.num_samples // n):
100
+ xx = torch.randperm(n, generator=generator).tolist()
101
+ if self._pos_start >= n:
102
+ self._pos_start = 0
103
+ print("xx top 10", xx[:10], self._pos_start)
104
+ for idx in range(self._pos_start, n):
105
+ yield xx[idx]
106
+ self._pos_start = (self._pos_start + 1) % n
107
+ self._pos_start = 0
108
+ yield from torch.randperm(n, generator=generator).tolist()[:self.num_samples % n]
109
+
110
+ def __len__(self) -> int:
111
+ return self.num_samples
112
+
113
+ class AspectRatioBatchImageSampler(BatchSampler):
114
+ """A sampler wrapper for grouping images with similar aspect ratio into a same batch.
115
+
116
+ Args:
117
+ sampler (Sampler): Base sampler.
118
+ dataset (Dataset): Dataset providing data information.
119
+ batch_size (int): Size of mini-batch.
120
+ drop_last (bool): If ``True``, the sampler will drop the last batch if
121
+ its size would be less than ``batch_size``.
122
+ aspect_ratios (dict): The predefined aspect ratios.
123
+ """
124
+ def __init__(
125
+ self,
126
+ sampler: Sampler,
127
+ dataset: Dataset,
128
+ batch_size: int,
129
+ train_folder: str = None,
130
+ aspect_ratios: dict = ASPECT_RATIO_512,
131
+ drop_last: bool = False,
132
+ config=None,
133
+ **kwargs
134
+ ) -> None:
135
+ if not isinstance(sampler, Sampler):
136
+ raise TypeError('sampler should be an instance of ``Sampler``, '
137
+ f'but got {sampler}')
138
+ if not isinstance(batch_size, int) or batch_size <= 0:
139
+ raise ValueError('batch_size should be a positive integer value, '
140
+ f'but got batch_size={batch_size}')
141
+ self.sampler = sampler
142
+ self.dataset = dataset
143
+ self.train_folder = train_folder
144
+ self.batch_size = batch_size
145
+ self.aspect_ratios = aspect_ratios
146
+ self.drop_last = drop_last
147
+ self.config = config
148
+ # buckets for each aspect ratio
149
+ self._aspect_ratio_buckets = {ratio: [] for ratio in aspect_ratios}
150
+ # [str(k) for k, v in aspect_ratios]
151
+ self.current_available_bucket_keys = list(aspect_ratios.keys())
152
+
153
+ def __iter__(self):
154
+ for idx in self.sampler:
155
+ try:
156
+ image_dict = self.dataset[idx]
157
+
158
+ width, height = image_dict.get("width", None), image_dict.get("height", None)
159
+ if width is None or height is None:
160
+ image_id, name = image_dict['file_path'], image_dict['text']
161
+ if self.train_folder is None:
162
+ image_dir = image_id
163
+ else:
164
+ image_dir = os.path.join(self.train_folder, image_id)
165
+
166
+ width, height = get_image_size_without_loading(image_dir)
167
+
168
+ ratio = height / width # self.dataset[idx]
169
+ else:
170
+ height = int(height)
171
+ width = int(width)
172
+ ratio = height / width # self.dataset[idx]
173
+ except Exception as e:
174
+ print(e)
175
+ continue
176
+ # find the closest aspect ratio
177
+ closest_ratio = min(self.aspect_ratios.keys(), key=lambda r: abs(float(r) - ratio))
178
+ if closest_ratio not in self.current_available_bucket_keys:
179
+ continue
180
+ bucket = self._aspect_ratio_buckets[closest_ratio]
181
+ bucket.append(idx)
182
+ # yield a batch of indices in the same aspect ratio group
183
+ if len(bucket) == self.batch_size:
184
+ yield bucket[:]
185
+ del bucket[:]
186
+
187
+ class AspectRatioBatchSampler(BatchSampler):
188
+ """A sampler wrapper for grouping images with similar aspect ratio into a same batch.
189
+
190
+ Args:
191
+ sampler (Sampler): Base sampler.
192
+ dataset (Dataset): Dataset providing data information.
193
+ batch_size (int): Size of mini-batch.
194
+ drop_last (bool): If ``True``, the sampler will drop the last batch if
195
+ its size would be less than ``batch_size``.
196
+ aspect_ratios (dict): The predefined aspect ratios.
197
+ """
198
+ def __init__(
199
+ self,
200
+ sampler: Sampler,
201
+ dataset: Dataset,
202
+ batch_size: int,
203
+ video_folder: str = None,
204
+ train_data_format: str = "webvid",
205
+ aspect_ratios: dict = ASPECT_RATIO_512,
206
+ drop_last: bool = False,
207
+ config=None,
208
+ **kwargs
209
+ ) -> None:
210
+ if not isinstance(sampler, Sampler):
211
+ raise TypeError('sampler should be an instance of ``Sampler``, '
212
+ f'but got {sampler}')
213
+ if not isinstance(batch_size, int) or batch_size <= 0:
214
+ raise ValueError('batch_size should be a positive integer value, '
215
+ f'but got batch_size={batch_size}')
216
+ self.sampler = sampler
217
+ self.dataset = dataset
218
+ self.video_folder = video_folder
219
+ self.train_data_format = train_data_format
220
+ self.batch_size = batch_size
221
+ self.aspect_ratios = aspect_ratios
222
+ self.drop_last = drop_last
223
+ self.config = config
224
+ # buckets for each aspect ratio
225
+ self._aspect_ratio_buckets = {ratio: [] for ratio in aspect_ratios}
226
+ # [str(k) for k, v in aspect_ratios]
227
+ self.current_available_bucket_keys = list(aspect_ratios.keys())
228
+
229
+ def __iter__(self):
230
+ for idx in self.sampler:
231
+ try:
232
+ video_dict = self.dataset[idx]
233
+ width, more = video_dict.get("width", None), video_dict.get("height", None)
234
+
235
+ if width is None or height is None:
236
+ if self.train_data_format == "normal":
237
+ video_id, name = video_dict['file_path'], video_dict['text']
238
+ if self.video_folder is None:
239
+ video_dir = video_id
240
+ else:
241
+ video_dir = os.path.join(self.video_folder, video_id)
242
+ else:
243
+ videoid, name, page_dir = video_dict['videoid'], video_dict['name'], video_dict['page_dir']
244
+ video_dir = os.path.join(self.video_folder, f"{videoid}.mp4")
245
+ cap = cv2.VideoCapture(video_dir)
246
+
247
+ # 获取视频尺寸
248
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) # 浮点数转换为整数
249
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) # 浮点数转换为整数
250
+
251
+ ratio = height / width # self.dataset[idx]
252
+ else:
253
+ height = int(height)
254
+ width = int(width)
255
+ ratio = height / width # self.dataset[idx]
256
+ except Exception as e:
257
+ print(e, self.dataset[idx], "This item is error, please check it.")
258
+ continue
259
+ # find the closest aspect ratio
260
+ closest_ratio = min(self.aspect_ratios.keys(), key=lambda r: abs(float(r) - ratio))
261
+ if closest_ratio not in self.current_available_bucket_keys:
262
+ continue
263
+ bucket = self._aspect_ratio_buckets[closest_ratio]
264
+ bucket.append(idx)
265
+ # yield a batch of indices in the same aspect ratio group
266
+ if len(bucket) == self.batch_size:
267
+ yield bucket[:]
268
+ del bucket[:]
269
+
270
+ class AspectRatioBatchImageVideoSampler(BatchSampler):
271
+ """A sampler wrapper for grouping images with similar aspect ratio into a same batch.
272
+
273
+ Args:
274
+ sampler (Sampler): Base sampler.
275
+ dataset (Dataset): Dataset providing data information.
276
+ batch_size (int): Size of mini-batch.
277
+ drop_last (bool): If ``True``, the sampler will drop the last batch if
278
+ its size would be less than ``batch_size``.
279
+ aspect_ratios (dict): The predefined aspect ratios.
280
+ """
281
+
282
+ def __init__(self,
283
+ sampler: Sampler,
284
+ dataset: Dataset,
285
+ batch_size: int,
286
+ train_folder: str = None,
287
+ aspect_ratios: dict = ASPECT_RATIO_512,
288
+ drop_last: bool = False
289
+ ) -> None:
290
+ if not isinstance(sampler, Sampler):
291
+ raise TypeError('sampler should be an instance of ``Sampler``, '
292
+ f'but got {sampler}')
293
+ if not isinstance(batch_size, int) or batch_size <= 0:
294
+ raise ValueError('batch_size should be a positive integer value, '
295
+ f'but got batch_size={batch_size}')
296
+ self.sampler = sampler
297
+ self.dataset = dataset
298
+ self.train_folder = train_folder
299
+ self.batch_size = batch_size
300
+ self.aspect_ratios = aspect_ratios
301
+ self.drop_last = drop_last
302
+
303
+ # buckets for each aspect ratio
304
+ self.current_available_bucket_keys = list(aspect_ratios.keys())
305
+ self.bucket = {
306
+ 'image':{ratio: [] for ratio in aspect_ratios},
307
+ 'video':{ratio: [] for ratio in aspect_ratios}
308
+ }
309
+
310
+ def __iter__(self):
311
+ for idx in self.sampler:
312
+ content_type = self.dataset[idx].get('type', 'image')
313
+ if content_type == 'image':
314
+ try:
315
+ image_dict = self.dataset[idx]
316
+
317
+ width, height = image_dict.get("width", None), image_dict.get("height", None)
318
+ if width is None or height is None:
319
+ image_id, name = image_dict['file_path'], image_dict['text']
320
+ if self.train_folder is None:
321
+ image_dir = image_id
322
+ else:
323
+ image_dir = os.path.join(self.train_folder, image_id)
324
+
325
+ width, height = get_image_size_without_loading(image_dir)
326
+
327
+ ratio = height / width # self.dataset[idx]
328
+ else:
329
+ height = int(height)
330
+ width = int(width)
331
+ ratio = height / width # self.dataset[idx]
332
+ except Exception as e:
333
+ print(e, self.dataset[idx], "This item is error, please check it.")
334
+ continue
335
+ # find the closest aspect ratio
336
+ closest_ratio = min(self.aspect_ratios.keys(), key=lambda r: abs(float(r) - ratio))
337
+ if closest_ratio not in self.current_available_bucket_keys:
338
+ continue
339
+ bucket = self.bucket['image'][closest_ratio]
340
+ bucket.append(idx)
341
+ # yield a batch of indices in the same aspect ratio group
342
+ if len(bucket) == self.batch_size:
343
+ yield bucket[:]
344
+ del bucket[:]
345
+ else:
346
+ try:
347
+ video_dict = self.dataset[idx]
348
+ width, height = video_dict.get("width", None), video_dict.get("height", None)
349
+
350
+ if width is None or height is None:
351
+ video_id, name = video_dict['file_path'], video_dict['text']
352
+ if self.train_folder is None:
353
+ video_dir = video_id
354
+ else:
355
+ video_dir = os.path.join(self.train_folder, video_id)
356
+ cap = cv2.VideoCapture(video_dir)
357
+
358
+ # 获取视频尺寸
359
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) # 浮点数转换为整数
360
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) # 浮点数转换为整数
361
+
362
+ ratio = height / width # self.dataset[idx]
363
+ else:
364
+ height = int(height)
365
+ width = int(width)
366
+ ratio = height / width # self.dataset[idx]
367
+ except Exception as e:
368
+ print(e, self.dataset[idx], "This item is error, please check it.")
369
+ continue
370
+ # find the closest aspect ratio
371
+ closest_ratio = min(self.aspect_ratios.keys(), key=lambda r: abs(float(r) - ratio))
372
+ if closest_ratio not in self.current_available_bucket_keys:
373
+ continue
374
+ bucket = self.bucket['video'][closest_ratio]
375
+ bucket.append(idx)
376
+ # yield a batch of indices in the same aspect ratio group
377
+ if len(bucket) == self.batch_size:
378
+ yield bucket[:]
379
+ del bucket[:]
videox_fun/data/dataset_image.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import random
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torchvision.transforms as transforms
8
+ from PIL import Image
9
+ from torch.utils.data.dataset import Dataset
10
+
11
+
12
+ class CC15M(Dataset):
13
+ def __init__(
14
+ self,
15
+ json_path,
16
+ video_folder=None,
17
+ resolution=512,
18
+ enable_bucket=False,
19
+ ):
20
+ print(f"loading annotations from {json_path} ...")
21
+ self.dataset = json.load(open(json_path, 'r'))
22
+ self.length = len(self.dataset)
23
+ print(f"data scale: {self.length}")
24
+
25
+ self.enable_bucket = enable_bucket
26
+ self.video_folder = video_folder
27
+
28
+ resolution = tuple(resolution) if not isinstance(resolution, int) else (resolution, resolution)
29
+ self.pixel_transforms = transforms.Compose([
30
+ transforms.Resize(resolution[0]),
31
+ transforms.CenterCrop(resolution),
32
+ transforms.ToTensor(),
33
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
34
+ ])
35
+
36
+ def get_batch(self, idx):
37
+ video_dict = self.dataset[idx]
38
+ video_id, name = video_dict['file_path'], video_dict['text']
39
+
40
+ if self.video_folder is None:
41
+ video_dir = video_id
42
+ else:
43
+ video_dir = os.path.join(self.video_folder, video_id)
44
+
45
+ pixel_values = Image.open(video_dir).convert("RGB")
46
+ return pixel_values, name
47
+
48
+ def __len__(self):
49
+ return self.length
50
+
51
+ def __getitem__(self, idx):
52
+ while True:
53
+ try:
54
+ pixel_values, name = self.get_batch(idx)
55
+ break
56
+ except Exception as e:
57
+ print(e)
58
+ idx = random.randint(0, self.length-1)
59
+
60
+ if not self.enable_bucket:
61
+ pixel_values = self.pixel_transforms(pixel_values)
62
+ else:
63
+ pixel_values = np.array(pixel_values)
64
+
65
+ sample = dict(pixel_values=pixel_values, text=name)
66
+ return sample
67
+
68
+ class ImageEditDataset(Dataset):
69
+ def __init__(
70
+ self,
71
+ ann_path, data_root=None,
72
+ image_sample_size=512,
73
+ text_drop_ratio=0.1,
74
+ enable_bucket=False,
75
+ enable_inpaint=False,
76
+ return_file_name=False,
77
+ ):
78
+ # Loading annotations from files
79
+ print(f"loading annotations from {ann_path} ...")
80
+ if ann_path.endswith('.csv'):
81
+ with open(ann_path, 'r') as csvfile:
82
+ dataset = list(csv.DictReader(csvfile))
83
+ elif ann_path.endswith('.json'):
84
+ dataset = json.load(open(ann_path))
85
+
86
+ self.data_root = data_root
87
+ self.dataset = dataset
88
+
89
+ self.length = len(self.dataset)
90
+ print(f"data scale: {self.length}")
91
+ # TODO: enable bucket training
92
+ self.enable_bucket = enable_bucket
93
+ self.text_drop_ratio = text_drop_ratio
94
+ self.enable_inpaint = enable_inpaint
95
+ self.return_file_name = return_file_name
96
+
97
+ # Image params
98
+ self.image_sample_size = tuple(image_sample_size) if not isinstance(image_sample_size, int) else (image_sample_size, image_sample_size)
99
+ self.image_transforms = transforms.Compose([
100
+ transforms.Resize(min(self.image_sample_size)),
101
+ transforms.CenterCrop(self.image_sample_size),
102
+ transforms.ToTensor(),
103
+ transforms.Normalize([0.5, 0.5, 0.5],[0.5, 0.5, 0.5])
104
+ ])
105
+
106
+ def get_batch(self, idx):
107
+ data_info = self.dataset[idx % len(self.dataset)]
108
+
109
+ image_path, text = data_info['file_path'], data_info['text']
110
+ if self.data_root is not None:
111
+ image_path = os.path.join(self.data_root, image_path)
112
+ image = Image.open(image_path).convert('RGB')
113
+
114
+ if not self.enable_bucket:
115
+ raise ValueError("Not enable_bucket is not supported now. ")
116
+ else:
117
+ image = np.expand_dims(np.array(image), 0)
118
+
119
+ source_image_path = data_info.get('source_file_path', [])
120
+ source_image = []
121
+ if isinstance(source_image_path, list):
122
+ for _source_image_path in source_image_path:
123
+ if self.data_root is not None:
124
+ _source_image_path = os.path.join(self.data_root, _source_image_path)
125
+ _source_image = Image.open(_source_image_path).convert('RGB')
126
+ source_image.append(_source_image)
127
+ else:
128
+ if self.data_root is not None:
129
+ _source_image_path = os.path.join(self.data_root, source_image_path)
130
+ _source_image = Image.open(_source_image_path).convert('RGB')
131
+ source_image.append(_source_image)
132
+
133
+ if not self.enable_bucket:
134
+ raise ValueError("Not enable_bucket is not supported now. ")
135
+ else:
136
+ source_image = [np.array(_source_image) for _source_image in source_image]
137
+
138
+ if random.random() < self.text_drop_ratio:
139
+ text = ''
140
+ return image, source_image, text, 'image', image_path
141
+
142
+ def __len__(self):
143
+ return self.length
144
+
145
+ def __getitem__(self, idx):
146
+ data_info = self.dataset[idx % len(self.dataset)]
147
+ data_type = data_info.get('type', 'image')
148
+ while True:
149
+ sample = {}
150
+ try:
151
+ data_info_local = self.dataset[idx % len(self.dataset)]
152
+ data_type_local = data_info_local.get('type', 'image')
153
+ if data_type_local != data_type:
154
+ raise ValueError("data_type_local != data_type")
155
+
156
+ pixel_values, source_pixel_values, name, data_type, file_path = self.get_batch(idx)
157
+ sample["pixel_values"] = pixel_values
158
+ sample["source_pixel_values"] = source_pixel_values
159
+ sample["text"] = name
160
+ sample["data_type"] = data_type
161
+ sample["idx"] = idx
162
+ if self.return_file_name:
163
+ sample["file_name"] = os.path.basename(file_path)
164
+
165
+ if len(sample) > 0:
166
+ break
167
+ except Exception as e:
168
+ print(e, self.dataset[idx % len(self.dataset)])
169
+ idx = random.randint(0, self.length-1)
170
+
171
+ if self.enable_inpaint and not self.enable_bucket:
172
+ mask = get_random_mask(pixel_values.size())
173
+ mask_pixel_values = pixel_values * (1 - mask) + torch.ones_like(pixel_values) * -1 * mask
174
+ sample["mask_pixel_values"] = mask_pixel_values
175
+ sample["mask"] = mask
176
+
177
+ clip_pixel_values = sample["pixel_values"][0].permute(1, 2, 0).contiguous()
178
+ clip_pixel_values = (clip_pixel_values * 0.5 + 0.5) * 255
179
+ sample["clip_pixel_values"] = clip_pixel_values
180
+
181
+ return sample
182
+
183
+ if __name__ == "__main__":
184
+ dataset = CC15M(
185
+ csv_path="./cc15m_add_index.json",
186
+ resolution=512,
187
+ )
188
+
189
+ dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, num_workers=0,)
190
+ for idx, batch in enumerate(dataloader):
191
+ print(batch["pixel_values"].shape, len(batch["text"]))
videox_fun/data/dataset_image_video.py ADDED
@@ -0,0 +1,657 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+ import gc
3
+ import io
4
+ import json
5
+ import math
6
+ import os
7
+ import random
8
+ from contextlib import contextmanager
9
+ from random import shuffle
10
+ from threading import Thread
11
+
12
+ import albumentations
13
+ import cv2
14
+ import numpy as np
15
+ import torch
16
+ import torch.nn.functional as F
17
+ import torchvision.transforms as transforms
18
+ from decord import VideoReader
19
+ from einops import rearrange
20
+ from func_timeout import FunctionTimedOut, func_timeout
21
+ from packaging import version as pver
22
+ from PIL import Image
23
+ from safetensors.torch import load_file
24
+ from torch.utils.data import BatchSampler, Sampler
25
+ from torch.utils.data.dataset import Dataset
26
+
27
+ from .utils import (VIDEO_READER_TIMEOUT, Camera, VideoReader_contextmanager,
28
+ custom_meshgrid, get_random_mask, get_relative_pose,
29
+ get_video_reader_batch, padding_image, process_pose_file,
30
+ process_pose_params, ray_condition, resize_frame,
31
+ resize_image_with_target_area)
32
+
33
+
34
+ class ImageVideoSampler(BatchSampler):
35
+ """A sampler wrapper for grouping images with similar aspect ratio into a same batch.
36
+
37
+ Args:
38
+ sampler (Sampler): Base sampler.
39
+ dataset (Dataset): Dataset providing data information.
40
+ batch_size (int): Size of mini-batch.
41
+ drop_last (bool): If ``True``, the sampler will drop the last batch if
42
+ its size would be less than ``batch_size``.
43
+ aspect_ratios (dict): The predefined aspect ratios.
44
+ """
45
+
46
+ def __init__(self,
47
+ sampler: Sampler,
48
+ dataset: Dataset,
49
+ batch_size: int,
50
+ drop_last: bool = False
51
+ ) -> None:
52
+ if not isinstance(sampler, Sampler):
53
+ raise TypeError('sampler should be an instance of ``Sampler``, '
54
+ f'but got {sampler}')
55
+ if not isinstance(batch_size, int) or batch_size <= 0:
56
+ raise ValueError('batch_size should be a positive integer value, '
57
+ f'but got batch_size={batch_size}')
58
+ self.sampler = sampler
59
+ self.dataset = dataset
60
+ self.batch_size = batch_size
61
+ self.drop_last = drop_last
62
+
63
+ # buckets for each aspect ratio
64
+ self.bucket = {'image':[], 'video':[]}
65
+
66
+ def __iter__(self):
67
+ for idx in self.sampler:
68
+ content_type = self.dataset.dataset[idx].get('type', 'image')
69
+ self.bucket[content_type].append(idx)
70
+
71
+ # yield a batch of indices in the same aspect ratio group
72
+ if len(self.bucket['video']) == self.batch_size:
73
+ bucket = self.bucket['video']
74
+ yield bucket[:]
75
+ del bucket[:]
76
+ elif len(self.bucket['image']) == self.batch_size:
77
+ bucket = self.bucket['image']
78
+ yield bucket[:]
79
+ del bucket[:]
80
+
81
+
82
+ class ImageVideoDataset(Dataset):
83
+ def __init__(
84
+ self,
85
+ ann_path, data_root=None,
86
+ video_sample_size=512, video_sample_stride=4, video_sample_n_frames=16,
87
+ image_sample_size=512,
88
+ video_repeat=0,
89
+ text_drop_ratio=0.1,
90
+ enable_bucket=False,
91
+ video_length_drop_start=0.0,
92
+ video_length_drop_end=1.0,
93
+ enable_inpaint=False,
94
+ return_file_name=False,
95
+ ):
96
+ # Loading annotations from files
97
+ print(f"loading annotations from {ann_path} ...")
98
+ if ann_path.endswith('.csv'):
99
+ with open(ann_path, 'r') as csvfile:
100
+ dataset = list(csv.DictReader(csvfile))
101
+ elif ann_path.endswith('.json'):
102
+ dataset = json.load(open(ann_path))
103
+
104
+ self.data_root = data_root
105
+
106
+ # It's used to balance num of images and videos.
107
+ if video_repeat > 0:
108
+ self.dataset = []
109
+ for data in dataset:
110
+ if data.get('type', 'image') != 'video':
111
+ self.dataset.append(data)
112
+
113
+ for _ in range(video_repeat):
114
+ for data in dataset:
115
+ if data.get('type', 'image') == 'video':
116
+ self.dataset.append(data)
117
+ else:
118
+ self.dataset = dataset
119
+ del dataset
120
+
121
+ self.length = len(self.dataset)
122
+ print(f"data scale: {self.length}")
123
+ # TODO: enable bucket training
124
+ self.enable_bucket = enable_bucket
125
+ self.text_drop_ratio = text_drop_ratio
126
+ self.enable_inpaint = enable_inpaint
127
+ self.return_file_name = return_file_name
128
+
129
+ self.video_length_drop_start = video_length_drop_start
130
+ self.video_length_drop_end = video_length_drop_end
131
+
132
+ # Video params
133
+ self.video_sample_stride = video_sample_stride
134
+ self.video_sample_n_frames = video_sample_n_frames
135
+ self.video_sample_size = tuple(video_sample_size) if not isinstance(video_sample_size, int) else (video_sample_size, video_sample_size)
136
+ self.video_transforms = transforms.Compose(
137
+ [
138
+ transforms.Resize(min(self.video_sample_size)),
139
+ transforms.CenterCrop(self.video_sample_size),
140
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
141
+ ]
142
+ )
143
+
144
+ # Image params
145
+ self.image_sample_size = tuple(image_sample_size) if not isinstance(image_sample_size, int) else (image_sample_size, image_sample_size)
146
+ self.image_transforms = transforms.Compose([
147
+ transforms.Resize(min(self.image_sample_size)),
148
+ transforms.CenterCrop(self.image_sample_size),
149
+ transforms.ToTensor(),
150
+ transforms.Normalize([0.5, 0.5, 0.5],[0.5, 0.5, 0.5])
151
+ ])
152
+
153
+ self.larger_side_of_image_and_video = max(min(self.image_sample_size), min(self.video_sample_size))
154
+
155
+ def get_batch(self, idx):
156
+ data_info = self.dataset[idx % len(self.dataset)]
157
+
158
+ if data_info.get('type', 'image')=='video':
159
+ video_id, text = data_info['file_path'], data_info['text']
160
+
161
+ if self.data_root is None:
162
+ video_dir = video_id
163
+ else:
164
+ video_dir = os.path.join(self.data_root, video_id)
165
+
166
+ with VideoReader_contextmanager(video_dir, num_threads=2) as video_reader:
167
+ min_sample_n_frames = min(
168
+ self.video_sample_n_frames,
169
+ int(len(video_reader) * (self.video_length_drop_end - self.video_length_drop_start) // self.video_sample_stride)
170
+ )
171
+ if min_sample_n_frames == 0:
172
+ raise ValueError(f"No Frames in video.")
173
+
174
+ video_length = int(self.video_length_drop_end * len(video_reader))
175
+ clip_length = min(video_length, (min_sample_n_frames - 1) * self.video_sample_stride + 1)
176
+ start_idx = random.randint(int(self.video_length_drop_start * video_length), video_length - clip_length) if video_length != clip_length else 0
177
+ batch_index = np.linspace(start_idx, start_idx + clip_length - 1, min_sample_n_frames, dtype=int)
178
+
179
+ try:
180
+ sample_args = (video_reader, batch_index)
181
+ pixel_values = func_timeout(
182
+ VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
183
+ )
184
+ resized_frames = []
185
+ for i in range(len(pixel_values)):
186
+ frame = pixel_values[i]
187
+ resized_frame = resize_frame(frame, self.larger_side_of_image_and_video)
188
+ resized_frames.append(resized_frame)
189
+ pixel_values = np.array(resized_frames)
190
+ except FunctionTimedOut:
191
+ raise ValueError(f"Read {idx} timeout.")
192
+ except Exception as e:
193
+ raise ValueError(f"Failed to extract frames from video. Error is {e}.")
194
+
195
+ if not self.enable_bucket:
196
+ pixel_values = torch.from_numpy(pixel_values).permute(0, 3, 1, 2).contiguous()
197
+ pixel_values = pixel_values / 255.
198
+ del video_reader
199
+ else:
200
+ pixel_values = pixel_values
201
+
202
+ if not self.enable_bucket:
203
+ pixel_values = self.video_transforms(pixel_values)
204
+
205
+ # Random use no text generation
206
+ if random.random() < self.text_drop_ratio:
207
+ text = ''
208
+ return pixel_values, text, 'video', video_dir
209
+ else:
210
+ image_path, text = data_info['file_path'], data_info['text']
211
+ if self.data_root is not None:
212
+ image_path = os.path.join(self.data_root, image_path)
213
+ image = Image.open(image_path).convert('RGB')
214
+ if not self.enable_bucket:
215
+ image = self.image_transforms(image).unsqueeze(0)
216
+ else:
217
+ image = np.expand_dims(np.array(image), 0)
218
+ if random.random() < self.text_drop_ratio:
219
+ text = ''
220
+ return image, text, 'image', image_path
221
+
222
+ def __len__(self):
223
+ return self.length
224
+
225
+ def __getitem__(self, idx):
226
+ data_info = self.dataset[idx % len(self.dataset)]
227
+ data_type = data_info.get('type', 'image')
228
+ while True:
229
+ sample = {}
230
+ try:
231
+ data_info_local = self.dataset[idx % len(self.dataset)]
232
+ data_type_local = data_info_local.get('type', 'image')
233
+ if data_type_local != data_type:
234
+ raise ValueError("data_type_local != data_type")
235
+
236
+ pixel_values, name, data_type, file_path = self.get_batch(idx)
237
+ sample["pixel_values"] = pixel_values
238
+ sample["text"] = name
239
+ sample["data_type"] = data_type
240
+ sample["idx"] = idx
241
+ if self.return_file_name:
242
+ sample["file_name"] = os.path.basename(file_path)
243
+
244
+ if len(sample) > 0:
245
+ break
246
+ except Exception as e:
247
+ print(e, self.dataset[idx % len(self.dataset)])
248
+ idx = random.randint(0, self.length-1)
249
+
250
+ if self.enable_inpaint and not self.enable_bucket:
251
+ mask = get_random_mask(pixel_values.size())
252
+ mask_pixel_values = pixel_values * (1 - mask) + torch.ones_like(pixel_values) * -1 * mask
253
+ sample["mask_pixel_values"] = mask_pixel_values
254
+ sample["mask"] = mask
255
+
256
+ clip_pixel_values = sample["pixel_values"][0].permute(1, 2, 0).contiguous()
257
+ clip_pixel_values = (clip_pixel_values * 0.5 + 0.5) * 255
258
+ sample["clip_pixel_values"] = clip_pixel_values
259
+
260
+ return sample
261
+
262
+
263
+ class ImageVideoControlDataset(Dataset):
264
+ def __init__(
265
+ self,
266
+ ann_path, data_root=None,
267
+ video_sample_size=512, video_sample_stride=4, video_sample_n_frames=16,
268
+ image_sample_size=512,
269
+ video_repeat=0,
270
+ text_drop_ratio=0.1,
271
+ enable_bucket=False,
272
+ video_length_drop_start=0.1,
273
+ video_length_drop_end=0.9,
274
+ enable_inpaint=False,
275
+ enable_camera_info=False,
276
+ return_file_name=False,
277
+ enable_subject_info=False,
278
+ padding_subject_info=True,
279
+ ):
280
+ # Loading annotations from files
281
+ print(f"loading annotations from {ann_path} ...")
282
+ if ann_path.endswith('.csv'):
283
+ with open(ann_path, 'r') as csvfile:
284
+ dataset = list(csv.DictReader(csvfile))
285
+ elif ann_path.endswith('.json'):
286
+ dataset = json.load(open(ann_path))
287
+
288
+ self.data_root = data_root
289
+
290
+ # It's used to balance num of images and videos.
291
+ if video_repeat > 0:
292
+ self.dataset = []
293
+ for data in dataset:
294
+ if data.get('type', 'image') != 'video':
295
+ self.dataset.append(data)
296
+
297
+ for _ in range(video_repeat):
298
+ for data in dataset:
299
+ if data.get('type', 'image') == 'video':
300
+ self.dataset.append(data)
301
+ else:
302
+ self.dataset = dataset
303
+ del dataset
304
+
305
+ self.length = len(self.dataset)
306
+ print(f"data scale: {self.length}")
307
+ # TODO: enable bucket training
308
+ self.enable_bucket = enable_bucket
309
+ self.text_drop_ratio = text_drop_ratio
310
+ self.enable_inpaint = enable_inpaint
311
+ self.enable_camera_info = enable_camera_info
312
+ self.enable_subject_info = enable_subject_info
313
+ self.padding_subject_info = padding_subject_info
314
+
315
+ self.video_length_drop_start = video_length_drop_start
316
+ self.video_length_drop_end = video_length_drop_end
317
+
318
+ # Video params
319
+ self.video_sample_stride = video_sample_stride
320
+ self.video_sample_n_frames = video_sample_n_frames
321
+ self.video_sample_size = tuple(video_sample_size) if not isinstance(video_sample_size, int) else (video_sample_size, video_sample_size)
322
+ self.video_transforms = transforms.Compose(
323
+ [
324
+ transforms.Resize(min(self.video_sample_size)),
325
+ transforms.CenterCrop(self.video_sample_size),
326
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
327
+ ]
328
+ )
329
+ if self.enable_camera_info:
330
+ self.video_transforms_camera = transforms.Compose(
331
+ [
332
+ transforms.Resize(min(self.video_sample_size)),
333
+ transforms.CenterCrop(self.video_sample_size)
334
+ ]
335
+ )
336
+
337
+ # Image params
338
+ self.image_sample_size = tuple(image_sample_size) if not isinstance(image_sample_size, int) else (image_sample_size, image_sample_size)
339
+ self.image_transforms = transforms.Compose([
340
+ transforms.Resize(min(self.image_sample_size)),
341
+ transforms.CenterCrop(self.image_sample_size),
342
+ transforms.ToTensor(),
343
+ transforms.Normalize([0.5, 0.5, 0.5],[0.5, 0.5, 0.5])
344
+ ])
345
+
346
+ self.larger_side_of_image_and_video = max(min(self.image_sample_size), min(self.video_sample_size))
347
+
348
+ def get_batch(self, idx):
349
+ data_info = self.dataset[idx % len(self.dataset)]
350
+ video_id, text = data_info['file_path'], data_info['text']
351
+
352
+ if data_info.get('type', 'image')=='video':
353
+ if self.data_root is None:
354
+ video_dir = video_id
355
+ else:
356
+ video_dir = os.path.join(self.data_root, video_id)
357
+
358
+ with VideoReader_contextmanager(video_dir, num_threads=2) as video_reader:
359
+ min_sample_n_frames = min(
360
+ self.video_sample_n_frames,
361
+ int(len(video_reader) * (self.video_length_drop_end - self.video_length_drop_start) // self.video_sample_stride)
362
+ )
363
+ if min_sample_n_frames == 0:
364
+ raise ValueError(f"No Frames in video.")
365
+
366
+ video_length = int(self.video_length_drop_end * len(video_reader))
367
+ clip_length = min(video_length, (min_sample_n_frames - 1) * self.video_sample_stride + 1)
368
+ start_idx = random.randint(int(self.video_length_drop_start * video_length), video_length - clip_length) if video_length != clip_length else 0
369
+ batch_index = np.linspace(start_idx, start_idx + clip_length - 1, min_sample_n_frames, dtype=int)
370
+
371
+ try:
372
+ sample_args = (video_reader, batch_index)
373
+ pixel_values = func_timeout(
374
+ VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
375
+ )
376
+ resized_frames = []
377
+ for i in range(len(pixel_values)):
378
+ frame = pixel_values[i]
379
+ resized_frame = resize_frame(frame, self.larger_side_of_image_and_video)
380
+ resized_frames.append(resized_frame)
381
+ pixel_values = np.array(resized_frames)
382
+ except FunctionTimedOut:
383
+ raise ValueError(f"Read {idx} timeout.")
384
+ except Exception as e:
385
+ raise ValueError(f"Failed to extract frames from video. Error is {e}.")
386
+
387
+ if not self.enable_bucket:
388
+ pixel_values = torch.from_numpy(pixel_values).permute(0, 3, 1, 2).contiguous()
389
+ pixel_values = pixel_values / 255.
390
+ del video_reader
391
+ else:
392
+ pixel_values = pixel_values
393
+
394
+ if not self.enable_bucket:
395
+ pixel_values = self.video_transforms(pixel_values)
396
+
397
+ # Random use no text generation
398
+ if random.random() < self.text_drop_ratio:
399
+ text = ''
400
+
401
+ control_video_id = data_info['control_file_path']
402
+
403
+ if control_video_id is not None:
404
+ if self.data_root is None:
405
+ control_video_id = control_video_id
406
+ else:
407
+ control_video_id = os.path.join(self.data_root, control_video_id)
408
+
409
+ if self.enable_camera_info:
410
+ if control_video_id.lower().endswith('.txt'):
411
+ if not self.enable_bucket:
412
+ control_pixel_values = torch.zeros_like(pixel_values)
413
+
414
+ control_camera_values = process_pose_file(control_video_id, width=self.video_sample_size[1], height=self.video_sample_size[0])
415
+ control_camera_values = torch.from_numpy(control_camera_values).permute(0, 3, 1, 2).contiguous()
416
+ control_camera_values = F.interpolate(control_camera_values, size=(len(video_reader), control_camera_values.size(3)), mode='bilinear', align_corners=True)
417
+ control_camera_values = self.video_transforms_camera(control_camera_values)
418
+ else:
419
+ control_pixel_values = np.zeros_like(pixel_values)
420
+
421
+ control_camera_values = process_pose_file(control_video_id, width=self.video_sample_size[1], height=self.video_sample_size[0], return_poses=True)
422
+ control_camera_values = torch.from_numpy(np.array(control_camera_values)).unsqueeze(0).unsqueeze(0)
423
+ control_camera_values = F.interpolate(control_camera_values, size=(len(video_reader), control_camera_values.size(3)), mode='bilinear', align_corners=True)[0][0]
424
+ control_camera_values = np.array([control_camera_values[index] for index in batch_index])
425
+ else:
426
+ if not self.enable_bucket:
427
+ control_pixel_values = torch.zeros_like(pixel_values)
428
+ control_camera_values = None
429
+ else:
430
+ control_pixel_values = np.zeros_like(pixel_values)
431
+ control_camera_values = None
432
+ else:
433
+ if control_video_id is not None:
434
+ with VideoReader_contextmanager(control_video_id, num_threads=2) as control_video_reader:
435
+ try:
436
+ sample_args = (control_video_reader, batch_index)
437
+ control_pixel_values = func_timeout(
438
+ VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
439
+ )
440
+ resized_frames = []
441
+ for i in range(len(control_pixel_values)):
442
+ frame = control_pixel_values[i]
443
+ resized_frame = resize_frame(frame, self.larger_side_of_image_and_video)
444
+ resized_frames.append(resized_frame)
445
+ control_pixel_values = np.array(resized_frames)
446
+ except FunctionTimedOut:
447
+ raise ValueError(f"Read {idx} timeout.")
448
+ except Exception as e:
449
+ raise ValueError(f"Failed to extract frames from video. Error is {e}.")
450
+
451
+ if not self.enable_bucket:
452
+ control_pixel_values = torch.from_numpy(control_pixel_values).permute(0, 3, 1, 2).contiguous()
453
+ control_pixel_values = control_pixel_values / 255.
454
+ del control_video_reader
455
+ else:
456
+ control_pixel_values = control_pixel_values
457
+
458
+ if not self.enable_bucket:
459
+ control_pixel_values = self.video_transforms(control_pixel_values)
460
+ else:
461
+ if not self.enable_bucket:
462
+ control_pixel_values = torch.zeros_like(pixel_values)
463
+ else:
464
+ control_pixel_values = np.zeros_like(pixel_values)
465
+ control_camera_values = None
466
+
467
+ if self.enable_subject_info:
468
+ if not self.enable_bucket:
469
+ visual_height, visual_width = pixel_values.shape[-2:]
470
+ else:
471
+ visual_height, visual_width = pixel_values.shape[1:3]
472
+
473
+ subject_id = data_info.get('object_file_path', [])
474
+ shuffle(subject_id)
475
+ subject_images = []
476
+ for i in range(min(len(subject_id), 4)):
477
+ subject_image = Image.open(subject_id[i])
478
+ width, height = subject_image.size
479
+ total_pixels = width * height
480
+
481
+ if self.padding_subject_info:
482
+ img = padding_image(subject_image, visual_width, visual_height)
483
+ else:
484
+ img = resize_image_with_target_area(subject_image, 1024 * 1024)
485
+
486
+ if random.random() < 0.5:
487
+ img = img.transpose(Image.FLIP_LEFT_RIGHT)
488
+ subject_images.append(np.array(img))
489
+ if self.padding_subject_info:
490
+ subject_image = np.array(subject_images)
491
+ else:
492
+ subject_image = subject_images
493
+ else:
494
+ subject_image = None
495
+
496
+ return pixel_values, control_pixel_values, subject_image, control_camera_values, text, "video"
497
+ else:
498
+ image_path, text = data_info['file_path'], data_info['text']
499
+ if self.data_root is not None:
500
+ image_path = os.path.join(self.data_root, image_path)
501
+ image = Image.open(image_path).convert('RGB')
502
+ if not self.enable_bucket:
503
+ image = self.image_transforms(image).unsqueeze(0)
504
+ else:
505
+ image = np.expand_dims(np.array(image), 0)
506
+
507
+ if random.random() < self.text_drop_ratio:
508
+ text = ''
509
+
510
+ control_image_id = data_info['control_file_path']
511
+
512
+ if self.data_root is None:
513
+ control_image_id = control_image_id
514
+ else:
515
+ control_image_id = os.path.join(self.data_root, control_image_id)
516
+
517
+ control_image = Image.open(control_image_id).convert('RGB')
518
+ if not self.enable_bucket:
519
+ control_image = self.image_transforms(control_image).unsqueeze(0)
520
+ else:
521
+ control_image = np.expand_dims(np.array(control_image), 0)
522
+
523
+ if self.enable_subject_info:
524
+ if not self.enable_bucket:
525
+ visual_height, visual_width = image.shape[-2:]
526
+ else:
527
+ visual_height, visual_width = image.shape[1:3]
528
+
529
+ subject_id = data_info.get('object_file_path', [])
530
+ shuffle(subject_id)
531
+ subject_images = []
532
+ for i in range(min(len(subject_id), 4)):
533
+ subject_image = Image.open(subject_id[i]).convert('RGB')
534
+ width, height = subject_image.size
535
+ total_pixels = width * height
536
+
537
+ if self.padding_subject_info:
538
+ img = padding_image(subject_image, visual_width, visual_height)
539
+ else:
540
+ img = resize_image_with_target_area(subject_image, 1024 * 1024)
541
+
542
+ if random.random() < 0.5:
543
+ img = img.transpose(Image.FLIP_LEFT_RIGHT)
544
+ subject_images.append(np.array(img))
545
+ if self.padding_subject_info:
546
+ subject_image = np.array(subject_images)
547
+ else:
548
+ subject_image = subject_images
549
+ else:
550
+ subject_image = None
551
+
552
+ return image, control_image, subject_image, None, text, 'image'
553
+
554
+ def __len__(self):
555
+ return self.length
556
+
557
+ def __getitem__(self, idx):
558
+ data_info = self.dataset[idx % len(self.dataset)]
559
+ data_type = data_info.get('type', 'image')
560
+ while True:
561
+ sample = {}
562
+ try:
563
+ data_info_local = self.dataset[idx % len(self.dataset)]
564
+ data_type_local = data_info_local.get('type', 'image')
565
+ if data_type_local != data_type:
566
+ raise ValueError("data_type_local != data_type")
567
+
568
+ pixel_values, control_pixel_values, subject_image, control_camera_values, name, data_type = self.get_batch(idx)
569
+
570
+ sample["pixel_values"] = pixel_values
571
+ sample["control_pixel_values"] = control_pixel_values
572
+ sample["subject_image"] = subject_image
573
+ sample["text"] = name
574
+ sample["data_type"] = data_type
575
+ sample["idx"] = idx
576
+
577
+ if self.enable_camera_info:
578
+ sample["control_camera_values"] = control_camera_values
579
+
580
+ if len(sample) > 0:
581
+ break
582
+ except Exception as e:
583
+ print(e, self.dataset[idx % len(self.dataset)])
584
+ idx = random.randint(0, self.length-1)
585
+
586
+ if self.enable_inpaint and not self.enable_bucket:
587
+ mask = get_random_mask(pixel_values.size())
588
+ mask_pixel_values = pixel_values * (1 - mask) + torch.zeros_like(pixel_values) * mask
589
+ sample["mask_pixel_values"] = mask_pixel_values
590
+ sample["mask"] = mask
591
+
592
+ clip_pixel_values = sample["pixel_values"][0].permute(1, 2, 0).contiguous()
593
+ clip_pixel_values = (clip_pixel_values * 0.5 + 0.5) * 255
594
+ sample["clip_pixel_values"] = clip_pixel_values
595
+
596
+ return sample
597
+
598
+
599
+ class ImageVideoSafetensorsDataset(Dataset):
600
+ def __init__(
601
+ self,
602
+ ann_path,
603
+ data_root=None,
604
+ ):
605
+ # Loading annotations from files
606
+ print(f"loading annotations from {ann_path} ...")
607
+ if ann_path.endswith('.json'):
608
+ dataset = json.load(open(ann_path))
609
+
610
+ self.data_root = data_root
611
+ self.dataset = dataset
612
+ self.length = len(self.dataset)
613
+ print(f"data scale: {self.length}")
614
+
615
+ def __len__(self):
616
+ return self.length
617
+
618
+ def __getitem__(self, idx):
619
+ if self.data_root is None:
620
+ path = self.dataset[idx]["file_path"]
621
+ else:
622
+ path = os.path.join(self.data_root, self.dataset[idx]["file_path"])
623
+ state_dict = load_file(path)
624
+ return state_dict
625
+
626
+
627
+ class TextDataset(Dataset):
628
+ def __init__(self, ann_path, text_drop_ratio=0.0):
629
+ print(f"loading annotations from {ann_path} ...")
630
+ with open(ann_path, 'r') as f:
631
+ self.dataset = json.load(f)
632
+ self.length = len(self.dataset)
633
+ print(f"data scale: {self.length}")
634
+ self.text_drop_ratio = text_drop_ratio
635
+
636
+ def __len__(self):
637
+ return self.length
638
+
639
+ def __getitem__(self, idx):
640
+ while True:
641
+ try:
642
+ item = self.dataset[idx]
643
+ text = item['text']
644
+
645
+ # Randomly drop text (for classifier-free guidance)
646
+ if random.random() < self.text_drop_ratio:
647
+ text = ''
648
+
649
+ sample = {
650
+ "text": text,
651
+ "idx": idx
652
+ }
653
+ return sample
654
+
655
+ except Exception as e:
656
+ print(f"Error at index {idx}: {e}, retrying with random index...")
657
+ idx = np.random.randint(0, self.length - 1)
videox_fun/data/dataset_video.py ADDED
@@ -0,0 +1,901 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+ import gc
3
+ import io
4
+ import json
5
+ import math
6
+ import os
7
+ import random
8
+ from contextlib import contextmanager
9
+ from threading import Thread
10
+
11
+ import albumentations
12
+ import cv2
13
+ import librosa
14
+ import numpy as np
15
+ import torch
16
+ import torchvision.transforms as transforms
17
+ from decord import VideoReader
18
+ from einops import rearrange
19
+ from func_timeout import FunctionTimedOut, func_timeout
20
+ from PIL import Image
21
+ from torch.utils.data import BatchSampler, Sampler
22
+ from torch.utils.data.dataset import Dataset
23
+
24
+ from .utils import (VIDEO_READER_TIMEOUT, Camera, VideoReader_contextmanager,
25
+ custom_meshgrid, get_random_mask, get_relative_pose,
26
+ get_video_reader_batch, padding_image, process_pose_file,
27
+ process_pose_params, ray_condition, resize_frame,
28
+ resize_image_with_target_area)
29
+
30
+
31
+ class WebVid10M(Dataset):
32
+ def __init__(
33
+ self,
34
+ csv_path, video_folder,
35
+ sample_size=256, sample_stride=4, sample_n_frames=16,
36
+ enable_bucket=False, enable_inpaint=False, is_image=False,
37
+ ):
38
+ print(f"loading annotations from {csv_path} ...")
39
+ with open(csv_path, 'r') as csvfile:
40
+ self.dataset = list(csv.DictReader(csvfile))
41
+ self.length = len(self.dataset)
42
+ print(f"data scale: {self.length}")
43
+
44
+ self.video_folder = video_folder
45
+ self.sample_stride = sample_stride
46
+ self.sample_n_frames = sample_n_frames
47
+ self.enable_bucket = enable_bucket
48
+ self.enable_inpaint = enable_inpaint
49
+ self.is_image = is_image
50
+
51
+ sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size)
52
+ self.pixel_transforms = transforms.Compose([
53
+ transforms.Resize(sample_size[0]),
54
+ transforms.CenterCrop(sample_size),
55
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
56
+ ])
57
+
58
+ def get_batch(self, idx):
59
+ video_dict = self.dataset[idx]
60
+ videoid, name, page_dir = video_dict['videoid'], video_dict['name'], video_dict['page_dir']
61
+
62
+ video_dir = os.path.join(self.video_folder, f"{videoid}.mp4")
63
+ video_reader = VideoReader(video_dir)
64
+ video_length = len(video_reader)
65
+
66
+ if not self.is_image:
67
+ clip_length = min(video_length, (self.sample_n_frames - 1) * self.sample_stride + 1)
68
+ start_idx = random.randint(0, video_length - clip_length)
69
+ batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.sample_n_frames, dtype=int)
70
+ else:
71
+ batch_index = [random.randint(0, video_length - 1)]
72
+
73
+ if not self.enable_bucket:
74
+ pixel_values = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous()
75
+ pixel_values = pixel_values / 255.
76
+ del video_reader
77
+ else:
78
+ pixel_values = video_reader.get_batch(batch_index).asnumpy()
79
+
80
+ if self.is_image:
81
+ pixel_values = pixel_values[0]
82
+ return pixel_values, name
83
+
84
+ def __len__(self):
85
+ return self.length
86
+
87
+ def __getitem__(self, idx):
88
+ while True:
89
+ try:
90
+ pixel_values, name = self.get_batch(idx)
91
+ break
92
+
93
+ except Exception as e:
94
+ print("Error info:", e)
95
+ idx = random.randint(0, self.length-1)
96
+
97
+ if not self.enable_bucket:
98
+ pixel_values = self.pixel_transforms(pixel_values)
99
+ if self.enable_inpaint:
100
+ mask = get_random_mask(pixel_values.size())
101
+ mask_pixel_values = pixel_values * (1 - mask) + torch.ones_like(pixel_values) * -1 * mask
102
+ sample = dict(pixel_values=pixel_values, mask_pixel_values=mask_pixel_values, mask=mask, text=name)
103
+ else:
104
+ sample = dict(pixel_values=pixel_values, text=name)
105
+ return sample
106
+
107
+
108
+ class VideoDataset(Dataset):
109
+ def __init__(
110
+ self,
111
+ ann_path, data_root=None,
112
+ sample_size=256, sample_stride=4, sample_n_frames=16,
113
+ enable_bucket=False, enable_inpaint=False
114
+ ):
115
+ print(f"loading annotations from {ann_path} ...")
116
+ self.dataset = json.load(open(ann_path, 'r'))
117
+ self.length = len(self.dataset)
118
+ print(f"data scale: {self.length}")
119
+
120
+ self.data_root = data_root
121
+ self.sample_stride = sample_stride
122
+ self.sample_n_frames = sample_n_frames
123
+ self.enable_bucket = enable_bucket
124
+ self.enable_inpaint = enable_inpaint
125
+
126
+ sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size)
127
+ self.pixel_transforms = transforms.Compose(
128
+ [
129
+ transforms.Resize(sample_size[0]),
130
+ transforms.CenterCrop(sample_size),
131
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
132
+ ]
133
+ )
134
+
135
+ def get_batch(self, idx):
136
+ video_dict = self.dataset[idx]
137
+ video_id, text = video_dict['file_path'], video_dict['text']
138
+
139
+ if self.data_root is None:
140
+ video_dir = video_id
141
+ else:
142
+ video_dir = os.path.join(self.data_root, video_id)
143
+
144
+ with VideoReader_contextmanager(video_dir, num_threads=2) as video_reader:
145
+ min_sample_n_frames = min(
146
+ self.video_sample_n_frames,
147
+ int(len(video_reader) * (self.video_length_drop_end - self.video_length_drop_start) // self.video_sample_stride)
148
+ )
149
+ if min_sample_n_frames == 0:
150
+ raise ValueError(f"No Frames in video.")
151
+
152
+ video_length = int(self.video_length_drop_end * len(video_reader))
153
+ clip_length = min(video_length, (min_sample_n_frames - 1) * self.video_sample_stride + 1)
154
+ start_idx = random.randint(int(self.video_length_drop_start * video_length), video_length - clip_length) if video_length != clip_length else 0
155
+ batch_index = np.linspace(start_idx, start_idx + clip_length - 1, min_sample_n_frames, dtype=int)
156
+
157
+ try:
158
+ sample_args = (video_reader, batch_index)
159
+ pixel_values = func_timeout(
160
+ VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
161
+ )
162
+ except FunctionTimedOut:
163
+ raise ValueError(f"Read {idx} timeout.")
164
+ except Exception as e:
165
+ raise ValueError(f"Failed to extract frames from video. Error is {e}.")
166
+
167
+ if not self.enable_bucket:
168
+ pixel_values = torch.from_numpy(pixel_values).permute(0, 3, 1, 2).contiguous()
169
+ pixel_values = pixel_values / 255.
170
+ del video_reader
171
+ else:
172
+ pixel_values = pixel_values
173
+
174
+ if not self.enable_bucket:
175
+ pixel_values = self.video_transforms(pixel_values)
176
+
177
+ # Random use no text generation
178
+ if random.random() < self.text_drop_ratio:
179
+ text = ''
180
+ return pixel_values, text
181
+
182
+ def __len__(self):
183
+ return self.length
184
+
185
+ def __getitem__(self, idx):
186
+ while True:
187
+ sample = {}
188
+ try:
189
+ pixel_values, name = self.get_batch(idx)
190
+ sample["pixel_values"] = pixel_values
191
+ sample["text"] = name
192
+ sample["idx"] = idx
193
+ if len(sample) > 0:
194
+ break
195
+
196
+ except Exception as e:
197
+ print(e, self.dataset[idx % len(self.dataset)])
198
+ idx = random.randint(0, self.length-1)
199
+
200
+ if self.enable_inpaint and not self.enable_bucket:
201
+ mask = get_random_mask(pixel_values.size())
202
+ mask_pixel_values = pixel_values * (1 - mask) + torch.zeros_like(pixel_values) * mask
203
+ sample["mask_pixel_values"] = mask_pixel_values
204
+ sample["mask"] = mask
205
+
206
+ clip_pixel_values = sample["pixel_values"][0].permute(1, 2, 0).contiguous()
207
+ clip_pixel_values = (clip_pixel_values * 0.5 + 0.5) * 255
208
+ sample["clip_pixel_values"] = clip_pixel_values
209
+
210
+ return sample
211
+
212
+
213
+ class VideoSpeechDataset(Dataset):
214
+ def __init__(
215
+ self,
216
+ ann_path, data_root=None,
217
+ video_sample_size=512, video_sample_stride=4, video_sample_n_frames=16,
218
+ enable_bucket=False, enable_inpaint=False,
219
+ audio_sr=16000, # 新增:目标音频采样率
220
+ text_drop_ratio=0.1 # 新增:文本丢弃概率
221
+ ):
222
+ print(f"loading annotations from {ann_path} ...")
223
+ self.dataset = json.load(open(ann_path, 'r'))
224
+ self.length = len(self.dataset)
225
+ print(f"data scale: {self.length}")
226
+
227
+ self.data_root = data_root
228
+ self.video_sample_stride = video_sample_stride
229
+ self.video_sample_n_frames = video_sample_n_frames
230
+ self.enable_bucket = enable_bucket
231
+ self.enable_inpaint = enable_inpaint
232
+ self.audio_sr = audio_sr
233
+ self.text_drop_ratio = text_drop_ratio
234
+
235
+ video_sample_size = tuple(video_sample_size) if not isinstance(video_sample_size, int) else (video_sample_size, video_sample_size)
236
+ self.pixel_transforms = transforms.Compose(
237
+ [
238
+ transforms.Resize(video_sample_size[0]),
239
+ transforms.CenterCrop(video_sample_size),
240
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
241
+ ]
242
+ )
243
+
244
+ def get_batch(self, idx):
245
+ video_dict = self.dataset[idx]
246
+ video_id, text = video_dict['file_path'], video_dict['text']
247
+ audio_id = video_dict['audio_path']
248
+
249
+ if self.data_root is None:
250
+ video_path = video_id
251
+ else:
252
+ video_path = os.path.join(self.data_root, video_id)
253
+
254
+ if self.data_root is None:
255
+ audio_path = audio_id
256
+ else:
257
+ audio_path = os.path.join(self.data_root, audio_id)
258
+
259
+ if not os.path.exists(audio_path):
260
+ raise FileNotFoundError(f"Audio file not found for {video_path}")
261
+
262
+ with VideoReader_contextmanager(video_path, num_threads=2) as video_reader:
263
+ total_frames = len(video_reader)
264
+ fps = video_reader.get_avg_fps() # 获取原始视频帧率
265
+
266
+ # 计算实际采样的视频帧数(考虑边界)
267
+ max_possible_frames = (total_frames - 1) // self.video_sample_stride + 1
268
+ actual_n_frames = min(self.video_sample_n_frames, max_possible_frames)
269
+ if actual_n_frames <= 0:
270
+ raise ValueError(f"Video too short: {video_path}")
271
+
272
+ # 随机选择起始帧
273
+ max_start = total_frames - (actual_n_frames - 1) * self.video_sample_stride - 1
274
+ start_frame = random.randint(0, max_start) if max_start > 0 else 0
275
+ frame_indices = [start_frame + i * self.video_sample_stride for i in range(actual_n_frames)]
276
+
277
+ # 读取视频帧
278
+ try:
279
+ sample_args = (video_reader, frame_indices)
280
+ pixel_values = func_timeout(
281
+ VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
282
+ )
283
+ except FunctionTimedOut:
284
+ raise ValueError(f"Read {idx} timeout.")
285
+ except Exception as e:
286
+ raise ValueError(f"Failed to extract frames from video. Error is {e}.")
287
+
288
+ # 视频后处理
289
+ if not self.enable_bucket:
290
+ pixel_values = torch.from_numpy(pixel_values).permute(0, 3, 1, 2).contiguous()
291
+ pixel_values = pixel_values / 255.
292
+ pixel_values = self.pixel_transforms(pixel_values)
293
+
294
+ # === 新增:加载并截取对应音频 ===
295
+ # 视频片段的起止时间(秒)
296
+ start_time = start_frame / fps
297
+ end_time = (start_frame + (actual_n_frames - 1) * self.video_sample_stride) / fps
298
+ duration = end_time - start_time
299
+
300
+ # 使用 librosa 加载整个音频(或仅加载所需部分,但 librosa.load 不支持精确 seek,所以先加载再切)
301
+ audio_input, sample_rate = librosa.load(audio_path, sr=self.audio_sr) # 重采样到目标 sr
302
+
303
+ # 转换为样本索引
304
+ start_sample = int(start_time * self.audio_sr)
305
+ end_sample = int(end_time * self.audio_sr)
306
+
307
+ # 安全截取
308
+ if start_sample >= len(audio_input):
309
+ # 音频太短,用零填充或截断
310
+ audio_segment = np.zeros(int(duration * self.audio_sr), dtype=np.float32)
311
+ else:
312
+ audio_segment = audio_input[start_sample:end_sample]
313
+ # 如果太短,补零
314
+ target_len = int(duration * self.audio_sr)
315
+ if len(audio_segment) < target_len:
316
+ audio_segment = np.pad(audio_segment, (0, target_len - len(audio_segment)), mode='constant')
317
+
318
+ # === 文本随机丢弃 ===
319
+ if random.random() < self.text_drop_ratio:
320
+ text = ''
321
+
322
+ return pixel_values, text, audio_segment, sample_rate
323
+
324
+ def __len__(self):
325
+ return self.length
326
+
327
+ def __getitem__(self, idx):
328
+ while True:
329
+ sample = {}
330
+ try:
331
+ pixel_values, text, audio, sample_rate = self.get_batch(idx)
332
+ sample["pixel_values"] = pixel_values
333
+ sample["text"] = text
334
+ sample["audio"] = torch.from_numpy(audio).float() # 转为 tensor
335
+ sample["sample_rate"] = sample_rate
336
+ sample["idx"] = idx
337
+ break
338
+ except Exception as e:
339
+ print(f"Error processing {idx}: {e}, retrying with random idx...")
340
+ idx = random.randint(0, self.length - 1)
341
+
342
+ if self.enable_inpaint and not self.enable_bucket:
343
+ mask = get_random_mask(pixel_values.size(), image_start_only=True)
344
+ mask_pixel_values = pixel_values * (1 - mask) + torch.zeros_like(pixel_values) * mask
345
+ sample["mask_pixel_values"] = mask_pixel_values
346
+ sample["mask"] = mask
347
+
348
+ clip_pixel_values = sample["pixel_values"][0].permute(1, 2, 0).contiguous()
349
+ clip_pixel_values = (clip_pixel_values * 0.5 + 0.5) * 255
350
+ sample["clip_pixel_values"] = clip_pixel_values
351
+
352
+ return sample
353
+
354
+
355
+ class VideoSpeechControlDataset(Dataset):
356
+ def __init__(
357
+ self,
358
+ ann_path, data_root=None,
359
+ video_sample_size=512, video_sample_stride=4, video_sample_n_frames=16,
360
+ enable_bucket=False, enable_inpaint=False,
361
+ audio_sr=16000,
362
+ text_drop_ratio=0.1,
363
+ enable_motion_info=False,
364
+ motion_frames=73,
365
+ ):
366
+ print(f"loading annotations from {ann_path} ...")
367
+ self.dataset = json.load(open(ann_path, 'r'))
368
+ self.length = len(self.dataset)
369
+ print(f"data scale: {self.length}")
370
+
371
+ self.data_root = data_root
372
+ self.video_sample_stride = video_sample_stride
373
+ self.video_sample_n_frames = video_sample_n_frames
374
+ self.enable_bucket = enable_bucket
375
+ self.enable_inpaint = enable_inpaint
376
+ self.audio_sr = audio_sr
377
+ self.text_drop_ratio = text_drop_ratio
378
+ self.enable_motion_info = enable_motion_info
379
+ self.motion_frames = motion_frames
380
+
381
+ video_sample_size = tuple(video_sample_size) if not isinstance(video_sample_size, int) else (video_sample_size, video_sample_size)
382
+ self.pixel_transforms = transforms.Compose(
383
+ [
384
+ transforms.Resize(video_sample_size[0]),
385
+ transforms.CenterCrop(video_sample_size),
386
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
387
+ ]
388
+ )
389
+
390
+ self.video_sample_size = video_sample_size
391
+
392
+ def get_batch(self, idx):
393
+ video_dict = self.dataset[idx]
394
+ video_id, text = video_dict['file_path'], video_dict['text']
395
+ audio_id = video_dict['audio_path']
396
+ control_video_id = video_dict['control_file_path']
397
+
398
+ if self.data_root is None:
399
+ video_path = video_id
400
+ else:
401
+ video_path = os.path.join(self.data_root, video_id)
402
+
403
+ if self.data_root is None:
404
+ audio_path = audio_id
405
+ else:
406
+ audio_path = os.path.join(self.data_root, audio_id)
407
+
408
+ if self.data_root is None:
409
+ control_video_id = control_video_id
410
+ else:
411
+ control_video_id = os.path.join(self.data_root, control_video_id)
412
+
413
+ if not os.path.exists(audio_path):
414
+ raise FileNotFoundError(f"Audio file not found for {video_path}")
415
+
416
+ # Video information
417
+ with VideoReader_contextmanager(video_path, num_threads=2) as video_reader:
418
+ total_frames = len(video_reader)
419
+ fps = video_reader.get_avg_fps()
420
+ if fps <= 0:
421
+ raise ValueError(f"Video has negative fps: {video_path}")
422
+ local_video_sample_stride = self.video_sample_stride
423
+ new_fps = int(fps // local_video_sample_stride)
424
+ while new_fps > 30:
425
+ local_video_sample_stride = local_video_sample_stride + 1
426
+ new_fps = int(fps // local_video_sample_stride)
427
+
428
+ max_possible_frames = (total_frames - 1) // local_video_sample_stride + 1
429
+ actual_n_frames = min(self.video_sample_n_frames, max_possible_frames)
430
+ if actual_n_frames <= 0:
431
+ raise ValueError(f"Video too short: {video_path}")
432
+
433
+ max_start = total_frames - (actual_n_frames - 1) * local_video_sample_stride - 1
434
+ start_frame = random.randint(0, max_start) if max_start > 0 else 0
435
+ frame_indices = [start_frame + i * local_video_sample_stride for i in range(actual_n_frames)]
436
+
437
+ try:
438
+ sample_args = (video_reader, frame_indices)
439
+ pixel_values = func_timeout(
440
+ VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
441
+ )
442
+ except FunctionTimedOut:
443
+ raise ValueError(f"Read {idx} timeout.")
444
+ except Exception as e:
445
+ raise ValueError(f"Failed to extract frames from video. Error is {e}.")
446
+
447
+ _, height, width, channel = np.shape(pixel_values)
448
+ if self.enable_motion_info:
449
+ motion_pixel_values = np.ones([self.motion_frames, height, width, channel]) * 127.5
450
+ if start_frame > 0:
451
+ motion_max_possible_frames = (start_frame - 1) // local_video_sample_stride + 1
452
+ motion_frame_indices = [0 + i * local_video_sample_stride for i in range(motion_max_possible_frames)]
453
+ motion_frame_indices = motion_frame_indices[-self.motion_frames:]
454
+
455
+ _motion_sample_args = (video_reader, motion_frame_indices)
456
+ _motion_pixel_values = func_timeout(
457
+ VIDEO_READER_TIMEOUT, get_video_reader_batch, args=_motion_sample_args
458
+ )
459
+ motion_pixel_values[-len(motion_frame_indices):] = _motion_pixel_values
460
+
461
+ if not self.enable_bucket:
462
+ motion_pixel_values = torch.from_numpy(motion_pixel_values).permute(0, 3, 1, 2).contiguous()
463
+ motion_pixel_values = motion_pixel_values / 255.
464
+ motion_pixel_values = self.pixel_transforms(motion_pixel_values)
465
+ else:
466
+ motion_pixel_values = None
467
+
468
+ if not self.enable_bucket:
469
+ pixel_values = torch.from_numpy(pixel_values).permute(0, 3, 1, 2).contiguous()
470
+ pixel_values = pixel_values / 255.
471
+ pixel_values = self.pixel_transforms(pixel_values)
472
+
473
+ # Audio information
474
+ start_time = start_frame / fps
475
+ end_time = (start_frame + (actual_n_frames - 1) * local_video_sample_stride) / fps
476
+ duration = end_time - start_time
477
+
478
+ audio_input, sample_rate = librosa.load(audio_path, sr=self.audio_sr)
479
+ start_sample = int(start_time * self.audio_sr)
480
+ end_sample = int(end_time * self.audio_sr)
481
+
482
+ if start_sample >= len(audio_input):
483
+ raise ValueError(f"Audio file too short: {audio_path}")
484
+ else:
485
+ audio_segment = audio_input[start_sample:end_sample]
486
+ target_len = int(duration * self.audio_sr)
487
+ if len(audio_segment) < target_len:
488
+ raise ValueError(f"Audio file too short: {audio_path}")
489
+
490
+ # Control information
491
+ with VideoReader_contextmanager(control_video_id, num_threads=2) as control_video_reader:
492
+ try:
493
+ sample_args = (control_video_reader, frame_indices)
494
+ control_pixel_values = func_timeout(
495
+ VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
496
+ )
497
+ resized_frames = []
498
+ for i in range(len(control_pixel_values)):
499
+ frame = control_pixel_values[i]
500
+ resized_frame = resize_frame(frame, max(self.video_sample_size))
501
+ resized_frames.append(resized_frame)
502
+ control_pixel_values = np.array(control_pixel_values)
503
+ except FunctionTimedOut:
504
+ raise ValueError(f"Read {idx} timeout.")
505
+ except Exception as e:
506
+ raise ValueError(f"Failed to extract frames from video. Error is {e}.")
507
+
508
+ if not self.enable_bucket:
509
+ control_pixel_values = torch.from_numpy(control_pixel_values).permute(0, 3, 1, 2).contiguous()
510
+ control_pixel_values = control_pixel_values / 255.
511
+ del control_video_reader
512
+ else:
513
+ control_pixel_values = control_pixel_values
514
+
515
+ if not self.enable_bucket:
516
+ control_pixel_values = self.video_transforms(control_pixel_values)
517
+
518
+ if random.random() < self.text_drop_ratio:
519
+ text = ''
520
+
521
+ return pixel_values, motion_pixel_values, control_pixel_values, text, audio_segment, sample_rate, new_fps
522
+
523
+ def __len__(self):
524
+ return self.length
525
+
526
+ def __getitem__(self, idx):
527
+ while True:
528
+ sample = {}
529
+ try:
530
+ pixel_values, motion_pixel_values, control_pixel_values, text, audio, sample_rate, new_fps = self.get_batch(idx)
531
+ sample["pixel_values"] = pixel_values
532
+ sample["motion_pixel_values"] = motion_pixel_values
533
+ sample["control_pixel_values"] = control_pixel_values
534
+ sample["text"] = text
535
+ sample["audio"] = torch.from_numpy(audio).float() # 转为 tensor
536
+ sample["sample_rate"] = sample_rate
537
+ sample["fps"] = new_fps
538
+ sample["idx"] = idx
539
+ break
540
+ except Exception as e:
541
+ print(f"Error processing {idx}: {e}, retrying with random idx...")
542
+ idx = random.randint(0, self.length - 1)
543
+
544
+ if self.enable_inpaint and not self.enable_bucket:
545
+ mask = get_random_mask(pixel_values.size(), image_start_only=True)
546
+ mask_pixel_values = pixel_values * (1 - mask) + torch.zeros_like(pixel_values) * mask
547
+ sample["mask_pixel_values"] = mask_pixel_values
548
+ sample["mask"] = mask
549
+
550
+ clip_pixel_values = sample["pixel_values"][0].permute(1, 2, 0).contiguous()
551
+ clip_pixel_values = (clip_pixel_values * 0.5 + 0.5) * 255
552
+ sample["clip_pixel_values"] = clip_pixel_values
553
+
554
+ return sample
555
+
556
+
557
+ class VideoAnimateDataset(Dataset):
558
+ def __init__(
559
+ self,
560
+ ann_path, data_root=None,
561
+ video_sample_size=512,
562
+ video_sample_stride=4,
563
+ video_sample_n_frames=16,
564
+ video_repeat=0,
565
+ text_drop_ratio=0.1,
566
+ enable_bucket=False,
567
+ video_length_drop_start=0.1,
568
+ video_length_drop_end=0.9,
569
+ return_file_name=False,
570
+ ):
571
+ # Loading annotations from files
572
+ print(f"loading annotations from {ann_path} ...")
573
+ if ann_path.endswith('.csv'):
574
+ with open(ann_path, 'r') as csvfile:
575
+ dataset = list(csv.DictReader(csvfile))
576
+ elif ann_path.endswith('.json'):
577
+ dataset = json.load(open(ann_path))
578
+
579
+ self.data_root = data_root
580
+
581
+ # It's used to balance num of images and videos.
582
+ if video_repeat > 0:
583
+ self.dataset = []
584
+ for data in dataset:
585
+ if data.get('type', 'image') != 'video':
586
+ self.dataset.append(data)
587
+
588
+ for _ in range(video_repeat):
589
+ for data in dataset:
590
+ if data.get('type', 'image') == 'video':
591
+ self.dataset.append(data)
592
+ else:
593
+ self.dataset = dataset
594
+ del dataset
595
+
596
+ self.length = len(self.dataset)
597
+ print(f"data scale: {self.length}")
598
+ # TODO: enable bucket training
599
+ self.enable_bucket = enable_bucket
600
+ self.text_drop_ratio = text_drop_ratio
601
+
602
+ self.video_length_drop_start = video_length_drop_start
603
+ self.video_length_drop_end = video_length_drop_end
604
+
605
+ # Video params
606
+ self.video_sample_stride = video_sample_stride
607
+ self.video_sample_n_frames = video_sample_n_frames
608
+ self.video_sample_size = tuple(video_sample_size) if not isinstance(video_sample_size, int) else (video_sample_size, video_sample_size)
609
+ self.video_transforms = transforms.Compose(
610
+ [
611
+ transforms.Resize(min(self.video_sample_size)),
612
+ transforms.CenterCrop(self.video_sample_size),
613
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
614
+ ]
615
+ )
616
+
617
+ self.larger_side_of_image_and_video = min(self.video_sample_size)
618
+
619
+ def get_batch(self, idx):
620
+ data_info = self.dataset[idx % len(self.dataset)]
621
+ video_id, text = data_info['file_path'], data_info['text']
622
+
623
+ if self.data_root is None:
624
+ video_dir = video_id
625
+ else:
626
+ video_dir = os.path.join(self.data_root, video_id)
627
+
628
+ with VideoReader_contextmanager(video_dir, num_threads=2) as video_reader:
629
+ min_sample_n_frames = min(
630
+ self.video_sample_n_frames,
631
+ int(len(video_reader) * (self.video_length_drop_end - self.video_length_drop_start) // self.video_sample_stride)
632
+ )
633
+ if min_sample_n_frames == 0:
634
+ raise ValueError(f"No Frames in video.")
635
+
636
+ video_length = int(self.video_length_drop_end * len(video_reader))
637
+ clip_length = min(video_length, (min_sample_n_frames - 1) * self.video_sample_stride + 1)
638
+ start_idx = random.randint(int(self.video_length_drop_start * video_length), video_length - clip_length) if video_length != clip_length else 0
639
+ batch_index = np.linspace(start_idx, start_idx + clip_length - 1, min_sample_n_frames, dtype=int)
640
+
641
+ try:
642
+ sample_args = (video_reader, batch_index)
643
+ pixel_values = func_timeout(
644
+ VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
645
+ )
646
+ resized_frames = []
647
+ for i in range(len(pixel_values)):
648
+ frame = pixel_values[i]
649
+ resized_frame = resize_frame(frame, self.larger_side_of_image_and_video)
650
+ resized_frames.append(resized_frame)
651
+ pixel_values = np.array(resized_frames)
652
+ except FunctionTimedOut:
653
+ raise ValueError(f"Read {idx} timeout.")
654
+ except Exception as e:
655
+ raise ValueError(f"Failed to extract frames from video. Error is {e}.")
656
+
657
+ if not self.enable_bucket:
658
+ pixel_values = torch.from_numpy(pixel_values).permute(0, 3, 1, 2).contiguous()
659
+ pixel_values = pixel_values / 255.
660
+ del video_reader
661
+ else:
662
+ pixel_values = pixel_values
663
+
664
+ if not self.enable_bucket:
665
+ pixel_values = self.video_transforms(pixel_values)
666
+
667
+ # Random use no text generation
668
+ if random.random() < self.text_drop_ratio:
669
+ text = ''
670
+
671
+ control_video_id = data_info['control_file_path']
672
+
673
+ if control_video_id is not None:
674
+ if self.data_root is None:
675
+ control_video_id = control_video_id
676
+ else:
677
+ control_video_id = os.path.join(self.data_root, control_video_id)
678
+
679
+ if control_video_id is not None:
680
+ with VideoReader_contextmanager(control_video_id, num_threads=2) as control_video_reader:
681
+ try:
682
+ sample_args = (control_video_reader, batch_index)
683
+ control_pixel_values = func_timeout(
684
+ VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
685
+ )
686
+ resized_frames = []
687
+ for i in range(len(control_pixel_values)):
688
+ frame = control_pixel_values[i]
689
+ resized_frame = resize_frame(frame, self.larger_side_of_image_and_video)
690
+ resized_frames.append(resized_frame)
691
+ control_pixel_values = np.array(resized_frames)
692
+ except FunctionTimedOut:
693
+ raise ValueError(f"Read {idx} timeout.")
694
+ except Exception as e:
695
+ raise ValueError(f"Failed to extract frames from video. Error is {e}.")
696
+
697
+ if not self.enable_bucket:
698
+ control_pixel_values = torch.from_numpy(control_pixel_values).permute(0, 3, 1, 2).contiguous()
699
+ control_pixel_values = control_pixel_values / 255.
700
+ del control_video_reader
701
+ else:
702
+ control_pixel_values = control_pixel_values
703
+
704
+ if not self.enable_bucket:
705
+ control_pixel_values = self.video_transforms(control_pixel_values)
706
+ else:
707
+ if not self.enable_bucket:
708
+ control_pixel_values = torch.zeros_like(pixel_values)
709
+ else:
710
+ control_pixel_values = np.zeros_like(pixel_values)
711
+
712
+ face_video_id = data_info['face_file_path']
713
+
714
+ if face_video_id is not None:
715
+ if self.data_root is None:
716
+ face_video_id = face_video_id
717
+ else:
718
+ face_video_id = os.path.join(self.data_root, face_video_id)
719
+
720
+ if face_video_id is not None:
721
+ with VideoReader_contextmanager(face_video_id, num_threads=2) as face_video_reader:
722
+ try:
723
+ sample_args = (face_video_reader, batch_index)
724
+ face_pixel_values = func_timeout(
725
+ VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
726
+ )
727
+ resized_frames = []
728
+ for i in range(len(face_pixel_values)):
729
+ frame = face_pixel_values[i]
730
+ resized_frame = resize_frame(frame, self.larger_side_of_image_and_video)
731
+ resized_frames.append(resized_frame)
732
+ face_pixel_values = np.array(resized_frames)
733
+ except FunctionTimedOut:
734
+ raise ValueError(f"Read {idx} timeout.")
735
+ except Exception as e:
736
+ raise ValueError(f"Failed to extract frames from video. Error is {e}.")
737
+
738
+ if not self.enable_bucket:
739
+ face_pixel_values = torch.from_numpy(face_pixel_values).permute(0, 3, 1, 2).contiguous()
740
+ face_pixel_values = face_pixel_values / 255.
741
+ del face_video_reader
742
+ else:
743
+ face_pixel_values = face_pixel_values
744
+
745
+ if not self.enable_bucket:
746
+ face_pixel_values = self.video_transforms(face_pixel_values)
747
+ else:
748
+ if not self.enable_bucket:
749
+ face_pixel_values = torch.zeros_like(pixel_values)
750
+ else:
751
+ face_pixel_values = np.zeros_like(pixel_values)
752
+
753
+ background_video_id = data_info.get('background_file_path', None)
754
+
755
+ if background_video_id is not None:
756
+ if self.data_root is None:
757
+ background_video_id = background_video_id
758
+ else:
759
+ background_video_id = os.path.join(self.data_root, background_video_id)
760
+
761
+ if background_video_id is not None:
762
+ with VideoReader_contextmanager(background_video_id, num_threads=2) as background_video_reader:
763
+ try:
764
+ sample_args = (background_video_reader, batch_index)
765
+ background_pixel_values = func_timeout(
766
+ VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
767
+ )
768
+ resized_frames = []
769
+ for i in range(len(background_pixel_values)):
770
+ frame = background_pixel_values[i]
771
+ resized_frame = resize_frame(frame, self.larger_side_of_image_and_video)
772
+ resized_frames.append(resized_frame)
773
+ background_pixel_values = np.array(resized_frames)
774
+ except FunctionTimedOut:
775
+ raise ValueError(f"Read {idx} timeout.")
776
+ except Exception as e:
777
+ raise ValueError(f"Failed to extract frames from video. Error is {e}.")
778
+
779
+ if not self.enable_bucket:
780
+ background_pixel_values = torch.from_numpy(background_pixel_values).permute(0, 3, 1, 2).contiguous()
781
+ background_pixel_values = background_pixel_values / 255.
782
+ del background_video_reader
783
+ else:
784
+ background_pixel_values = background_pixel_values
785
+
786
+ if not self.enable_bucket:
787
+ background_pixel_values = self.video_transforms(background_pixel_values)
788
+ else:
789
+ if not self.enable_bucket:
790
+ background_pixel_values = torch.ones_like(pixel_values) * 127.5
791
+ else:
792
+ background_pixel_values = np.ones_like(pixel_values) * 127.5
793
+
794
+ mask_video_id = data_info.get('mask_file_path', None)
795
+
796
+ if mask_video_id is not None:
797
+ if self.data_root is None:
798
+ mask_video_id = mask_video_id
799
+ else:
800
+ mask_video_id = os.path.join(self.data_root, mask_video_id)
801
+
802
+ if mask_video_id is not None:
803
+ with VideoReader_contextmanager(mask_video_id, num_threads=2) as mask_video_reader:
804
+ try:
805
+ sample_args = (mask_video_reader, batch_index)
806
+ mask = func_timeout(
807
+ VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
808
+ )
809
+ resized_frames = []
810
+ for i in range(len(mask)):
811
+ frame = mask[i]
812
+ resized_frame = resize_frame(frame, self.larger_side_of_image_and_video)
813
+ resized_frames.append(resized_frame)
814
+ mask = np.array(resized_frames)
815
+ except FunctionTimedOut:
816
+ raise ValueError(f"Read {idx} timeout.")
817
+ except Exception as e:
818
+ raise ValueError(f"Failed to extract frames from video. Error is {e}.")
819
+
820
+ if not self.enable_bucket:
821
+ mask = torch.from_numpy(mask).permute(0, 3, 1, 2).contiguous()
822
+ mask = mask / 255.
823
+ del mask_video_reader
824
+ else:
825
+ mask = mask
826
+ else:
827
+ if not self.enable_bucket:
828
+ mask = torch.ones_like(pixel_values)
829
+ else:
830
+ mask = np.ones_like(pixel_values) * 255
831
+ mask = mask[:, :, :, :1]
832
+
833
+ ref_pixel_values_path = data_info.get('ref_file_path', [])
834
+ if self.data_root is not None:
835
+ ref_pixel_values_path = os.path.join(self.data_root, ref_pixel_values_path)
836
+ ref_pixel_values = Image.open(ref_pixel_values_path).convert('RGB')
837
+
838
+ if not self.enable_bucket:
839
+ raise ValueError("Not enable_bucket is not supported now. ")
840
+ else:
841
+ ref_pixel_values = np.array(ref_pixel_values)
842
+
843
+ return pixel_values, control_pixel_values, face_pixel_values, background_pixel_values, mask, ref_pixel_values, text, "video"
844
+
845
+ def __len__(self):
846
+ return self.length
847
+
848
+ def __getitem__(self, idx):
849
+ data_info = self.dataset[idx % len(self.dataset)]
850
+ data_type = data_info.get('type', 'image')
851
+ while True:
852
+ sample = {}
853
+ try:
854
+ data_info_local = self.dataset[idx % len(self.dataset)]
855
+ data_type_local = data_info_local.get('type', 'image')
856
+ if data_type_local != data_type:
857
+ raise ValueError("data_type_local != data_type")
858
+
859
+ pixel_values, control_pixel_values, face_pixel_values, background_pixel_values, mask, ref_pixel_values, name, data_type = \
860
+ self.get_batch(idx)
861
+
862
+ sample["pixel_values"] = pixel_values
863
+ sample["control_pixel_values"] = control_pixel_values
864
+ sample["face_pixel_values"] = face_pixel_values
865
+ sample["background_pixel_values"] = background_pixel_values
866
+ sample["mask"] = mask
867
+ sample["ref_pixel_values"] = ref_pixel_values
868
+ sample["clip_pixel_values"] = ref_pixel_values
869
+ sample["text"] = name
870
+ sample["data_type"] = data_type
871
+ sample["idx"] = idx
872
+
873
+ if len(sample) > 0:
874
+ break
875
+ except Exception as e:
876
+ print(e, self.dataset[idx % len(self.dataset)])
877
+ idx = random.randint(0, self.length-1)
878
+
879
+ return sample
880
+
881
+
882
+ if __name__ == "__main__":
883
+ if 1:
884
+ dataset = VideoDataset(
885
+ json_path="./webvidval/results_2M_val.json",
886
+ sample_size=256,
887
+ sample_stride=4, sample_n_frames=16,
888
+ )
889
+
890
+ if 0:
891
+ dataset = WebVid10M(
892
+ csv_path="./webvid/results_2M_val.csv",
893
+ video_folder="./webvid/2M_val",
894
+ sample_size=256,
895
+ sample_stride=4, sample_n_frames=16,
896
+ is_image=False,
897
+ )
898
+
899
+ dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, num_workers=0,)
900
+ for idx, batch in enumerate(dataloader):
901
+ print(batch["pixel_values"].shape, len(batch["text"]))
videox_fun/data/utils.py ADDED
@@ -0,0 +1,347 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+ import gc
3
+ import io
4
+ import json
5
+ import math
6
+ import os
7
+ import random
8
+ from contextlib import contextmanager
9
+ from random import shuffle
10
+ from threading import Thread
11
+
12
+ import albumentations
13
+ import cv2
14
+ import numpy as np
15
+ import torch
16
+ import torch.nn.functional as F
17
+ import torchvision.transforms as transforms
18
+ from decord import VideoReader
19
+ from einops import rearrange
20
+ from func_timeout import FunctionTimedOut, func_timeout
21
+ from packaging import version as pver
22
+ from PIL import Image
23
+ from safetensors.torch import load_file
24
+ from torch.utils.data import BatchSampler, Sampler
25
+ from torch.utils.data.dataset import Dataset
26
+
27
+ VIDEO_READER_TIMEOUT = 20
28
+
29
+ def get_random_mask(shape, image_start_only=False):
30
+ f, c, h, w = shape
31
+ mask = torch.zeros((f, 1, h, w), dtype=torch.uint8)
32
+
33
+ if not image_start_only:
34
+ if f != 1:
35
+ mask_index = np.random.choice([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], p=[0.05, 0.2, 0.2, 0.2, 0.05, 0.05, 0.05, 0.1, 0.05, 0.05])
36
+ else:
37
+ mask_index = np.random.choice([0, 1, 7, 8], p = [0.2, 0.7, 0.05, 0.05])
38
+ if mask_index == 0:
39
+ center_x = torch.randint(0, w, (1,)).item()
40
+ center_y = torch.randint(0, h, (1,)).item()
41
+ block_size_x = torch.randint(w // 4, w // 4 * 3, (1,)).item() # 方块的宽度范围
42
+ block_size_y = torch.randint(h // 4, h // 4 * 3, (1,)).item() # 方块的高度范围
43
+
44
+ start_x = max(center_x - block_size_x // 2, 0)
45
+ end_x = min(center_x + block_size_x // 2, w)
46
+ start_y = max(center_y - block_size_y // 2, 0)
47
+ end_y = min(center_y + block_size_y // 2, h)
48
+ mask[:, :, start_y:end_y, start_x:end_x] = 1
49
+ elif mask_index == 1:
50
+ mask[:, :, :, :] = 1
51
+ elif mask_index == 2:
52
+ mask_frame_index = np.random.randint(1, 5)
53
+ mask[mask_frame_index:, :, :, :] = 1
54
+ elif mask_index == 3:
55
+ mask_frame_index = np.random.randint(1, 5)
56
+ mask[mask_frame_index:-mask_frame_index, :, :, :] = 1
57
+ elif mask_index == 4:
58
+ center_x = torch.randint(0, w, (1,)).item()
59
+ center_y = torch.randint(0, h, (1,)).item()
60
+ block_size_x = torch.randint(w // 4, w // 4 * 3, (1,)).item() # 方块的宽度范围
61
+ block_size_y = torch.randint(h // 4, h // 4 * 3, (1,)).item() # 方块的高度范围
62
+
63
+ start_x = max(center_x - block_size_x // 2, 0)
64
+ end_x = min(center_x + block_size_x // 2, w)
65
+ start_y = max(center_y - block_size_y // 2, 0)
66
+ end_y = min(center_y + block_size_y // 2, h)
67
+
68
+ mask_frame_before = np.random.randint(0, f // 2)
69
+ mask_frame_after = np.random.randint(f // 2, f)
70
+ mask[mask_frame_before:mask_frame_after, :, start_y:end_y, start_x:end_x] = 1
71
+ elif mask_index == 5:
72
+ mask = torch.randint(0, 2, (f, 1, h, w), dtype=torch.uint8)
73
+ elif mask_index == 6:
74
+ num_frames_to_mask = random.randint(1, max(f // 2, 1))
75
+ frames_to_mask = random.sample(range(f), num_frames_to_mask)
76
+
77
+ for i in frames_to_mask:
78
+ block_height = random.randint(1, h // 4)
79
+ block_width = random.randint(1, w // 4)
80
+ top_left_y = random.randint(0, h - block_height)
81
+ top_left_x = random.randint(0, w - block_width)
82
+ mask[i, 0, top_left_y:top_left_y + block_height, top_left_x:top_left_x + block_width] = 1
83
+ elif mask_index == 7:
84
+ center_x = torch.randint(0, w, (1,)).item()
85
+ center_y = torch.randint(0, h, (1,)).item()
86
+ a = torch.randint(min(w, h) // 8, min(w, h) // 4, (1,)).item() # 长半轴
87
+ b = torch.randint(min(h, w) // 8, min(h, w) // 4, (1,)).item() # 短半轴
88
+
89
+ for i in range(h):
90
+ for j in range(w):
91
+ if ((i - center_y) ** 2) / (b ** 2) + ((j - center_x) ** 2) / (a ** 2) < 1:
92
+ mask[:, :, i, j] = 1
93
+ elif mask_index == 8:
94
+ center_x = torch.randint(0, w, (1,)).item()
95
+ center_y = torch.randint(0, h, (1,)).item()
96
+ radius = torch.randint(min(h, w) // 8, min(h, w) // 4, (1,)).item()
97
+ for i in range(h):
98
+ for j in range(w):
99
+ if (i - center_y) ** 2 + (j - center_x) ** 2 < radius ** 2:
100
+ mask[:, :, i, j] = 1
101
+ elif mask_index == 9:
102
+ for idx in range(f):
103
+ if np.random.rand() > 0.5:
104
+ mask[idx, :, :, :] = 1
105
+ else:
106
+ raise ValueError(f"The mask_index {mask_index} is not define")
107
+ else:
108
+ if f != 1:
109
+ mask[1:, :, :, :] = 1
110
+ else:
111
+ mask[:, :, :, :] = 1
112
+ return mask
113
+
114
+ @contextmanager
115
+ def VideoReader_contextmanager(*args, **kwargs):
116
+ vr = VideoReader(*args, **kwargs)
117
+ try:
118
+ yield vr
119
+ finally:
120
+ del vr
121
+ gc.collect()
122
+
123
+ def get_video_reader_batch(video_reader, batch_index):
124
+ frames = video_reader.get_batch(batch_index).asnumpy()
125
+ return frames
126
+
127
+ def resize_frame(frame, target_short_side):
128
+ h, w, _ = frame.shape
129
+ if h < w:
130
+ if target_short_side > h:
131
+ return frame
132
+ new_h = target_short_side
133
+ new_w = int(target_short_side * w / h)
134
+ else:
135
+ if target_short_side > w:
136
+ return frame
137
+ new_w = target_short_side
138
+ new_h = int(target_short_side * h / w)
139
+
140
+ resized_frame = cv2.resize(frame, (new_w, new_h))
141
+ return resized_frame
142
+
143
+ def padding_image(images, new_width, new_height):
144
+ new_image = Image.new('RGB', (new_width, new_height), (255, 255, 255))
145
+
146
+ aspect_ratio = images.width / images.height
147
+ if new_width / new_height > 1:
148
+ if aspect_ratio > new_width / new_height:
149
+ new_img_width = new_width
150
+ new_img_height = int(new_img_width / aspect_ratio)
151
+ else:
152
+ new_img_height = new_height
153
+ new_img_width = int(new_img_height * aspect_ratio)
154
+ else:
155
+ if aspect_ratio > new_width / new_height:
156
+ new_img_width = new_width
157
+ new_img_height = int(new_img_width / aspect_ratio)
158
+ else:
159
+ new_img_height = new_height
160
+ new_img_width = int(new_img_height * aspect_ratio)
161
+
162
+ resized_img = images.resize((new_img_width, new_img_height))
163
+
164
+ paste_x = (new_width - new_img_width) // 2
165
+ paste_y = (new_height - new_img_height) // 2
166
+
167
+ new_image.paste(resized_img, (paste_x, paste_y))
168
+
169
+ return new_image
170
+
171
+ def resize_image_with_target_area(img: Image.Image, target_area: int = 1024 * 1024) -> Image.Image:
172
+ """
173
+ 将 PIL 图像缩放到接近指定像素面积(target_area),保持原始宽高比,
174
+ 并确保新宽度和高度均为 32 的整数倍。
175
+
176
+ 参数:
177
+ img (PIL.Image.Image): 输入图像
178
+ target_area (int): 目标像素总面积,例如 1024*1024 = 1048576
179
+
180
+ 返回:
181
+ PIL.Image.Image: Resize 后的图像
182
+ """
183
+ orig_w, orig_h = img.size
184
+ if orig_w == 0 or orig_h == 0:
185
+ raise ValueError("Input image has zero width or height.")
186
+
187
+ ratio = orig_w / orig_h
188
+ ideal_width = math.sqrt(target_area * ratio)
189
+ ideal_height = ideal_width / ratio
190
+
191
+ new_width = round(ideal_width / 32) * 32
192
+ new_height = round(ideal_height / 32) * 32
193
+
194
+ new_width = max(32, new_width)
195
+ new_height = max(32, new_height)
196
+
197
+ new_width = int(new_width)
198
+ new_height = int(new_height)
199
+
200
+ resized_img = img.resize((new_width, new_height), Image.LANCZOS)
201
+ return resized_img
202
+
203
+ class Camera(object):
204
+ """Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
205
+ """
206
+ def __init__(self, entry):
207
+ fx, fy, cx, cy = entry[1:5]
208
+ self.fx = fx
209
+ self.fy = fy
210
+ self.cx = cx
211
+ self.cy = cy
212
+ w2c_mat = np.array(entry[7:]).reshape(3, 4)
213
+ w2c_mat_4x4 = np.eye(4)
214
+ w2c_mat_4x4[:3, :] = w2c_mat
215
+ self.w2c_mat = w2c_mat_4x4
216
+ self.c2w_mat = np.linalg.inv(w2c_mat_4x4)
217
+
218
+ def custom_meshgrid(*args):
219
+ """Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
220
+ """
221
+ # ref: https://pytorch.org/docs/stable/generated/torch.meshgrid.html?highlight=meshgrid#torch.meshgrid
222
+ if pver.parse(torch.__version__) < pver.parse('1.10'):
223
+ return torch.meshgrid(*args)
224
+ else:
225
+ return torch.meshgrid(*args, indexing='ij')
226
+
227
+ def get_relative_pose(cam_params):
228
+ """Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
229
+ """
230
+ abs_w2cs = [cam_param.w2c_mat for cam_param in cam_params]
231
+ abs_c2ws = [cam_param.c2w_mat for cam_param in cam_params]
232
+ cam_to_origin = 0
233
+ target_cam_c2w = np.array([
234
+ [1, 0, 0, 0],
235
+ [0, 1, 0, -cam_to_origin],
236
+ [0, 0, 1, 0],
237
+ [0, 0, 0, 1]
238
+ ])
239
+ abs2rel = target_cam_c2w @ abs_w2cs[0]
240
+ ret_poses = [target_cam_c2w, ] + [abs2rel @ abs_c2w for abs_c2w in abs_c2ws[1:]]
241
+ ret_poses = np.array(ret_poses, dtype=np.float32)
242
+ return ret_poses
243
+
244
+ def ray_condition(K, c2w, H, W, device):
245
+ """Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
246
+ """
247
+ # c2w: B, V, 4, 4
248
+ # K: B, V, 4
249
+
250
+ B = K.shape[0]
251
+
252
+ j, i = custom_meshgrid(
253
+ torch.linspace(0, H - 1, H, device=device, dtype=c2w.dtype),
254
+ torch.linspace(0, W - 1, W, device=device, dtype=c2w.dtype),
255
+ )
256
+ i = i.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW]
257
+ j = j.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW]
258
+
259
+ fx, fy, cx, cy = K.chunk(4, dim=-1) # B,V, 1
260
+
261
+ zs = torch.ones_like(i) # [B, HxW]
262
+ xs = (i - cx) / fx * zs
263
+ ys = (j - cy) / fy * zs
264
+ zs = zs.expand_as(ys)
265
+
266
+ directions = torch.stack((xs, ys, zs), dim=-1) # B, V, HW, 3
267
+ directions = directions / directions.norm(dim=-1, keepdim=True) # B, V, HW, 3
268
+
269
+ rays_d = directions @ c2w[..., :3, :3].transpose(-1, -2) # B, V, 3, HW
270
+ rays_o = c2w[..., :3, 3] # B, V, 3
271
+ rays_o = rays_o[:, :, None].expand_as(rays_d) # B, V, 3, HW
272
+ # c2w @ dirctions
273
+ rays_dxo = torch.cross(rays_o, rays_d)
274
+ plucker = torch.cat([rays_dxo, rays_d], dim=-1)
275
+ plucker = plucker.reshape(B, c2w.shape[1], H, W, 6) # B, V, H, W, 6
276
+ # plucker = plucker.permute(0, 1, 4, 2, 3)
277
+ return plucker
278
+
279
+ def process_pose_file(pose_file_path, width=672, height=384, original_pose_width=1280, original_pose_height=720, device='cpu', return_poses=False):
280
+ """Modified from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
281
+ """
282
+ with open(pose_file_path, 'r') as f:
283
+ poses = f.readlines()
284
+
285
+ poses = [pose.strip().split(' ') for pose in poses[1:]]
286
+ cam_params = [[float(x) for x in pose] for pose in poses]
287
+ if return_poses:
288
+ return cam_params
289
+ else:
290
+ cam_params = [Camera(cam_param) for cam_param in cam_params]
291
+
292
+ sample_wh_ratio = width / height
293
+ pose_wh_ratio = original_pose_width / original_pose_height # Assuming placeholder ratios, change as needed
294
+
295
+ if pose_wh_ratio > sample_wh_ratio:
296
+ resized_ori_w = height * pose_wh_ratio
297
+ for cam_param in cam_params:
298
+ cam_param.fx = resized_ori_w * cam_param.fx / width
299
+ else:
300
+ resized_ori_h = width / pose_wh_ratio
301
+ for cam_param in cam_params:
302
+ cam_param.fy = resized_ori_h * cam_param.fy / height
303
+
304
+ intrinsic = np.asarray([[cam_param.fx * width,
305
+ cam_param.fy * height,
306
+ cam_param.cx * width,
307
+ cam_param.cy * height]
308
+ for cam_param in cam_params], dtype=np.float32)
309
+
310
+ K = torch.as_tensor(intrinsic)[None] # [1, 1, 4]
311
+ c2ws = get_relative_pose(cam_params) # Assuming this function is defined elsewhere
312
+ c2ws = torch.as_tensor(c2ws)[None] # [1, n_frame, 4, 4]
313
+ plucker_embedding = ray_condition(K, c2ws, height, width, device=device)[0].permute(0, 3, 1, 2).contiguous() # V, 6, H, W
314
+ plucker_embedding = plucker_embedding[None]
315
+ plucker_embedding = rearrange(plucker_embedding, "b f c h w -> b f h w c")[0]
316
+ return plucker_embedding
317
+
318
+ def process_pose_params(cam_params, width=672, height=384, original_pose_width=1280, original_pose_height=720, device='cpu'):
319
+ """Modified from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
320
+ """
321
+ cam_params = [Camera(cam_param) for cam_param in cam_params]
322
+
323
+ sample_wh_ratio = width / height
324
+ pose_wh_ratio = original_pose_width / original_pose_height # Assuming placeholder ratios, change as needed
325
+
326
+ if pose_wh_ratio > sample_wh_ratio:
327
+ resized_ori_w = height * pose_wh_ratio
328
+ for cam_param in cam_params:
329
+ cam_param.fx = resized_ori_w * cam_param.fx / width
330
+ else:
331
+ resized_ori_h = width / pose_wh_ratio
332
+ for cam_param in cam_params:
333
+ cam_param.fy = resized_ori_h * cam_param.fy / height
334
+
335
+ intrinsic = np.asarray([[cam_param.fx * width,
336
+ cam_param.fy * height,
337
+ cam_param.cx * width,
338
+ cam_param.cy * height]
339
+ for cam_param in cam_params], dtype=np.float32)
340
+
341
+ K = torch.as_tensor(intrinsic)[None] # [1, 1, 4]
342
+ c2ws = get_relative_pose(cam_params) # Assuming this function is defined elsewhere
343
+ c2ws = torch.as_tensor(c2ws)[None] # [1, n_frame, 4, 4]
344
+ plucker_embedding = ray_condition(K, c2ws, height, width, device=device)[0].permute(0, 3, 1, 2).contiguous() # V, 6, H, W
345
+ plucker_embedding = plucker_embedding[None]
346
+ plucker_embedding = rearrange(plucker_embedding, "b f c h w -> b f h w c")[0]
347
+ return plucker_embedding
videox_fun/pipeline/__init__.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # from .pipeline_cogvideox_fun import CogVideoXFunPipeline
2
+ # from .pipeline_cogvideox_fun_control import CogVideoXFunControlPipeline
3
+ # from .pipeline_cogvideox_fun_inpaint import CogVideoXFunInpaintPipeline
4
+ # from .pipeline_fantasy_talking import FantasyTalkingPipeline
5
+ # from .pipeline_flux import FluxPipeline
6
+ # from .pipeline_flux2 import Flux2Pipeline
7
+ # from .pipeline_flux2_control import Flux2ControlPipeline
8
+ # from .pipeline_hunyuanvideo import HunyuanVideoPipeline
9
+ # from .pipeline_hunyuanvideo_i2v import HunyuanVideoI2VPipeline
10
+ # from .pipeline_qwenimage import QwenImagePipeline
11
+ # from .pipeline_qwenimage_edit import QwenImageEditPipeline
12
+ # from .pipeline_qwenimage_edit_plus import QwenImageEditPlusPipeline
13
+ # from .pipeline_wan import WanPipeline
14
+ # from .pipeline_wan2_2 import Wan2_2Pipeline
15
+ # from .pipeline_wan2_2_animate import Wan2_2AnimatePipeline
16
+ # from .pipeline_wan2_2_fun_control import Wan2_2FunControlPipeline
17
+ # from .pipeline_wan2_2_fun_inpaint import Wan2_2FunInpaintPipeline
18
+ # from .pipeline_wan2_2_s2v import Wan2_2S2VPipeline
19
+ # from .pipeline_wan2_2_ti2v import Wan2_2TI2VPipeline
20
+ # from .pipeline_wan2_2_vace_fun import Wan2_2VaceFunPipeline
21
+ # from .pipeline_wan_fun_control import WanFunControlPipeline
22
+ # from .pipeline_wan_fun_inpaint import WanFunInpaintPipeline
23
+ # from .pipeline_wan_phantom import WanFunPhantomPipeline
24
+ # from .pipeline_wan_vace import WanVacePipeline
25
+ from .pipeline_z_image import ZImagePipeline
26
+ from .pipeline_z_image_control import ZImageControlPipeline
27
+
28
+ # WanFunPipeline = WanPipeline
29
+ # WanI2VPipeline = WanFunInpaintPipeline
30
+
31
+ # Wan2_2FunPipeline = Wan2_2Pipeline
32
+ # Wan2_2I2VPipeline = Wan2_2FunInpaintPipeline
33
+
34
+ # import importlib.util
35
+
36
+ # if importlib.util.find_spec("paifuser") is not None:
37
+ # # --------------------------------------------------------------- #
38
+ # # Sparse Attention
39
+ # # --------------------------------------------------------------- #
40
+ # from paifuser.ops import sparse_reset
41
+
42
+ # # Wan2.1
43
+ # WanFunInpaintPipeline.__call__ = sparse_reset(WanFunInpaintPipeline.__call__)
44
+ # WanFunPipeline.__call__ = sparse_reset(WanFunPipeline.__call__)
45
+ # WanFunControlPipeline.__call__ = sparse_reset(WanFunControlPipeline.__call__)
46
+ # WanI2VPipeline.__call__ = sparse_reset(WanI2VPipeline.__call__)
47
+ # WanPipeline.__call__ = sparse_reset(WanPipeline.__call__)
48
+ # WanVacePipeline.__call__ = sparse_reset(WanVacePipeline.__call__)
49
+
50
+ # # Phantom
51
+ # WanFunPhantomPipeline.__call__ = sparse_reset(WanFunPhantomPipeline.__call__)
52
+
53
+ # # Wan2.2
54
+ # Wan2_2FunInpaintPipeline.__call__ = sparse_reset(Wan2_2FunInpaintPipeline.__call__)
55
+ # Wan2_2FunPipeline.__call__ = sparse_reset(Wan2_2FunPipeline.__call__)
56
+ # Wan2_2FunControlPipeline.__call__ = sparse_reset(Wan2_2FunControlPipeline.__call__)
57
+ # Wan2_2Pipeline.__call__ = sparse_reset(Wan2_2Pipeline.__call__)
58
+ # Wan2_2I2VPipeline.__call__ = sparse_reset(Wan2_2I2VPipeline.__call__)
59
+ # Wan2_2TI2VPipeline.__call__ = sparse_reset(Wan2_2TI2VPipeline.__call__)
60
+ # Wan2_2S2VPipeline.__call__ = sparse_reset(Wan2_2S2VPipeline.__call__)
61
+ # Wan2_2VaceFunPipeline.__call__ = sparse_reset(Wan2_2VaceFunPipeline.__call__)
62
+ # Wan2_2AnimatePipeline.__call__ = sparse_reset(Wan2_2AnimatePipeline.__call__)
videox_fun/pipeline/pipeline_cogvideox_fun.py ADDED
@@ -0,0 +1,862 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import inspect
17
+ import math
18
+ from dataclasses import dataclass
19
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
20
+
21
+ import numpy as np
22
+ import torch
23
+ from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
24
+ from diffusers.models.embeddings import get_1d_rotary_pos_embed
25
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
26
+ from diffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
27
+ from diffusers.utils import BaseOutput, logging, replace_example_docstring
28
+ from diffusers.utils.torch_utils import randn_tensor
29
+ from diffusers.video_processor import VideoProcessor
30
+
31
+ from ..models import (AutoencoderKLCogVideoX,
32
+ CogVideoXTransformer3DModel, T5EncoderModel,
33
+ T5Tokenizer)
34
+
35
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
36
+
37
+
38
+ EXAMPLE_DOC_STRING = """
39
+ Examples:
40
+ ```python
41
+ pass
42
+ ```
43
+ """
44
+
45
+
46
+ # Copied from diffusers.models.embeddings.get_3d_rotary_pos_embed
47
+ def get_3d_rotary_pos_embed(
48
+ embed_dim,
49
+ crops_coords,
50
+ grid_size,
51
+ temporal_size,
52
+ theta: int = 10000,
53
+ use_real: bool = True,
54
+ grid_type: str = "linspace",
55
+ max_size: Optional[Tuple[int, int]] = None,
56
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
57
+ """
58
+ RoPE for video tokens with 3D structure.
59
+
60
+ Args:
61
+ embed_dim: (`int`):
62
+ The embedding dimension size, corresponding to hidden_size_head.
63
+ crops_coords (`Tuple[int]`):
64
+ The top-left and bottom-right coordinates of the crop.
65
+ grid_size (`Tuple[int]`):
66
+ The grid size of the spatial positional embedding (height, width).
67
+ temporal_size (`int`):
68
+ The size of the temporal dimension.
69
+ theta (`float`):
70
+ Scaling factor for frequency computation.
71
+ grid_type (`str`):
72
+ Whether to use "linspace" or "slice" to compute grids.
73
+
74
+ Returns:
75
+ `torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`.
76
+ """
77
+ if use_real is not True:
78
+ raise ValueError(" `use_real = False` is not currently supported for get_3d_rotary_pos_embed")
79
+
80
+ if grid_type == "linspace":
81
+ start, stop = crops_coords
82
+ grid_size_h, grid_size_w = grid_size
83
+ grid_h = np.linspace(start[0], stop[0], grid_size_h, endpoint=False, dtype=np.float32)
84
+ grid_w = np.linspace(start[1], stop[1], grid_size_w, endpoint=False, dtype=np.float32)
85
+ grid_t = np.arange(temporal_size, dtype=np.float32)
86
+ grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32)
87
+ elif grid_type == "slice":
88
+ max_h, max_w = max_size
89
+ grid_size_h, grid_size_w = grid_size
90
+ grid_h = np.arange(max_h, dtype=np.float32)
91
+ grid_w = np.arange(max_w, dtype=np.float32)
92
+ grid_t = np.arange(temporal_size, dtype=np.float32)
93
+ else:
94
+ raise ValueError("Invalid value passed for `grid_type`.")
95
+
96
+ # Compute dimensions for each axis
97
+ dim_t = embed_dim // 4
98
+ dim_h = embed_dim // 8 * 3
99
+ dim_w = embed_dim // 8 * 3
100
+
101
+ # Temporal frequencies
102
+ freqs_t = get_1d_rotary_pos_embed(dim_t, grid_t, use_real=True)
103
+ # Spatial frequencies for height and width
104
+ freqs_h = get_1d_rotary_pos_embed(dim_h, grid_h, use_real=True)
105
+ freqs_w = get_1d_rotary_pos_embed(dim_w, grid_w, use_real=True)
106
+
107
+ # BroadCast and concatenate temporal and spaial frequencie (height and width) into a 3d tensor
108
+ def combine_time_height_width(freqs_t, freqs_h, freqs_w):
109
+ freqs_t = freqs_t[:, None, None, :].expand(
110
+ -1, grid_size_h, grid_size_w, -1
111
+ ) # temporal_size, grid_size_h, grid_size_w, dim_t
112
+ freqs_h = freqs_h[None, :, None, :].expand(
113
+ temporal_size, -1, grid_size_w, -1
114
+ ) # temporal_size, grid_size_h, grid_size_2, dim_h
115
+ freqs_w = freqs_w[None, None, :, :].expand(
116
+ temporal_size, grid_size_h, -1, -1
117
+ ) # temporal_size, grid_size_h, grid_size_2, dim_w
118
+
119
+ freqs = torch.cat(
120
+ [freqs_t, freqs_h, freqs_w], dim=-1
121
+ ) # temporal_size, grid_size_h, grid_size_w, (dim_t + dim_h + dim_w)
122
+ freqs = freqs.view(
123
+ temporal_size * grid_size_h * grid_size_w, -1
124
+ ) # (temporal_size * grid_size_h * grid_size_w), (dim_t + dim_h + dim_w)
125
+ return freqs
126
+
127
+ t_cos, t_sin = freqs_t # both t_cos and t_sin has shape: temporal_size, dim_t
128
+ h_cos, h_sin = freqs_h # both h_cos and h_sin has shape: grid_size_h, dim_h
129
+ w_cos, w_sin = freqs_w # both w_cos and w_sin has shape: grid_size_w, dim_w
130
+
131
+ if grid_type == "slice":
132
+ t_cos, t_sin = t_cos[:temporal_size], t_sin[:temporal_size]
133
+ h_cos, h_sin = h_cos[:grid_size_h], h_sin[:grid_size_h]
134
+ w_cos, w_sin = w_cos[:grid_size_w], w_sin[:grid_size_w]
135
+
136
+ cos = combine_time_height_width(t_cos, h_cos, w_cos)
137
+ sin = combine_time_height_width(t_sin, h_sin, w_sin)
138
+ return cos, sin
139
+
140
+
141
+ # Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
142
+ def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
143
+ tw = tgt_width
144
+ th = tgt_height
145
+ h, w = src
146
+ r = h / w
147
+ if r > (th / tw):
148
+ resize_height = th
149
+ resize_width = int(round(th / h * w))
150
+ else:
151
+ resize_width = tw
152
+ resize_height = int(round(tw / w * h))
153
+
154
+ crop_top = int(round((th - resize_height) / 2.0))
155
+ crop_left = int(round((tw - resize_width) / 2.0))
156
+
157
+ return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
158
+
159
+
160
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
161
+ def retrieve_timesteps(
162
+ scheduler,
163
+ num_inference_steps: Optional[int] = None,
164
+ device: Optional[Union[str, torch.device]] = None,
165
+ timesteps: Optional[List[int]] = None,
166
+ sigmas: Optional[List[float]] = None,
167
+ **kwargs,
168
+ ):
169
+ """
170
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
171
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
172
+
173
+ Args:
174
+ scheduler (`SchedulerMixin`):
175
+ The scheduler to get timesteps from.
176
+ num_inference_steps (`int`):
177
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
178
+ must be `None`.
179
+ device (`str` or `torch.device`, *optional*):
180
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
181
+ timesteps (`List[int]`, *optional*):
182
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
183
+ `num_inference_steps` and `sigmas` must be `None`.
184
+ sigmas (`List[float]`, *optional*):
185
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
186
+ `num_inference_steps` and `timesteps` must be `None`.
187
+
188
+ Returns:
189
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
190
+ second element is the number of inference steps.
191
+ """
192
+ if timesteps is not None and sigmas is not None:
193
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
194
+ if timesteps is not None:
195
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
196
+ if not accepts_timesteps:
197
+ raise ValueError(
198
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
199
+ f" timestep schedules. Please check whether you are using the correct scheduler."
200
+ )
201
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
202
+ timesteps = scheduler.timesteps
203
+ num_inference_steps = len(timesteps)
204
+ elif sigmas is not None:
205
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
206
+ if not accept_sigmas:
207
+ raise ValueError(
208
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
209
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
210
+ )
211
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
212
+ timesteps = scheduler.timesteps
213
+ num_inference_steps = len(timesteps)
214
+ else:
215
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
216
+ timesteps = scheduler.timesteps
217
+ return timesteps, num_inference_steps
218
+
219
+
220
+ @dataclass
221
+ class CogVideoXFunPipelineOutput(BaseOutput):
222
+ r"""
223
+ Output class for CogVideo pipelines.
224
+
225
+ Args:
226
+ video (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
227
+ List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
228
+ denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
229
+ `(batch_size, num_frames, channels, height, width)`.
230
+ """
231
+
232
+ videos: torch.Tensor
233
+
234
+
235
+ class CogVideoXFunPipeline(DiffusionPipeline):
236
+ r"""
237
+ Pipeline for text-to-video generation using CogVideoX_Fun.
238
+
239
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
240
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
241
+
242
+ Args:
243
+ vae ([`AutoencoderKL`]):
244
+ Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
245
+ text_encoder ([`T5EncoderModel`]):
246
+ Frozen text-encoder. CogVideoX uses
247
+ [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the
248
+ [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
249
+ tokenizer (`T5Tokenizer`):
250
+ Tokenizer of class
251
+ [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
252
+ transformer ([`CogVideoXTransformer3DModel`]):
253
+ A text conditioned `CogVideoXTransformer3DModel` to denoise the encoded video latents.
254
+ scheduler ([`SchedulerMixin`]):
255
+ A scheduler to be used in combination with `transformer` to denoise the encoded video latents.
256
+ """
257
+
258
+ _optional_components = []
259
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
260
+
261
+ _callback_tensor_inputs = [
262
+ "latents",
263
+ "prompt_embeds",
264
+ "negative_prompt_embeds",
265
+ ]
266
+
267
+ def __init__(
268
+ self,
269
+ tokenizer: T5Tokenizer,
270
+ text_encoder: T5EncoderModel,
271
+ vae: AutoencoderKLCogVideoX,
272
+ transformer: CogVideoXTransformer3DModel,
273
+ scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
274
+ ):
275
+ super().__init__()
276
+
277
+ self.register_modules(
278
+ tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
279
+ )
280
+ self.vae_scale_factor_spatial = (
281
+ 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
282
+ )
283
+ self.vae_scale_factor_temporal = (
284
+ self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4
285
+ )
286
+
287
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
288
+
289
+ def _get_t5_prompt_embeds(
290
+ self,
291
+ prompt: Union[str, List[str]] = None,
292
+ num_videos_per_prompt: int = 1,
293
+ max_sequence_length: int = 226,
294
+ device: Optional[torch.device] = None,
295
+ dtype: Optional[torch.dtype] = None,
296
+ ):
297
+ device = device or self._execution_device
298
+ dtype = dtype or self.text_encoder.dtype
299
+
300
+ prompt = [prompt] if isinstance(prompt, str) else prompt
301
+ batch_size = len(prompt)
302
+
303
+ text_inputs = self.tokenizer(
304
+ prompt,
305
+ padding="max_length",
306
+ max_length=max_sequence_length,
307
+ truncation=True,
308
+ add_special_tokens=True,
309
+ return_tensors="pt",
310
+ )
311
+ text_input_ids = text_inputs.input_ids
312
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
313
+
314
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
315
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
316
+ logger.warning(
317
+ "The following part of your input was truncated because `max_sequence_length` is set to "
318
+ f" {max_sequence_length} tokens: {removed_text}"
319
+ )
320
+
321
+ prompt_embeds = self.text_encoder(text_input_ids.to(device))[0]
322
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
323
+
324
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
325
+ _, seq_len, _ = prompt_embeds.shape
326
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
327
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
328
+
329
+ return prompt_embeds
330
+
331
+ def encode_prompt(
332
+ self,
333
+ prompt: Union[str, List[str]],
334
+ negative_prompt: Optional[Union[str, List[str]]] = None,
335
+ do_classifier_free_guidance: bool = True,
336
+ num_videos_per_prompt: int = 1,
337
+ prompt_embeds: Optional[torch.Tensor] = None,
338
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
339
+ max_sequence_length: int = 226,
340
+ device: Optional[torch.device] = None,
341
+ dtype: Optional[torch.dtype] = None,
342
+ ):
343
+ r"""
344
+ Encodes the prompt into text encoder hidden states.
345
+
346
+ Args:
347
+ prompt (`str` or `List[str]`, *optional*):
348
+ prompt to be encoded
349
+ negative_prompt (`str` or `List[str]`, *optional*):
350
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
351
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
352
+ less than `1`).
353
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
354
+ Whether to use classifier free guidance or not.
355
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
356
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
357
+ prompt_embeds (`torch.Tensor`, *optional*):
358
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
359
+ provided, text embeddings will be generated from `prompt` input argument.
360
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
361
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
362
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
363
+ argument.
364
+ device: (`torch.device`, *optional*):
365
+ torch device
366
+ dtype: (`torch.dtype`, *optional*):
367
+ torch dtype
368
+ """
369
+ device = device or self._execution_device
370
+
371
+ prompt = [prompt] if isinstance(prompt, str) else prompt
372
+ if prompt is not None:
373
+ batch_size = len(prompt)
374
+ else:
375
+ batch_size = prompt_embeds.shape[0]
376
+
377
+ if prompt_embeds is None:
378
+ prompt_embeds = self._get_t5_prompt_embeds(
379
+ prompt=prompt,
380
+ num_videos_per_prompt=num_videos_per_prompt,
381
+ max_sequence_length=max_sequence_length,
382
+ device=device,
383
+ dtype=dtype,
384
+ )
385
+
386
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
387
+ negative_prompt = negative_prompt or ""
388
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
389
+
390
+ if prompt is not None and type(prompt) is not type(negative_prompt):
391
+ raise TypeError(
392
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
393
+ f" {type(prompt)}."
394
+ )
395
+ elif batch_size != len(negative_prompt):
396
+ raise ValueError(
397
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
398
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
399
+ " the batch size of `prompt`."
400
+ )
401
+
402
+ negative_prompt_embeds = self._get_t5_prompt_embeds(
403
+ prompt=negative_prompt,
404
+ num_videos_per_prompt=num_videos_per_prompt,
405
+ max_sequence_length=max_sequence_length,
406
+ device=device,
407
+ dtype=dtype,
408
+ )
409
+
410
+ return prompt_embeds, negative_prompt_embeds
411
+
412
+ def prepare_latents(
413
+ self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
414
+ ):
415
+ if isinstance(generator, list) and len(generator) != batch_size:
416
+ raise ValueError(
417
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
418
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
419
+ )
420
+
421
+ shape = (
422
+ batch_size,
423
+ (num_frames - 1) // self.vae_scale_factor_temporal + 1,
424
+ num_channels_latents,
425
+ height // self.vae_scale_factor_spatial,
426
+ width // self.vae_scale_factor_spatial,
427
+ )
428
+
429
+ if latents is None:
430
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
431
+ else:
432
+ latents = latents.to(device)
433
+
434
+ # scale the initial noise by the standard deviation required by the scheduler
435
+ latents = latents * self.scheduler.init_noise_sigma
436
+ return latents
437
+
438
+ def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
439
+ latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
440
+ latents = 1 / self.vae.config.scaling_factor * latents
441
+
442
+ frames = self.vae.decode(latents).sample
443
+ frames = (frames / 2 + 0.5).clamp(0, 1)
444
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
445
+ frames = frames.cpu().float().numpy()
446
+ return frames
447
+
448
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
449
+ def prepare_extra_step_kwargs(self, generator, eta):
450
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
451
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
452
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
453
+ # and should be between [0, 1]
454
+
455
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
456
+ extra_step_kwargs = {}
457
+ if accepts_eta:
458
+ extra_step_kwargs["eta"] = eta
459
+
460
+ # check if the scheduler accepts generator
461
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
462
+ if accepts_generator:
463
+ extra_step_kwargs["generator"] = generator
464
+ return extra_step_kwargs
465
+
466
+ # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs
467
+ def check_inputs(
468
+ self,
469
+ prompt,
470
+ height,
471
+ width,
472
+ negative_prompt,
473
+ callback_on_step_end_tensor_inputs,
474
+ prompt_embeds=None,
475
+ negative_prompt_embeds=None,
476
+ ):
477
+ if height % 8 != 0 or width % 8 != 0:
478
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
479
+
480
+ if callback_on_step_end_tensor_inputs is not None and not all(
481
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
482
+ ):
483
+ raise ValueError(
484
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
485
+ )
486
+ if prompt is not None and prompt_embeds is not None:
487
+ raise ValueError(
488
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
489
+ " only forward one of the two."
490
+ )
491
+ elif prompt is None and prompt_embeds is None:
492
+ raise ValueError(
493
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
494
+ )
495
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
496
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
497
+
498
+ if prompt is not None and negative_prompt_embeds is not None:
499
+ raise ValueError(
500
+ f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
501
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
502
+ )
503
+
504
+ if negative_prompt is not None and negative_prompt_embeds is not None:
505
+ raise ValueError(
506
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
507
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
508
+ )
509
+
510
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
511
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
512
+ raise ValueError(
513
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
514
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
515
+ f" {negative_prompt_embeds.shape}."
516
+ )
517
+
518
+ def fuse_qkv_projections(self) -> None:
519
+ r"""Enables fused QKV projections."""
520
+ self.fusing_transformer = True
521
+ self.transformer.fuse_qkv_projections()
522
+
523
+ def unfuse_qkv_projections(self) -> None:
524
+ r"""Disable QKV projection fusion if enabled."""
525
+ if not self.fusing_transformer:
526
+ logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.")
527
+ else:
528
+ self.transformer.unfuse_qkv_projections()
529
+ self.fusing_transformer = False
530
+
531
+ def _prepare_rotary_positional_embeddings(
532
+ self,
533
+ height: int,
534
+ width: int,
535
+ num_frames: int,
536
+ device: torch.device,
537
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
538
+ grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
539
+ grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
540
+
541
+ p = self.transformer.config.patch_size
542
+ p_t = self.transformer.config.patch_size_t
543
+
544
+ base_size_width = self.transformer.config.sample_width // p
545
+ base_size_height = self.transformer.config.sample_height // p
546
+
547
+ if p_t is None:
548
+ # CogVideoX 1.0
549
+ grid_crops_coords = get_resize_crop_region_for_grid(
550
+ (grid_height, grid_width), base_size_width, base_size_height
551
+ )
552
+ freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
553
+ embed_dim=self.transformer.config.attention_head_dim,
554
+ crops_coords=grid_crops_coords,
555
+ grid_size=(grid_height, grid_width),
556
+ temporal_size=num_frames,
557
+ )
558
+ else:
559
+ # CogVideoX 1.5
560
+ base_num_frames = (num_frames + p_t - 1) // p_t
561
+
562
+ freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
563
+ embed_dim=self.transformer.config.attention_head_dim,
564
+ crops_coords=None,
565
+ grid_size=(grid_height, grid_width),
566
+ temporal_size=base_num_frames,
567
+ grid_type="slice",
568
+ max_size=(base_size_height, base_size_width),
569
+ )
570
+
571
+ freqs_cos = freqs_cos.to(device=device)
572
+ freqs_sin = freqs_sin.to(device=device)
573
+ return freqs_cos, freqs_sin
574
+
575
+ @property
576
+ def guidance_scale(self):
577
+ return self._guidance_scale
578
+
579
+ @property
580
+ def num_timesteps(self):
581
+ return self._num_timesteps
582
+
583
+ @property
584
+ def attention_kwargs(self):
585
+ return self._attention_kwargs
586
+
587
+ @property
588
+ def interrupt(self):
589
+ return self._interrupt
590
+
591
+ @torch.no_grad()
592
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
593
+ def __call__(
594
+ self,
595
+ prompt: Optional[Union[str, List[str]]] = None,
596
+ negative_prompt: Optional[Union[str, List[str]]] = None,
597
+ height: int = 480,
598
+ width: int = 720,
599
+ num_frames: int = 49,
600
+ num_inference_steps: int = 50,
601
+ timesteps: Optional[List[int]] = None,
602
+ guidance_scale: float = 6,
603
+ use_dynamic_cfg: bool = False,
604
+ num_videos_per_prompt: int = 1,
605
+ eta: float = 0.0,
606
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
607
+ latents: Optional[torch.FloatTensor] = None,
608
+ prompt_embeds: Optional[torch.FloatTensor] = None,
609
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
610
+ output_type: str = "numpy",
611
+ return_dict: bool = False,
612
+ callback_on_step_end: Optional[
613
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
614
+ ] = None,
615
+ attention_kwargs: Optional[Dict[str, Any]] = None,
616
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
617
+ max_sequence_length: int = 226,
618
+ ) -> Union[CogVideoXFunPipelineOutput, Tuple]:
619
+ """
620
+ Function invoked when calling the pipeline for generation.
621
+
622
+ Args:
623
+ prompt (`str` or `List[str]`, *optional*):
624
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
625
+ instead.
626
+ negative_prompt (`str` or `List[str]`, *optional*):
627
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
628
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
629
+ less than `1`).
630
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
631
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
632
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
633
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
634
+ num_frames (`int`, defaults to `48`):
635
+ Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will
636
+ contain 1 extra frame because CogVideoX is conditioned with (num_seconds * fps + 1) frames where
637
+ num_seconds is 6 and fps is 4. However, since videos can be saved at any fps, the only condition that
638
+ needs to be satisfied is that of divisibility mentioned above.
639
+ num_inference_steps (`int`, *optional*, defaults to 50):
640
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
641
+ expense of slower inference.
642
+ timesteps (`List[int]`, *optional*):
643
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
644
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
645
+ passed will be used. Must be in descending order.
646
+ guidance_scale (`float`, *optional*, defaults to 7.0):
647
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
648
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
649
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
650
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
651
+ usually at the expense of lower image quality.
652
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
653
+ The number of videos to generate per prompt.
654
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
655
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
656
+ to make generation deterministic.
657
+ latents (`torch.FloatTensor`, *optional*):
658
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
659
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
660
+ tensor will ge generated by sampling using the supplied random `generator`.
661
+ prompt_embeds (`torch.FloatTensor`, *optional*):
662
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
663
+ provided, text embeddings will be generated from `prompt` input argument.
664
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
665
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
666
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
667
+ argument.
668
+ output_type (`str`, *optional*, defaults to `"pil"`):
669
+ The output format of the generate image. Choose between
670
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
671
+ return_dict (`bool`, *optional*, defaults to `True`):
672
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
673
+ of a plain tuple.
674
+ callback_on_step_end (`Callable`, *optional*):
675
+ A function that calls at the end of each denoising steps during the inference. The function is called
676
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
677
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
678
+ `callback_on_step_end_tensor_inputs`.
679
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
680
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
681
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
682
+ `._callback_tensor_inputs` attribute of your pipeline class.
683
+ max_sequence_length (`int`, defaults to `226`):
684
+ Maximum sequence length in encoded prompt. Must be consistent with
685
+ `self.transformer.config.max_text_seq_length` otherwise may lead to poor results.
686
+
687
+ Examples:
688
+
689
+ Returns:
690
+ [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXFunPipelineOutput`] or `tuple`:
691
+ [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXFunPipelineOutput`] if `return_dict` is True, otherwise a
692
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
693
+ """
694
+
695
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
696
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
697
+
698
+ height = height or self.transformer.config.sample_height * self.vae_scale_factor_spatial
699
+ width = width or self.transformer.config.sample_width * self.vae_scale_factor_spatial
700
+ num_frames = num_frames or self.transformer.config.sample_frames
701
+
702
+ num_videos_per_prompt = 1
703
+
704
+ # 1. Check inputs. Raise error if not correct
705
+ self.check_inputs(
706
+ prompt,
707
+ height,
708
+ width,
709
+ negative_prompt,
710
+ callback_on_step_end_tensor_inputs,
711
+ prompt_embeds,
712
+ negative_prompt_embeds,
713
+ )
714
+ self._guidance_scale = guidance_scale
715
+ self._attention_kwargs = attention_kwargs
716
+ self._interrupt = False
717
+
718
+ # 2. Default call parameters
719
+ if prompt is not None and isinstance(prompt, str):
720
+ batch_size = 1
721
+ elif prompt is not None and isinstance(prompt, list):
722
+ batch_size = len(prompt)
723
+ else:
724
+ batch_size = prompt_embeds.shape[0]
725
+
726
+ device = self._execution_device
727
+
728
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
729
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
730
+ # corresponds to doing no classifier free guidance.
731
+ do_classifier_free_guidance = guidance_scale > 1.0
732
+
733
+ # 3. Encode input prompt
734
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
735
+ prompt,
736
+ negative_prompt,
737
+ do_classifier_free_guidance,
738
+ num_videos_per_prompt=num_videos_per_prompt,
739
+ prompt_embeds=prompt_embeds,
740
+ negative_prompt_embeds=negative_prompt_embeds,
741
+ max_sequence_length=max_sequence_length,
742
+ device=device,
743
+ )
744
+ if do_classifier_free_guidance:
745
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
746
+
747
+ # 4. Prepare timesteps
748
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
749
+ self._num_timesteps = len(timesteps)
750
+
751
+ # 5. Prepare latents
752
+ latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
753
+
754
+ # For CogVideoX 1.5, the latent frames should be padded to make it divisible by patch_size_t
755
+ patch_size_t = self.transformer.config.patch_size_t
756
+ additional_frames = 0
757
+ if num_frames != 1 and patch_size_t is not None and latent_frames % patch_size_t != 0:
758
+ additional_frames = patch_size_t - latent_frames % patch_size_t
759
+ num_frames += additional_frames * self.vae_scale_factor_temporal
760
+
761
+ latent_channels = self.transformer.config.in_channels
762
+ latents = self.prepare_latents(
763
+ batch_size * num_videos_per_prompt,
764
+ latent_channels,
765
+ num_frames,
766
+ height,
767
+ width,
768
+ prompt_embeds.dtype,
769
+ device,
770
+ generator,
771
+ latents,
772
+ )
773
+
774
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
775
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
776
+
777
+ # 7. Create rotary embeds if required
778
+ image_rotary_emb = (
779
+ self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
780
+ if self.transformer.config.use_rotary_positional_embeddings
781
+ else None
782
+ )
783
+
784
+ # 8. Denoising loop
785
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
786
+
787
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
788
+ # for DPM-solver++
789
+ old_pred_original_sample = None
790
+ for i, t in enumerate(timesteps):
791
+ if self.interrupt:
792
+ continue
793
+
794
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
795
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
796
+
797
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
798
+ timestep = t.expand(latent_model_input.shape[0])
799
+
800
+ # predict noise model_output
801
+ noise_pred = self.transformer(
802
+ hidden_states=latent_model_input,
803
+ encoder_hidden_states=prompt_embeds,
804
+ timestep=timestep,
805
+ image_rotary_emb=image_rotary_emb,
806
+ return_dict=False,
807
+ )[0]
808
+ noise_pred = noise_pred.float()
809
+
810
+ # perform guidance
811
+ if use_dynamic_cfg:
812
+ self._guidance_scale = 1 + guidance_scale * (
813
+ (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2
814
+ )
815
+ if do_classifier_free_guidance:
816
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
817
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
818
+
819
+ # compute the previous noisy sample x_t -> x_t-1
820
+ if not isinstance(self.scheduler, CogVideoXDPMScheduler):
821
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
822
+ else:
823
+ latents, old_pred_original_sample = self.scheduler.step(
824
+ noise_pred,
825
+ old_pred_original_sample,
826
+ t,
827
+ timesteps[i - 1] if i > 0 else None,
828
+ latents,
829
+ **extra_step_kwargs,
830
+ return_dict=False,
831
+ )
832
+ latents = latents.to(prompt_embeds.dtype)
833
+
834
+ # call the callback, if provided
835
+ if callback_on_step_end is not None:
836
+ callback_kwargs = {}
837
+ for k in callback_on_step_end_tensor_inputs:
838
+ callback_kwargs[k] = locals()[k]
839
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
840
+
841
+ latents = callback_outputs.pop("latents", latents)
842
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
843
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
844
+
845
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
846
+ progress_bar.update()
847
+
848
+ if output_type == "numpy":
849
+ video = self.decode_latents(latents)
850
+ elif not output_type == "latent":
851
+ video = self.decode_latents(latents)
852
+ video = self.video_processor.postprocess_video(video=video, output_type=output_type)
853
+ else:
854
+ video = latents
855
+
856
+ # Offload all models
857
+ self.maybe_free_model_hooks()
858
+
859
+ if not return_dict:
860
+ video = torch.from_numpy(video)
861
+
862
+ return CogVideoXFunPipelineOutput(videos=video)
videox_fun/pipeline/pipeline_cogvideox_fun_control.py ADDED
@@ -0,0 +1,956 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import inspect
17
+ import math
18
+ from dataclasses import dataclass
19
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
20
+
21
+ import numpy as np
22
+ import torch
23
+ import torch.nn.functional as F
24
+ from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
25
+ from diffusers.image_processor import VaeImageProcessor
26
+ from diffusers.models.embeddings import (get_1d_rotary_pos_embed,
27
+ get_3d_rotary_pos_embed)
28
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
29
+ from diffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
30
+ from diffusers.utils import BaseOutput, logging, replace_example_docstring
31
+ from diffusers.utils.torch_utils import randn_tensor
32
+ from diffusers.video_processor import VideoProcessor
33
+ from einops import rearrange
34
+
35
+ from ..models import (AutoencoderKLCogVideoX,
36
+ CogVideoXTransformer3DModel, T5EncoderModel,
37
+ T5Tokenizer)
38
+
39
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
40
+
41
+
42
+ EXAMPLE_DOC_STRING = """
43
+ Examples:
44
+ ```python
45
+ pass
46
+ ```
47
+ """
48
+
49
+
50
+ # Copied from diffusers.models.embeddings.get_3d_rotary_pos_embed
51
+ def get_3d_rotary_pos_embed(
52
+ embed_dim,
53
+ crops_coords,
54
+ grid_size,
55
+ temporal_size,
56
+ theta: int = 10000,
57
+ use_real: bool = True,
58
+ grid_type: str = "linspace",
59
+ max_size: Optional[Tuple[int, int]] = None,
60
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
61
+ """
62
+ RoPE for video tokens with 3D structure.
63
+
64
+ Args:
65
+ embed_dim: (`int`):
66
+ The embedding dimension size, corresponding to hidden_size_head.
67
+ crops_coords (`Tuple[int]`):
68
+ The top-left and bottom-right coordinates of the crop.
69
+ grid_size (`Tuple[int]`):
70
+ The grid size of the spatial positional embedding (height, width).
71
+ temporal_size (`int`):
72
+ The size of the temporal dimension.
73
+ theta (`float`):
74
+ Scaling factor for frequency computation.
75
+ grid_type (`str`):
76
+ Whether to use "linspace" or "slice" to compute grids.
77
+
78
+ Returns:
79
+ `torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`.
80
+ """
81
+ if use_real is not True:
82
+ raise ValueError(" `use_real = False` is not currently supported for get_3d_rotary_pos_embed")
83
+
84
+ if grid_type == "linspace":
85
+ start, stop = crops_coords
86
+ grid_size_h, grid_size_w = grid_size
87
+ grid_h = np.linspace(start[0], stop[0], grid_size_h, endpoint=False, dtype=np.float32)
88
+ grid_w = np.linspace(start[1], stop[1], grid_size_w, endpoint=False, dtype=np.float32)
89
+ grid_t = np.arange(temporal_size, dtype=np.float32)
90
+ grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32)
91
+ elif grid_type == "slice":
92
+ max_h, max_w = max_size
93
+ grid_size_h, grid_size_w = grid_size
94
+ grid_h = np.arange(max_h, dtype=np.float32)
95
+ grid_w = np.arange(max_w, dtype=np.float32)
96
+ grid_t = np.arange(temporal_size, dtype=np.float32)
97
+ else:
98
+ raise ValueError("Invalid value passed for `grid_type`.")
99
+
100
+ # Compute dimensions for each axis
101
+ dim_t = embed_dim // 4
102
+ dim_h = embed_dim // 8 * 3
103
+ dim_w = embed_dim // 8 * 3
104
+
105
+ # Temporal frequencies
106
+ freqs_t = get_1d_rotary_pos_embed(dim_t, grid_t, use_real=True)
107
+ # Spatial frequencies for height and width
108
+ freqs_h = get_1d_rotary_pos_embed(dim_h, grid_h, use_real=True)
109
+ freqs_w = get_1d_rotary_pos_embed(dim_w, grid_w, use_real=True)
110
+
111
+ # BroadCast and concatenate temporal and spaial frequencie (height and width) into a 3d tensor
112
+ def combine_time_height_width(freqs_t, freqs_h, freqs_w):
113
+ freqs_t = freqs_t[:, None, None, :].expand(
114
+ -1, grid_size_h, grid_size_w, -1
115
+ ) # temporal_size, grid_size_h, grid_size_w, dim_t
116
+ freqs_h = freqs_h[None, :, None, :].expand(
117
+ temporal_size, -1, grid_size_w, -1
118
+ ) # temporal_size, grid_size_h, grid_size_2, dim_h
119
+ freqs_w = freqs_w[None, None, :, :].expand(
120
+ temporal_size, grid_size_h, -1, -1
121
+ ) # temporal_size, grid_size_h, grid_size_2, dim_w
122
+
123
+ freqs = torch.cat(
124
+ [freqs_t, freqs_h, freqs_w], dim=-1
125
+ ) # temporal_size, grid_size_h, grid_size_w, (dim_t + dim_h + dim_w)
126
+ freqs = freqs.view(
127
+ temporal_size * grid_size_h * grid_size_w, -1
128
+ ) # (temporal_size * grid_size_h * grid_size_w), (dim_t + dim_h + dim_w)
129
+ return freqs
130
+
131
+ t_cos, t_sin = freqs_t # both t_cos and t_sin has shape: temporal_size, dim_t
132
+ h_cos, h_sin = freqs_h # both h_cos and h_sin has shape: grid_size_h, dim_h
133
+ w_cos, w_sin = freqs_w # both w_cos and w_sin has shape: grid_size_w, dim_w
134
+
135
+ if grid_type == "slice":
136
+ t_cos, t_sin = t_cos[:temporal_size], t_sin[:temporal_size]
137
+ h_cos, h_sin = h_cos[:grid_size_h], h_sin[:grid_size_h]
138
+ w_cos, w_sin = w_cos[:grid_size_w], w_sin[:grid_size_w]
139
+
140
+ cos = combine_time_height_width(t_cos, h_cos, w_cos)
141
+ sin = combine_time_height_width(t_sin, h_sin, w_sin)
142
+ return cos, sin
143
+
144
+
145
+ # Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
146
+ def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
147
+ tw = tgt_width
148
+ th = tgt_height
149
+ h, w = src
150
+ r = h / w
151
+ if r > (th / tw):
152
+ resize_height = th
153
+ resize_width = int(round(th / h * w))
154
+ else:
155
+ resize_width = tw
156
+ resize_height = int(round(tw / w * h))
157
+
158
+ crop_top = int(round((th - resize_height) / 2.0))
159
+ crop_left = int(round((tw - resize_width) / 2.0))
160
+
161
+ return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
162
+
163
+
164
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
165
+ def retrieve_timesteps(
166
+ scheduler,
167
+ num_inference_steps: Optional[int] = None,
168
+ device: Optional[Union[str, torch.device]] = None,
169
+ timesteps: Optional[List[int]] = None,
170
+ sigmas: Optional[List[float]] = None,
171
+ **kwargs,
172
+ ):
173
+ """
174
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
175
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
176
+
177
+ Args:
178
+ scheduler (`SchedulerMixin`):
179
+ The scheduler to get timesteps from.
180
+ num_inference_steps (`int`):
181
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
182
+ must be `None`.
183
+ device (`str` or `torch.device`, *optional*):
184
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
185
+ timesteps (`List[int]`, *optional*):
186
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
187
+ `num_inference_steps` and `sigmas` must be `None`.
188
+ sigmas (`List[float]`, *optional*):
189
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
190
+ `num_inference_steps` and `timesteps` must be `None`.
191
+
192
+ Returns:
193
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
194
+ second element is the number of inference steps.
195
+ """
196
+ if timesteps is not None and sigmas is not None:
197
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
198
+ if timesteps is not None:
199
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
200
+ if not accepts_timesteps:
201
+ raise ValueError(
202
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
203
+ f" timestep schedules. Please check whether you are using the correct scheduler."
204
+ )
205
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
206
+ timesteps = scheduler.timesteps
207
+ num_inference_steps = len(timesteps)
208
+ elif sigmas is not None:
209
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
210
+ if not accept_sigmas:
211
+ raise ValueError(
212
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
213
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
214
+ )
215
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
216
+ timesteps = scheduler.timesteps
217
+ num_inference_steps = len(timesteps)
218
+ else:
219
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
220
+ timesteps = scheduler.timesteps
221
+ return timesteps, num_inference_steps
222
+
223
+
224
+ @dataclass
225
+ class CogVideoXFunPipelineOutput(BaseOutput):
226
+ r"""
227
+ Output class for CogVideo pipelines.
228
+
229
+ Args:
230
+ video (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
231
+ List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
232
+ denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
233
+ `(batch_size, num_frames, channels, height, width)`.
234
+ """
235
+
236
+ videos: torch.Tensor
237
+
238
+
239
+ class CogVideoXFunControlPipeline(DiffusionPipeline):
240
+ r"""
241
+ Pipeline for text-to-video generation using CogVideoX.
242
+
243
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
244
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
245
+
246
+ Args:
247
+ vae ([`AutoencoderKL`]):
248
+ Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
249
+ text_encoder ([`T5EncoderModel`]):
250
+ Frozen text-encoder. CogVideoX_Fun uses
251
+ [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the
252
+ [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
253
+ tokenizer (`T5Tokenizer`):
254
+ Tokenizer of class
255
+ [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
256
+ transformer ([`CogVideoXTransformer3DModel`]):
257
+ A text conditioned `CogVideoXTransformer3DModel` to denoise the encoded video latents.
258
+ scheduler ([`SchedulerMixin`]):
259
+ A scheduler to be used in combination with `transformer` to denoise the encoded video latents.
260
+ """
261
+
262
+ _optional_components = []
263
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
264
+
265
+ _callback_tensor_inputs = [
266
+ "latents",
267
+ "prompt_embeds",
268
+ "negative_prompt_embeds",
269
+ ]
270
+
271
+ def __init__(
272
+ self,
273
+ tokenizer: T5Tokenizer,
274
+ text_encoder: T5EncoderModel,
275
+ vae: AutoencoderKLCogVideoX,
276
+ transformer: CogVideoXTransformer3DModel,
277
+ scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
278
+ ):
279
+ super().__init__()
280
+
281
+ self.register_modules(
282
+ tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
283
+ )
284
+ self.vae_scale_factor_spatial = (
285
+ 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
286
+ )
287
+ self.vae_scale_factor_temporal = (
288
+ self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4
289
+ )
290
+
291
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
292
+
293
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
294
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
295
+ self.mask_processor = VaeImageProcessor(
296
+ vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
297
+ )
298
+
299
+ def _get_t5_prompt_embeds(
300
+ self,
301
+ prompt: Union[str, List[str]] = None,
302
+ num_videos_per_prompt: int = 1,
303
+ max_sequence_length: int = 226,
304
+ device: Optional[torch.device] = None,
305
+ dtype: Optional[torch.dtype] = None,
306
+ ):
307
+ device = device or self._execution_device
308
+ dtype = dtype or self.text_encoder.dtype
309
+
310
+ prompt = [prompt] if isinstance(prompt, str) else prompt
311
+ batch_size = len(prompt)
312
+
313
+ text_inputs = self.tokenizer(
314
+ prompt,
315
+ padding="max_length",
316
+ max_length=max_sequence_length,
317
+ truncation=True,
318
+ add_special_tokens=True,
319
+ return_tensors="pt",
320
+ )
321
+ text_input_ids = text_inputs.input_ids
322
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
323
+
324
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
325
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
326
+ logger.warning(
327
+ "The following part of your input was truncated because `max_sequence_length` is set to "
328
+ f" {max_sequence_length} tokens: {removed_text}"
329
+ )
330
+
331
+ prompt_embeds = self.text_encoder(text_input_ids.to(device))[0]
332
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
333
+
334
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
335
+ _, seq_len, _ = prompt_embeds.shape
336
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
337
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
338
+
339
+ return prompt_embeds
340
+
341
+ def encode_prompt(
342
+ self,
343
+ prompt: Union[str, List[str]],
344
+ negative_prompt: Optional[Union[str, List[str]]] = None,
345
+ do_classifier_free_guidance: bool = True,
346
+ num_videos_per_prompt: int = 1,
347
+ prompt_embeds: Optional[torch.Tensor] = None,
348
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
349
+ max_sequence_length: int = 226,
350
+ device: Optional[torch.device] = None,
351
+ dtype: Optional[torch.dtype] = None,
352
+ ):
353
+ r"""
354
+ Encodes the prompt into text encoder hidden states.
355
+
356
+ Args:
357
+ prompt (`str` or `List[str]`, *optional*):
358
+ prompt to be encoded
359
+ negative_prompt (`str` or `List[str]`, *optional*):
360
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
361
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
362
+ less than `1`).
363
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
364
+ Whether to use classifier free guidance or not.
365
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
366
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
367
+ prompt_embeds (`torch.Tensor`, *optional*):
368
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
369
+ provided, text embeddings will be generated from `prompt` input argument.
370
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
371
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
372
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
373
+ argument.
374
+ device: (`torch.device`, *optional*):
375
+ torch device
376
+ dtype: (`torch.dtype`, *optional*):
377
+ torch dtype
378
+ """
379
+ device = device or self._execution_device
380
+
381
+ prompt = [prompt] if isinstance(prompt, str) else prompt
382
+ if prompt is not None:
383
+ batch_size = len(prompt)
384
+ else:
385
+ batch_size = prompt_embeds.shape[0]
386
+
387
+ if prompt_embeds is None:
388
+ prompt_embeds = self._get_t5_prompt_embeds(
389
+ prompt=prompt,
390
+ num_videos_per_prompt=num_videos_per_prompt,
391
+ max_sequence_length=max_sequence_length,
392
+ device=device,
393
+ dtype=dtype,
394
+ )
395
+
396
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
397
+ negative_prompt = negative_prompt or ""
398
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
399
+
400
+ if prompt is not None and type(prompt) is not type(negative_prompt):
401
+ raise TypeError(
402
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
403
+ f" {type(prompt)}."
404
+ )
405
+ elif batch_size != len(negative_prompt):
406
+ raise ValueError(
407
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
408
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
409
+ " the batch size of `prompt`."
410
+ )
411
+
412
+ negative_prompt_embeds = self._get_t5_prompt_embeds(
413
+ prompt=negative_prompt,
414
+ num_videos_per_prompt=num_videos_per_prompt,
415
+ max_sequence_length=max_sequence_length,
416
+ device=device,
417
+ dtype=dtype,
418
+ )
419
+
420
+ return prompt_embeds, negative_prompt_embeds
421
+
422
+ def prepare_latents(
423
+ self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
424
+ ):
425
+ shape = (
426
+ batch_size,
427
+ (num_frames - 1) // self.vae_scale_factor_temporal + 1,
428
+ num_channels_latents,
429
+ height // self.vae_scale_factor_spatial,
430
+ width // self.vae_scale_factor_spatial,
431
+ )
432
+ if isinstance(generator, list) and len(generator) != batch_size:
433
+ raise ValueError(
434
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
435
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
436
+ )
437
+
438
+ if latents is None:
439
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
440
+ else:
441
+ latents = latents.to(device)
442
+
443
+ # scale the initial noise by the standard deviation required by the scheduler
444
+ latents = latents * self.scheduler.init_noise_sigma
445
+ return latents
446
+
447
+ def prepare_control_latents(
448
+ self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
449
+ ):
450
+ # resize the mask to latents shape as we concatenate the mask to the latents
451
+ # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
452
+ # and half precision
453
+
454
+ if mask is not None:
455
+ mask = mask.to(device=device, dtype=self.vae.dtype)
456
+ bs = 1
457
+ new_mask = []
458
+ for i in range(0, mask.shape[0], bs):
459
+ mask_bs = mask[i : i + bs]
460
+ mask_bs = self.vae.encode(mask_bs)[0]
461
+ mask_bs = mask_bs.mode()
462
+ new_mask.append(mask_bs)
463
+ mask = torch.cat(new_mask, dim = 0)
464
+ mask = mask * self.vae.config.scaling_factor
465
+
466
+ if masked_image is not None:
467
+ masked_image = masked_image.to(device=device, dtype=self.vae.dtype)
468
+ bs = 1
469
+ new_mask_pixel_values = []
470
+ for i in range(0, masked_image.shape[0], bs):
471
+ mask_pixel_values_bs = masked_image[i : i + bs]
472
+ mask_pixel_values_bs = self.vae.encode(mask_pixel_values_bs)[0]
473
+ mask_pixel_values_bs = mask_pixel_values_bs.mode()
474
+ new_mask_pixel_values.append(mask_pixel_values_bs)
475
+ masked_image_latents = torch.cat(new_mask_pixel_values, dim = 0)
476
+ masked_image_latents = masked_image_latents * self.vae.config.scaling_factor
477
+ else:
478
+ masked_image_latents = None
479
+
480
+ return mask, masked_image_latents
481
+
482
+ def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
483
+ latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
484
+ latents = 1 / self.vae.config.scaling_factor * latents
485
+
486
+ frames = self.vae.decode(latents).sample
487
+ frames = (frames / 2 + 0.5).clamp(0, 1)
488
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
489
+ frames = frames.cpu().float().numpy()
490
+ return frames
491
+
492
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
493
+ def prepare_extra_step_kwargs(self, generator, eta):
494
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
495
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
496
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
497
+ # and should be between [0, 1]
498
+
499
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
500
+ extra_step_kwargs = {}
501
+ if accepts_eta:
502
+ extra_step_kwargs["eta"] = eta
503
+
504
+ # check if the scheduler accepts generator
505
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
506
+ if accepts_generator:
507
+ extra_step_kwargs["generator"] = generator
508
+ return extra_step_kwargs
509
+
510
+ # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs
511
+ def check_inputs(
512
+ self,
513
+ prompt,
514
+ height,
515
+ width,
516
+ negative_prompt,
517
+ callback_on_step_end_tensor_inputs,
518
+ prompt_embeds=None,
519
+ negative_prompt_embeds=None,
520
+ ):
521
+ if height % 8 != 0 or width % 8 != 0:
522
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
523
+
524
+ if callback_on_step_end_tensor_inputs is not None and not all(
525
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
526
+ ):
527
+ raise ValueError(
528
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
529
+ )
530
+ if prompt is not None and prompt_embeds is not None:
531
+ raise ValueError(
532
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
533
+ " only forward one of the two."
534
+ )
535
+ elif prompt is None and prompt_embeds is None:
536
+ raise ValueError(
537
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
538
+ )
539
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
540
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
541
+
542
+ if prompt is not None and negative_prompt_embeds is not None:
543
+ raise ValueError(
544
+ f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
545
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
546
+ )
547
+
548
+ if negative_prompt is not None and negative_prompt_embeds is not None:
549
+ raise ValueError(
550
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
551
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
552
+ )
553
+
554
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
555
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
556
+ raise ValueError(
557
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
558
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
559
+ f" {negative_prompt_embeds.shape}."
560
+ )
561
+
562
+ def fuse_qkv_projections(self) -> None:
563
+ r"""Enables fused QKV projections."""
564
+ self.fusing_transformer = True
565
+ self.transformer.fuse_qkv_projections()
566
+
567
+ def unfuse_qkv_projections(self) -> None:
568
+ r"""Disable QKV projection fusion if enabled."""
569
+ if not self.fusing_transformer:
570
+ logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.")
571
+ else:
572
+ self.transformer.unfuse_qkv_projections()
573
+ self.fusing_transformer = False
574
+
575
+ def _prepare_rotary_positional_embeddings(
576
+ self,
577
+ height: int,
578
+ width: int,
579
+ num_frames: int,
580
+ device: torch.device,
581
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
582
+ grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
583
+ grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
584
+
585
+ p = self.transformer.config.patch_size
586
+ p_t = self.transformer.config.patch_size_t
587
+
588
+ base_size_width = self.transformer.config.sample_width // p
589
+ base_size_height = self.transformer.config.sample_height // p
590
+
591
+ if p_t is None:
592
+ # CogVideoX 1.0
593
+ grid_crops_coords = get_resize_crop_region_for_grid(
594
+ (grid_height, grid_width), base_size_width, base_size_height
595
+ )
596
+ freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
597
+ embed_dim=self.transformer.config.attention_head_dim,
598
+ crops_coords=grid_crops_coords,
599
+ grid_size=(grid_height, grid_width),
600
+ temporal_size=num_frames,
601
+ )
602
+ else:
603
+ # CogVideoX 1.5
604
+ base_num_frames = (num_frames + p_t - 1) // p_t
605
+
606
+ freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
607
+ embed_dim=self.transformer.config.attention_head_dim,
608
+ crops_coords=None,
609
+ grid_size=(grid_height, grid_width),
610
+ temporal_size=base_num_frames,
611
+ grid_type="slice",
612
+ max_size=(base_size_height, base_size_width),
613
+ )
614
+
615
+ freqs_cos = freqs_cos.to(device=device)
616
+ freqs_sin = freqs_sin.to(device=device)
617
+ return freqs_cos, freqs_sin
618
+
619
+ @property
620
+ def guidance_scale(self):
621
+ return self._guidance_scale
622
+
623
+ @property
624
+ def num_timesteps(self):
625
+ return self._num_timesteps
626
+
627
+ @property
628
+ def interrupt(self):
629
+ return self._interrupt
630
+
631
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
632
+ def get_timesteps(self, num_inference_steps, strength, device):
633
+ # get the original timestep using init_timestep
634
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
635
+
636
+ t_start = max(num_inference_steps - init_timestep, 0)
637
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
638
+
639
+ return timesteps, num_inference_steps - t_start
640
+
641
+ @torch.no_grad()
642
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
643
+ def __call__(
644
+ self,
645
+ prompt: Optional[Union[str, List[str]]] = None,
646
+ negative_prompt: Optional[Union[str, List[str]]] = None,
647
+ height: int = 480,
648
+ width: int = 720,
649
+ video: Union[torch.FloatTensor] = None,
650
+ control_video: Union[torch.FloatTensor] = None,
651
+ num_frames: int = 49,
652
+ num_inference_steps: int = 50,
653
+ timesteps: Optional[List[int]] = None,
654
+ guidance_scale: float = 6,
655
+ use_dynamic_cfg: bool = False,
656
+ num_videos_per_prompt: int = 1,
657
+ eta: float = 0.0,
658
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
659
+ latents: Optional[torch.FloatTensor] = None,
660
+ prompt_embeds: Optional[torch.FloatTensor] = None,
661
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
662
+ output_type: str = "numpy",
663
+ return_dict: bool = False,
664
+ callback_on_step_end: Optional[
665
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
666
+ ] = None,
667
+ attention_kwargs: Optional[Dict[str, Any]] = None,
668
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
669
+ max_sequence_length: int = 226,
670
+ comfyui_progressbar: bool = False,
671
+ ) -> Union[CogVideoXFunPipelineOutput, Tuple]:
672
+ """
673
+ Function invoked when calling the pipeline for generation.
674
+
675
+ Args:
676
+ prompt (`str` or `List[str]`, *optional*):
677
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
678
+ instead.
679
+ negative_prompt (`str` or `List[str]`, *optional*):
680
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
681
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
682
+ less than `1`).
683
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
684
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
685
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
686
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
687
+ num_frames (`int`, defaults to `48`):
688
+ Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will
689
+ contain 1 extra frame because CogVideoX_Fun is conditioned with (num_seconds * fps + 1) frames where
690
+ num_seconds is 6 and fps is 4. However, since videos can be saved at any fps, the only condition that
691
+ needs to be satisfied is that of divisibility mentioned above.
692
+ num_inference_steps (`int`, *optional*, defaults to 50):
693
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
694
+ expense of slower inference.
695
+ timesteps (`List[int]`, *optional*):
696
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
697
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
698
+ passed will be used. Must be in descending order.
699
+ guidance_scale (`float`, *optional*, defaults to 7.0):
700
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
701
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
702
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
703
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
704
+ usually at the expense of lower image quality.
705
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
706
+ The number of videos to generate per prompt.
707
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
708
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
709
+ to make generation deterministic.
710
+ latents (`torch.FloatTensor`, *optional*):
711
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
712
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
713
+ tensor will ge generated by sampling using the supplied random `generator`.
714
+ prompt_embeds (`torch.FloatTensor`, *optional*):
715
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
716
+ provided, text embeddings will be generated from `prompt` input argument.
717
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
718
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
719
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
720
+ argument.
721
+ output_type (`str`, *optional*, defaults to `"pil"`):
722
+ The output format of the generate image. Choose between
723
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
724
+ return_dict (`bool`, *optional*, defaults to `True`):
725
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
726
+ of a plain tuple.
727
+ callback_on_step_end (`Callable`, *optional*):
728
+ A function that calls at the end of each denoising steps during the inference. The function is called
729
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
730
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
731
+ `callback_on_step_end_tensor_inputs`.
732
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
733
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
734
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
735
+ `._callback_tensor_inputs` attribute of your pipeline class.
736
+ max_sequence_length (`int`, defaults to `226`):
737
+ Maximum sequence length in encoded prompt. Must be consistent with
738
+ `self.transformer.config.max_text_seq_length` otherwise may lead to poor results.
739
+
740
+ Examples:
741
+
742
+ Returns:
743
+ [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXFunPipelineOutput`] or `tuple`:
744
+ [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXFunPipelineOutput`] if `return_dict` is True, otherwise a
745
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
746
+ """
747
+
748
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
749
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
750
+
751
+ height = height or self.transformer.config.sample_height * self.vae_scale_factor_spatial
752
+ width = width or self.transformer.config.sample_width * self.vae_scale_factor_spatial
753
+ num_frames = num_frames or self.transformer.config.sample_frames
754
+
755
+ num_videos_per_prompt = 1
756
+
757
+ # 1. Check inputs. Raise error if not correct
758
+ self.check_inputs(
759
+ prompt,
760
+ height,
761
+ width,
762
+ negative_prompt,
763
+ callback_on_step_end_tensor_inputs,
764
+ prompt_embeds,
765
+ negative_prompt_embeds,
766
+ )
767
+ self._guidance_scale = guidance_scale
768
+ self._attention_kwargs = attention_kwargs
769
+ self._interrupt = False
770
+
771
+ # 2. Default call parameters
772
+ if prompt is not None and isinstance(prompt, str):
773
+ batch_size = 1
774
+ elif prompt is not None and isinstance(prompt, list):
775
+ batch_size = len(prompt)
776
+ else:
777
+ batch_size = prompt_embeds.shape[0]
778
+
779
+ device = self._execution_device
780
+
781
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
782
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
783
+ # corresponds to doing no classifier free guidance.
784
+ do_classifier_free_guidance = guidance_scale > 1.0
785
+
786
+ # 3. Encode input prompt
787
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
788
+ prompt,
789
+ negative_prompt,
790
+ do_classifier_free_guidance,
791
+ num_videos_per_prompt=num_videos_per_prompt,
792
+ prompt_embeds=prompt_embeds,
793
+ negative_prompt_embeds=negative_prompt_embeds,
794
+ max_sequence_length=max_sequence_length,
795
+ device=device,
796
+ )
797
+ if do_classifier_free_guidance:
798
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
799
+
800
+ # 4. Prepare timesteps
801
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
802
+ self._num_timesteps = len(timesteps)
803
+ if comfyui_progressbar:
804
+ from comfy.utils import ProgressBar
805
+ pbar = ProgressBar(num_inference_steps + 2)
806
+
807
+ if control_video is not None:
808
+ video_length = control_video.shape[2]
809
+ control_video = self.image_processor.preprocess(rearrange(control_video, "b c f h w -> (b f) c h w"), height=height, width=width)
810
+ control_video = control_video.to(dtype=torch.float32)
811
+ control_video = rearrange(control_video, "(b f) c h w -> b c f h w", f=video_length)
812
+ else:
813
+ control_video = None
814
+
815
+ # Magvae needs the number of frames to be 4n + 1.
816
+ local_latent_length = (num_frames - 1) // self.vae_scale_factor_temporal + 1
817
+ # For CogVideoX 1.5, the latent frames should be clipped to make it divisible by patch_size_t
818
+ patch_size_t = self.transformer.config.patch_size_t
819
+ additional_frames = 0
820
+ if patch_size_t is not None and local_latent_length % patch_size_t != 0:
821
+ additional_frames = local_latent_length % patch_size_t
822
+ num_frames -= additional_frames * self.vae_scale_factor_temporal
823
+ if num_frames <= 0:
824
+ num_frames = 1
825
+ if video_length > num_frames:
826
+ logger.warning("The length of condition video is not right, the latent frames should be clipped to make it divisible by patch_size_t. ")
827
+ video_length = num_frames
828
+ control_video = control_video[:, :, :video_length]
829
+
830
+ # 5. Prepare latents.
831
+ latent_channels = self.vae.config.latent_channels
832
+ latents = self.prepare_latents(
833
+ batch_size * num_videos_per_prompt,
834
+ latent_channels,
835
+ num_frames,
836
+ height,
837
+ width,
838
+ prompt_embeds.dtype,
839
+ device,
840
+ generator,
841
+ latents,
842
+ )
843
+ if comfyui_progressbar:
844
+ pbar.update(1)
845
+
846
+ control_video_latents = self.prepare_control_latents(
847
+ None,
848
+ control_video,
849
+ batch_size,
850
+ height,
851
+ width,
852
+ prompt_embeds.dtype,
853
+ device,
854
+ generator,
855
+ do_classifier_free_guidance
856
+ )[1]
857
+ control_video_latents_input = (
858
+ torch.cat([control_video_latents] * 2) if do_classifier_free_guidance else control_video_latents
859
+ )
860
+ control_latents = rearrange(control_video_latents_input, "b c f h w -> b f c h w")
861
+
862
+ if comfyui_progressbar:
863
+ pbar.update(1)
864
+
865
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
866
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
867
+
868
+ # 7. Create rotary embeds if required
869
+ image_rotary_emb = (
870
+ self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
871
+ if self.transformer.config.use_rotary_positional_embeddings
872
+ else None
873
+ )
874
+
875
+ # 8. Denoising loop
876
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
877
+
878
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
879
+ # for DPM-solver++
880
+ old_pred_original_sample = None
881
+ for i, t in enumerate(timesteps):
882
+ if self.interrupt:
883
+ continue
884
+
885
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
886
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
887
+
888
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
889
+ timestep = t.expand(latent_model_input.shape[0])
890
+
891
+ # predict noise model_output
892
+ noise_pred = self.transformer(
893
+ hidden_states=latent_model_input,
894
+ encoder_hidden_states=prompt_embeds,
895
+ timestep=timestep,
896
+ image_rotary_emb=image_rotary_emb,
897
+ return_dict=False,
898
+ control_latents=control_latents,
899
+ )[0]
900
+ noise_pred = noise_pred.float()
901
+
902
+ # perform guidance
903
+ if use_dynamic_cfg:
904
+ self._guidance_scale = 1 + guidance_scale * (
905
+ (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2
906
+ )
907
+ if do_classifier_free_guidance:
908
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
909
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
910
+
911
+ # compute the previous noisy sample x_t -> x_t-1
912
+ if not isinstance(self.scheduler, CogVideoXDPMScheduler):
913
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
914
+ else:
915
+ latents, old_pred_original_sample = self.scheduler.step(
916
+ noise_pred,
917
+ old_pred_original_sample,
918
+ t,
919
+ timesteps[i - 1] if i > 0 else None,
920
+ latents,
921
+ **extra_step_kwargs,
922
+ return_dict=False,
923
+ )
924
+ latents = latents.to(prompt_embeds.dtype)
925
+
926
+ # call the callback, if provided
927
+ if callback_on_step_end is not None:
928
+ callback_kwargs = {}
929
+ for k in callback_on_step_end_tensor_inputs:
930
+ callback_kwargs[k] = locals()[k]
931
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
932
+
933
+ latents = callback_outputs.pop("latents", latents)
934
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
935
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
936
+
937
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
938
+ progress_bar.update()
939
+ if comfyui_progressbar:
940
+ pbar.update(1)
941
+
942
+ if output_type == "numpy":
943
+ video = self.decode_latents(latents)
944
+ elif not output_type == "latent":
945
+ video = self.decode_latents(latents)
946
+ video = self.video_processor.postprocess_video(video=video, output_type=output_type)
947
+ else:
948
+ video = latents
949
+
950
+ # Offload all models
951
+ self.maybe_free_model_hooks()
952
+
953
+ if not return_dict:
954
+ video = torch.from_numpy(video)
955
+
956
+ return CogVideoXFunPipelineOutput(videos=video)
videox_fun/pipeline/pipeline_cogvideox_fun_inpaint.py ADDED
@@ -0,0 +1,1136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import inspect
17
+ import math
18
+ from dataclasses import dataclass
19
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
20
+
21
+ import numpy as np
22
+ import torch
23
+ import torch.nn.functional as F
24
+ from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
25
+ from diffusers.image_processor import VaeImageProcessor
26
+ from diffusers.models.embeddings import get_1d_rotary_pos_embed
27
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
28
+ from diffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
29
+ from diffusers.utils import BaseOutput, logging, replace_example_docstring
30
+ from diffusers.utils.torch_utils import randn_tensor
31
+ from diffusers.video_processor import VideoProcessor
32
+ from einops import rearrange
33
+
34
+ from ..models import (AutoencoderKLCogVideoX,
35
+ CogVideoXTransformer3DModel, T5EncoderModel,
36
+ T5Tokenizer)
37
+
38
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
39
+
40
+
41
+ EXAMPLE_DOC_STRING = """
42
+ Examples:
43
+ ```python
44
+ pass
45
+ ```
46
+ """
47
+
48
+ # Copied from diffusers.models.embeddings.get_3d_rotary_pos_embed
49
+ def get_3d_rotary_pos_embed(
50
+ embed_dim,
51
+ crops_coords,
52
+ grid_size,
53
+ temporal_size,
54
+ theta: int = 10000,
55
+ use_real: bool = True,
56
+ grid_type: str = "linspace",
57
+ max_size: Optional[Tuple[int, int]] = None,
58
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
59
+ """
60
+ RoPE for video tokens with 3D structure.
61
+
62
+ Args:
63
+ embed_dim: (`int`):
64
+ The embedding dimension size, corresponding to hidden_size_head.
65
+ crops_coords (`Tuple[int]`):
66
+ The top-left and bottom-right coordinates of the crop.
67
+ grid_size (`Tuple[int]`):
68
+ The grid size of the spatial positional embedding (height, width).
69
+ temporal_size (`int`):
70
+ The size of the temporal dimension.
71
+ theta (`float`):
72
+ Scaling factor for frequency computation.
73
+ grid_type (`str`):
74
+ Whether to use "linspace" or "slice" to compute grids.
75
+
76
+ Returns:
77
+ `torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`.
78
+ """
79
+ if use_real is not True:
80
+ raise ValueError(" `use_real = False` is not currently supported for get_3d_rotary_pos_embed")
81
+
82
+ if grid_type == "linspace":
83
+ start, stop = crops_coords
84
+ grid_size_h, grid_size_w = grid_size
85
+ grid_h = np.linspace(start[0], stop[0], grid_size_h, endpoint=False, dtype=np.float32)
86
+ grid_w = np.linspace(start[1], stop[1], grid_size_w, endpoint=False, dtype=np.float32)
87
+ grid_t = np.arange(temporal_size, dtype=np.float32)
88
+ grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32)
89
+ elif grid_type == "slice":
90
+ max_h, max_w = max_size
91
+ grid_size_h, grid_size_w = grid_size
92
+ grid_h = np.arange(max_h, dtype=np.float32)
93
+ grid_w = np.arange(max_w, dtype=np.float32)
94
+ grid_t = np.arange(temporal_size, dtype=np.float32)
95
+ else:
96
+ raise ValueError("Invalid value passed for `grid_type`.")
97
+
98
+ # Compute dimensions for each axis
99
+ dim_t = embed_dim // 4
100
+ dim_h = embed_dim // 8 * 3
101
+ dim_w = embed_dim // 8 * 3
102
+
103
+ # Temporal frequencies
104
+ freqs_t = get_1d_rotary_pos_embed(dim_t, grid_t, use_real=True)
105
+ # Spatial frequencies for height and width
106
+ freqs_h = get_1d_rotary_pos_embed(dim_h, grid_h, use_real=True)
107
+ freqs_w = get_1d_rotary_pos_embed(dim_w, grid_w, use_real=True)
108
+
109
+ # BroadCast and concatenate temporal and spaial frequencie (height and width) into a 3d tensor
110
+ def combine_time_height_width(freqs_t, freqs_h, freqs_w):
111
+ freqs_t = freqs_t[:, None, None, :].expand(
112
+ -1, grid_size_h, grid_size_w, -1
113
+ ) # temporal_size, grid_size_h, grid_size_w, dim_t
114
+ freqs_h = freqs_h[None, :, None, :].expand(
115
+ temporal_size, -1, grid_size_w, -1
116
+ ) # temporal_size, grid_size_h, grid_size_2, dim_h
117
+ freqs_w = freqs_w[None, None, :, :].expand(
118
+ temporal_size, grid_size_h, -1, -1
119
+ ) # temporal_size, grid_size_h, grid_size_2, dim_w
120
+
121
+ freqs = torch.cat(
122
+ [freqs_t, freqs_h, freqs_w], dim=-1
123
+ ) # temporal_size, grid_size_h, grid_size_w, (dim_t + dim_h + dim_w)
124
+ freqs = freqs.view(
125
+ temporal_size * grid_size_h * grid_size_w, -1
126
+ ) # (temporal_size * grid_size_h * grid_size_w), (dim_t + dim_h + dim_w)
127
+ return freqs
128
+
129
+ t_cos, t_sin = freqs_t # both t_cos and t_sin has shape: temporal_size, dim_t
130
+ h_cos, h_sin = freqs_h # both h_cos and h_sin has shape: grid_size_h, dim_h
131
+ w_cos, w_sin = freqs_w # both w_cos and w_sin has shape: grid_size_w, dim_w
132
+
133
+ if grid_type == "slice":
134
+ t_cos, t_sin = t_cos[:temporal_size], t_sin[:temporal_size]
135
+ h_cos, h_sin = h_cos[:grid_size_h], h_sin[:grid_size_h]
136
+ w_cos, w_sin = w_cos[:grid_size_w], w_sin[:grid_size_w]
137
+
138
+ cos = combine_time_height_width(t_cos, h_cos, w_cos)
139
+ sin = combine_time_height_width(t_sin, h_sin, w_sin)
140
+ return cos, sin
141
+
142
+
143
+ # Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
144
+ def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
145
+ tw = tgt_width
146
+ th = tgt_height
147
+ h, w = src
148
+ r = h / w
149
+ if r > (th / tw):
150
+ resize_height = th
151
+ resize_width = int(round(th / h * w))
152
+ else:
153
+ resize_width = tw
154
+ resize_height = int(round(tw / w * h))
155
+
156
+ crop_top = int(round((th - resize_height) / 2.0))
157
+ crop_left = int(round((tw - resize_width) / 2.0))
158
+
159
+ return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
160
+
161
+
162
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
163
+ def retrieve_timesteps(
164
+ scheduler,
165
+ num_inference_steps: Optional[int] = None,
166
+ device: Optional[Union[str, torch.device]] = None,
167
+ timesteps: Optional[List[int]] = None,
168
+ sigmas: Optional[List[float]] = None,
169
+ **kwargs,
170
+ ):
171
+ """
172
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
173
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
174
+
175
+ Args:
176
+ scheduler (`SchedulerMixin`):
177
+ The scheduler to get timesteps from.
178
+ num_inference_steps (`int`):
179
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
180
+ must be `None`.
181
+ device (`str` or `torch.device`, *optional*):
182
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
183
+ timesteps (`List[int]`, *optional*):
184
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
185
+ `num_inference_steps` and `sigmas` must be `None`.
186
+ sigmas (`List[float]`, *optional*):
187
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
188
+ `num_inference_steps` and `timesteps` must be `None`.
189
+
190
+ Returns:
191
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
192
+ second element is the number of inference steps.
193
+ """
194
+ if timesteps is not None and sigmas is not None:
195
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
196
+ if timesteps is not None:
197
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
198
+ if not accepts_timesteps:
199
+ raise ValueError(
200
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
201
+ f" timestep schedules. Please check whether you are using the correct scheduler."
202
+ )
203
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
204
+ timesteps = scheduler.timesteps
205
+ num_inference_steps = len(timesteps)
206
+ elif sigmas is not None:
207
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
208
+ if not accept_sigmas:
209
+ raise ValueError(
210
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
211
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
212
+ )
213
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
214
+ timesteps = scheduler.timesteps
215
+ num_inference_steps = len(timesteps)
216
+ else:
217
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
218
+ timesteps = scheduler.timesteps
219
+ return timesteps, num_inference_steps
220
+
221
+
222
+ def resize_mask(mask, latent, process_first_frame_only=True):
223
+ latent_size = latent.size()
224
+ batch_size, channels, num_frames, height, width = mask.shape
225
+
226
+ if process_first_frame_only:
227
+ target_size = list(latent_size[2:])
228
+ target_size[0] = 1
229
+ first_frame_resized = F.interpolate(
230
+ mask[:, :, 0:1, :, :],
231
+ size=target_size,
232
+ mode='trilinear',
233
+ align_corners=False
234
+ )
235
+
236
+ target_size = list(latent_size[2:])
237
+ target_size[0] = target_size[0] - 1
238
+ if target_size[0] != 0:
239
+ remaining_frames_resized = F.interpolate(
240
+ mask[:, :, 1:, :, :],
241
+ size=target_size,
242
+ mode='trilinear',
243
+ align_corners=False
244
+ )
245
+ resized_mask = torch.cat([first_frame_resized, remaining_frames_resized], dim=2)
246
+ else:
247
+ resized_mask = first_frame_resized
248
+ else:
249
+ target_size = list(latent_size[2:])
250
+ resized_mask = F.interpolate(
251
+ mask,
252
+ size=target_size,
253
+ mode='trilinear',
254
+ align_corners=False
255
+ )
256
+ return resized_mask
257
+
258
+
259
+ def add_noise_to_reference_video(image, ratio=None):
260
+ if ratio is None:
261
+ sigma = torch.normal(mean=-3.0, std=0.5, size=(image.shape[0],)).to(image.device)
262
+ sigma = torch.exp(sigma).to(image.dtype)
263
+ else:
264
+ sigma = torch.ones((image.shape[0],)).to(image.device, image.dtype) * ratio
265
+
266
+ image_noise = torch.randn_like(image) * sigma[:, None, None, None, None]
267
+ image_noise = torch.where(image==-1, torch.zeros_like(image), image_noise)
268
+ image = image + image_noise
269
+ return image
270
+
271
+
272
+ @dataclass
273
+ class CogVideoXFunPipelineOutput(BaseOutput):
274
+ r"""
275
+ Output class for CogVideo pipelines.
276
+
277
+ Args:
278
+ video (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
279
+ List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
280
+ denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
281
+ `(batch_size, num_frames, channels, height, width)`.
282
+ """
283
+
284
+ videos: torch.Tensor
285
+
286
+
287
+ class CogVideoXFunInpaintPipeline(DiffusionPipeline):
288
+ r"""
289
+ Pipeline for text-to-video generation using CogVideoX.
290
+
291
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
292
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
293
+
294
+ Args:
295
+ vae ([`AutoencoderKL`]):
296
+ Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
297
+ text_encoder ([`T5EncoderModel`]):
298
+ Frozen text-encoder. CogVideoX_Fun uses
299
+ [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the
300
+ [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
301
+ tokenizer (`T5Tokenizer`):
302
+ Tokenizer of class
303
+ [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
304
+ transformer ([`CogVideoXTransformer3DModel`]):
305
+ A text conditioned `CogVideoXTransformer3DModel` to denoise the encoded video latents.
306
+ scheduler ([`SchedulerMixin`]):
307
+ A scheduler to be used in combination with `transformer` to denoise the encoded video latents.
308
+ """
309
+
310
+ _optional_components = []
311
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
312
+
313
+ _callback_tensor_inputs = [
314
+ "latents",
315
+ "prompt_embeds",
316
+ "negative_prompt_embeds",
317
+ ]
318
+
319
+ def __init__(
320
+ self,
321
+ tokenizer: T5Tokenizer,
322
+ text_encoder: T5EncoderModel,
323
+ vae: AutoencoderKLCogVideoX,
324
+ transformer: CogVideoXTransformer3DModel,
325
+ scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
326
+ ):
327
+ super().__init__()
328
+
329
+ self.register_modules(
330
+ tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
331
+ )
332
+ self.vae_scale_factor_spatial = (
333
+ 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
334
+ )
335
+ self.vae_scale_factor_temporal = (
336
+ self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4
337
+ )
338
+
339
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
340
+
341
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
342
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
343
+ self.mask_processor = VaeImageProcessor(
344
+ vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
345
+ )
346
+
347
+ def _get_t5_prompt_embeds(
348
+ self,
349
+ prompt: Union[str, List[str]] = None,
350
+ num_videos_per_prompt: int = 1,
351
+ max_sequence_length: int = 226,
352
+ device: Optional[torch.device] = None,
353
+ dtype: Optional[torch.dtype] = None,
354
+ ):
355
+ device = device or self._execution_device
356
+ dtype = dtype or self.text_encoder.dtype
357
+
358
+ prompt = [prompt] if isinstance(prompt, str) else prompt
359
+ batch_size = len(prompt)
360
+
361
+ text_inputs = self.tokenizer(
362
+ prompt,
363
+ padding="max_length",
364
+ max_length=max_sequence_length,
365
+ truncation=True,
366
+ add_special_tokens=True,
367
+ return_tensors="pt",
368
+ )
369
+ text_input_ids = text_inputs.input_ids
370
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
371
+
372
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
373
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
374
+ logger.warning(
375
+ "The following part of your input was truncated because `max_sequence_length` is set to "
376
+ f" {max_sequence_length} tokens: {removed_text}"
377
+ )
378
+
379
+ prompt_embeds = self.text_encoder(text_input_ids.to(device))[0]
380
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
381
+
382
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
383
+ _, seq_len, _ = prompt_embeds.shape
384
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
385
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
386
+
387
+ return prompt_embeds
388
+
389
+ def encode_prompt(
390
+ self,
391
+ prompt: Union[str, List[str]],
392
+ negative_prompt: Optional[Union[str, List[str]]] = None,
393
+ do_classifier_free_guidance: bool = True,
394
+ num_videos_per_prompt: int = 1,
395
+ prompt_embeds: Optional[torch.Tensor] = None,
396
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
397
+ max_sequence_length: int = 226,
398
+ device: Optional[torch.device] = None,
399
+ dtype: Optional[torch.dtype] = None,
400
+ ):
401
+ r"""
402
+ Encodes the prompt into text encoder hidden states.
403
+
404
+ Args:
405
+ prompt (`str` or `List[str]`, *optional*):
406
+ prompt to be encoded
407
+ negative_prompt (`str` or `List[str]`, *optional*):
408
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
409
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
410
+ less than `1`).
411
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
412
+ Whether to use classifier free guidance or not.
413
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
414
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
415
+ prompt_embeds (`torch.Tensor`, *optional*):
416
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
417
+ provided, text embeddings will be generated from `prompt` input argument.
418
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
419
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
420
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
421
+ argument.
422
+ device: (`torch.device`, *optional*):
423
+ torch device
424
+ dtype: (`torch.dtype`, *optional*):
425
+ torch dtype
426
+ """
427
+ device = device or self._execution_device
428
+
429
+ prompt = [prompt] if isinstance(prompt, str) else prompt
430
+ if prompt is not None:
431
+ batch_size = len(prompt)
432
+ else:
433
+ batch_size = prompt_embeds.shape[0]
434
+
435
+ if prompt_embeds is None:
436
+ prompt_embeds = self._get_t5_prompt_embeds(
437
+ prompt=prompt,
438
+ num_videos_per_prompt=num_videos_per_prompt,
439
+ max_sequence_length=max_sequence_length,
440
+ device=device,
441
+ dtype=dtype,
442
+ )
443
+
444
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
445
+ negative_prompt = negative_prompt or ""
446
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
447
+
448
+ if prompt is not None and type(prompt) is not type(negative_prompt):
449
+ raise TypeError(
450
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
451
+ f" {type(prompt)}."
452
+ )
453
+ elif batch_size != len(negative_prompt):
454
+ raise ValueError(
455
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
456
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
457
+ " the batch size of `prompt`."
458
+ )
459
+
460
+ negative_prompt_embeds = self._get_t5_prompt_embeds(
461
+ prompt=negative_prompt,
462
+ num_videos_per_prompt=num_videos_per_prompt,
463
+ max_sequence_length=max_sequence_length,
464
+ device=device,
465
+ dtype=dtype,
466
+ )
467
+
468
+ return prompt_embeds, negative_prompt_embeds
469
+
470
+ def prepare_latents(
471
+ self,
472
+ batch_size,
473
+ num_channels_latents,
474
+ height,
475
+ width,
476
+ video_length,
477
+ dtype,
478
+ device,
479
+ generator,
480
+ latents=None,
481
+ video=None,
482
+ timestep=None,
483
+ is_strength_max=True,
484
+ return_noise=False,
485
+ return_video_latents=False,
486
+ ):
487
+ shape = (
488
+ batch_size,
489
+ (video_length - 1) // self.vae_scale_factor_temporal + 1,
490
+ num_channels_latents,
491
+ height // self.vae_scale_factor_spatial,
492
+ width // self.vae_scale_factor_spatial,
493
+ )
494
+ if isinstance(generator, list) and len(generator) != batch_size:
495
+ raise ValueError(
496
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
497
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
498
+ )
499
+
500
+ if return_video_latents or (latents is None and not is_strength_max):
501
+ video = video.to(device=device, dtype=self.vae.dtype)
502
+
503
+ bs = 1
504
+ new_video = []
505
+ for i in range(0, video.shape[0], bs):
506
+ video_bs = video[i : i + bs]
507
+ video_bs = self.vae.encode(video_bs)[0]
508
+ video_bs = video_bs.sample()
509
+ new_video.append(video_bs)
510
+ video = torch.cat(new_video, dim = 0)
511
+ video = video * self.vae.config.scaling_factor
512
+
513
+ video_latents = video.repeat(batch_size // video.shape[0], 1, 1, 1, 1)
514
+ video_latents = video_latents.to(device=device, dtype=dtype)
515
+ video_latents = rearrange(video_latents, "b c f h w -> b f c h w")
516
+
517
+ if latents is None:
518
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
519
+ # if strength is 1. then initialise the latents to noise, else initial to image + noise
520
+ latents = noise if is_strength_max else self.scheduler.add_noise(video_latents, noise, timestep)
521
+ # if pure noise then scale the initial latents by the Scheduler's init sigma
522
+ latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents
523
+ else:
524
+ noise = latents.to(device)
525
+ latents = noise * self.scheduler.init_noise_sigma
526
+
527
+ # scale the initial noise by the standard deviation required by the scheduler
528
+ outputs = (latents,)
529
+
530
+ if return_noise:
531
+ outputs += (noise,)
532
+
533
+ if return_video_latents:
534
+ outputs += (video_latents,)
535
+
536
+ return outputs
537
+
538
+ def prepare_mask_latents(
539
+ self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance, noise_aug_strength
540
+ ):
541
+ # resize the mask to latents shape as we concatenate the mask to the latents
542
+ # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
543
+ # and half precision
544
+
545
+ if mask is not None:
546
+ mask = mask.to(device=device, dtype=self.vae.dtype)
547
+ bs = 1
548
+ new_mask = []
549
+ for i in range(0, mask.shape[0], bs):
550
+ mask_bs = mask[i : i + bs]
551
+ mask_bs = self.vae.encode(mask_bs)[0]
552
+ mask_bs = mask_bs.mode()
553
+ new_mask.append(mask_bs)
554
+ mask = torch.cat(new_mask, dim = 0)
555
+ mask = mask * self.vae.config.scaling_factor
556
+
557
+ if masked_image is not None:
558
+ if self.transformer.config.add_noise_in_inpaint_model:
559
+ masked_image = add_noise_to_reference_video(masked_image, ratio=noise_aug_strength)
560
+ masked_image = masked_image.to(device=device, dtype=self.vae.dtype)
561
+ bs = 1
562
+ new_mask_pixel_values = []
563
+ for i in range(0, masked_image.shape[0], bs):
564
+ mask_pixel_values_bs = masked_image[i : i + bs]
565
+ mask_pixel_values_bs = self.vae.encode(mask_pixel_values_bs)[0]
566
+ mask_pixel_values_bs = mask_pixel_values_bs.mode()
567
+ new_mask_pixel_values.append(mask_pixel_values_bs)
568
+ masked_image_latents = torch.cat(new_mask_pixel_values, dim = 0)
569
+ masked_image_latents = masked_image_latents * self.vae.config.scaling_factor
570
+ else:
571
+ masked_image_latents = None
572
+
573
+ return mask, masked_image_latents
574
+
575
+ def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
576
+ latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
577
+ latents = 1 / self.vae.config.scaling_factor * latents
578
+
579
+ frames = self.vae.decode(latents).sample
580
+ frames = (frames / 2 + 0.5).clamp(0, 1)
581
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
582
+ frames = frames.cpu().float().numpy()
583
+ return frames
584
+
585
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
586
+ def prepare_extra_step_kwargs(self, generator, eta):
587
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
588
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
589
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
590
+ # and should be between [0, 1]
591
+
592
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
593
+ extra_step_kwargs = {}
594
+ if accepts_eta:
595
+ extra_step_kwargs["eta"] = eta
596
+
597
+ # check if the scheduler accepts generator
598
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
599
+ if accepts_generator:
600
+ extra_step_kwargs["generator"] = generator
601
+ return extra_step_kwargs
602
+
603
+ # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs
604
+ def check_inputs(
605
+ self,
606
+ prompt,
607
+ height,
608
+ width,
609
+ negative_prompt,
610
+ callback_on_step_end_tensor_inputs,
611
+ prompt_embeds=None,
612
+ negative_prompt_embeds=None,
613
+ ):
614
+ if height % 8 != 0 or width % 8 != 0:
615
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
616
+
617
+ if callback_on_step_end_tensor_inputs is not None and not all(
618
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
619
+ ):
620
+ raise ValueError(
621
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
622
+ )
623
+ if prompt is not None and prompt_embeds is not None:
624
+ raise ValueError(
625
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
626
+ " only forward one of the two."
627
+ )
628
+ elif prompt is None and prompt_embeds is None:
629
+ raise ValueError(
630
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
631
+ )
632
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
633
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
634
+
635
+ if prompt is not None and negative_prompt_embeds is not None:
636
+ raise ValueError(
637
+ f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
638
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
639
+ )
640
+
641
+ if negative_prompt is not None and negative_prompt_embeds is not None:
642
+ raise ValueError(
643
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
644
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
645
+ )
646
+
647
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
648
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
649
+ raise ValueError(
650
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
651
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
652
+ f" {negative_prompt_embeds.shape}."
653
+ )
654
+
655
+ def fuse_qkv_projections(self) -> None:
656
+ r"""Enables fused QKV projections."""
657
+ self.fusing_transformer = True
658
+ self.transformer.fuse_qkv_projections()
659
+
660
+ def unfuse_qkv_projections(self) -> None:
661
+ r"""Disable QKV projection fusion if enabled."""
662
+ if not self.fusing_transformer:
663
+ logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.")
664
+ else:
665
+ self.transformer.unfuse_qkv_projections()
666
+ self.fusing_transformer = False
667
+
668
+ def _prepare_rotary_positional_embeddings(
669
+ self,
670
+ height: int,
671
+ width: int,
672
+ num_frames: int,
673
+ device: torch.device,
674
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
675
+ grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
676
+ grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
677
+
678
+ p = self.transformer.config.patch_size
679
+ p_t = self.transformer.config.patch_size_t
680
+
681
+ base_size_width = self.transformer.config.sample_width // p
682
+ base_size_height = self.transformer.config.sample_height // p
683
+
684
+ if p_t is None:
685
+ # CogVideoX 1.0
686
+ grid_crops_coords = get_resize_crop_region_for_grid(
687
+ (grid_height, grid_width), base_size_width, base_size_height
688
+ )
689
+ freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
690
+ embed_dim=self.transformer.config.attention_head_dim,
691
+ crops_coords=grid_crops_coords,
692
+ grid_size=(grid_height, grid_width),
693
+ temporal_size=num_frames,
694
+ )
695
+ else:
696
+ # CogVideoX 1.5
697
+ base_num_frames = (num_frames + p_t - 1) // p_t
698
+
699
+ freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
700
+ embed_dim=self.transformer.config.attention_head_dim,
701
+ crops_coords=None,
702
+ grid_size=(grid_height, grid_width),
703
+ temporal_size=base_num_frames,
704
+ grid_type="slice",
705
+ max_size=(base_size_height, base_size_width),
706
+ )
707
+
708
+ freqs_cos = freqs_cos.to(device=device)
709
+ freqs_sin = freqs_sin.to(device=device)
710
+ return freqs_cos, freqs_sin
711
+
712
+ @property
713
+ def guidance_scale(self):
714
+ return self._guidance_scale
715
+
716
+ @property
717
+ def num_timesteps(self):
718
+ return self._num_timesteps
719
+
720
+ @property
721
+ def attention_kwargs(self):
722
+ return self._attention_kwargs
723
+
724
+ @property
725
+ def interrupt(self):
726
+ return self._interrupt
727
+
728
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
729
+ def get_timesteps(self, num_inference_steps, strength, device):
730
+ # get the original timestep using init_timestep
731
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
732
+
733
+ t_start = max(num_inference_steps - init_timestep, 0)
734
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
735
+
736
+ return timesteps, num_inference_steps - t_start
737
+
738
+ @torch.no_grad()
739
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
740
+ def __call__(
741
+ self,
742
+ prompt: Optional[Union[str, List[str]]] = None,
743
+ negative_prompt: Optional[Union[str, List[str]]] = None,
744
+ height: int = 480,
745
+ width: int = 720,
746
+ video: Union[torch.FloatTensor] = None,
747
+ mask_video: Union[torch.FloatTensor] = None,
748
+ masked_video_latents: Union[torch.FloatTensor] = None,
749
+ num_frames: int = 49,
750
+ num_inference_steps: int = 50,
751
+ timesteps: Optional[List[int]] = None,
752
+ guidance_scale: float = 6,
753
+ use_dynamic_cfg: bool = False,
754
+ num_videos_per_prompt: int = 1,
755
+ eta: float = 0.0,
756
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
757
+ latents: Optional[torch.FloatTensor] = None,
758
+ prompt_embeds: Optional[torch.FloatTensor] = None,
759
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
760
+ output_type: str = "numpy",
761
+ return_dict: bool = False,
762
+ callback_on_step_end: Optional[
763
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
764
+ ] = None,
765
+ attention_kwargs: Optional[Dict[str, Any]] = None,
766
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
767
+ max_sequence_length: int = 226,
768
+ strength: float = 1,
769
+ noise_aug_strength: float = 0.0563,
770
+ comfyui_progressbar: bool = False,
771
+ ) -> Union[CogVideoXFunPipelineOutput, Tuple]:
772
+ """
773
+ Function invoked when calling the pipeline for generation.
774
+
775
+ Args:
776
+ prompt (`str` or `List[str]`, *optional*):
777
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
778
+ instead.
779
+ negative_prompt (`str` or `List[str]`, *optional*):
780
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
781
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
782
+ less than `1`).
783
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
784
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
785
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
786
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
787
+ num_frames (`int`, defaults to `48`):
788
+ Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will
789
+ contain 1 extra frame because CogVideoX_Fun is conditioned with (num_seconds * fps + 1) frames where
790
+ num_seconds is 6 and fps is 4. However, since videos can be saved at any fps, the only condition that
791
+ needs to be satisfied is that of divisibility mentioned above.
792
+ num_inference_steps (`int`, *optional*, defaults to 50):
793
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
794
+ expense of slower inference.
795
+ timesteps (`List[int]`, *optional*):
796
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
797
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
798
+ passed will be used. Must be in descending order.
799
+ guidance_scale (`float`, *optional*, defaults to 7.0):
800
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
801
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
802
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
803
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
804
+ usually at the expense of lower image quality.
805
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
806
+ The number of videos to generate per prompt.
807
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
808
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
809
+ to make generation deterministic.
810
+ latents (`torch.FloatTensor`, *optional*):
811
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
812
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
813
+ tensor will ge generated by sampling using the supplied random `generator`.
814
+ prompt_embeds (`torch.FloatTensor`, *optional*):
815
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
816
+ provided, text embeddings will be generated from `prompt` input argument.
817
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
818
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
819
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
820
+ argument.
821
+ output_type (`str`, *optional*, defaults to `"pil"`):
822
+ The output format of the generate image. Choose between
823
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
824
+ return_dict (`bool`, *optional*, defaults to `True`):
825
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
826
+ of a plain tuple.
827
+ callback_on_step_end (`Callable`, *optional*):
828
+ A function that calls at the end of each denoising steps during the inference. The function is called
829
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
830
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
831
+ `callback_on_step_end_tensor_inputs`.
832
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
833
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
834
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
835
+ `._callback_tensor_inputs` attribute of your pipeline class.
836
+ max_sequence_length (`int`, defaults to `226`):
837
+ Maximum sequence length in encoded prompt. Must be consistent with
838
+ `self.transformer.config.max_text_seq_length` otherwise may lead to poor results.
839
+
840
+ Examples:
841
+
842
+ Returns:
843
+ [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXFunPipelineOutput`] or `tuple`:
844
+ [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXFunPipelineOutput`] if `return_dict` is True, otherwise a
845
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
846
+ """
847
+
848
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
849
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
850
+
851
+ height = height or self.transformer.config.sample_height * self.vae_scale_factor_spatial
852
+ width = width or self.transformer.config.sample_width * self.vae_scale_factor_spatial
853
+ num_frames = num_frames or self.transformer.config.sample_frames
854
+
855
+ num_videos_per_prompt = 1
856
+
857
+ # 1. Check inputs. Raise error if not correct
858
+ self.check_inputs(
859
+ prompt,
860
+ height,
861
+ width,
862
+ negative_prompt,
863
+ callback_on_step_end_tensor_inputs,
864
+ prompt_embeds,
865
+ negative_prompt_embeds,
866
+ )
867
+ self._guidance_scale = guidance_scale
868
+ self._attention_kwargs = attention_kwargs
869
+ self._interrupt = False
870
+
871
+ # 2. Default call parameters
872
+ if prompt is not None and isinstance(prompt, str):
873
+ batch_size = 1
874
+ elif prompt is not None and isinstance(prompt, list):
875
+ batch_size = len(prompt)
876
+ else:
877
+ batch_size = prompt_embeds.shape[0]
878
+
879
+ device = self._execution_device
880
+
881
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
882
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
883
+ # corresponds to doing no classifier free guidance.
884
+ do_classifier_free_guidance = guidance_scale > 1.0
885
+
886
+ # 3. Encode input prompt
887
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
888
+ prompt,
889
+ negative_prompt,
890
+ do_classifier_free_guidance,
891
+ num_videos_per_prompt=num_videos_per_prompt,
892
+ prompt_embeds=prompt_embeds,
893
+ negative_prompt_embeds=negative_prompt_embeds,
894
+ max_sequence_length=max_sequence_length,
895
+ device=device,
896
+ )
897
+ if do_classifier_free_guidance:
898
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
899
+
900
+ # 4. set timesteps
901
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
902
+ timesteps, num_inference_steps = self.get_timesteps(
903
+ num_inference_steps=num_inference_steps, strength=strength, device=device
904
+ )
905
+ self._num_timesteps = len(timesteps)
906
+ if comfyui_progressbar:
907
+ from comfy.utils import ProgressBar
908
+ pbar = ProgressBar(num_inference_steps + 2)
909
+ # at which timestep to set the initial noise (n.b. 50% if strength is 0.5)
910
+ latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt)
911
+ # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise
912
+ is_strength_max = strength == 1.0
913
+
914
+ # 5. Prepare latents.
915
+ if video is not None:
916
+ video_length = video.shape[2]
917
+ init_video = self.image_processor.preprocess(rearrange(video, "b c f h w -> (b f) c h w"), height=height, width=width)
918
+ init_video = init_video.to(dtype=torch.float32)
919
+ init_video = rearrange(init_video, "(b f) c h w -> b c f h w", f=video_length)
920
+ else:
921
+ init_video = None
922
+
923
+ # Magvae needs the number of frames to be 4n + 1.
924
+ local_latent_length = (num_frames - 1) // self.vae_scale_factor_temporal + 1
925
+ # For CogVideoX 1.5, the latent frames should be clipped to make it divisible by patch_size_t
926
+ patch_size_t = self.transformer.config.patch_size_t
927
+ additional_frames = 0
928
+ if patch_size_t is not None and local_latent_length % patch_size_t != 0:
929
+ additional_frames = local_latent_length % patch_size_t
930
+ num_frames -= additional_frames * self.vae_scale_factor_temporal
931
+ if num_frames <= 0:
932
+ num_frames = 1
933
+ if video_length > num_frames:
934
+ logger.warning("The length of condition video is not right, the latent frames should be clipped to make it divisible by patch_size_t. ")
935
+ video_length = num_frames
936
+ video = video[:, :, :video_length]
937
+ init_video = init_video[:, :, :video_length]
938
+ mask_video = mask_video[:, :, :video_length]
939
+
940
+ num_channels_latents = self.vae.config.latent_channels
941
+ num_channels_transformer = self.transformer.config.in_channels
942
+ return_image_latents = num_channels_transformer == num_channels_latents
943
+
944
+ latents_outputs = self.prepare_latents(
945
+ batch_size * num_videos_per_prompt,
946
+ num_channels_latents,
947
+ height,
948
+ width,
949
+ video_length,
950
+ prompt_embeds.dtype,
951
+ device,
952
+ generator,
953
+ latents,
954
+ video=init_video,
955
+ timestep=latent_timestep,
956
+ is_strength_max=is_strength_max,
957
+ return_noise=True,
958
+ return_video_latents=return_image_latents,
959
+ )
960
+ if return_image_latents:
961
+ latents, noise, image_latents = latents_outputs
962
+ else:
963
+ latents, noise = latents_outputs
964
+ if comfyui_progressbar:
965
+ pbar.update(1)
966
+
967
+ if mask_video is not None:
968
+ if (mask_video == 255).all():
969
+ mask_latents = torch.zeros_like(latents)[:, :, :1].to(latents.device, latents.dtype)
970
+ masked_video_latents = torch.zeros_like(latents).to(latents.device, latents.dtype)
971
+
972
+ mask_input = torch.cat([mask_latents] * 2) if do_classifier_free_guidance else mask_latents
973
+ masked_video_latents_input = (
974
+ torch.cat([masked_video_latents] * 2) if do_classifier_free_guidance else masked_video_latents
975
+ )
976
+ inpaint_latents = torch.cat([mask_input, masked_video_latents_input], dim=2).to(latents.dtype)
977
+ else:
978
+ # Prepare mask latent variables
979
+ video_length = video.shape[2]
980
+ mask_condition = self.mask_processor.preprocess(rearrange(mask_video, "b c f h w -> (b f) c h w"), height=height, width=width)
981
+ mask_condition = mask_condition.to(dtype=torch.float32)
982
+ mask_condition = rearrange(mask_condition, "(b f) c h w -> b c f h w", f=video_length)
983
+
984
+ if num_channels_transformer != num_channels_latents:
985
+ mask_condition_tile = torch.tile(mask_condition, [1, 3, 1, 1, 1])
986
+ if masked_video_latents is None:
987
+ masked_video = init_video * (mask_condition_tile < 0.5) + torch.ones_like(init_video) * (mask_condition_tile > 0.5) * -1
988
+ else:
989
+ masked_video = masked_video_latents
990
+
991
+ _, masked_video_latents = self.prepare_mask_latents(
992
+ None,
993
+ masked_video,
994
+ batch_size,
995
+ height,
996
+ width,
997
+ prompt_embeds.dtype,
998
+ device,
999
+ generator,
1000
+ do_classifier_free_guidance,
1001
+ noise_aug_strength=noise_aug_strength,
1002
+ )
1003
+ mask_latents = resize_mask(1 - mask_condition, masked_video_latents)
1004
+ mask_latents = mask_latents.to(masked_video_latents.device) * self.vae.config.scaling_factor
1005
+
1006
+ mask = torch.tile(mask_condition, [1, num_channels_latents, 1, 1, 1])
1007
+ mask = F.interpolate(mask, size=latents.size()[-3:], mode='trilinear', align_corners=True).to(latents.device, latents.dtype)
1008
+
1009
+ mask_input = torch.cat([mask_latents] * 2) if do_classifier_free_guidance else mask_latents
1010
+ masked_video_latents_input = (
1011
+ torch.cat([masked_video_latents] * 2) if do_classifier_free_guidance else masked_video_latents
1012
+ )
1013
+
1014
+ mask = rearrange(mask, "b c f h w -> b f c h w")
1015
+ mask_input = rearrange(mask_input, "b c f h w -> b f c h w")
1016
+ masked_video_latents_input = rearrange(masked_video_latents_input, "b c f h w -> b f c h w")
1017
+
1018
+ inpaint_latents = torch.cat([mask_input, masked_video_latents_input], dim=2).to(latents.dtype)
1019
+ else:
1020
+ mask = torch.tile(mask_condition, [1, num_channels_latents, 1, 1, 1])
1021
+ mask = F.interpolate(mask, size=latents.size()[-3:], mode='trilinear', align_corners=True).to(latents.device, latents.dtype)
1022
+ mask = rearrange(mask, "b c f h w -> b f c h w")
1023
+
1024
+ inpaint_latents = None
1025
+ else:
1026
+ if num_channels_transformer != num_channels_latents:
1027
+ mask = torch.zeros_like(latents).to(latents.device, latents.dtype)
1028
+ masked_video_latents = torch.zeros_like(latents).to(latents.device, latents.dtype)
1029
+
1030
+ mask_input = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
1031
+ masked_video_latents_input = (
1032
+ torch.cat([masked_video_latents] * 2) if do_classifier_free_guidance else masked_video_latents
1033
+ )
1034
+ inpaint_latents = torch.cat([mask_input, masked_video_latents_input], dim=1).to(latents.dtype)
1035
+ else:
1036
+ mask = torch.zeros_like(init_video[:, :1])
1037
+ mask = torch.tile(mask, [1, num_channels_latents, 1, 1, 1])
1038
+ mask = F.interpolate(mask, size=latents.size()[-3:], mode='trilinear', align_corners=True).to(latents.device, latents.dtype)
1039
+ mask = rearrange(mask, "b c f h w -> b f c h w")
1040
+
1041
+ inpaint_latents = None
1042
+ if comfyui_progressbar:
1043
+ pbar.update(1)
1044
+
1045
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
1046
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
1047
+
1048
+ # 7. Create rotary embeds if required
1049
+ image_rotary_emb = (
1050
+ self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
1051
+ if self.transformer.config.use_rotary_positional_embeddings
1052
+ else None
1053
+ )
1054
+
1055
+ # 8. Denoising loop
1056
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
1057
+
1058
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1059
+ # for DPM-solver++
1060
+ old_pred_original_sample = None
1061
+ for i, t in enumerate(timesteps):
1062
+ if self.interrupt:
1063
+ continue
1064
+
1065
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
1066
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1067
+
1068
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
1069
+ timestep = t.expand(latent_model_input.shape[0])
1070
+
1071
+ # predict noise model_output
1072
+ noise_pred = self.transformer(
1073
+ hidden_states=latent_model_input,
1074
+ encoder_hidden_states=prompt_embeds,
1075
+ timestep=timestep,
1076
+ image_rotary_emb=image_rotary_emb,
1077
+ return_dict=False,
1078
+ inpaint_latents=inpaint_latents,
1079
+ )[0]
1080
+ noise_pred = noise_pred.float()
1081
+
1082
+ # perform guidance
1083
+ if use_dynamic_cfg:
1084
+ self._guidance_scale = 1 + guidance_scale * (
1085
+ (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2
1086
+ )
1087
+ if do_classifier_free_guidance:
1088
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1089
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
1090
+
1091
+ # compute the previous noisy sample x_t -> x_t-1
1092
+ if not isinstance(self.scheduler, CogVideoXDPMScheduler):
1093
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
1094
+ else:
1095
+ latents, old_pred_original_sample = self.scheduler.step(
1096
+ noise_pred,
1097
+ old_pred_original_sample,
1098
+ t,
1099
+ timesteps[i - 1] if i > 0 else None,
1100
+ latents,
1101
+ **extra_step_kwargs,
1102
+ return_dict=False,
1103
+ )
1104
+ latents = latents.to(prompt_embeds.dtype)
1105
+
1106
+ # call the callback, if provided
1107
+ if callback_on_step_end is not None:
1108
+ callback_kwargs = {}
1109
+ for k in callback_on_step_end_tensor_inputs:
1110
+ callback_kwargs[k] = locals()[k]
1111
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1112
+
1113
+ latents = callback_outputs.pop("latents", latents)
1114
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1115
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
1116
+
1117
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1118
+ progress_bar.update()
1119
+ if comfyui_progressbar:
1120
+ pbar.update(1)
1121
+
1122
+ if output_type == "numpy":
1123
+ video = self.decode_latents(latents)
1124
+ elif not output_type == "latent":
1125
+ video = self.decode_latents(latents)
1126
+ video = self.video_processor.postprocess_video(video=video, output_type=output_type)
1127
+ else:
1128
+ video = latents
1129
+
1130
+ # Offload all models
1131
+ self.maybe_free_model_hooks()
1132
+
1133
+ if not return_dict:
1134
+ video = torch.from_numpy(video)
1135
+
1136
+ return CogVideoXFunPipelineOutput(videos=video)
videox_fun/pipeline/pipeline_fantasy_talking.py ADDED
@@ -0,0 +1,754 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import math
3
+ import copy
4
+ from dataclasses import dataclass
5
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn.functional as F
10
+ import torchvision.transforms.functional as TF
11
+ from diffusers import FlowMatchEulerDiscreteScheduler
12
+ from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
13
+ from diffusers.image_processor import VaeImageProcessor
14
+ from diffusers.models.embeddings import get_1d_rotary_pos_embed
15
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
16
+ from diffusers.utils import BaseOutput, logging, replace_example_docstring
17
+ from diffusers.utils.torch_utils import randn_tensor
18
+ from diffusers.video_processor import VideoProcessor
19
+ from einops import rearrange
20
+ from PIL import Image
21
+ from torchvision import transforms
22
+ from transformers import T5Tokenizer
23
+
24
+ from ..models import (AutoencoderKLWan, AutoTokenizer, CLIPModel,
25
+ Wan2_2Transformer3DModel_S2V, WanAudioEncoder,
26
+ WanT5EncoderModel)
27
+ from ..utils.fm_solvers import (FlowDPMSolverMultistepScheduler,
28
+ get_sampling_sigmas)
29
+ from ..utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
30
+
31
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
32
+
33
+
34
+ EXAMPLE_DOC_STRING = """
35
+ Examples:
36
+ ```python
37
+ pass
38
+ ```
39
+ """
40
+
41
+
42
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
43
+ def retrieve_timesteps(
44
+ scheduler,
45
+ num_inference_steps: Optional[int] = None,
46
+ device: Optional[Union[str, torch.device]] = None,
47
+ timesteps: Optional[List[int]] = None,
48
+ sigmas: Optional[List[float]] = None,
49
+ **kwargs,
50
+ ):
51
+ """
52
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
53
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
54
+
55
+ Args:
56
+ scheduler (`SchedulerMixin`):
57
+ The scheduler to get timesteps from.
58
+ num_inference_steps (`int`):
59
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
60
+ must be `None`.
61
+ device (`str` or `torch.device`, *optional*):
62
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
63
+ timesteps (`List[int]`, *optional*):
64
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
65
+ `num_inference_steps` and `sigmas` must be `None`.
66
+ sigmas (`List[float]`, *optional*):
67
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
68
+ `num_inference_steps` and `timesteps` must be `None`.
69
+
70
+ Returns:
71
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
72
+ second element is the number of inference steps.
73
+ """
74
+ if timesteps is not None and sigmas is not None:
75
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
76
+ if timesteps is not None:
77
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
78
+ if not accepts_timesteps:
79
+ raise ValueError(
80
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
81
+ f" timestep schedules. Please check whether you are using the correct scheduler."
82
+ )
83
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
84
+ timesteps = scheduler.timesteps
85
+ num_inference_steps = len(timesteps)
86
+ elif sigmas is not None:
87
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
88
+ if not accept_sigmas:
89
+ raise ValueError(
90
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
91
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
92
+ )
93
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
94
+ timesteps = scheduler.timesteps
95
+ num_inference_steps = len(timesteps)
96
+ else:
97
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
98
+ timesteps = scheduler.timesteps
99
+ return timesteps, num_inference_steps
100
+
101
+
102
+ def resize_mask(mask, latent, process_first_frame_only=True):
103
+ latent_size = latent.size()
104
+ batch_size, channels, num_frames, height, width = mask.shape
105
+
106
+ if process_first_frame_only:
107
+ target_size = list(latent_size[2:])
108
+ target_size[0] = 1
109
+ first_frame_resized = F.interpolate(
110
+ mask[:, :, 0:1, :, :],
111
+ size=target_size,
112
+ mode='trilinear',
113
+ align_corners=False
114
+ )
115
+
116
+ target_size = list(latent_size[2:])
117
+ target_size[0] = target_size[0] - 1
118
+ if target_size[0] != 0:
119
+ remaining_frames_resized = F.interpolate(
120
+ mask[:, :, 1:, :, :],
121
+ size=target_size,
122
+ mode='trilinear',
123
+ align_corners=False
124
+ )
125
+ resized_mask = torch.cat([first_frame_resized, remaining_frames_resized], dim=2)
126
+ else:
127
+ resized_mask = first_frame_resized
128
+ else:
129
+ target_size = list(latent_size[2:])
130
+ resized_mask = F.interpolate(
131
+ mask,
132
+ size=target_size,
133
+ mode='trilinear',
134
+ align_corners=False
135
+ )
136
+ return resized_mask
137
+
138
+
139
+ @dataclass
140
+ class WanPipelineOutput(BaseOutput):
141
+ r"""
142
+ Output class for CogVideo pipelines.
143
+
144
+ Args:
145
+ video (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
146
+ List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
147
+ denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
148
+ `(batch_size, num_frames, channels, height, width)`.
149
+ """
150
+
151
+ videos: torch.Tensor
152
+
153
+
154
+ class FantasyTalkingPipeline(DiffusionPipeline):
155
+ r"""
156
+ Pipeline for text-to-video generation using Wan.
157
+
158
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
159
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
160
+ """
161
+
162
+ _optional_components = ["transformer_2", "audio_encoder"]
163
+ model_cpu_offload_seq = "text_encoder->transformer_2->transformer->vae"
164
+
165
+ _callback_tensor_inputs = [
166
+ "latents",
167
+ "prompt_embeds",
168
+ "negative_prompt_embeds",
169
+ ]
170
+
171
+ def __init__(
172
+ self,
173
+ tokenizer: AutoTokenizer,
174
+ text_encoder: WanT5EncoderModel,
175
+ audio_encoder: WanAudioEncoder,
176
+ vae: AutoencoderKLWan,
177
+ transformer: Wan2_2Transformer3DModel_S2V,
178
+ clip_image_encoder: CLIPModel,
179
+ transformer_2: Wan2_2Transformer3DModel_S2V = None,
180
+ scheduler: FlowMatchEulerDiscreteScheduler = None,
181
+ ):
182
+ super().__init__()
183
+
184
+ self.register_modules(
185
+ tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer,
186
+ transformer_2=transformer_2, scheduler=scheduler, clip_image_encoder=clip_image_encoder, audio_encoder=audio_encoder
187
+ )
188
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae.spatial_compression_ratio)
189
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae.spatial_compression_ratio)
190
+ self.mask_processor = VaeImageProcessor(
191
+ vae_scale_factor=self.vae.spatial_compression_ratio, do_normalize=False, do_binarize=True, do_convert_grayscale=True
192
+ )
193
+
194
+ def _get_t5_prompt_embeds(
195
+ self,
196
+ prompt: Union[str, List[str]] = None,
197
+ num_videos_per_prompt: int = 1,
198
+ max_sequence_length: int = 512,
199
+ device: Optional[torch.device] = None,
200
+ dtype: Optional[torch.dtype] = None,
201
+ ):
202
+ device = device or self._execution_device
203
+ dtype = dtype or self.text_encoder.dtype
204
+
205
+ prompt = [prompt] if isinstance(prompt, str) else prompt
206
+ batch_size = len(prompt)
207
+
208
+ text_inputs = self.tokenizer(
209
+ prompt,
210
+ padding="max_length",
211
+ max_length=max_sequence_length,
212
+ truncation=True,
213
+ add_special_tokens=True,
214
+ return_tensors="pt",
215
+ )
216
+ text_input_ids = text_inputs.input_ids
217
+ prompt_attention_mask = text_inputs.attention_mask
218
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
219
+
220
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
221
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
222
+ logger.warning(
223
+ "The following part of your input was truncated because `max_sequence_length` is set to "
224
+ f" {max_sequence_length} tokens: {removed_text}"
225
+ )
226
+
227
+ seq_lens = prompt_attention_mask.gt(0).sum(dim=1).long()
228
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask.to(device))[0]
229
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
230
+
231
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
232
+ _, seq_len, _ = prompt_embeds.shape
233
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
234
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
235
+
236
+ return [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
237
+
238
+ def encode_prompt(
239
+ self,
240
+ prompt: Union[str, List[str]],
241
+ negative_prompt: Optional[Union[str, List[str]]] = None,
242
+ do_classifier_free_guidance: bool = True,
243
+ num_videos_per_prompt: int = 1,
244
+ prompt_embeds: Optional[torch.Tensor] = None,
245
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
246
+ max_sequence_length: int = 512,
247
+ device: Optional[torch.device] = None,
248
+ dtype: Optional[torch.dtype] = None,
249
+ ):
250
+ r"""
251
+ Encodes the prompt into text encoder hidden states.
252
+
253
+ Args:
254
+ prompt (`str` or `List[str]`, *optional*):
255
+ prompt to be encoded
256
+ negative_prompt (`str` or `List[str]`, *optional*):
257
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
258
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
259
+ less than `1`).
260
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
261
+ Whether to use classifier free guidance or not.
262
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
263
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
264
+ prompt_embeds (`torch.Tensor`, *optional*):
265
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
266
+ provided, text embeddings will be generated from `prompt` input argument.
267
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
268
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
269
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
270
+ argument.
271
+ device: (`torch.device`, *optional*):
272
+ torch device
273
+ dtype: (`torch.dtype`, *optional*):
274
+ torch dtype
275
+ """
276
+ device = device or self._execution_device
277
+
278
+ prompt = [prompt] if isinstance(prompt, str) else prompt
279
+ if prompt is not None:
280
+ batch_size = len(prompt)
281
+ else:
282
+ batch_size = prompt_embeds.shape[0]
283
+
284
+ if prompt_embeds is None:
285
+ prompt_embeds = self._get_t5_prompt_embeds(
286
+ prompt=prompt,
287
+ num_videos_per_prompt=num_videos_per_prompt,
288
+ max_sequence_length=max_sequence_length,
289
+ device=device,
290
+ dtype=dtype,
291
+ )
292
+
293
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
294
+ negative_prompt = negative_prompt or ""
295
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
296
+
297
+ if prompt is not None and type(prompt) is not type(negative_prompt):
298
+ raise TypeError(
299
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
300
+ f" {type(prompt)}."
301
+ )
302
+ elif batch_size != len(negative_prompt):
303
+ raise ValueError(
304
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
305
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
306
+ " the batch size of `prompt`."
307
+ )
308
+
309
+ negative_prompt_embeds = self._get_t5_prompt_embeds(
310
+ prompt=negative_prompt,
311
+ num_videos_per_prompt=num_videos_per_prompt,
312
+ max_sequence_length=max_sequence_length,
313
+ device=device,
314
+ dtype=dtype,
315
+ )
316
+
317
+ return prompt_embeds, negative_prompt_embeds
318
+
319
+ def prepare_latents(
320
+ self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None, num_length_latents=None
321
+ ):
322
+ if isinstance(generator, list) and len(generator) != batch_size:
323
+ raise ValueError(
324
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
325
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
326
+ )
327
+
328
+ shape = (
329
+ batch_size,
330
+ num_channels_latents,
331
+ (num_frames - 1) // self.vae.temporal_compression_ratio + 1 if num_length_latents is None else num_length_latents,
332
+ height // self.vae.spatial_compression_ratio,
333
+ width // self.vae.spatial_compression_ratio,
334
+ )
335
+
336
+ if latents is None:
337
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
338
+ else:
339
+ latents = latents.to(device)
340
+
341
+ # scale the initial noise by the standard deviation required by the scheduler
342
+ if hasattr(self.scheduler, "init_noise_sigma"):
343
+ latents = latents * self.scheduler.init_noise_sigma
344
+ return latents
345
+
346
+ def prepare_mask_latents(
347
+ self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance, noise_aug_strength
348
+ ):
349
+ # resize the mask to latents shape as we concatenate the mask to the latents
350
+ # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
351
+ # and half precision
352
+
353
+ if mask is not None:
354
+ mask = mask.to(device=device, dtype=self.vae.dtype)
355
+ bs = 1
356
+ new_mask = []
357
+ for i in range(0, mask.shape[0], bs):
358
+ mask_bs = mask[i : i + bs]
359
+ mask_bs = self.vae.encode(mask_bs)[0]
360
+ mask_bs = mask_bs.mode()
361
+ new_mask.append(mask_bs)
362
+ mask = torch.cat(new_mask, dim = 0)
363
+ # mask = mask * self.vae.config.scaling_factor
364
+
365
+ if masked_image is not None:
366
+ masked_image = masked_image.to(device=device, dtype=self.vae.dtype)
367
+ bs = 1
368
+ new_mask_pixel_values = []
369
+ for i in range(0, masked_image.shape[0], bs):
370
+ mask_pixel_values_bs = masked_image[i : i + bs]
371
+ mask_pixel_values_bs = self.vae.encode(mask_pixel_values_bs)[0]
372
+ mask_pixel_values_bs = mask_pixel_values_bs.mode()
373
+ new_mask_pixel_values.append(mask_pixel_values_bs)
374
+ masked_image_latents = torch.cat(new_mask_pixel_values, dim = 0)
375
+ # masked_image_latents = masked_image_latents * self.vae.config.scaling_factor
376
+ else:
377
+ masked_image_latents = None
378
+
379
+ return mask, masked_image_latents
380
+
381
+ def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
382
+ frames = self.vae.decode(latents.to(self.vae.dtype)).sample
383
+ frames = (frames / 2 + 0.5).clamp(0, 1)
384
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
385
+ frames = frames.cpu().float().numpy()
386
+ return frames
387
+
388
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
389
+ def prepare_extra_step_kwargs(self, generator, eta):
390
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
391
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
392
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
393
+ # and should be between [0, 1]
394
+
395
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
396
+ extra_step_kwargs = {}
397
+ if accepts_eta:
398
+ extra_step_kwargs["eta"] = eta
399
+
400
+ # check if the scheduler accepts generator
401
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
402
+ if accepts_generator:
403
+ extra_step_kwargs["generator"] = generator
404
+ return extra_step_kwargs
405
+
406
+ # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs
407
+ def check_inputs(
408
+ self,
409
+ prompt,
410
+ height,
411
+ width,
412
+ negative_prompt,
413
+ callback_on_step_end_tensor_inputs,
414
+ prompt_embeds=None,
415
+ negative_prompt_embeds=None,
416
+ ):
417
+ if height % 8 != 0 or width % 8 != 0:
418
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
419
+
420
+ if callback_on_step_end_tensor_inputs is not None and not all(
421
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
422
+ ):
423
+ raise ValueError(
424
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
425
+ )
426
+ if prompt is not None and prompt_embeds is not None:
427
+ raise ValueError(
428
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
429
+ " only forward one of the two."
430
+ )
431
+ elif prompt is None and prompt_embeds is None:
432
+ raise ValueError(
433
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
434
+ )
435
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
436
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
437
+
438
+ if prompt is not None and negative_prompt_embeds is not None:
439
+ raise ValueError(
440
+ f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
441
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
442
+ )
443
+
444
+ if negative_prompt is not None and negative_prompt_embeds is not None:
445
+ raise ValueError(
446
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
447
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
448
+ )
449
+
450
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
451
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
452
+ raise ValueError(
453
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
454
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
455
+ f" {negative_prompt_embeds.shape}."
456
+ )
457
+
458
+ @property
459
+ def guidance_scale(self):
460
+ return self._guidance_scale
461
+
462
+ @property
463
+ def num_timesteps(self):
464
+ return self._num_timesteps
465
+
466
+ @property
467
+ def attention_kwargs(self):
468
+ return self._attention_kwargs
469
+
470
+ @property
471
+ def interrupt(self):
472
+ return self._interrupt
473
+
474
+ @torch.no_grad()
475
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
476
+ def __call__(
477
+ self,
478
+ prompt: Optional[Union[str, List[str]]] = None,
479
+ negative_prompt: Optional[Union[str, List[str]]] = None,
480
+ height: int = 480,
481
+ width: int = 720,
482
+ video: Union[torch.FloatTensor] = None,
483
+ mask_video: Union[torch.FloatTensor] = None,
484
+ audio_path = None,
485
+ num_frames: int = 49,
486
+ num_inference_steps: int = 50,
487
+ timesteps: Optional[List[int]] = None,
488
+ guidance_scale: float = 6,
489
+ num_videos_per_prompt: int = 1,
490
+ eta: float = 0.0,
491
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
492
+ latents: Optional[torch.FloatTensor] = None,
493
+ prompt_embeds: Optional[torch.FloatTensor] = None,
494
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
495
+ output_type: str = "numpy",
496
+ return_dict: bool = False,
497
+ callback_on_step_end: Optional[
498
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
499
+ ] = None,
500
+ attention_kwargs: Optional[Dict[str, Any]] = None,
501
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
502
+ clip_image: Image = None,
503
+ max_sequence_length: int = 512,
504
+ comfyui_progressbar: bool = False,
505
+ shift: int = 5,
506
+ fps: int = 16,
507
+ ) -> Union[WanPipelineOutput, Tuple]:
508
+ """
509
+ Function invoked when calling the pipeline for generation.
510
+ Args:
511
+
512
+ Examples:
513
+
514
+ Returns:
515
+
516
+ """
517
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
518
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
519
+ num_videos_per_prompt = 1
520
+
521
+ # 1. Check inputs. Raise error if not correct
522
+ self.check_inputs(
523
+ prompt,
524
+ height,
525
+ width,
526
+ negative_prompt,
527
+ callback_on_step_end_tensor_inputs,
528
+ prompt_embeds,
529
+ negative_prompt_embeds,
530
+ )
531
+ self._guidance_scale = guidance_scale
532
+ self._attention_kwargs = attention_kwargs
533
+ self._interrupt = False
534
+
535
+ # 2. Default call parameters
536
+ if prompt is not None and isinstance(prompt, str):
537
+ batch_size = 1
538
+ elif prompt is not None and isinstance(prompt, list):
539
+ batch_size = len(prompt)
540
+ else:
541
+ batch_size = prompt_embeds.shape[0]
542
+
543
+ device = self._execution_device
544
+ weight_dtype = self.text_encoder.dtype
545
+
546
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
547
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
548
+ # corresponds to doing no classifier free guidance.
549
+ do_classifier_free_guidance = guidance_scale > 1.0
550
+
551
+ # 3. Encode input prompt
552
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
553
+ prompt,
554
+ negative_prompt,
555
+ do_classifier_free_guidance,
556
+ num_videos_per_prompt=num_videos_per_prompt,
557
+ prompt_embeds=prompt_embeds,
558
+ negative_prompt_embeds=negative_prompt_embeds,
559
+ max_sequence_length=max_sequence_length,
560
+ device=device,
561
+ )
562
+ if do_classifier_free_guidance:
563
+ in_prompt_embeds = negative_prompt_embeds + prompt_embeds
564
+ else:
565
+ in_prompt_embeds = prompt_embeds
566
+
567
+ # 4. Prepare timesteps
568
+ if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler):
569
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps, mu=1)
570
+ elif isinstance(self.scheduler, FlowUniPCMultistepScheduler):
571
+ self.scheduler.set_timesteps(num_inference_steps, device=device, shift=shift)
572
+ timesteps = self.scheduler.timesteps
573
+ elif isinstance(self.scheduler, FlowDPMSolverMultistepScheduler):
574
+ sampling_sigmas = get_sampling_sigmas(num_inference_steps, shift)
575
+ timesteps, _ = retrieve_timesteps(
576
+ self.scheduler,
577
+ device=device,
578
+ sigmas=sampling_sigmas)
579
+ else:
580
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
581
+ self._num_timesteps = len(timesteps)
582
+ if comfyui_progressbar:
583
+ from comfy.utils import ProgressBar
584
+ pbar = ProgressBar(num_inference_steps + 2)
585
+
586
+ # 5. Prepare latents.
587
+ if video is not None:
588
+ video_length = video.shape[2]
589
+ init_video = self.image_processor.preprocess(rearrange(video, "b c f h w -> (b f) c h w"), height=height, width=width)
590
+ init_video = init_video.to(dtype=torch.float32)
591
+ init_video = rearrange(init_video, "(b f) c h w -> b c f h w", f=video_length)
592
+ else:
593
+ init_video = None
594
+
595
+ latent_channels = self.vae.config.latent_channels
596
+ latents = self.prepare_latents(
597
+ batch_size * num_videos_per_prompt,
598
+ latent_channels,
599
+ num_frames,
600
+ height,
601
+ width,
602
+ weight_dtype,
603
+ device,
604
+ generator,
605
+ latents,
606
+ )
607
+ if comfyui_progressbar:
608
+ pbar.update(1)
609
+
610
+ # Prepare mask latent variables
611
+ if init_video is not None:
612
+ if (mask_video == 255).all():
613
+ mask_latents = torch.tile(
614
+ torch.zeros_like(latents)[:, :1].to(device, weight_dtype), [1, 4, 1, 1, 1]
615
+ )
616
+ masked_video_latents = torch.zeros_like(latents).to(device, weight_dtype)
617
+ else:
618
+ bs, _, video_length, height, width = video.size()
619
+ mask_condition = self.mask_processor.preprocess(rearrange(mask_video, "b c f h w -> (b f) c h w"), height=height, width=width)
620
+ mask_condition = mask_condition.to(dtype=torch.float32)
621
+ mask_condition = rearrange(mask_condition, "(b f) c h w -> b c f h w", f=video_length)
622
+
623
+ masked_video = init_video * (torch.tile(mask_condition, [1, 3, 1, 1, 1]) < 0.5)
624
+ _, masked_video_latents = self.prepare_mask_latents(
625
+ None,
626
+ masked_video,
627
+ batch_size,
628
+ height,
629
+ width,
630
+ weight_dtype,
631
+ device,
632
+ generator,
633
+ do_classifier_free_guidance,
634
+ noise_aug_strength=None,
635
+ )
636
+
637
+ mask_condition = torch.concat(
638
+ [
639
+ torch.repeat_interleave(mask_condition[:, :, 0:1], repeats=4, dim=2),
640
+ mask_condition[:, :, 1:]
641
+ ], dim=2
642
+ )
643
+ mask_condition = mask_condition.view(bs, mask_condition.shape[2] // 4, 4, height, width)
644
+ mask_condition = mask_condition.transpose(1, 2)
645
+ mask_latents = resize_mask(1 - mask_condition, masked_video_latents, True).to(device, weight_dtype)
646
+
647
+ # Prepare clip latent variables
648
+ if clip_image is not None:
649
+ clip_image = TF.to_tensor(clip_image).sub_(0.5).div_(0.5).to(device, weight_dtype)
650
+ clip_context = self.clip_image_encoder([clip_image[:, None, :, :]])
651
+ else:
652
+ clip_image = Image.new("RGB", (512, 512), color=(0, 0, 0))
653
+ clip_image = TF.to_tensor(clip_image).sub_(0.5).div_(0.5).to(device, weight_dtype)
654
+ clip_context = self.clip_image_encoder([clip_image[:, None, :, :]])
655
+ clip_context = torch.zeros_like(clip_context)
656
+
657
+ # Extract audio emb
658
+ audio_wav2vec_fea = self.audio_encoder.extract_audio_feat(audio_path, num_frames=num_frames, fps=fps)
659
+
660
+ if comfyui_progressbar:
661
+ pbar.update(1)
662
+
663
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
664
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
665
+
666
+ target_shape = (self.vae.latent_channels, (num_frames - 1) // self.vae.temporal_compression_ratio + 1, width // self.vae.spatial_compression_ratio, height // self.vae.spatial_compression_ratio)
667
+ seq_len = math.ceil((target_shape[2] * target_shape[3]) / (self.transformer.config.patch_size[1] * self.transformer.config.patch_size[2]) * target_shape[1])
668
+ # 7. Denoising loop
669
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
670
+ self.transformer.num_inference_steps = num_inference_steps
671
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
672
+ for i, t in enumerate(timesteps):
673
+ self.transformer.current_steps = i
674
+
675
+ if self.interrupt:
676
+ continue
677
+
678
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
679
+ if hasattr(self.scheduler, "scale_model_input"):
680
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
681
+
682
+ if init_video is not None:
683
+ mask_input = torch.cat([mask_latents] * 2) if do_classifier_free_guidance else mask_latents
684
+ masked_video_latents_input = (
685
+ torch.cat([masked_video_latents] * 2) if do_classifier_free_guidance else masked_video_latents
686
+ )
687
+ y = torch.cat([mask_input, masked_video_latents_input], dim=1).to(device, weight_dtype)
688
+
689
+ clip_context_input = (
690
+ torch.cat([clip_context] * 2) if do_classifier_free_guidance else clip_context
691
+ )
692
+
693
+ audio_wav2vec_fea_input = (
694
+ torch.cat([audio_wav2vec_fea] * 2) if do_classifier_free_guidance else audio_wav2vec_fea
695
+ )
696
+
697
+ audio_scale = torch.tensor(
698
+ [0.75, 1]
699
+ ).to(latent_model_input.device, latent_model_input.dtype)
700
+
701
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
702
+ timestep = t.expand(latent_model_input.shape[0])
703
+
704
+ # predict noise model_output
705
+ with torch.cuda.amp.autocast(dtype=weight_dtype), torch.cuda.device(device=device):
706
+ noise_pred = self.transformer(
707
+ x=latent_model_input,
708
+ context=in_prompt_embeds,
709
+ t=timestep,
710
+ seq_len=seq_len,
711
+ y=y,
712
+ audio_wav2vec_fea=audio_wav2vec_fea_input,
713
+ audio_scale=audio_scale,
714
+ clip_fea=clip_context_input,
715
+ )
716
+
717
+ # perform guidance
718
+ if do_classifier_free_guidance:
719
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
720
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
721
+
722
+ # compute the previous noisy sample x_t -> x_t-1
723
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
724
+
725
+ if callback_on_step_end is not None:
726
+ callback_kwargs = {}
727
+ for k in callback_on_step_end_tensor_inputs:
728
+ callback_kwargs[k] = locals()[k]
729
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
730
+
731
+ latents = callback_outputs.pop("latents", latents)
732
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
733
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
734
+
735
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
736
+ progress_bar.update()
737
+ if comfyui_progressbar:
738
+ pbar.update(1)
739
+
740
+ if output_type == "numpy":
741
+ video = self.decode_latents(latents)
742
+ elif not output_type == "latent":
743
+ video = self.decode_latents(latents)
744
+ video = self.video_processor.postprocess_video(video=video, output_type=output_type)
745
+ else:
746
+ video = latents
747
+
748
+ # Offload all models
749
+ self.maybe_free_model_hooks()
750
+
751
+ if not return_dict:
752
+ video = torch.from_numpy(video)
753
+
754
+ return WanPipelineOutput(videos=video)
videox_fun/pipeline/pipeline_flux.py ADDED
@@ -0,0 +1,978 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/flux/pipeline_flux.py
2
+ # Copyright 2025 Black Forest Labs and The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import inspect
17
+ from dataclasses import dataclass
18
+ from typing import Any, Callable, Dict, List, Optional, Union
19
+
20
+ import numpy as np
21
+ import PIL.Image
22
+ import torch
23
+ from diffusers.image_processor import VaeImageProcessor, PipelineImageInput
24
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
25
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
26
+ from diffusers.utils import (BaseOutput, is_torch_xla_available, logging,
27
+ replace_example_docstring)
28
+ from diffusers.utils.torch_utils import randn_tensor
29
+
30
+ from ..models import (CLIPImageProcessor, CLIPTextModel,
31
+ CLIPTokenizer, CLIPVisionModelWithProjection,
32
+ FluxTransformer2DModel, T5EncoderModel, AutoencoderKL,
33
+ T5TokenizerFast)
34
+
35
+ if is_torch_xla_available():
36
+ import torch_xla.core.xla_model as xm
37
+
38
+ XLA_AVAILABLE = True
39
+ else:
40
+ XLA_AVAILABLE = False
41
+
42
+
43
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
44
+
45
+ EXAMPLE_DOC_STRING = """
46
+ Examples:
47
+ ```py
48
+ >>> import torch
49
+ >>> from diffusers import FluxPipeline
50
+
51
+ >>> pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
52
+ >>> pipe.to("cuda")
53
+ >>> prompt = "A cat holding a sign that says hello world"
54
+ >>> # Depending on the variant being used, the pipeline call will slightly vary.
55
+ >>> # Refer to the pipeline documentation for more details.
56
+ >>> image = pipe(prompt, num_inference_steps=4, guidance_scale=0.0).images[0]
57
+ >>> image.save("flux.png")
58
+ ```
59
+ """
60
+
61
+
62
+ def calculate_shift(
63
+ image_seq_len,
64
+ base_seq_len: int = 256,
65
+ max_seq_len: int = 4096,
66
+ base_shift: float = 0.5,
67
+ max_shift: float = 1.15,
68
+ ):
69
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
70
+ b = base_shift - m * base_seq_len
71
+ mu = image_seq_len * m + b
72
+ return mu
73
+
74
+
75
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
76
+ def retrieve_timesteps(
77
+ scheduler,
78
+ num_inference_steps: Optional[int] = None,
79
+ device: Optional[Union[str, torch.device]] = None,
80
+ timesteps: Optional[List[int]] = None,
81
+ sigmas: Optional[List[float]] = None,
82
+ **kwargs,
83
+ ):
84
+ r"""
85
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
86
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
87
+
88
+ Args:
89
+ scheduler (`SchedulerMixin`):
90
+ The scheduler to get timesteps from.
91
+ num_inference_steps (`int`):
92
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
93
+ must be `None`.
94
+ device (`str` or `torch.device`, *optional*):
95
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
96
+ timesteps (`List[int]`, *optional*):
97
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
98
+ `num_inference_steps` and `sigmas` must be `None`.
99
+ sigmas (`List[float]`, *optional*):
100
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
101
+ `num_inference_steps` and `timesteps` must be `None`.
102
+
103
+ Returns:
104
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
105
+ second element is the number of inference steps.
106
+ """
107
+ if timesteps is not None and sigmas is not None:
108
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
109
+ if timesteps is not None:
110
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
111
+ if not accepts_timesteps:
112
+ raise ValueError(
113
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
114
+ f" timestep schedules. Please check whether you are using the correct scheduler."
115
+ )
116
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
117
+ timesteps = scheduler.timesteps
118
+ num_inference_steps = len(timesteps)
119
+ elif sigmas is not None:
120
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
121
+ if not accept_sigmas:
122
+ raise ValueError(
123
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
124
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
125
+ )
126
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
127
+ timesteps = scheduler.timesteps
128
+ num_inference_steps = len(timesteps)
129
+ else:
130
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
131
+ timesteps = scheduler.timesteps
132
+ return timesteps, num_inference_steps
133
+
134
+
135
+ @dataclass
136
+ class FluxPipelineOutput(BaseOutput):
137
+ """
138
+ Output class for Flux image generation pipelines.
139
+
140
+ Args:
141
+ images (`List[PIL.Image.Image]` or `torch.Tensor` or `np.ndarray`)
142
+ List of denoised PIL images of length `batch_size` or numpy array or torch tensor of shape `(batch_size,
143
+ height, width, num_channels)`. PIL images or numpy array present the denoised images of the diffusion
144
+ pipeline. Torch tensors can represent either the denoised images or the intermediate latents ready to be
145
+ passed to the decoder.
146
+ """
147
+
148
+ images: Union[List[PIL.Image.Image], np.ndarray]
149
+
150
+
151
+ @dataclass
152
+ class FluxPriorReduxPipelineOutput(BaseOutput):
153
+ """
154
+ Output class for Flux Prior Redux pipelines.
155
+
156
+ Args:
157
+ images (`List[PIL.Image.Image]` or `np.ndarray`)
158
+ List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
159
+ num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
160
+ """
161
+
162
+ prompt_embeds: torch.Tensor
163
+ pooled_prompt_embeds: torch.Tensor
164
+
165
+
166
+ class FluxPipeline(
167
+ DiffusionPipeline,
168
+ ):
169
+ r"""
170
+ The Flux pipeline for text-to-image generation.
171
+
172
+ Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
173
+
174
+ Args:
175
+ transformer ([`FluxTransformer2DModel`]):
176
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
177
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
178
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
179
+ vae ([`AutoencoderKL`]):
180
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
181
+ text_encoder ([`CLIPTextModel`]):
182
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
183
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
184
+ text_encoder_2 ([`T5EncoderModel`]):
185
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
186
+ the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
187
+ tokenizer (`CLIPTokenizer`):
188
+ Tokenizer of class
189
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
190
+ tokenizer_2 (`T5TokenizerFast`):
191
+ Second Tokenizer of class
192
+ [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
193
+ """
194
+
195
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->transformer->vae"
196
+ _optional_components = ["image_encoder", "feature_extractor"]
197
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
198
+
199
+ def __init__(
200
+ self,
201
+ scheduler: FlowMatchEulerDiscreteScheduler,
202
+ vae: AutoencoderKL,
203
+ text_encoder: CLIPTextModel,
204
+ tokenizer: CLIPTokenizer,
205
+ text_encoder_2: T5EncoderModel,
206
+ tokenizer_2: T5TokenizerFast,
207
+ transformer: FluxTransformer2DModel,
208
+ image_encoder: CLIPVisionModelWithProjection = None,
209
+ feature_extractor: CLIPImageProcessor = None,
210
+ ):
211
+ super().__init__()
212
+
213
+ self.register_modules(
214
+ vae=vae,
215
+ text_encoder=text_encoder,
216
+ text_encoder_2=text_encoder_2,
217
+ tokenizer=tokenizer,
218
+ tokenizer_2=tokenizer_2,
219
+ transformer=transformer,
220
+ scheduler=scheduler,
221
+ image_encoder=image_encoder,
222
+ feature_extractor=feature_extractor,
223
+ )
224
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
225
+ # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
226
+ # by the patch size. So the vae scale factor is multiplied by the patch size to account for this
227
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
228
+ self.tokenizer_max_length = (
229
+ self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
230
+ )
231
+ self.default_sample_size = 128
232
+
233
+ def _get_t5_prompt_embeds(
234
+ self,
235
+ prompt: Union[str, List[str]] = None,
236
+ num_images_per_prompt: int = 1,
237
+ max_sequence_length: int = 512,
238
+ device: Optional[torch.device] = None,
239
+ dtype: Optional[torch.dtype] = None,
240
+ ):
241
+ device = device or self._execution_device
242
+ dtype = dtype or self.text_encoder.dtype
243
+
244
+ prompt = [prompt] if isinstance(prompt, str) else prompt
245
+ batch_size = len(prompt)
246
+
247
+ text_inputs = self.tokenizer_2(
248
+ prompt,
249
+ padding="max_length",
250
+ max_length=max_sequence_length,
251
+ truncation=True,
252
+ return_length=False,
253
+ return_overflowing_tokens=False,
254
+ return_tensors="pt",
255
+ )
256
+ text_input_ids = text_inputs.input_ids
257
+ untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
258
+
259
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
260
+ removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
261
+ logger.warning(
262
+ "The following part of your input was truncated because `max_sequence_length` is set to "
263
+ f" {max_sequence_length} tokens: {removed_text}"
264
+ )
265
+
266
+ prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
267
+
268
+ dtype = self.text_encoder_2.dtype
269
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
270
+
271
+ _, seq_len, _ = prompt_embeds.shape
272
+
273
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
274
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
275
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
276
+
277
+ return prompt_embeds
278
+
279
+ def _get_clip_prompt_embeds(
280
+ self,
281
+ prompt: Union[str, List[str]],
282
+ num_images_per_prompt: int = 1,
283
+ device: Optional[torch.device] = None,
284
+ ):
285
+ device = device or self._execution_device
286
+
287
+ prompt = [prompt] if isinstance(prompt, str) else prompt
288
+ batch_size = len(prompt)
289
+
290
+ text_inputs = self.tokenizer(
291
+ prompt,
292
+ padding="max_length",
293
+ max_length=self.tokenizer_max_length,
294
+ truncation=True,
295
+ return_overflowing_tokens=False,
296
+ return_length=False,
297
+ return_tensors="pt",
298
+ )
299
+
300
+ text_input_ids = text_inputs.input_ids
301
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
302
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
303
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
304
+ logger.warning(
305
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
306
+ f" {self.tokenizer_max_length} tokens: {removed_text}"
307
+ )
308
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
309
+
310
+ # Use pooled output of CLIPTextModel
311
+ prompt_embeds = prompt_embeds.pooler_output
312
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
313
+
314
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
315
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
316
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
317
+
318
+ return prompt_embeds
319
+
320
+ def encode_prompt(
321
+ self,
322
+ prompt: Union[str, List[str]],
323
+ prompt_2: Optional[Union[str, List[str]]] = None,
324
+ device: Optional[torch.device] = None,
325
+ num_images_per_prompt: int = 1,
326
+ prompt_embeds: Optional[torch.FloatTensor] = None,
327
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
328
+ max_sequence_length: int = 512,
329
+ lora_scale: Optional[float] = None,
330
+ ):
331
+ r"""
332
+
333
+ Args:
334
+ prompt (`str` or `List[str]`, *optional*):
335
+ prompt to be encoded
336
+ prompt_2 (`str` or `List[str]`, *optional*):
337
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
338
+ used in all text-encoders
339
+ device: (`torch.device`):
340
+ torch device
341
+ num_images_per_prompt (`int`):
342
+ number of images that should be generated per prompt
343
+ prompt_embeds (`torch.FloatTensor`, *optional*):
344
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
345
+ provided, text embeddings will be generated from `prompt` input argument.
346
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
347
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
348
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
349
+ lora_scale (`float`, *optional*):
350
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
351
+ """
352
+ device = device or self._execution_device
353
+
354
+ # set lora scale so that monkey patched LoRA
355
+ # function of text encoder can correctly access it
356
+ prompt = [prompt] if isinstance(prompt, str) else prompt
357
+
358
+ if prompt_embeds is None:
359
+ prompt_2 = prompt_2 or prompt
360
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
361
+
362
+ # We only use the pooled prompt output from the CLIPTextModel
363
+ pooled_prompt_embeds = self._get_clip_prompt_embeds(
364
+ prompt=prompt,
365
+ device=device,
366
+ num_images_per_prompt=num_images_per_prompt,
367
+ )
368
+ prompt_embeds = self._get_t5_prompt_embeds(
369
+ prompt=prompt_2,
370
+ num_images_per_prompt=num_images_per_prompt,
371
+ max_sequence_length=max_sequence_length,
372
+ device=device,
373
+ )
374
+
375
+ dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
376
+ text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
377
+
378
+ return prompt_embeds, pooled_prompt_embeds, text_ids
379
+
380
+ def encode_image(self, image, device, num_images_per_prompt):
381
+ dtype = next(self.image_encoder.parameters()).dtype
382
+
383
+ if not isinstance(image, torch.Tensor):
384
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
385
+
386
+ image = image.to(device=device, dtype=dtype)
387
+ image_embeds = self.image_encoder(image).image_embeds
388
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
389
+ return image_embeds
390
+
391
+ def prepare_ip_adapter_image_embeds(
392
+ self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt
393
+ ):
394
+ image_embeds = []
395
+ if ip_adapter_image_embeds is None:
396
+ if not isinstance(ip_adapter_image, list):
397
+ ip_adapter_image = [ip_adapter_image]
398
+
399
+ if len(ip_adapter_image) != self.transformer.encoder_hid_proj.num_ip_adapters:
400
+ raise ValueError(
401
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters."
402
+ )
403
+
404
+ for single_ip_adapter_image in ip_adapter_image:
405
+ single_image_embeds = self.encode_image(single_ip_adapter_image, device, 1)
406
+ image_embeds.append(single_image_embeds[None, :])
407
+ else:
408
+ if not isinstance(ip_adapter_image_embeds, list):
409
+ ip_adapter_image_embeds = [ip_adapter_image_embeds]
410
+
411
+ if len(ip_adapter_image_embeds) != self.transformer.encoder_hid_proj.num_ip_adapters:
412
+ raise ValueError(
413
+ f"`ip_adapter_image_embeds` must have same length as the number of IP Adapters. Got {len(ip_adapter_image_embeds)} image embeds and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters."
414
+ )
415
+
416
+ for single_image_embeds in ip_adapter_image_embeds:
417
+ image_embeds.append(single_image_embeds)
418
+
419
+ ip_adapter_image_embeds = []
420
+ for single_image_embeds in image_embeds:
421
+ single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
422
+ single_image_embeds = single_image_embeds.to(device=device)
423
+ ip_adapter_image_embeds.append(single_image_embeds)
424
+
425
+ return ip_adapter_image_embeds
426
+
427
+ def check_inputs(
428
+ self,
429
+ prompt,
430
+ prompt_2,
431
+ height,
432
+ width,
433
+ negative_prompt=None,
434
+ negative_prompt_2=None,
435
+ prompt_embeds=None,
436
+ negative_prompt_embeds=None,
437
+ pooled_prompt_embeds=None,
438
+ negative_pooled_prompt_embeds=None,
439
+ callback_on_step_end_tensor_inputs=None,
440
+ max_sequence_length=None,
441
+ ):
442
+ if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
443
+ logger.warning(
444
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
445
+ )
446
+
447
+ if callback_on_step_end_tensor_inputs is not None and not all(
448
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
449
+ ):
450
+ raise ValueError(
451
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
452
+ )
453
+
454
+ if prompt is not None and prompt_embeds is not None:
455
+ raise ValueError(
456
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
457
+ " only forward one of the two."
458
+ )
459
+ elif prompt_2 is not None and prompt_embeds is not None:
460
+ raise ValueError(
461
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
462
+ " only forward one of the two."
463
+ )
464
+ elif prompt is None and prompt_embeds is None:
465
+ raise ValueError(
466
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
467
+ )
468
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
469
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
470
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
471
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
472
+
473
+ if negative_prompt is not None and negative_prompt_embeds is not None:
474
+ raise ValueError(
475
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
476
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
477
+ )
478
+ elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
479
+ raise ValueError(
480
+ f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
481
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
482
+ )
483
+
484
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
485
+ raise ValueError(
486
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
487
+ )
488
+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
489
+ raise ValueError(
490
+ "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
491
+ )
492
+
493
+ if max_sequence_length is not None and max_sequence_length > 512:
494
+ raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
495
+
496
+ @staticmethod
497
+ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
498
+ latent_image_ids = torch.zeros(height, width, 3)
499
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
500
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
501
+
502
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
503
+
504
+ latent_image_ids = latent_image_ids.reshape(
505
+ latent_image_id_height * latent_image_id_width, latent_image_id_channels
506
+ )
507
+
508
+ return latent_image_ids.to(device=device, dtype=dtype)
509
+
510
+ @staticmethod
511
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
512
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
513
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
514
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
515
+
516
+ return latents
517
+
518
+ @staticmethod
519
+ def _unpack_latents(latents, height, width, vae_scale_factor):
520
+ batch_size, num_patches, channels = latents.shape
521
+
522
+ # VAE applies 8x compression on images but we must also account for packing which requires
523
+ # latent height and width to be divisible by 2.
524
+ height = 2 * (int(height) // (vae_scale_factor * 2))
525
+ width = 2 * (int(width) // (vae_scale_factor * 2))
526
+
527
+ latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
528
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
529
+
530
+ latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
531
+
532
+ return latents
533
+
534
+ def enable_vae_slicing(self):
535
+ r"""
536
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
537
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
538
+ """
539
+ self.vae.enable_slicing()
540
+
541
+ def disable_vae_slicing(self):
542
+ r"""
543
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
544
+ computing decoding in one step.
545
+ """
546
+ self.vae.disable_slicing()
547
+
548
+ def enable_vae_tiling(self):
549
+ r"""
550
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
551
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
552
+ processing larger images.
553
+ """
554
+ self.vae.enable_tiling()
555
+
556
+ def disable_vae_tiling(self):
557
+ r"""
558
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
559
+ computing decoding in one step.
560
+ """
561
+ self.vae.disable_tiling()
562
+
563
+ def prepare_latents(
564
+ self,
565
+ batch_size,
566
+ num_channels_latents,
567
+ height,
568
+ width,
569
+ dtype,
570
+ device,
571
+ generator,
572
+ latents=None,
573
+ ):
574
+ # VAE applies 8x compression on images but we must also account for packing which requires
575
+ # latent height and width to be divisible by 2.
576
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
577
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
578
+
579
+ shape = (batch_size, num_channels_latents, height, width)
580
+
581
+ if latents is not None:
582
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
583
+ return latents.to(device=device, dtype=dtype), latent_image_ids
584
+
585
+ if isinstance(generator, list) and len(generator) != batch_size:
586
+ raise ValueError(
587
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
588
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
589
+ )
590
+
591
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
592
+ latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
593
+
594
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
595
+
596
+ return latents, latent_image_ids
597
+
598
+ @property
599
+ def guidance_scale(self):
600
+ return self._guidance_scale
601
+
602
+ @property
603
+ def joint_attention_kwargs(self):
604
+ return self._joint_attention_kwargs
605
+
606
+ @property
607
+ def num_timesteps(self):
608
+ return self._num_timesteps
609
+
610
+ @property
611
+ def current_timestep(self):
612
+ return self._current_timestep
613
+
614
+ @property
615
+ def interrupt(self):
616
+ return self._interrupt
617
+
618
+ @torch.no_grad()
619
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
620
+ def __call__(
621
+ self,
622
+ prompt: Union[str, List[str]] = None,
623
+ prompt_2: Optional[Union[str, List[str]]] = None,
624
+ negative_prompt: Union[str, List[str]] = None,
625
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
626
+ true_cfg_scale: float = 1.0,
627
+ height: Optional[int] = None,
628
+ width: Optional[int] = None,
629
+ num_inference_steps: int = 28,
630
+ sigmas: Optional[List[float]] = None,
631
+ guidance_scale: float = 3.5,
632
+ num_images_per_prompt: Optional[int] = 1,
633
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
634
+ latents: Optional[torch.FloatTensor] = None,
635
+ prompt_embeds: Optional[torch.FloatTensor] = None,
636
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
637
+ ip_adapter_image: Optional[PipelineImageInput] = None,
638
+ ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
639
+ negative_ip_adapter_image: Optional[PipelineImageInput] = None,
640
+ negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
641
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
642
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
643
+ output_type: Optional[str] = "pil",
644
+ return_dict: bool = True,
645
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
646
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
647
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
648
+ max_sequence_length: int = 512,
649
+ ):
650
+ r"""
651
+ Function invoked when calling the pipeline for generation.
652
+
653
+ Args:
654
+ prompt (`str` or `List[str]`, *optional*):
655
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
656
+ instead.
657
+ prompt_2 (`str` or `List[str]`, *optional*):
658
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
659
+ will be used instead.
660
+ negative_prompt (`str` or `List[str]`, *optional*):
661
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
662
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is
663
+ not greater than `1`).
664
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
665
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
666
+ `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders.
667
+ true_cfg_scale (`float`, *optional*, defaults to 1.0):
668
+ True classifier-free guidance (guidance scale) is enabled when `true_cfg_scale` > 1 and
669
+ `negative_prompt` is provided.
670
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
671
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
672
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
673
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
674
+ num_inference_steps (`int`, *optional*, defaults to 50):
675
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
676
+ expense of slower inference.
677
+ sigmas (`List[float]`, *optional*):
678
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
679
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
680
+ will be used.
681
+ guidance_scale (`float`, *optional*, defaults to 3.5):
682
+ Embedded guiddance scale is enabled by setting `guidance_scale` > 1. Higher `guidance_scale` encourages
683
+ a model to generate images more aligned with `prompt` at the expense of lower image quality.
684
+
685
+ Guidance-distilled models approximates true classifer-free guidance for `guidance_scale` > 1. Refer to
686
+ the [paper](https://huggingface.co/papers/2210.03142) to learn more.
687
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
688
+ The number of images to generate per prompt.
689
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
690
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
691
+ to make generation deterministic.
692
+ latents (`torch.FloatTensor`, *optional*):
693
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
694
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
695
+ tensor will be generated by sampling using the supplied random `generator`.
696
+ prompt_embeds (`torch.FloatTensor`, *optional*):
697
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
698
+ provided, text embeddings will be generated from `prompt` input argument.
699
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
700
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
701
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
702
+ ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
703
+ ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
704
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
705
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
706
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
707
+ negative_ip_adapter_image:
708
+ (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
709
+ negative_ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
710
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
711
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
712
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
713
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
714
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
715
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
716
+ argument.
717
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
718
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
719
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
720
+ input argument.
721
+ output_type (`str`, *optional*, defaults to `"pil"`):
722
+ The output format of the generate image. Choose between
723
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
724
+ return_dict (`bool`, *optional*, defaults to `True`):
725
+ Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
726
+ joint_attention_kwargs (`dict`, *optional*):
727
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
728
+ `self.processor` in
729
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
730
+ callback_on_step_end (`Callable`, *optional*):
731
+ A function that calls at the end of each denoising steps during the inference. The function is called
732
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
733
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
734
+ `callback_on_step_end_tensor_inputs`.
735
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
736
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
737
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
738
+ `._callback_tensor_inputs` attribute of your pipeline class.
739
+ max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
740
+
741
+ Examples:
742
+
743
+ Returns:
744
+ [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
745
+ is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
746
+ images.
747
+ """
748
+
749
+ height = height or self.default_sample_size * self.vae_scale_factor
750
+ width = width or self.default_sample_size * self.vae_scale_factor
751
+
752
+ # 1. Check inputs. Raise error if not correct
753
+ self.check_inputs(
754
+ prompt,
755
+ prompt_2,
756
+ height,
757
+ width,
758
+ negative_prompt=negative_prompt,
759
+ negative_prompt_2=negative_prompt_2,
760
+ prompt_embeds=prompt_embeds,
761
+ negative_prompt_embeds=negative_prompt_embeds,
762
+ pooled_prompt_embeds=pooled_prompt_embeds,
763
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
764
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
765
+ max_sequence_length=max_sequence_length,
766
+ )
767
+
768
+ self._guidance_scale = guidance_scale
769
+ self._joint_attention_kwargs = joint_attention_kwargs
770
+ self._current_timestep = None
771
+ self._interrupt = False
772
+
773
+ # 2. Define call parameters
774
+ if prompt is not None and isinstance(prompt, str):
775
+ batch_size = 1
776
+ elif prompt is not None and isinstance(prompt, list):
777
+ batch_size = len(prompt)
778
+ else:
779
+ batch_size = prompt_embeds.shape[0]
780
+
781
+ device = self._execution_device
782
+
783
+ lora_scale = (
784
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
785
+ )
786
+ has_neg_prompt = negative_prompt is not None or (
787
+ negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None
788
+ )
789
+ do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
790
+ (
791
+ prompt_embeds,
792
+ pooled_prompt_embeds,
793
+ text_ids,
794
+ ) = self.encode_prompt(
795
+ prompt=prompt,
796
+ prompt_2=prompt_2,
797
+ prompt_embeds=prompt_embeds,
798
+ pooled_prompt_embeds=pooled_prompt_embeds,
799
+ device=device,
800
+ num_images_per_prompt=num_images_per_prompt,
801
+ max_sequence_length=max_sequence_length,
802
+ lora_scale=lora_scale,
803
+ )
804
+ if do_true_cfg:
805
+ (
806
+ negative_prompt_embeds,
807
+ negative_pooled_prompt_embeds,
808
+ negative_text_ids,
809
+ ) = self.encode_prompt(
810
+ prompt=negative_prompt,
811
+ prompt_2=negative_prompt_2,
812
+ prompt_embeds=negative_prompt_embeds,
813
+ pooled_prompt_embeds=negative_pooled_prompt_embeds,
814
+ device=device,
815
+ num_images_per_prompt=num_images_per_prompt,
816
+ max_sequence_length=max_sequence_length,
817
+ lora_scale=lora_scale,
818
+ )
819
+
820
+ # 4. Prepare latent variables
821
+ num_channels_latents = self.transformer.config.in_channels // 4
822
+ latents, latent_image_ids = self.prepare_latents(
823
+ batch_size * num_images_per_prompt,
824
+ num_channels_latents,
825
+ height,
826
+ width,
827
+ prompt_embeds.dtype,
828
+ device,
829
+ generator,
830
+ latents,
831
+ )
832
+
833
+ # 5. Prepare timesteps
834
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
835
+ if hasattr(self.scheduler.config, "use_flow_sigmas") and self.scheduler.config.use_flow_sigmas:
836
+ sigmas = None
837
+ image_seq_len = latents.shape[1]
838
+ mu = calculate_shift(
839
+ image_seq_len,
840
+ self.scheduler.config.get("base_image_seq_len", 256),
841
+ self.scheduler.config.get("max_image_seq_len", 4096),
842
+ self.scheduler.config.get("base_shift", 0.5),
843
+ self.scheduler.config.get("max_shift", 1.15),
844
+ )
845
+ timesteps, num_inference_steps = retrieve_timesteps(
846
+ self.scheduler,
847
+ num_inference_steps,
848
+ device,
849
+ sigmas=sigmas,
850
+ mu=mu,
851
+ )
852
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
853
+ self._num_timesteps = len(timesteps)
854
+
855
+ # handle guidance
856
+ if self.transformer.config.guidance_embeds:
857
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
858
+ guidance = guidance.expand(latents.shape[0])
859
+ else:
860
+ guidance = None
861
+
862
+ if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and (
863
+ negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None
864
+ ):
865
+ negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
866
+ negative_ip_adapter_image = [negative_ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters
867
+
868
+ elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and (
869
+ negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None
870
+ ):
871
+ ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
872
+ ip_adapter_image = [ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters
873
+
874
+ if self.joint_attention_kwargs is None:
875
+ self._joint_attention_kwargs = {}
876
+
877
+ image_embeds = None
878
+ negative_image_embeds = None
879
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
880
+ image_embeds = self.prepare_ip_adapter_image_embeds(
881
+ ip_adapter_image,
882
+ ip_adapter_image_embeds,
883
+ device,
884
+ batch_size * num_images_per_prompt,
885
+ )
886
+ if negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None:
887
+ negative_image_embeds = self.prepare_ip_adapter_image_embeds(
888
+ negative_ip_adapter_image,
889
+ negative_ip_adapter_image_embeds,
890
+ device,
891
+ batch_size * num_images_per_prompt,
892
+ )
893
+
894
+ # 6. Denoising loop
895
+ # We set the index here to remove DtoH sync, helpful especially during compilation.
896
+ # Check out more details here: https://github.com/huggingface/diffusers/pull/11696
897
+ self.scheduler.set_begin_index(0)
898
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
899
+ for i, t in enumerate(timesteps):
900
+ if self.interrupt:
901
+ continue
902
+
903
+ self._current_timestep = t
904
+ if image_embeds is not None:
905
+ self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds
906
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
907
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
908
+
909
+ with torch.cuda.amp.autocast(dtype=latents.dtype), torch.cuda.device(device=latents.device):
910
+ noise_pred = self.transformer(
911
+ hidden_states=latents,
912
+ timestep=timestep / 1000,
913
+ guidance=guidance,
914
+ pooled_projections=pooled_prompt_embeds,
915
+ encoder_hidden_states=prompt_embeds,
916
+ txt_ids=text_ids,
917
+ img_ids=latent_image_ids,
918
+ joint_attention_kwargs=self.joint_attention_kwargs,
919
+ return_dict=False,
920
+ )[0]
921
+
922
+ if do_true_cfg:
923
+ with torch.cuda.amp.autocast(dtype=latents.dtype), torch.cuda.device(device=latents.device):
924
+ neg_noise_pred = self.transformer(
925
+ hidden_states=latents,
926
+ timestep=timestep / 1000,
927
+ guidance=guidance,
928
+ pooled_projections=negative_pooled_prompt_embeds,
929
+ encoder_hidden_states=negative_prompt_embeds,
930
+ txt_ids=negative_text_ids,
931
+ img_ids=latent_image_ids,
932
+ joint_attention_kwargs=self.joint_attention_kwargs,
933
+ return_dict=False,
934
+ )[0]
935
+ noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
936
+
937
+ # compute the previous noisy sample x_t -> x_t-1
938
+ latents_dtype = latents.dtype
939
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
940
+
941
+ if latents.dtype != latents_dtype:
942
+ if torch.backends.mps.is_available():
943
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
944
+ latents = latents.to(latents_dtype)
945
+
946
+ if callback_on_step_end is not None:
947
+ callback_kwargs = {}
948
+ for k in callback_on_step_end_tensor_inputs:
949
+ callback_kwargs[k] = locals()[k]
950
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
951
+
952
+ latents = callback_outputs.pop("latents", latents)
953
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
954
+
955
+ # call the callback, if provided
956
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
957
+ progress_bar.update()
958
+
959
+ if XLA_AVAILABLE:
960
+ xm.mark_step()
961
+
962
+ self._current_timestep = None
963
+
964
+ if output_type == "latent":
965
+ image = latents
966
+ else:
967
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
968
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
969
+ image = self.vae.decode(latents, return_dict=False)[0]
970
+ image = self.image_processor.postprocess(image, output_type=output_type)
971
+
972
+ # Offload all models
973
+ self.maybe_free_model_hooks()
974
+
975
+ if not return_dict:
976
+ return (image,)
977
+
978
+ return FluxPipelineOutput(images=image)
videox_fun/pipeline/pipeline_flux2.py ADDED
@@ -0,0 +1,900 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/flux2/pipeline_flux2.py
2
+ # Copyright 2025 Black Forest Labs and The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import inspect
17
+ from diffusers.utils import (BaseOutput, is_torch_xla_available, logging,
18
+ replace_example_docstring)
19
+ from dataclasses import dataclass
20
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
21
+
22
+ import numpy as np
23
+ import PIL
24
+ import torch
25
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
26
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
27
+ from diffusers.utils import (is_torch_xla_available, logging,
28
+ replace_example_docstring)
29
+ from diffusers.utils.torch_utils import randn_tensor
30
+
31
+ from ..models import (AutoencoderKLFlux2, Flux2ImageProcessor,
32
+ Flux2Transformer2DModel, Mistral3ForConditionalGeneration, AutoProcessor)
33
+
34
+ if is_torch_xla_available():
35
+ import torch_xla.core.xla_model as xm
36
+
37
+ XLA_AVAILABLE = True
38
+ else:
39
+ XLA_AVAILABLE = False
40
+
41
+
42
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
43
+
44
+ EXAMPLE_DOC_STRING = """
45
+ Examples:
46
+ ```py
47
+ >>> import torch
48
+ >>> from diffusers import Flux2Pipeline
49
+
50
+ >>> pipe = Flux2Pipeline.from_pretrained("black-forest-labs/FLUX.2-dev", torch_dtype=torch.bfloat16)
51
+ >>> pipe.to("cuda")
52
+ >>> prompt = "A cat holding a sign that says hello world"
53
+ >>> # Depending on the variant being used, the pipeline call will slightly vary.
54
+ >>> # Refer to the pipeline documentation for more details.
55
+ >>> image = pipe(prompt, num_inference_steps=50, guidance_scale=2.5).images[0]
56
+ >>> image.save("flux.png")
57
+ ```
58
+ """
59
+
60
+
61
+ def format_text_input(prompts: List[str], system_message: str = None):
62
+ # Remove [IMG] tokens from prompts to avoid Pixtral validation issues
63
+ # when truncation is enabled. The processor counts [IMG] tokens and fails
64
+ # if the count changes after truncation.
65
+ cleaned_txt = [prompt.replace("[IMG]", "") for prompt in prompts]
66
+
67
+ return [
68
+ [
69
+ {
70
+ "role": "system",
71
+ "content": [{"type": "text", "text": system_message}],
72
+ },
73
+ {"role": "user", "content": [{"type": "text", "text": prompt}]},
74
+ ]
75
+ for prompt in cleaned_txt
76
+ ]
77
+
78
+
79
+ def compute_empirical_mu(image_seq_len: int, num_steps: int) -> float:
80
+ a1, b1 = 8.73809524e-05, 1.89833333
81
+ a2, b2 = 0.00016927, 0.45666666
82
+
83
+ if image_seq_len > 4300:
84
+ mu = a2 * image_seq_len + b2
85
+ return float(mu)
86
+
87
+ m_200 = a2 * image_seq_len + b2
88
+ m_10 = a1 * image_seq_len + b1
89
+
90
+ a = (m_200 - m_10) / 190.0
91
+ b = m_200 - 200.0 * a
92
+ mu = a * num_steps + b
93
+
94
+ return float(mu)
95
+
96
+
97
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
98
+ def retrieve_timesteps(
99
+ scheduler,
100
+ num_inference_steps: Optional[int] = None,
101
+ device: Optional[Union[str, torch.device]] = None,
102
+ timesteps: Optional[List[int]] = None,
103
+ sigmas: Optional[List[float]] = None,
104
+ **kwargs,
105
+ ):
106
+ r"""
107
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
108
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
109
+
110
+ Args:
111
+ scheduler (`SchedulerMixin`):
112
+ The scheduler to get timesteps from.
113
+ num_inference_steps (`int`):
114
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
115
+ must be `None`.
116
+ device (`str` or `torch.device`, *optional*):
117
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
118
+ timesteps (`List[int]`, *optional*):
119
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
120
+ `num_inference_steps` and `sigmas` must be `None`.
121
+ sigmas (`List[float]`, *optional*):
122
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
123
+ `num_inference_steps` and `timesteps` must be `None`.
124
+
125
+ Returns:
126
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
127
+ second element is the number of inference steps.
128
+ """
129
+ if timesteps is not None and sigmas is not None:
130
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
131
+ if timesteps is not None:
132
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
133
+ if not accepts_timesteps:
134
+ raise ValueError(
135
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
136
+ f" timestep schedules. Please check whether you are using the correct scheduler."
137
+ )
138
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
139
+ timesteps = scheduler.timesteps
140
+ num_inference_steps = len(timesteps)
141
+ elif sigmas is not None:
142
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
143
+ if not accept_sigmas:
144
+ raise ValueError(
145
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
146
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
147
+ )
148
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
149
+ timesteps = scheduler.timesteps
150
+ num_inference_steps = len(timesteps)
151
+ else:
152
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
153
+ timesteps = scheduler.timesteps
154
+ return timesteps, num_inference_steps
155
+
156
+
157
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
158
+ def retrieve_latents(
159
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
160
+ ):
161
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
162
+ return encoder_output.latent_dist.sample(generator)
163
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
164
+ return encoder_output.latent_dist.mode()
165
+ elif hasattr(encoder_output, "latents"):
166
+ return encoder_output.latents
167
+ else:
168
+ raise AttributeError("Could not access latents of provided encoder_output")
169
+
170
+
171
+ @dataclass
172
+ class Flux2PipelineOutput(BaseOutput):
173
+ """
174
+ Output class for Flux2 image generation pipelines.
175
+
176
+ Args:
177
+ images (`List[PIL.Image.Image]` or `torch.Tensor` or `np.ndarray`)
178
+ List of denoised PIL images of length `batch_size` or numpy array or torch tensor of shape `(batch_size,
179
+ height, width, num_channels)`. PIL images or numpy array present the denoised images of the diffusion
180
+ pipeline. Torch tensors can represent either the denoised images or the intermediate latents ready to be
181
+ passed to the decoder.
182
+ """
183
+
184
+ images: Union[List[PIL.Image.Image], np.ndarray]
185
+
186
+
187
+ class Flux2Pipeline(DiffusionPipeline):
188
+ r"""
189
+ The Flux2 pipeline for text-to-image generation.
190
+
191
+ Reference: [https://bfl.ai/blog/flux-2](https://bfl.ai/blog/flux-2)
192
+
193
+ Args:
194
+ transformer ([`Flux2Transformer2DModel`]):
195
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
196
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
197
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
198
+ vae ([`AutoencoderKLFlux2`]):
199
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
200
+ text_encoder ([`Mistral3ForConditionalGeneration`]):
201
+ [Mistral3ForConditionalGeneration](https://huggingface.co/docs/transformers/en/model_doc/mistral3#transformers.Mistral3ForConditionalGeneration)
202
+ tokenizer (`AutoProcessor`):
203
+ Tokenizer of class
204
+ [PixtralProcessor](https://huggingface.co/docs/transformers/en/model_doc/pixtral#transformers.PixtralProcessor).
205
+ """
206
+
207
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
208
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
209
+
210
+ def __init__(
211
+ self,
212
+ scheduler: FlowMatchEulerDiscreteScheduler,
213
+ vae: AutoencoderKLFlux2,
214
+ text_encoder: Mistral3ForConditionalGeneration,
215
+ tokenizer: AutoProcessor,
216
+ transformer: Flux2Transformer2DModel,
217
+ ):
218
+ super().__init__()
219
+
220
+ self.register_modules(
221
+ vae=vae,
222
+ text_encoder=text_encoder,
223
+ tokenizer=tokenizer,
224
+ scheduler=scheduler,
225
+ transformer=transformer,
226
+ )
227
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
228
+ # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
229
+ # by the patch size. So the vae scale factor is multiplied by the patch size to account for this
230
+ self.image_processor = Flux2ImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
231
+ self.tokenizer_max_length = 512
232
+ self.default_sample_size = 128
233
+
234
+ # fmt: off
235
+ self.system_message = "You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object attribution and actions without speculation."
236
+ # fmt: on
237
+
238
+ @staticmethod
239
+ def _get_mistral_3_small_prompt_embeds(
240
+ text_encoder: Mistral3ForConditionalGeneration,
241
+ tokenizer: AutoProcessor,
242
+ prompt: Union[str, List[str]],
243
+ dtype: Optional[torch.dtype] = None,
244
+ device: Optional[torch.device] = None,
245
+ max_sequence_length: int = 512,
246
+ # fmt: off
247
+ system_message: str = "You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object attribution and actions without speculation.",
248
+ # fmt: on
249
+ hidden_states_layers: List[int] = (10, 20, 30),
250
+ ):
251
+ dtype = text_encoder.dtype if dtype is None else dtype
252
+ device = text_encoder.device if device is None else device
253
+
254
+ prompt = [prompt] if isinstance(prompt, str) else prompt
255
+
256
+ # Format input messages
257
+ messages_batch = format_text_input(prompts=prompt, system_message=system_message)
258
+
259
+ # Process all messages at once
260
+ inputs = tokenizer.apply_chat_template(
261
+ messages_batch,
262
+ add_generation_prompt=False,
263
+ tokenize=True,
264
+ return_dict=True,
265
+ return_tensors="pt",
266
+ padding="max_length",
267
+ truncation=True,
268
+ max_length=max_sequence_length,
269
+ )
270
+
271
+ # Move to device
272
+ input_ids = inputs["input_ids"].to(device)
273
+ attention_mask = inputs["attention_mask"].to(device)
274
+
275
+ # Forward pass through the model
276
+ output = text_encoder(
277
+ input_ids=input_ids,
278
+ attention_mask=attention_mask,
279
+ output_hidden_states=True,
280
+ use_cache=False,
281
+ )
282
+
283
+ # Only use outputs from intermediate layers and stack them
284
+ out = torch.stack([output.hidden_states[k] for k in hidden_states_layers], dim=1)
285
+ out = out.to(dtype=dtype, device=device)
286
+
287
+ batch_size, num_channels, seq_len, hidden_dim = out.shape
288
+ prompt_embeds = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, num_channels * hidden_dim)
289
+
290
+ return prompt_embeds
291
+
292
+ @staticmethod
293
+ def _prepare_text_ids(
294
+ x: torch.Tensor, # (B, L, D) or (L, D)
295
+ t_coord: Optional[torch.Tensor] = None,
296
+ ):
297
+ B, L, _ = x.shape
298
+ out_ids = []
299
+
300
+ for i in range(B):
301
+ t = torch.arange(1) if t_coord is None else t_coord[i]
302
+ h = torch.arange(1)
303
+ w = torch.arange(1)
304
+ l = torch.arange(L)
305
+
306
+ coords = torch.cartesian_prod(t, h, w, l)
307
+ out_ids.append(coords)
308
+
309
+ return torch.stack(out_ids)
310
+
311
+ def encode_prompt(
312
+ self,
313
+ prompt: Union[str, List[str]],
314
+ device: Optional[torch.device] = None,
315
+ num_images_per_prompt: int = 1,
316
+ prompt_embeds: Optional[torch.Tensor] = None,
317
+ max_sequence_length: int = 512,
318
+ text_encoder_out_layers: Tuple[int] = (10, 20, 30),
319
+ ):
320
+ device = device or self._execution_device
321
+
322
+ if prompt is None:
323
+ prompt = ""
324
+
325
+ prompt = [prompt] if isinstance(prompt, str) else prompt
326
+
327
+ if prompt_embeds is None:
328
+ prompt_embeds = self._get_mistral_3_small_prompt_embeds(
329
+ text_encoder=self.text_encoder,
330
+ tokenizer=self.tokenizer,
331
+ prompt=prompt,
332
+ device=device,
333
+ max_sequence_length=max_sequence_length,
334
+ system_message=self.system_message,
335
+ hidden_states_layers=text_encoder_out_layers,
336
+ )
337
+
338
+ batch_size, seq_len, _ = prompt_embeds.shape
339
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
340
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
341
+
342
+ text_ids = self._prepare_text_ids(prompt_embeds)
343
+ text_ids = text_ids.to(device)
344
+ return prompt_embeds, text_ids
345
+
346
+ @staticmethod
347
+ def _prepare_latent_ids(
348
+ latents: torch.Tensor, # (B, C, H, W)
349
+ ):
350
+ r"""
351
+ Generates 4D position coordinates (T, H, W, L) for latent tensors.
352
+
353
+ Args:
354
+ latents (torch.Tensor):
355
+ Latent tensor of shape (B, C, H, W)
356
+
357
+ Returns:
358
+ torch.Tensor:
359
+ Position IDs tensor of shape (B, H*W, 4) All batches share the same coordinate structure: T=0,
360
+ H=[0..H-1], W=[0..W-1], L=0
361
+ """
362
+
363
+ batch_size, _, height, width = latents.shape
364
+
365
+ t = torch.arange(1) # [0] - time dimension
366
+ h = torch.arange(height)
367
+ w = torch.arange(width)
368
+ l = torch.arange(1) # [0] - layer dimension
369
+
370
+ # Create position IDs: (H*W, 4)
371
+ latent_ids = torch.cartesian_prod(t, h, w, l)
372
+
373
+ # Expand to batch: (B, H*W, 4)
374
+ latent_ids = latent_ids.unsqueeze(0).expand(batch_size, -1, -1)
375
+
376
+ return latent_ids
377
+
378
+ @staticmethod
379
+ def _prepare_image_ids(
380
+ image_latents: List[torch.Tensor], # [(1, C, H, W), (1, C, H, W), ...]
381
+ scale: int = 10,
382
+ ):
383
+ r"""
384
+ Generates 4D time-space coordinates (T, H, W, L) for a sequence of image latents.
385
+
386
+ This function creates a unique coordinate for every pixel/patch across all input latent with different
387
+ dimensions.
388
+
389
+ Args:
390
+ image_latents (List[torch.Tensor]):
391
+ A list of image latent feature tensors, typically of shape (C, H, W).
392
+ scale (int, optional):
393
+ A factor used to define the time separation (T-coordinate) between latents. T-coordinate for the i-th
394
+ latent is: 'scale + scale * i'. Defaults to 10.
395
+
396
+ Returns:
397
+ torch.Tensor:
398
+ The combined coordinate tensor. Shape: (1, N_total, 4) Where N_total is the sum of (H * W) for all
399
+ input latents.
400
+
401
+ Coordinate Components (Dimension 4):
402
+ - T (Time): The unique index indicating which latent image the coordinate belongs to.
403
+ - H (Height): The row index within that latent image.
404
+ - W (Width): The column index within that latent image.
405
+ - L (Seq. Length): A sequence length dimension, which is always fixed at 0 (size 1)
406
+ """
407
+
408
+ if not isinstance(image_latents, list):
409
+ raise ValueError(f"Expected `image_latents` to be a list, got {type(image_latents)}.")
410
+
411
+ # create time offset for each reference image
412
+ t_coords = [scale + scale * t for t in torch.arange(0, len(image_latents))]
413
+ t_coords = [t.view(-1) for t in t_coords]
414
+
415
+ image_latent_ids = []
416
+ for x, t in zip(image_latents, t_coords):
417
+ x = x.squeeze(0)
418
+ _, height, width = x.shape
419
+
420
+ x_ids = torch.cartesian_prod(t, torch.arange(height), torch.arange(width), torch.arange(1))
421
+ image_latent_ids.append(x_ids)
422
+
423
+ image_latent_ids = torch.cat(image_latent_ids, dim=0)
424
+ image_latent_ids = image_latent_ids.unsqueeze(0)
425
+
426
+ return image_latent_ids
427
+
428
+ @staticmethod
429
+ def _patchify_latents(latents):
430
+ batch_size, num_channels_latents, height, width = latents.shape
431
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
432
+ latents = latents.permute(0, 1, 3, 5, 2, 4)
433
+ latents = latents.reshape(batch_size, num_channels_latents * 4, height // 2, width // 2)
434
+ return latents
435
+
436
+ @staticmethod
437
+ def _unpatchify_latents(latents):
438
+ batch_size, num_channels_latents, height, width = latents.shape
439
+ latents = latents.reshape(batch_size, num_channels_latents // (2 * 2), 2, 2, height, width)
440
+ latents = latents.permute(0, 1, 4, 2, 5, 3)
441
+ latents = latents.reshape(batch_size, num_channels_latents // (2 * 2), height * 2, width * 2)
442
+ return latents
443
+
444
+ @staticmethod
445
+ def _pack_latents(latents):
446
+ """
447
+ pack latents: (batch_size, num_channels, height, width) -> (batch_size, height * width, num_channels)
448
+ """
449
+
450
+ batch_size, num_channels, height, width = latents.shape
451
+ latents = latents.reshape(batch_size, num_channels, height * width).permute(0, 2, 1)
452
+
453
+ return latents
454
+
455
+ @staticmethod
456
+ def _unpack_latents_with_ids(x: torch.Tensor, x_ids: torch.Tensor) -> list[torch.Tensor]:
457
+ """
458
+ using position ids to scatter tokens into place
459
+ """
460
+ x_list = []
461
+ for data, pos in zip(x, x_ids):
462
+ _, ch = data.shape # noqa: F841
463
+ h_ids = pos[:, 1].to(torch.int64)
464
+ w_ids = pos[:, 2].to(torch.int64)
465
+
466
+ h = torch.max(h_ids) + 1
467
+ w = torch.max(w_ids) + 1
468
+
469
+ flat_ids = h_ids * w + w_ids
470
+
471
+ out = torch.zeros((h * w, ch), device=data.device, dtype=data.dtype)
472
+ out.scatter_(0, flat_ids.unsqueeze(1).expand(-1, ch), data)
473
+
474
+ # reshape from (H * W, C) to (H, W, C) and permute to (C, H, W)
475
+
476
+ out = out.view(h, w, ch).permute(2, 0, 1)
477
+ x_list.append(out)
478
+
479
+ return torch.stack(x_list, dim=0)
480
+
481
+ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
482
+ if image.ndim != 4:
483
+ raise ValueError(f"Expected image dims 4, got {image.ndim}.")
484
+
485
+ image_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode="argmax")
486
+ image_latents = self._patchify_latents(image_latents)
487
+
488
+ latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(image_latents.device, image_latents.dtype)
489
+ latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps)
490
+ image_latents = (image_latents - latents_bn_mean) / latents_bn_std
491
+
492
+ return image_latents
493
+
494
+ def prepare_latents(
495
+ self,
496
+ batch_size,
497
+ num_latents_channels,
498
+ height,
499
+ width,
500
+ dtype,
501
+ device,
502
+ generator: torch.Generator,
503
+ latents: Optional[torch.Tensor] = None,
504
+ ):
505
+ # VAE applies 8x compression on images but we must also account for packing which requires
506
+ # latent height and width to be divisible by 2.
507
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
508
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
509
+
510
+ shape = (batch_size, num_latents_channels * 4, height // 2, width // 2)
511
+ if isinstance(generator, list) and len(generator) != batch_size:
512
+ raise ValueError(
513
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
514
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
515
+ )
516
+ if latents is None:
517
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
518
+ else:
519
+ latents = latents.to(device=device, dtype=dtype)
520
+
521
+ latent_ids = self._prepare_latent_ids(latents)
522
+ latent_ids = latent_ids.to(device)
523
+
524
+ latents = self._pack_latents(latents) # [B, C, H, W] -> [B, H*W, C]
525
+ return latents, latent_ids
526
+
527
+ def prepare_image_latents(
528
+ self,
529
+ images: List[torch.Tensor],
530
+ batch_size,
531
+ generator: torch.Generator,
532
+ device,
533
+ dtype,
534
+ ):
535
+ image_latents = []
536
+ for image in images:
537
+ image = image.to(device=device, dtype=dtype)
538
+ imagge_latent = self._encode_vae_image(image=image, generator=generator)
539
+ image_latents.append(imagge_latent) # (1, 128, 32, 32)
540
+
541
+ image_latent_ids = self._prepare_image_ids(image_latents)
542
+
543
+ # Pack each latent and concatenate
544
+ packed_latents = []
545
+ for latent in image_latents:
546
+ # latent: (1, 128, 32, 32)
547
+ packed = self._pack_latents(latent) # (1, 1024, 128)
548
+ packed = packed.squeeze(0) # (1024, 128) - remove batch dim
549
+ packed_latents.append(packed)
550
+
551
+ # Concatenate all reference tokens along sequence dimension
552
+ image_latents = torch.cat(packed_latents, dim=0) # (N*1024, 128)
553
+ image_latents = image_latents.unsqueeze(0) # (1, N*1024, 128)
554
+
555
+ image_latents = image_latents.repeat(batch_size, 1, 1)
556
+ image_latent_ids = image_latent_ids.repeat(batch_size, 1, 1)
557
+ image_latent_ids = image_latent_ids.to(device)
558
+
559
+ return image_latents, image_latent_ids
560
+
561
+ def check_inputs(
562
+ self,
563
+ prompt,
564
+ height,
565
+ width,
566
+ prompt_embeds=None,
567
+ callback_on_step_end_tensor_inputs=None,
568
+ ):
569
+ if (
570
+ height is not None
571
+ and height % (self.vae_scale_factor * 2) != 0
572
+ or width is not None
573
+ and width % (self.vae_scale_factor * 2) != 0
574
+ ):
575
+ logger.warning(
576
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
577
+ )
578
+
579
+ if callback_on_step_end_tensor_inputs is not None and not all(
580
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
581
+ ):
582
+ raise ValueError(
583
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
584
+ )
585
+
586
+ if prompt is not None and prompt_embeds is not None:
587
+ raise ValueError(
588
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
589
+ " only forward one of the two."
590
+ )
591
+ elif prompt is None and prompt_embeds is None:
592
+ raise ValueError(
593
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
594
+ )
595
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
596
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
597
+
598
+ @property
599
+ def guidance_scale(self):
600
+ return self._guidance_scale
601
+
602
+ @property
603
+ def joint_attention_kwargs(self):
604
+ return self._joint_attention_kwargs
605
+
606
+ @property
607
+ def num_timesteps(self):
608
+ return self._num_timesteps
609
+
610
+ @property
611
+ def current_timestep(self):
612
+ return self._current_timestep
613
+
614
+ @property
615
+ def interrupt(self):
616
+ return self._interrupt
617
+
618
+ @torch.no_grad()
619
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
620
+ def __call__(
621
+ self,
622
+ image: Optional[Union[List[PIL.Image.Image], PIL.Image.Image]] = None,
623
+ prompt: Union[str, List[str]] = None,
624
+ height: Optional[int] = None,
625
+ width: Optional[int] = None,
626
+ num_inference_steps: int = 50,
627
+ sigmas: Optional[List[float]] = None,
628
+ guidance_scale: Optional[float] = 4.0,
629
+ num_images_per_prompt: int = 1,
630
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
631
+ latents: Optional[torch.Tensor] = None,
632
+ prompt_embeds: Optional[torch.Tensor] = None,
633
+ output_type: Optional[str] = "pil",
634
+ return_dict: bool = True,
635
+ attention_kwargs: Optional[Dict[str, Any]] = None,
636
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
637
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
638
+ max_sequence_length: int = 512,
639
+ text_encoder_out_layers: Tuple[int] = (10, 20, 30),
640
+ ):
641
+ r"""
642
+ Function invoked when calling the pipeline for generation.
643
+
644
+ Args:
645
+ image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
646
+ `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
647
+ numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
648
+ or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
649
+ list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
650
+ latents as `image`, but if passing latents directly it is not encoded again.
651
+ prompt (`str` or `List[str]`, *optional*):
652
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
653
+ instead.
654
+ guidance_scale (`float`, *optional*, defaults to 1.0):
655
+ Guidance scale as defined in [Classifier-Free Diffusion
656
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
657
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
658
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
659
+ the text `prompt`, usually at the expense of lower image quality.
660
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
661
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
662
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
663
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
664
+ num_inference_steps (`int`, *optional*, defaults to 50):
665
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
666
+ expense of slower inference.
667
+ sigmas (`List[float]`, *optional*):
668
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
669
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
670
+ will be used.
671
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
672
+ The number of images to generate per prompt.
673
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
674
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
675
+ to make generation deterministic.
676
+ latents (`torch.Tensor`, *optional*):
677
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
678
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
679
+ tensor will be generated by sampling using the supplied random `generator`.
680
+ prompt_embeds (`torch.Tensor`, *optional*):
681
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
682
+ provided, text embeddings will be generated from `prompt` input argument.
683
+ output_type (`str`, *optional*, defaults to `"pil"`):
684
+ The output format of the generate image. Choose between
685
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
686
+ return_dict (`bool`, *optional*, defaults to `True`):
687
+ Whether or not to return a [`~pipelines.qwenimage.QwenImagePipelineOutput`] instead of a plain tuple.
688
+ attention_kwargs (`dict`, *optional*):
689
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
690
+ `self.processor` in
691
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
692
+ callback_on_step_end (`Callable`, *optional*):
693
+ A function that calls at the end of each denoising steps during the inference. The function is called
694
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
695
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
696
+ `callback_on_step_end_tensor_inputs`.
697
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
698
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
699
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
700
+ `._callback_tensor_inputs` attribute of your pipeline class.
701
+ max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
702
+ text_encoder_out_layers (`Tuple[int]`):
703
+ Layer indices to use in the `text_encoder` to derive the final prompt embeddings.
704
+
705
+ Examples:
706
+
707
+ Returns:
708
+ [`~pipelines.flux2.Flux2PipelineOutput`] or `tuple`: [`~pipelines.flux2.Flux2PipelineOutput`] if
709
+ `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the
710
+ generated images.
711
+ """
712
+
713
+ # 1. Check inputs. Raise error if not correct
714
+ self.check_inputs(
715
+ prompt=prompt,
716
+ height=height,
717
+ width=width,
718
+ prompt_embeds=prompt_embeds,
719
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
720
+ )
721
+
722
+ self._guidance_scale = guidance_scale
723
+ self._attention_kwargs = attention_kwargs
724
+ self._current_timestep = None
725
+ self._interrupt = False
726
+
727
+ # 2. Define call parameters
728
+ if prompt is not None and isinstance(prompt, str):
729
+ batch_size = 1
730
+ elif prompt is not None and isinstance(prompt, list):
731
+ batch_size = len(prompt)
732
+ else:
733
+ batch_size = prompt_embeds.shape[0]
734
+
735
+ device = self._execution_device
736
+
737
+ # 3. prepare text embeddings
738
+ prompt_embeds, text_ids = self.encode_prompt(
739
+ prompt=prompt,
740
+ prompt_embeds=prompt_embeds,
741
+ device=device,
742
+ num_images_per_prompt=num_images_per_prompt,
743
+ max_sequence_length=max_sequence_length,
744
+ text_encoder_out_layers=text_encoder_out_layers,
745
+ )
746
+
747
+ # 4. process images
748
+ if image is not None and not isinstance(image, list):
749
+ image = [image]
750
+
751
+ condition_images = None
752
+ if image is not None:
753
+ for img in image:
754
+ self.image_processor.check_image_input(img)
755
+
756
+ condition_images = []
757
+ for img in image:
758
+ image_width, image_height = img.size
759
+ if image_width * image_height > 1024 * 1024:
760
+ img = self.image_processor._resize_to_target_area(img, 1024 * 1024)
761
+ image_width, image_height = img.size
762
+
763
+ multiple_of = self.vae_scale_factor * 2
764
+ image_width = (image_width // multiple_of) * multiple_of
765
+ image_height = (image_height // multiple_of) * multiple_of
766
+ img = self.image_processor.preprocess(img, height=image_height, width=image_width, resize_mode="crop")
767
+ condition_images.append(img)
768
+ height = height or image_height
769
+ width = width or image_width
770
+
771
+ height = height or self.default_sample_size * self.vae_scale_factor
772
+ width = width or self.default_sample_size * self.vae_scale_factor
773
+
774
+ # 5. prepare latent variables
775
+ num_channels_latents = self.transformer.config.in_channels // 4
776
+ latents, latent_ids = self.prepare_latents(
777
+ batch_size=batch_size * num_images_per_prompt,
778
+ num_latents_channels=num_channels_latents,
779
+ height=height,
780
+ width=width,
781
+ dtype=prompt_embeds.dtype,
782
+ device=device,
783
+ generator=generator,
784
+ latents=latents,
785
+ )
786
+
787
+ image_latents = None
788
+ image_latent_ids = None
789
+ if condition_images is not None:
790
+ image_latents, image_latent_ids = self.prepare_image_latents(
791
+ images=condition_images,
792
+ batch_size=batch_size * num_images_per_prompt,
793
+ generator=generator,
794
+ device=device,
795
+ dtype=self.vae.dtype,
796
+ )
797
+
798
+ # 6. Prepare timesteps
799
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
800
+ if hasattr(self.scheduler.config, "use_flow_sigmas") and self.scheduler.config.use_flow_sigmas:
801
+ sigmas = None
802
+ image_seq_len = latents.shape[1]
803
+ mu = compute_empirical_mu(image_seq_len=image_seq_len, num_steps=num_inference_steps)
804
+ timesteps, num_inference_steps = retrieve_timesteps(
805
+ self.scheduler,
806
+ num_inference_steps,
807
+ device,
808
+ sigmas=sigmas,
809
+ mu=mu,
810
+ )
811
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
812
+ self._num_timesteps = len(timesteps)
813
+
814
+ # handle guidance
815
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
816
+ guidance = guidance.expand(latents.shape[0])
817
+
818
+ # 7. Denoising loop
819
+ # We set the index here to remove DtoH sync, helpful especially during compilation.
820
+ # Check out more details here: https://github.com/huggingface/diffusers/pull/11696
821
+ self.scheduler.set_begin_index(0)
822
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
823
+ for i, t in enumerate(timesteps):
824
+ if self.interrupt:
825
+ continue
826
+
827
+ self._current_timestep = t
828
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
829
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
830
+
831
+ latent_model_input = latents.to(self.transformer.dtype)
832
+ latent_image_ids = latent_ids
833
+
834
+ if image_latents is not None:
835
+ latent_model_input = torch.cat([latents, image_latents], dim=1).to(self.transformer.dtype)
836
+ latent_image_ids = torch.cat([latent_ids, image_latent_ids], dim=1)
837
+
838
+ noise_pred = self.transformer(
839
+ hidden_states=latent_model_input, # (B, image_seq_len, C)
840
+ timestep=timestep / 1000,
841
+ guidance=guidance,
842
+ encoder_hidden_states=prompt_embeds,
843
+ txt_ids=text_ids, # B, text_seq_len, 4
844
+ img_ids=latent_image_ids, # B, image_seq_len, 4
845
+ joint_attention_kwargs=self._attention_kwargs,
846
+ return_dict=False,
847
+ )[0]
848
+
849
+ noise_pred = noise_pred[:, : latents.size(1) :]
850
+
851
+ # compute the previous noisy sample x_t -> x_t-1
852
+ latents_dtype = latents.dtype
853
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
854
+
855
+ if latents.dtype != latents_dtype:
856
+ if torch.backends.mps.is_available():
857
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
858
+ latents = latents.to(latents_dtype)
859
+
860
+ if callback_on_step_end is not None:
861
+ callback_kwargs = {}
862
+ for k in callback_on_step_end_tensor_inputs:
863
+ callback_kwargs[k] = locals()[k]
864
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
865
+
866
+ latents = callback_outputs.pop("latents", latents)
867
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
868
+
869
+ # call the callback, if provided
870
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
871
+ progress_bar.update()
872
+
873
+ if XLA_AVAILABLE:
874
+ xm.mark_step()
875
+
876
+ self._current_timestep = None
877
+
878
+ if output_type == "latent":
879
+ image = latents
880
+ else:
881
+ torch.save({"pred": latents}, "pred_d.pt")
882
+ latents = self._unpack_latents_with_ids(latents, latent_ids)
883
+
884
+ latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(latents.device, latents.dtype)
885
+ latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps).to(
886
+ latents.device, latents.dtype
887
+ )
888
+ latents = latents * latents_bn_std + latents_bn_mean
889
+ latents = self._unpatchify_latents(latents)
890
+
891
+ image = self.vae.decode(latents, return_dict=False)[0]
892
+ image = self.image_processor.postprocess(image, output_type=output_type)
893
+
894
+ # Offload all models
895
+ self.maybe_free_model_hooks()
896
+
897
+ if not return_dict:
898
+ return (image,)
899
+
900
+ return Flux2PipelineOutput(images=image)
videox_fun/pipeline/pipeline_flux2_control.py ADDED
@@ -0,0 +1,973 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/flux2/pipeline_flux2.py
2
+ # Copyright 2025 Black Forest Labs and The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import inspect
17
+ from diffusers.utils import (BaseOutput, is_torch_xla_available, logging,
18
+ replace_example_docstring)
19
+ from diffusers.image_processor import VaeImageProcessor
20
+ from dataclasses import dataclass
21
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
22
+
23
+ import numpy as np
24
+ import PIL
25
+ import torch
26
+ import torch.nn.functional as F
27
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
28
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
29
+ from diffusers.utils import (is_torch_xla_available, logging,
30
+ replace_example_docstring)
31
+ from diffusers.utils.torch_utils import randn_tensor
32
+
33
+ from ..models import (AutoencoderKLFlux2, Flux2ImageProcessor,
34
+ Flux2ControlTransformer2DModel, Mistral3ForConditionalGeneration, AutoProcessor)
35
+
36
+ if is_torch_xla_available():
37
+ import torch_xla.core.xla_model as xm
38
+
39
+ XLA_AVAILABLE = True
40
+ else:
41
+ XLA_AVAILABLE = False
42
+
43
+
44
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
45
+
46
+ EXAMPLE_DOC_STRING = """
47
+ Examples:
48
+ ```py
49
+ >>> import torch
50
+ >>> from diffusers import Flux2Pipeline
51
+
52
+ >>> pipe = Flux2Pipeline.from_pretrained("black-forest-labs/FLUX.2-dev", torch_dtype=torch.bfloat16)
53
+ >>> pipe.to("cuda")
54
+ >>> prompt = "A cat holding a sign that says hello world"
55
+ >>> # Depending on the variant being used, the pipeline call will slightly vary.
56
+ >>> # Refer to the pipeline documentation for more details.
57
+ >>> image = pipe(prompt, num_inference_steps=50, guidance_scale=2.5).images[0]
58
+ >>> image.save("flux.png")
59
+ ```
60
+ """
61
+
62
+
63
+ def format_text_input(prompts: List[str], system_message: str = None):
64
+ # Remove [IMG] tokens from prompts to avoid Pixtral validation issues
65
+ # when truncation is enabled. The processor counts [IMG] tokens and fails
66
+ # if the count changes after truncation.
67
+ cleaned_txt = [prompt.replace("[IMG]", "") for prompt in prompts]
68
+
69
+ return [
70
+ [
71
+ {
72
+ "role": "system",
73
+ "content": [{"type": "text", "text": system_message}],
74
+ },
75
+ {"role": "user", "content": [{"type": "text", "text": prompt}]},
76
+ ]
77
+ for prompt in cleaned_txt
78
+ ]
79
+
80
+
81
+ def compute_empirical_mu(image_seq_len: int, num_steps: int) -> float:
82
+ a1, b1 = 8.73809524e-05, 1.89833333
83
+ a2, b2 = 0.00016927, 0.45666666
84
+
85
+ if image_seq_len > 4300:
86
+ mu = a2 * image_seq_len + b2
87
+ return float(mu)
88
+
89
+ m_200 = a2 * image_seq_len + b2
90
+ m_10 = a1 * image_seq_len + b1
91
+
92
+ a = (m_200 - m_10) / 190.0
93
+ b = m_200 - 200.0 * a
94
+ mu = a * num_steps + b
95
+
96
+ return float(mu)
97
+
98
+
99
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
100
+ def retrieve_timesteps(
101
+ scheduler,
102
+ num_inference_steps: Optional[int] = None,
103
+ device: Optional[Union[str, torch.device]] = None,
104
+ timesteps: Optional[List[int]] = None,
105
+ sigmas: Optional[List[float]] = None,
106
+ **kwargs,
107
+ ):
108
+ r"""
109
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
110
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
111
+
112
+ Args:
113
+ scheduler (`SchedulerMixin`):
114
+ The scheduler to get timesteps from.
115
+ num_inference_steps (`int`):
116
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
117
+ must be `None`.
118
+ device (`str` or `torch.device`, *optional*):
119
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
120
+ timesteps (`List[int]`, *optional*):
121
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
122
+ `num_inference_steps` and `sigmas` must be `None`.
123
+ sigmas (`List[float]`, *optional*):
124
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
125
+ `num_inference_steps` and `timesteps` must be `None`.
126
+
127
+ Returns:
128
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
129
+ second element is the number of inference steps.
130
+ """
131
+ if timesteps is not None and sigmas is not None:
132
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
133
+ if timesteps is not None:
134
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
135
+ if not accepts_timesteps:
136
+ raise ValueError(
137
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
138
+ f" timestep schedules. Please check whether you are using the correct scheduler."
139
+ )
140
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
141
+ timesteps = scheduler.timesteps
142
+ num_inference_steps = len(timesteps)
143
+ elif sigmas is not None:
144
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
145
+ if not accept_sigmas:
146
+ raise ValueError(
147
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
148
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
149
+ )
150
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
151
+ timesteps = scheduler.timesteps
152
+ num_inference_steps = len(timesteps)
153
+ else:
154
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
155
+ timesteps = scheduler.timesteps
156
+ return timesteps, num_inference_steps
157
+
158
+
159
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
160
+ def retrieve_latents(
161
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
162
+ ):
163
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
164
+ return encoder_output.latent_dist.sample(generator)
165
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
166
+ return encoder_output.latent_dist.mode()
167
+ elif hasattr(encoder_output, "latents"):
168
+ return encoder_output.latents
169
+ else:
170
+ raise AttributeError("Could not access latents of provided encoder_output")
171
+
172
+
173
+ @dataclass
174
+ class Flux2PipelineOutput(BaseOutput):
175
+ """
176
+ Output class for Flux2 image generation pipelines.
177
+
178
+ Args:
179
+ images (`List[PIL.Image.Image]` or `torch.Tensor` or `np.ndarray`)
180
+ List of denoised PIL images of length `batch_size` or numpy array or torch tensor of shape `(batch_size,
181
+ height, width, num_channels)`. PIL images or numpy array present the denoised images of the diffusion
182
+ pipeline. Torch tensors can represent either the denoised images or the intermediate latents ready to be
183
+ passed to the decoder.
184
+ """
185
+
186
+ images: Union[List[PIL.Image.Image], np.ndarray]
187
+
188
+
189
+ class Flux2ControlPipeline(DiffusionPipeline):
190
+ r"""
191
+ The Flux2 pipeline for text-to-image generation.
192
+
193
+ Reference: [https://bfl.ai/blog/flux-2](https://bfl.ai/blog/flux-2)
194
+
195
+ Args:
196
+ transformer ([`Flux2ControlTransformer2DModel`]):
197
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
198
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
199
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
200
+ vae ([`AutoencoderKLFlux2`]):
201
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
202
+ text_encoder ([`Mistral3ForConditionalGeneration`]):
203
+ [Mistral3ForConditionalGeneration](https://huggingface.co/docs/transformers/en/model_doc/mistral3#transformers.Mistral3ForConditionalGeneration)
204
+ tokenizer (`AutoProcessor`):
205
+ Tokenizer of class
206
+ [PixtralProcessor](https://huggingface.co/docs/transformers/en/model_doc/pixtral#transformers.PixtralProcessor).
207
+ """
208
+
209
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
210
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
211
+
212
+ def __init__(
213
+ self,
214
+ scheduler: FlowMatchEulerDiscreteScheduler,
215
+ vae: AutoencoderKLFlux2,
216
+ text_encoder: Mistral3ForConditionalGeneration,
217
+ tokenizer: AutoProcessor,
218
+ transformer: Flux2ControlTransformer2DModel,
219
+ ):
220
+ super().__init__()
221
+
222
+ self.register_modules(
223
+ vae=vae,
224
+ text_encoder=text_encoder,
225
+ tokenizer=tokenizer,
226
+ scheduler=scheduler,
227
+ transformer=transformer,
228
+ )
229
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
230
+ # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
231
+ # by the patch size. So the vae scale factor is multiplied by the patch size to account for this
232
+ self.image_processor = Flux2ImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
233
+ self.diffusers_image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
234
+ self.mask_processor = VaeImageProcessor(
235
+ vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
236
+ )
237
+ self.tokenizer_max_length = 512
238
+ self.default_sample_size = 128
239
+
240
+ # fmt: off
241
+ self.system_message = "You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object attribution and actions without speculation."
242
+ # fmt: on
243
+
244
+ @staticmethod
245
+ def _get_mistral_3_small_prompt_embeds(
246
+ text_encoder: Mistral3ForConditionalGeneration,
247
+ tokenizer: AutoProcessor,
248
+ prompt: Union[str, List[str]],
249
+ dtype: Optional[torch.dtype] = None,
250
+ device: Optional[torch.device] = None,
251
+ max_sequence_length: int = 512,
252
+ # fmt: off
253
+ system_message: str = "You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object attribution and actions without speculation.",
254
+ # fmt: on
255
+ hidden_states_layers: List[int] = (10, 20, 30),
256
+ ):
257
+ dtype = text_encoder.dtype if dtype is None else dtype
258
+ device = text_encoder.device if device is None else device
259
+
260
+ prompt = [prompt] if isinstance(prompt, str) else prompt
261
+
262
+ # Format input messages
263
+ messages_batch = format_text_input(prompts=prompt, system_message=system_message)
264
+
265
+ # Process all messages at once
266
+ inputs = tokenizer.apply_chat_template(
267
+ messages_batch,
268
+ add_generation_prompt=False,
269
+ tokenize=True,
270
+ return_dict=True,
271
+ return_tensors="pt",
272
+ padding="max_length",
273
+ truncation=True,
274
+ max_length=max_sequence_length,
275
+ )
276
+
277
+ # Move to device
278
+ input_ids = inputs["input_ids"].to(device)
279
+ attention_mask = inputs["attention_mask"].to(device)
280
+
281
+ # Forward pass through the model
282
+ output = text_encoder(
283
+ input_ids=input_ids,
284
+ attention_mask=attention_mask,
285
+ output_hidden_states=True,
286
+ use_cache=False,
287
+ )
288
+
289
+ # Only use outputs from intermediate layers and stack them
290
+ out = torch.stack([output.hidden_states[k] for k in hidden_states_layers], dim=1)
291
+ out = out.to(dtype=dtype, device=device)
292
+
293
+ batch_size, num_channels, seq_len, hidden_dim = out.shape
294
+ prompt_embeds = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, num_channels * hidden_dim)
295
+
296
+ return prompt_embeds
297
+
298
+ @staticmethod
299
+ def _prepare_text_ids(
300
+ x: torch.Tensor, # (B, L, D) or (L, D)
301
+ t_coord: Optional[torch.Tensor] = None,
302
+ ):
303
+ B, L, _ = x.shape
304
+ out_ids = []
305
+
306
+ for i in range(B):
307
+ t = torch.arange(1) if t_coord is None else t_coord[i]
308
+ h = torch.arange(1)
309
+ w = torch.arange(1)
310
+ l = torch.arange(L)
311
+
312
+ coords = torch.cartesian_prod(t, h, w, l)
313
+ out_ids.append(coords)
314
+
315
+ return torch.stack(out_ids)
316
+
317
+ def encode_prompt(
318
+ self,
319
+ prompt: Union[str, List[str]],
320
+ device: Optional[torch.device] = None,
321
+ num_images_per_prompt: int = 1,
322
+ prompt_embeds: Optional[torch.Tensor] = None,
323
+ max_sequence_length: int = 512,
324
+ text_encoder_out_layers: Tuple[int] = (10, 20, 30),
325
+ ):
326
+ device = device or self._execution_device
327
+
328
+ if prompt is None:
329
+ prompt = ""
330
+
331
+ prompt = [prompt] if isinstance(prompt, str) else prompt
332
+
333
+ if prompt_embeds is None:
334
+ prompt_embeds = self._get_mistral_3_small_prompt_embeds(
335
+ text_encoder=self.text_encoder,
336
+ tokenizer=self.tokenizer,
337
+ prompt=prompt,
338
+ device=device,
339
+ max_sequence_length=max_sequence_length,
340
+ system_message=self.system_message,
341
+ hidden_states_layers=text_encoder_out_layers,
342
+ )
343
+
344
+ batch_size, seq_len, _ = prompt_embeds.shape
345
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
346
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
347
+
348
+ text_ids = self._prepare_text_ids(prompt_embeds)
349
+ text_ids = text_ids.to(device)
350
+ return prompt_embeds, text_ids
351
+
352
+ @staticmethod
353
+ def _prepare_latent_ids(
354
+ latents: torch.Tensor, # (B, C, H, W)
355
+ ):
356
+ r"""
357
+ Generates 4D position coordinates (T, H, W, L) for latent tensors.
358
+
359
+ Args:
360
+ latents (torch.Tensor):
361
+ Latent tensor of shape (B, C, H, W)
362
+
363
+ Returns:
364
+ torch.Tensor:
365
+ Position IDs tensor of shape (B, H*W, 4) All batches share the same coordinate structure: T=0,
366
+ H=[0..H-1], W=[0..W-1], L=0
367
+ """
368
+
369
+ batch_size, _, height, width = latents.shape
370
+
371
+ t = torch.arange(1) # [0] - time dimension
372
+ h = torch.arange(height)
373
+ w = torch.arange(width)
374
+ l = torch.arange(1) # [0] - layer dimension
375
+
376
+ # Create position IDs: (H*W, 4)
377
+ latent_ids = torch.cartesian_prod(t, h, w, l)
378
+
379
+ # Expand to batch: (B, H*W, 4)
380
+ latent_ids = latent_ids.unsqueeze(0).expand(batch_size, -1, -1)
381
+
382
+ return latent_ids
383
+
384
+ @staticmethod
385
+ def _prepare_image_ids(
386
+ image_latents: List[torch.Tensor], # [(1, C, H, W), (1, C, H, W), ...]
387
+ scale: int = 10,
388
+ ):
389
+ r"""
390
+ Generates 4D time-space coordinates (T, H, W, L) for a sequence of image latents.
391
+
392
+ This function creates a unique coordinate for every pixel/patch across all input latent with different
393
+ dimensions.
394
+
395
+ Args:
396
+ image_latents (List[torch.Tensor]):
397
+ A list of image latent feature tensors, typically of shape (C, H, W).
398
+ scale (int, optional):
399
+ A factor used to define the time separation (T-coordinate) between latents. T-coordinate for the i-th
400
+ latent is: 'scale + scale * i'. Defaults to 10.
401
+
402
+ Returns:
403
+ torch.Tensor:
404
+ The combined coordinate tensor. Shape: (1, N_total, 4) Where N_total is the sum of (H * W) for all
405
+ input latents.
406
+
407
+ Coordinate Components (Dimension 4):
408
+ - T (Time): The unique index indicating which latent image the coordinate belongs to.
409
+ - H (Height): The row index within that latent image.
410
+ - W (Width): The column index within that latent image.
411
+ - L (Seq. Length): A sequence length dimension, which is always fixed at 0 (size 1)
412
+ """
413
+
414
+ if not isinstance(image_latents, list):
415
+ raise ValueError(f"Expected `image_latents` to be a list, got {type(image_latents)}.")
416
+
417
+ # create time offset for each reference image
418
+ t_coords = [scale + scale * t for t in torch.arange(0, len(image_latents))]
419
+ t_coords = [t.view(-1) for t in t_coords]
420
+
421
+ image_latent_ids = []
422
+ for x, t in zip(image_latents, t_coords):
423
+ x = x.squeeze(0)
424
+ _, height, width = x.shape
425
+
426
+ x_ids = torch.cartesian_prod(t, torch.arange(height), torch.arange(width), torch.arange(1))
427
+ image_latent_ids.append(x_ids)
428
+
429
+ image_latent_ids = torch.cat(image_latent_ids, dim=0)
430
+ image_latent_ids = image_latent_ids.unsqueeze(0)
431
+
432
+ return image_latent_ids
433
+
434
+ @staticmethod
435
+ def _patchify_latents(latents):
436
+ batch_size, num_channels_latents, height, width = latents.shape
437
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
438
+ latents = latents.permute(0, 1, 3, 5, 2, 4)
439
+ latents = latents.reshape(batch_size, num_channels_latents * 4, height // 2, width // 2)
440
+ return latents
441
+
442
+ @staticmethod
443
+ def _unpatchify_latents(latents):
444
+ batch_size, num_channels_latents, height, width = latents.shape
445
+ latents = latents.reshape(batch_size, num_channels_latents // (2 * 2), 2, 2, height, width)
446
+ latents = latents.permute(0, 1, 4, 2, 5, 3)
447
+ latents = latents.reshape(batch_size, num_channels_latents // (2 * 2), height * 2, width * 2)
448
+ return latents
449
+
450
+ @staticmethod
451
+ def _pack_latents(latents):
452
+ """
453
+ pack latents: (batch_size, num_channels, height, width) -> (batch_size, height * width, num_channels)
454
+ """
455
+
456
+ batch_size, num_channels, height, width = latents.shape
457
+ latents = latents.reshape(batch_size, num_channels, height * width).permute(0, 2, 1)
458
+
459
+ return latents
460
+
461
+ @staticmethod
462
+ def _unpack_latents_with_ids(x: torch.Tensor, x_ids: torch.Tensor) -> list[torch.Tensor]:
463
+ """
464
+ using position ids to scatter tokens into place
465
+ """
466
+ x_list = []
467
+ for data, pos in zip(x, x_ids):
468
+ _, ch = data.shape # noqa: F841
469
+ h_ids = pos[:, 1].to(torch.int64)
470
+ w_ids = pos[:, 2].to(torch.int64)
471
+
472
+ h = torch.max(h_ids) + 1
473
+ w = torch.max(w_ids) + 1
474
+
475
+ flat_ids = h_ids * w + w_ids
476
+
477
+ out = torch.zeros((h * w, ch), device=data.device, dtype=data.dtype)
478
+ out.scatter_(0, flat_ids.unsqueeze(1).expand(-1, ch), data)
479
+
480
+ # reshape from (H * W, C) to (H, W, C) and permute to (C, H, W)
481
+
482
+ out = out.view(h, w, ch).permute(2, 0, 1)
483
+ x_list.append(out)
484
+
485
+ return torch.stack(x_list, dim=0)
486
+
487
+ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
488
+ if image.ndim != 4:
489
+ raise ValueError(f"Expected image dims 4, got {image.ndim}.")
490
+
491
+ image_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode="argmax")
492
+ image_latents = self._patchify_latents(image_latents)
493
+
494
+ latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(image_latents.device, image_latents.dtype)
495
+ latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps)
496
+ image_latents = (image_latents - latents_bn_mean) / latents_bn_std
497
+
498
+ return image_latents
499
+
500
+ def prepare_latents(
501
+ self,
502
+ batch_size,
503
+ num_latents_channels,
504
+ height,
505
+ width,
506
+ dtype,
507
+ device,
508
+ generator: torch.Generator,
509
+ latents: Optional[torch.Tensor] = None,
510
+ ):
511
+ # VAE applies 8x compression on images but we must also account for packing which requires
512
+ # latent height and width to be divisible by 2.
513
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
514
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
515
+
516
+ shape = (batch_size, num_latents_channels * 4, height // 2, width // 2)
517
+ if isinstance(generator, list) and len(generator) != batch_size:
518
+ raise ValueError(
519
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
520
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
521
+ )
522
+ if latents is None:
523
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
524
+ else:
525
+ latents = latents.to(device=device, dtype=dtype)
526
+
527
+ latent_ids = self._prepare_latent_ids(latents)
528
+ latent_ids = latent_ids.to(device)
529
+
530
+ latents = self._pack_latents(latents) # [B, C, H, W] -> [B, H*W, C]
531
+ return latents, latent_ids
532
+
533
+ def prepare_image_latents(
534
+ self,
535
+ images: List[torch.Tensor],
536
+ batch_size,
537
+ generator: torch.Generator,
538
+ device,
539
+ dtype,
540
+ ):
541
+ image_latents = []
542
+ for image in images:
543
+ image = image.to(device=device, dtype=dtype)
544
+ imagge_latent = self._encode_vae_image(image=image, generator=generator)
545
+ image_latents.append(imagge_latent) # (1, 128, 32, 32)
546
+
547
+ image_latent_ids = self._prepare_image_ids(image_latents)
548
+
549
+ # Pack each latent and concatenate
550
+ packed_latents = []
551
+ for latent in image_latents:
552
+ # latent: (1, 128, 32, 32)
553
+ packed = self._pack_latents(latent) # (1, 1024, 128)
554
+ packed = packed.squeeze(0) # (1024, 128) - remove batch dim
555
+ packed_latents.append(packed)
556
+
557
+ # Concatenate all reference tokens along sequence dimension
558
+ image_latents = torch.cat(packed_latents, dim=0) # (N*1024, 128)
559
+ image_latents = image_latents.unsqueeze(0) # (1, N*1024, 128)
560
+
561
+ image_latents = image_latents.repeat(batch_size, 1, 1)
562
+ image_latent_ids = image_latent_ids.repeat(batch_size, 1, 1)
563
+ image_latent_ids = image_latent_ids.to(device)
564
+
565
+ return image_latents, image_latent_ids
566
+
567
+ def check_inputs(
568
+ self,
569
+ prompt,
570
+ height,
571
+ width,
572
+ prompt_embeds=None,
573
+ callback_on_step_end_tensor_inputs=None,
574
+ ):
575
+ if (
576
+ height is not None
577
+ and height % (self.vae_scale_factor * 2) != 0
578
+ or width is not None
579
+ and width % (self.vae_scale_factor * 2) != 0
580
+ ):
581
+ logger.warning(
582
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
583
+ )
584
+
585
+ if callback_on_step_end_tensor_inputs is not None and not all(
586
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
587
+ ):
588
+ raise ValueError(
589
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
590
+ )
591
+
592
+ if prompt is not None and prompt_embeds is not None:
593
+ raise ValueError(
594
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
595
+ " only forward one of the two."
596
+ )
597
+ elif prompt is None and prompt_embeds is None:
598
+ raise ValueError(
599
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
600
+ )
601
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
602
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
603
+
604
+ @property
605
+ def guidance_scale(self):
606
+ return self._guidance_scale
607
+
608
+ @property
609
+ def joint_attention_kwargs(self):
610
+ return self._joint_attention_kwargs
611
+
612
+ @property
613
+ def num_timesteps(self):
614
+ return self._num_timesteps
615
+
616
+ @property
617
+ def current_timestep(self):
618
+ return self._current_timestep
619
+
620
+ @property
621
+ def interrupt(self):
622
+ return self._interrupt
623
+
624
+ @torch.no_grad()
625
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
626
+ def __call__(
627
+ self,
628
+ prompt: Union[str, List[str]] = None,
629
+ height: Optional[int] = None,
630
+ width: Optional[int] = None,
631
+
632
+ image: Optional[Union[List[PIL.Image.Image], PIL.Image.Image]] = None,
633
+ inpaint_image: Optional[Union[List[PIL.Image.Image], PIL.Image.Image]] = None,
634
+ mask_image: Union[torch.FloatTensor] = None,
635
+ control_image: Union[torch.FloatTensor] = None,
636
+ control_context_scale: float = 1.0,
637
+
638
+ num_inference_steps: int = 50,
639
+ sigmas: Optional[List[float]] = None,
640
+ guidance_scale: Optional[float] = 4.0,
641
+ num_images_per_prompt: int = 1,
642
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
643
+ latents: Optional[torch.Tensor] = None,
644
+ prompt_embeds: Optional[torch.Tensor] = None,
645
+ output_type: Optional[str] = "pil",
646
+ return_dict: bool = True,
647
+ attention_kwargs: Optional[Dict[str, Any]] = None,
648
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
649
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
650
+ max_sequence_length: int = 512,
651
+ text_encoder_out_layers: Tuple[int] = (10, 20, 30),
652
+ ):
653
+ r"""
654
+ Function invoked when calling the pipeline for generation.
655
+
656
+ Args:
657
+ image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
658
+ `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
659
+ numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
660
+ or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
661
+ list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
662
+ latents as `image`, but if passing latents directly it is not encoded again.
663
+ prompt (`str` or `List[str]`, *optional*):
664
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
665
+ instead.
666
+ guidance_scale (`float`, *optional*, defaults to 1.0):
667
+ Guidance scale as defined in [Classifier-Free Diffusion
668
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
669
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
670
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
671
+ the text `prompt`, usually at the expense of lower image quality.
672
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
673
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
674
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
675
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
676
+ num_inference_steps (`int`, *optional*, defaults to 50):
677
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
678
+ expense of slower inference.
679
+ sigmas (`List[float]`, *optional*):
680
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
681
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
682
+ will be used.
683
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
684
+ The number of images to generate per prompt.
685
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
686
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
687
+ to make generation deterministic.
688
+ latents (`torch.Tensor`, *optional*):
689
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
690
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
691
+ tensor will be generated by sampling using the supplied random `generator`.
692
+ prompt_embeds (`torch.Tensor`, *optional*):
693
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
694
+ provided, text embeddings will be generated from `prompt` input argument.
695
+ output_type (`str`, *optional*, defaults to `"pil"`):
696
+ The output format of the generate image. Choose between
697
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
698
+ return_dict (`bool`, *optional*, defaults to `True`):
699
+ Whether or not to return a [`~pipelines.qwenimage.QwenImagePipelineOutput`] instead of a plain tuple.
700
+ attention_kwargs (`dict`, *optional*):
701
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
702
+ `self.processor` in
703
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
704
+ callback_on_step_end (`Callable`, *optional*):
705
+ A function that calls at the end of each denoising steps during the inference. The function is called
706
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
707
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
708
+ `callback_on_step_end_tensor_inputs`.
709
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
710
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
711
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
712
+ `._callback_tensor_inputs` attribute of your pipeline class.
713
+ max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
714
+ text_encoder_out_layers (`Tuple[int]`):
715
+ Layer indices to use in the `text_encoder` to derive the final prompt embeddings.
716
+
717
+ Examples:
718
+
719
+ Returns:
720
+ [`~pipelines.flux2.Flux2PipelineOutput`] or `tuple`: [`~pipelines.flux2.Flux2PipelineOutput`] if
721
+ `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the
722
+ generated images.
723
+ """
724
+
725
+ # 1. Check inputs. Raise error if not correct
726
+ self.check_inputs(
727
+ prompt=prompt,
728
+ height=height,
729
+ width=width,
730
+ prompt_embeds=prompt_embeds,
731
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
732
+ )
733
+
734
+ self._guidance_scale = guidance_scale
735
+ self._attention_kwargs = attention_kwargs
736
+ self._current_timestep = None
737
+ self._interrupt = False
738
+
739
+ # 2. Define call parameters
740
+ if prompt is not None and isinstance(prompt, str):
741
+ batch_size = 1
742
+ elif prompt is not None and isinstance(prompt, list):
743
+ batch_size = len(prompt)
744
+ else:
745
+ batch_size = prompt_embeds.shape[0]
746
+
747
+ device = self._execution_device
748
+ weight_dtype = self.text_encoder.dtype
749
+
750
+ latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(device, weight_dtype)
751
+ latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps).to(
752
+ device, weight_dtype
753
+ )
754
+ height = height or self.default_sample_size * self.vae_scale_factor
755
+ width = width or self.default_sample_size * self.vae_scale_factor
756
+ num_channels_latents = self.transformer.config.in_channels // 4
757
+
758
+ # Prepare mask latent variables
759
+ if mask_image is not None:
760
+ mask_condition = self.mask_processor.preprocess(mask_image, height=height, width=width)
761
+ mask_condition = torch.tile(mask_condition, [1, 3, 1, 1]).to(dtype=weight_dtype, device=device)
762
+
763
+ if inpaint_image is not None:
764
+ init_image = self.diffusers_image_processor.preprocess(inpaint_image, height=height, width=width)
765
+ init_image = init_image.to(dtype=weight_dtype, device=device) * (mask_condition < 0.5)
766
+ inpaint_latent = self.vae.encode(init_image)[0].mode()
767
+ else:
768
+ inpaint_latent = torch.zeros((batch_size, num_channels_latents * 4, height // 2 // self.vae_scale_factor, width // 2 // self.vae_scale_factor)).to(device, weight_dtype)
769
+
770
+ if control_image is not None:
771
+ control_image = self.diffusers_image_processor.preprocess(control_image, height=height, width=width)
772
+ control_image = control_image.to(dtype=weight_dtype, device=device)
773
+ control_latents = self.vae.encode(control_image)[0].mode()
774
+ else:
775
+ control_latents = torch.zeros_like(inpaint_latent)
776
+
777
+ mask_condition = F.interpolate(1 - mask_condition[:, :1], size=control_latents.size()[-2:], mode='nearest').to(device, weight_dtype)
778
+ mask_condition = self._patchify_latents(mask_condition)
779
+ mask_condition = self._pack_latents(mask_condition)
780
+
781
+ if inpaint_image is not None:
782
+ inpaint_latent = self._patchify_latents(inpaint_latent)
783
+ inpaint_latent = (inpaint_latent - latents_bn_mean) / latents_bn_std
784
+ inpaint_latent = self._pack_latents(inpaint_latent)
785
+ else:
786
+ inpaint_latent = self._patchify_latents(inpaint_latent)
787
+ inpaint_latent = self._pack_latents(inpaint_latent)
788
+
789
+ if control_image is not None:
790
+ control_latents = self._patchify_latents(control_latents)
791
+ control_latents = (control_latents - latents_bn_mean) / latents_bn_std
792
+ control_latents = self._pack_latents(control_latents)
793
+ else:
794
+ control_latents = self._patchify_latents(control_latents)
795
+ control_latents = self._pack_latents(control_latents)
796
+ control_context = torch.concat([control_latents, mask_condition, inpaint_latent], dim=2)
797
+
798
+ # 3. prepare text embeddings
799
+ prompt_embeds, text_ids = self.encode_prompt(
800
+ prompt=prompt,
801
+ prompt_embeds=prompt_embeds,
802
+ device=device,
803
+ num_images_per_prompt=num_images_per_prompt,
804
+ max_sequence_length=max_sequence_length,
805
+ text_encoder_out_layers=text_encoder_out_layers,
806
+ )
807
+
808
+ # 4. process images
809
+ if image is not None and not isinstance(image, list):
810
+ image = [image]
811
+
812
+ condition_images = None
813
+ if image is not None:
814
+ for img in image:
815
+ self.image_processor.check_image_input(img)
816
+
817
+ condition_images = []
818
+ for img in image:
819
+ image_width, image_height = img.size
820
+ if image_width * image_height > 1024 * 1024:
821
+ img = self.image_processor._resize_to_target_area(img, 1024 * 1024)
822
+ image_width, image_height = img.size
823
+
824
+ multiple_of = self.vae_scale_factor * 2
825
+ image_width = (image_width // multiple_of) * multiple_of
826
+ image_height = (image_height // multiple_of) * multiple_of
827
+ img = self.image_processor.preprocess(img, height=image_height, width=image_width, resize_mode="crop")
828
+ condition_images.append(img)
829
+ height = height or image_height
830
+ width = width or image_width
831
+
832
+ # 5. prepare latent variables
833
+ latents, latent_ids = self.prepare_latents(
834
+ batch_size=batch_size * num_images_per_prompt,
835
+ num_latents_channels=num_channels_latents,
836
+ height=height,
837
+ width=width,
838
+ dtype=prompt_embeds.dtype,
839
+ device=device,
840
+ generator=generator,
841
+ latents=latents,
842
+ )
843
+
844
+ image_latents = None
845
+ image_latent_ids = None
846
+ if condition_images is not None:
847
+ image_latents, image_latent_ids = self.prepare_image_latents(
848
+ images=condition_images,
849
+ batch_size=batch_size * num_images_per_prompt,
850
+ generator=generator,
851
+ device=device,
852
+ dtype=self.vae.dtype,
853
+ )
854
+
855
+ # 6. Prepare timesteps
856
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
857
+ if hasattr(self.scheduler.config, "use_flow_sigmas") and self.scheduler.config.use_flow_sigmas:
858
+ sigmas = None
859
+ image_seq_len = latents.shape[1]
860
+ mu = compute_empirical_mu(image_seq_len=image_seq_len, num_steps=num_inference_steps)
861
+ timesteps, num_inference_steps = retrieve_timesteps(
862
+ self.scheduler,
863
+ num_inference_steps,
864
+ device,
865
+ sigmas=sigmas,
866
+ mu=mu,
867
+ )
868
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
869
+ self._num_timesteps = len(timesteps)
870
+
871
+ # handle guidance
872
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
873
+ guidance = guidance.expand(latents.shape[0])
874
+
875
+ # 7. Denoising loop
876
+ # We set the index here to remove DtoH sync, helpful especially during compilation.
877
+ # Check out more details here: https://github.com/huggingface/diffusers/pull/11696
878
+ self.scheduler.set_begin_index(0)
879
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
880
+ for i, t in enumerate(timesteps):
881
+ if self.interrupt:
882
+ continue
883
+
884
+ self._current_timestep = t
885
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
886
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
887
+
888
+ latent_model_input = latents.to(self.transformer.dtype)
889
+ control_context_input = control_context.to(self.transformer.dtype)
890
+ latent_image_ids = latent_ids
891
+
892
+ if image_latents is not None:
893
+ latent_model_input = torch.cat([latents, image_latents], dim=1).to(self.transformer.dtype)
894
+ latent_image_ids = torch.cat([latent_ids, image_latent_ids], dim=1)
895
+
896
+ local_bs, local_length, local_c = control_context.size()
897
+ control_context_input = torch.cat(
898
+ [
899
+ control_context,
900
+ torch.zeros(
901
+ [
902
+ local_bs,
903
+ image_latents.size()[1],
904
+ local_c
905
+ ]
906
+ ).to(control_context.device, control_context.dtype)],
907
+ dim=1
908
+ ).to(self.transformer.dtype)
909
+
910
+ noise_pred = self.transformer(
911
+ hidden_states=latent_model_input, # (B, image_seq_len, C)
912
+ timestep=timestep / 1000,
913
+ guidance=guidance,
914
+ encoder_hidden_states=prompt_embeds,
915
+ txt_ids=text_ids, # B, text_seq_len, 4
916
+ img_ids=latent_image_ids, # B, image_seq_len, 4
917
+ joint_attention_kwargs=self._attention_kwargs,
918
+ control_context=control_context_input,
919
+ control_context_scale=control_context_scale,
920
+ return_dict=False,
921
+ )[0]
922
+
923
+ noise_pred = noise_pred[:, : latents.size(1) :]
924
+
925
+ # compute the previous noisy sample x_t -> x_t-1
926
+ latents_dtype = latents.dtype
927
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
928
+
929
+ if latents.dtype != latents_dtype:
930
+ if torch.backends.mps.is_available():
931
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
932
+ latents = latents.to(latents_dtype)
933
+
934
+ if callback_on_step_end is not None:
935
+ callback_kwargs = {}
936
+ for k in callback_on_step_end_tensor_inputs:
937
+ callback_kwargs[k] = locals()[k]
938
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
939
+
940
+ latents = callback_outputs.pop("latents", latents)
941
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
942
+
943
+ # call the callback, if provided
944
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
945
+ progress_bar.update()
946
+
947
+ if XLA_AVAILABLE:
948
+ xm.mark_step()
949
+
950
+ self._current_timestep = None
951
+
952
+ if output_type == "latent":
953
+ image = latents
954
+ else:
955
+ latents = self._unpack_latents_with_ids(latents, latent_ids)
956
+
957
+ latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(latents.device, latents.dtype)
958
+ latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps).to(
959
+ latents.device, latents.dtype
960
+ )
961
+ latents = latents * latents_bn_std + latents_bn_mean
962
+ latents = self._unpatchify_latents(latents)
963
+
964
+ image = self.vae.decode(latents, return_dict=False)[0]
965
+ image = self.image_processor.postprocess(image, output_type=output_type)
966
+
967
+ # Offload all models
968
+ self.maybe_free_model_hooks()
969
+
970
+ if not return_dict:
971
+ return (image,)
972
+
973
+ return Flux2PipelineOutput(images=image)
videox_fun/pipeline/pipeline_hunyuanvideo.py ADDED
@@ -0,0 +1,805 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py
2
+ # Copyright 2025 The HunyuanVideo Team and The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import inspect
17
+ from dataclasses import dataclass
18
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
19
+
20
+ import numpy as np
21
+ import torch
22
+ from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
23
+ from diffusers.loaders import HunyuanVideoLoraLoaderMixin
24
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
25
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
26
+ from diffusers.utils import (BaseOutput, deprecate, is_torch_xla_available,
27
+ logging, replace_example_docstring)
28
+ from diffusers.utils.torch_utils import randn_tensor
29
+ from diffusers.video_processor import VideoProcessor
30
+
31
+ from ..models import (AutoencoderKLHunyuanVideo, CLIPImageProcessor,
32
+ CLIPTextModel, CLIPTokenizer,
33
+ HunyuanVideoTransformer3DModel, LlamaModel,
34
+ LlamaTokenizerFast, LlavaForConditionalGeneration)
35
+
36
+ if is_torch_xla_available():
37
+ import torch_xla.core.xla_model as xm
38
+
39
+ XLA_AVAILABLE = True
40
+ else:
41
+ XLA_AVAILABLE = False
42
+
43
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
44
+
45
+
46
+ EXAMPLE_DOC_STRING = """
47
+ Examples:
48
+ ```python
49
+ >>> import torch
50
+ >>> from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel
51
+ >>> from diffusers.utils import export_to_video
52
+
53
+ >>> model_id = "hunyuanvideo-community/HunyuanVideo"
54
+ >>> transformer = HunyuanVideoTransformer3DModel.from_pretrained(
55
+ ... model_id, subfolder="transformer", torch_dtype=torch.bfloat16
56
+ ... )
57
+ >>> pipe = HunyuanVideoPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch.float16)
58
+ >>> pipe.vae.enable_tiling()
59
+ >>> pipe.to("cuda")
60
+
61
+ >>> output = pipe(
62
+ ... prompt="A cat walks on the grass, realistic",
63
+ ... height=320,
64
+ ... width=512,
65
+ ... num_frames=61,
66
+ ... num_inference_steps=30,
67
+ ... ).frames[0]
68
+ >>> export_to_video(output, "output.mp4", fps=15)
69
+ ```
70
+ """
71
+
72
+
73
+ DEFAULT_PROMPT_TEMPLATE = {
74
+ "template": (
75
+ "<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: "
76
+ "1. The main content and theme of the video."
77
+ "2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects."
78
+ "3. Actions, events, behaviors temporal relationships, physical movement changes of the objects."
79
+ "4. background environment, light, style and atmosphere."
80
+ "5. camera angles, movements, and transitions used in the video:<|eot_id|>"
81
+ "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
82
+ ),
83
+ "crop_start": 95,
84
+ }
85
+
86
+
87
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
88
+ def retrieve_timesteps(
89
+ scheduler,
90
+ num_inference_steps: Optional[int] = None,
91
+ device: Optional[Union[str, torch.device]] = None,
92
+ timesteps: Optional[List[int]] = None,
93
+ sigmas: Optional[List[float]] = None,
94
+ **kwargs,
95
+ ):
96
+ r"""
97
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
98
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
99
+
100
+ Args:
101
+ scheduler (`SchedulerMixin`):
102
+ The scheduler to get timesteps from.
103
+ num_inference_steps (`int`):
104
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
105
+ must be `None`.
106
+ device (`str` or `torch.device`, *optional*):
107
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
108
+ timesteps (`List[int]`, *optional*):
109
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
110
+ `num_inference_steps` and `sigmas` must be `None`.
111
+ sigmas (`List[float]`, *optional*):
112
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
113
+ `num_inference_steps` and `timesteps` must be `None`.
114
+
115
+ Returns:
116
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
117
+ second element is the number of inference steps.
118
+ """
119
+ if timesteps is not None and sigmas is not None:
120
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
121
+ if timesteps is not None:
122
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
123
+ if not accepts_timesteps:
124
+ raise ValueError(
125
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
126
+ f" timestep schedules. Please check whether you are using the correct scheduler."
127
+ )
128
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
129
+ timesteps = scheduler.timesteps
130
+ num_inference_steps = len(timesteps)
131
+ elif sigmas is not None:
132
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
133
+ if not accept_sigmas:
134
+ raise ValueError(
135
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
136
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
137
+ )
138
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
139
+ timesteps = scheduler.timesteps
140
+ num_inference_steps = len(timesteps)
141
+ else:
142
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
143
+ timesteps = scheduler.timesteps
144
+ return timesteps, num_inference_steps
145
+
146
+
147
+ @dataclass
148
+ class HunyuanVideoPipelineOutput(BaseOutput):
149
+ r"""
150
+ Output class for video pipelines.
151
+
152
+ Args:
153
+ video (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
154
+ List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
155
+ denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
156
+ `(batch_size, num_frames, channels, height, width)`.
157
+ """
158
+
159
+ videos: torch.Tensor
160
+
161
+
162
+ class HunyuanVideoPipeline(DiffusionPipeline):
163
+ r"""
164
+ Pipeline for text-to-video generation using HunyuanVideo.
165
+
166
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
167
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
168
+
169
+ Args:
170
+ text_encoder ([`LlamaModel`]):
171
+ [Llava Llama3-8B](https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers).
172
+ tokenizer (`LlamaTokenizer`):
173
+ Tokenizer from [Llava Llama3-8B](https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers).
174
+ transformer ([`HunyuanVideoTransformer3DModel`]):
175
+ Conditional Transformer to denoise the encoded image latents.
176
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
177
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
178
+ vae ([`AutoencoderKLHunyuanVideo`]):
179
+ Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
180
+ text_encoder_2 ([`CLIPTextModel`]):
181
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
182
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
183
+ tokenizer_2 (`CLIPTokenizer`):
184
+ Tokenizer of class
185
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
186
+ """
187
+
188
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
189
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
190
+
191
+ def __init__(
192
+ self,
193
+ text_encoder: LlamaModel,
194
+ tokenizer: LlamaTokenizerFast,
195
+ transformer: HunyuanVideoTransformer3DModel,
196
+ vae: AutoencoderKLHunyuanVideo,
197
+ scheduler: FlowMatchEulerDiscreteScheduler,
198
+ text_encoder_2: CLIPTextModel,
199
+ tokenizer_2: CLIPTokenizer,
200
+ ):
201
+ super().__init__()
202
+
203
+ self.register_modules(
204
+ vae=vae,
205
+ text_encoder=text_encoder,
206
+ tokenizer=tokenizer,
207
+ transformer=transformer,
208
+ scheduler=scheduler,
209
+ text_encoder_2=text_encoder_2,
210
+ tokenizer_2=tokenizer_2,
211
+ )
212
+
213
+ self.vae_scale_factor_temporal = self.vae.temporal_compression_ratio if getattr(self, "vae", None) else 4
214
+ self.vae_scale_factor_spatial = self.vae.spatial_compression_ratio if getattr(self, "vae", None) else 8
215
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
216
+
217
+ def _get_llama_prompt_embeds(
218
+ self,
219
+ prompt: Union[str, List[str]],
220
+ prompt_template: Dict[str, Any],
221
+ num_videos_per_prompt: int = 1,
222
+ device: Optional[torch.device] = None,
223
+ dtype: Optional[torch.dtype] = None,
224
+ max_sequence_length: int = 256,
225
+ num_hidden_layers_to_skip: int = 2,
226
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
227
+ device = device or self._execution_device
228
+ dtype = dtype or self.text_encoder.dtype
229
+
230
+ prompt = [prompt] if isinstance(prompt, str) else prompt
231
+ batch_size = len(prompt)
232
+
233
+ prompt = [prompt_template["template"].format(p) for p in prompt]
234
+
235
+ crop_start = prompt_template.get("crop_start", None)
236
+ if crop_start is None:
237
+ prompt_template_input = self.tokenizer(
238
+ prompt_template["template"],
239
+ padding="max_length",
240
+ return_tensors="pt",
241
+ return_length=False,
242
+ return_overflowing_tokens=False,
243
+ return_attention_mask=False,
244
+ )
245
+ crop_start = prompt_template_input["input_ids"].shape[-1]
246
+ # Remove <|eot_id|> token and placeholder {}
247
+ crop_start -= 2
248
+
249
+ max_sequence_length += crop_start
250
+ text_inputs = self.tokenizer(
251
+ prompt,
252
+ max_length=max_sequence_length,
253
+ padding="max_length",
254
+ truncation=True,
255
+ return_tensors="pt",
256
+ return_length=False,
257
+ return_overflowing_tokens=False,
258
+ return_attention_mask=True,
259
+ )
260
+ text_input_ids = text_inputs.input_ids.to(device=device)
261
+ prompt_attention_mask = text_inputs.attention_mask.to(device=device)
262
+
263
+ prompt_embeds = self.text_encoder(
264
+ input_ids=text_input_ids,
265
+ attention_mask=prompt_attention_mask,
266
+ output_hidden_states=True,
267
+ ).hidden_states[-(num_hidden_layers_to_skip + 1)]
268
+ prompt_embeds = prompt_embeds.to(dtype=dtype)
269
+
270
+ if crop_start is not None and crop_start > 0:
271
+ prompt_embeds = prompt_embeds[:, crop_start:]
272
+ prompt_attention_mask = prompt_attention_mask[:, crop_start:]
273
+
274
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
275
+ _, seq_len, _ = prompt_embeds.shape
276
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
277
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
278
+ prompt_attention_mask = prompt_attention_mask.repeat(1, num_videos_per_prompt)
279
+ prompt_attention_mask = prompt_attention_mask.view(batch_size * num_videos_per_prompt, seq_len)
280
+
281
+ return prompt_embeds, prompt_attention_mask
282
+
283
+ def _get_clip_prompt_embeds(
284
+ self,
285
+ prompt: Union[str, List[str]],
286
+ num_videos_per_prompt: int = 1,
287
+ device: Optional[torch.device] = None,
288
+ dtype: Optional[torch.dtype] = None,
289
+ max_sequence_length: int = 77,
290
+ ) -> torch.Tensor:
291
+ device = device or self._execution_device
292
+ dtype = dtype or self.text_encoder_2.dtype
293
+
294
+ prompt = [prompt] if isinstance(prompt, str) else prompt
295
+ batch_size = len(prompt)
296
+
297
+ text_inputs = self.tokenizer_2(
298
+ prompt,
299
+ padding="max_length",
300
+ max_length=max_sequence_length,
301
+ truncation=True,
302
+ return_tensors="pt",
303
+ )
304
+
305
+ text_input_ids = text_inputs.input_ids
306
+ untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
307
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
308
+ removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
309
+ logger.warning(
310
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
311
+ f" {max_sequence_length} tokens: {removed_text}"
312
+ )
313
+
314
+ prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False).pooler_output
315
+
316
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
317
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt)
318
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, -1)
319
+
320
+ return prompt_embeds
321
+
322
+ def encode_prompt(
323
+ self,
324
+ prompt: Union[str, List[str]],
325
+ prompt_2: Union[str, List[str]] = None,
326
+ prompt_template: Dict[str, Any] = DEFAULT_PROMPT_TEMPLATE,
327
+ num_videos_per_prompt: int = 1,
328
+ prompt_embeds: Optional[torch.Tensor] = None,
329
+ pooled_prompt_embeds: Optional[torch.Tensor] = None,
330
+ prompt_attention_mask: Optional[torch.Tensor] = None,
331
+ device: Optional[torch.device] = None,
332
+ dtype: Optional[torch.dtype] = None,
333
+ max_sequence_length: int = 256,
334
+ ):
335
+ if prompt_embeds is None:
336
+ prompt_embeds, prompt_attention_mask = self._get_llama_prompt_embeds(
337
+ prompt,
338
+ prompt_template,
339
+ num_videos_per_prompt,
340
+ device=device,
341
+ dtype=dtype,
342
+ max_sequence_length=max_sequence_length,
343
+ )
344
+
345
+ if pooled_prompt_embeds is None:
346
+ if prompt_2 is None:
347
+ prompt_2 = prompt
348
+ pooled_prompt_embeds = self._get_clip_prompt_embeds(
349
+ prompt,
350
+ num_videos_per_prompt,
351
+ device=device,
352
+ dtype=dtype,
353
+ max_sequence_length=77,
354
+ )
355
+
356
+ return prompt_embeds, pooled_prompt_embeds, prompt_attention_mask
357
+
358
+ def check_inputs(
359
+ self,
360
+ prompt,
361
+ prompt_2,
362
+ height,
363
+ width,
364
+ prompt_embeds=None,
365
+ callback_on_step_end_tensor_inputs=None,
366
+ prompt_template=None,
367
+ ):
368
+ if height % 16 != 0 or width % 16 != 0:
369
+ raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
370
+
371
+ if callback_on_step_end_tensor_inputs is not None and not all(
372
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
373
+ ):
374
+ raise ValueError(
375
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
376
+ )
377
+
378
+ if prompt is not None and prompt_embeds is not None:
379
+ raise ValueError(
380
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
381
+ " only forward one of the two."
382
+ )
383
+ elif prompt_2 is not None and prompt_embeds is not None:
384
+ raise ValueError(
385
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
386
+ " only forward one of the two."
387
+ )
388
+ elif prompt is None and prompt_embeds is None:
389
+ raise ValueError(
390
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
391
+ )
392
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
393
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
394
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
395
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
396
+
397
+ if prompt_template is not None:
398
+ if not isinstance(prompt_template, dict):
399
+ raise ValueError(f"`prompt_template` has to be of type `dict` but is {type(prompt_template)}")
400
+ if "template" not in prompt_template:
401
+ raise ValueError(
402
+ f"`prompt_template` has to contain a key `template` but only found {prompt_template.keys()}"
403
+ )
404
+
405
+ def prepare_latents(
406
+ self,
407
+ batch_size: int,
408
+ num_channels_latents: int = 32,
409
+ height: int = 720,
410
+ width: int = 1280,
411
+ num_frames: int = 129,
412
+ dtype: Optional[torch.dtype] = None,
413
+ device: Optional[torch.device] = None,
414
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
415
+ latents: Optional[torch.Tensor] = None,
416
+ ) -> torch.Tensor:
417
+ if latents is not None:
418
+ return latents.to(device=device, dtype=dtype)
419
+
420
+ shape = (
421
+ batch_size,
422
+ num_channels_latents,
423
+ (num_frames - 1) // self.vae_scale_factor_temporal + 1,
424
+ int(height) // self.vae_scale_factor_spatial,
425
+ int(width) // self.vae_scale_factor_spatial,
426
+ )
427
+ if isinstance(generator, list) and len(generator) != batch_size:
428
+ raise ValueError(
429
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
430
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
431
+ )
432
+
433
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
434
+ return latents
435
+
436
+ def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
437
+ latents = 1 / self.vae.config.scaling_factor * latents
438
+
439
+ frames = self.vae.decode(latents).sample
440
+ frames = (frames / 2 + 0.5).clamp(0, 1)
441
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
442
+ frames = frames.cpu().float().numpy()
443
+ return frames
444
+
445
+ def enable_vae_slicing(self):
446
+ r"""
447
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
448
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
449
+ """
450
+ depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`."
451
+ deprecate(
452
+ "enable_vae_slicing",
453
+ "0.40.0",
454
+ depr_message,
455
+ )
456
+ self.vae.enable_slicing()
457
+
458
+ def disable_vae_slicing(self):
459
+ r"""
460
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
461
+ computing decoding in one step.
462
+ """
463
+ depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`."
464
+ deprecate(
465
+ "disable_vae_slicing",
466
+ "0.40.0",
467
+ depr_message,
468
+ )
469
+ self.vae.disable_slicing()
470
+
471
+ def enable_vae_tiling(self):
472
+ r"""
473
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
474
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
475
+ processing larger images.
476
+ """
477
+ depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`."
478
+ deprecate(
479
+ "enable_vae_tiling",
480
+ "0.40.0",
481
+ depr_message,
482
+ )
483
+ self.vae.enable_tiling()
484
+
485
+ def disable_vae_tiling(self):
486
+ r"""
487
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
488
+ computing decoding in one step.
489
+ """
490
+ depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`."
491
+ deprecate(
492
+ "disable_vae_tiling",
493
+ "0.40.0",
494
+ depr_message,
495
+ )
496
+ self.vae.disable_tiling()
497
+
498
+ @property
499
+ def guidance_scale(self):
500
+ return self._guidance_scale
501
+
502
+ @property
503
+ def num_timesteps(self):
504
+ return self._num_timesteps
505
+
506
+ @property
507
+ def attention_kwargs(self):
508
+ return self._attention_kwargs
509
+
510
+ @property
511
+ def current_timestep(self):
512
+ return self._current_timestep
513
+
514
+ @property
515
+ def interrupt(self):
516
+ return self._interrupt
517
+
518
+ @torch.no_grad()
519
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
520
+ def __call__(
521
+ self,
522
+ prompt: Union[str, List[str]] = None,
523
+ prompt_2: Union[str, List[str]] = None,
524
+ negative_prompt: Union[str, List[str]] = None,
525
+ negative_prompt_2: Union[str, List[str]] = None,
526
+ height: int = 720,
527
+ width: int = 1280,
528
+ num_frames: int = 129,
529
+ num_inference_steps: int = 50,
530
+ sigmas: List[float] = None,
531
+ true_cfg_scale: float = 1.0,
532
+ guidance_scale: float = 6.0,
533
+ num_videos_per_prompt: Optional[int] = 1,
534
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
535
+ latents: Optional[torch.Tensor] = None,
536
+ prompt_embeds: Optional[torch.Tensor] = None,
537
+ pooled_prompt_embeds: Optional[torch.Tensor] = None,
538
+ prompt_attention_mask: Optional[torch.Tensor] = None,
539
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
540
+ negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
541
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
542
+ output_type: str = "numpy",
543
+ return_dict: bool = False,
544
+ attention_kwargs: Optional[Dict[str, Any]] = None,
545
+ callback_on_step_end: Optional[
546
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
547
+ ] = None,
548
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
549
+ prompt_template: Dict[str, Any] = DEFAULT_PROMPT_TEMPLATE,
550
+ max_sequence_length: int = 256,
551
+ ):
552
+ r"""
553
+ The call function to the pipeline for generation.
554
+
555
+ Args:
556
+ prompt (`str` or `List[str]`, *optional*):
557
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
558
+ instead.
559
+ prompt_2 (`str` or `List[str]`, *optional*):
560
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
561
+ will be used instead.
562
+ negative_prompt (`str` or `List[str]`, *optional*):
563
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
564
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is
565
+ not greater than `1`).
566
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
567
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
568
+ `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders.
569
+ height (`int`, defaults to `720`):
570
+ The height in pixels of the generated image.
571
+ width (`int`, defaults to `1280`):
572
+ The width in pixels of the generated image.
573
+ num_frames (`int`, defaults to `129`):
574
+ The number of frames in the generated video.
575
+ num_inference_steps (`int`, defaults to `50`):
576
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
577
+ expense of slower inference.
578
+ sigmas (`List[float]`, *optional*):
579
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
580
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
581
+ will be used.
582
+ true_cfg_scale (`float`, *optional*, defaults to 1.0):
583
+ True classifier-free guidance (guidance scale) is enabled when `true_cfg_scale` > 1 and
584
+ `negative_prompt` is provided.
585
+ guidance_scale (`float`, defaults to `6.0`):
586
+ Embedded guiddance scale is enabled by setting `guidance_scale` > 1. Higher `guidance_scale` encourages
587
+ a model to generate images more aligned with `prompt` at the expense of lower image quality.
588
+
589
+ Guidance-distilled models approximates true classifer-free guidance for `guidance_scale` > 1. Refer to
590
+ the [paper](https://huggingface.co/papers/2210.03142) to learn more.
591
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
592
+ The number of images to generate per prompt.
593
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
594
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
595
+ generation deterministic.
596
+ latents (`torch.Tensor`, *optional*):
597
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
598
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
599
+ tensor is generated by sampling using the supplied random `generator`.
600
+ prompt_embeds (`torch.Tensor`, *optional*):
601
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
602
+ provided, text embeddings are generated from the `prompt` input argument.
603
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
604
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
605
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
606
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
607
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
608
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
609
+ argument.
610
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
611
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
612
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
613
+ input argument.
614
+ output_type (`str`, *optional*, defaults to `"pil"`):
615
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
616
+ return_dict (`bool`, *optional*, defaults to `True`):
617
+ Whether or not to return a [`HunyuanVideoPipelineOutput`] instead of a plain tuple.
618
+ attention_kwargs (`dict`, *optional*):
619
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
620
+ `self.processor` in
621
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
622
+ clip_skip (`int`, *optional*):
623
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
624
+ the output of the pre-final layer will be used for computing the prompt embeddings.
625
+ callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
626
+ A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
627
+ each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
628
+ DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
629
+ list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
630
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
631
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
632
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
633
+ `._callback_tensor_inputs` attribute of your pipeline class.
634
+
635
+ Examples:
636
+
637
+ Returns:
638
+ [`~HunyuanVideoPipelineOutput`] or `tuple`:
639
+ If `return_dict` is `True`, [`HunyuanVideoPipelineOutput`] is returned, otherwise a `tuple` is returned
640
+ where the first element is a list with the generated images and the second element is a list of `bool`s
641
+ indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
642
+ """
643
+
644
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
645
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
646
+
647
+ # 1. Check inputs. Raise error if not correct
648
+ self.check_inputs(
649
+ prompt,
650
+ prompt_2,
651
+ height,
652
+ width,
653
+ prompt_embeds,
654
+ callback_on_step_end_tensor_inputs,
655
+ prompt_template,
656
+ )
657
+
658
+ has_neg_prompt = negative_prompt is not None or (
659
+ negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None
660
+ )
661
+ do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
662
+
663
+ self._guidance_scale = guidance_scale
664
+ self._attention_kwargs = attention_kwargs
665
+ self._current_timestep = None
666
+ self._interrupt = False
667
+
668
+ device = self._execution_device
669
+
670
+ # 2. Define call parameters
671
+ if prompt is not None and isinstance(prompt, str):
672
+ batch_size = 1
673
+ elif prompt is not None and isinstance(prompt, list):
674
+ batch_size = len(prompt)
675
+ else:
676
+ batch_size = prompt_embeds.shape[0]
677
+
678
+ # 3. Encode input prompt
679
+ transformer_dtype = self.transformer.dtype
680
+ prompt_embeds, pooled_prompt_embeds, prompt_attention_mask = self.encode_prompt(
681
+ prompt=prompt,
682
+ prompt_2=prompt_2,
683
+ prompt_template=prompt_template,
684
+ num_videos_per_prompt=num_videos_per_prompt,
685
+ prompt_embeds=prompt_embeds,
686
+ pooled_prompt_embeds=pooled_prompt_embeds,
687
+ prompt_attention_mask=prompt_attention_mask,
688
+ device=device,
689
+ max_sequence_length=max_sequence_length,
690
+ )
691
+ prompt_embeds = prompt_embeds.to(transformer_dtype)
692
+ prompt_attention_mask = prompt_attention_mask.to(transformer_dtype)
693
+ pooled_prompt_embeds = pooled_prompt_embeds.to(transformer_dtype)
694
+
695
+ if do_true_cfg:
696
+ negative_prompt_embeds, negative_pooled_prompt_embeds, negative_prompt_attention_mask = self.encode_prompt(
697
+ prompt=negative_prompt,
698
+ prompt_2=negative_prompt_2,
699
+ prompt_template=prompt_template,
700
+ num_videos_per_prompt=num_videos_per_prompt,
701
+ prompt_embeds=negative_prompt_embeds,
702
+ pooled_prompt_embeds=negative_pooled_prompt_embeds,
703
+ prompt_attention_mask=negative_prompt_attention_mask,
704
+ device=device,
705
+ max_sequence_length=max_sequence_length,
706
+ )
707
+ negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
708
+ negative_prompt_attention_mask = negative_prompt_attention_mask.to(transformer_dtype)
709
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.to(transformer_dtype)
710
+
711
+ # 4. Prepare timesteps
712
+ sigmas = np.linspace(1.0, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas
713
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas)
714
+
715
+ # 5. Prepare latent variables
716
+ num_channels_latents = self.transformer.config.in_channels
717
+ latents = self.prepare_latents(
718
+ batch_size * num_videos_per_prompt,
719
+ num_channels_latents,
720
+ height,
721
+ width,
722
+ num_frames,
723
+ torch.float32,
724
+ device,
725
+ generator,
726
+ latents,
727
+ )
728
+
729
+ # 6. Prepare guidance condition
730
+ guidance = torch.tensor([guidance_scale] * latents.shape[0], dtype=transformer_dtype, device=device) * 1000.0
731
+
732
+ # 7. Denoising loop
733
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
734
+ self._num_timesteps = len(timesteps)
735
+
736
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
737
+ for i, t in enumerate(timesteps):
738
+ if self.interrupt:
739
+ continue
740
+
741
+ self._current_timestep = t
742
+ latent_model_input = latents.to(transformer_dtype)
743
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
744
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
745
+
746
+ noise_pred = self.transformer(
747
+ hidden_states=latent_model_input,
748
+ timestep=timestep,
749
+ encoder_hidden_states=prompt_embeds,
750
+ encoder_attention_mask=prompt_attention_mask,
751
+ pooled_projections=pooled_prompt_embeds,
752
+ guidance=guidance,
753
+ attention_kwargs=attention_kwargs,
754
+ return_dict=False,
755
+ )[0]
756
+
757
+ if do_true_cfg:
758
+ neg_noise_pred = self.transformer(
759
+ hidden_states=latent_model_input,
760
+ timestep=timestep,
761
+ encoder_hidden_states=negative_prompt_embeds,
762
+ encoder_attention_mask=negative_prompt_attention_mask,
763
+ pooled_projections=negative_pooled_prompt_embeds,
764
+ guidance=guidance,
765
+ attention_kwargs=attention_kwargs,
766
+ return_dict=False,
767
+ )[0]
768
+ noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
769
+
770
+ # compute the previous noisy sample x_t -> x_t-1
771
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
772
+
773
+ if callback_on_step_end is not None:
774
+ callback_kwargs = {}
775
+ for k in callback_on_step_end_tensor_inputs:
776
+ callback_kwargs[k] = locals()[k]
777
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
778
+
779
+ latents = callback_outputs.pop("latents", latents)
780
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
781
+
782
+ # call the callback, if provided
783
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
784
+ progress_bar.update()
785
+
786
+ if XLA_AVAILABLE:
787
+ xm.mark_step()
788
+
789
+ self._current_timestep = None
790
+
791
+ if output_type == "numpy":
792
+ video = self.decode_latents(latents)
793
+ elif not output_type == "latent":
794
+ video = self.decode_latents(latents)
795
+ video = self.video_processor.postprocess_video(video=video, output_type=output_type)
796
+ else:
797
+ video = latents
798
+
799
+ # Offload all models
800
+ self.maybe_free_model_hooks()
801
+
802
+ if not return_dict:
803
+ video = torch.from_numpy(video)
804
+
805
+ return HunyuanVideoPipelineOutput(videos=video)
videox_fun/pipeline/pipeline_hunyuanvideo_i2v.py ADDED
@@ -0,0 +1,972 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py
2
+ # Copyright 2025 The HunyuanVideo Team and The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import inspect
17
+ from dataclasses import dataclass
18
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
19
+
20
+ import numpy as np
21
+ import PIL
22
+ import torch
23
+ from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
24
+ from diffusers.loaders import HunyuanVideoLoraLoaderMixin
25
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
26
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
27
+ from diffusers.utils import (BaseOutput, deprecate, is_torch_xla_available,
28
+ logging, replace_example_docstring)
29
+ from diffusers.utils.torch_utils import randn_tensor
30
+ from diffusers.video_processor import VideoProcessor
31
+
32
+ from ..models import (AutoencoderKLHunyuanVideo, CLIPImageProcessor,
33
+ CLIPTextModel, CLIPTokenizer,
34
+ HunyuanVideoTransformer3DModel, LlamaModel,
35
+ LlamaTokenizerFast, LlavaForConditionalGeneration)
36
+
37
+ if is_torch_xla_available():
38
+ import torch_xla.core.xla_model as xm
39
+
40
+ XLA_AVAILABLE = True
41
+ else:
42
+ XLA_AVAILABLE = False
43
+
44
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
45
+
46
+
47
+ EXAMPLE_DOC_STRING = """
48
+ Examples:
49
+ ```python
50
+ ```
51
+ """
52
+
53
+
54
+ DEFAULT_PROMPT_TEMPLATE = {
55
+ "template": (
56
+ "<|start_header_id|>system<|end_header_id|>\n\n<image>\nDescribe the video by detailing the following aspects according to the reference image: "
57
+ "1. The main content and theme of the video."
58
+ "2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects."
59
+ "3. Actions, events, behaviors temporal relationships, physical movement changes of the objects."
60
+ "4. background environment, light, style and atmosphere."
61
+ "5. camera angles, movements, and transitions used in the video:<|eot_id|>\n\n"
62
+ "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
63
+ "<|start_header_id|>assistant<|end_header_id|>\n\n"
64
+ ),
65
+ "crop_start": 103,
66
+ "image_emb_start": 5,
67
+ "image_emb_end": 581,
68
+ "image_emb_len": 576,
69
+ "double_return_token_id": 271,
70
+ }
71
+
72
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
73
+ def retrieve_timesteps(
74
+ scheduler,
75
+ num_inference_steps: Optional[int] = None,
76
+ device: Optional[Union[str, torch.device]] = None,
77
+ timesteps: Optional[List[int]] = None,
78
+ sigmas: Optional[List[float]] = None,
79
+ **kwargs,
80
+ ):
81
+ r"""
82
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
83
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
84
+
85
+ Args:
86
+ scheduler (`SchedulerMixin`):
87
+ The scheduler to get timesteps from.
88
+ num_inference_steps (`int`):
89
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
90
+ must be `None`.
91
+ device (`str` or `torch.device`, *optional*):
92
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
93
+ timesteps (`List[int]`, *optional*):
94
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
95
+ `num_inference_steps` and `sigmas` must be `None`.
96
+ sigmas (`List[float]`, *optional*):
97
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
98
+ `num_inference_steps` and `timesteps` must be `None`.
99
+
100
+ Returns:
101
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
102
+ second element is the number of inference steps.
103
+ """
104
+ if timesteps is not None and sigmas is not None:
105
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
106
+ if timesteps is not None:
107
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
108
+ if not accepts_timesteps:
109
+ raise ValueError(
110
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
111
+ f" timestep schedules. Please check whether you are using the correct scheduler."
112
+ )
113
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
114
+ timesteps = scheduler.timesteps
115
+ num_inference_steps = len(timesteps)
116
+ elif sigmas is not None:
117
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
118
+ if not accept_sigmas:
119
+ raise ValueError(
120
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
121
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
122
+ )
123
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
124
+ timesteps = scheduler.timesteps
125
+ num_inference_steps = len(timesteps)
126
+ else:
127
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
128
+ timesteps = scheduler.timesteps
129
+ return timesteps, num_inference_steps
130
+
131
+
132
+ def _expand_input_ids_with_image_tokens(
133
+ text_input_ids,
134
+ prompt_attention_mask,
135
+ max_sequence_length,
136
+ image_token_index,
137
+ image_emb_len,
138
+ image_emb_start,
139
+ image_emb_end,
140
+ pad_token_id,
141
+ ):
142
+ special_image_token_mask = text_input_ids == image_token_index
143
+ num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1)
144
+ batch_indices, non_image_indices = torch.where(text_input_ids != image_token_index)
145
+
146
+ max_expanded_length = max_sequence_length + (num_special_image_tokens.max() * (image_emb_len - 1))
147
+ new_token_positions = torch.cumsum((special_image_token_mask * (image_emb_len - 1) + 1), -1) - 1
148
+ text_to_overwrite = new_token_positions[batch_indices, non_image_indices]
149
+
150
+ expanded_input_ids = torch.full(
151
+ (text_input_ids.shape[0], max_expanded_length),
152
+ pad_token_id,
153
+ dtype=text_input_ids.dtype,
154
+ device=text_input_ids.device,
155
+ )
156
+ expanded_input_ids[batch_indices, text_to_overwrite] = text_input_ids[batch_indices, non_image_indices]
157
+ expanded_input_ids[batch_indices, image_emb_start:image_emb_end] = image_token_index
158
+
159
+ expanded_attention_mask = torch.zeros(
160
+ (text_input_ids.shape[0], max_expanded_length),
161
+ dtype=prompt_attention_mask.dtype,
162
+ device=prompt_attention_mask.device,
163
+ )
164
+ attn_batch_indices, attention_indices = torch.where(expanded_input_ids != pad_token_id)
165
+ expanded_attention_mask[attn_batch_indices, attention_indices] = 1.0
166
+ expanded_attention_mask = expanded_attention_mask.to(prompt_attention_mask.dtype)
167
+ position_ids = (expanded_attention_mask.cumsum(-1) - 1).masked_fill_((expanded_attention_mask == 0), 1)
168
+
169
+ return {
170
+ "input_ids": expanded_input_ids,
171
+ "attention_mask": expanded_attention_mask,
172
+ "position_ids": position_ids,
173
+ }
174
+
175
+
176
+ @dataclass
177
+ class HunyuanVideoPipelineOutput(BaseOutput):
178
+ r"""
179
+ Output class for video pipelines.
180
+
181
+ Args:
182
+ video (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
183
+ List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
184
+ denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
185
+ `(batch_size, num_frames, channels, height, width)`.
186
+ """
187
+
188
+ videos: torch.Tensor
189
+
190
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
191
+ def retrieve_latents(
192
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
193
+ ):
194
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
195
+ return encoder_output.latent_dist.sample(generator)
196
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
197
+ return encoder_output.latent_dist.mode()
198
+ elif hasattr(encoder_output, "latents"):
199
+ return encoder_output.latents
200
+ else:
201
+ raise AttributeError("Could not access latents of provided encoder_output")
202
+
203
+ class HunyuanVideoI2VPipeline(DiffusionPipeline):
204
+ r"""
205
+ Pipeline for image-to-video generation using HunyuanVideo.
206
+
207
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
208
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
209
+
210
+ Args:
211
+ text_encoder ([`LlamaModel`]):
212
+ [Llava Llama3-8B](https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers).
213
+ tokenizer (`LlamaTokenizer`):
214
+ Tokenizer from [Llava Llama3-8B](https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers).
215
+ transformer ([`HunyuanVideoTransformer3DModel`]):
216
+ Conditional Transformer to denoise the encoded image latents.
217
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
218
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
219
+ vae ([`AutoencoderKLHunyuanVideo`]):
220
+ Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
221
+ text_encoder_2 ([`CLIPTextModel`]):
222
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
223
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
224
+ tokenizer_2 (`CLIPTokenizer`):
225
+ Tokenizer of class
226
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
227
+ """
228
+
229
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
230
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
231
+
232
+ def __init__(
233
+ self,
234
+ text_encoder: LlavaForConditionalGeneration,
235
+ tokenizer: LlamaTokenizerFast,
236
+ transformer: HunyuanVideoTransformer3DModel,
237
+ vae: AutoencoderKLHunyuanVideo,
238
+ scheduler: FlowMatchEulerDiscreteScheduler,
239
+ text_encoder_2: CLIPTextModel,
240
+ tokenizer_2: CLIPTokenizer,
241
+ image_processor: CLIPImageProcessor,
242
+ ):
243
+ super().__init__()
244
+
245
+ self.register_modules(
246
+ vae=vae,
247
+ text_encoder=text_encoder,
248
+ tokenizer=tokenizer,
249
+ transformer=transformer,
250
+ scheduler=scheduler,
251
+ text_encoder_2=text_encoder_2,
252
+ tokenizer_2=tokenizer_2,
253
+ image_processor=image_processor,
254
+ )
255
+
256
+ self.vae_scaling_factor = self.vae.config.scaling_factor if getattr(self, "vae", None) else 0.476986
257
+ self.vae_scale_factor_temporal = self.vae.temporal_compression_ratio if getattr(self, "vae", None) else 4
258
+ self.vae_scale_factor_spatial = self.vae.spatial_compression_ratio if getattr(self, "vae", None) else 8
259
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
260
+
261
+ def _get_llama_prompt_embeds(
262
+ self,
263
+ image: torch.Tensor,
264
+ prompt: Union[str, List[str]],
265
+ prompt_template: Dict[str, Any],
266
+ num_videos_per_prompt: int = 1,
267
+ device: Optional[torch.device] = None,
268
+ dtype: Optional[torch.dtype] = None,
269
+ max_sequence_length: int = 256,
270
+ num_hidden_layers_to_skip: int = 2,
271
+ image_embed_interleave: int = 2,
272
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
273
+ device = device or self._execution_device
274
+ dtype = dtype or self.text_encoder.dtype
275
+
276
+ prompt = [prompt] if isinstance(prompt, str) else prompt
277
+ prompt = [prompt_template["template"].format(p) for p in prompt]
278
+
279
+ crop_start = prompt_template.get("crop_start", None)
280
+
281
+ image_emb_len = prompt_template.get("image_emb_len", 576)
282
+ image_emb_start = prompt_template.get("image_emb_start", 5)
283
+ image_emb_end = prompt_template.get("image_emb_end", 581)
284
+ double_return_token_id = prompt_template.get("double_return_token_id", 271)
285
+
286
+ if crop_start is None:
287
+ prompt_template_input = self.tokenizer(
288
+ prompt_template["template"],
289
+ padding="max_length",
290
+ return_tensors="pt",
291
+ return_length=False,
292
+ return_overflowing_tokens=False,
293
+ return_attention_mask=False,
294
+ )
295
+ crop_start = prompt_template_input["input_ids"].shape[-1]
296
+ # Remove <|start_header_id|>, <|end_header_id|>, assistant, <|eot_id|>, and placeholder {}
297
+ crop_start -= 5
298
+
299
+ max_sequence_length += crop_start
300
+ text_inputs = self.tokenizer(
301
+ prompt,
302
+ max_length=max_sequence_length,
303
+ padding="max_length",
304
+ truncation=True,
305
+ return_tensors="pt",
306
+ return_length=False,
307
+ return_overflowing_tokens=False,
308
+ return_attention_mask=True,
309
+ )
310
+ text_input_ids = text_inputs.input_ids.to(device=device)
311
+ prompt_attention_mask = text_inputs.attention_mask.to(device=device)
312
+
313
+ image_embeds = self.image_processor(image, return_tensors="pt").pixel_values.to(device)
314
+
315
+ image_token_index = self.text_encoder.config.image_token_index
316
+ pad_token_id = self.text_encoder.config.pad_token_id
317
+ expanded_inputs = _expand_input_ids_with_image_tokens(
318
+ text_input_ids,
319
+ prompt_attention_mask,
320
+ max_sequence_length,
321
+ image_token_index,
322
+ image_emb_len,
323
+ image_emb_start,
324
+ image_emb_end,
325
+ pad_token_id,
326
+ )
327
+ prompt_embeds = self.text_encoder(
328
+ **expanded_inputs,
329
+ pixel_values=image_embeds,
330
+ output_hidden_states=True,
331
+ ).hidden_states[-(num_hidden_layers_to_skip + 1)]
332
+ prompt_embeds = prompt_embeds.to(dtype=dtype)
333
+
334
+ if crop_start is not None and crop_start > 0:
335
+ text_crop_start = crop_start - 1 + image_emb_len
336
+ batch_indices, last_double_return_token_indices = torch.where(text_input_ids == double_return_token_id)
337
+
338
+ if last_double_return_token_indices.shape[0] == 3:
339
+ # in case the prompt is too long
340
+ last_double_return_token_indices = torch.cat(
341
+ (last_double_return_token_indices, torch.tensor([text_input_ids.shape[-1]]))
342
+ )
343
+ batch_indices = torch.cat((batch_indices, torch.tensor([0])))
344
+
345
+ last_double_return_token_indices = last_double_return_token_indices.reshape(text_input_ids.shape[0], -1)[
346
+ :, -1
347
+ ]
348
+ batch_indices = batch_indices.reshape(text_input_ids.shape[0], -1)[:, -1]
349
+ assistant_crop_start = last_double_return_token_indices - 1 + image_emb_len - 4
350
+ assistant_crop_end = last_double_return_token_indices - 1 + image_emb_len
351
+ attention_mask_assistant_crop_start = last_double_return_token_indices - 4
352
+ attention_mask_assistant_crop_end = last_double_return_token_indices
353
+
354
+ prompt_embed_list = []
355
+ prompt_attention_mask_list = []
356
+ image_embed_list = []
357
+ image_attention_mask_list = []
358
+
359
+ for i in range(text_input_ids.shape[0]):
360
+ prompt_embed_list.append(
361
+ torch.cat(
362
+ [
363
+ prompt_embeds[i, text_crop_start : assistant_crop_start[i].item()],
364
+ prompt_embeds[i, assistant_crop_end[i].item() :],
365
+ ]
366
+ )
367
+ )
368
+ prompt_attention_mask_list.append(
369
+ torch.cat(
370
+ [
371
+ prompt_attention_mask[i, crop_start : attention_mask_assistant_crop_start[i].item()],
372
+ prompt_attention_mask[i, attention_mask_assistant_crop_end[i].item() :],
373
+ ]
374
+ )
375
+ )
376
+ image_embed_list.append(prompt_embeds[i, image_emb_start:image_emb_end])
377
+ image_attention_mask_list.append(
378
+ torch.ones(image_embed_list[-1].shape[0]).to(prompt_embeds.device).to(prompt_attention_mask.dtype)
379
+ )
380
+
381
+ prompt_embed_list = torch.stack(prompt_embed_list)
382
+ prompt_attention_mask_list = torch.stack(prompt_attention_mask_list)
383
+ image_embed_list = torch.stack(image_embed_list)
384
+ image_attention_mask_list = torch.stack(image_attention_mask_list)
385
+
386
+ if 0 < image_embed_interleave < 6:
387
+ image_embed_list = image_embed_list[:, ::image_embed_interleave, :]
388
+ image_attention_mask_list = image_attention_mask_list[:, ::image_embed_interleave]
389
+
390
+ assert (
391
+ prompt_embed_list.shape[0] == prompt_attention_mask_list.shape[0]
392
+ and image_embed_list.shape[0] == image_attention_mask_list.shape[0]
393
+ )
394
+
395
+ prompt_embeds = torch.cat([image_embed_list, prompt_embed_list], dim=1)
396
+ prompt_attention_mask = torch.cat([image_attention_mask_list, prompt_attention_mask_list], dim=1)
397
+
398
+ return prompt_embeds, prompt_attention_mask
399
+
400
+ def _get_clip_prompt_embeds(
401
+ self,
402
+ prompt: Union[str, List[str]],
403
+ num_videos_per_prompt: int = 1,
404
+ device: Optional[torch.device] = None,
405
+ dtype: Optional[torch.dtype] = None,
406
+ max_sequence_length: int = 77,
407
+ ) -> torch.Tensor:
408
+ device = device or self._execution_device
409
+ dtype = dtype or self.text_encoder_2.dtype
410
+
411
+ prompt = [prompt] if isinstance(prompt, str) else prompt
412
+
413
+ text_inputs = self.tokenizer_2(
414
+ prompt,
415
+ padding="max_length",
416
+ max_length=max_sequence_length,
417
+ truncation=True,
418
+ return_tensors="pt",
419
+ )
420
+
421
+ text_input_ids = text_inputs.input_ids
422
+ untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
423
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
424
+ removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
425
+ logger.warning(
426
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
427
+ f" {max_sequence_length} tokens: {removed_text}"
428
+ )
429
+
430
+ prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False).pooler_output
431
+ return prompt_embeds
432
+
433
+ def encode_prompt(
434
+ self,
435
+ image: torch.Tensor,
436
+ prompt: Union[str, List[str]],
437
+ prompt_2: Union[str, List[str]] = None,
438
+ prompt_template: Dict[str, Any] = DEFAULT_PROMPT_TEMPLATE,
439
+ num_videos_per_prompt: int = 1,
440
+ prompt_embeds: Optional[torch.Tensor] = None,
441
+ pooled_prompt_embeds: Optional[torch.Tensor] = None,
442
+ prompt_attention_mask: Optional[torch.Tensor] = None,
443
+ device: Optional[torch.device] = None,
444
+ dtype: Optional[torch.dtype] = None,
445
+ max_sequence_length: int = 256,
446
+ image_embed_interleave: int = 2,
447
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
448
+ if prompt_embeds is None:
449
+ prompt_embeds, prompt_attention_mask = self._get_llama_prompt_embeds(
450
+ image,
451
+ prompt,
452
+ prompt_template,
453
+ num_videos_per_prompt,
454
+ device=device,
455
+ dtype=dtype,
456
+ max_sequence_length=max_sequence_length,
457
+ image_embed_interleave=image_embed_interleave,
458
+ )
459
+
460
+ if pooled_prompt_embeds is None:
461
+ if prompt_2 is None:
462
+ prompt_2 = prompt
463
+ pooled_prompt_embeds = self._get_clip_prompt_embeds(
464
+ prompt,
465
+ num_videos_per_prompt,
466
+ device=device,
467
+ dtype=dtype,
468
+ max_sequence_length=77,
469
+ )
470
+
471
+ return prompt_embeds, pooled_prompt_embeds, prompt_attention_mask
472
+
473
+ def check_inputs(
474
+ self,
475
+ prompt,
476
+ prompt_2,
477
+ height,
478
+ width,
479
+ prompt_embeds=None,
480
+ callback_on_step_end_tensor_inputs=None,
481
+ prompt_template=None,
482
+ true_cfg_scale=1.0,
483
+ guidance_scale=1.0,
484
+ ):
485
+ if height % 16 != 0 or width % 16 != 0:
486
+ raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
487
+
488
+ if callback_on_step_end_tensor_inputs is not None and not all(
489
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
490
+ ):
491
+ raise ValueError(
492
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
493
+ )
494
+
495
+ if prompt is not None and prompt_embeds is not None:
496
+ raise ValueError(
497
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
498
+ " only forward one of the two."
499
+ )
500
+ elif prompt_2 is not None and prompt_embeds is not None:
501
+ raise ValueError(
502
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
503
+ " only forward one of the two."
504
+ )
505
+ elif prompt is None and prompt_embeds is None:
506
+ raise ValueError(
507
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
508
+ )
509
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
510
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
511
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
512
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
513
+
514
+ if prompt_template is not None:
515
+ if not isinstance(prompt_template, dict):
516
+ raise ValueError(f"`prompt_template` has to be of type `dict` but is {type(prompt_template)}")
517
+ if "template" not in prompt_template:
518
+ raise ValueError(
519
+ f"`prompt_template` has to contain a key `template` but only found {prompt_template.keys()}"
520
+ )
521
+
522
+ if true_cfg_scale > 1.0 and guidance_scale > 1.0:
523
+ logger.warning(
524
+ "Both `true_cfg_scale` and `guidance_scale` are greater than 1.0. This will result in both "
525
+ "classifier-free guidance and embedded-guidance to be applied. This is not recommended "
526
+ "as it may lead to higher memory usage, slower inference and potentially worse results."
527
+ )
528
+
529
+ def prepare_latents(
530
+ self,
531
+ image: torch.Tensor,
532
+ batch_size: int,
533
+ num_channels_latents: int = 32,
534
+ height: int = 720,
535
+ width: int = 1280,
536
+ num_frames: int = 129,
537
+ dtype: Optional[torch.dtype] = None,
538
+ device: Optional[torch.device] = None,
539
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
540
+ latents: Optional[torch.Tensor] = None,
541
+ ) -> torch.Tensor:
542
+ if isinstance(generator, list) and len(generator) != batch_size:
543
+ raise ValueError(
544
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
545
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
546
+ )
547
+
548
+ num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
549
+ latent_height, latent_width = height // self.vae_scale_factor_spatial, width // self.vae_scale_factor_spatial
550
+ shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width)
551
+
552
+ image = image.unsqueeze(2) # [B, C, 1, H, W]
553
+ if isinstance(generator, list):
554
+ image_latents = [
555
+ retrieve_latents(self.vae.encode(image[i].unsqueeze(0)), generator[i], "argmax")
556
+ for i in range(batch_size)
557
+ ]
558
+ else:
559
+ image_latents = [retrieve_latents(self.vae.encode(img.unsqueeze(0)), generator, "argmax") for img in image]
560
+
561
+ image_latents = torch.cat(image_latents, dim=0).to(dtype) * self.vae_scaling_factor
562
+ image_latents = image_latents.repeat(1, 1, num_latent_frames, 1, 1)
563
+
564
+ if latents is None:
565
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
566
+ else:
567
+ latents = latents.to(device=device, dtype=dtype)
568
+
569
+ t = torch.tensor([0.999]).to(device=device)
570
+ latents = latents * t + image_latents * (1 - t)
571
+
572
+ image_latents = image_latents[:, :, :1]
573
+ return latents, image_latents
574
+
575
+ def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
576
+ latents = 1 / self.vae.config.scaling_factor * latents
577
+
578
+ frames = self.vae.decode(latents).sample
579
+ frames = (frames / 2 + 0.5).clamp(0, 1)
580
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
581
+ frames = frames.cpu().float().numpy()
582
+ return frames
583
+
584
+ def enable_vae_slicing(self):
585
+ r"""
586
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
587
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
588
+ """
589
+ depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`."
590
+ deprecate(
591
+ "enable_vae_slicing",
592
+ "0.40.0",
593
+ depr_message,
594
+ )
595
+ self.vae.enable_slicing()
596
+
597
+ def disable_vae_slicing(self):
598
+ r"""
599
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
600
+ computing decoding in one step.
601
+ """
602
+ depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`."
603
+ deprecate(
604
+ "disable_vae_slicing",
605
+ "0.40.0",
606
+ depr_message,
607
+ )
608
+ self.vae.disable_slicing()
609
+
610
+ def enable_vae_tiling(self):
611
+ r"""
612
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
613
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
614
+ processing larger images.
615
+ """
616
+ depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`."
617
+ deprecate(
618
+ "enable_vae_tiling",
619
+ "0.40.0",
620
+ depr_message,
621
+ )
622
+ self.vae.enable_tiling()
623
+
624
+ def disable_vae_tiling(self):
625
+ r"""
626
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
627
+ computing decoding in one step.
628
+ """
629
+ depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`."
630
+ deprecate(
631
+ "disable_vae_tiling",
632
+ "0.40.0",
633
+ depr_message,
634
+ )
635
+ self.vae.disable_tiling()
636
+
637
+ @property
638
+ def guidance_scale(self):
639
+ return self._guidance_scale
640
+
641
+ @property
642
+ def num_timesteps(self):
643
+ return self._num_timesteps
644
+
645
+ @property
646
+ def attention_kwargs(self):
647
+ return self._attention_kwargs
648
+
649
+ @property
650
+ def current_timestep(self):
651
+ return self._current_timestep
652
+
653
+ @property
654
+ def interrupt(self):
655
+ return self._interrupt
656
+
657
+ @torch.no_grad()
658
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
659
+ def __call__(
660
+ self,
661
+ prompt: Union[str, List[str]] = None,
662
+ prompt_2: Union[str, List[str]] = None,
663
+ negative_prompt: Union[str, List[str]] = None,
664
+ negative_prompt_2: Union[str, List[str]] = None,
665
+ height: int = 720,
666
+ width: int = 1280,
667
+ num_frames: int = 129,
668
+ num_inference_steps: int = 50,
669
+ sigmas: List[float] = None,
670
+ true_cfg_scale: float = 1.0,
671
+ guidance_scale: float = 6.0,
672
+ num_videos_per_prompt: Optional[int] = 1,
673
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
674
+ latents: Optional[torch.Tensor] = None,
675
+ prompt_embeds: Optional[torch.Tensor] = None,
676
+ pooled_prompt_embeds: Optional[torch.Tensor] = None,
677
+ prompt_attention_mask: Optional[torch.Tensor] = None,
678
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
679
+ negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
680
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
681
+ output_type: str = "numpy",
682
+ return_dict: bool = False,
683
+ attention_kwargs: Optional[Dict[str, Any]] = None,
684
+ callback_on_step_end: Optional[
685
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
686
+ ] = None,
687
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
688
+ prompt_template: Dict[str, Any] = DEFAULT_PROMPT_TEMPLATE,
689
+ image: PIL.Image.Image = None,
690
+ max_sequence_length: int = 256,
691
+ image_embed_interleave: Optional[int] = None,
692
+ ):
693
+ r"""
694
+ The call function to the pipeline for generation.
695
+
696
+ Args:
697
+ prompt (`str` or `List[str]`, *optional*):
698
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
699
+ instead.
700
+ prompt_2 (`str` or `List[str]`, *optional*):
701
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
702
+ will be used instead.
703
+ negative_prompt (`str` or `List[str]`, *optional*):
704
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
705
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is
706
+ not greater than `1`).
707
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
708
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
709
+ `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders.
710
+ height (`int`, defaults to `720`):
711
+ The height in pixels of the generated image.
712
+ width (`int`, defaults to `1280`):
713
+ The width in pixels of the generated image.
714
+ num_frames (`int`, defaults to `129`):
715
+ The number of frames in the generated video.
716
+ num_inference_steps (`int`, defaults to `50`):
717
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
718
+ expense of slower inference.
719
+ sigmas (`List[float]`, *optional*):
720
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
721
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
722
+ will be used.
723
+ true_cfg_scale (`float`, *optional*, defaults to 1.0):
724
+ When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance.
725
+ guidance_scale (`float`, defaults to `1.0`):
726
+ Guidance scale as defined in [Classifier-Free Diffusion
727
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
728
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
729
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
730
+ the text `prompt`, usually at the expense of lower image quality. Note that the only available
731
+ HunyuanVideo model is CFG-distilled, which means that traditional guidance between unconditional and
732
+ conditional latent is not applied.
733
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
734
+ The number of images to generate per prompt.
735
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
736
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
737
+ generation deterministic.
738
+ latents (`torch.Tensor`, *optional*):
739
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
740
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
741
+ tensor is generated by sampling using the supplied random `generator`.
742
+ prompt_embeds (`torch.Tensor`, *optional*):
743
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
744
+ provided, text embeddings are generated from the `prompt` input argument.
745
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
746
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
747
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
748
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
749
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
750
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
751
+ argument.
752
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
753
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
754
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
755
+ input argument.
756
+ output_type (`str`, *optional*, defaults to `"pil"`):
757
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
758
+ return_dict (`bool`, *optional*, defaults to `True`):
759
+ Whether or not to return a [`HunyuanVideoPipelineOutput`] instead of a plain tuple.
760
+ attention_kwargs (`dict`, *optional*):
761
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
762
+ `self.processor` in
763
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
764
+ clip_skip (`int`, *optional*):
765
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
766
+ the output of the pre-final layer will be used for computing the prompt embeddings.
767
+ callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
768
+ A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
769
+ each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
770
+ DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
771
+ list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
772
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
773
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
774
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
775
+ `._callback_tensor_inputs` attribute of your pipeline class.
776
+
777
+ Examples:
778
+
779
+ Returns:
780
+ [`~HunyuanVideoPipelineOutput`] or `tuple`:
781
+ If `return_dict` is `True`, [`HunyuanVideoPipelineOutput`] is returned, otherwise a `tuple` is returned
782
+ where the first element is a list with the generated images and the second element is a list of `bool`s
783
+ indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
784
+ """
785
+
786
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
787
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
788
+
789
+ # 1. Check inputs. Raise error if not correct
790
+ self.check_inputs(
791
+ prompt,
792
+ prompt_2,
793
+ height,
794
+ width,
795
+ prompt_embeds,
796
+ callback_on_step_end_tensor_inputs,
797
+ prompt_template,
798
+ true_cfg_scale,
799
+ guidance_scale,
800
+ )
801
+
802
+ has_neg_prompt = negative_prompt is not None or (
803
+ negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None
804
+ )
805
+ do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
806
+ image_embed_interleave = (
807
+ image_embed_interleave
808
+ if image_embed_interleave is not None
809
+ else 4
810
+ )
811
+
812
+ self._guidance_scale = guidance_scale
813
+ self._attention_kwargs = attention_kwargs
814
+ self._current_timestep = None
815
+ self._interrupt = False
816
+
817
+ device = self._execution_device
818
+
819
+ # 2. Define call parameters
820
+ if prompt is not None and isinstance(prompt, str):
821
+ batch_size = 1
822
+ elif prompt is not None and isinstance(prompt, list):
823
+ batch_size = len(prompt)
824
+ else:
825
+ batch_size = prompt_embeds.shape[0]
826
+
827
+ # 3. Prepare latent variables
828
+ vae_dtype = self.vae.dtype
829
+ image_tensor = self.video_processor.preprocess(image, height, width).to(device, vae_dtype)
830
+
831
+ num_channels_latents = self.transformer.config.in_channels
832
+
833
+ latents, image_latents = self.prepare_latents(
834
+ image_tensor,
835
+ batch_size * num_videos_per_prompt,
836
+ num_channels_latents,
837
+ height,
838
+ width,
839
+ num_frames,
840
+ torch.float32,
841
+ device,
842
+ generator,
843
+ latents,
844
+ )
845
+
846
+ # 4. Encode input prompt
847
+ transformer_dtype = self.transformer.dtype
848
+ prompt_embeds, pooled_prompt_embeds, prompt_attention_mask = self.encode_prompt(
849
+ image=image,
850
+ prompt=prompt,
851
+ prompt_2=prompt_2,
852
+ prompt_template=prompt_template,
853
+ num_videos_per_prompt=num_videos_per_prompt,
854
+ prompt_embeds=prompt_embeds,
855
+ pooled_prompt_embeds=pooled_prompt_embeds,
856
+ prompt_attention_mask=prompt_attention_mask,
857
+ device=device,
858
+ max_sequence_length=max_sequence_length,
859
+ image_embed_interleave=image_embed_interleave,
860
+ )
861
+ prompt_embeds = prompt_embeds.to(transformer_dtype)
862
+ prompt_attention_mask = prompt_attention_mask.to(transformer_dtype)
863
+ pooled_prompt_embeds = pooled_prompt_embeds.to(transformer_dtype)
864
+
865
+ if do_true_cfg:
866
+ black_image = PIL.Image.new("RGB", (width, height), 0)
867
+ negative_prompt_embeds, negative_pooled_prompt_embeds, negative_prompt_attention_mask = self.encode_prompt(
868
+ image=black_image,
869
+ prompt=negative_prompt,
870
+ prompt_2=negative_prompt_2,
871
+ prompt_template=prompt_template,
872
+ num_videos_per_prompt=num_videos_per_prompt,
873
+ prompt_embeds=negative_prompt_embeds,
874
+ pooled_prompt_embeds=negative_pooled_prompt_embeds,
875
+ prompt_attention_mask=negative_prompt_attention_mask,
876
+ device=device,
877
+ max_sequence_length=max_sequence_length,
878
+ )
879
+ negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
880
+ negative_prompt_attention_mask = negative_prompt_attention_mask.to(transformer_dtype)
881
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.to(transformer_dtype)
882
+
883
+ # 5. Prepare timesteps
884
+ sigmas = np.linspace(1.0, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas
885
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas)
886
+
887
+ # 6. Prepare guidance condition
888
+ guidance = None
889
+ if self.transformer.config.guidance_embeds:
890
+ guidance = (
891
+ torch.tensor([guidance_scale] * latents.shape[0], dtype=transformer_dtype, device=device) * 1000.0
892
+ )
893
+
894
+ # 7. Denoising loop
895
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
896
+ self._num_timesteps = len(timesteps)
897
+
898
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
899
+ for i, t in enumerate(timesteps):
900
+ if self.interrupt:
901
+ continue
902
+
903
+ self._current_timestep = t
904
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
905
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
906
+
907
+ latent_model_input = torch.cat([image_latents, latents[:, :, 1:]], dim=2).to(transformer_dtype)
908
+
909
+ noise_pred = self.transformer(
910
+ hidden_states=latent_model_input,
911
+ timestep=timestep,
912
+ encoder_hidden_states=prompt_embeds,
913
+ encoder_attention_mask=prompt_attention_mask,
914
+ pooled_projections=pooled_prompt_embeds,
915
+ guidance=guidance,
916
+ attention_kwargs=attention_kwargs,
917
+ return_dict=False,
918
+ )[0]
919
+
920
+ if do_true_cfg:
921
+ neg_noise_pred = self.transformer(
922
+ hidden_states=latent_model_input,
923
+ timestep=timestep,
924
+ encoder_hidden_states=negative_prompt_embeds,
925
+ encoder_attention_mask=negative_prompt_attention_mask,
926
+ pooled_projections=negative_pooled_prompt_embeds,
927
+ guidance=guidance,
928
+ attention_kwargs=attention_kwargs,
929
+ return_dict=False,
930
+ )[0]
931
+ noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
932
+
933
+ # compute the previous noisy sample x_t -> x_t-1
934
+ latents = latents = self.scheduler.step(
935
+ noise_pred[:, :, 1:], t, latents[:, :, 1:], return_dict=False
936
+ )[0]
937
+ latents = torch.cat([image_latents, latents], dim=2)
938
+ latents = latents.to(self.vae.dtype)
939
+
940
+ if callback_on_step_end is not None:
941
+ callback_kwargs = {}
942
+ for k in callback_on_step_end_tensor_inputs:
943
+ callback_kwargs[k] = locals()[k]
944
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
945
+
946
+ latents = callback_outputs.pop("latents", latents)
947
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
948
+
949
+ # call the callback, if provided
950
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
951
+ progress_bar.update()
952
+
953
+ if XLA_AVAILABLE:
954
+ xm.mark_step()
955
+
956
+ self._current_timestep = None
957
+
958
+ if output_type == "numpy":
959
+ video = self.decode_latents(latents)
960
+ elif not output_type == "latent":
961
+ video = self.decode_latents(latents)
962
+ video = self.video_processor.postprocess_video(video=video, output_type=output_type)
963
+ else:
964
+ video = latents
965
+
966
+ # Offload all models
967
+ self.maybe_free_model_hooks()
968
+
969
+ if not return_dict:
970
+ video = torch.from_numpy(video)
971
+
972
+ return HunyuanVideoPipelineOutput(videos=video)
videox_fun/pipeline/pipeline_qwenimage.py ADDED
@@ -0,0 +1,767 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py
2
+ # Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import inspect
17
+ from dataclasses import dataclass
18
+ from typing import Any, Callable, Dict, List, Optional, Union
19
+
20
+ import numpy as np
21
+ import PIL.Image
22
+ import torch
23
+ from diffusers.image_processor import VaeImageProcessor
24
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
25
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
26
+ from diffusers.utils import (BaseOutput, is_torch_xla_available, logging,
27
+ replace_example_docstring)
28
+ from diffusers.utils.torch_utils import randn_tensor
29
+
30
+ from ..models import (AutoencoderKLQwenImage,
31
+ Qwen2_5_VLForConditionalGeneration,
32
+ Qwen2Tokenizer, QwenImageTransformer2DModel)
33
+
34
+ if is_torch_xla_available():
35
+ import torch_xla.core.xla_model as xm
36
+
37
+ XLA_AVAILABLE = True
38
+ else:
39
+ XLA_AVAILABLE = False
40
+
41
+
42
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
43
+
44
+ EXAMPLE_DOC_STRING = """
45
+ Examples:
46
+ ```py
47
+ >>> import torch
48
+ >>> from diffusers import QwenImagePipeline
49
+
50
+ >>> pipe = QwenImagePipeline.from_pretrained("Qwen/Qwen-Image", torch_dtype=torch.bfloat16)
51
+ >>> pipe.to("cuda")
52
+ >>> prompt = "A cat holding a sign that says hello world"
53
+ >>> # Depending on the variant being used, the pipeline call will slightly vary.
54
+ >>> # Refer to the pipeline documentation for more details.
55
+ >>> image = pipe(prompt, num_inference_steps=50).images[0]
56
+ >>> image.save("qwenimage.png")
57
+ ```
58
+ """
59
+
60
+
61
+ def calculate_shift(
62
+ image_seq_len,
63
+ base_seq_len: int = 256,
64
+ max_seq_len: int = 4096,
65
+ base_shift: float = 0.5,
66
+ max_shift: float = 1.15,
67
+ ):
68
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
69
+ b = base_shift - m * base_seq_len
70
+ mu = image_seq_len * m + b
71
+ return mu
72
+
73
+
74
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
75
+ def retrieve_timesteps(
76
+ scheduler,
77
+ num_inference_steps: Optional[int] = None,
78
+ device: Optional[Union[str, torch.device]] = None,
79
+ timesteps: Optional[List[int]] = None,
80
+ sigmas: Optional[List[float]] = None,
81
+ **kwargs,
82
+ ):
83
+ r"""
84
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
85
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
86
+
87
+ Args:
88
+ scheduler (`SchedulerMixin`):
89
+ The scheduler to get timesteps from.
90
+ num_inference_steps (`int`):
91
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
92
+ must be `None`.
93
+ device (`str` or `torch.device`, *optional*):
94
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
95
+ timesteps (`List[int]`, *optional*):
96
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
97
+ `num_inference_steps` and `sigmas` must be `None`.
98
+ sigmas (`List[float]`, *optional*):
99
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
100
+ `num_inference_steps` and `timesteps` must be `None`.
101
+
102
+ Returns:
103
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
104
+ second element is the number of inference steps.
105
+ """
106
+ if timesteps is not None and sigmas is not None:
107
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
108
+ if timesteps is not None:
109
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
110
+ if not accepts_timesteps:
111
+ raise ValueError(
112
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
113
+ f" timestep schedules. Please check whether you are using the correct scheduler."
114
+ )
115
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
116
+ timesteps = scheduler.timesteps
117
+ num_inference_steps = len(timesteps)
118
+ elif sigmas is not None:
119
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
120
+ if not accept_sigmas:
121
+ raise ValueError(
122
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
123
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
124
+ )
125
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
126
+ timesteps = scheduler.timesteps
127
+ num_inference_steps = len(timesteps)
128
+ else:
129
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
130
+ timesteps = scheduler.timesteps
131
+ return timesteps, num_inference_steps
132
+
133
+
134
+ @dataclass
135
+ class QwenImagePipelineOutput(BaseOutput):
136
+ """
137
+ Output class for Stable Diffusion pipelines.
138
+
139
+ Args:
140
+ images (`List[PIL.Image.Image]` or `np.ndarray`)
141
+ List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
142
+ num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
143
+ """
144
+
145
+ images: Union[List[PIL.Image.Image], np.ndarray]
146
+
147
+
148
+ class QwenImagePipeline(DiffusionPipeline):
149
+ r"""
150
+ The QwenImage pipeline for text-to-image generation.
151
+
152
+ Args:
153
+ transformer ([`QwenImageTransformer2DModel`]):
154
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
155
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
156
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
157
+ vae ([`AutoencoderKL`]):
158
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
159
+ text_encoder ([`Qwen2.5-VL-7B-Instruct`]):
160
+ [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct), specifically the
161
+ [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) variant.
162
+ tokenizer (`QwenTokenizer`):
163
+ Tokenizer of class
164
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
165
+ """
166
+
167
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
168
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
169
+
170
+ def __init__(
171
+ self,
172
+ scheduler: FlowMatchEulerDiscreteScheduler,
173
+ vae: AutoencoderKLQwenImage,
174
+ text_encoder: Qwen2_5_VLForConditionalGeneration,
175
+ tokenizer: Qwen2Tokenizer,
176
+ transformer: QwenImageTransformer2DModel,
177
+ ):
178
+ super().__init__()
179
+
180
+ self.register_modules(
181
+ vae=vae,
182
+ text_encoder=text_encoder,
183
+ tokenizer=tokenizer,
184
+ transformer=transformer,
185
+ scheduler=scheduler,
186
+ )
187
+ self.vae_scale_factor = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
188
+ # QwenImage latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
189
+ # by the patch size. So the vae scale factor is multiplied by the patch size to account for this
190
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
191
+ self.tokenizer_max_length = 1024
192
+ self.prompt_template_encode = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
193
+ self.prompt_template_encode_start_idx = 34
194
+ self.default_sample_size = 128
195
+
196
+ def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor):
197
+ bool_mask = mask.bool()
198
+ valid_lengths = bool_mask.sum(dim=1)
199
+ selected = hidden_states[bool_mask]
200
+ split_result = torch.split(selected, valid_lengths.tolist(), dim=0)
201
+
202
+ return split_result
203
+
204
+ def _get_qwen_prompt_embeds(
205
+ self,
206
+ prompt: Union[str, List[str]] = None,
207
+ device: Optional[torch.device] = None,
208
+ dtype: Optional[torch.dtype] = None,
209
+ ):
210
+ device = device or self._execution_device
211
+ dtype = dtype or self.text_encoder.dtype
212
+
213
+ prompt = [prompt] if isinstance(prompt, str) else prompt
214
+
215
+ template = self.prompt_template_encode
216
+ drop_idx = self.prompt_template_encode_start_idx
217
+ txt = [template.format(e) for e in prompt]
218
+ txt_tokens = self.tokenizer(
219
+ txt, max_length=self.tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="pt"
220
+ ).to(device)
221
+ encoder_hidden_states = self.text_encoder(
222
+ input_ids=txt_tokens.input_ids,
223
+ attention_mask=txt_tokens.attention_mask,
224
+ output_hidden_states=True,
225
+ )
226
+ hidden_states = encoder_hidden_states.hidden_states[-1]
227
+ split_hidden_states = self._extract_masked_hidden(hidden_states, txt_tokens.attention_mask)
228
+ split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
229
+ attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
230
+ max_seq_len = max([e.size(0) for e in split_hidden_states])
231
+ prompt_embeds = torch.stack(
232
+ [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]
233
+ )
234
+ encoder_attention_mask = torch.stack(
235
+ [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list]
236
+ )
237
+
238
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
239
+
240
+ return prompt_embeds, encoder_attention_mask
241
+
242
+ def encode_prompt(
243
+ self,
244
+ prompt: Union[str, List[str]],
245
+ device: Optional[torch.device] = None,
246
+ num_images_per_prompt: int = 1,
247
+ prompt_embeds: Optional[torch.Tensor] = None,
248
+ prompt_embeds_mask: Optional[torch.Tensor] = None,
249
+ max_sequence_length: int = 1024,
250
+ ):
251
+ r"""
252
+
253
+ Args:
254
+ prompt (`str` or `List[str]`, *optional*):
255
+ prompt to be encoded
256
+ device: (`torch.device`):
257
+ torch device
258
+ num_images_per_prompt (`int`):
259
+ number of images that should be generated per prompt
260
+ prompt_embeds (`torch.Tensor`, *optional*):
261
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
262
+ provided, text embeddings will be generated from `prompt` input argument.
263
+ """
264
+ device = device or self._execution_device
265
+
266
+ prompt = [prompt] if isinstance(prompt, str) else prompt
267
+ batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0]
268
+
269
+ if prompt_embeds is None:
270
+ prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device)
271
+
272
+ prompt_embeds = prompt_embeds[:, :max_sequence_length]
273
+ prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length]
274
+
275
+ _, seq_len, _ = prompt_embeds.shape
276
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
277
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
278
+ prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
279
+ prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
280
+
281
+ return prompt_embeds, prompt_embeds_mask
282
+
283
+ def check_inputs(
284
+ self,
285
+ prompt,
286
+ height,
287
+ width,
288
+ negative_prompt=None,
289
+ prompt_embeds=None,
290
+ negative_prompt_embeds=None,
291
+ prompt_embeds_mask=None,
292
+ negative_prompt_embeds_mask=None,
293
+ callback_on_step_end_tensor_inputs=None,
294
+ max_sequence_length=None,
295
+ ):
296
+ if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
297
+ logger.warning(
298
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
299
+ )
300
+
301
+ if callback_on_step_end_tensor_inputs is not None and not all(
302
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
303
+ ):
304
+ raise ValueError(
305
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
306
+ )
307
+
308
+ if prompt is not None and prompt_embeds is not None:
309
+ raise ValueError(
310
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
311
+ " only forward one of the two."
312
+ )
313
+ elif prompt is None and prompt_embeds is None:
314
+ raise ValueError(
315
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
316
+ )
317
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
318
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
319
+
320
+ if negative_prompt is not None and negative_prompt_embeds is not None:
321
+ raise ValueError(
322
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
323
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
324
+ )
325
+
326
+ if prompt_embeds is not None and prompt_embeds_mask is None:
327
+ raise ValueError(
328
+ "If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`."
329
+ )
330
+ if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None:
331
+ raise ValueError(
332
+ "If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`."
333
+ )
334
+
335
+ if max_sequence_length is not None and max_sequence_length > 1024:
336
+ raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}")
337
+
338
+ @staticmethod
339
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
340
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
341
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
342
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
343
+
344
+ return latents
345
+
346
+ @staticmethod
347
+ def _unpack_latents(latents, height, width, vae_scale_factor):
348
+ batch_size, num_patches, channels = latents.shape
349
+
350
+ # VAE applies 8x compression on images but we must also account for packing which requires
351
+ # latent height and width to be divisible by 2.
352
+ height = 2 * (int(height) // (vae_scale_factor * 2))
353
+ width = 2 * (int(width) // (vae_scale_factor * 2))
354
+
355
+ latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
356
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
357
+
358
+ latents = latents.reshape(batch_size, channels // (2 * 2), 1, height, width)
359
+
360
+ return latents
361
+
362
+ def enable_vae_slicing(self):
363
+ r"""
364
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
365
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
366
+ """
367
+ self.vae.enable_slicing()
368
+
369
+ def disable_vae_slicing(self):
370
+ r"""
371
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
372
+ computing decoding in one step.
373
+ """
374
+ self.vae.disable_slicing()
375
+
376
+ def enable_vae_tiling(self):
377
+ r"""
378
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
379
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
380
+ processing larger images.
381
+ """
382
+ self.vae.enable_tiling()
383
+
384
+ def disable_vae_tiling(self):
385
+ r"""
386
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
387
+ computing decoding in one step.
388
+ """
389
+ self.vae.disable_tiling()
390
+
391
+ def prepare_latents(
392
+ self,
393
+ batch_size,
394
+ num_channels_latents,
395
+ height,
396
+ width,
397
+ dtype,
398
+ device,
399
+ generator,
400
+ latents=None,
401
+ ):
402
+ # VAE applies 8x compression on images but we must also account for packing which requires
403
+ # latent height and width to be divisible by 2.
404
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
405
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
406
+
407
+ shape = (batch_size, 1, num_channels_latents, height, width)
408
+
409
+ if latents is not None:
410
+ return latents.to(device=device, dtype=dtype)
411
+
412
+ if isinstance(generator, list) and len(generator) != batch_size:
413
+ raise ValueError(
414
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
415
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
416
+ )
417
+
418
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
419
+ latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
420
+
421
+ return latents
422
+
423
+ @property
424
+ def guidance_scale(self):
425
+ return self._guidance_scale
426
+
427
+ @property
428
+ def attention_kwargs(self):
429
+ return self._attention_kwargs
430
+
431
+ @property
432
+ def num_timesteps(self):
433
+ return self._num_timesteps
434
+
435
+ @property
436
+ def current_timestep(self):
437
+ return self._current_timestep
438
+
439
+ @property
440
+ def interrupt(self):
441
+ return self._interrupt
442
+
443
+ @torch.no_grad()
444
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
445
+ def __call__(
446
+ self,
447
+ prompt: Union[str, List[str]] = None,
448
+ negative_prompt: Union[str, List[str]] = None,
449
+ true_cfg_scale: float = 4.0,
450
+ height: Optional[int] = None,
451
+ width: Optional[int] = None,
452
+ num_inference_steps: int = 50,
453
+ sigmas: Optional[List[float]] = None,
454
+ guidance_scale: float = 1.0,
455
+ num_images_per_prompt: int = 1,
456
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
457
+ latents: Optional[torch.Tensor] = None,
458
+ prompt_embeds: Optional[torch.Tensor] = None,
459
+ prompt_embeds_mask: Optional[torch.Tensor] = None,
460
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
461
+ negative_prompt_embeds_mask: Optional[torch.Tensor] = None,
462
+ output_type: Optional[str] = "pil",
463
+ return_dict: bool = True,
464
+ attention_kwargs: Optional[Dict[str, Any]] = None,
465
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
466
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
467
+ max_sequence_length: int = 512,
468
+ comfyui_progressbar: bool = False,
469
+ ):
470
+ r"""
471
+ Function invoked when calling the pipeline for generation.
472
+
473
+ Args:
474
+ prompt (`str` or `List[str]`, *optional*):
475
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
476
+ instead.
477
+ negative_prompt (`str` or `List[str]`, *optional*):
478
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
479
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is
480
+ not greater than `1`).
481
+ true_cfg_scale (`float`, *optional*, defaults to 1.0):
482
+ When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance.
483
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
484
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
485
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
486
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
487
+ num_inference_steps (`int`, *optional*, defaults to 50):
488
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
489
+ expense of slower inference.
490
+ sigmas (`List[float]`, *optional*):
491
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
492
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
493
+ will be used.
494
+ guidance_scale (`float`, *optional*, defaults to 3.5):
495
+ Guidance scale as defined in [Classifier-Free Diffusion
496
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
497
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
498
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
499
+ the text `prompt`, usually at the expense of lower image quality.
500
+
501
+ This parameter in the pipeline is there to support future guidance-distilled models when they come up.
502
+ Note that passing `guidance_scale` to the pipeline is ineffective. To enable classifier-free guidance,
503
+ please pass `true_cfg_scale` and `negative_prompt` (even an empty negative prompt like " ") should
504
+ enable classifier-free guidance computations.
505
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
506
+ The number of images to generate per prompt.
507
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
508
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
509
+ to make generation deterministic.
510
+ latents (`torch.Tensor`, *optional*):
511
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
512
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
513
+ tensor will be generated by sampling using the supplied random `generator`.
514
+ prompt_embeds (`torch.Tensor`, *optional*):
515
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
516
+ provided, text embeddings will be generated from `prompt` input argument.
517
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
518
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
519
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
520
+ argument.
521
+ output_type (`str`, *optional*, defaults to `"pil"`):
522
+ The output format of the generate image. Choose between
523
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
524
+ return_dict (`bool`, *optional*, defaults to `True`):
525
+ Whether or not to return a [`~pipelines.qwenimage.QwenImagePipelineOutput`] instead of a plain tuple.
526
+ attention_kwargs (`dict`, *optional*):
527
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
528
+ `self.processor` in
529
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
530
+ callback_on_step_end (`Callable`, *optional*):
531
+ A function that calls at the end of each denoising steps during the inference. The function is called
532
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
533
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
534
+ `callback_on_step_end_tensor_inputs`.
535
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
536
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
537
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
538
+ `._callback_tensor_inputs` attribute of your pipeline class.
539
+ max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
540
+
541
+ Examples:
542
+
543
+ Returns:
544
+ [`~pipelines.qwenimage.QwenImagePipelineOutput`] or `tuple`:
545
+ [`~pipelines.qwenimage.QwenImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When
546
+ returning a tuple, the first element is a list with the generated images.
547
+ """
548
+
549
+ height = height or self.default_sample_size * self.vae_scale_factor
550
+ width = width or self.default_sample_size * self.vae_scale_factor
551
+
552
+ # 1. Check inputs. Raise error if not correct
553
+ self.check_inputs(
554
+ prompt,
555
+ height,
556
+ width,
557
+ negative_prompt=negative_prompt,
558
+ prompt_embeds=prompt_embeds,
559
+ negative_prompt_embeds=negative_prompt_embeds,
560
+ prompt_embeds_mask=prompt_embeds_mask,
561
+ negative_prompt_embeds_mask=negative_prompt_embeds_mask,
562
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
563
+ max_sequence_length=max_sequence_length,
564
+ )
565
+
566
+ self._guidance_scale = guidance_scale
567
+ self._attention_kwargs = attention_kwargs
568
+ self._current_timestep = None
569
+ self._interrupt = False
570
+
571
+ # 2. Define call parameters
572
+ if prompt is not None and isinstance(prompt, str):
573
+ batch_size = 1
574
+ elif prompt is not None and isinstance(prompt, list):
575
+ batch_size = len(prompt)
576
+ else:
577
+ batch_size = prompt_embeds.shape[0]
578
+
579
+ device = self._execution_device
580
+ if comfyui_progressbar:
581
+ from comfy.utils import ProgressBar
582
+ pbar = ProgressBar(num_inference_steps + 2)
583
+
584
+ has_neg_prompt = negative_prompt is not None or (
585
+ negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None
586
+ )
587
+ do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
588
+ prompt_embeds, prompt_embeds_mask = self.encode_prompt(
589
+ prompt=prompt,
590
+ prompt_embeds=prompt_embeds,
591
+ prompt_embeds_mask=prompt_embeds_mask,
592
+ device=device,
593
+ num_images_per_prompt=num_images_per_prompt,
594
+ max_sequence_length=max_sequence_length,
595
+ )
596
+ if do_true_cfg:
597
+ negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt(
598
+ prompt=negative_prompt,
599
+ prompt_embeds=negative_prompt_embeds,
600
+ prompt_embeds_mask=negative_prompt_embeds_mask,
601
+ device=device,
602
+ num_images_per_prompt=num_images_per_prompt,
603
+ max_sequence_length=max_sequence_length,
604
+ )
605
+
606
+ # 4. Prepare latent variables
607
+ num_channels_latents = self.transformer.config.in_channels // 4
608
+ latents = self.prepare_latents(
609
+ batch_size * num_images_per_prompt,
610
+ num_channels_latents,
611
+ height,
612
+ width,
613
+ prompt_embeds.dtype,
614
+ device,
615
+ generator,
616
+ latents,
617
+ )
618
+ img_shapes = [[(1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2)]] * batch_size
619
+ if comfyui_progressbar:
620
+ pbar.update(1)
621
+
622
+ # 5. Prepare timesteps
623
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
624
+ image_seq_len = latents.shape[1]
625
+ mu = calculate_shift(
626
+ image_seq_len,
627
+ self.scheduler.config.get("base_image_seq_len", 256),
628
+ self.scheduler.config.get("max_image_seq_len", 4096),
629
+ self.scheduler.config.get("base_shift", 0.5),
630
+ self.scheduler.config.get("max_shift", 1.15),
631
+ )
632
+ timesteps, num_inference_steps = retrieve_timesteps(
633
+ self.scheduler,
634
+ num_inference_steps,
635
+ device,
636
+ sigmas=sigmas,
637
+ mu=mu,
638
+ )
639
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
640
+ self._num_timesteps = len(timesteps)
641
+
642
+ # handle guidance
643
+ if self.transformer.config.guidance_embeds:
644
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
645
+ guidance = guidance.expand(latents.shape[0])
646
+ else:
647
+ guidance = None
648
+
649
+ if self.attention_kwargs is None:
650
+ self._attention_kwargs = {}
651
+
652
+ txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None
653
+ negative_txt_seq_lens = (
654
+ negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None
655
+ )
656
+ if comfyui_progressbar:
657
+ pbar.update(1)
658
+
659
+ # 6. Denoising loop
660
+ self.scheduler.set_begin_index(0)
661
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
662
+ for i, t in enumerate(timesteps):
663
+ self.transformer.current_steps = i
664
+ if self.interrupt:
665
+ continue
666
+
667
+ if do_true_cfg:
668
+ latent_model_input = torch.cat([latents] * 2)
669
+ prompt_embeds_mask_input = [_negative_prompt_embeds_mask for _negative_prompt_embeds_mask in negative_prompt_embeds_mask] + [_prompt_embeds_mask for _prompt_embeds_mask in prompt_embeds_mask]
670
+ prompt_embeds_input = [_negative_prompt_embeds for _negative_prompt_embeds in negative_prompt_embeds] + [_prompt_embeds for _prompt_embeds in prompt_embeds]
671
+ img_shapes_input = img_shapes * 2
672
+ txt_seq_lens_input = negative_txt_seq_lens + txt_seq_lens
673
+ else:
674
+ latent_model_input = latents
675
+ prompt_embeds_mask_input = prompt_embeds_mask
676
+ prompt_embeds_input = prompt_embeds
677
+ img_shapes_input = img_shapes
678
+ txt_seq_lens_input = txt_seq_lens
679
+
680
+ if hasattr(self.scheduler, "scale_model_input"):
681
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
682
+
683
+ # handle guidance
684
+ if self.transformer.config.guidance_embeds:
685
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
686
+ guidance = guidance.expand(latent_model_input.shape[0])
687
+ else:
688
+ guidance = None
689
+
690
+ self._current_timestep = t
691
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
692
+ timestep = t.expand(latent_model_input.shape[0]).to(latent_model_input.dtype)
693
+
694
+ with torch.cuda.amp.autocast(dtype=latents.dtype), torch.cuda.device(device=latents.device):
695
+ noise_pred = self.transformer.forward_bs(
696
+ x=latent_model_input,
697
+ timestep=timestep / 1000,
698
+ guidance=guidance,
699
+ encoder_hidden_states_mask=prompt_embeds_mask_input,
700
+ encoder_hidden_states=prompt_embeds_input,
701
+ img_shapes=img_shapes_input,
702
+ txt_seq_lens=txt_seq_lens_input,
703
+ attention_kwargs=self.attention_kwargs,
704
+ return_dict=False,
705
+ )
706
+
707
+ if do_true_cfg:
708
+ neg_noise_pred, noise_pred = noise_pred.chunk(2)
709
+ comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
710
+
711
+ cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True)
712
+ noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True)
713
+ noise_pred = comb_pred * (cond_norm / noise_norm)
714
+
715
+ # compute the previous noisy sample x_t -> x_t-1
716
+ latents_dtype = latents.dtype
717
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
718
+
719
+ if latents.dtype != latents_dtype:
720
+ if torch.backends.mps.is_available():
721
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
722
+ latents = latents.to(latents_dtype)
723
+
724
+ if callback_on_step_end is not None:
725
+ callback_kwargs = {}
726
+ for k in callback_on_step_end_tensor_inputs:
727
+ callback_kwargs[k] = locals()[k]
728
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
729
+
730
+ latents = callback_outputs.pop("latents", latents)
731
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
732
+
733
+ # call the callback, if provided
734
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
735
+ progress_bar.update()
736
+
737
+ if XLA_AVAILABLE:
738
+ xm.mark_step()
739
+
740
+ if comfyui_progressbar:
741
+ pbar.update(1)
742
+
743
+ self._current_timestep = None
744
+ if output_type == "latent":
745
+ image = latents
746
+ else:
747
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
748
+ latents = latents.to(self.vae.dtype)
749
+ latents_mean = (
750
+ torch.tensor(self.vae.config.latents_mean)
751
+ .view(1, self.vae.config.z_dim, 1, 1, 1)
752
+ .to(latents.device, latents.dtype)
753
+ )
754
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
755
+ latents.device, latents.dtype
756
+ )
757
+ latents = latents / latents_std + latents_mean
758
+ image = self.vae.decode(latents, return_dict=False)[0][:, :, 0]
759
+ image = self.image_processor.postprocess(image, output_type=output_type)
760
+
761
+ # Offload all models
762
+ self.maybe_free_model_hooks()
763
+
764
+ if not return_dict:
765
+ return (image,)
766
+
767
+ return QwenImagePipelineOutput(images=image)
videox_fun/pipeline/pipeline_qwenimage_edit.py ADDED
@@ -0,0 +1,952 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from https://github.com/naykun/diffusers/blob/main/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py
2
+ # Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import inspect
17
+ from dataclasses import dataclass
18
+ from typing import Any, Callable, Dict, List, Optional, Union
19
+
20
+ import numpy as np
21
+ import math
22
+ import PIL.Image
23
+ import torch
24
+ from diffusers.image_processor import VaeImageProcessor
25
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
26
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
27
+ from diffusers.utils import (BaseOutput, is_torch_xla_available, logging,
28
+ replace_example_docstring)
29
+ from diffusers.utils.torch_utils import randn_tensor
30
+
31
+ from ..models import (AutoencoderKLQwenImage,
32
+ Qwen2_5_VLForConditionalGeneration, Qwen2VLProcessor,
33
+ Qwen2Tokenizer, QwenImageTransformer2DModel)
34
+
35
+ if is_torch_xla_available():
36
+ import torch_xla.core.xla_model as xm
37
+
38
+ XLA_AVAILABLE = True
39
+ else:
40
+ XLA_AVAILABLE = False
41
+
42
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
43
+
44
+ EXAMPLE_DOC_STRING = """
45
+ Examples:
46
+ ```py
47
+ >>> import torch
48
+ >>> from PIL import Image
49
+ >>> from diffusers import QwenImageEditPipeline
50
+
51
+ >>> pipe = QwenImageEditPipeline.from_pretrained("Qwen/Qwen-Image-Edit", torch_dtype=torch.bfloat16)
52
+ >>> pipe.to("cuda")
53
+ >>> prompt = "Change the cat to a dog"
54
+ >>> image = Image.open("cat.png")
55
+ >>> # Depending on the variant being used, the pipeline call will slightly vary.
56
+ >>> # Refer to the pipeline documentation for more details.
57
+ >>> image = pipe(image, prompt, num_inference_steps=50).images[0]
58
+ >>> image.save("qwenimageedit.png")
59
+ ```
60
+ """
61
+ PREFERRED_QWENIMAGE_RESOLUTIONS = [
62
+ (672, 1568),
63
+ (688, 1504),
64
+ (720, 1456),
65
+ (752, 1392),
66
+ (800, 1328),
67
+ (832, 1248),
68
+ (880, 1184),
69
+ (944, 1104),
70
+ (1024, 1024),
71
+ (1104, 944),
72
+ (1184, 880),
73
+ (1248, 832),
74
+ (1328, 800),
75
+ (1392, 752),
76
+ (1456, 720),
77
+ (1504, 688),
78
+ (1568, 672),
79
+ ]
80
+
81
+ # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.calculate_shift
82
+ def calculate_shift(
83
+ image_seq_len,
84
+ base_seq_len: int = 256,
85
+ max_seq_len: int = 4096,
86
+ base_shift: float = 0.5,
87
+ max_shift: float = 1.15,
88
+ ):
89
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
90
+ b = base_shift - m * base_seq_len
91
+ mu = image_seq_len * m + b
92
+ return mu
93
+
94
+
95
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
96
+ def retrieve_timesteps(
97
+ scheduler,
98
+ num_inference_steps: Optional[int] = None,
99
+ device: Optional[Union[str, torch.device]] = None,
100
+ timesteps: Optional[List[int]] = None,
101
+ sigmas: Optional[List[float]] = None,
102
+ **kwargs,
103
+ ):
104
+ r"""
105
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
106
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
107
+
108
+ Args:
109
+ scheduler (`SchedulerMixin`):
110
+ The scheduler to get timesteps from.
111
+ num_inference_steps (`int`):
112
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
113
+ must be `None`.
114
+ device (`str` or `torch.device`, *optional*):
115
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
116
+ timesteps (`List[int]`, *optional*):
117
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
118
+ `num_inference_steps` and `sigmas` must be `None`.
119
+ sigmas (`List[float]`, *optional*):
120
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
121
+ `num_inference_steps` and `timesteps` must be `None`.
122
+
123
+ Returns:
124
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
125
+ second element is the number of inference steps.
126
+ """
127
+ if timesteps is not None and sigmas is not None:
128
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
129
+ if timesteps is not None:
130
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
131
+ if not accepts_timesteps:
132
+ raise ValueError(
133
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
134
+ f" timestep schedules. Please check whether you are using the correct scheduler."
135
+ )
136
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
137
+ timesteps = scheduler.timesteps
138
+ num_inference_steps = len(timesteps)
139
+ elif sigmas is not None:
140
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
141
+ if not accept_sigmas:
142
+ raise ValueError(
143
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
144
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
145
+ )
146
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
147
+ timesteps = scheduler.timesteps
148
+ num_inference_steps = len(timesteps)
149
+ else:
150
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
151
+ timesteps = scheduler.timesteps
152
+ return timesteps, num_inference_steps
153
+
154
+
155
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
156
+ def retrieve_latents(
157
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
158
+ ):
159
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
160
+ return encoder_output.latent_dist.sample(generator)
161
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
162
+ return encoder_output.latent_dist.mode()
163
+ elif hasattr(encoder_output, "latents"):
164
+ return encoder_output.latents
165
+ else:
166
+ raise AttributeError("Could not access latents of provided encoder_output")
167
+
168
+
169
+ def calculate_dimensions(target_area, ratio):
170
+ width = math.sqrt(target_area * ratio)
171
+ height = width / ratio
172
+
173
+ width = round(width / 32) * 32
174
+ height = round(height / 32) * 32
175
+
176
+ return width, height
177
+
178
+
179
+ @dataclass
180
+ class QwenImagePipelineOutput(BaseOutput):
181
+ """
182
+ Output class for Stable Diffusion pipelines.
183
+
184
+ Args:
185
+ images (`List[PIL.Image.Image]` or `np.ndarray`)
186
+ List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
187
+ num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
188
+ """
189
+
190
+ images: Union[List[PIL.Image.Image], np.ndarray]
191
+
192
+
193
+ class QwenImageEditPipeline(DiffusionPipeline):
194
+ r"""
195
+ The QwenImage pipeline for text-to-image generation.
196
+
197
+ Args:
198
+ transformer ([`QwenImageTransformer2DModel`]):
199
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
200
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
201
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
202
+ vae ([`AutoencoderKL`]):
203
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
204
+ text_encoder ([`Qwen2.5-VL-7B-Instruct`]):
205
+ [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct), specifically the
206
+ [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) variant.
207
+ tokenizer (`QwenTokenizer`):
208
+ Tokenizer of class
209
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
210
+ """
211
+
212
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
213
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
214
+
215
+ def __init__(
216
+ self,
217
+ scheduler: FlowMatchEulerDiscreteScheduler,
218
+ vae: AutoencoderKLQwenImage,
219
+ text_encoder: Qwen2_5_VLForConditionalGeneration,
220
+ tokenizer: Qwen2Tokenizer,
221
+ processor: Qwen2VLProcessor,
222
+ transformer: QwenImageTransformer2DModel,
223
+ ):
224
+ super().__init__()
225
+
226
+ self.register_modules(
227
+ vae=vae,
228
+ text_encoder=text_encoder,
229
+ tokenizer=tokenizer,
230
+ processor=processor,
231
+ transformer=transformer,
232
+ scheduler=scheduler,
233
+ )
234
+ self.vae_scale_factor = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
235
+ self.latent_channels = self.vae.config.z_dim if getattr(self, "vae", None) else 16
236
+ # QwenImage latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
237
+ # by the patch size. So the vae scale factor is multiplied by the patch size to account for this
238
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
239
+ self.tokenizer_max_length = 1024
240
+
241
+ self.prompt_template_encode = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n"
242
+ self.prompt_template_encode_start_idx = 64
243
+ self.default_sample_size = 128
244
+
245
+ # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._extract_masked_hidden
246
+ def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor):
247
+ bool_mask = mask.bool()
248
+ valid_lengths = bool_mask.sum(dim=1)
249
+ selected = hidden_states[bool_mask]
250
+ split_result = torch.split(selected, valid_lengths.tolist(), dim=0)
251
+
252
+ return split_result
253
+
254
+ def _get_qwen_prompt_embeds(
255
+ self,
256
+ prompt: Union[str, List[str]] = None,
257
+ image: Optional[torch.Tensor] = None,
258
+ device: Optional[torch.device] = None,
259
+ dtype: Optional[torch.dtype] = None,
260
+ ):
261
+ device = device or self._execution_device
262
+ dtype = dtype or self.text_encoder.dtype
263
+
264
+ prompt = [prompt] if isinstance(prompt, str) else prompt
265
+
266
+ template = self.prompt_template_encode
267
+ drop_idx = self.prompt_template_encode_start_idx
268
+ txt = [template.format(e) for e in prompt]
269
+
270
+ model_inputs = self.processor(
271
+ text=txt,
272
+ images=image,
273
+ padding=True,
274
+ return_tensors="pt",
275
+ ).to(device)
276
+
277
+ outputs = self.text_encoder(
278
+ input_ids=model_inputs.input_ids,
279
+ attention_mask=model_inputs.attention_mask,
280
+ pixel_values=model_inputs.pixel_values,
281
+ image_grid_thw=model_inputs.image_grid_thw,
282
+ output_hidden_states=True,
283
+ )
284
+
285
+ hidden_states = outputs.hidden_states[-1]
286
+ split_hidden_states = self._extract_masked_hidden(hidden_states, model_inputs.attention_mask)
287
+ split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
288
+ attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
289
+ max_seq_len = max([e.size(0) for e in split_hidden_states])
290
+ prompt_embeds = torch.stack(
291
+ [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]
292
+ )
293
+ encoder_attention_mask = torch.stack(
294
+ [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list]
295
+ )
296
+
297
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
298
+
299
+ return prompt_embeds, encoder_attention_mask
300
+
301
+ def encode_prompt(
302
+ self,
303
+ prompt: Union[str, List[str]],
304
+ image: Optional[torch.Tensor] = None,
305
+ device: Optional[torch.device] = None,
306
+ num_images_per_prompt: int = 1,
307
+ prompt_embeds: Optional[torch.Tensor] = None,
308
+ prompt_embeds_mask: Optional[torch.Tensor] = None,
309
+ max_sequence_length: int = 1024,
310
+ ):
311
+ r"""
312
+
313
+ Args:
314
+ prompt (`str` or `List[str]`, *optional*):
315
+ prompt to be encoded
316
+ image (`torch.Tensor`, *optional*):
317
+ image to be encoded
318
+ device: (`torch.device`):
319
+ torch device
320
+ num_images_per_prompt (`int`):
321
+ number of images that should be generated per prompt
322
+ prompt_embeds (`torch.Tensor`, *optional*):
323
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
324
+ provided, text embeddings will be generated from `prompt` input argument.
325
+ """
326
+ device = device or self._execution_device
327
+
328
+ prompt = [prompt] if isinstance(prompt, str) else prompt
329
+ batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0]
330
+
331
+ if prompt_embeds is None:
332
+ prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, image, device)
333
+
334
+ _, seq_len, _ = prompt_embeds.shape
335
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
336
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
337
+ prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
338
+ prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
339
+
340
+ return prompt_embeds, prompt_embeds_mask
341
+
342
+ def check_inputs(
343
+ self,
344
+ prompt,
345
+ height,
346
+ width,
347
+ negative_prompt=None,
348
+ prompt_embeds=None,
349
+ negative_prompt_embeds=None,
350
+ prompt_embeds_mask=None,
351
+ negative_prompt_embeds_mask=None,
352
+ callback_on_step_end_tensor_inputs=None,
353
+ max_sequence_length=None,
354
+ ):
355
+ if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
356
+ logger.warning(
357
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
358
+ )
359
+
360
+ if callback_on_step_end_tensor_inputs is not None and not all(
361
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
362
+ ):
363
+ raise ValueError(
364
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
365
+ )
366
+
367
+ if prompt is not None and prompt_embeds is not None:
368
+ raise ValueError(
369
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
370
+ " only forward one of the two."
371
+ )
372
+ elif prompt is None and prompt_embeds is None:
373
+ raise ValueError(
374
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
375
+ )
376
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
377
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
378
+
379
+ if negative_prompt is not None and negative_prompt_embeds is not None:
380
+ raise ValueError(
381
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
382
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
383
+ )
384
+
385
+ if prompt_embeds is not None and prompt_embeds_mask is None:
386
+ raise ValueError(
387
+ "If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`."
388
+ )
389
+ if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None:
390
+ raise ValueError(
391
+ "If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`."
392
+ )
393
+
394
+ if max_sequence_length is not None and max_sequence_length > 1024:
395
+ raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}")
396
+
397
+ @staticmethod
398
+ # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._pack_latents
399
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
400
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
401
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
402
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
403
+
404
+ return latents
405
+
406
+ @staticmethod
407
+ # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._unpack_latents
408
+ def _unpack_latents(latents, height, width, vae_scale_factor):
409
+ batch_size, num_patches, channels = latents.shape
410
+
411
+ # VAE applies 8x compression on images but we must also account for packing which requires
412
+ # latent height and width to be divisible by 2.
413
+ height = 2 * (int(height) // (vae_scale_factor * 2))
414
+ width = 2 * (int(width) // (vae_scale_factor * 2))
415
+
416
+ latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
417
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
418
+
419
+ latents = latents.reshape(batch_size, channels // (2 * 2), 1, height, width)
420
+
421
+ return latents
422
+
423
+ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
424
+ if isinstance(generator, list):
425
+ image_latents = [
426
+ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i], sample_mode="argmax")
427
+ for i in range(image.shape[0])
428
+ ]
429
+ image_latents = torch.cat(image_latents, dim=0)
430
+ else:
431
+ image_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode="argmax")
432
+ latents_mean = (
433
+ torch.tensor(self.vae.config.latents_mean)
434
+ .view(1, self.latent_channels, 1, 1, 1)
435
+ .to(image_latents.device, image_latents.dtype)
436
+ )
437
+ latents_std = (
438
+ torch.tensor(self.vae.config.latents_std)
439
+ .view(1, self.latent_channels, 1, 1, 1)
440
+ .to(image_latents.device, image_latents.dtype)
441
+ )
442
+ image_latents = (image_latents - latents_mean) / latents_std
443
+
444
+ return image_latents
445
+
446
+ def enable_vae_slicing(self):
447
+ r"""
448
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
449
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
450
+ """
451
+ depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`."
452
+ deprecate(
453
+ "enable_vae_slicing",
454
+ "0.40.0",
455
+ depr_message,
456
+ )
457
+ self.vae.enable_slicing()
458
+
459
+ def disable_vae_slicing(self):
460
+ r"""
461
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
462
+ computing decoding in one step.
463
+ """
464
+ depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`."
465
+ deprecate(
466
+ "disable_vae_slicing",
467
+ "0.40.0",
468
+ depr_message,
469
+ )
470
+ self.vae.disable_slicing()
471
+
472
+ def enable_vae_tiling(self):
473
+ r"""
474
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
475
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
476
+ processing larger images.
477
+ """
478
+ depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`."
479
+ deprecate(
480
+ "enable_vae_tiling",
481
+ "0.40.0",
482
+ depr_message,
483
+ )
484
+ self.vae.enable_tiling()
485
+
486
+ def disable_vae_tiling(self):
487
+ r"""
488
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
489
+ computing decoding in one step.
490
+ """
491
+ depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`."
492
+ deprecate(
493
+ "disable_vae_tiling",
494
+ "0.40.0",
495
+ depr_message,
496
+ )
497
+ self.vae.disable_tiling()
498
+
499
+ def prepare_latents(
500
+ self,
501
+ image,
502
+ batch_size,
503
+ num_channels_latents,
504
+ height,
505
+ width,
506
+ dtype,
507
+ device,
508
+ generator,
509
+ latents=None,
510
+ ):
511
+ # VAE applies 8x compression on images but we must also account for packing which requires
512
+ # latent height and width to be divisible by 2.
513
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
514
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
515
+
516
+ shape = (batch_size, 1, num_channels_latents, height, width)
517
+
518
+ image_latents = None
519
+ if image is not None:
520
+ image = image.to(device=device, dtype=dtype)
521
+ if image.shape[1] != self.latent_channels:
522
+ image_latents = self._encode_vae_image(image=image, generator=generator)
523
+ else:
524
+ image_latents = image
525
+ if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
526
+ # expand init_latents for batch_size
527
+ additional_image_per_prompt = batch_size // image_latents.shape[0]
528
+ image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
529
+ elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
530
+ raise ValueError(
531
+ f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
532
+ )
533
+ else:
534
+ image_latents = torch.cat([image_latents], dim=0)
535
+
536
+ image_latent_height, image_latent_width = image_latents.shape[3:]
537
+ image_latents = self._pack_latents(
538
+ image_latents, batch_size, num_channels_latents, image_latent_height, image_latent_width
539
+ )
540
+
541
+ if isinstance(generator, list) and len(generator) != batch_size:
542
+ raise ValueError(
543
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
544
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
545
+ )
546
+ if latents is None:
547
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
548
+ latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
549
+ else:
550
+ latents = latents.to(device=device, dtype=dtype)
551
+
552
+ return latents, image_latents
553
+
554
+ @property
555
+ def guidance_scale(self):
556
+ return self._guidance_scale
557
+
558
+ @property
559
+ def attention_kwargs(self):
560
+ return self._attention_kwargs
561
+
562
+ @property
563
+ def num_timesteps(self):
564
+ return self._num_timesteps
565
+
566
+ @property
567
+ def current_timestep(self):
568
+ return self._current_timestep
569
+
570
+ @property
571
+ def interrupt(self):
572
+ return self._interrupt
573
+
574
+ @torch.no_grad()
575
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
576
+ def __call__(
577
+ self,
578
+ image = None,
579
+ prompt: Union[str, List[str]] = None,
580
+ negative_prompt: Union[str, List[str]] = None,
581
+ true_cfg_scale: float = 4.0,
582
+ height: Optional[int] = None,
583
+ width: Optional[int] = None,
584
+ num_inference_steps: int = 50,
585
+ sigmas: Optional[List[float]] = None,
586
+ guidance_scale: Optional[float] = None,
587
+ num_images_per_prompt: int = 1,
588
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
589
+ latents: Optional[torch.Tensor] = None,
590
+ prompt_embeds: Optional[torch.Tensor] = None,
591
+ prompt_embeds_mask: Optional[torch.Tensor] = None,
592
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
593
+ negative_prompt_embeds_mask: Optional[torch.Tensor] = None,
594
+ output_type: Optional[str] = "pil",
595
+ return_dict: bool = True,
596
+ attention_kwargs: Optional[Dict[str, Any]] = None,
597
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
598
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
599
+ max_sequence_length: int = 512,
600
+ comfyui_progressbar: bool = False,
601
+ ):
602
+ r"""
603
+ Function invoked when calling the pipeline for generation.
604
+
605
+ Args:
606
+ image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
607
+ `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
608
+ numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
609
+ or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
610
+ list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
611
+ latents as `image`, but if passing latents directly it is not encoded again.
612
+ prompt (`str` or `List[str]`, *optional*):
613
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
614
+ instead.
615
+ negative_prompt (`str` or `List[str]`, *optional*):
616
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
617
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is
618
+ not greater than `1`).
619
+ true_cfg_scale (`float`, *optional*, defaults to 1.0):
620
+ true_cfg_scale (`float`, *optional*, defaults to 1.0): Guidance scale as defined in [Classifier-Free
621
+ Diffusion Guidance](https://huggingface.co/papers/2207.12598). `true_cfg_scale` is defined as `w` of
622
+ equation 2. of [Imagen Paper](https://huggingface.co/papers/2205.11487). Classifier-free guidance is
623
+ enabled by setting `true_cfg_scale > 1` and a provided `negative_prompt`. Higher guidance scale
624
+ encourages to generate images that are closely linked to the text `prompt`, usually at the expense of
625
+ lower image quality.
626
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
627
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
628
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
629
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
630
+ num_inference_steps (`int`, *optional*, defaults to 50):
631
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
632
+ expense of slower inference.
633
+ sigmas (`List[float]`, *optional*):
634
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
635
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
636
+ will be used.
637
+ guidance_scale (`float`, *optional*, defaults to None):
638
+ A guidance scale value for guidance distilled models. Unlike the traditional classifier-free guidance
639
+ where the guidance scale is applied during inference through noise prediction rescaling, guidance
640
+ distilled models take the guidance scale directly as an input parameter during forward pass. Guidance
641
+ scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images
642
+ that are closely linked to the text `prompt`, usually at the expense of lower image quality. This
643
+ parameter in the pipeline is there to support future guidance-distilled models when they come up. It is
644
+ ignored when not using guidance distilled models. To enable traditional classifier-free guidance,
645
+ please pass `true_cfg_scale > 1.0` and `negative_prompt` (even an empty negative prompt like " " should
646
+ enable classifier-free guidance computations).
647
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
648
+ The number of images to generate per prompt.
649
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
650
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
651
+ to make generation deterministic.
652
+ latents (`torch.Tensor`, *optional*):
653
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
654
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
655
+ tensor will be generated by sampling using the supplied random `generator`.
656
+ prompt_embeds (`torch.Tensor`, *optional*):
657
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
658
+ provided, text embeddings will be generated from `prompt` input argument.
659
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
660
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
661
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
662
+ argument.
663
+ output_type (`str`, *optional*, defaults to `"pil"`):
664
+ The output format of the generate image. Choose between
665
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
666
+ return_dict (`bool`, *optional*, defaults to `True`):
667
+ Whether or not to return a [`~pipelines.qwenimage.QwenImagePipelineOutput`] instead of a plain tuple.
668
+ attention_kwargs (`dict`, *optional*):
669
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
670
+ `self.processor` in
671
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
672
+ callback_on_step_end (`Callable`, *optional*):
673
+ A function that calls at the end of each denoising steps during the inference. The function is called
674
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
675
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
676
+ `callback_on_step_end_tensor_inputs`.
677
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
678
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
679
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
680
+ `._callback_tensor_inputs` attribute of your pipeline class.
681
+ max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
682
+
683
+ Examples:
684
+
685
+ Returns:
686
+ [`~pipelines.qwenimage.QwenImagePipelineOutput`] or `tuple`:
687
+ [`~pipelines.qwenimage.QwenImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When
688
+ returning a tuple, the first element is a list with the generated images.
689
+ """
690
+ image_size = image[0].size if isinstance(image, list) else image.size
691
+ calculated_width, calculated_height = calculate_dimensions(1024 * 1024, image_size[0] / image_size[1])
692
+ height = height or calculated_height
693
+ width = width or calculated_width
694
+
695
+ multiple_of = self.vae_scale_factor * 2
696
+ width = width // multiple_of * multiple_of
697
+ height = height // multiple_of * multiple_of
698
+
699
+ # 1. Check inputs. Raise error if not correct
700
+ self.check_inputs(
701
+ prompt,
702
+ height,
703
+ width,
704
+ negative_prompt=negative_prompt,
705
+ prompt_embeds=prompt_embeds,
706
+ negative_prompt_embeds=negative_prompt_embeds,
707
+ prompt_embeds_mask=prompt_embeds_mask,
708
+ negative_prompt_embeds_mask=negative_prompt_embeds_mask,
709
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
710
+ max_sequence_length=max_sequence_length,
711
+ )
712
+
713
+ self._guidance_scale = guidance_scale
714
+ self._attention_kwargs = attention_kwargs
715
+ self._current_timestep = None
716
+ self._interrupt = False
717
+
718
+ # 2. Define call parameters
719
+ if prompt is not None and isinstance(prompt, str):
720
+ batch_size = 1
721
+ elif prompt is not None and isinstance(prompt, list):
722
+ batch_size = len(prompt)
723
+ else:
724
+ batch_size = prompt_embeds.shape[0]
725
+
726
+ device = self._execution_device
727
+ if comfyui_progressbar:
728
+ from comfy.utils import ProgressBar
729
+ pbar = ProgressBar(num_inference_steps + 2)
730
+
731
+ # 3. Preprocess image
732
+ if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels):
733
+ image = self.image_processor.resize(image, calculated_height, calculated_width)
734
+ prompt_image = image
735
+ image = self.image_processor.preprocess(image, calculated_height, calculated_width)
736
+ image = image.unsqueeze(2)
737
+
738
+ has_neg_prompt = negative_prompt is not None or (
739
+ negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None
740
+ )
741
+
742
+ if true_cfg_scale > 1 and not has_neg_prompt:
743
+ logger.warning(
744
+ f"true_cfg_scale is passed as {true_cfg_scale}, but classifier-free guidance is not enabled since no negative_prompt is provided."
745
+ )
746
+ elif true_cfg_scale <= 1 and has_neg_prompt:
747
+ logger.warning(
748
+ " negative_prompt is passed but classifier-free guidance is not enabled since true_cfg_scale <= 1"
749
+ )
750
+
751
+ do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
752
+ prompt_embeds, prompt_embeds_mask = self.encode_prompt(
753
+ image=prompt_image,
754
+ prompt=prompt,
755
+ prompt_embeds=prompt_embeds,
756
+ prompt_embeds_mask=prompt_embeds_mask,
757
+ device=device,
758
+ num_images_per_prompt=num_images_per_prompt,
759
+ max_sequence_length=max_sequence_length,
760
+ )
761
+ if do_true_cfg:
762
+ negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt(
763
+ image=prompt_image,
764
+ prompt=negative_prompt,
765
+ prompt_embeds=negative_prompt_embeds,
766
+ prompt_embeds_mask=negative_prompt_embeds_mask,
767
+ device=device,
768
+ num_images_per_prompt=num_images_per_prompt,
769
+ max_sequence_length=max_sequence_length,
770
+ )
771
+ if comfyui_progressbar:
772
+ pbar.update(1)
773
+
774
+ # 4. Prepare latent variables
775
+ num_channels_latents = self.transformer.config.in_channels // 4
776
+ latents, image_latents = self.prepare_latents(
777
+ image,
778
+ batch_size * num_images_per_prompt,
779
+ num_channels_latents,
780
+ height,
781
+ width,
782
+ prompt_embeds.dtype,
783
+ device,
784
+ generator,
785
+ latents,
786
+ )
787
+ img_shapes = [
788
+ [
789
+ (1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2),
790
+ (1, calculated_height // self.vae_scale_factor // 2, calculated_width // self.vae_scale_factor // 2),
791
+ ]
792
+ ] * batch_size
793
+
794
+ # 5. Prepare timesteps
795
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
796
+ image_seq_len = latents.shape[1]
797
+ mu = calculate_shift(
798
+ image_seq_len,
799
+ self.scheduler.config.get("base_image_seq_len", 256),
800
+ self.scheduler.config.get("max_image_seq_len", 4096),
801
+ self.scheduler.config.get("base_shift", 0.5),
802
+ self.scheduler.config.get("max_shift", 1.15),
803
+ )
804
+ timesteps, num_inference_steps = retrieve_timesteps(
805
+ self.scheduler,
806
+ num_inference_steps,
807
+ device,
808
+ sigmas=sigmas,
809
+ mu=mu,
810
+ )
811
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
812
+ self._num_timesteps = len(timesteps)
813
+
814
+ # handle guidance
815
+ if self.transformer.config.guidance_embeds and guidance_scale is None:
816
+ raise ValueError("guidance_scale is required for guidance-distilled model.")
817
+ elif self.transformer.config.guidance_embeds:
818
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
819
+ guidance = guidance.expand(latents.shape[0])
820
+ elif not self.transformer.config.guidance_embeds and guidance_scale is not None:
821
+ logger.warning(
822
+ f"guidance_scale is passed as {guidance_scale}, but ignored since the model is not guidance-distilled."
823
+ )
824
+ guidance = None
825
+ elif not self.transformer.config.guidance_embeds and guidance_scale is None:
826
+ guidance = None
827
+
828
+ if self.attention_kwargs is None:
829
+ self._attention_kwargs = {}
830
+
831
+ txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None
832
+ negative_txt_seq_lens = (
833
+ negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None
834
+ )
835
+ if comfyui_progressbar:
836
+ pbar.update(1)
837
+
838
+ # 6. Denoising loop
839
+ self.scheduler.set_begin_index(0)
840
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
841
+ for i, t in enumerate(timesteps):
842
+ self.transformer.current_steps = i
843
+ if self.interrupt:
844
+ continue
845
+
846
+ if image_latents is not None:
847
+ latents_and_image_latents = torch.cat([latents, image_latents], dim=1)
848
+ else:
849
+ latents_and_image_latents = latents
850
+
851
+ if do_true_cfg:
852
+ latent_model_input = torch.cat([latents_and_image_latents] * 2)
853
+ prompt_embeds_mask_input = [_negative_prompt_embeds_mask for _negative_prompt_embeds_mask in negative_prompt_embeds_mask] + [_prompt_embeds_mask for _prompt_embeds_mask in prompt_embeds_mask]
854
+ prompt_embeds_input = [_negative_prompt_embeds for _negative_prompt_embeds in negative_prompt_embeds] + [_prompt_embeds for _prompt_embeds in prompt_embeds]
855
+ img_shapes_input = img_shapes * 2
856
+ txt_seq_lens_input = negative_txt_seq_lens + txt_seq_lens
857
+ else:
858
+ latent_model_input = latents_and_image_latents
859
+ prompt_embeds_mask_input = prompt_embeds_mask
860
+ prompt_embeds_input = prompt_embeds
861
+ img_shapes_input = img_shapes
862
+ txt_seq_lens_input = txt_seq_lens
863
+
864
+ if hasattr(self.scheduler, "scale_model_input"):
865
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
866
+
867
+ # handle guidance
868
+ if self.transformer.config.guidance_embeds:
869
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
870
+ guidance = guidance.expand(latent_model_input.shape[0])
871
+ else:
872
+ guidance = None
873
+
874
+ self._current_timestep = t
875
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
876
+ timestep = t.expand(latent_model_input.shape[0]).to(latent_model_input.dtype)
877
+
878
+ with torch.cuda.amp.autocast(dtype=latents.dtype), torch.cuda.device(device=latents.device):
879
+ noise_pred = self.transformer.forward_bs(
880
+ x=latent_model_input,
881
+ timestep=timestep / 1000,
882
+ guidance=guidance,
883
+ encoder_hidden_states_mask=prompt_embeds_mask_input,
884
+ encoder_hidden_states=prompt_embeds_input,
885
+ img_shapes=img_shapes_input,
886
+ txt_seq_lens=txt_seq_lens_input,
887
+ attention_kwargs=self.attention_kwargs,
888
+ return_dict=False,
889
+ )
890
+ noise_pred = noise_pred[:, : latents.size(1)]
891
+
892
+ if do_true_cfg:
893
+ neg_noise_pred, noise_pred = noise_pred.chunk(2)
894
+ comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
895
+
896
+ cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True)
897
+ noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True)
898
+ noise_pred = comb_pred * (cond_norm / noise_norm)
899
+
900
+ # compute the previous noisy sample x_t -> x_t-1
901
+ latents_dtype = latents.dtype
902
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
903
+
904
+ if latents.dtype != latents_dtype:
905
+ if torch.backends.mps.is_available():
906
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
907
+ latents = latents.to(latents_dtype)
908
+
909
+ if callback_on_step_end is not None:
910
+ callback_kwargs = {}
911
+ for k in callback_on_step_end_tensor_inputs:
912
+ callback_kwargs[k] = locals()[k]
913
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
914
+
915
+ latents = callback_outputs.pop("latents", latents)
916
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
917
+
918
+ # call the callback, if provided
919
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
920
+ progress_bar.update()
921
+
922
+ if XLA_AVAILABLE:
923
+ xm.mark_step()
924
+
925
+ if comfyui_progressbar:
926
+ pbar.update(1)
927
+
928
+ self._current_timestep = None
929
+ if output_type == "latent":
930
+ image = latents
931
+ else:
932
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
933
+ latents = latents.to(self.vae.dtype)
934
+ latents_mean = (
935
+ torch.tensor(self.vae.config.latents_mean)
936
+ .view(1, self.vae.config.z_dim, 1, 1, 1)
937
+ .to(latents.device, latents.dtype)
938
+ )
939
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
940
+ latents.device, latents.dtype
941
+ )
942
+ latents = latents / latents_std + latents_mean
943
+ image = self.vae.decode(latents, return_dict=False)[0][:, :, 0]
944
+ image = self.image_processor.postprocess(image, output_type=output_type)
945
+
946
+ # Offload all models
947
+ self.maybe_free_model_hooks()
948
+
949
+ if not return_dict:
950
+ return (image,)
951
+
952
+ return QwenImagePipelineOutput(images=image)
videox_fun/pipeline/pipeline_qwenimage_edit_plus.py ADDED
@@ -0,0 +1,937 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from https://github.com/naykun/diffusers/blob/main/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py
2
+ # Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import inspect
17
+ from dataclasses import dataclass
18
+ from typing import Any, Callable, Dict, List, Optional, Union
19
+
20
+ import numpy as np
21
+ import math
22
+ import PIL.Image
23
+ import torch
24
+ from diffusers.image_processor import VaeImageProcessor
25
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
26
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
27
+ from diffusers.utils import (BaseOutput, is_torch_xla_available, logging,
28
+ replace_example_docstring)
29
+ from diffusers.utils.torch_utils import randn_tensor
30
+
31
+ from ..models import (AutoencoderKLQwenImage,
32
+ Qwen2_5_VLForConditionalGeneration, Qwen2VLProcessor,
33
+ Qwen2Tokenizer, QwenImageTransformer2DModel)
34
+
35
+ if is_torch_xla_available():
36
+ import torch_xla.core.xla_model as xm
37
+
38
+ XLA_AVAILABLE = True
39
+ else:
40
+ XLA_AVAILABLE = False
41
+
42
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
43
+
44
+ EXAMPLE_DOC_STRING = """
45
+ Examples:
46
+ ```py
47
+ >>> import torch
48
+ >>> from PIL import Image
49
+ >>> from diffusers import QwenImageEditPipeline
50
+
51
+ >>> pipe = QwenImageEditPipeline.from_pretrained("Qwen/Qwen-Image-Edit", torch_dtype=torch.bfloat16)
52
+ >>> pipe.to("cuda")
53
+ >>> prompt = "Change the cat to a dog"
54
+ >>> image = Image.open("cat.png")
55
+ >>> # Depending on the variant being used, the pipeline call will slightly vary.
56
+ >>> # Refer to the pipeline documentation for more details.
57
+ >>> image = pipe(image, prompt, num_inference_steps=50).images[0]
58
+ >>> image.save("qwenimageedit.png")
59
+ ```
60
+ """
61
+
62
+ CONDITION_IMAGE_SIZE = 384 * 384
63
+ VAE_IMAGE_SIZE = 1024 * 1024
64
+
65
+ PREFERRED_QWENIMAGE_RESOLUTIONS = [
66
+ (672, 1568),
67
+ (688, 1504),
68
+ (720, 1456),
69
+ (752, 1392),
70
+ (800, 1328),
71
+ (832, 1248),
72
+ (880, 1184),
73
+ (944, 1104),
74
+ (1024, 1024),
75
+ (1104, 944),
76
+ (1184, 880),
77
+ (1248, 832),
78
+ (1328, 800),
79
+ (1392, 752),
80
+ (1456, 720),
81
+ (1504, 688),
82
+ (1568, 672),
83
+ ]
84
+
85
+ # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.calculate_shift
86
+ def calculate_shift(
87
+ image_seq_len,
88
+ base_seq_len: int = 256,
89
+ max_seq_len: int = 4096,
90
+ base_shift: float = 0.5,
91
+ max_shift: float = 1.15,
92
+ ):
93
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
94
+ b = base_shift - m * base_seq_len
95
+ mu = image_seq_len * m + b
96
+ return mu
97
+
98
+
99
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
100
+ def retrieve_timesteps(
101
+ scheduler,
102
+ num_inference_steps: Optional[int] = None,
103
+ device: Optional[Union[str, torch.device]] = None,
104
+ timesteps: Optional[List[int]] = None,
105
+ sigmas: Optional[List[float]] = None,
106
+ **kwargs,
107
+ ):
108
+ r"""
109
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
110
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
111
+
112
+ Args:
113
+ scheduler (`SchedulerMixin`):
114
+ The scheduler to get timesteps from.
115
+ num_inference_steps (`int`):
116
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
117
+ must be `None`.
118
+ device (`str` or `torch.device`, *optional*):
119
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
120
+ timesteps (`List[int]`, *optional*):
121
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
122
+ `num_inference_steps` and `sigmas` must be `None`.
123
+ sigmas (`List[float]`, *optional*):
124
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
125
+ `num_inference_steps` and `timesteps` must be `None`.
126
+
127
+ Returns:
128
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
129
+ second element is the number of inference steps.
130
+ """
131
+ if timesteps is not None and sigmas is not None:
132
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
133
+ if timesteps is not None:
134
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
135
+ if not accepts_timesteps:
136
+ raise ValueError(
137
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
138
+ f" timestep schedules. Please check whether you are using the correct scheduler."
139
+ )
140
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
141
+ timesteps = scheduler.timesteps
142
+ num_inference_steps = len(timesteps)
143
+ elif sigmas is not None:
144
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
145
+ if not accept_sigmas:
146
+ raise ValueError(
147
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
148
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
149
+ )
150
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
151
+ timesteps = scheduler.timesteps
152
+ num_inference_steps = len(timesteps)
153
+ else:
154
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
155
+ timesteps = scheduler.timesteps
156
+ return timesteps, num_inference_steps
157
+
158
+
159
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
160
+ def retrieve_latents(
161
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
162
+ ):
163
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
164
+ return encoder_output.latent_dist.sample(generator)
165
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
166
+ return encoder_output.latent_dist.mode()
167
+ elif hasattr(encoder_output, "latents"):
168
+ return encoder_output.latents
169
+ else:
170
+ raise AttributeError("Could not access latents of provided encoder_output")
171
+
172
+
173
+ def calculate_dimensions(target_area, ratio):
174
+ width = math.sqrt(target_area * ratio)
175
+ height = width / ratio
176
+
177
+ width = round(width / 32) * 32
178
+ height = round(height / 32) * 32
179
+
180
+ return width, height
181
+
182
+
183
+ @dataclass
184
+ class QwenImagePipelineOutput(BaseOutput):
185
+ """
186
+ Output class for Stable Diffusion pipelines.
187
+
188
+ Args:
189
+ images (`List[PIL.Image.Image]` or `np.ndarray`)
190
+ List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
191
+ num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
192
+ """
193
+
194
+ images: Union[List[PIL.Image.Image], np.ndarray]
195
+
196
+
197
+ class QwenImageEditPlusPipeline(DiffusionPipeline):
198
+ r"""
199
+ The QwenImage pipeline for text-to-image generation.
200
+
201
+ Args:
202
+ transformer ([`QwenImageTransformer2DModel`]):
203
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
204
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
205
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
206
+ vae ([`AutoencoderKL`]):
207
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
208
+ text_encoder ([`Qwen2.5-VL-7B-Instruct`]):
209
+ [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct), specifically the
210
+ [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) variant.
211
+ tokenizer (`QwenTokenizer`):
212
+ Tokenizer of class
213
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
214
+ """
215
+
216
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
217
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
218
+
219
+ def __init__(
220
+ self,
221
+ scheduler: FlowMatchEulerDiscreteScheduler,
222
+ vae: AutoencoderKLQwenImage,
223
+ text_encoder: Qwen2_5_VLForConditionalGeneration,
224
+ tokenizer: Qwen2Tokenizer,
225
+ processor: Qwen2VLProcessor,
226
+ transformer: QwenImageTransformer2DModel,
227
+ ):
228
+ super().__init__()
229
+
230
+ self.register_modules(
231
+ vae=vae,
232
+ text_encoder=text_encoder,
233
+ tokenizer=tokenizer,
234
+ processor=processor,
235
+ transformer=transformer,
236
+ scheduler=scheduler,
237
+ )
238
+ self.vae_scale_factor = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
239
+ self.latent_channels = self.vae.config.z_dim if getattr(self, "vae", None) else 16
240
+ # QwenImage latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
241
+ # by the patch size. So the vae scale factor is multiplied by the patch size to account for this
242
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
243
+ self.tokenizer_max_length = 1024
244
+
245
+ self.prompt_template_encode = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
246
+ self.prompt_template_encode_start_idx = 64
247
+ self.default_sample_size = 128
248
+
249
+ # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._extract_masked_hidden
250
+ def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor):
251
+ bool_mask = mask.bool()
252
+ valid_lengths = bool_mask.sum(dim=1)
253
+ selected = hidden_states[bool_mask]
254
+ split_result = torch.split(selected, valid_lengths.tolist(), dim=0)
255
+
256
+ return split_result
257
+
258
+ def _get_qwen_prompt_embeds(
259
+ self,
260
+ prompt: Union[str, List[str]] = None,
261
+ image: Optional[torch.Tensor] = None,
262
+ device: Optional[torch.device] = None,
263
+ dtype: Optional[torch.dtype] = None,
264
+ ):
265
+ device = device or self._execution_device
266
+ dtype = dtype or self.text_encoder.dtype
267
+
268
+ prompt = [prompt] if isinstance(prompt, str) else prompt
269
+ img_prompt_template = "Picture {}: <|vision_start|><|image_pad|><|vision_end|>"
270
+ if isinstance(image, list):
271
+ base_img_prompt = ""
272
+ for i, img in enumerate(image):
273
+ base_img_prompt += img_prompt_template.format(i + 1)
274
+ elif image is not None:
275
+ base_img_prompt = img_prompt_template.format(1)
276
+ else:
277
+ base_img_prompt = ""
278
+
279
+ template = self.prompt_template_encode
280
+
281
+ drop_idx = self.prompt_template_encode_start_idx
282
+ txt = [template.format(base_img_prompt + e) for e in prompt]
283
+
284
+ model_inputs = self.processor(
285
+ text=txt,
286
+ images=image,
287
+ padding=True,
288
+ return_tensors="pt",
289
+ ).to(device)
290
+
291
+ outputs = self.text_encoder(
292
+ input_ids=model_inputs.input_ids,
293
+ attention_mask=model_inputs.attention_mask,
294
+ pixel_values=model_inputs.pixel_values,
295
+ image_grid_thw=model_inputs.image_grid_thw,
296
+ output_hidden_states=True,
297
+ )
298
+
299
+ hidden_states = outputs.hidden_states[-1]
300
+ split_hidden_states = self._extract_masked_hidden(hidden_states, model_inputs.attention_mask)
301
+ split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
302
+ attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
303
+ max_seq_len = max([e.size(0) for e in split_hidden_states])
304
+ prompt_embeds = torch.stack(
305
+ [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]
306
+ )
307
+ encoder_attention_mask = torch.stack(
308
+ [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list]
309
+ )
310
+
311
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
312
+
313
+ return prompt_embeds, encoder_attention_mask
314
+
315
+ # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit.QwenImageEditPipeline.encode_prompt
316
+ def encode_prompt(
317
+ self,
318
+ prompt: Union[str, List[str]],
319
+ image: Optional[torch.Tensor] = None,
320
+ device: Optional[torch.device] = None,
321
+ num_images_per_prompt: int = 1,
322
+ prompt_embeds: Optional[torch.Tensor] = None,
323
+ prompt_embeds_mask: Optional[torch.Tensor] = None,
324
+ max_sequence_length: int = 1024,
325
+ ):
326
+ r"""
327
+
328
+ Args:
329
+ prompt (`str` or `List[str]`, *optional*):
330
+ prompt to be encoded
331
+ image (`torch.Tensor`, *optional*):
332
+ image to be encoded
333
+ device: (`torch.device`):
334
+ torch device
335
+ num_images_per_prompt (`int`):
336
+ number of images that should be generated per prompt
337
+ prompt_embeds (`torch.Tensor`, *optional*):
338
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
339
+ provided, text embeddings will be generated from `prompt` input argument.
340
+ """
341
+ device = device or self._execution_device
342
+
343
+ prompt = [prompt] if isinstance(prompt, str) else prompt
344
+ batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0]
345
+
346
+ if prompt_embeds is None:
347
+ prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, image, device)
348
+
349
+ _, seq_len, _ = prompt_embeds.shape
350
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
351
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
352
+ prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
353
+ prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
354
+
355
+ return prompt_embeds, prompt_embeds_mask
356
+
357
+ # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit.QwenImageEditPipeline.check_inputs
358
+ def check_inputs(
359
+ self,
360
+ prompt,
361
+ height,
362
+ width,
363
+ negative_prompt=None,
364
+ prompt_embeds=None,
365
+ negative_prompt_embeds=None,
366
+ prompt_embeds_mask=None,
367
+ negative_prompt_embeds_mask=None,
368
+ callback_on_step_end_tensor_inputs=None,
369
+ max_sequence_length=None,
370
+ ):
371
+ if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
372
+ logger.warning(
373
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
374
+ )
375
+
376
+ if callback_on_step_end_tensor_inputs is not None and not all(
377
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
378
+ ):
379
+ raise ValueError(
380
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
381
+ )
382
+
383
+ if prompt is not None and prompt_embeds is not None:
384
+ raise ValueError(
385
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
386
+ " only forward one of the two."
387
+ )
388
+ elif prompt is None and prompt_embeds is None:
389
+ raise ValueError(
390
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
391
+ )
392
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
393
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
394
+
395
+ if negative_prompt is not None and negative_prompt_embeds is not None:
396
+ raise ValueError(
397
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
398
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
399
+ )
400
+
401
+ if prompt_embeds is not None and prompt_embeds_mask is None:
402
+ raise ValueError(
403
+ "If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`."
404
+ )
405
+ if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None:
406
+ raise ValueError(
407
+ "If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`."
408
+ )
409
+
410
+ if max_sequence_length is not None and max_sequence_length > 1024:
411
+ raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}")
412
+
413
+ @staticmethod
414
+ # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._pack_latents
415
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
416
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
417
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
418
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
419
+
420
+ return latents
421
+
422
+ @staticmethod
423
+ # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._unpack_latents
424
+ def _unpack_latents(latents, height, width, vae_scale_factor):
425
+ batch_size, num_patches, channels = latents.shape
426
+
427
+ # VAE applies 8x compression on images but we must also account for packing which requires
428
+ # latent height and width to be divisible by 2.
429
+ height = 2 * (int(height) // (vae_scale_factor * 2))
430
+ width = 2 * (int(width) // (vae_scale_factor * 2))
431
+
432
+ latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
433
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
434
+
435
+ latents = latents.reshape(batch_size, channels // (2 * 2), 1, height, width)
436
+
437
+ return latents
438
+
439
+ # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit.QwenImageEditPipeline._encode_vae_image
440
+ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
441
+ if isinstance(generator, list):
442
+ image_latents = [
443
+ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i], sample_mode="argmax")
444
+ for i in range(image.shape[0])
445
+ ]
446
+ image_latents = torch.cat(image_latents, dim=0)
447
+ else:
448
+ image_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode="argmax")
449
+ latents_mean = (
450
+ torch.tensor(self.vae.config.latents_mean)
451
+ .view(1, self.latent_channels, 1, 1, 1)
452
+ .to(image_latents.device, image_latents.dtype)
453
+ )
454
+ latents_std = (
455
+ torch.tensor(self.vae.config.latents_std)
456
+ .view(1, self.latent_channels, 1, 1, 1)
457
+ .to(image_latents.device, image_latents.dtype)
458
+ )
459
+ image_latents = (image_latents - latents_mean) / latents_std
460
+
461
+ return image_latents
462
+
463
+ def prepare_latents(
464
+ self,
465
+ images,
466
+ batch_size,
467
+ num_channels_latents,
468
+ height,
469
+ width,
470
+ dtype,
471
+ device,
472
+ generator,
473
+ latents=None,
474
+ ):
475
+ # VAE applies 8x compression on images but we must also account for packing which requires
476
+ # latent height and width to be divisible by 2.
477
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
478
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
479
+
480
+ shape = (batch_size, 1, num_channels_latents, height, width)
481
+
482
+ image_latents = None
483
+ if images is not None:
484
+ if not isinstance(images, list):
485
+ images = [images]
486
+ all_image_latents = []
487
+ for image in images:
488
+ image = image.to(device=device, dtype=dtype)
489
+ if image.shape[1] != self.latent_channels:
490
+ image_latents = self._encode_vae_image(image=image, generator=generator)
491
+ else:
492
+ image_latents = image
493
+ if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
494
+ # expand init_latents for batch_size
495
+ additional_image_per_prompt = batch_size // image_latents.shape[0]
496
+ image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
497
+ elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
498
+ raise ValueError(
499
+ f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
500
+ )
501
+ else:
502
+ image_latents = torch.cat([image_latents], dim=0)
503
+
504
+ image_latent_height, image_latent_width = image_latents.shape[3:]
505
+ image_latents = self._pack_latents(
506
+ image_latents, batch_size, num_channels_latents, image_latent_height, image_latent_width
507
+ )
508
+ all_image_latents.append(image_latents)
509
+ image_latents = torch.cat(all_image_latents, dim=1)
510
+
511
+ if isinstance(generator, list) and len(generator) != batch_size:
512
+ raise ValueError(
513
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
514
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
515
+ )
516
+ if latents is None:
517
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
518
+ latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
519
+ else:
520
+ latents = latents.to(device=device, dtype=dtype)
521
+
522
+ return latents, image_latents
523
+
524
+ @property
525
+ def guidance_scale(self):
526
+ return self._guidance_scale
527
+
528
+ @property
529
+ def attention_kwargs(self):
530
+ return self._attention_kwargs
531
+
532
+ @property
533
+ def num_timesteps(self):
534
+ return self._num_timesteps
535
+
536
+ @property
537
+ def current_timestep(self):
538
+ return self._current_timestep
539
+
540
+ @property
541
+ def interrupt(self):
542
+ return self._interrupt
543
+
544
+ @torch.no_grad()
545
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
546
+ def __call__(
547
+ self,
548
+ image = None,
549
+ prompt: Union[str, List[str]] = None,
550
+ negative_prompt: Union[str, List[str]] = None,
551
+ true_cfg_scale: float = 4.0,
552
+ height: Optional[int] = None,
553
+ width: Optional[int] = None,
554
+ num_inference_steps: int = 50,
555
+ sigmas: Optional[List[float]] = None,
556
+ guidance_scale: Optional[float] = None,
557
+ num_images_per_prompt: int = 1,
558
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
559
+ latents: Optional[torch.Tensor] = None,
560
+ prompt_embeds: Optional[torch.Tensor] = None,
561
+ prompt_embeds_mask: Optional[torch.Tensor] = None,
562
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
563
+ negative_prompt_embeds_mask: Optional[torch.Tensor] = None,
564
+ output_type: Optional[str] = "pil",
565
+ return_dict: bool = True,
566
+ attention_kwargs: Optional[Dict[str, Any]] = None,
567
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
568
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
569
+ max_sequence_length: int = 512,
570
+ comfyui_progressbar: bool = False,
571
+ ):
572
+ r"""
573
+ Function invoked when calling the pipeline for generation.
574
+
575
+ Args:
576
+ image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
577
+ `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
578
+ numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
579
+ or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
580
+ list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
581
+ latents as `image`, but if passing latents directly it is not encoded again.
582
+ prompt (`str` or `List[str]`, *optional*):
583
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
584
+ instead.
585
+ negative_prompt (`str` or `List[str]`, *optional*):
586
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
587
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is
588
+ not greater than `1`).
589
+ true_cfg_scale (`float`, *optional*, defaults to 1.0):
590
+ true_cfg_scale (`float`, *optional*, defaults to 1.0): Guidance scale as defined in [Classifier-Free
591
+ Diffusion Guidance](https://huggingface.co/papers/2207.12598). `true_cfg_scale` is defined as `w` of
592
+ equation 2. of [Imagen Paper](https://huggingface.co/papers/2205.11487). Classifier-free guidance is
593
+ enabled by setting `true_cfg_scale > 1` and a provided `negative_prompt`. Higher guidance scale
594
+ encourages to generate images that are closely linked to the text `prompt`, usually at the expense of
595
+ lower image quality.
596
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
597
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
598
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
599
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
600
+ num_inference_steps (`int`, *optional*, defaults to 50):
601
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
602
+ expense of slower inference.
603
+ sigmas (`List[float]`, *optional*):
604
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
605
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
606
+ will be used.
607
+ guidance_scale (`float`, *optional*, defaults to None):
608
+ A guidance scale value for guidance distilled models. Unlike the traditional classifier-free guidance
609
+ where the guidance scale is applied during inference through noise prediction rescaling, guidance
610
+ distilled models take the guidance scale directly as an input parameter during forward pass. Guidance
611
+ scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images
612
+ that are closely linked to the text `prompt`, usually at the expense of lower image quality. This
613
+ parameter in the pipeline is there to support future guidance-distilled models when they come up. It is
614
+ ignored when not using guidance distilled models. To enable traditional classifier-free guidance,
615
+ please pass `true_cfg_scale > 1.0` and `negative_prompt` (even an empty negative prompt like " " should
616
+ enable classifier-free guidance computations).
617
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
618
+ The number of images to generate per prompt.
619
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
620
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
621
+ to make generation deterministic.
622
+ latents (`torch.Tensor`, *optional*):
623
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
624
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
625
+ tensor will be generated by sampling using the supplied random `generator`.
626
+ prompt_embeds (`torch.Tensor`, *optional*):
627
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
628
+ provided, text embeddings will be generated from `prompt` input argument.
629
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
630
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
631
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
632
+ argument.
633
+ output_type (`str`, *optional*, defaults to `"pil"`):
634
+ The output format of the generate image. Choose between
635
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
636
+ return_dict (`bool`, *optional*, defaults to `True`):
637
+ Whether or not to return a [`~pipelines.qwenimage.QwenImagePipelineOutput`] instead of a plain tuple.
638
+ attention_kwargs (`dict`, *optional*):
639
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
640
+ `self.processor` in
641
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
642
+ callback_on_step_end (`Callable`, *optional*):
643
+ A function that calls at the end of each denoising steps during the inference. The function is called
644
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
645
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
646
+ `callback_on_step_end_tensor_inputs`.
647
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
648
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
649
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
650
+ `._callback_tensor_inputs` attribute of your pipeline class.
651
+ max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
652
+
653
+ Examples:
654
+
655
+ Returns:
656
+ [`~pipelines.qwenimage.QwenImagePipelineOutput`] or `tuple`:
657
+ [`~pipelines.qwenimage.QwenImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When
658
+ returning a tuple, the first element is a list with the generated images.
659
+ """
660
+ image_size = image[-1].size if isinstance(image, list) else image.size
661
+ calculated_width, calculated_height = calculate_dimensions(1024 * 1024, image_size[0] / image_size[1])
662
+ height = height or calculated_height
663
+ width = width or calculated_width
664
+
665
+ multiple_of = self.vae_scale_factor * 2
666
+ width = width // multiple_of * multiple_of
667
+ height = height // multiple_of * multiple_of
668
+
669
+ # 1. Check inputs. Raise error if not correct
670
+ self.check_inputs(
671
+ prompt,
672
+ height,
673
+ width,
674
+ negative_prompt=negative_prompt,
675
+ prompt_embeds=prompt_embeds,
676
+ negative_prompt_embeds=negative_prompt_embeds,
677
+ prompt_embeds_mask=prompt_embeds_mask,
678
+ negative_prompt_embeds_mask=negative_prompt_embeds_mask,
679
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
680
+ max_sequence_length=max_sequence_length,
681
+ )
682
+
683
+ self._guidance_scale = guidance_scale
684
+ self._attention_kwargs = attention_kwargs
685
+ self._current_timestep = None
686
+ self._interrupt = False
687
+
688
+ # 2. Define call parameters
689
+ if prompt is not None and isinstance(prompt, str):
690
+ batch_size = 1
691
+ elif prompt is not None and isinstance(prompt, list):
692
+ batch_size = len(prompt)
693
+ else:
694
+ batch_size = prompt_embeds.shape[0]
695
+
696
+ device = self._execution_device
697
+ if comfyui_progressbar:
698
+ from comfy.utils import ProgressBar
699
+ pbar = ProgressBar(num_inference_steps + 2)
700
+
701
+ # 3. Preprocess image
702
+ if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels):
703
+ if not isinstance(image, list):
704
+ image = [image]
705
+ condition_image_sizes = []
706
+ condition_images = []
707
+ vae_image_sizes = []
708
+ vae_images = []
709
+ for img in image:
710
+ image_width, image_height = img.size
711
+ condition_width, condition_height = calculate_dimensions(
712
+ CONDITION_IMAGE_SIZE, image_width / image_height
713
+ )
714
+ vae_width, vae_height = calculate_dimensions(VAE_IMAGE_SIZE, image_width / image_height)
715
+ condition_image_sizes.append((condition_width, condition_height))
716
+ vae_image_sizes.append((vae_width, vae_height))
717
+ condition_images.append(self.image_processor.resize(img, condition_height, condition_width))
718
+ vae_images.append(self.image_processor.preprocess(img, vae_height, vae_width).unsqueeze(2))
719
+
720
+ has_neg_prompt = negative_prompt is not None or (
721
+ negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None
722
+ )
723
+
724
+ if true_cfg_scale > 1 and not has_neg_prompt:
725
+ logger.warning(
726
+ f"true_cfg_scale is passed as {true_cfg_scale}, but classifier-free guidance is not enabled since no negative_prompt is provided."
727
+ )
728
+ elif true_cfg_scale <= 1 and has_neg_prompt:
729
+ logger.warning(
730
+ " negative_prompt is passed but classifier-free guidance is not enabled since true_cfg_scale <= 1"
731
+ )
732
+
733
+ do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
734
+ prompt_embeds, prompt_embeds_mask = self.encode_prompt(
735
+ image=condition_images,
736
+ prompt=prompt,
737
+ prompt_embeds=prompt_embeds,
738
+ prompt_embeds_mask=prompt_embeds_mask,
739
+ device=device,
740
+ num_images_per_prompt=num_images_per_prompt,
741
+ max_sequence_length=max_sequence_length,
742
+ )
743
+ if do_true_cfg:
744
+ negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt(
745
+ image=condition_images,
746
+ prompt=negative_prompt,
747
+ prompt_embeds=negative_prompt_embeds,
748
+ prompt_embeds_mask=negative_prompt_embeds_mask,
749
+ device=device,
750
+ num_images_per_prompt=num_images_per_prompt,
751
+ max_sequence_length=max_sequence_length,
752
+ )
753
+ if comfyui_progressbar:
754
+ pbar.update(1)
755
+
756
+ # 4. Prepare latent variables
757
+ num_channels_latents = self.transformer.config.in_channels // 4
758
+ latents, image_latents = self.prepare_latents(
759
+ vae_images,
760
+ batch_size * num_images_per_prompt,
761
+ num_channels_latents,
762
+ height,
763
+ width,
764
+ prompt_embeds.dtype,
765
+ device,
766
+ generator,
767
+ latents,
768
+ )
769
+ img_shapes = [
770
+ [
771
+ (1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2),
772
+ *[
773
+ (1, vae_height // self.vae_scale_factor // 2, vae_width // self.vae_scale_factor // 2)
774
+ for vae_width, vae_height in vae_image_sizes
775
+ ],
776
+ ]
777
+ ] * batch_size
778
+
779
+ # 5. Prepare timesteps
780
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
781
+ image_seq_len = latents.shape[1]
782
+ mu = calculate_shift(
783
+ image_seq_len,
784
+ self.scheduler.config.get("base_image_seq_len", 256),
785
+ self.scheduler.config.get("max_image_seq_len", 4096),
786
+ self.scheduler.config.get("base_shift", 0.5),
787
+ self.scheduler.config.get("max_shift", 1.15),
788
+ )
789
+ timesteps, num_inference_steps = retrieve_timesteps(
790
+ self.scheduler,
791
+ num_inference_steps,
792
+ device,
793
+ sigmas=sigmas,
794
+ mu=mu,
795
+ )
796
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
797
+ self._num_timesteps = len(timesteps)
798
+
799
+ # handle guidance
800
+ if self.transformer.config.guidance_embeds and guidance_scale is None:
801
+ raise ValueError("guidance_scale is required for guidance-distilled model.")
802
+ elif self.transformer.config.guidance_embeds:
803
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
804
+ guidance = guidance.expand(latents.shape[0])
805
+ elif not self.transformer.config.guidance_embeds and guidance_scale is not None:
806
+ logger.warning(
807
+ f"guidance_scale is passed as {guidance_scale}, but ignored since the model is not guidance-distilled."
808
+ )
809
+ guidance = None
810
+ elif not self.transformer.config.guidance_embeds and guidance_scale is None:
811
+ guidance = None
812
+
813
+ if self.attention_kwargs is None:
814
+ self._attention_kwargs = {}
815
+
816
+ txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None
817
+ negative_txt_seq_lens = (
818
+ negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None
819
+ )
820
+ if comfyui_progressbar:
821
+ pbar.update(1)
822
+
823
+ # 6. Denoising loop
824
+ self.scheduler.set_begin_index(0)
825
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
826
+ for i, t in enumerate(timesteps):
827
+ self.transformer.current_steps = i
828
+ if self.interrupt:
829
+ continue
830
+
831
+ if image_latents is not None:
832
+ latents_and_image_latents = torch.cat([latents, image_latents], dim=1)
833
+ else:
834
+ latents_and_image_latents = latents
835
+
836
+ if do_true_cfg:
837
+ latent_model_input = torch.cat([latents_and_image_latents] * 2)
838
+ prompt_embeds_mask_input = [_negative_prompt_embeds_mask for _negative_prompt_embeds_mask in negative_prompt_embeds_mask] + [_prompt_embeds_mask for _prompt_embeds_mask in prompt_embeds_mask]
839
+ prompt_embeds_input = [_negative_prompt_embeds for _negative_prompt_embeds in negative_prompt_embeds] + [_prompt_embeds for _prompt_embeds in prompt_embeds]
840
+ img_shapes_input = img_shapes * 2
841
+ txt_seq_lens_input = negative_txt_seq_lens + txt_seq_lens
842
+ else:
843
+ latent_model_input = latents_and_image_latents
844
+ prompt_embeds_mask_input = prompt_embeds_mask
845
+ prompt_embeds_input = prompt_embeds
846
+ img_shapes_input = img_shapes
847
+ txt_seq_lens_input = txt_seq_lens
848
+
849
+ if hasattr(self.scheduler, "scale_model_input"):
850
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
851
+
852
+ # handle guidance
853
+ if self.transformer.config.guidance_embeds:
854
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
855
+ guidance = guidance.expand(latent_model_input.shape[0])
856
+ else:
857
+ guidance = None
858
+
859
+ self._current_timestep = t
860
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
861
+ timestep = t.expand(latent_model_input.shape[0]).to(latent_model_input.dtype)
862
+
863
+ with torch.cuda.amp.autocast(dtype=latents.dtype), torch.cuda.device(device=latents.device):
864
+ noise_pred = self.transformer.forward_bs(
865
+ x=latent_model_input,
866
+ timestep=timestep / 1000,
867
+ guidance=guidance,
868
+ encoder_hidden_states_mask=prompt_embeds_mask_input,
869
+ encoder_hidden_states=prompt_embeds_input,
870
+ img_shapes=img_shapes_input,
871
+ txt_seq_lens=txt_seq_lens_input,
872
+ attention_kwargs=self.attention_kwargs,
873
+ return_dict=False,
874
+ )
875
+ noise_pred = noise_pred[:, : latents.size(1)]
876
+
877
+ if do_true_cfg:
878
+ neg_noise_pred, noise_pred = noise_pred.chunk(2)
879
+ comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
880
+
881
+ cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True)
882
+ noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True)
883
+ noise_pred = comb_pred * (cond_norm / noise_norm)
884
+
885
+ # compute the previous noisy sample x_t -> x_t-1
886
+ latents_dtype = latents.dtype
887
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
888
+
889
+ if latents.dtype != latents_dtype:
890
+ if torch.backends.mps.is_available():
891
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
892
+ latents = latents.to(latents_dtype)
893
+
894
+ if callback_on_step_end is not None:
895
+ callback_kwargs = {}
896
+ for k in callback_on_step_end_tensor_inputs:
897
+ callback_kwargs[k] = locals()[k]
898
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
899
+
900
+ latents = callback_outputs.pop("latents", latents)
901
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
902
+
903
+ # call the callback, if provided
904
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
905
+ progress_bar.update()
906
+
907
+ if XLA_AVAILABLE:
908
+ xm.mark_step()
909
+
910
+ if comfyui_progressbar:
911
+ pbar.update(1)
912
+
913
+ self._current_timestep = None
914
+ if output_type == "latent":
915
+ image = latents
916
+ else:
917
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
918
+ latents = latents.to(self.vae.dtype)
919
+ latents_mean = (
920
+ torch.tensor(self.vae.config.latents_mean)
921
+ .view(1, self.vae.config.z_dim, 1, 1, 1)
922
+ .to(latents.device, latents.dtype)
923
+ )
924
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
925
+ latents.device, latents.dtype
926
+ )
927
+ latents = latents / latents_std + latents_mean
928
+ image = self.vae.decode(latents, return_dict=False)[0][:, :, 0]
929
+ image = self.image_processor.postprocess(image, output_type=output_type)
930
+
931
+ # Offload all models
932
+ self.maybe_free_model_hooks()
933
+
934
+ if not return_dict:
935
+ return (image,)
936
+
937
+ return QwenImagePipelineOutput(images=image)
videox_fun/pipeline/pipeline_wan.py ADDED
@@ -0,0 +1,576 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import math
3
+ from dataclasses import dataclass
4
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
5
+
6
+ import numpy as np
7
+ import torch
8
+ from diffusers import FlowMatchEulerDiscreteScheduler
9
+ from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
10
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
11
+ from diffusers.utils import BaseOutput, logging, replace_example_docstring
12
+ from diffusers.utils.torch_utils import randn_tensor
13
+ from diffusers.video_processor import VideoProcessor
14
+
15
+ from ..models import (AutoencoderKLWan, AutoTokenizer,
16
+ WanT5EncoderModel, WanTransformer3DModel)
17
+ from ..utils.fm_solvers import (FlowDPMSolverMultistepScheduler,
18
+ get_sampling_sigmas)
19
+ from ..utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
20
+
21
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
22
+
23
+
24
+ EXAMPLE_DOC_STRING = """
25
+ Examples:
26
+ ```python
27
+ pass
28
+ ```
29
+ """
30
+
31
+
32
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
33
+ def retrieve_timesteps(
34
+ scheduler,
35
+ num_inference_steps: Optional[int] = None,
36
+ device: Optional[Union[str, torch.device]] = None,
37
+ timesteps: Optional[List[int]] = None,
38
+ sigmas: Optional[List[float]] = None,
39
+ **kwargs,
40
+ ):
41
+ """
42
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
43
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
44
+
45
+ Args:
46
+ scheduler (`SchedulerMixin`):
47
+ The scheduler to get timesteps from.
48
+ num_inference_steps (`int`):
49
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
50
+ must be `None`.
51
+ device (`str` or `torch.device`, *optional*):
52
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
53
+ timesteps (`List[int]`, *optional*):
54
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
55
+ `num_inference_steps` and `sigmas` must be `None`.
56
+ sigmas (`List[float]`, *optional*):
57
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
58
+ `num_inference_steps` and `timesteps` must be `None`.
59
+
60
+ Returns:
61
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
62
+ second element is the number of inference steps.
63
+ """
64
+ if timesteps is not None and sigmas is not None:
65
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
66
+ if timesteps is not None:
67
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
68
+ if not accepts_timesteps:
69
+ raise ValueError(
70
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
71
+ f" timestep schedules. Please check whether you are using the correct scheduler."
72
+ )
73
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
74
+ timesteps = scheduler.timesteps
75
+ num_inference_steps = len(timesteps)
76
+ elif sigmas is not None:
77
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
78
+ if not accept_sigmas:
79
+ raise ValueError(
80
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
81
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
82
+ )
83
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
84
+ timesteps = scheduler.timesteps
85
+ num_inference_steps = len(timesteps)
86
+ else:
87
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
88
+ timesteps = scheduler.timesteps
89
+ return timesteps, num_inference_steps
90
+
91
+
92
+ @dataclass
93
+ class WanPipelineOutput(BaseOutput):
94
+ r"""
95
+ Output class for CogVideo pipelines.
96
+
97
+ Args:
98
+ video (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
99
+ List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
100
+ denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
101
+ `(batch_size, num_frames, channels, height, width)`.
102
+ """
103
+
104
+ videos: torch.Tensor
105
+
106
+
107
+ class WanPipeline(DiffusionPipeline):
108
+ r"""
109
+ Pipeline for text-to-video generation using Wan.
110
+
111
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
112
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
113
+ """
114
+
115
+ _optional_components = []
116
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
117
+
118
+ _callback_tensor_inputs = [
119
+ "latents",
120
+ "prompt_embeds",
121
+ "negative_prompt_embeds",
122
+ ]
123
+
124
+ def __init__(
125
+ self,
126
+ tokenizer: AutoTokenizer,
127
+ text_encoder: WanT5EncoderModel,
128
+ vae: AutoencoderKLWan,
129
+ transformer: WanTransformer3DModel,
130
+ scheduler: FlowMatchEulerDiscreteScheduler,
131
+ ):
132
+ super().__init__()
133
+
134
+ self.register_modules(
135
+ tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
136
+ )
137
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae.spatial_compression_ratio)
138
+
139
+ def _get_t5_prompt_embeds(
140
+ self,
141
+ prompt: Union[str, List[str]] = None,
142
+ num_videos_per_prompt: int = 1,
143
+ max_sequence_length: int = 512,
144
+ device: Optional[torch.device] = None,
145
+ dtype: Optional[torch.dtype] = None,
146
+ ):
147
+ device = device or self._execution_device
148
+ dtype = dtype or self.text_encoder.dtype
149
+
150
+ prompt = [prompt] if isinstance(prompt, str) else prompt
151
+ batch_size = len(prompt)
152
+
153
+ text_inputs = self.tokenizer(
154
+ prompt,
155
+ padding="max_length",
156
+ max_length=max_sequence_length,
157
+ truncation=True,
158
+ add_special_tokens=True,
159
+ return_tensors="pt",
160
+ )
161
+ text_input_ids = text_inputs.input_ids
162
+ prompt_attention_mask = text_inputs.attention_mask
163
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
164
+
165
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
166
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
167
+ logger.warning(
168
+ "The following part of your input was truncated because `max_sequence_length` is set to "
169
+ f" {max_sequence_length} tokens: {removed_text}"
170
+ )
171
+
172
+ seq_lens = prompt_attention_mask.gt(0).sum(dim=1).long()
173
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask.to(device))[0]
174
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
175
+
176
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
177
+ _, seq_len, _ = prompt_embeds.shape
178
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
179
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
180
+
181
+ return [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
182
+
183
+ def encode_prompt(
184
+ self,
185
+ prompt: Union[str, List[str]],
186
+ negative_prompt: Optional[Union[str, List[str]]] = None,
187
+ do_classifier_free_guidance: bool = True,
188
+ num_videos_per_prompt: int = 1,
189
+ prompt_embeds: Optional[torch.Tensor] = None,
190
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
191
+ max_sequence_length: int = 512,
192
+ device: Optional[torch.device] = None,
193
+ dtype: Optional[torch.dtype] = None,
194
+ ):
195
+ r"""
196
+ Encodes the prompt into text encoder hidden states.
197
+
198
+ Args:
199
+ prompt (`str` or `List[str]`, *optional*):
200
+ prompt to be encoded
201
+ negative_prompt (`str` or `List[str]`, *optional*):
202
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
203
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
204
+ less than `1`).
205
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
206
+ Whether to use classifier free guidance or not.
207
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
208
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
209
+ prompt_embeds (`torch.Tensor`, *optional*):
210
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
211
+ provided, text embeddings will be generated from `prompt` input argument.
212
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
213
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
214
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
215
+ argument.
216
+ device: (`torch.device`, *optional*):
217
+ torch device
218
+ dtype: (`torch.dtype`, *optional*):
219
+ torch dtype
220
+ """
221
+ device = device or self._execution_device
222
+
223
+ prompt = [prompt] if isinstance(prompt, str) else prompt
224
+ if prompt is not None:
225
+ batch_size = len(prompt)
226
+ else:
227
+ batch_size = prompt_embeds.shape[0]
228
+
229
+ if prompt_embeds is None:
230
+ prompt_embeds = self._get_t5_prompt_embeds(
231
+ prompt=prompt,
232
+ num_videos_per_prompt=num_videos_per_prompt,
233
+ max_sequence_length=max_sequence_length,
234
+ device=device,
235
+ dtype=dtype,
236
+ )
237
+
238
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
239
+ negative_prompt = negative_prompt or ""
240
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
241
+
242
+ if prompt is not None and type(prompt) is not type(negative_prompt):
243
+ raise TypeError(
244
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
245
+ f" {type(prompt)}."
246
+ )
247
+ elif batch_size != len(negative_prompt):
248
+ raise ValueError(
249
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
250
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
251
+ " the batch size of `prompt`."
252
+ )
253
+
254
+ negative_prompt_embeds = self._get_t5_prompt_embeds(
255
+ prompt=negative_prompt,
256
+ num_videos_per_prompt=num_videos_per_prompt,
257
+ max_sequence_length=max_sequence_length,
258
+ device=device,
259
+ dtype=dtype,
260
+ )
261
+
262
+ return prompt_embeds, negative_prompt_embeds
263
+
264
+ def prepare_latents(
265
+ self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
266
+ ):
267
+ if isinstance(generator, list) and len(generator) != batch_size:
268
+ raise ValueError(
269
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
270
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
271
+ )
272
+
273
+ shape = (
274
+ batch_size,
275
+ num_channels_latents,
276
+ (num_frames - 1) // self.vae.temporal_compression_ratio + 1,
277
+ height // self.vae.spatial_compression_ratio,
278
+ width // self.vae.spatial_compression_ratio,
279
+ )
280
+
281
+ if latents is None:
282
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
283
+ else:
284
+ latents = latents.to(device)
285
+
286
+ # scale the initial noise by the standard deviation required by the scheduler
287
+ if hasattr(self.scheduler, "init_noise_sigma"):
288
+ latents = latents * self.scheduler.init_noise_sigma
289
+ return latents
290
+
291
+ def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
292
+ frames = self.vae.decode(latents.to(self.vae.dtype)).sample
293
+ frames = (frames / 2 + 0.5).clamp(0, 1)
294
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
295
+ frames = frames.cpu().float().numpy()
296
+ return frames
297
+
298
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
299
+ def prepare_extra_step_kwargs(self, generator, eta):
300
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
301
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
302
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
303
+ # and should be between [0, 1]
304
+
305
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
306
+ extra_step_kwargs = {}
307
+ if accepts_eta:
308
+ extra_step_kwargs["eta"] = eta
309
+
310
+ # check if the scheduler accepts generator
311
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
312
+ if accepts_generator:
313
+ extra_step_kwargs["generator"] = generator
314
+ return extra_step_kwargs
315
+
316
+ # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs
317
+ def check_inputs(
318
+ self,
319
+ prompt,
320
+ height,
321
+ width,
322
+ negative_prompt,
323
+ callback_on_step_end_tensor_inputs,
324
+ prompt_embeds=None,
325
+ negative_prompt_embeds=None,
326
+ ):
327
+ if height % 8 != 0 or width % 8 != 0:
328
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
329
+
330
+ if callback_on_step_end_tensor_inputs is not None and not all(
331
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
332
+ ):
333
+ raise ValueError(
334
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
335
+ )
336
+ if prompt is not None and prompt_embeds is not None:
337
+ raise ValueError(
338
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
339
+ " only forward one of the two."
340
+ )
341
+ elif prompt is None and prompt_embeds is None:
342
+ raise ValueError(
343
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
344
+ )
345
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
346
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
347
+
348
+ if prompt is not None and negative_prompt_embeds is not None:
349
+ raise ValueError(
350
+ f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
351
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
352
+ )
353
+
354
+ if negative_prompt is not None and negative_prompt_embeds is not None:
355
+ raise ValueError(
356
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
357
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
358
+ )
359
+
360
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
361
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
362
+ raise ValueError(
363
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
364
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
365
+ f" {negative_prompt_embeds.shape}."
366
+ )
367
+
368
+ @property
369
+ def guidance_scale(self):
370
+ return self._guidance_scale
371
+
372
+ @property
373
+ def num_timesteps(self):
374
+ return self._num_timesteps
375
+
376
+ @property
377
+ def attention_kwargs(self):
378
+ return self._attention_kwargs
379
+
380
+ @property
381
+ def interrupt(self):
382
+ return self._interrupt
383
+
384
+ @torch.no_grad()
385
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
386
+ def __call__(
387
+ self,
388
+ prompt: Optional[Union[str, List[str]]] = None,
389
+ negative_prompt: Optional[Union[str, List[str]]] = None,
390
+ height: int = 480,
391
+ width: int = 720,
392
+ num_frames: int = 49,
393
+ num_inference_steps: int = 50,
394
+ timesteps: Optional[List[int]] = None,
395
+ guidance_scale: float = 6,
396
+ num_videos_per_prompt: int = 1,
397
+ eta: float = 0.0,
398
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
399
+ latents: Optional[torch.FloatTensor] = None,
400
+ prompt_embeds: Optional[torch.FloatTensor] = None,
401
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
402
+ output_type: str = "numpy",
403
+ return_dict: bool = False,
404
+ callback_on_step_end: Optional[
405
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
406
+ ] = None,
407
+ attention_kwargs: Optional[Dict[str, Any]] = None,
408
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
409
+ max_sequence_length: int = 512,
410
+ comfyui_progressbar: bool = False,
411
+ shift: int = 5,
412
+ ) -> Union[WanPipelineOutput, Tuple]:
413
+ """
414
+ Function invoked when calling the pipeline for generation.
415
+ Args:
416
+
417
+ Examples:
418
+
419
+ Returns:
420
+
421
+ """
422
+
423
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
424
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
425
+ num_videos_per_prompt = 1
426
+
427
+ # 1. Check inputs. Raise error if not correct
428
+ self.check_inputs(
429
+ prompt,
430
+ height,
431
+ width,
432
+ negative_prompt,
433
+ callback_on_step_end_tensor_inputs,
434
+ prompt_embeds,
435
+ negative_prompt_embeds,
436
+ )
437
+ self._guidance_scale = guidance_scale
438
+ self._attention_kwargs = attention_kwargs
439
+ self._interrupt = False
440
+
441
+ # 2. Default call parameters
442
+ if prompt is not None and isinstance(prompt, str):
443
+ batch_size = 1
444
+ elif prompt is not None and isinstance(prompt, list):
445
+ batch_size = len(prompt)
446
+ else:
447
+ batch_size = prompt_embeds.shape[0]
448
+
449
+ device = self._execution_device
450
+ weight_dtype = self.text_encoder.dtype
451
+
452
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
453
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
454
+ # corresponds to doing no classifier free guidance.
455
+ do_classifier_free_guidance = guidance_scale > 1.0
456
+
457
+ # 3. Encode input prompt
458
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
459
+ prompt,
460
+ negative_prompt,
461
+ do_classifier_free_guidance,
462
+ num_videos_per_prompt=num_videos_per_prompt,
463
+ prompt_embeds=prompt_embeds,
464
+ negative_prompt_embeds=negative_prompt_embeds,
465
+ max_sequence_length=max_sequence_length,
466
+ device=device,
467
+ )
468
+ if do_classifier_free_guidance:
469
+ in_prompt_embeds = negative_prompt_embeds + prompt_embeds
470
+ else:
471
+ in_prompt_embeds = prompt_embeds
472
+
473
+ # 4. Prepare timesteps
474
+ if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler):
475
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps, mu=1)
476
+ elif isinstance(self.scheduler, FlowUniPCMultistepScheduler):
477
+ self.scheduler.set_timesteps(num_inference_steps, device=device, shift=shift)
478
+ timesteps = self.scheduler.timesteps
479
+ elif isinstance(self.scheduler, FlowDPMSolverMultistepScheduler):
480
+ sampling_sigmas = get_sampling_sigmas(num_inference_steps, shift)
481
+ timesteps, _ = retrieve_timesteps(
482
+ self.scheduler,
483
+ device=device,
484
+ sigmas=sampling_sigmas)
485
+ else:
486
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
487
+ self._num_timesteps = len(timesteps)
488
+ if comfyui_progressbar:
489
+ from comfy.utils import ProgressBar
490
+ pbar = ProgressBar(num_inference_steps + 1)
491
+
492
+ # 5. Prepare latents
493
+ latent_channels = self.transformer.config.in_channels
494
+ latents = self.prepare_latents(
495
+ batch_size * num_videos_per_prompt,
496
+ latent_channels,
497
+ num_frames,
498
+ height,
499
+ width,
500
+ weight_dtype,
501
+ device,
502
+ generator,
503
+ latents,
504
+ )
505
+ if comfyui_progressbar:
506
+ pbar.update(1)
507
+
508
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
509
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
510
+
511
+ target_shape = (self.vae.latent_channels, (num_frames - 1) // self.vae.temporal_compression_ratio + 1, width // self.vae.spatial_compression_ratio, height // self.vae.spatial_compression_ratio)
512
+ seq_len = math.ceil((target_shape[2] * target_shape[3]) / (self.transformer.config.patch_size[1] * self.transformer.config.patch_size[2]) * target_shape[1])
513
+ # 7. Denoising loop
514
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
515
+ self.transformer.num_inference_steps = num_inference_steps
516
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
517
+ for i, t in enumerate(timesteps):
518
+ self.transformer.current_steps = i
519
+
520
+ if self.interrupt:
521
+ continue
522
+
523
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
524
+ if hasattr(self.scheduler, "scale_model_input"):
525
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
526
+
527
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
528
+ timestep = t.expand(latent_model_input.shape[0])
529
+
530
+ # predict noise model_output
531
+ with torch.cuda.amp.autocast(dtype=weight_dtype), torch.cuda.device(device=device):
532
+ noise_pred = self.transformer(
533
+ x=latent_model_input,
534
+ context=in_prompt_embeds,
535
+ t=timestep,
536
+ seq_len=seq_len,
537
+ )
538
+
539
+ # perform guidance
540
+ if do_classifier_free_guidance:
541
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
542
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
543
+
544
+ # compute the previous noisy sample x_t -> x_t-1
545
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
546
+
547
+ if callback_on_step_end is not None:
548
+ callback_kwargs = {}
549
+ for k in callback_on_step_end_tensor_inputs:
550
+ callback_kwargs[k] = locals()[k]
551
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
552
+
553
+ latents = callback_outputs.pop("latents", latents)
554
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
555
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
556
+
557
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
558
+ progress_bar.update()
559
+ if comfyui_progressbar:
560
+ pbar.update(1)
561
+
562
+ if output_type == "numpy":
563
+ video = self.decode_latents(latents)
564
+ elif not output_type == "latent":
565
+ video = self.decode_latents(latents)
566
+ video = self.video_processor.postprocess_video(video=video, output_type=output_type)
567
+ else:
568
+ video = latents
569
+
570
+ # Offload all models
571
+ self.maybe_free_model_hooks()
572
+
573
+ if not return_dict:
574
+ video = torch.from_numpy(video)
575
+
576
+ return WanPipelineOutput(videos=video)
videox_fun/pipeline/pipeline_wan2_2.py ADDED
@@ -0,0 +1,591 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import math
3
+ from dataclasses import dataclass
4
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
5
+
6
+ import numpy as np
7
+ import torch
8
+ from diffusers import FlowMatchEulerDiscreteScheduler
9
+ from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
10
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
11
+ from diffusers.utils import BaseOutput, logging, replace_example_docstring
12
+ from diffusers.utils.torch_utils import randn_tensor
13
+ from diffusers.video_processor import VideoProcessor
14
+
15
+ from ..models import (AutoencoderKLWan, AutoTokenizer,
16
+ WanT5EncoderModel, Wan2_2Transformer3DModel)
17
+ from ..utils.fm_solvers import (FlowDPMSolverMultistepScheduler,
18
+ get_sampling_sigmas)
19
+ from ..utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
20
+
21
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
22
+
23
+
24
+ EXAMPLE_DOC_STRING = """
25
+ Examples:
26
+ ```python
27
+ pass
28
+ ```
29
+ """
30
+
31
+
32
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
33
+ def retrieve_timesteps(
34
+ scheduler,
35
+ num_inference_steps: Optional[int] = None,
36
+ device: Optional[Union[str, torch.device]] = None,
37
+ timesteps: Optional[List[int]] = None,
38
+ sigmas: Optional[List[float]] = None,
39
+ **kwargs,
40
+ ):
41
+ """
42
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
43
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
44
+
45
+ Args:
46
+ scheduler (`SchedulerMixin`):
47
+ The scheduler to get timesteps from.
48
+ num_inference_steps (`int`):
49
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
50
+ must be `None`.
51
+ device (`str` or `torch.device`, *optional*):
52
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
53
+ timesteps (`List[int]`, *optional*):
54
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
55
+ `num_inference_steps` and `sigmas` must be `None`.
56
+ sigmas (`List[float]`, *optional*):
57
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
58
+ `num_inference_steps` and `timesteps` must be `None`.
59
+
60
+ Returns:
61
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
62
+ second element is the number of inference steps.
63
+ """
64
+ if timesteps is not None and sigmas is not None:
65
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
66
+ if timesteps is not None:
67
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
68
+ if not accepts_timesteps:
69
+ raise ValueError(
70
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
71
+ f" timestep schedules. Please check whether you are using the correct scheduler."
72
+ )
73
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
74
+ timesteps = scheduler.timesteps
75
+ num_inference_steps = len(timesteps)
76
+ elif sigmas is not None:
77
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
78
+ if not accept_sigmas:
79
+ raise ValueError(
80
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
81
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
82
+ )
83
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
84
+ timesteps = scheduler.timesteps
85
+ num_inference_steps = len(timesteps)
86
+ else:
87
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
88
+ timesteps = scheduler.timesteps
89
+ return timesteps, num_inference_steps
90
+
91
+
92
+ @dataclass
93
+ class WanPipelineOutput(BaseOutput):
94
+ r"""
95
+ Output class for CogVideo pipelines.
96
+
97
+ Args:
98
+ video (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
99
+ List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
100
+ denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
101
+ `(batch_size, num_frames, channels, height, width)`.
102
+ """
103
+
104
+ videos: torch.Tensor
105
+
106
+
107
+ class Wan2_2Pipeline(DiffusionPipeline):
108
+ r"""
109
+ Pipeline for text-to-video generation using Wan.
110
+
111
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
112
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
113
+ """
114
+
115
+ _optional_components = ["transformer_2"]
116
+ model_cpu_offload_seq = "text_encoder->transformer_2->transformer->vae"
117
+
118
+ _callback_tensor_inputs = [
119
+ "latents",
120
+ "prompt_embeds",
121
+ "negative_prompt_embeds",
122
+ ]
123
+
124
+ def __init__(
125
+ self,
126
+ tokenizer: AutoTokenizer,
127
+ text_encoder: WanT5EncoderModel,
128
+ vae: AutoencoderKLWan,
129
+ transformer: Wan2_2Transformer3DModel,
130
+ transformer_2: Wan2_2Transformer3DModel = None,
131
+ scheduler: FlowMatchEulerDiscreteScheduler = None,
132
+ ):
133
+ super().__init__()
134
+
135
+ self.register_modules(
136
+ tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer,
137
+ transformer_2=transformer_2, scheduler=scheduler
138
+ )
139
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae.spatial_compression_ratio)
140
+
141
+ def _get_t5_prompt_embeds(
142
+ self,
143
+ prompt: Union[str, List[str]] = None,
144
+ num_videos_per_prompt: int = 1,
145
+ max_sequence_length: int = 512,
146
+ device: Optional[torch.device] = None,
147
+ dtype: Optional[torch.dtype] = None,
148
+ ):
149
+ device = device or self._execution_device
150
+ dtype = dtype or self.text_encoder.dtype
151
+
152
+ prompt = [prompt] if isinstance(prompt, str) else prompt
153
+ batch_size = len(prompt)
154
+
155
+ text_inputs = self.tokenizer(
156
+ prompt,
157
+ padding="max_length",
158
+ max_length=max_sequence_length,
159
+ truncation=True,
160
+ add_special_tokens=True,
161
+ return_tensors="pt",
162
+ )
163
+ text_input_ids = text_inputs.input_ids
164
+ prompt_attention_mask = text_inputs.attention_mask
165
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
166
+
167
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
168
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
169
+ logger.warning(
170
+ "The following part of your input was truncated because `max_sequence_length` is set to "
171
+ f" {max_sequence_length} tokens: {removed_text}"
172
+ )
173
+
174
+ seq_lens = prompt_attention_mask.gt(0).sum(dim=1).long()
175
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask.to(device))[0]
176
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
177
+
178
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
179
+ _, seq_len, _ = prompt_embeds.shape
180
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
181
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
182
+
183
+ return [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
184
+
185
+ def encode_prompt(
186
+ self,
187
+ prompt: Union[str, List[str]],
188
+ negative_prompt: Optional[Union[str, List[str]]] = None,
189
+ do_classifier_free_guidance: bool = True,
190
+ num_videos_per_prompt: int = 1,
191
+ prompt_embeds: Optional[torch.Tensor] = None,
192
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
193
+ max_sequence_length: int = 512,
194
+ device: Optional[torch.device] = None,
195
+ dtype: Optional[torch.dtype] = None,
196
+ ):
197
+ r"""
198
+ Encodes the prompt into text encoder hidden states.
199
+
200
+ Args:
201
+ prompt (`str` or `List[str]`, *optional*):
202
+ prompt to be encoded
203
+ negative_prompt (`str` or `List[str]`, *optional*):
204
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
205
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
206
+ less than `1`).
207
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
208
+ Whether to use classifier free guidance or not.
209
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
210
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
211
+ prompt_embeds (`torch.Tensor`, *optional*):
212
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
213
+ provided, text embeddings will be generated from `prompt` input argument.
214
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
215
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
216
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
217
+ argument.
218
+ device: (`torch.device`, *optional*):
219
+ torch device
220
+ dtype: (`torch.dtype`, *optional*):
221
+ torch dtype
222
+ """
223
+ device = device or self._execution_device
224
+
225
+ prompt = [prompt] if isinstance(prompt, str) else prompt
226
+ if prompt is not None:
227
+ batch_size = len(prompt)
228
+ else:
229
+ batch_size = prompt_embeds.shape[0]
230
+
231
+ if prompt_embeds is None:
232
+ prompt_embeds = self._get_t5_prompt_embeds(
233
+ prompt=prompt,
234
+ num_videos_per_prompt=num_videos_per_prompt,
235
+ max_sequence_length=max_sequence_length,
236
+ device=device,
237
+ dtype=dtype,
238
+ )
239
+
240
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
241
+ negative_prompt = negative_prompt or ""
242
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
243
+
244
+ if prompt is not None and type(prompt) is not type(negative_prompt):
245
+ raise TypeError(
246
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
247
+ f" {type(prompt)}."
248
+ )
249
+ elif batch_size != len(negative_prompt):
250
+ raise ValueError(
251
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
252
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
253
+ " the batch size of `prompt`."
254
+ )
255
+
256
+ negative_prompt_embeds = self._get_t5_prompt_embeds(
257
+ prompt=negative_prompt,
258
+ num_videos_per_prompt=num_videos_per_prompt,
259
+ max_sequence_length=max_sequence_length,
260
+ device=device,
261
+ dtype=dtype,
262
+ )
263
+
264
+ return prompt_embeds, negative_prompt_embeds
265
+
266
+ def prepare_latents(
267
+ self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
268
+ ):
269
+ if isinstance(generator, list) and len(generator) != batch_size:
270
+ raise ValueError(
271
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
272
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
273
+ )
274
+
275
+ shape = (
276
+ batch_size,
277
+ num_channels_latents,
278
+ (num_frames - 1) // self.vae.temporal_compression_ratio + 1,
279
+ height // self.vae.spatial_compression_ratio,
280
+ width // self.vae.spatial_compression_ratio,
281
+ )
282
+
283
+ if latents is None:
284
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
285
+ else:
286
+ latents = latents.to(device)
287
+
288
+ # scale the initial noise by the standard deviation required by the scheduler
289
+ if hasattr(self.scheduler, "init_noise_sigma"):
290
+ latents = latents * self.scheduler.init_noise_sigma
291
+ return latents
292
+
293
+ def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
294
+ frames = self.vae.decode(latents.to(self.vae.dtype)).sample
295
+ frames = (frames / 2 + 0.5).clamp(0, 1)
296
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
297
+ frames = frames.cpu().float().numpy()
298
+ return frames
299
+
300
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
301
+ def prepare_extra_step_kwargs(self, generator, eta):
302
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
303
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
304
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
305
+ # and should be between [0, 1]
306
+
307
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
308
+ extra_step_kwargs = {}
309
+ if accepts_eta:
310
+ extra_step_kwargs["eta"] = eta
311
+
312
+ # check if the scheduler accepts generator
313
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
314
+ if accepts_generator:
315
+ extra_step_kwargs["generator"] = generator
316
+ return extra_step_kwargs
317
+
318
+ # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs
319
+ def check_inputs(
320
+ self,
321
+ prompt,
322
+ height,
323
+ width,
324
+ negative_prompt,
325
+ callback_on_step_end_tensor_inputs,
326
+ prompt_embeds=None,
327
+ negative_prompt_embeds=None,
328
+ ):
329
+ if height % 8 != 0 or width % 8 != 0:
330
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
331
+
332
+ if callback_on_step_end_tensor_inputs is not None and not all(
333
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
334
+ ):
335
+ raise ValueError(
336
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
337
+ )
338
+ if prompt is not None and prompt_embeds is not None:
339
+ raise ValueError(
340
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
341
+ " only forward one of the two."
342
+ )
343
+ elif prompt is None and prompt_embeds is None:
344
+ raise ValueError(
345
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
346
+ )
347
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
348
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
349
+
350
+ if prompt is not None and negative_prompt_embeds is not None:
351
+ raise ValueError(
352
+ f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
353
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
354
+ )
355
+
356
+ if negative_prompt is not None and negative_prompt_embeds is not None:
357
+ raise ValueError(
358
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
359
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
360
+ )
361
+
362
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
363
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
364
+ raise ValueError(
365
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
366
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
367
+ f" {negative_prompt_embeds.shape}."
368
+ )
369
+
370
+ @property
371
+ def guidance_scale(self):
372
+ return self._guidance_scale
373
+
374
+ @property
375
+ def num_timesteps(self):
376
+ return self._num_timesteps
377
+
378
+ @property
379
+ def attention_kwargs(self):
380
+ return self._attention_kwargs
381
+
382
+ @property
383
+ def interrupt(self):
384
+ return self._interrupt
385
+
386
+ @torch.no_grad()
387
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
388
+ def __call__(
389
+ self,
390
+ prompt: Optional[Union[str, List[str]]] = None,
391
+ negative_prompt: Optional[Union[str, List[str]]] = None,
392
+ height: int = 480,
393
+ width: int = 720,
394
+ num_frames: int = 49,
395
+ num_inference_steps: int = 50,
396
+ timesteps: Optional[List[int]] = None,
397
+ guidance_scale: float = 6,
398
+ num_videos_per_prompt: int = 1,
399
+ eta: float = 0.0,
400
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
401
+ latents: Optional[torch.FloatTensor] = None,
402
+ prompt_embeds: Optional[torch.FloatTensor] = None,
403
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
404
+ output_type: str = "numpy",
405
+ return_dict: bool = False,
406
+ callback_on_step_end: Optional[
407
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
408
+ ] = None,
409
+ attention_kwargs: Optional[Dict[str, Any]] = None,
410
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
411
+ max_sequence_length: int = 512,
412
+ boundary: float = 0.875,
413
+ comfyui_progressbar: bool = False,
414
+ shift: int = 5,
415
+ ) -> Union[WanPipelineOutput, Tuple]:
416
+ """
417
+ Function invoked when calling the pipeline for generation.
418
+ Args:
419
+
420
+ Examples:
421
+
422
+ Returns:
423
+
424
+ """
425
+
426
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
427
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
428
+ num_videos_per_prompt = 1
429
+
430
+ # 1. Check inputs. Raise error if not correct
431
+ self.check_inputs(
432
+ prompt,
433
+ height,
434
+ width,
435
+ negative_prompt,
436
+ callback_on_step_end_tensor_inputs,
437
+ prompt_embeds,
438
+ negative_prompt_embeds,
439
+ )
440
+ self._guidance_scale = guidance_scale
441
+ self._attention_kwargs = attention_kwargs
442
+ self._interrupt = False
443
+
444
+ # 2. Default call parameters
445
+ if prompt is not None and isinstance(prompt, str):
446
+ batch_size = 1
447
+ elif prompt is not None and isinstance(prompt, list):
448
+ batch_size = len(prompt)
449
+ else:
450
+ batch_size = prompt_embeds.shape[0]
451
+
452
+ device = self._execution_device
453
+ weight_dtype = self.text_encoder.dtype
454
+
455
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
456
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
457
+ # corresponds to doing no classifier free guidance.
458
+ do_classifier_free_guidance = guidance_scale > 1.0
459
+
460
+ # 3. Encode input prompt
461
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
462
+ prompt,
463
+ negative_prompt,
464
+ do_classifier_free_guidance,
465
+ num_videos_per_prompt=num_videos_per_prompt,
466
+ prompt_embeds=prompt_embeds,
467
+ negative_prompt_embeds=negative_prompt_embeds,
468
+ max_sequence_length=max_sequence_length,
469
+ device=device,
470
+ )
471
+ if do_classifier_free_guidance:
472
+ in_prompt_embeds = negative_prompt_embeds + prompt_embeds
473
+ else:
474
+ in_prompt_embeds = prompt_embeds
475
+
476
+ # 4. Prepare timesteps
477
+ if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler):
478
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps, mu=1)
479
+ elif isinstance(self.scheduler, FlowUniPCMultistepScheduler):
480
+ self.scheduler.set_timesteps(num_inference_steps, device=device, shift=shift)
481
+ timesteps = self.scheduler.timesteps
482
+ elif isinstance(self.scheduler, FlowDPMSolverMultistepScheduler):
483
+ sampling_sigmas = get_sampling_sigmas(num_inference_steps, shift)
484
+ timesteps, _ = retrieve_timesteps(
485
+ self.scheduler,
486
+ device=device,
487
+ sigmas=sampling_sigmas)
488
+ else:
489
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
490
+ self._num_timesteps = len(timesteps)
491
+ if comfyui_progressbar:
492
+ from comfy.utils import ProgressBar
493
+ pbar = ProgressBar(num_inference_steps + 1)
494
+
495
+ # 5. Prepare latents
496
+ latent_channels = self.transformer.config.in_channels
497
+ latents = self.prepare_latents(
498
+ batch_size * num_videos_per_prompt,
499
+ latent_channels,
500
+ num_frames,
501
+ height,
502
+ width,
503
+ weight_dtype,
504
+ device,
505
+ generator,
506
+ latents,
507
+ )
508
+ if comfyui_progressbar:
509
+ pbar.update(1)
510
+
511
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
512
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
513
+
514
+ target_shape = (self.vae.latent_channels, (num_frames - 1) // self.vae.temporal_compression_ratio + 1, width // self.vae.spatial_compression_ratio, height // self.vae.spatial_compression_ratio)
515
+ seq_len = math.ceil((target_shape[2] * target_shape[3]) / (self.transformer.config.patch_size[1] * self.transformer.config.patch_size[2]) * target_shape[1])
516
+ # 7. Denoising loop
517
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
518
+ self.transformer.num_inference_steps = num_inference_steps
519
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
520
+ for i, t in enumerate(timesteps):
521
+ self.transformer.current_steps = i
522
+
523
+ if self.interrupt:
524
+ continue
525
+
526
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
527
+ if hasattr(self.scheduler, "scale_model_input"):
528
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
529
+
530
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
531
+ timestep = t.expand(latent_model_input.shape[0])
532
+
533
+ if self.transformer_2 is not None:
534
+ if t >= boundary * self.scheduler.config.num_train_timesteps:
535
+ local_transformer = self.transformer_2
536
+ else:
537
+ local_transformer = self.transformer
538
+ else:
539
+ local_transformer = self.transformer
540
+
541
+ # predict noise model_output
542
+ with torch.cuda.amp.autocast(dtype=weight_dtype), torch.cuda.device(device=device):
543
+ noise_pred = local_transformer(
544
+ x=latent_model_input,
545
+ context=in_prompt_embeds,
546
+ t=timestep,
547
+ seq_len=seq_len,
548
+ )
549
+
550
+ # perform guidance
551
+ if do_classifier_free_guidance:
552
+ if self.transformer_2 is not None and (isinstance(self.guidance_scale, (list, tuple))):
553
+ sample_guide_scale = self.guidance_scale[1] if t >= self.transformer_2.config.boundary * self.scheduler.config.num_train_timesteps else self.guidance_scale[0]
554
+ else:
555
+ sample_guide_scale = self.guidance_scale
556
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
557
+ noise_pred = noise_pred_uncond + sample_guide_scale * (noise_pred_text - noise_pred_uncond)
558
+
559
+ # compute the previous noisy sample x_t -> x_t-1
560
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
561
+
562
+ if callback_on_step_end is not None:
563
+ callback_kwargs = {}
564
+ for k in callback_on_step_end_tensor_inputs:
565
+ callback_kwargs[k] = locals()[k]
566
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
567
+
568
+ latents = callback_outputs.pop("latents", latents)
569
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
570
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
571
+
572
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
573
+ progress_bar.update()
574
+ if comfyui_progressbar:
575
+ pbar.update(1)
576
+
577
+ if output_type == "numpy":
578
+ video = self.decode_latents(latents)
579
+ elif not output_type == "latent":
580
+ video = self.decode_latents(latents)
581
+ video = self.video_processor.postprocess_video(video=video, output_type=output_type)
582
+ else:
583
+ video = latents
584
+
585
+ # Offload all models
586
+ self.maybe_free_model_hooks()
587
+
588
+ if not return_dict:
589
+ video = torch.from_numpy(video)
590
+
591
+ return WanPipelineOutput(videos=video)
videox_fun/pipeline/pipeline_wan2_2_animate.py ADDED
@@ -0,0 +1,929 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import math
3
+ from copy import deepcopy
4
+ from dataclasses import dataclass
5
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
6
+
7
+ import numpy as np
8
+ import copy
9
+ import torch
10
+ import cv2
11
+ import torch.nn.functional as F
12
+ from einops import rearrange
13
+ from diffusers import FlowMatchEulerDiscreteScheduler
14
+ from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
15
+ from diffusers.image_processor import VaeImageProcessor
16
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
17
+ from diffusers.utils import BaseOutput, logging, replace_example_docstring
18
+ from diffusers.utils.torch_utils import randn_tensor
19
+ from diffusers.video_processor import VideoProcessor
20
+ from decord import VideoReader
21
+
22
+ from ..models import (AutoencoderKLWan, AutoTokenizer, CLIPModel,
23
+ WanT5EncoderModel, Wan2_2Transformer3DModel_Animate)
24
+ from ..utils.fm_solvers import (FlowDPMSolverMultistepScheduler,
25
+ get_sampling_sigmas)
26
+ from ..utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
27
+
28
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
29
+
30
+
31
+ EXAMPLE_DOC_STRING = """
32
+ Examples:
33
+ ```python
34
+ pass
35
+ ```
36
+ """
37
+
38
+
39
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
40
+ def retrieve_timesteps(
41
+ scheduler,
42
+ num_inference_steps: Optional[int] = None,
43
+ device: Optional[Union[str, torch.device]] = None,
44
+ timesteps: Optional[List[int]] = None,
45
+ sigmas: Optional[List[float]] = None,
46
+ **kwargs,
47
+ ):
48
+ """
49
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
50
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
51
+
52
+ Args:
53
+ scheduler (`SchedulerMixin`):
54
+ The scheduler to get timesteps from.
55
+ num_inference_steps (`int`):
56
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
57
+ must be `None`.
58
+ device (`str` or `torch.device`, *optional*):
59
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
60
+ timesteps (`List[int]`, *optional*):
61
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
62
+ `num_inference_steps` and `sigmas` must be `None`.
63
+ sigmas (`List[float]`, *optional*):
64
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
65
+ `num_inference_steps` and `timesteps` must be `None`.
66
+
67
+ Returns:
68
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
69
+ second element is the number of inference steps.
70
+ """
71
+ if timesteps is not None and sigmas is not None:
72
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
73
+ if timesteps is not None:
74
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
75
+ if not accepts_timesteps:
76
+ raise ValueError(
77
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
78
+ f" timestep schedules. Please check whether you are using the correct scheduler."
79
+ )
80
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
81
+ timesteps = scheduler.timesteps
82
+ num_inference_steps = len(timesteps)
83
+ elif sigmas is not None:
84
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
85
+ if not accept_sigmas:
86
+ raise ValueError(
87
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
88
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
89
+ )
90
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
91
+ timesteps = scheduler.timesteps
92
+ num_inference_steps = len(timesteps)
93
+ else:
94
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
95
+ timesteps = scheduler.timesteps
96
+ return timesteps, num_inference_steps
97
+
98
+
99
+ @dataclass
100
+ class WanPipelineOutput(BaseOutput):
101
+ r"""
102
+ Output class for CogVideo pipelines.
103
+
104
+ Args:
105
+ video (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
106
+ List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
107
+ denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
108
+ `(batch_size, num_frames, channels, height, width)`.
109
+ """
110
+
111
+ videos: torch.Tensor
112
+
113
+
114
+ class Wan2_2AnimatePipeline(DiffusionPipeline):
115
+ r"""
116
+ Pipeline for text-to-video generation using Wan.
117
+
118
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
119
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
120
+ """
121
+
122
+ _optional_components = ["transformer_2", "clip_image_encoder"]
123
+ model_cpu_offload_seq = "text_encoder->clip_image_encoder->transformer_2->transformer->vae"
124
+
125
+ _callback_tensor_inputs = [
126
+ "latents",
127
+ "prompt_embeds",
128
+ "negative_prompt_embeds",
129
+ ]
130
+
131
+ def __init__(
132
+ self,
133
+ tokenizer: AutoTokenizer,
134
+ text_encoder: WanT5EncoderModel,
135
+ vae: AutoencoderKLWan,
136
+ transformer: Wan2_2Transformer3DModel_Animate,
137
+ transformer_2: Wan2_2Transformer3DModel_Animate = None,
138
+ clip_image_encoder: CLIPModel = None,
139
+ scheduler: FlowMatchEulerDiscreteScheduler = None,
140
+ ):
141
+ super().__init__()
142
+
143
+ self.register_modules(
144
+ tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer,
145
+ transformer_2=transformer_2, clip_image_encoder=clip_image_encoder, scheduler=scheduler
146
+ )
147
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae.spatial_compression_ratio)
148
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae.spatial_compression_ratio)
149
+ self.mask_processor = VaeImageProcessor(
150
+ vae_scale_factor=self.vae.spatial_compression_ratio, do_normalize=False, do_binarize=True, do_convert_grayscale=True
151
+ )
152
+
153
+ def _get_t5_prompt_embeds(
154
+ self,
155
+ prompt: Union[str, List[str]] = None,
156
+ num_videos_per_prompt: int = 1,
157
+ max_sequence_length: int = 512,
158
+ device: Optional[torch.device] = None,
159
+ dtype: Optional[torch.dtype] = None,
160
+ ):
161
+ device = device or self._execution_device
162
+ dtype = dtype or self.text_encoder.dtype
163
+
164
+ prompt = [prompt] if isinstance(prompt, str) else prompt
165
+ batch_size = len(prompt)
166
+
167
+ text_inputs = self.tokenizer(
168
+ prompt,
169
+ padding="max_length",
170
+ max_length=max_sequence_length,
171
+ truncation=True,
172
+ add_special_tokens=True,
173
+ return_tensors="pt",
174
+ )
175
+ text_input_ids = text_inputs.input_ids
176
+ prompt_attention_mask = text_inputs.attention_mask
177
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
178
+
179
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
180
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
181
+ logger.warning(
182
+ "The following part of your input was truncated because `max_sequence_length` is set to "
183
+ f" {max_sequence_length} tokens: {removed_text}"
184
+ )
185
+
186
+ seq_lens = prompt_attention_mask.gt(0).sum(dim=1).long()
187
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask.to(device))[0]
188
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
189
+
190
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
191
+ _, seq_len, _ = prompt_embeds.shape
192
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
193
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
194
+
195
+ return [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
196
+
197
+ def encode_prompt(
198
+ self,
199
+ prompt: Union[str, List[str]],
200
+ negative_prompt: Optional[Union[str, List[str]]] = None,
201
+ do_classifier_free_guidance: bool = True,
202
+ num_videos_per_prompt: int = 1,
203
+ prompt_embeds: Optional[torch.Tensor] = None,
204
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
205
+ max_sequence_length: int = 512,
206
+ device: Optional[torch.device] = None,
207
+ dtype: Optional[torch.dtype] = None,
208
+ ):
209
+ r"""
210
+ Encodes the prompt into text encoder hidden states.
211
+
212
+ Args:
213
+ prompt (`str` or `List[str]`, *optional*):
214
+ prompt to be encoded
215
+ negative_prompt (`str` or `List[str]`, *optional*):
216
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
217
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
218
+ less than `1`).
219
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
220
+ Whether to use classifier free guidance or not.
221
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
222
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
223
+ prompt_embeds (`torch.Tensor`, *optional*):
224
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
225
+ provided, text embeddings will be generated from `prompt` input argument.
226
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
227
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
228
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
229
+ argument.
230
+ device: (`torch.device`, *optional*):
231
+ torch device
232
+ dtype: (`torch.dtype`, *optional*):
233
+ torch dtype
234
+ """
235
+ device = device or self._execution_device
236
+
237
+ prompt = [prompt] if isinstance(prompt, str) else prompt
238
+ if prompt is not None:
239
+ batch_size = len(prompt)
240
+ else:
241
+ batch_size = prompt_embeds.shape[0]
242
+
243
+ if prompt_embeds is None:
244
+ prompt_embeds = self._get_t5_prompt_embeds(
245
+ prompt=prompt,
246
+ num_videos_per_prompt=num_videos_per_prompt,
247
+ max_sequence_length=max_sequence_length,
248
+ device=device,
249
+ dtype=dtype,
250
+ )
251
+
252
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
253
+ negative_prompt = negative_prompt or ""
254
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
255
+
256
+ if prompt is not None and type(prompt) is not type(negative_prompt):
257
+ raise TypeError(
258
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
259
+ f" {type(prompt)}."
260
+ )
261
+ elif batch_size != len(negative_prompt):
262
+ raise ValueError(
263
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
264
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
265
+ " the batch size of `prompt`."
266
+ )
267
+
268
+ negative_prompt_embeds = self._get_t5_prompt_embeds(
269
+ prompt=negative_prompt,
270
+ num_videos_per_prompt=num_videos_per_prompt,
271
+ max_sequence_length=max_sequence_length,
272
+ device=device,
273
+ dtype=dtype,
274
+ )
275
+
276
+ return prompt_embeds, negative_prompt_embeds
277
+
278
+ def prepare_latents(
279
+ self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
280
+ ):
281
+ if isinstance(generator, list) and len(generator) != batch_size:
282
+ raise ValueError(
283
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
284
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
285
+ )
286
+
287
+ shape = (
288
+ batch_size,
289
+ num_channels_latents,
290
+ (num_frames - 1) // self.vae.temporal_compression_ratio + 1,
291
+ height // self.vae.spatial_compression_ratio,
292
+ width // self.vae.spatial_compression_ratio,
293
+ )
294
+
295
+ if latents is None:
296
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
297
+ else:
298
+ latents = latents.to(device)
299
+
300
+ # scale the initial noise by the standard deviation required by the scheduler
301
+ if hasattr(self.scheduler, "init_noise_sigma"):
302
+ latents = latents * self.scheduler.init_noise_sigma
303
+ return latents
304
+
305
+ def padding_resize(self, img_ori, height=512, width=512, padding_color=(0, 0, 0), interpolation=cv2.INTER_LINEAR):
306
+ ori_height = img_ori.shape[0]
307
+ ori_width = img_ori.shape[1]
308
+ channel = img_ori.shape[2]
309
+
310
+ img_pad = np.zeros((height, width, channel))
311
+ if channel == 1:
312
+ img_pad[:, :, 0] = padding_color[0]
313
+ else:
314
+ img_pad[:, :, 0] = padding_color[0]
315
+ img_pad[:, :, 1] = padding_color[1]
316
+ img_pad[:, :, 2] = padding_color[2]
317
+
318
+ if (ori_height / ori_width) > (height / width):
319
+ new_width = int(height / ori_height * ori_width)
320
+ img = cv2.resize(img_ori, (new_width, height), interpolation=interpolation)
321
+ padding = int((width - new_width) / 2)
322
+ if len(img.shape) == 2:
323
+ img = img[:, :, np.newaxis]
324
+ img_pad[:, padding: padding + new_width, :] = img
325
+ else:
326
+ new_height = int(width / ori_width * ori_height)
327
+ img = cv2.resize(img_ori, (width, new_height), interpolation=interpolation)
328
+ padding = int((height - new_height) / 2)
329
+ if len(img.shape) == 2:
330
+ img = img[:, :, np.newaxis]
331
+ img_pad[padding: padding + new_height, :, :] = img
332
+
333
+ img_pad = np.uint8(img_pad)
334
+
335
+ return img_pad
336
+
337
+ def inputs_padding(self, x, target_len):
338
+ ndim = x.ndim
339
+
340
+ if ndim == 4:
341
+ f = x.shape[0]
342
+ if target_len <= f:
343
+ return [deepcopy(x[i]) for i in range(target_len)]
344
+
345
+ idx = 0
346
+ flip = False
347
+ target_array = []
348
+ while len(target_array) < target_len:
349
+ target_array.append(deepcopy(x[idx]))
350
+ if flip:
351
+ idx -= 1
352
+ else:
353
+ idx += 1
354
+ if idx == 0 or idx == f - 1:
355
+ flip = not flip
356
+ return target_array[:target_len]
357
+
358
+ elif ndim == 5:
359
+ b, c, f, h, w = x.shape
360
+
361
+ if target_len <= f:
362
+ return x[:, :, :target_len, :, :]
363
+
364
+ indices = []
365
+ idx = 0
366
+ flip = False
367
+ while len(indices) < target_len:
368
+ indices.append(idx)
369
+ if flip:
370
+ idx -= 1
371
+ else:
372
+ idx += 1
373
+ if idx == 0 or idx == f - 1:
374
+ flip = not flip
375
+ indices = indices[:target_len]
376
+
377
+ if isinstance(x, torch.Tensor):
378
+ indices_tensor = torch.tensor(indices, device=x.device, dtype=torch.long)
379
+ return x[:, :, indices_tensor, :, :]
380
+ else:
381
+ indices_array = np.array(indices)
382
+ return x[:, :, indices_array, :, :]
383
+
384
+ else:
385
+ raise ValueError(f"Unsupported input dimension: {ndim}. Expected 4D or 5D.")
386
+
387
+ def get_valid_len(self, real_len, clip_len=81, overlap=1):
388
+ real_clip_len = clip_len - overlap
389
+ last_clip_num = (real_len - overlap) % real_clip_len
390
+ if last_clip_num == 0:
391
+ extra = 0
392
+ else:
393
+ extra = real_clip_len - last_clip_num
394
+ target_len = real_len + extra
395
+ return target_len
396
+
397
+ def prepare_source(self, src_pose_path, src_face_path, src_ref_path):
398
+ pose_video_reader = VideoReader(src_pose_path)
399
+ pose_len = len(pose_video_reader)
400
+ pose_idxs = list(range(pose_len))
401
+ pose_video = pose_video_reader.get_batch(pose_idxs).asnumpy()
402
+
403
+ face_video_reader = VideoReader(src_face_path)
404
+ face_len = len(face_video_reader)
405
+ face_idxs = list(range(face_len))
406
+ face_video = face_video_reader.get_batch(face_idxs).asnumpy()
407
+ height, width = pose_video[0].shape[:2]
408
+
409
+ ref_image = cv2.imread(src_ref_path)[..., ::-1]
410
+ ref_image = self.padding_resize(ref_image, height=height, width=width)
411
+ return pose_video, face_video, ref_image
412
+
413
+ def prepare_source_for_replace(self, src_bg_path, src_mask_path):
414
+ bg_video_reader = VideoReader(src_bg_path)
415
+ bg_len = len(bg_video_reader)
416
+ bg_idxs = list(range(bg_len))
417
+ bg_video = bg_video_reader.get_batch(bg_idxs).asnumpy()
418
+
419
+ mask_video_reader = VideoReader(src_mask_path)
420
+ mask_len = len(mask_video_reader)
421
+ mask_idxs = list(range(mask_len))
422
+ mask_video = mask_video_reader.get_batch(mask_idxs).asnumpy()
423
+ mask_video = mask_video[:, :, :, 0] / 255
424
+ return bg_video, mask_video
425
+
426
+ def get_i2v_mask(self, lat_t, lat_h, lat_w, mask_len=1, mask_pixel_values=None, device="cuda"):
427
+ if mask_pixel_values is None:
428
+ msk = torch.zeros(1, (lat_t-1) * 4 + 1, lat_h, lat_w, device=device)
429
+ else:
430
+ msk = mask_pixel_values.clone()
431
+ msk[:, :mask_len] = 1
432
+ msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
433
+ msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
434
+ msk = msk.transpose(1, 2)
435
+ return msk
436
+
437
+ def prepare_control_latents(
438
+ self, control, control_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
439
+ ):
440
+ # resize the control to latents shape as we concatenate the control to the latents
441
+ # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
442
+ # and half precision
443
+
444
+ if control is not None:
445
+ control = control.to(device=device, dtype=dtype)
446
+ bs = 1
447
+ new_control = []
448
+ for i in range(0, control.shape[0], bs):
449
+ control_bs = control[i : i + bs]
450
+ control_bs = self.vae.encode(control_bs)[0]
451
+ control_bs = control_bs.mode()
452
+ new_control.append(control_bs)
453
+ control = torch.cat(new_control, dim = 0)
454
+
455
+ if control_image is not None:
456
+ control_image = control_image.to(device=device, dtype=dtype)
457
+ bs = 1
458
+ new_control_pixel_values = []
459
+ for i in range(0, control_image.shape[0], bs):
460
+ control_pixel_values_bs = control_image[i : i + bs]
461
+ control_pixel_values_bs = self.vae.encode(control_pixel_values_bs)[0]
462
+ control_pixel_values_bs = control_pixel_values_bs.mode()
463
+ new_control_pixel_values.append(control_pixel_values_bs)
464
+ control_image_latents = torch.cat(new_control_pixel_values, dim = 0)
465
+ else:
466
+ control_image_latents = None
467
+
468
+ return control, control_image_latents
469
+
470
+ def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
471
+ frames = self.vae.decode(latents.to(self.vae.dtype)).sample
472
+ frames = (frames / 2 + 0.5).clamp(0, 1)
473
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
474
+ # frames = frames.cpu().float().numpy()
475
+ return frames
476
+
477
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
478
+ def prepare_extra_step_kwargs(self, generator, eta):
479
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
480
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
481
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
482
+ # and should be between [0, 1]
483
+
484
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
485
+ extra_step_kwargs = {}
486
+ if accepts_eta:
487
+ extra_step_kwargs["eta"] = eta
488
+
489
+ # check if the scheduler accepts generator
490
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
491
+ if accepts_generator:
492
+ extra_step_kwargs["generator"] = generator
493
+ return extra_step_kwargs
494
+
495
+ # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs
496
+ def check_inputs(
497
+ self,
498
+ prompt,
499
+ height,
500
+ width,
501
+ negative_prompt,
502
+ callback_on_step_end_tensor_inputs,
503
+ prompt_embeds=None,
504
+ negative_prompt_embeds=None,
505
+ ):
506
+ if height % 8 != 0 or width % 8 != 0:
507
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
508
+
509
+ if callback_on_step_end_tensor_inputs is not None and not all(
510
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
511
+ ):
512
+ raise ValueError(
513
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
514
+ )
515
+ if prompt is not None and prompt_embeds is not None:
516
+ raise ValueError(
517
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
518
+ " only forward one of the two."
519
+ )
520
+ elif prompt is None and prompt_embeds is None:
521
+ raise ValueError(
522
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
523
+ )
524
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
525
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
526
+
527
+ if prompt is not None and negative_prompt_embeds is not None:
528
+ raise ValueError(
529
+ f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
530
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
531
+ )
532
+
533
+ if negative_prompt is not None and negative_prompt_embeds is not None:
534
+ raise ValueError(
535
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
536
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
537
+ )
538
+
539
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
540
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
541
+ raise ValueError(
542
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
543
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
544
+ f" {negative_prompt_embeds.shape}."
545
+ )
546
+
547
+ @property
548
+ def guidance_scale(self):
549
+ return self._guidance_scale
550
+
551
+ @property
552
+ def num_timesteps(self):
553
+ return self._num_timesteps
554
+
555
+ @property
556
+ def attention_kwargs(self):
557
+ return self._attention_kwargs
558
+
559
+ @property
560
+ def interrupt(self):
561
+ return self._interrupt
562
+
563
+ @torch.no_grad()
564
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
565
+ def __call__(
566
+ self,
567
+ prompt: Optional[Union[str, List[str]]] = None,
568
+ negative_prompt: Optional[Union[str, List[str]]] = None,
569
+ height: int = 480,
570
+ width: int = 720,
571
+ clip_len=77,
572
+ num_frames: int = 49,
573
+ num_inference_steps: int = 50,
574
+ pose_video = None,
575
+ face_video = None,
576
+ ref_image = None,
577
+ bg_video = None,
578
+ mask_video = None,
579
+ replace_flag = True,
580
+ timesteps: Optional[List[int]] = None,
581
+ guidance_scale: float = 6,
582
+ num_videos_per_prompt: int = 1,
583
+ eta: float = 0.0,
584
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
585
+ latents: Optional[torch.FloatTensor] = None,
586
+ prompt_embeds: Optional[torch.FloatTensor] = None,
587
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
588
+ output_type: str = "numpy",
589
+ return_dict: bool = False,
590
+ callback_on_step_end: Optional[
591
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
592
+ ] = None,
593
+ attention_kwargs: Optional[Dict[str, Any]] = None,
594
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
595
+ max_sequence_length: int = 512,
596
+ boundary: float = 0.875,
597
+ comfyui_progressbar: bool = False,
598
+ shift: int = 5,
599
+ refert_num = 1,
600
+ ) -> Union[WanPipelineOutput, Tuple]:
601
+ """
602
+ Function invoked when calling the pipeline for generation.
603
+ Args:
604
+
605
+ Examples:
606
+
607
+ Returns:
608
+
609
+ """
610
+
611
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
612
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
613
+ num_videos_per_prompt = 1
614
+
615
+ # 1. Check inputs. Raise error if not correct
616
+ self.check_inputs(
617
+ prompt,
618
+ height,
619
+ width,
620
+ negative_prompt,
621
+ callback_on_step_end_tensor_inputs,
622
+ prompt_embeds,
623
+ negative_prompt_embeds,
624
+ )
625
+ self._guidance_scale = guidance_scale
626
+ self._attention_kwargs = attention_kwargs
627
+ self._interrupt = False
628
+
629
+ # 2. Default call parameters
630
+ if prompt is not None and isinstance(prompt, str):
631
+ batch_size = 1
632
+ elif prompt is not None and isinstance(prompt, list):
633
+ batch_size = len(prompt)
634
+ else:
635
+ batch_size = prompt_embeds.shape[0]
636
+
637
+ device = self._execution_device
638
+ weight_dtype = self.text_encoder.dtype
639
+
640
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
641
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
642
+ # corresponds to doing no classifier free guidance.
643
+ do_classifier_free_guidance = guidance_scale > 1.0
644
+
645
+ # 3. Encode input prompt
646
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
647
+ prompt,
648
+ negative_prompt,
649
+ do_classifier_free_guidance,
650
+ num_videos_per_prompt=num_videos_per_prompt,
651
+ prompt_embeds=prompt_embeds,
652
+ negative_prompt_embeds=negative_prompt_embeds,
653
+ max_sequence_length=max_sequence_length,
654
+ device=device,
655
+ )
656
+ if do_classifier_free_guidance:
657
+ in_prompt_embeds = negative_prompt_embeds + prompt_embeds
658
+ else:
659
+ in_prompt_embeds = prompt_embeds
660
+
661
+ if comfyui_progressbar:
662
+ from comfy.utils import ProgressBar
663
+ pbar = ProgressBar(num_inference_steps + 1)
664
+
665
+ # 4. Prepare latents
666
+ if pose_video is not None:
667
+ video_length = pose_video.shape[2]
668
+ pose_video = self.image_processor.preprocess(rearrange(pose_video, "b c f h w -> (b f) c h w"), height=height, width=width)
669
+ pose_video = pose_video.to(dtype=torch.float32)
670
+ pose_video = rearrange(pose_video, "(b f) c h w -> b c f h w", f=video_length)
671
+ else:
672
+ pose_video = None
673
+
674
+ if face_video is not None:
675
+ video_length = face_video.shape[2]
676
+ face_video = self.image_processor.preprocess(rearrange(face_video, "b c f h w -> (b f) c h w"))
677
+ face_video = face_video.to(dtype=torch.float32)
678
+ face_video = rearrange(face_video, "(b f) c h w -> b c f h w", f=video_length)
679
+ else:
680
+ face_video = None
681
+
682
+ real_frame_len = pose_video.size()[2]
683
+ target_len = self.get_valid_len(real_frame_len, clip_len, overlap=refert_num)
684
+ print('real frames: {} target frames: {}'.format(real_frame_len, target_len))
685
+ pose_video = self.inputs_padding(pose_video, target_len).to(device, weight_dtype)
686
+ face_video = self.inputs_padding(face_video, target_len).to(device, weight_dtype)
687
+ ref_image = self.padding_resize(np.array(ref_image), height=height, width=width)
688
+ ref_image = torch.tensor(ref_image / 127.5 - 1).unsqueeze(0).permute([3, 0, 1, 2]).unsqueeze(0).to(device, weight_dtype)
689
+
690
+ if replace_flag:
691
+ if bg_video is not None:
692
+ video_length = bg_video.shape[2]
693
+ bg_video = self.image_processor.preprocess(rearrange(bg_video, "b c f h w -> (b f) c h w"), height=height, width=width)
694
+ bg_video = bg_video.to(dtype=torch.float32)
695
+ bg_video = rearrange(bg_video, "(b f) c h w -> b c f h w", f=video_length)
696
+ else:
697
+ bg_video = None
698
+ bg_video = self.inputs_padding(bg_video, target_len).to(device, weight_dtype)
699
+ mask_video = self.inputs_padding(mask_video, target_len).to(device, weight_dtype)
700
+
701
+ if comfyui_progressbar:
702
+ pbar.update(1)
703
+
704
+ # 5. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
705
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
706
+
707
+ target_shape = (self.vae.latent_channels, (num_frames - 1) // self.vae.temporal_compression_ratio + 1, width // self.vae.spatial_compression_ratio, height // self.vae.spatial_compression_ratio)
708
+ seq_len = math.ceil((target_shape[2] * target_shape[3]) / (self.transformer.config.patch_size[1] * self.transformer.config.patch_size[2]) * target_shape[1])
709
+
710
+ # 6. Denoising loop
711
+ start = 0
712
+ end = clip_len
713
+ all_out_frames = []
714
+ copy_timesteps = copy.deepcopy(timesteps)
715
+ copy_latents = copy.deepcopy(latents)
716
+ bs = pose_video.size()[0]
717
+ while True:
718
+ if start + refert_num >= pose_video.size()[2]:
719
+ break
720
+
721
+ # Prepare timesteps
722
+ if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler):
723
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, copy_timesteps, mu=1)
724
+ elif isinstance(self.scheduler, FlowUniPCMultistepScheduler):
725
+ self.scheduler.set_timesteps(num_inference_steps, device=device, shift=shift)
726
+ timesteps = self.scheduler.timesteps
727
+ elif isinstance(self.scheduler, FlowDPMSolverMultistepScheduler):
728
+ sampling_sigmas = get_sampling_sigmas(num_inference_steps, shift)
729
+ timesteps, _ = retrieve_timesteps(
730
+ self.scheduler,
731
+ device=device,
732
+ sigmas=sampling_sigmas)
733
+ else:
734
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, copy_timesteps)
735
+ self._num_timesteps = len(timesteps)
736
+
737
+ latent_channels = self.transformer.config.in_channels
738
+ latents = self.prepare_latents(
739
+ batch_size * num_videos_per_prompt,
740
+ latent_channels,
741
+ num_frames,
742
+ height,
743
+ width,
744
+ weight_dtype,
745
+ device,
746
+ generator,
747
+ copy_latents,
748
+ )
749
+
750
+ if start == 0:
751
+ mask_reft_len = 0
752
+ else:
753
+ mask_reft_len = refert_num
754
+
755
+ conditioning_pixel_values = pose_video[:, :, start:end]
756
+ face_pixel_values = face_video[:, :, start:end]
757
+ ref_pixel_values = ref_image.clone().detach()
758
+ if start > 0:
759
+ refer_t_pixel_values = out_frames[:, :, -refert_num:].clone().detach()
760
+ refer_t_pixel_values = (refer_t_pixel_values - 0.5) / 0.5
761
+ else:
762
+ refer_t_pixel_values = torch.zeros(bs, 3, refert_num, height, width)
763
+ refer_t_pixel_values = refer_t_pixel_values.to(device=device, dtype=weight_dtype)
764
+
765
+ pose_latents, ref_latents = self.prepare_control_latents(
766
+ conditioning_pixel_values,
767
+ ref_pixel_values,
768
+ batch_size,
769
+ height,
770
+ width,
771
+ weight_dtype,
772
+ device,
773
+ generator,
774
+ do_classifier_free_guidance
775
+ )
776
+
777
+ mask_ref = self.get_i2v_mask(1, target_shape[-1], target_shape[-2], 1, device=device)
778
+ y_ref = torch.concat([mask_ref, ref_latents], dim=1).to(device=device, dtype=weight_dtype)
779
+ if mask_reft_len > 0:
780
+ if replace_flag:
781
+ # Image.fromarray(np.array((refer_t_pixel_values[0, :, 0].permute(1,2,0) * 0.5 + 0.5).float().cpu().numpy() *255, np.uint8)).save("1.jpg")
782
+ bg_pixel_values = bg_video[:, :, start:end]
783
+ y_reft = self.vae.encode(
784
+ torch.concat(
785
+ [
786
+ refer_t_pixel_values[:, :, :mask_reft_len],
787
+ bg_pixel_values[:, :, mask_reft_len:]
788
+ ], dim=2
789
+ ).to(device=device, dtype=weight_dtype)
790
+ )[0].mode()
791
+
792
+ mask_pixel_values = 1 - mask_video[:, :, start:end]
793
+ mask_pixel_values = rearrange(mask_pixel_values, "b c t h w -> (b t) c h w")
794
+ mask_pixel_values = F.interpolate(mask_pixel_values, size=(target_shape[-1], target_shape[-2]), mode='nearest')
795
+ mask_pixel_values = rearrange(mask_pixel_values, "(b t) c h w -> b c t h w", b = bs)[:, 0]
796
+ msk_reft = self.get_i2v_mask(
797
+ int((clip_len - 1) // self.vae.temporal_compression_ratio + 1), target_shape[-1], target_shape[-2], mask_reft_len, mask_pixel_values=mask_pixel_values, device=device
798
+ )
799
+ else:
800
+ refer_t_pixel_values = rearrange(refer_t_pixel_values[:, :, :mask_reft_len], "b c t h w -> (b t) c h w")
801
+ refer_t_pixel_values = F.interpolate(refer_t_pixel_values, size=(height, width), mode="bicubic")
802
+ refer_t_pixel_values = rearrange(refer_t_pixel_values, "(b t) c h w -> b c t h w", b = bs)
803
+
804
+ y_reft = self.vae.encode(
805
+ torch.concat(
806
+ [
807
+ refer_t_pixel_values,
808
+ torch.zeros(bs, 3, clip_len - mask_reft_len, height, width).to(device=device, dtype=weight_dtype),
809
+ ], dim=2,
810
+ ).to(device=device, dtype=weight_dtype)
811
+ )[0].mode()
812
+ msk_reft = self.get_i2v_mask(
813
+ int((clip_len - 1) // self.vae.temporal_compression_ratio + 1), target_shape[-1], target_shape[-2], mask_reft_len, device=device
814
+ )
815
+ else:
816
+ if replace_flag:
817
+ bg_pixel_values = bg_video[:, :, start:end]
818
+ y_reft = self.vae.encode(
819
+ bg_pixel_values.to(device=device, dtype=weight_dtype)
820
+ )[0].mode()
821
+
822
+ mask_pixel_values = 1 - mask_video[:, :, start:end]
823
+ mask_pixel_values = rearrange(mask_pixel_values, "b c t h w -> (b t) c h w")
824
+ mask_pixel_values = F.interpolate(mask_pixel_values, size=(target_shape[-1], target_shape[-2]), mode='nearest')
825
+ mask_pixel_values = rearrange(mask_pixel_values, "(b t) c h w -> b c t h w", b = bs)[:, 0]
826
+ msk_reft = self.get_i2v_mask(
827
+ int((clip_len - 1) // self.vae.temporal_compression_ratio + 1), target_shape[-1], target_shape[-2], mask_reft_len, mask_pixel_values=mask_pixel_values, device=device
828
+ )
829
+ else:
830
+ y_reft = self.vae.encode(
831
+ torch.zeros(1, 3, clip_len - mask_reft_len, height, width).to(device=device, dtype=weight_dtype)
832
+ )[0].mode()
833
+ msk_reft = self.get_i2v_mask(
834
+ int((clip_len - 1) // self.vae.temporal_compression_ratio + 1), target_shape[-1], target_shape[-2], mask_reft_len, device=device
835
+ )
836
+
837
+ y_reft = torch.concat([msk_reft, y_reft], dim=1).to(device=device, dtype=weight_dtype)
838
+ y = torch.concat([y_ref, y_reft], dim=2)
839
+
840
+ clip_context = self.clip_image_encoder([ref_pixel_values[0, :, :, :]]).to(device=device, dtype=weight_dtype)
841
+
842
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
843
+ self.transformer.num_inference_steps = num_inference_steps
844
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
845
+ for i, t in enumerate(timesteps):
846
+ self.transformer.current_steps = i
847
+
848
+ if self.interrupt:
849
+ continue
850
+
851
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
852
+ if hasattr(self.scheduler, "scale_model_input"):
853
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
854
+
855
+ y_in = torch.cat([y] * 2) if do_classifier_free_guidance else y
856
+ clip_context_input = (
857
+ torch.cat([clip_context] * 2) if do_classifier_free_guidance else clip_context
858
+ )
859
+ pose_latents_input = (
860
+ torch.cat([pose_latents] * 2) if do_classifier_free_guidance else pose_latents
861
+ )
862
+ face_pixel_values_input = (
863
+ torch.cat([torch.ones_like(face_pixel_values) * -1] + [face_pixel_values]) if do_classifier_free_guidance else face_pixel_values
864
+ )
865
+
866
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
867
+ timestep = t.expand(latent_model_input.shape[0])
868
+
869
+ if self.transformer_2 is not None:
870
+ if t >= boundary * self.scheduler.config.num_train_timesteps:
871
+ local_transformer = self.transformer_2
872
+ else:
873
+ local_transformer = self.transformer
874
+ else:
875
+ local_transformer = self.transformer
876
+
877
+ # predict noise model_output
878
+ with torch.cuda.amp.autocast(dtype=weight_dtype), torch.cuda.device(device=device):
879
+ noise_pred = local_transformer(
880
+ x=latent_model_input,
881
+ context=in_prompt_embeds,
882
+ t=timestep,
883
+ seq_len=seq_len,
884
+ y=y_in,
885
+ clip_fea=clip_context_input,
886
+ pose_latents=pose_latents_input,
887
+ face_pixel_values=face_pixel_values_input,
888
+ )
889
+
890
+ # Perform guidance
891
+ if do_classifier_free_guidance:
892
+ if self.transformer_2 is not None and (isinstance(self.guidance_scale, (list, tuple))):
893
+ sample_guide_scale = self.guidance_scale[1] if t >= self.transformer_2.config.boundary * self.scheduler.config.num_train_timesteps else self.guidance_scale[0]
894
+ else:
895
+ sample_guide_scale = self.guidance_scale
896
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
897
+ noise_pred = noise_pred_uncond + sample_guide_scale * (noise_pred_text - noise_pred_uncond)
898
+
899
+ # Compute the previous noisy sample x_t -> x_t-1
900
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
901
+
902
+ if callback_on_step_end is not None:
903
+ callback_kwargs = {}
904
+ for k in callback_on_step_end_tensor_inputs:
905
+ callback_kwargs[k] = locals()[k]
906
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
907
+
908
+ latents = callback_outputs.pop("latents", latents)
909
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
910
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
911
+
912
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
913
+ progress_bar.update()
914
+ if comfyui_progressbar:
915
+ pbar.update(1)
916
+
917
+ out_frames = self.decode_latents(latents[:, :, 1:])
918
+ if start != 0:
919
+ out_frames = out_frames[:, :, refert_num:]
920
+ all_out_frames.append(out_frames.cpu())
921
+ start += clip_len - refert_num
922
+ end += clip_len - refert_num
923
+
924
+ videos = torch.cat(all_out_frames, dim=2)[:, :, :real_frame_len]
925
+
926
+ # Offload all models
927
+ self.maybe_free_model_hooks()
928
+
929
+ return WanPipelineOutput(videos=videos.float().cpu())
videox_fun/pipeline/pipeline_wan2_2_fun_control.py ADDED
@@ -0,0 +1,903 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import math
3
+ from dataclasses import dataclass
4
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn.functional as F
9
+ import torchvision.transforms.functional as TF
10
+ from diffusers import FlowMatchEulerDiscreteScheduler
11
+ from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
12
+ from diffusers.image_processor import VaeImageProcessor
13
+ from diffusers.models.embeddings import get_1d_rotary_pos_embed
14
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
15
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
16
+ from diffusers.utils import BaseOutput, logging, replace_example_docstring
17
+ from diffusers.utils.torch_utils import randn_tensor
18
+ from diffusers.video_processor import VideoProcessor
19
+ from einops import rearrange
20
+ from PIL import Image
21
+ from transformers import T5Tokenizer
22
+
23
+ from ..models import (AutoencoderKLWan, AutoTokenizer,
24
+ Wan2_2Transformer3DModel, WanT5EncoderModel)
25
+ from ..utils.fm_solvers import (FlowDPMSolverMultistepScheduler,
26
+ get_sampling_sigmas)
27
+ from ..utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
28
+
29
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
30
+
31
+
32
+ EXAMPLE_DOC_STRING = """
33
+ Examples:
34
+ ```python
35
+ pass
36
+ ```
37
+ """
38
+
39
+
40
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
41
+ def retrieve_timesteps(
42
+ scheduler,
43
+ num_inference_steps: Optional[int] = None,
44
+ device: Optional[Union[str, torch.device]] = None,
45
+ timesteps: Optional[List[int]] = None,
46
+ sigmas: Optional[List[float]] = None,
47
+ **kwargs,
48
+ ):
49
+ """
50
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
51
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
52
+
53
+ Args:
54
+ scheduler (`SchedulerMixin`):
55
+ The scheduler to get timesteps from.
56
+ num_inference_steps (`int`):
57
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
58
+ must be `None`.
59
+ device (`str` or `torch.device`, *optional*):
60
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
61
+ timesteps (`List[int]`, *optional*):
62
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
63
+ `num_inference_steps` and `sigmas` must be `None`.
64
+ sigmas (`List[float]`, *optional*):
65
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
66
+ `num_inference_steps` and `timesteps` must be `None`.
67
+
68
+ Returns:
69
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
70
+ second element is the number of inference steps.
71
+ """
72
+ if timesteps is not None and sigmas is not None:
73
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
74
+ if timesteps is not None:
75
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
76
+ if not accepts_timesteps:
77
+ raise ValueError(
78
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
79
+ f" timestep schedules. Please check whether you are using the correct scheduler."
80
+ )
81
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
82
+ timesteps = scheduler.timesteps
83
+ num_inference_steps = len(timesteps)
84
+ elif sigmas is not None:
85
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
86
+ if not accept_sigmas:
87
+ raise ValueError(
88
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
89
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
90
+ )
91
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
92
+ timesteps = scheduler.timesteps
93
+ num_inference_steps = len(timesteps)
94
+ else:
95
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
96
+ timesteps = scheduler.timesteps
97
+ return timesteps, num_inference_steps
98
+
99
+
100
+ def resize_mask(mask, latent, process_first_frame_only=True):
101
+ latent_size = latent.size()
102
+ batch_size, channels, num_frames, height, width = mask.shape
103
+
104
+ if process_first_frame_only:
105
+ target_size = list(latent_size[2:])
106
+ target_size[0] = 1
107
+ first_frame_resized = F.interpolate(
108
+ mask[:, :, 0:1, :, :],
109
+ size=target_size,
110
+ mode='trilinear',
111
+ align_corners=False
112
+ )
113
+
114
+ target_size = list(latent_size[2:])
115
+ target_size[0] = target_size[0] - 1
116
+ if target_size[0] != 0:
117
+ remaining_frames_resized = F.interpolate(
118
+ mask[:, :, 1:, :, :],
119
+ size=target_size,
120
+ mode='trilinear',
121
+ align_corners=False
122
+ )
123
+ resized_mask = torch.cat([first_frame_resized, remaining_frames_resized], dim=2)
124
+ else:
125
+ resized_mask = first_frame_resized
126
+ else:
127
+ target_size = list(latent_size[2:])
128
+ resized_mask = F.interpolate(
129
+ mask,
130
+ size=target_size,
131
+ mode='trilinear',
132
+ align_corners=False
133
+ )
134
+ return resized_mask
135
+
136
+
137
+ @dataclass
138
+ class WanPipelineOutput(BaseOutput):
139
+ r"""
140
+ Output class for CogVideo pipelines.
141
+
142
+ Args:
143
+ video (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
144
+ List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
145
+ denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
146
+ `(batch_size, num_frames, channels, height, width)`.
147
+ """
148
+
149
+ videos: torch.Tensor
150
+
151
+
152
+ class Wan2_2FunControlPipeline(DiffusionPipeline):
153
+ r"""
154
+ Pipeline for text-to-video generation using Wan.
155
+
156
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
157
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
158
+ """
159
+
160
+ _optional_components = ["transformer_2"]
161
+ model_cpu_offload_seq = "text_encoder->transformer_2->transformer->vae"
162
+
163
+ _callback_tensor_inputs = [
164
+ "latents",
165
+ "prompt_embeds",
166
+ "negative_prompt_embeds",
167
+ ]
168
+
169
+ def __init__(
170
+ self,
171
+ tokenizer: AutoTokenizer,
172
+ text_encoder: WanT5EncoderModel,
173
+ vae: AutoencoderKLWan,
174
+ transformer: Wan2_2Transformer3DModel,
175
+ transformer_2: Wan2_2Transformer3DModel = None,
176
+ scheduler: FlowMatchEulerDiscreteScheduler = None,
177
+ ):
178
+ super().__init__()
179
+
180
+ self.register_modules(
181
+ tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer,
182
+ transformer_2=transformer_2, scheduler=scheduler
183
+ )
184
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae.spatial_compression_ratio)
185
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae.spatial_compression_ratio)
186
+ self.mask_processor = VaeImageProcessor(
187
+ vae_scale_factor=self.vae.spatial_compression_ratio, do_normalize=False, do_binarize=True, do_convert_grayscale=True
188
+ )
189
+
190
+ def _get_t5_prompt_embeds(
191
+ self,
192
+ prompt: Union[str, List[str]] = None,
193
+ num_videos_per_prompt: int = 1,
194
+ max_sequence_length: int = 512,
195
+ device: Optional[torch.device] = None,
196
+ dtype: Optional[torch.dtype] = None,
197
+ ):
198
+ device = device or self._execution_device
199
+ dtype = dtype or self.text_encoder.dtype
200
+
201
+ prompt = [prompt] if isinstance(prompt, str) else prompt
202
+ batch_size = len(prompt)
203
+
204
+ text_inputs = self.tokenizer(
205
+ prompt,
206
+ padding="max_length",
207
+ max_length=max_sequence_length,
208
+ truncation=True,
209
+ add_special_tokens=True,
210
+ return_tensors="pt",
211
+ )
212
+ text_input_ids = text_inputs.input_ids
213
+ prompt_attention_mask = text_inputs.attention_mask
214
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
215
+
216
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
217
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
218
+ logger.warning(
219
+ "The following part of your input was truncated because `max_sequence_length` is set to "
220
+ f" {max_sequence_length} tokens: {removed_text}"
221
+ )
222
+
223
+ seq_lens = prompt_attention_mask.gt(0).sum(dim=1).long()
224
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask.to(device))[0]
225
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
226
+
227
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
228
+ _, seq_len, _ = prompt_embeds.shape
229
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
230
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
231
+
232
+ return [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
233
+
234
+ def encode_prompt(
235
+ self,
236
+ prompt: Union[str, List[str]],
237
+ negative_prompt: Optional[Union[str, List[str]]] = None,
238
+ do_classifier_free_guidance: bool = True,
239
+ num_videos_per_prompt: int = 1,
240
+ prompt_embeds: Optional[torch.Tensor] = None,
241
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
242
+ max_sequence_length: int = 512,
243
+ device: Optional[torch.device] = None,
244
+ dtype: Optional[torch.dtype] = None,
245
+ ):
246
+ r"""
247
+ Encodes the prompt into text encoder hidden states.
248
+
249
+ Args:
250
+ prompt (`str` or `List[str]`, *optional*):
251
+ prompt to be encoded
252
+ negative_prompt (`str` or `List[str]`, *optional*):
253
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
254
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
255
+ less than `1`).
256
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
257
+ Whether to use classifier free guidance or not.
258
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
259
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
260
+ prompt_embeds (`torch.Tensor`, *optional*):
261
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
262
+ provided, text embeddings will be generated from `prompt` input argument.
263
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
264
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
265
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
266
+ argument.
267
+ device: (`torch.device`, *optional*):
268
+ torch device
269
+ dtype: (`torch.dtype`, *optional*):
270
+ torch dtype
271
+ """
272
+ device = device or self._execution_device
273
+
274
+ prompt = [prompt] if isinstance(prompt, str) else prompt
275
+ if prompt is not None:
276
+ batch_size = len(prompt)
277
+ else:
278
+ batch_size = prompt_embeds.shape[0]
279
+
280
+ if prompt_embeds is None:
281
+ prompt_embeds = self._get_t5_prompt_embeds(
282
+ prompt=prompt,
283
+ num_videos_per_prompt=num_videos_per_prompt,
284
+ max_sequence_length=max_sequence_length,
285
+ device=device,
286
+ dtype=dtype,
287
+ )
288
+
289
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
290
+ negative_prompt = negative_prompt or ""
291
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
292
+
293
+ if prompt is not None and type(prompt) is not type(negative_prompt):
294
+ raise TypeError(
295
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
296
+ f" {type(prompt)}."
297
+ )
298
+ elif batch_size != len(negative_prompt):
299
+ raise ValueError(
300
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
301
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
302
+ " the batch size of `prompt`."
303
+ )
304
+
305
+ negative_prompt_embeds = self._get_t5_prompt_embeds(
306
+ prompt=negative_prompt,
307
+ num_videos_per_prompt=num_videos_per_prompt,
308
+ max_sequence_length=max_sequence_length,
309
+ device=device,
310
+ dtype=dtype,
311
+ )
312
+
313
+ return prompt_embeds, negative_prompt_embeds
314
+
315
+ def prepare_latents(
316
+ self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
317
+ ):
318
+ if isinstance(generator, list) and len(generator) != batch_size:
319
+ raise ValueError(
320
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
321
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
322
+ )
323
+
324
+ shape = (
325
+ batch_size,
326
+ num_channels_latents,
327
+ (num_frames - 1) // self.vae.temporal_compression_ratio + 1,
328
+ height // self.vae.spatial_compression_ratio,
329
+ width // self.vae.spatial_compression_ratio,
330
+ )
331
+
332
+ if latents is None:
333
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
334
+ else:
335
+ latents = latents.to(device)
336
+
337
+ # scale the initial noise by the standard deviation required by the scheduler
338
+ if hasattr(self.scheduler, "init_noise_sigma"):
339
+ latents = latents * self.scheduler.init_noise_sigma
340
+ return latents
341
+
342
+ def prepare_mask_latents(
343
+ self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance, noise_aug_strength
344
+ ):
345
+ # resize the mask to latents shape as we concatenate the mask to the latents
346
+ # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
347
+ # and half precision
348
+
349
+ if mask is not None:
350
+ mask = mask.to(device=device, dtype=self.vae.dtype)
351
+ bs = 1
352
+ new_mask = []
353
+ for i in range(0, mask.shape[0], bs):
354
+ mask_bs = mask[i : i + bs]
355
+ mask_bs = self.vae.encode(mask_bs)[0]
356
+ mask_bs = mask_bs.mode()
357
+ new_mask.append(mask_bs)
358
+ mask = torch.cat(new_mask, dim = 0)
359
+ # mask = mask * self.vae.config.scaling_factor
360
+
361
+ if masked_image is not None:
362
+ masked_image = masked_image.to(device=device, dtype=self.vae.dtype)
363
+ bs = 1
364
+ new_mask_pixel_values = []
365
+ for i in range(0, masked_image.shape[0], bs):
366
+ mask_pixel_values_bs = masked_image[i : i + bs]
367
+ mask_pixel_values_bs = self.vae.encode(mask_pixel_values_bs)[0]
368
+ mask_pixel_values_bs = mask_pixel_values_bs.mode()
369
+ new_mask_pixel_values.append(mask_pixel_values_bs)
370
+ masked_image_latents = torch.cat(new_mask_pixel_values, dim = 0)
371
+ # masked_image_latents = masked_image_latents * self.vae.config.scaling_factor
372
+ else:
373
+ masked_image_latents = None
374
+
375
+ return mask, masked_image_latents
376
+
377
+ def prepare_control_latents(
378
+ self, control, control_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
379
+ ):
380
+ # resize the control to latents shape as we concatenate the control to the latents
381
+ # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
382
+ # and half precision
383
+
384
+ if control is not None:
385
+ control = control.to(device=device, dtype=dtype)
386
+ bs = 1
387
+ new_control = []
388
+ for i in range(0, control.shape[0], bs):
389
+ control_bs = control[i : i + bs]
390
+ control_bs = self.vae.encode(control_bs)[0]
391
+ control_bs = control_bs.mode()
392
+ new_control.append(control_bs)
393
+ control = torch.cat(new_control, dim = 0)
394
+
395
+ if control_image is not None:
396
+ control_image = control_image.to(device=device, dtype=dtype)
397
+ bs = 1
398
+ new_control_pixel_values = []
399
+ for i in range(0, control_image.shape[0], bs):
400
+ control_pixel_values_bs = control_image[i : i + bs]
401
+ control_pixel_values_bs = self.vae.encode(control_pixel_values_bs)[0]
402
+ control_pixel_values_bs = control_pixel_values_bs.mode()
403
+ new_control_pixel_values.append(control_pixel_values_bs)
404
+ control_image_latents = torch.cat(new_control_pixel_values, dim = 0)
405
+ else:
406
+ control_image_latents = None
407
+
408
+ return control, control_image_latents
409
+
410
+ def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
411
+ frames = self.vae.decode(latents.to(self.vae.dtype)).sample
412
+ frames = (frames / 2 + 0.5).clamp(0, 1)
413
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
414
+ frames = frames.cpu().float().numpy()
415
+ return frames
416
+
417
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
418
+ def prepare_extra_step_kwargs(self, generator, eta):
419
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
420
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
421
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
422
+ # and should be between [0, 1]
423
+
424
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
425
+ extra_step_kwargs = {}
426
+ if accepts_eta:
427
+ extra_step_kwargs["eta"] = eta
428
+
429
+ # check if the scheduler accepts generator
430
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
431
+ if accepts_generator:
432
+ extra_step_kwargs["generator"] = generator
433
+ return extra_step_kwargs
434
+
435
+ # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs
436
+ def check_inputs(
437
+ self,
438
+ prompt,
439
+ height,
440
+ width,
441
+ negative_prompt,
442
+ callback_on_step_end_tensor_inputs,
443
+ prompt_embeds=None,
444
+ negative_prompt_embeds=None,
445
+ ):
446
+ if height % 8 != 0 or width % 8 != 0:
447
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
448
+
449
+ if callback_on_step_end_tensor_inputs is not None and not all(
450
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
451
+ ):
452
+ raise ValueError(
453
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
454
+ )
455
+ if prompt is not None and prompt_embeds is not None:
456
+ raise ValueError(
457
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
458
+ " only forward one of the two."
459
+ )
460
+ elif prompt is None and prompt_embeds is None:
461
+ raise ValueError(
462
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
463
+ )
464
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
465
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
466
+
467
+ if prompt is not None and negative_prompt_embeds is not None:
468
+ raise ValueError(
469
+ f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
470
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
471
+ )
472
+
473
+ if negative_prompt is not None and negative_prompt_embeds is not None:
474
+ raise ValueError(
475
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
476
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
477
+ )
478
+
479
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
480
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
481
+ raise ValueError(
482
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
483
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
484
+ f" {negative_prompt_embeds.shape}."
485
+ )
486
+
487
+ @property
488
+ def guidance_scale(self):
489
+ return self._guidance_scale
490
+
491
+ @property
492
+ def num_timesteps(self):
493
+ return self._num_timesteps
494
+
495
+ @property
496
+ def attention_kwargs(self):
497
+ return self._attention_kwargs
498
+
499
+ @property
500
+ def interrupt(self):
501
+ return self._interrupt
502
+
503
+ @torch.no_grad()
504
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
505
+ def __call__(
506
+ self,
507
+ prompt: Optional[Union[str, List[str]]] = None,
508
+ negative_prompt: Optional[Union[str, List[str]]] = None,
509
+ height: int = 480,
510
+ width: int = 720,
511
+ video: Union[torch.FloatTensor] = None,
512
+ mask_video: Union[torch.FloatTensor] = None,
513
+ control_video: Union[torch.FloatTensor] = None,
514
+ control_camera_video: Union[torch.FloatTensor] = None,
515
+ start_image: Union[torch.FloatTensor] = None,
516
+ ref_image: Union[torch.FloatTensor] = None,
517
+ num_frames: int = 49,
518
+ num_inference_steps: int = 50,
519
+ timesteps: Optional[List[int]] = None,
520
+ guidance_scale: float = 6,
521
+ num_videos_per_prompt: int = 1,
522
+ eta: float = 0.0,
523
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
524
+ latents: Optional[torch.FloatTensor] = None,
525
+ prompt_embeds: Optional[torch.FloatTensor] = None,
526
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
527
+ output_type: str = "numpy",
528
+ return_dict: bool = False,
529
+ callback_on_step_end: Optional[
530
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
531
+ ] = None,
532
+ attention_kwargs: Optional[Dict[str, Any]] = None,
533
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
534
+ max_sequence_length: int = 512,
535
+ boundary: float = 0.875,
536
+ comfyui_progressbar: bool = False,
537
+ shift: int = 5,
538
+ ) -> Union[WanPipelineOutput, Tuple]:
539
+ """
540
+ Function invoked when calling the pipeline for generation.
541
+ Args:
542
+
543
+ Examples:
544
+
545
+ Returns:
546
+
547
+ """
548
+
549
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
550
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
551
+ num_videos_per_prompt = 1
552
+
553
+ # 1. Check inputs. Raise error if not correct
554
+ self.check_inputs(
555
+ prompt,
556
+ height,
557
+ width,
558
+ negative_prompt,
559
+ callback_on_step_end_tensor_inputs,
560
+ prompt_embeds,
561
+ negative_prompt_embeds,
562
+ )
563
+ self._guidance_scale = guidance_scale
564
+ self._attention_kwargs = attention_kwargs
565
+ self._interrupt = False
566
+
567
+ # 2. Default call parameters
568
+ if prompt is not None and isinstance(prompt, str):
569
+ batch_size = 1
570
+ elif prompt is not None and isinstance(prompt, list):
571
+ batch_size = len(prompt)
572
+ else:
573
+ batch_size = prompt_embeds.shape[0]
574
+
575
+ device = self._execution_device
576
+ weight_dtype = self.text_encoder.dtype
577
+
578
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
579
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
580
+ # corresponds to doing no classifier free guidance.
581
+ do_classifier_free_guidance = guidance_scale > 1.0
582
+
583
+ # 3. Encode input prompt
584
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
585
+ prompt,
586
+ negative_prompt,
587
+ do_classifier_free_guidance,
588
+ num_videos_per_prompt=num_videos_per_prompt,
589
+ prompt_embeds=prompt_embeds,
590
+ negative_prompt_embeds=negative_prompt_embeds,
591
+ max_sequence_length=max_sequence_length,
592
+ device=device,
593
+ )
594
+ if do_classifier_free_guidance:
595
+ in_prompt_embeds = negative_prompt_embeds + prompt_embeds
596
+ else:
597
+ in_prompt_embeds = prompt_embeds
598
+
599
+ # 4. Prepare timesteps
600
+ if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler):
601
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps, mu=1)
602
+ elif isinstance(self.scheduler, FlowUniPCMultistepScheduler):
603
+ self.scheduler.set_timesteps(num_inference_steps, device=device, shift=shift)
604
+ timesteps = self.scheduler.timesteps
605
+ elif isinstance(self.scheduler, FlowDPMSolverMultistepScheduler):
606
+ sampling_sigmas = get_sampling_sigmas(num_inference_steps, shift)
607
+ timesteps, _ = retrieve_timesteps(
608
+ self.scheduler,
609
+ device=device,
610
+ sigmas=sampling_sigmas)
611
+ else:
612
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
613
+ self._num_timesteps = len(timesteps)
614
+ if comfyui_progressbar:
615
+ from comfy.utils import ProgressBar
616
+ pbar = ProgressBar(num_inference_steps + 2)
617
+
618
+ # 5. Prepare latents.
619
+ if video is not None:
620
+ video_length = video.shape[2]
621
+ init_video = self.image_processor.preprocess(rearrange(video, "b c f h w -> (b f) c h w"), height=height, width=width)
622
+ init_video = init_video.to(dtype=torch.float32)
623
+ init_video = rearrange(init_video, "(b f) c h w -> b c f h w", f=video_length)
624
+ else:
625
+ init_video = None
626
+
627
+ latent_channels = self.vae.config.latent_channels
628
+ latents = self.prepare_latents(
629
+ batch_size * num_videos_per_prompt,
630
+ latent_channels,
631
+ num_frames,
632
+ height,
633
+ width,
634
+ weight_dtype,
635
+ device,
636
+ generator,
637
+ latents,
638
+ )
639
+ if comfyui_progressbar:
640
+ pbar.update(1)
641
+
642
+ # Prepare mask latent variables
643
+ if init_video is not None:
644
+ if (mask_video == 255).all():
645
+ mask_latents = torch.tile(
646
+ torch.zeros_like(latents)[:, :1].to(device, weight_dtype), [1, 4, 1, 1, 1]
647
+ )
648
+ masked_video_latents = torch.zeros_like(latents).to(device, weight_dtype)
649
+ if self.vae.spatial_compression_ratio >= 16:
650
+ mask = torch.ones_like(latents).to(device, weight_dtype)[:, :1].to(device, weight_dtype)
651
+ else:
652
+ bs, _, video_length, height, width = video.size()
653
+ mask_condition = self.mask_processor.preprocess(rearrange(mask_video, "b c f h w -> (b f) c h w"), height=height, width=width)
654
+ mask_condition = mask_condition.to(dtype=torch.float32)
655
+ mask_condition = rearrange(mask_condition, "(b f) c h w -> b c f h w", f=video_length)
656
+
657
+ masked_video = init_video * (torch.tile(mask_condition, [1, 3, 1, 1, 1]) < 0.5)
658
+ _, masked_video_latents = self.prepare_mask_latents(
659
+ None,
660
+ masked_video,
661
+ batch_size,
662
+ height,
663
+ width,
664
+ weight_dtype,
665
+ device,
666
+ generator,
667
+ do_classifier_free_guidance,
668
+ noise_aug_strength=None,
669
+ )
670
+
671
+ mask_condition = torch.concat(
672
+ [
673
+ torch.repeat_interleave(mask_condition[:, :, 0:1], repeats=4, dim=2),
674
+ mask_condition[:, :, 1:]
675
+ ], dim=2
676
+ )
677
+ mask_condition = mask_condition.view(bs, mask_condition.shape[2] // 4, 4, height, width)
678
+ mask_condition = mask_condition.transpose(1, 2)
679
+ mask_latents = resize_mask(1 - mask_condition, masked_video_latents, True).to(device, weight_dtype)
680
+
681
+ if self.vae.spatial_compression_ratio >= 16:
682
+ mask = F.interpolate(mask_condition[:, :1], size=latents.size()[-3:], mode='trilinear', align_corners=True).to(device, weight_dtype)
683
+ if not mask[:, :, 0, :, :].any():
684
+ mask[:, :, 1:, :, :] = 1
685
+ latents = (1 - mask) * masked_video_latents + mask * latents
686
+
687
+ # Prepare mask latent variables
688
+ if control_camera_video is not None:
689
+ control_latents = None
690
+ # Rearrange dimensions
691
+ # Concatenate and transpose dimensions
692
+ control_camera_latents = torch.concat(
693
+ [
694
+ torch.repeat_interleave(control_camera_video[:, :, 0:1], repeats=4, dim=2),
695
+ control_camera_video[:, :, 1:]
696
+ ], dim=2
697
+ ).transpose(1, 2)
698
+
699
+ # Reshape, transpose, and view into desired shape
700
+ b, f, c, h, w = control_camera_latents.shape
701
+ control_camera_latents = control_camera_latents.contiguous().view(b, f // 4, 4, c, h, w).transpose(2, 3)
702
+ control_camera_latents = control_camera_latents.contiguous().view(b, f // 4, c * 4, h, w).transpose(1, 2)
703
+ elif control_video is not None:
704
+ video_length = control_video.shape[2]
705
+ control_video = self.image_processor.preprocess(rearrange(control_video, "b c f h w -> (b f) c h w"), height=height, width=width)
706
+ control_video = control_video.to(dtype=torch.float32)
707
+ control_video = rearrange(control_video, "(b f) c h w -> b c f h w", f=video_length)
708
+ control_video_latents = self.prepare_control_latents(
709
+ None,
710
+ control_video,
711
+ batch_size,
712
+ height,
713
+ width,
714
+ weight_dtype,
715
+ device,
716
+ generator,
717
+ do_classifier_free_guidance
718
+ )[1]
719
+ control_camera_latents = None
720
+ else:
721
+ control_video_latents = torch.zeros_like(latents).to(device, weight_dtype)
722
+ control_camera_latents = None
723
+
724
+ if start_image is not None:
725
+ video_length = start_image.shape[2]
726
+ start_image = self.image_processor.preprocess(rearrange(start_image, "b c f h w -> (b f) c h w"), height=height, width=width)
727
+ start_image = start_image.to(dtype=torch.float32)
728
+ start_image = rearrange(start_image, "(b f) c h w -> b c f h w", f=video_length)
729
+
730
+ start_image_latentes = self.prepare_control_latents(
731
+ None,
732
+ start_image,
733
+ batch_size,
734
+ height,
735
+ width,
736
+ weight_dtype,
737
+ device,
738
+ generator,
739
+ do_classifier_free_guidance
740
+ )[1]
741
+
742
+ start_image_latentes_conv_in = torch.zeros_like(latents)
743
+ if latents.size()[2] != 1:
744
+ start_image_latentes_conv_in[:, :, :1] = start_image_latentes
745
+ else:
746
+ start_image_latentes_conv_in = torch.zeros_like(latents)
747
+
748
+ if self.transformer.config.get("add_ref_conv", False):
749
+ if ref_image is not None:
750
+ video_length = ref_image.shape[2]
751
+ ref_image = self.image_processor.preprocess(rearrange(ref_image, "b c f h w -> (b f) c h w"), height=height, width=width)
752
+ ref_image = ref_image.to(dtype=torch.float32)
753
+ ref_image = rearrange(ref_image, "(b f) c h w -> b c f h w", f=video_length)
754
+
755
+ ref_image_latentes = self.prepare_control_latents(
756
+ None,
757
+ ref_image,
758
+ batch_size,
759
+ height,
760
+ width,
761
+ weight_dtype,
762
+ device,
763
+ generator,
764
+ do_classifier_free_guidance
765
+ )[1]
766
+ ref_image_latentes = ref_image_latentes[:, :, 0]
767
+ else:
768
+ ref_image_latentes = torch.zeros_like(latents)[:, :, 0]
769
+ else:
770
+ if ref_image is not None:
771
+ raise ValueError("The add_ref_conv is False, but ref_image is not None")
772
+ else:
773
+ ref_image_latentes = None
774
+
775
+ if comfyui_progressbar:
776
+ pbar.update(1)
777
+
778
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
779
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
780
+
781
+ target_shape = (self.vae.latent_channels, (num_frames - 1) // self.vae.temporal_compression_ratio + 1, width // self.vae.spatial_compression_ratio, height // self.vae.spatial_compression_ratio)
782
+ seq_len = math.ceil((target_shape[2] * target_shape[3]) / (self.transformer.config.patch_size[1] * self.transformer.config.patch_size[2]) * target_shape[1])
783
+ # 7. Denoising loop
784
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
785
+ self.transformer.num_inference_steps = num_inference_steps
786
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
787
+ for i, t in enumerate(timesteps):
788
+ self.transformer.current_steps = i
789
+
790
+ if self.interrupt:
791
+ continue
792
+
793
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
794
+ if hasattr(self.scheduler, "scale_model_input"):
795
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
796
+
797
+ # Prepare mask latent variables
798
+ if control_camera_video is not None:
799
+ control_latents_input = None
800
+ control_camera_latents_input = (
801
+ torch.cat([control_camera_latents] * 2) if do_classifier_free_guidance else control_camera_latents
802
+ ).to(device, weight_dtype)
803
+ else:
804
+ control_latents_input = (
805
+ torch.cat([control_video_latents] * 2) if do_classifier_free_guidance else control_video_latents
806
+ ).to(device, weight_dtype)
807
+ control_camera_latents_input = None
808
+
809
+ if init_video is not None:
810
+ mask_input = torch.cat([mask_latents] * 2) if do_classifier_free_guidance else mask_latents
811
+ masked_video_latents_input = (
812
+ torch.cat([masked_video_latents] * 2) if do_classifier_free_guidance else masked_video_latents
813
+ )
814
+ y = torch.cat([mask_input, masked_video_latents_input], dim=1).to(device, weight_dtype)
815
+ control_latents_input = y if control_latents_input is None else \
816
+ torch.cat([control_latents_input, y], dim = 1)
817
+ else:
818
+ start_image_latentes_conv_in_input = (
819
+ torch.cat([start_image_latentes_conv_in] * 2) if do_classifier_free_guidance else start_image_latentes_conv_in
820
+ ).to(device, weight_dtype)
821
+ control_latents_input = start_image_latentes_conv_in_input if control_latents_input is None else \
822
+ torch.cat([control_latents_input, start_image_latentes_conv_in_input], dim = 1)
823
+
824
+ if ref_image_latentes is not None:
825
+ full_ref = (
826
+ torch.cat([ref_image_latentes] * 2) if do_classifier_free_guidance else ref_image_latentes
827
+ ).to(device, weight_dtype)
828
+ else:
829
+ full_ref = None
830
+
831
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
832
+ if self.vae.spatial_compression_ratio >= 16 and init_video is not None:
833
+ temp_ts = ((mask[0][0][:, ::2, ::2]) * t).flatten()
834
+ temp_ts = torch.cat([
835
+ temp_ts,
836
+ temp_ts.new_ones(seq_len - temp_ts.size(0)) * t
837
+ ])
838
+ temp_ts = temp_ts.unsqueeze(0)
839
+ timestep = temp_ts.expand(latent_model_input.shape[0], temp_ts.size(1))
840
+ else:
841
+ timestep = t.expand(latent_model_input.shape[0])
842
+
843
+ if self.transformer_2 is not None:
844
+ if t >= boundary * self.scheduler.config.num_train_timesteps:
845
+ local_transformer = self.transformer_2
846
+ else:
847
+ local_transformer = self.transformer
848
+ else:
849
+ local_transformer = self.transformer
850
+
851
+ # predict noise model_output
852
+ with torch.cuda.amp.autocast(dtype=weight_dtype), torch.cuda.device(device=device):
853
+ noise_pred = local_transformer(
854
+ x=latent_model_input,
855
+ context=in_prompt_embeds,
856
+ t=timestep,
857
+ seq_len=seq_len,
858
+ y=control_latents_input,
859
+ y_camera=control_camera_latents_input,
860
+ full_ref=full_ref,
861
+ )
862
+
863
+ # perform guidance
864
+ if do_classifier_free_guidance:
865
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
866
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
867
+
868
+ # compute the previous noisy sample x_t -> x_t-1
869
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
870
+
871
+ if self.vae.spatial_compression_ratio >= 16 and not mask[:, :, 0, :, :].any():
872
+ latents = (1 - mask) * masked_video_latents + mask * latents
873
+
874
+ if callback_on_step_end is not None:
875
+ callback_kwargs = {}
876
+ for k in callback_on_step_end_tensor_inputs:
877
+ callback_kwargs[k] = locals()[k]
878
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
879
+
880
+ latents = callback_outputs.pop("latents", latents)
881
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
882
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
883
+
884
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
885
+ progress_bar.update()
886
+ if comfyui_progressbar:
887
+ pbar.update(1)
888
+
889
+ if output_type == "numpy":
890
+ video = self.decode_latents(latents)
891
+ elif not output_type == "latent":
892
+ video = self.decode_latents(latents)
893
+ video = self.video_processor.postprocess_video(video=video, output_type=output_type)
894
+ else:
895
+ video = latents
896
+
897
+ # Offload all models
898
+ self.maybe_free_model_hooks()
899
+
900
+ if not return_dict:
901
+ video = torch.from_numpy(video)
902
+
903
+ return WanPipelineOutput(videos=video)
videox_fun/pipeline/pipeline_wan2_2_fun_inpaint.py ADDED
@@ -0,0 +1,752 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import math
3
+ from dataclasses import dataclass
4
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn.functional as F
9
+ import torchvision.transforms.functional as TF
10
+ from diffusers import FlowMatchEulerDiscreteScheduler
11
+ from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
12
+ from diffusers.image_processor import VaeImageProcessor
13
+ from diffusers.models.embeddings import get_1d_rotary_pos_embed
14
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
15
+ from diffusers.utils import BaseOutput, logging, replace_example_docstring
16
+ from diffusers.utils.torch_utils import randn_tensor
17
+ from diffusers.video_processor import VideoProcessor
18
+ from einops import rearrange
19
+ from PIL import Image
20
+ from transformers import T5Tokenizer
21
+
22
+ from ..models import (AutoencoderKLWan, AutoTokenizer, CLIPModel,
23
+ WanT5EncoderModel, Wan2_2Transformer3DModel)
24
+ from ..utils.fm_solvers import (FlowDPMSolverMultistepScheduler,
25
+ get_sampling_sigmas)
26
+ from ..utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
27
+
28
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
29
+
30
+
31
+ EXAMPLE_DOC_STRING = """
32
+ Examples:
33
+ ```python
34
+ pass
35
+ ```
36
+ """
37
+
38
+
39
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
40
+ def retrieve_timesteps(
41
+ scheduler,
42
+ num_inference_steps: Optional[int] = None,
43
+ device: Optional[Union[str, torch.device]] = None,
44
+ timesteps: Optional[List[int]] = None,
45
+ sigmas: Optional[List[float]] = None,
46
+ **kwargs,
47
+ ):
48
+ """
49
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
50
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
51
+
52
+ Args:
53
+ scheduler (`SchedulerMixin`):
54
+ The scheduler to get timesteps from.
55
+ num_inference_steps (`int`):
56
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
57
+ must be `None`.
58
+ device (`str` or `torch.device`, *optional*):
59
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
60
+ timesteps (`List[int]`, *optional*):
61
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
62
+ `num_inference_steps` and `sigmas` must be `None`.
63
+ sigmas (`List[float]`, *optional*):
64
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
65
+ `num_inference_steps` and `timesteps` must be `None`.
66
+
67
+ Returns:
68
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
69
+ second element is the number of inference steps.
70
+ """
71
+ if timesteps is not None and sigmas is not None:
72
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
73
+ if timesteps is not None:
74
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
75
+ if not accepts_timesteps:
76
+ raise ValueError(
77
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
78
+ f" timestep schedules. Please check whether you are using the correct scheduler."
79
+ )
80
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
81
+ timesteps = scheduler.timesteps
82
+ num_inference_steps = len(timesteps)
83
+ elif sigmas is not None:
84
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
85
+ if not accept_sigmas:
86
+ raise ValueError(
87
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
88
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
89
+ )
90
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
91
+ timesteps = scheduler.timesteps
92
+ num_inference_steps = len(timesteps)
93
+ else:
94
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
95
+ timesteps = scheduler.timesteps
96
+ return timesteps, num_inference_steps
97
+
98
+
99
+ def resize_mask(mask, latent, process_first_frame_only=True):
100
+ latent_size = latent.size()
101
+ batch_size, channels, num_frames, height, width = mask.shape
102
+
103
+ if process_first_frame_only:
104
+ target_size = list(latent_size[2:])
105
+ target_size[0] = 1
106
+ first_frame_resized = F.interpolate(
107
+ mask[:, :, 0:1, :, :],
108
+ size=target_size,
109
+ mode='trilinear',
110
+ align_corners=False
111
+ )
112
+
113
+ target_size = list(latent_size[2:])
114
+ target_size[0] = target_size[0] - 1
115
+ if target_size[0] != 0:
116
+ remaining_frames_resized = F.interpolate(
117
+ mask[:, :, 1:, :, :],
118
+ size=target_size,
119
+ mode='trilinear',
120
+ align_corners=False
121
+ )
122
+ resized_mask = torch.cat([first_frame_resized, remaining_frames_resized], dim=2)
123
+ else:
124
+ resized_mask = first_frame_resized
125
+ else:
126
+ target_size = list(latent_size[2:])
127
+ resized_mask = F.interpolate(
128
+ mask,
129
+ size=target_size,
130
+ mode='trilinear',
131
+ align_corners=False
132
+ )
133
+ return resized_mask
134
+
135
+
136
+ @dataclass
137
+ class WanPipelineOutput(BaseOutput):
138
+ r"""
139
+ Output class for CogVideo pipelines.
140
+
141
+ Args:
142
+ video (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
143
+ List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
144
+ denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
145
+ `(batch_size, num_frames, channels, height, width)`.
146
+ """
147
+
148
+ videos: torch.Tensor
149
+
150
+
151
+ class Wan2_2FunInpaintPipeline(DiffusionPipeline):
152
+ r"""
153
+ Pipeline for text-to-video generation using Wan.
154
+
155
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
156
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
157
+ """
158
+
159
+ _optional_components = ["transformer_2"]
160
+ model_cpu_offload_seq = "text_encoder->transformer_2->transformer->vae"
161
+
162
+ _callback_tensor_inputs = [
163
+ "latents",
164
+ "prompt_embeds",
165
+ "negative_prompt_embeds",
166
+ ]
167
+
168
+ def __init__(
169
+ self,
170
+ tokenizer: AutoTokenizer,
171
+ text_encoder: WanT5EncoderModel,
172
+ vae: AutoencoderKLWan,
173
+ transformer: Wan2_2Transformer3DModel,
174
+ transformer_2: Wan2_2Transformer3DModel = None,
175
+ scheduler: FlowMatchEulerDiscreteScheduler = None,
176
+ ):
177
+ super().__init__()
178
+
179
+ self.register_modules(
180
+ tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer,
181
+ transformer_2=transformer_2, scheduler=scheduler
182
+ )
183
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae.spatial_compression_ratio)
184
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae.spatial_compression_ratio)
185
+ self.mask_processor = VaeImageProcessor(
186
+ vae_scale_factor=self.vae.spatial_compression_ratio, do_normalize=False, do_binarize=True, do_convert_grayscale=True
187
+ )
188
+
189
+ def _get_t5_prompt_embeds(
190
+ self,
191
+ prompt: Union[str, List[str]] = None,
192
+ num_videos_per_prompt: int = 1,
193
+ max_sequence_length: int = 512,
194
+ device: Optional[torch.device] = None,
195
+ dtype: Optional[torch.dtype] = None,
196
+ ):
197
+ device = device or self._execution_device
198
+ dtype = dtype or self.text_encoder.dtype
199
+
200
+ prompt = [prompt] if isinstance(prompt, str) else prompt
201
+ batch_size = len(prompt)
202
+
203
+ text_inputs = self.tokenizer(
204
+ prompt,
205
+ padding="max_length",
206
+ max_length=max_sequence_length,
207
+ truncation=True,
208
+ add_special_tokens=True,
209
+ return_tensors="pt",
210
+ )
211
+ text_input_ids = text_inputs.input_ids
212
+ prompt_attention_mask = text_inputs.attention_mask
213
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
214
+
215
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
216
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
217
+ logger.warning(
218
+ "The following part of your input was truncated because `max_sequence_length` is set to "
219
+ f" {max_sequence_length} tokens: {removed_text}"
220
+ )
221
+
222
+ seq_lens = prompt_attention_mask.gt(0).sum(dim=1).long()
223
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask.to(device))[0]
224
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
225
+
226
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
227
+ _, seq_len, _ = prompt_embeds.shape
228
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
229
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
230
+
231
+ return [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
232
+
233
+ def encode_prompt(
234
+ self,
235
+ prompt: Union[str, List[str]],
236
+ negative_prompt: Optional[Union[str, List[str]]] = None,
237
+ do_classifier_free_guidance: bool = True,
238
+ num_videos_per_prompt: int = 1,
239
+ prompt_embeds: Optional[torch.Tensor] = None,
240
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
241
+ max_sequence_length: int = 512,
242
+ device: Optional[torch.device] = None,
243
+ dtype: Optional[torch.dtype] = None,
244
+ ):
245
+ r"""
246
+ Encodes the prompt into text encoder hidden states.
247
+
248
+ Args:
249
+ prompt (`str` or `List[str]`, *optional*):
250
+ prompt to be encoded
251
+ negative_prompt (`str` or `List[str]`, *optional*):
252
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
253
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
254
+ less than `1`).
255
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
256
+ Whether to use classifier free guidance or not.
257
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
258
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
259
+ prompt_embeds (`torch.Tensor`, *optional*):
260
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
261
+ provided, text embeddings will be generated from `prompt` input argument.
262
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
263
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
264
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
265
+ argument.
266
+ device: (`torch.device`, *optional*):
267
+ torch device
268
+ dtype: (`torch.dtype`, *optional*):
269
+ torch dtype
270
+ """
271
+ device = device or self._execution_device
272
+
273
+ prompt = [prompt] if isinstance(prompt, str) else prompt
274
+ if prompt is not None:
275
+ batch_size = len(prompt)
276
+ else:
277
+ batch_size = prompt_embeds.shape[0]
278
+
279
+ if prompt_embeds is None:
280
+ prompt_embeds = self._get_t5_prompt_embeds(
281
+ prompt=prompt,
282
+ num_videos_per_prompt=num_videos_per_prompt,
283
+ max_sequence_length=max_sequence_length,
284
+ device=device,
285
+ dtype=dtype,
286
+ )
287
+
288
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
289
+ negative_prompt = negative_prompt or ""
290
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
291
+
292
+ if prompt is not None and type(prompt) is not type(negative_prompt):
293
+ raise TypeError(
294
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
295
+ f" {type(prompt)}."
296
+ )
297
+ elif batch_size != len(negative_prompt):
298
+ raise ValueError(
299
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
300
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
301
+ " the batch size of `prompt`."
302
+ )
303
+
304
+ negative_prompt_embeds = self._get_t5_prompt_embeds(
305
+ prompt=negative_prompt,
306
+ num_videos_per_prompt=num_videos_per_prompt,
307
+ max_sequence_length=max_sequence_length,
308
+ device=device,
309
+ dtype=dtype,
310
+ )
311
+
312
+ return prompt_embeds, negative_prompt_embeds
313
+
314
+ def prepare_latents(
315
+ self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
316
+ ):
317
+ if isinstance(generator, list) and len(generator) != batch_size:
318
+ raise ValueError(
319
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
320
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
321
+ )
322
+
323
+ shape = (
324
+ batch_size,
325
+ num_channels_latents,
326
+ (num_frames - 1) // self.vae.temporal_compression_ratio + 1,
327
+ height // self.vae.spatial_compression_ratio,
328
+ width // self.vae.spatial_compression_ratio,
329
+ )
330
+
331
+ if latents is None:
332
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
333
+ else:
334
+ latents = latents.to(device)
335
+
336
+ # scale the initial noise by the standard deviation required by the scheduler
337
+ if hasattr(self.scheduler, "init_noise_sigma"):
338
+ latents = latents * self.scheduler.init_noise_sigma
339
+ return latents
340
+
341
+
342
+ def prepare_mask_latents(
343
+ self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance, noise_aug_strength
344
+ ):
345
+ # resize the mask to latents shape as we concatenate the mask to the latents
346
+ # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
347
+ # and half precision
348
+
349
+ if mask is not None:
350
+ mask = mask.to(device=device, dtype=self.vae.dtype)
351
+ bs = 1
352
+ new_mask = []
353
+ for i in range(0, mask.shape[0], bs):
354
+ mask_bs = mask[i : i + bs]
355
+ mask_bs = self.vae.encode(mask_bs)[0]
356
+ mask_bs = mask_bs.mode()
357
+ new_mask.append(mask_bs)
358
+ mask = torch.cat(new_mask, dim = 0)
359
+ # mask = mask * self.vae.config.scaling_factor
360
+
361
+ if masked_image is not None:
362
+ masked_image = masked_image.to(device=device, dtype=self.vae.dtype)
363
+ bs = 1
364
+ new_mask_pixel_values = []
365
+ for i in range(0, masked_image.shape[0], bs):
366
+ mask_pixel_values_bs = masked_image[i : i + bs]
367
+ mask_pixel_values_bs = self.vae.encode(mask_pixel_values_bs)[0]
368
+ mask_pixel_values_bs = mask_pixel_values_bs.mode()
369
+ new_mask_pixel_values.append(mask_pixel_values_bs)
370
+ masked_image_latents = torch.cat(new_mask_pixel_values, dim = 0)
371
+ # masked_image_latents = masked_image_latents * self.vae.config.scaling_factor
372
+ else:
373
+ masked_image_latents = None
374
+
375
+ return mask, masked_image_latents
376
+
377
+ def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
378
+ frames = self.vae.decode(latents.to(self.vae.dtype)).sample
379
+ frames = (frames / 2 + 0.5).clamp(0, 1)
380
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
381
+ frames = frames.cpu().float().numpy()
382
+ return frames
383
+
384
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
385
+ def prepare_extra_step_kwargs(self, generator, eta):
386
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
387
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
388
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
389
+ # and should be between [0, 1]
390
+
391
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
392
+ extra_step_kwargs = {}
393
+ if accepts_eta:
394
+ extra_step_kwargs["eta"] = eta
395
+
396
+ # check if the scheduler accepts generator
397
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
398
+ if accepts_generator:
399
+ extra_step_kwargs["generator"] = generator
400
+ return extra_step_kwargs
401
+
402
+ # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs
403
+ def check_inputs(
404
+ self,
405
+ prompt,
406
+ height,
407
+ width,
408
+ negative_prompt,
409
+ callback_on_step_end_tensor_inputs,
410
+ prompt_embeds=None,
411
+ negative_prompt_embeds=None,
412
+ ):
413
+ if height % 8 != 0 or width % 8 != 0:
414
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
415
+
416
+ if callback_on_step_end_tensor_inputs is not None and not all(
417
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
418
+ ):
419
+ raise ValueError(
420
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
421
+ )
422
+ if prompt is not None and prompt_embeds is not None:
423
+ raise ValueError(
424
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
425
+ " only forward one of the two."
426
+ )
427
+ elif prompt is None and prompt_embeds is None:
428
+ raise ValueError(
429
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
430
+ )
431
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
432
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
433
+
434
+ if prompt is not None and negative_prompt_embeds is not None:
435
+ raise ValueError(
436
+ f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
437
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
438
+ )
439
+
440
+ if negative_prompt is not None and negative_prompt_embeds is not None:
441
+ raise ValueError(
442
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
443
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
444
+ )
445
+
446
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
447
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
448
+ raise ValueError(
449
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
450
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
451
+ f" {negative_prompt_embeds.shape}."
452
+ )
453
+
454
+ @property
455
+ def guidance_scale(self):
456
+ return self._guidance_scale
457
+
458
+ @property
459
+ def num_timesteps(self):
460
+ return self._num_timesteps
461
+
462
+ @property
463
+ def attention_kwargs(self):
464
+ return self._attention_kwargs
465
+
466
+ @property
467
+ def interrupt(self):
468
+ return self._interrupt
469
+
470
+ @torch.no_grad()
471
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
472
+ def __call__(
473
+ self,
474
+ prompt: Optional[Union[str, List[str]]] = None,
475
+ negative_prompt: Optional[Union[str, List[str]]] = None,
476
+ height: int = 480,
477
+ width: int = 720,
478
+ video: Union[torch.FloatTensor] = None,
479
+ mask_video: Union[torch.FloatTensor] = None,
480
+ num_frames: int = 49,
481
+ num_inference_steps: int = 50,
482
+ timesteps: Optional[List[int]] = None,
483
+ guidance_scale: float = 6,
484
+ num_videos_per_prompt: int = 1,
485
+ eta: float = 0.0,
486
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
487
+ latents: Optional[torch.FloatTensor] = None,
488
+ prompt_embeds: Optional[torch.FloatTensor] = None,
489
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
490
+ output_type: str = "numpy",
491
+ return_dict: bool = False,
492
+ callback_on_step_end: Optional[
493
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
494
+ ] = None,
495
+ attention_kwargs: Optional[Dict[str, Any]] = None,
496
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
497
+ max_sequence_length: int = 512,
498
+ boundary: float = 0.875,
499
+ comfyui_progressbar: bool = False,
500
+ shift: int = 5,
501
+ ) -> Union[WanPipelineOutput, Tuple]:
502
+ """
503
+ Function invoked when calling the pipeline for generation.
504
+ Args:
505
+
506
+ Examples:
507
+
508
+ Returns:
509
+
510
+ """
511
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
512
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
513
+ num_videos_per_prompt = 1
514
+
515
+ # 1. Check inputs. Raise error if not correct
516
+ self.check_inputs(
517
+ prompt,
518
+ height,
519
+ width,
520
+ negative_prompt,
521
+ callback_on_step_end_tensor_inputs,
522
+ prompt_embeds,
523
+ negative_prompt_embeds,
524
+ )
525
+ self._guidance_scale = guidance_scale
526
+ self._attention_kwargs = attention_kwargs
527
+ self._interrupt = False
528
+
529
+ # 2. Default call parameters
530
+ if prompt is not None and isinstance(prompt, str):
531
+ batch_size = 1
532
+ elif prompt is not None and isinstance(prompt, list):
533
+ batch_size = len(prompt)
534
+ else:
535
+ batch_size = prompt_embeds.shape[0]
536
+
537
+ device = self._execution_device
538
+ weight_dtype = self.text_encoder.dtype
539
+
540
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
541
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
542
+ # corresponds to doing no classifier free guidance.
543
+ do_classifier_free_guidance = guidance_scale > 1.0
544
+
545
+ # 3. Encode input prompt
546
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
547
+ prompt,
548
+ negative_prompt,
549
+ do_classifier_free_guidance,
550
+ num_videos_per_prompt=num_videos_per_prompt,
551
+ prompt_embeds=prompt_embeds,
552
+ negative_prompt_embeds=negative_prompt_embeds,
553
+ max_sequence_length=max_sequence_length,
554
+ device=device,
555
+ )
556
+ if do_classifier_free_guidance:
557
+ in_prompt_embeds = negative_prompt_embeds + prompt_embeds
558
+ else:
559
+ in_prompt_embeds = prompt_embeds
560
+
561
+ # 4. Prepare timesteps
562
+ if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler):
563
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps, mu=1)
564
+ elif isinstance(self.scheduler, FlowUniPCMultistepScheduler):
565
+ self.scheduler.set_timesteps(num_inference_steps, device=device, shift=shift)
566
+ timesteps = self.scheduler.timesteps
567
+ elif isinstance(self.scheduler, FlowDPMSolverMultistepScheduler):
568
+ sampling_sigmas = get_sampling_sigmas(num_inference_steps, shift)
569
+ timesteps, _ = retrieve_timesteps(
570
+ self.scheduler,
571
+ device=device,
572
+ sigmas=sampling_sigmas)
573
+ else:
574
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
575
+ self._num_timesteps = len(timesteps)
576
+ if comfyui_progressbar:
577
+ from comfy.utils import ProgressBar
578
+ pbar = ProgressBar(num_inference_steps + 2)
579
+
580
+ # 5. Prepare latents.
581
+ if video is not None:
582
+ video_length = video.shape[2]
583
+ init_video = self.image_processor.preprocess(rearrange(video, "b c f h w -> (b f) c h w"), height=height, width=width)
584
+ init_video = init_video.to(dtype=torch.float32)
585
+ init_video = rearrange(init_video, "(b f) c h w -> b c f h w", f=video_length)
586
+ else:
587
+ init_video = None
588
+
589
+ latent_channels = self.vae.config.latent_channels
590
+ latents = self.prepare_latents(
591
+ batch_size * num_videos_per_prompt,
592
+ latent_channels,
593
+ num_frames,
594
+ height,
595
+ width,
596
+ weight_dtype,
597
+ device,
598
+ generator,
599
+ latents,
600
+ )
601
+ if comfyui_progressbar:
602
+ pbar.update(1)
603
+
604
+ # Prepare mask latent variables
605
+ if init_video is not None:
606
+ if (mask_video == 255).all():
607
+ mask_latents = torch.tile(
608
+ torch.zeros_like(latents)[:, :1].to(device, weight_dtype), [1, 4, 1, 1, 1]
609
+ )
610
+ masked_video_latents = torch.zeros_like(latents).to(device, weight_dtype)
611
+ if self.vae.spatial_compression_ratio >= 16:
612
+ mask = torch.ones_like(latents).to(device, weight_dtype)[:, :1].to(device, weight_dtype)
613
+ else:
614
+ bs, _, video_length, height, width = video.size()
615
+ mask_condition = self.mask_processor.preprocess(rearrange(mask_video, "b c f h w -> (b f) c h w"), height=height, width=width)
616
+ mask_condition = mask_condition.to(dtype=torch.float32)
617
+ mask_condition = rearrange(mask_condition, "(b f) c h w -> b c f h w", f=video_length)
618
+
619
+ masked_video = init_video * (torch.tile(mask_condition, [1, 3, 1, 1, 1]) < 0.5)
620
+ _, masked_video_latents = self.prepare_mask_latents(
621
+ None,
622
+ masked_video,
623
+ batch_size,
624
+ height,
625
+ width,
626
+ weight_dtype,
627
+ device,
628
+ generator,
629
+ do_classifier_free_guidance,
630
+ noise_aug_strength=None,
631
+ )
632
+
633
+ mask_condition = torch.concat(
634
+ [
635
+ torch.repeat_interleave(mask_condition[:, :, 0:1], repeats=4, dim=2),
636
+ mask_condition[:, :, 1:]
637
+ ], dim=2
638
+ )
639
+ mask_condition = mask_condition.view(bs, mask_condition.shape[2] // 4, 4, height, width)
640
+ mask_condition = mask_condition.transpose(1, 2)
641
+ mask_latents = resize_mask(1 - mask_condition, masked_video_latents, True).to(device, weight_dtype)
642
+
643
+ if self.vae.spatial_compression_ratio >= 16:
644
+ mask = F.interpolate(mask_condition[:, :1], size=latents.size()[-3:], mode='trilinear', align_corners=True).to(device, weight_dtype)
645
+ if not mask[:, :, 0, :, :].any():
646
+ mask[:, :, 1:, :, :] = 1
647
+ latents = (1 - mask) * masked_video_latents + mask * latents
648
+
649
+ if comfyui_progressbar:
650
+ pbar.update(1)
651
+
652
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
653
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
654
+
655
+ target_shape = (self.vae.latent_channels, (num_frames - 1) // self.vae.temporal_compression_ratio + 1, width // self.vae.spatial_compression_ratio, height // self.vae.spatial_compression_ratio)
656
+ seq_len = math.ceil((target_shape[2] * target_shape[3]) / (self.transformer.config.patch_size[1] * self.transformer.config.patch_size[2]) * target_shape[1])
657
+ # 7. Denoising loop
658
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
659
+ self.transformer.num_inference_steps = num_inference_steps
660
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
661
+ for i, t in enumerate(timesteps):
662
+ self.transformer.current_steps = i
663
+
664
+ if self.interrupt:
665
+ continue
666
+
667
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
668
+ if hasattr(self.scheduler, "scale_model_input"):
669
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
670
+
671
+ if init_video is not None:
672
+ mask_input = torch.cat([mask_latents] * 2) if do_classifier_free_guidance else mask_latents
673
+ masked_video_latents_input = (
674
+ torch.cat([masked_video_latents] * 2) if do_classifier_free_guidance else masked_video_latents
675
+ )
676
+ y = torch.cat([mask_input, masked_video_latents_input], dim=1).to(device, weight_dtype)
677
+
678
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
679
+ if self.vae.spatial_compression_ratio >= 16 and init_video is not None:
680
+ temp_ts = ((mask[0][0][:, ::2, ::2]) * t).flatten()
681
+ temp_ts = torch.cat([
682
+ temp_ts,
683
+ temp_ts.new_ones(seq_len - temp_ts.size(0)) * t
684
+ ])
685
+ temp_ts = temp_ts.unsqueeze(0)
686
+ timestep = temp_ts.expand(latent_model_input.shape[0], temp_ts.size(1))
687
+ else:
688
+ timestep = t.expand(latent_model_input.shape[0])
689
+
690
+ if self.transformer_2 is not None:
691
+ if t >= boundary * self.scheduler.config.num_train_timesteps:
692
+ local_transformer = self.transformer_2
693
+ else:
694
+ local_transformer = self.transformer
695
+ else:
696
+ local_transformer = self.transformer
697
+
698
+ # predict noise model_output
699
+ with torch.cuda.amp.autocast(dtype=weight_dtype), torch.cuda.device(device=device):
700
+ noise_pred = local_transformer(
701
+ x=latent_model_input,
702
+ context=in_prompt_embeds,
703
+ t=timestep,
704
+ seq_len=seq_len,
705
+ y=y,
706
+ )
707
+
708
+ # perform guidance
709
+ if do_classifier_free_guidance:
710
+ if self.transformer_2 is not None and (isinstance(self.guidance_scale, (list, tuple))):
711
+ sample_guide_scale = self.guidance_scale[1] if t >= self.transformer_2.config.boundary * self.scheduler.config.num_train_timesteps else self.guidance_scale[0]
712
+ else:
713
+ sample_guide_scale = self.guidance_scale
714
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
715
+ noise_pred = noise_pred_uncond + sample_guide_scale * (noise_pred_text - noise_pred_uncond)
716
+
717
+ # compute the previous noisy sample x_t -> x_t-1
718
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
719
+
720
+ if self.vae.spatial_compression_ratio >= 16 and not mask[:, :, 0, :, :].any():
721
+ latents = (1 - mask) * masked_video_latents + mask * latents
722
+
723
+ if callback_on_step_end is not None:
724
+ callback_kwargs = {}
725
+ for k in callback_on_step_end_tensor_inputs:
726
+ callback_kwargs[k] = locals()[k]
727
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
728
+
729
+ latents = callback_outputs.pop("latents", latents)
730
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
731
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
732
+
733
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
734
+ progress_bar.update()
735
+ if comfyui_progressbar:
736
+ pbar.update(1)
737
+
738
+ if output_type == "numpy":
739
+ video = self.decode_latents(latents)
740
+ elif not output_type == "latent":
741
+ video = self.decode_latents(latents)
742
+ video = self.video_processor.postprocess_video(video=video, output_type=output_type)
743
+ else:
744
+ video = latents
745
+
746
+ # Offload all models
747
+ self.maybe_free_model_hooks()
748
+
749
+ if not return_dict:
750
+ video = torch.from_numpy(video)
751
+
752
+ return WanPipelineOutput(videos=video)
videox_fun/pipeline/pipeline_wan2_2_s2v.py ADDED
@@ -0,0 +1,815 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import math
3
+ import copy
4
+ from dataclasses import dataclass
5
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn.functional as F
10
+ from diffusers import FlowMatchEulerDiscreteScheduler
11
+ from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
12
+ from diffusers.image_processor import VaeImageProcessor
13
+ from diffusers.models.embeddings import get_1d_rotary_pos_embed
14
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
15
+ from diffusers.utils import BaseOutput, logging, replace_example_docstring
16
+ from diffusers.utils.torch_utils import randn_tensor
17
+ from diffusers.video_processor import VideoProcessor
18
+ from einops import rearrange
19
+ from PIL import Image
20
+ from torchvision import transforms
21
+ from transformers import T5Tokenizer
22
+
23
+ from ..models import (AutoencoderKLWan, AutoTokenizer, CLIPModel,
24
+ Wan2_2Transformer3DModel_S2V, WanAudioEncoder,
25
+ WanT5EncoderModel)
26
+ from ..utils.fm_solvers import (FlowDPMSolverMultistepScheduler,
27
+ get_sampling_sigmas)
28
+ from ..utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
29
+
30
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
31
+
32
+
33
+ EXAMPLE_DOC_STRING = """
34
+ Examples:
35
+ ```python
36
+ pass
37
+ ```
38
+ """
39
+
40
+
41
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
42
+ def retrieve_timesteps(
43
+ scheduler,
44
+ num_inference_steps: Optional[int] = None,
45
+ device: Optional[Union[str, torch.device]] = None,
46
+ timesteps: Optional[List[int]] = None,
47
+ sigmas: Optional[List[float]] = None,
48
+ **kwargs,
49
+ ):
50
+ """
51
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
52
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
53
+
54
+ Args:
55
+ scheduler (`SchedulerMixin`):
56
+ The scheduler to get timesteps from.
57
+ num_inference_steps (`int`):
58
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
59
+ must be `None`.
60
+ device (`str` or `torch.device`, *optional*):
61
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
62
+ timesteps (`List[int]`, *optional*):
63
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
64
+ `num_inference_steps` and `sigmas` must be `None`.
65
+ sigmas (`List[float]`, *optional*):
66
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
67
+ `num_inference_steps` and `timesteps` must be `None`.
68
+
69
+ Returns:
70
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
71
+ second element is the number of inference steps.
72
+ """
73
+ if timesteps is not None and sigmas is not None:
74
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
75
+ if timesteps is not None:
76
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
77
+ if not accepts_timesteps:
78
+ raise ValueError(
79
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
80
+ f" timestep schedules. Please check whether you are using the correct scheduler."
81
+ )
82
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
83
+ timesteps = scheduler.timesteps
84
+ num_inference_steps = len(timesteps)
85
+ elif sigmas is not None:
86
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
87
+ if not accept_sigmas:
88
+ raise ValueError(
89
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
90
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
91
+ )
92
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
93
+ timesteps = scheduler.timesteps
94
+ num_inference_steps = len(timesteps)
95
+ else:
96
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
97
+ timesteps = scheduler.timesteps
98
+ return timesteps, num_inference_steps
99
+
100
+
101
+ def resize_mask(mask, latent, process_first_frame_only=True):
102
+ latent_size = latent.size()
103
+ batch_size, channels, num_frames, height, width = mask.shape
104
+
105
+ if process_first_frame_only:
106
+ target_size = list(latent_size[2:])
107
+ target_size[0] = 1
108
+ first_frame_resized = F.interpolate(
109
+ mask[:, :, 0:1, :, :],
110
+ size=target_size,
111
+ mode='trilinear',
112
+ align_corners=False
113
+ )
114
+
115
+ target_size = list(latent_size[2:])
116
+ target_size[0] = target_size[0] - 1
117
+ if target_size[0] != 0:
118
+ remaining_frames_resized = F.interpolate(
119
+ mask[:, :, 1:, :, :],
120
+ size=target_size,
121
+ mode='trilinear',
122
+ align_corners=False
123
+ )
124
+ resized_mask = torch.cat([first_frame_resized, remaining_frames_resized], dim=2)
125
+ else:
126
+ resized_mask = first_frame_resized
127
+ else:
128
+ target_size = list(latent_size[2:])
129
+ resized_mask = F.interpolate(
130
+ mask,
131
+ size=target_size,
132
+ mode='trilinear',
133
+ align_corners=False
134
+ )
135
+ return resized_mask
136
+
137
+
138
+ @dataclass
139
+ class WanPipelineOutput(BaseOutput):
140
+ r"""
141
+ Output class for CogVideo pipelines.
142
+
143
+ Args:
144
+ video (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
145
+ List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
146
+ denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
147
+ `(batch_size, num_frames, channels, height, width)`.
148
+ """
149
+
150
+ videos: torch.Tensor
151
+
152
+
153
+ class Wan2_2S2VPipeline(DiffusionPipeline):
154
+ r"""
155
+ Pipeline for text-to-video generation using Wan.
156
+
157
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
158
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
159
+ """
160
+
161
+ _optional_components = ["transformer_2", "audio_encoder"]
162
+ model_cpu_offload_seq = "text_encoder->transformer_2->transformer->vae"
163
+
164
+ _callback_tensor_inputs = [
165
+ "latents",
166
+ "prompt_embeds",
167
+ "negative_prompt_embeds",
168
+ ]
169
+
170
+ def __init__(
171
+ self,
172
+ tokenizer: AutoTokenizer,
173
+ text_encoder: WanT5EncoderModel,
174
+ audio_encoder: WanAudioEncoder,
175
+ vae: AutoencoderKLWan,
176
+ transformer: Wan2_2Transformer3DModel_S2V,
177
+ transformer_2: Wan2_2Transformer3DModel_S2V = None,
178
+ scheduler: FlowMatchEulerDiscreteScheduler = None,
179
+ ):
180
+ super().__init__()
181
+
182
+ self.register_modules(
183
+ tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer,
184
+ transformer_2=transformer_2, scheduler=scheduler, audio_encoder=audio_encoder
185
+ )
186
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae.spatial_compression_ratio)
187
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae.spatial_compression_ratio)
188
+ self.mask_processor = VaeImageProcessor(
189
+ vae_scale_factor=self.vae.spatial_compression_ratio, do_normalize=False, do_binarize=True, do_convert_grayscale=True
190
+ )
191
+ self.motion_frames = 73
192
+ self.audio_sample_m = 0
193
+ self.drop_first_motion = True
194
+
195
+ def _get_t5_prompt_embeds(
196
+ self,
197
+ prompt: Union[str, List[str]] = None,
198
+ num_videos_per_prompt: int = 1,
199
+ max_sequence_length: int = 512,
200
+ device: Optional[torch.device] = None,
201
+ dtype: Optional[torch.dtype] = None,
202
+ ):
203
+ device = device or self._execution_device
204
+ dtype = dtype or self.text_encoder.dtype
205
+
206
+ prompt = [prompt] if isinstance(prompt, str) else prompt
207
+ batch_size = len(prompt)
208
+
209
+ text_inputs = self.tokenizer(
210
+ prompt,
211
+ padding="max_length",
212
+ max_length=max_sequence_length,
213
+ truncation=True,
214
+ add_special_tokens=True,
215
+ return_tensors="pt",
216
+ )
217
+ text_input_ids = text_inputs.input_ids
218
+ prompt_attention_mask = text_inputs.attention_mask
219
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
220
+
221
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
222
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
223
+ logger.warning(
224
+ "The following part of your input was truncated because `max_sequence_length` is set to "
225
+ f" {max_sequence_length} tokens: {removed_text}"
226
+ )
227
+
228
+ seq_lens = prompt_attention_mask.gt(0).sum(dim=1).long()
229
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask.to(device))[0]
230
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
231
+
232
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
233
+ _, seq_len, _ = prompt_embeds.shape
234
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
235
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
236
+
237
+ return [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
238
+
239
+ def encode_prompt(
240
+ self,
241
+ prompt: Union[str, List[str]],
242
+ negative_prompt: Optional[Union[str, List[str]]] = None,
243
+ do_classifier_free_guidance: bool = True,
244
+ num_videos_per_prompt: int = 1,
245
+ prompt_embeds: Optional[torch.Tensor] = None,
246
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
247
+ max_sequence_length: int = 512,
248
+ device: Optional[torch.device] = None,
249
+ dtype: Optional[torch.dtype] = None,
250
+ ):
251
+ r"""
252
+ Encodes the prompt into text encoder hidden states.
253
+
254
+ Args:
255
+ prompt (`str` or `List[str]`, *optional*):
256
+ prompt to be encoded
257
+ negative_prompt (`str` or `List[str]`, *optional*):
258
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
259
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
260
+ less than `1`).
261
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
262
+ Whether to use classifier free guidance or not.
263
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
264
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
265
+ prompt_embeds (`torch.Tensor`, *optional*):
266
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
267
+ provided, text embeddings will be generated from `prompt` input argument.
268
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
269
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
270
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
271
+ argument.
272
+ device: (`torch.device`, *optional*):
273
+ torch device
274
+ dtype: (`torch.dtype`, *optional*):
275
+ torch dtype
276
+ """
277
+ device = device or self._execution_device
278
+
279
+ prompt = [prompt] if isinstance(prompt, str) else prompt
280
+ if prompt is not None:
281
+ batch_size = len(prompt)
282
+ else:
283
+ batch_size = prompt_embeds.shape[0]
284
+
285
+ if prompt_embeds is None:
286
+ prompt_embeds = self._get_t5_prompt_embeds(
287
+ prompt=prompt,
288
+ num_videos_per_prompt=num_videos_per_prompt,
289
+ max_sequence_length=max_sequence_length,
290
+ device=device,
291
+ dtype=dtype,
292
+ )
293
+
294
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
295
+ negative_prompt = negative_prompt or ""
296
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
297
+
298
+ if prompt is not None and type(prompt) is not type(negative_prompt):
299
+ raise TypeError(
300
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
301
+ f" {type(prompt)}."
302
+ )
303
+ elif batch_size != len(negative_prompt):
304
+ raise ValueError(
305
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
306
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
307
+ " the batch size of `prompt`."
308
+ )
309
+
310
+ negative_prompt_embeds = self._get_t5_prompt_embeds(
311
+ prompt=negative_prompt,
312
+ num_videos_per_prompt=num_videos_per_prompt,
313
+ max_sequence_length=max_sequence_length,
314
+ device=device,
315
+ dtype=dtype,
316
+ )
317
+
318
+ return prompt_embeds, negative_prompt_embeds
319
+
320
+ def encode_audio_embeddings(self, audio_path, num_frames, fps, weight_dtype, device):
321
+ z = self.audio_encoder.extract_audio_feat(
322
+ audio_path, return_all_layers=True)
323
+ audio_embed_bucket, num_repeat = self.audio_encoder.get_audio_embed_bucket_fps(
324
+ z, fps=fps, batch_frames=num_frames, m=self.audio_sample_m)
325
+ audio_embed_bucket = audio_embed_bucket.to(device,
326
+ weight_dtype)
327
+ audio_embed_bucket = audio_embed_bucket.unsqueeze(0)
328
+ if len(audio_embed_bucket.shape) == 3:
329
+ audio_embed_bucket = audio_embed_bucket.permute(0, 2, 1)
330
+ elif len(audio_embed_bucket.shape) == 4:
331
+ audio_embed_bucket = audio_embed_bucket.permute(0, 2, 3, 1)
332
+ return audio_embed_bucket, num_repeat
333
+
334
+ def encode_pose_latents(self, pose_video, num_repeat, num_frames, size, fps, weight_dtype, device):
335
+ height, width = size
336
+ if not pose_video is None:
337
+ padding_frame_num = num_repeat * num_frames - pose_video.shape[2]
338
+ pose_video = torch.cat(
339
+ [
340
+ pose_video,
341
+ -torch.ones([1, 3, padding_frame_num, height, width])
342
+ ],
343
+ dim=2
344
+ )
345
+
346
+ cond_tensors = torch.chunk(pose_video, num_repeat, dim=2)
347
+ else:
348
+ cond_tensors = [-torch.ones([1, 3, num_frames, height, width])]
349
+
350
+ pose_latents = []
351
+ for r in range(len(cond_tensors)):
352
+ cond = cond_tensors[r]
353
+ cond = torch.cat([cond[:, :, 0:1].repeat(1, 1, 1, 1, 1), cond],
354
+ dim=2)
355
+ cond_lat = self.vae.encode(cond.to(dtype=weight_dtype, device=device))[0].mode()[:, :, 1:]
356
+ pose_latents.append(cond_lat)
357
+ return pose_latents
358
+
359
+ def prepare_latents(
360
+ self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None, num_length_latents=None
361
+ ):
362
+ if isinstance(generator, list) and len(generator) != batch_size:
363
+ raise ValueError(
364
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
365
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
366
+ )
367
+
368
+ shape = (
369
+ batch_size,
370
+ num_channels_latents,
371
+ (num_frames - 1) // self.vae.temporal_compression_ratio + 1 if num_length_latents is None else num_length_latents,
372
+ height // self.vae.spatial_compression_ratio,
373
+ width // self.vae.spatial_compression_ratio,
374
+ )
375
+
376
+ if latents is None:
377
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
378
+ else:
379
+ latents = latents.to(device)
380
+
381
+ # scale the initial noise by the standard deviation required by the scheduler
382
+ if hasattr(self.scheduler, "init_noise_sigma"):
383
+ latents = latents * self.scheduler.init_noise_sigma
384
+ return latents
385
+
386
+ def prepare_control_latents(
387
+ self, control, control_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
388
+ ):
389
+ # resize the control to latents shape as we concatenate the control to the latents
390
+ # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
391
+ # and half precision
392
+
393
+ if control is not None:
394
+ control = control.to(device=device, dtype=dtype)
395
+ bs = 1
396
+ new_control = []
397
+ for i in range(0, control.shape[0], bs):
398
+ control_bs = control[i : i + bs]
399
+ control_bs = self.vae.encode(control_bs)[0]
400
+ control_bs = control_bs.mode()
401
+ new_control.append(control_bs)
402
+ control = torch.cat(new_control, dim = 0)
403
+
404
+ if control_image is not None:
405
+ control_image = control_image.to(device=device, dtype=dtype)
406
+ bs = 1
407
+ new_control_pixel_values = []
408
+ for i in range(0, control_image.shape[0], bs):
409
+ control_pixel_values_bs = control_image[i : i + bs]
410
+ control_pixel_values_bs = self.vae.encode(control_pixel_values_bs)[0]
411
+ control_pixel_values_bs = control_pixel_values_bs.mode()
412
+ new_control_pixel_values.append(control_pixel_values_bs)
413
+ control_image_latents = torch.cat(new_control_pixel_values, dim = 0)
414
+ else:
415
+ control_image_latents = None
416
+
417
+ return control, control_image_latents
418
+
419
+ def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
420
+ frames = self.vae.decode(latents.to(self.vae.dtype)).sample
421
+ frames = (frames / 2 + 0.5).clamp(0, 1)
422
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
423
+ # frames = frames.cpu().float().numpy()
424
+ return frames
425
+
426
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
427
+ def prepare_extra_step_kwargs(self, generator, eta):
428
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
429
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
430
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
431
+ # and should be between [0, 1]
432
+
433
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
434
+ extra_step_kwargs = {}
435
+ if accepts_eta:
436
+ extra_step_kwargs["eta"] = eta
437
+
438
+ # check if the scheduler accepts generator
439
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
440
+ if accepts_generator:
441
+ extra_step_kwargs["generator"] = generator
442
+ return extra_step_kwargs
443
+
444
+ # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs
445
+ def check_inputs(
446
+ self,
447
+ prompt,
448
+ height,
449
+ width,
450
+ negative_prompt,
451
+ callback_on_step_end_tensor_inputs,
452
+ prompt_embeds=None,
453
+ negative_prompt_embeds=None,
454
+ ):
455
+ if height % 8 != 0 or width % 8 != 0:
456
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
457
+
458
+ if callback_on_step_end_tensor_inputs is not None and not all(
459
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
460
+ ):
461
+ raise ValueError(
462
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
463
+ )
464
+ if prompt is not None and prompt_embeds is not None:
465
+ raise ValueError(
466
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
467
+ " only forward one of the two."
468
+ )
469
+ elif prompt is None and prompt_embeds is None:
470
+ raise ValueError(
471
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
472
+ )
473
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
474
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
475
+
476
+ if prompt is not None and negative_prompt_embeds is not None:
477
+ raise ValueError(
478
+ f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
479
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
480
+ )
481
+
482
+ if negative_prompt is not None and negative_prompt_embeds is not None:
483
+ raise ValueError(
484
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
485
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
486
+ )
487
+
488
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
489
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
490
+ raise ValueError(
491
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
492
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
493
+ f" {negative_prompt_embeds.shape}."
494
+ )
495
+
496
+ @property
497
+ def guidance_scale(self):
498
+ return self._guidance_scale
499
+
500
+ @property
501
+ def num_timesteps(self):
502
+ return self._num_timesteps
503
+
504
+ @property
505
+ def attention_kwargs(self):
506
+ return self._attention_kwargs
507
+
508
+ @property
509
+ def interrupt(self):
510
+ return self._interrupt
511
+
512
+ @torch.no_grad()
513
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
514
+ def __call__(
515
+ self,
516
+ prompt: Optional[Union[str, List[str]]] = None,
517
+ negative_prompt: Optional[Union[str, List[str]]] = None,
518
+ height: int = 480,
519
+ width: int = 720,
520
+ ref_image: Union[torch.FloatTensor] = None,
521
+ audio_path = None,
522
+ pose_video = None,
523
+ num_frames: int = 49,
524
+ num_inference_steps: int = 50,
525
+ timesteps: Optional[List[int]] = None,
526
+ guidance_scale: float = 6,
527
+ num_videos_per_prompt: int = 1,
528
+ eta: float = 0.0,
529
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
530
+ latents: Optional[torch.FloatTensor] = None,
531
+ prompt_embeds: Optional[torch.FloatTensor] = None,
532
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
533
+ output_type: str = "numpy",
534
+ return_dict: bool = False,
535
+ callback_on_step_end: Optional[
536
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
537
+ ] = None,
538
+ attention_kwargs: Optional[Dict[str, Any]] = None,
539
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
540
+ max_sequence_length: int = 512,
541
+ boundary: float = 0.875,
542
+ comfyui_progressbar: bool = False,
543
+ shift: int = 5,
544
+ fps: int = 16,
545
+ init_first_frame: bool = False,
546
+ ) -> Union[WanPipelineOutput, Tuple]:
547
+ """
548
+ Function invoked when calling the pipeline for generation.
549
+ Args:
550
+
551
+ Examples:
552
+
553
+ Returns:
554
+
555
+ """
556
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
557
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
558
+ num_videos_per_prompt = 1
559
+
560
+ # 1. Check inputs. Raise error if not correct
561
+ self.check_inputs(
562
+ prompt,
563
+ height,
564
+ width,
565
+ negative_prompt,
566
+ callback_on_step_end_tensor_inputs,
567
+ prompt_embeds,
568
+ negative_prompt_embeds,
569
+ )
570
+ self._guidance_scale = guidance_scale
571
+ self._attention_kwargs = attention_kwargs
572
+ self._interrupt = False
573
+
574
+ # 2. Default call parameters
575
+ if prompt is not None and isinstance(prompt, str):
576
+ batch_size = 1
577
+ elif prompt is not None and isinstance(prompt, list):
578
+ batch_size = len(prompt)
579
+ else:
580
+ batch_size = prompt_embeds.shape[0]
581
+
582
+ device = self._execution_device
583
+ weight_dtype = self.text_encoder.dtype
584
+
585
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
586
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
587
+ # corresponds to doing no classifier free guidance.
588
+ do_classifier_free_guidance = guidance_scale > 1.0
589
+
590
+ lat_motion_frames = (self.motion_frames + 3) // 4
591
+ lat_target_frames = (num_frames + 3 + self.motion_frames) // 4 - lat_motion_frames
592
+
593
+ # 3. Encode input prompt
594
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
595
+ prompt,
596
+ negative_prompt,
597
+ do_classifier_free_guidance,
598
+ num_videos_per_prompt=num_videos_per_prompt,
599
+ prompt_embeds=prompt_embeds,
600
+ negative_prompt_embeds=negative_prompt_embeds,
601
+ max_sequence_length=max_sequence_length,
602
+ device=device,
603
+ )
604
+ if do_classifier_free_guidance:
605
+ in_prompt_embeds = negative_prompt_embeds + prompt_embeds
606
+ else:
607
+ in_prompt_embeds = prompt_embeds
608
+
609
+ if comfyui_progressbar:
610
+ from comfy.utils import ProgressBar
611
+ pbar = ProgressBar(num_inference_steps + 2)
612
+
613
+ # 5. Prepare latents.
614
+ latent_channels = self.vae.config.latent_channels
615
+ if comfyui_progressbar:
616
+ pbar.update(1)
617
+
618
+ video_length = ref_image.shape[2]
619
+ ref_image = self.image_processor.preprocess(rearrange(ref_image, "b c f h w -> (b f) c h w"), height=height, width=width)
620
+ ref_image = ref_image.to(dtype=torch.float32)
621
+ ref_image = rearrange(ref_image, "(b f) c h w -> b c f h w", f=video_length)
622
+
623
+ ref_image_latentes = self.prepare_control_latents(
624
+ None,
625
+ ref_image,
626
+ batch_size,
627
+ height,
628
+ width,
629
+ weight_dtype,
630
+ device,
631
+ generator,
632
+ do_classifier_free_guidance
633
+ )[1]
634
+ ref_image_latentes = ref_image_latentes[:, :, :1]
635
+
636
+ # Extract audio emb
637
+ audio_emb, num_repeat = self.encode_audio_embeddings(
638
+ audio_path, num_frames=num_frames, fps=fps, weight_dtype=weight_dtype, device=device
639
+ )
640
+
641
+ # Encode the motion latents
642
+ motion_latents = torch.zeros(
643
+ [1, 3, self.motion_frames, height, width],
644
+ dtype=weight_dtype,
645
+ device=device
646
+ )
647
+ videos_last_frames = motion_latents.detach()
648
+ drop_first_motion = self.drop_first_motion
649
+ if init_first_frame:
650
+ drop_first_motion = False
651
+ motion_latents[:, :, -6:] = ref_image
652
+ motion_latents = self.vae.encode(motion_latents)[0].mode()
653
+
654
+ # Get pose cond input if need
655
+ if pose_video is not None:
656
+ video_length = pose_video.shape[2]
657
+ pose_video = self.image_processor.preprocess(rearrange(pose_video, "b c f h w -> (b f) c h w"), height=height, width=width)
658
+ pose_video = pose_video.to(dtype=torch.float32)
659
+ pose_video = rearrange(pose_video, "(b f) c h w -> b c f h w", f=video_length)
660
+ pose_latents = self.encode_pose_latents(
661
+ pose_video=pose_video,
662
+ num_repeat=num_repeat,
663
+ num_frames=num_frames,
664
+ size=(height, width),
665
+ fps=fps,
666
+ weight_dtype=weight_dtype,
667
+ device=device
668
+ )
669
+
670
+ if comfyui_progressbar:
671
+ pbar.update(1)
672
+
673
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
674
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
675
+
676
+ videos = []
677
+ copy_timesteps = copy.deepcopy(timesteps)
678
+ copy_latents = copy.deepcopy(latents)
679
+ for r in range(num_repeat):
680
+ # Prepare timesteps
681
+ if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler):
682
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, copy_timesteps, mu=1)
683
+ elif isinstance(self.scheduler, FlowUniPCMultistepScheduler):
684
+ self.scheduler.set_timesteps(num_inference_steps, device=device, shift=shift)
685
+ timesteps = self.scheduler.timesteps
686
+ elif isinstance(self.scheduler, FlowDPMSolverMultistepScheduler):
687
+ sampling_sigmas = get_sampling_sigmas(num_inference_steps, shift)
688
+ timesteps, _ = retrieve_timesteps(
689
+ self.scheduler,
690
+ device=device,
691
+ sigmas=sampling_sigmas)
692
+ else:
693
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, copy_timesteps)
694
+ self._num_timesteps = len(timesteps)
695
+
696
+ target_shape = (self.vae.latent_channels, lat_target_frames, width // self.vae.spatial_compression_ratio, height // self.vae.spatial_compression_ratio)
697
+ seq_len = math.ceil((target_shape[2] * target_shape[3]) / (self.transformer.config.patch_size[1] * self.transformer.config.patch_size[2]) * target_shape[1])
698
+
699
+ latents = self.prepare_latents(
700
+ batch_size * num_videos_per_prompt,
701
+ latent_channels,
702
+ num_frames,
703
+ height,
704
+ width,
705
+ weight_dtype,
706
+ device,
707
+ generator,
708
+ copy_latents,
709
+ num_length_latents=target_shape[1]
710
+ )
711
+ # 7. Denoising loop
712
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
713
+ self.transformer.num_inference_steps = num_inference_steps
714
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
715
+ for i, t in enumerate(timesteps):
716
+ self.transformer.current_steps = i
717
+
718
+ if self.interrupt:
719
+ continue
720
+
721
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
722
+ if hasattr(self.scheduler, "scale_model_input"):
723
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
724
+
725
+ with torch.no_grad():
726
+ left_idx = r * num_frames
727
+ right_idx = r * num_frames + num_frames
728
+ cond_latents = pose_latents[r] if pose_video is not None else pose_latents[0] * 0
729
+ cond_latents = cond_latents.to(dtype=weight_dtype, device=device)
730
+ audio_input = audio_emb[..., left_idx:right_idx]
731
+
732
+ pose_latents_input = torch.cat([cond_latents] * 2) if do_classifier_free_guidance else cond_latents
733
+ motion_latents_input = torch.cat([motion_latents] * 2) if do_classifier_free_guidance else motion_latents
734
+ audio_emb_input = torch.cat([audio_input * 0] + [audio_input]) if do_classifier_free_guidance else audio_input
735
+ ref_image_latentes_input = torch.cat([ref_image_latentes] * 2) if do_classifier_free_guidance else ref_image_latentes
736
+ motion_frames=[[self.motion_frames, (self.motion_frames + 3) // 4]] * 2 if do_classifier_free_guidance else [[self.motion_frames, (self.motion_frames + 3) // 4]]
737
+ timestep = t.expand(latent_model_input.shape[0])
738
+
739
+ if self.transformer_2 is not None:
740
+ if t >= boundary * self.scheduler.config.num_train_timesteps:
741
+ local_transformer = self.transformer_2
742
+ else:
743
+ local_transformer = self.transformer
744
+ else:
745
+ local_transformer = self.transformer
746
+
747
+ # predict noise model_output
748
+ with torch.cuda.amp.autocast(dtype=weight_dtype), torch.cuda.device(device=device):
749
+ noise_pred = local_transformer(
750
+ x=latent_model_input,
751
+ context=in_prompt_embeds,
752
+ t=timestep,
753
+ seq_len=seq_len,
754
+ cond_states=pose_latents_input,
755
+ motion_latents=motion_latents_input,
756
+ ref_latents=ref_image_latentes_input,
757
+ audio_input=audio_emb_input,
758
+ motion_frames=motion_frames,
759
+ drop_motion_frames=drop_first_motion and r == 0,
760
+ )
761
+ # perform guidance
762
+ if do_classifier_free_guidance:
763
+ if self.transformer_2 is not None and (isinstance(self.guidance_scale, (list, tuple))):
764
+ sample_guide_scale = self.guidance_scale[1] if t >= self.transformer_2.config.boundary * self.scheduler.config.num_train_timesteps else self.guidance_scale[0]
765
+ else:
766
+ sample_guide_scale = self.guidance_scale
767
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
768
+ noise_pred = noise_pred_uncond + sample_guide_scale * (noise_pred_text - noise_pred_uncond)
769
+
770
+ # compute the previous noisy sample x_t -> x_t-1
771
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
772
+
773
+ if callback_on_step_end is not None:
774
+ callback_kwargs = {}
775
+ for k in callback_on_step_end_tensor_inputs:
776
+ callback_kwargs[k] = locals()[k]
777
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
778
+
779
+ latents = callback_outputs.pop("latents", latents)
780
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
781
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
782
+
783
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
784
+ progress_bar.update()
785
+ if comfyui_progressbar:
786
+ pbar.update(1)
787
+
788
+ if not (drop_first_motion and r == 0):
789
+ decode_latents = torch.cat([motion_latents, latents], dim=2)
790
+ else:
791
+ decode_latents = torch.cat([ref_image_latentes, latents], dim=2)
792
+
793
+ image = self.vae.decode(decode_latents).sample
794
+ image = image[:, :, -(num_frames):]
795
+ if (drop_first_motion and r == 0):
796
+ image = image[:, :, 3:]
797
+
798
+ overlap_frames_num = min(self.motion_frames, image.shape[2])
799
+ videos_last_frames = torch.cat(
800
+ [
801
+ videos_last_frames[:, :, overlap_frames_num:],
802
+ image[:, :, -overlap_frames_num:]
803
+ ],
804
+ dim=2
805
+ ).to(dtype=motion_latents.dtype, device=motion_latents.device)
806
+ motion_latents = self.vae.encode(videos_last_frames)[0].mode()
807
+ videos.append(image)
808
+
809
+ videos = torch.cat(videos, dim=2)
810
+ videos = (videos / 2 + 0.5).clamp(0, 1)
811
+
812
+ # Offload all models
813
+ self.maybe_free_model_hooks()
814
+
815
+ return WanPipelineOutput(videos=videos.float().cpu())
videox_fun/pipeline/pipeline_wan2_2_ti2v.py ADDED
@@ -0,0 +1,732 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import math
3
+ from dataclasses import dataclass
4
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn.functional as F
9
+ import torchvision.transforms.functional as TF
10
+ from diffusers import FlowMatchEulerDiscreteScheduler
11
+ from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
12
+ from diffusers.image_processor import VaeImageProcessor
13
+ from diffusers.models.embeddings import get_1d_rotary_pos_embed
14
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
15
+ from diffusers.utils import BaseOutput, logging, replace_example_docstring
16
+ from diffusers.utils.torch_utils import randn_tensor
17
+ from diffusers.video_processor import VideoProcessor
18
+ from einops import rearrange
19
+ from PIL import Image
20
+ from transformers import T5Tokenizer
21
+
22
+ from ..models import (AutoencoderKLWan, AutoTokenizer, CLIPModel,
23
+ WanT5EncoderModel, Wan2_2Transformer3DModel)
24
+ from ..utils.fm_solvers import (FlowDPMSolverMultistepScheduler,
25
+ get_sampling_sigmas)
26
+ from ..utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
27
+
28
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
29
+
30
+
31
+ EXAMPLE_DOC_STRING = """
32
+ Examples:
33
+ ```python
34
+ pass
35
+ ```
36
+ """
37
+
38
+
39
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
40
+ def retrieve_timesteps(
41
+ scheduler,
42
+ num_inference_steps: Optional[int] = None,
43
+ device: Optional[Union[str, torch.device]] = None,
44
+ timesteps: Optional[List[int]] = None,
45
+ sigmas: Optional[List[float]] = None,
46
+ **kwargs,
47
+ ):
48
+ """
49
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
50
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
51
+
52
+ Args:
53
+ scheduler (`SchedulerMixin`):
54
+ The scheduler to get timesteps from.
55
+ num_inference_steps (`int`):
56
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
57
+ must be `None`.
58
+ device (`str` or `torch.device`, *optional*):
59
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
60
+ timesteps (`List[int]`, *optional*):
61
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
62
+ `num_inference_steps` and `sigmas` must be `None`.
63
+ sigmas (`List[float]`, *optional*):
64
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
65
+ `num_inference_steps` and `timesteps` must be `None`.
66
+
67
+ Returns:
68
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
69
+ second element is the number of inference steps.
70
+ """
71
+ if timesteps is not None and sigmas is not None:
72
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
73
+ if timesteps is not None:
74
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
75
+ if not accepts_timesteps:
76
+ raise ValueError(
77
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
78
+ f" timestep schedules. Please check whether you are using the correct scheduler."
79
+ )
80
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
81
+ timesteps = scheduler.timesteps
82
+ num_inference_steps = len(timesteps)
83
+ elif sigmas is not None:
84
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
85
+ if not accept_sigmas:
86
+ raise ValueError(
87
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
88
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
89
+ )
90
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
91
+ timesteps = scheduler.timesteps
92
+ num_inference_steps = len(timesteps)
93
+ else:
94
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
95
+ timesteps = scheduler.timesteps
96
+ return timesteps, num_inference_steps
97
+
98
+
99
+ def resize_mask(mask, latent, process_first_frame_only=True):
100
+ latent_size = latent.size()
101
+ batch_size, channels, num_frames, height, width = mask.shape
102
+
103
+ if process_first_frame_only:
104
+ target_size = list(latent_size[2:])
105
+ target_size[0] = 1
106
+ first_frame_resized = F.interpolate(
107
+ mask[:, :, 0:1, :, :],
108
+ size=target_size,
109
+ mode='trilinear',
110
+ align_corners=False
111
+ )
112
+
113
+ target_size = list(latent_size[2:])
114
+ target_size[0] = target_size[0] - 1
115
+ if target_size[0] != 0:
116
+ remaining_frames_resized = F.interpolate(
117
+ mask[:, :, 1:, :, :],
118
+ size=target_size,
119
+ mode='trilinear',
120
+ align_corners=False
121
+ )
122
+ resized_mask = torch.cat([first_frame_resized, remaining_frames_resized], dim=2)
123
+ else:
124
+ resized_mask = first_frame_resized
125
+ else:
126
+ target_size = list(latent_size[2:])
127
+ resized_mask = F.interpolate(
128
+ mask,
129
+ size=target_size,
130
+ mode='trilinear',
131
+ align_corners=False
132
+ )
133
+ return resized_mask
134
+
135
+
136
+ @dataclass
137
+ class WanPipelineOutput(BaseOutput):
138
+ r"""
139
+ Output class for CogVideo pipelines.
140
+
141
+ Args:
142
+ video (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
143
+ List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
144
+ denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
145
+ `(batch_size, num_frames, channels, height, width)`.
146
+ """
147
+
148
+ videos: torch.Tensor
149
+
150
+
151
+ class Wan2_2TI2VPipeline(DiffusionPipeline):
152
+ r"""
153
+ Pipeline for text-to-video generation using Wan.
154
+
155
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
156
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
157
+ """
158
+
159
+ _optional_components = ["transformer_2"]
160
+ model_cpu_offload_seq = "text_encoder->transformer_2->transformer->vae"
161
+
162
+ _callback_tensor_inputs = [
163
+ "latents",
164
+ "prompt_embeds",
165
+ "negative_prompt_embeds",
166
+ ]
167
+
168
+ def __init__(
169
+ self,
170
+ tokenizer: AutoTokenizer,
171
+ text_encoder: WanT5EncoderModel,
172
+ vae: AutoencoderKLWan,
173
+ transformer: Wan2_2Transformer3DModel,
174
+ transformer_2: Wan2_2Transformer3DModel = None,
175
+ scheduler: FlowMatchEulerDiscreteScheduler = None,
176
+ ):
177
+ super().__init__()
178
+
179
+ self.register_modules(
180
+ tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer,
181
+ transformer_2=transformer_2, scheduler=scheduler
182
+ )
183
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae.spatial_compression_ratio)
184
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae.spatial_compression_ratio)
185
+ self.mask_processor = VaeImageProcessor(
186
+ vae_scale_factor=self.vae.spatial_compression_ratio, do_normalize=False, do_binarize=True, do_convert_grayscale=True
187
+ )
188
+
189
+ def _get_t5_prompt_embeds(
190
+ self,
191
+ prompt: Union[str, List[str]] = None,
192
+ num_videos_per_prompt: int = 1,
193
+ max_sequence_length: int = 512,
194
+ device: Optional[torch.device] = None,
195
+ dtype: Optional[torch.dtype] = None,
196
+ ):
197
+ device = device or self._execution_device
198
+ dtype = dtype or self.text_encoder.dtype
199
+
200
+ prompt = [prompt] if isinstance(prompt, str) else prompt
201
+ batch_size = len(prompt)
202
+
203
+ text_inputs = self.tokenizer(
204
+ prompt,
205
+ padding="max_length",
206
+ max_length=max_sequence_length,
207
+ truncation=True,
208
+ add_special_tokens=True,
209
+ return_tensors="pt",
210
+ )
211
+ text_input_ids = text_inputs.input_ids
212
+ prompt_attention_mask = text_inputs.attention_mask
213
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
214
+
215
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
216
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
217
+ logger.warning(
218
+ "The following part of your input was truncated because `max_sequence_length` is set to "
219
+ f" {max_sequence_length} tokens: {removed_text}"
220
+ )
221
+
222
+ seq_lens = prompt_attention_mask.gt(0).sum(dim=1).long()
223
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask.to(device))[0]
224
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
225
+
226
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
227
+ _, seq_len, _ = prompt_embeds.shape
228
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
229
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
230
+
231
+ return [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
232
+
233
+ def encode_prompt(
234
+ self,
235
+ prompt: Union[str, List[str]],
236
+ negative_prompt: Optional[Union[str, List[str]]] = None,
237
+ do_classifier_free_guidance: bool = True,
238
+ num_videos_per_prompt: int = 1,
239
+ prompt_embeds: Optional[torch.Tensor] = None,
240
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
241
+ max_sequence_length: int = 512,
242
+ device: Optional[torch.device] = None,
243
+ dtype: Optional[torch.dtype] = None,
244
+ ):
245
+ r"""
246
+ Encodes the prompt into text encoder hidden states.
247
+
248
+ Args:
249
+ prompt (`str` or `List[str]`, *optional*):
250
+ prompt to be encoded
251
+ negative_prompt (`str` or `List[str]`, *optional*):
252
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
253
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
254
+ less than `1`).
255
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
256
+ Whether to use classifier free guidance or not.
257
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
258
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
259
+ prompt_embeds (`torch.Tensor`, *optional*):
260
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
261
+ provided, text embeddings will be generated from `prompt` input argument.
262
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
263
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
264
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
265
+ argument.
266
+ device: (`torch.device`, *optional*):
267
+ torch device
268
+ dtype: (`torch.dtype`, *optional*):
269
+ torch dtype
270
+ """
271
+ device = device or self._execution_device
272
+
273
+ prompt = [prompt] if isinstance(prompt, str) else prompt
274
+ if prompt is not None:
275
+ batch_size = len(prompt)
276
+ else:
277
+ batch_size = prompt_embeds.shape[0]
278
+
279
+ if prompt_embeds is None:
280
+ prompt_embeds = self._get_t5_prompt_embeds(
281
+ prompt=prompt,
282
+ num_videos_per_prompt=num_videos_per_prompt,
283
+ max_sequence_length=max_sequence_length,
284
+ device=device,
285
+ dtype=dtype,
286
+ )
287
+
288
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
289
+ negative_prompt = negative_prompt or ""
290
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
291
+
292
+ if prompt is not None and type(prompt) is not type(negative_prompt):
293
+ raise TypeError(
294
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
295
+ f" {type(prompt)}."
296
+ )
297
+ elif batch_size != len(negative_prompt):
298
+ raise ValueError(
299
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
300
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
301
+ " the batch size of `prompt`."
302
+ )
303
+
304
+ negative_prompt_embeds = self._get_t5_prompt_embeds(
305
+ prompt=negative_prompt,
306
+ num_videos_per_prompt=num_videos_per_prompt,
307
+ max_sequence_length=max_sequence_length,
308
+ device=device,
309
+ dtype=dtype,
310
+ )
311
+
312
+ return prompt_embeds, negative_prompt_embeds
313
+
314
+ def prepare_latents(
315
+ self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
316
+ ):
317
+ if isinstance(generator, list) and len(generator) != batch_size:
318
+ raise ValueError(
319
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
320
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
321
+ )
322
+
323
+ shape = (
324
+ batch_size,
325
+ num_channels_latents,
326
+ (num_frames - 1) // self.vae.temporal_compression_ratio + 1,
327
+ height // self.vae.spatial_compression_ratio,
328
+ width // self.vae.spatial_compression_ratio,
329
+ )
330
+
331
+ if latents is None:
332
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
333
+ else:
334
+ latents = latents.to(device)
335
+
336
+ # scale the initial noise by the standard deviation required by the scheduler
337
+ if hasattr(self.scheduler, "init_noise_sigma"):
338
+ latents = latents * self.scheduler.init_noise_sigma
339
+ return latents
340
+
341
+ def prepare_mask_latents(
342
+ self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance, noise_aug_strength
343
+ ):
344
+ # resize the mask to latents shape as we concatenate the mask to the latents
345
+ # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
346
+ # and half precision
347
+
348
+ if mask is not None:
349
+ mask = mask.to(device=device, dtype=self.vae.dtype)
350
+ bs = 1
351
+ new_mask = []
352
+ for i in range(0, mask.shape[0], bs):
353
+ mask_bs = mask[i : i + bs]
354
+ mask_bs = self.vae.encode(mask_bs)[0]
355
+ mask_bs = mask_bs.mode()
356
+ new_mask.append(mask_bs)
357
+ mask = torch.cat(new_mask, dim = 0)
358
+ # mask = mask * self.vae.config.scaling_factor
359
+
360
+ if masked_image is not None:
361
+ masked_image = masked_image.to(device=device, dtype=self.vae.dtype)
362
+ bs = 1
363
+ new_mask_pixel_values = []
364
+ for i in range(0, masked_image.shape[0], bs):
365
+ mask_pixel_values_bs = masked_image[i : i + bs]
366
+ mask_pixel_values_bs = self.vae.encode(mask_pixel_values_bs)[0]
367
+ mask_pixel_values_bs = mask_pixel_values_bs.mode()
368
+ new_mask_pixel_values.append(mask_pixel_values_bs)
369
+ masked_image_latents = torch.cat(new_mask_pixel_values, dim = 0)
370
+ # masked_image_latents = masked_image_latents * self.vae.config.scaling_factor
371
+ else:
372
+ masked_image_latents = None
373
+
374
+ return mask, masked_image_latents
375
+
376
+ def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
377
+ frames = self.vae.decode(latents.to(self.vae.dtype)).sample
378
+ frames = (frames / 2 + 0.5).clamp(0, 1)
379
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
380
+ frames = frames.cpu().float().numpy()
381
+ return frames
382
+
383
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
384
+ def prepare_extra_step_kwargs(self, generator, eta):
385
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
386
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
387
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
388
+ # and should be between [0, 1]
389
+
390
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
391
+ extra_step_kwargs = {}
392
+ if accepts_eta:
393
+ extra_step_kwargs["eta"] = eta
394
+
395
+ # check if the scheduler accepts generator
396
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
397
+ if accepts_generator:
398
+ extra_step_kwargs["generator"] = generator
399
+ return extra_step_kwargs
400
+
401
+ # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs
402
+ def check_inputs(
403
+ self,
404
+ prompt,
405
+ height,
406
+ width,
407
+ negative_prompt,
408
+ callback_on_step_end_tensor_inputs,
409
+ prompt_embeds=None,
410
+ negative_prompt_embeds=None,
411
+ ):
412
+ if height % 8 != 0 or width % 8 != 0:
413
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
414
+
415
+ if callback_on_step_end_tensor_inputs is not None and not all(
416
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
417
+ ):
418
+ raise ValueError(
419
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
420
+ )
421
+ if prompt is not None and prompt_embeds is not None:
422
+ raise ValueError(
423
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
424
+ " only forward one of the two."
425
+ )
426
+ elif prompt is None and prompt_embeds is None:
427
+ raise ValueError(
428
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
429
+ )
430
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
431
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
432
+
433
+ if prompt is not None and negative_prompt_embeds is not None:
434
+ raise ValueError(
435
+ f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
436
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
437
+ )
438
+
439
+ if negative_prompt is not None and negative_prompt_embeds is not None:
440
+ raise ValueError(
441
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
442
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
443
+ )
444
+
445
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
446
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
447
+ raise ValueError(
448
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
449
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
450
+ f" {negative_prompt_embeds.shape}."
451
+ )
452
+
453
+ @property
454
+ def guidance_scale(self):
455
+ return self._guidance_scale
456
+
457
+ @property
458
+ def num_timesteps(self):
459
+ return self._num_timesteps
460
+
461
+ @property
462
+ def attention_kwargs(self):
463
+ return self._attention_kwargs
464
+
465
+ @property
466
+ def interrupt(self):
467
+ return self._interrupt
468
+
469
+ @torch.no_grad()
470
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
471
+ def __call__(
472
+ self,
473
+ prompt: Optional[Union[str, List[str]]] = None,
474
+ negative_prompt: Optional[Union[str, List[str]]] = None,
475
+ height: int = 480,
476
+ width: int = 720,
477
+ video: Union[torch.FloatTensor] = None,
478
+ mask_video: Union[torch.FloatTensor] = None,
479
+ num_frames: int = 49,
480
+ num_inference_steps: int = 50,
481
+ timesteps: Optional[List[int]] = None,
482
+ guidance_scale: float = 6,
483
+ num_videos_per_prompt: int = 1,
484
+ eta: float = 0.0,
485
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
486
+ latents: Optional[torch.FloatTensor] = None,
487
+ prompt_embeds: Optional[torch.FloatTensor] = None,
488
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
489
+ output_type: str = "numpy",
490
+ return_dict: bool = False,
491
+ callback_on_step_end: Optional[
492
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
493
+ ] = None,
494
+ attention_kwargs: Optional[Dict[str, Any]] = None,
495
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
496
+ max_sequence_length: int = 512,
497
+ boundary: float = 0.875,
498
+ comfyui_progressbar: bool = False,
499
+ shift: int = 5,
500
+ ) -> Union[WanPipelineOutput, Tuple]:
501
+ """
502
+ Function invoked when calling the pipeline for generation.
503
+ Args:
504
+
505
+ Examples:
506
+
507
+ Returns:
508
+
509
+ """
510
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
511
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
512
+ num_videos_per_prompt = 1
513
+
514
+ # 1. Check inputs. Raise error if not correct
515
+ self.check_inputs(
516
+ prompt,
517
+ height,
518
+ width,
519
+ negative_prompt,
520
+ callback_on_step_end_tensor_inputs,
521
+ prompt_embeds,
522
+ negative_prompt_embeds,
523
+ )
524
+ self._guidance_scale = guidance_scale
525
+ self._attention_kwargs = attention_kwargs
526
+ self._interrupt = False
527
+
528
+ # 2. Default call parameters
529
+ if prompt is not None and isinstance(prompt, str):
530
+ batch_size = 1
531
+ elif prompt is not None and isinstance(prompt, list):
532
+ batch_size = len(prompt)
533
+ else:
534
+ batch_size = prompt_embeds.shape[0]
535
+
536
+ device = self._execution_device
537
+ weight_dtype = self.text_encoder.dtype
538
+
539
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
540
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
541
+ # corresponds to doing no classifier free guidance.
542
+ do_classifier_free_guidance = guidance_scale > 1.0
543
+
544
+ # 3. Encode input prompt
545
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
546
+ prompt,
547
+ negative_prompt,
548
+ do_classifier_free_guidance,
549
+ num_videos_per_prompt=num_videos_per_prompt,
550
+ prompt_embeds=prompt_embeds,
551
+ negative_prompt_embeds=negative_prompt_embeds,
552
+ max_sequence_length=max_sequence_length,
553
+ device=device,
554
+ )
555
+ if do_classifier_free_guidance:
556
+ in_prompt_embeds = negative_prompt_embeds + prompt_embeds
557
+ else:
558
+ in_prompt_embeds = prompt_embeds
559
+
560
+ # 4. Prepare timesteps
561
+ if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler):
562
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps, mu=1)
563
+ elif isinstance(self.scheduler, FlowUniPCMultistepScheduler):
564
+ self.scheduler.set_timesteps(num_inference_steps, device=device, shift=shift)
565
+ timesteps = self.scheduler.timesteps
566
+ elif isinstance(self.scheduler, FlowDPMSolverMultistepScheduler):
567
+ sampling_sigmas = get_sampling_sigmas(num_inference_steps, shift)
568
+ timesteps, _ = retrieve_timesteps(
569
+ self.scheduler,
570
+ device=device,
571
+ sigmas=sampling_sigmas)
572
+ else:
573
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
574
+ self._num_timesteps = len(timesteps)
575
+ if comfyui_progressbar:
576
+ from comfy.utils import ProgressBar
577
+ pbar = ProgressBar(num_inference_steps + 2)
578
+
579
+ # 5. Prepare latents.
580
+ if video is not None:
581
+ video_length = video.shape[2]
582
+ init_video = self.image_processor.preprocess(rearrange(video, "b c f h w -> (b f) c h w"), height=height, width=width)
583
+ init_video = init_video.to(dtype=torch.float32)
584
+ init_video = rearrange(init_video, "(b f) c h w -> b c f h w", f=video_length)
585
+ else:
586
+ init_video = None
587
+
588
+ latent_channels = self.vae.config.latent_channels
589
+ latents = self.prepare_latents(
590
+ batch_size * num_videos_per_prompt,
591
+ latent_channels,
592
+ num_frames,
593
+ height,
594
+ width,
595
+ weight_dtype,
596
+ device,
597
+ generator,
598
+ latents,
599
+ )
600
+ if comfyui_progressbar:
601
+ pbar.update(1)
602
+
603
+ # Prepare mask latent variables
604
+ if init_video is not None and not (mask_video == 255).all():
605
+ bs, _, video_length, height, width = video.size()
606
+ mask_condition = self.mask_processor.preprocess(rearrange(mask_video, "b c f h w -> (b f) c h w"), height=height, width=width)
607
+ mask_condition = mask_condition.to(dtype=torch.float32)
608
+ mask_condition = rearrange(mask_condition, "(b f) c h w -> b c f h w", f=video_length)
609
+
610
+ masked_video = init_video * (torch.tile(mask_condition, [1, 3, 1, 1, 1]) < 0.5)
611
+ _, masked_video_latents = self.prepare_mask_latents(
612
+ None,
613
+ masked_video,
614
+ batch_size,
615
+ height,
616
+ width,
617
+ weight_dtype,
618
+ device,
619
+ generator,
620
+ do_classifier_free_guidance,
621
+ noise_aug_strength=None,
622
+ )
623
+
624
+ mask_condition = torch.concat(
625
+ [
626
+ torch.repeat_interleave(mask_condition[:, :, 0:1], repeats=4, dim=2),
627
+ mask_condition[:, :, 1:]
628
+ ], dim=2
629
+ )
630
+ mask_condition = mask_condition.view(bs, mask_condition.shape[2] // 4, 4, height, width)
631
+ mask_condition = mask_condition.transpose(1, 2)
632
+
633
+ mask = F.interpolate(mask_condition[:, :1], size=latents.size()[-3:], mode='trilinear', align_corners=True).to(device, weight_dtype)
634
+ latents = (1 - mask) * masked_video_latents + mask * latents
635
+ else:
636
+ init_video = None
637
+
638
+ if comfyui_progressbar:
639
+ pbar.update(1)
640
+
641
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
642
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
643
+
644
+ target_shape = (self.vae.latent_channels, (num_frames - 1) // self.vae.temporal_compression_ratio + 1, width // self.vae.spatial_compression_ratio, height // self.vae.spatial_compression_ratio)
645
+ seq_len = math.ceil((target_shape[2] * target_shape[3]) / (self.transformer.config.patch_size[1] * self.transformer.config.patch_size[2]) * target_shape[1])
646
+ # 7. Denoising loop
647
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
648
+ self.transformer.num_inference_steps = num_inference_steps
649
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
650
+ for i, t in enumerate(timesteps):
651
+ self.transformer.current_steps = i
652
+
653
+ if self.interrupt:
654
+ continue
655
+
656
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
657
+ if hasattr(self.scheduler, "scale_model_input"):
658
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
659
+
660
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
661
+ if init_video is not None:
662
+ temp_ts = ((mask[0][0][:, ::2, ::2]) * t).flatten()
663
+ temp_ts = torch.cat([
664
+ temp_ts,
665
+ temp_ts.new_ones(seq_len - temp_ts.size(0)) * t
666
+ ])
667
+ temp_ts = temp_ts.unsqueeze(0)
668
+ timestep = temp_ts.expand(latent_model_input.shape[0], temp_ts.size(1))
669
+ else:
670
+ timestep = t.expand(latent_model_input.shape[0])
671
+
672
+ if self.transformer_2 is not None:
673
+ if t >= boundary * self.scheduler.config.num_train_timesteps:
674
+ local_transformer = self.transformer_2
675
+ else:
676
+ local_transformer = self.transformer
677
+ else:
678
+ local_transformer = self.transformer
679
+
680
+ # predict noise model_output
681
+ with torch.cuda.amp.autocast(dtype=weight_dtype), torch.cuda.device(device=device):
682
+ noise_pred = local_transformer(
683
+ x=latent_model_input,
684
+ context=in_prompt_embeds,
685
+ t=timestep,
686
+ seq_len=seq_len,
687
+ )
688
+
689
+ # perform guidance
690
+ if do_classifier_free_guidance:
691
+ if self.transformer_2 is not None and (isinstance(self.guidance_scale, (list, tuple))):
692
+ sample_guide_scale = self.guidance_scale[1] if t >= self.transformer_2.config.boundary * self.scheduler.config.num_train_timesteps else self.guidance_scale[0]
693
+ else:
694
+ sample_guide_scale = self.guidance_scale
695
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
696
+ noise_pred = noise_pred_uncond + sample_guide_scale * (noise_pred_text - noise_pred_uncond)
697
+
698
+ # compute the previous noisy sample x_t -> x_t-1
699
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
700
+ if init_video is not None:
701
+ latents = (1 - mask) * masked_video_latents + mask * latents
702
+
703
+ if callback_on_step_end is not None:
704
+ callback_kwargs = {}
705
+ for k in callback_on_step_end_tensor_inputs:
706
+ callback_kwargs[k] = locals()[k]
707
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
708
+
709
+ latents = callback_outputs.pop("latents", latents)
710
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
711
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
712
+
713
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
714
+ progress_bar.update()
715
+ if comfyui_progressbar:
716
+ pbar.update(1)
717
+
718
+ if output_type == "numpy":
719
+ video = self.decode_latents(latents)
720
+ elif not output_type == "latent":
721
+ video = self.decode_latents(latents)
722
+ video = self.video_processor.postprocess_video(video=video, output_type=output_type)
723
+ else:
724
+ video = latents
725
+
726
+ # Offload all models
727
+ self.maybe_free_model_hooks()
728
+
729
+ if not return_dict:
730
+ video = torch.from_numpy(video)
731
+
732
+ return WanPipelineOutput(videos=video)
videox_fun/pipeline/pipeline_wan2_2_vace_fun.py ADDED
@@ -0,0 +1,801 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import math
3
+ from dataclasses import dataclass
4
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn.functional as F
9
+ import torchvision.transforms.functional as TF
10
+ from diffusers import FlowMatchEulerDiscreteScheduler
11
+ from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
12
+ from diffusers.image_processor import VaeImageProcessor
13
+ from diffusers.models.embeddings import get_1d_rotary_pos_embed
14
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
15
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
16
+ from diffusers.utils import BaseOutput, logging, replace_example_docstring
17
+ from diffusers.utils.torch_utils import randn_tensor
18
+ from diffusers.video_processor import VideoProcessor
19
+ from einops import rearrange
20
+ from PIL import Image
21
+ from transformers import T5Tokenizer
22
+
23
+ from ..models import (AutoencoderKLWan, AutoTokenizer,
24
+ WanT5EncoderModel, VaceWanTransformer3DModel)
25
+ from ..utils.fm_solvers import (FlowDPMSolverMultistepScheduler,
26
+ get_sampling_sigmas)
27
+ from ..utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
28
+
29
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
30
+
31
+
32
+ EXAMPLE_DOC_STRING = """
33
+ Examples:
34
+ ```python
35
+ pass
36
+ ```
37
+ """
38
+
39
+
40
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
41
+ def retrieve_timesteps(
42
+ scheduler,
43
+ num_inference_steps: Optional[int] = None,
44
+ device: Optional[Union[str, torch.device]] = None,
45
+ timesteps: Optional[List[int]] = None,
46
+ sigmas: Optional[List[float]] = None,
47
+ **kwargs,
48
+ ):
49
+ """
50
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
51
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
52
+
53
+ Args:
54
+ scheduler (`SchedulerMixin`):
55
+ The scheduler to get timesteps from.
56
+ num_inference_steps (`int`):
57
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
58
+ must be `None`.
59
+ device (`str` or `torch.device`, *optional*):
60
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
61
+ timesteps (`List[int]`, *optional*):
62
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
63
+ `num_inference_steps` and `sigmas` must be `None`.
64
+ sigmas (`List[float]`, *optional*):
65
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
66
+ `num_inference_steps` and `timesteps` must be `None`.
67
+
68
+ Returns:
69
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
70
+ second element is the number of inference steps.
71
+ """
72
+ if timesteps is not None and sigmas is not None:
73
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
74
+ if timesteps is not None:
75
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
76
+ if not accepts_timesteps:
77
+ raise ValueError(
78
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
79
+ f" timestep schedules. Please check whether you are using the correct scheduler."
80
+ )
81
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
82
+ timesteps = scheduler.timesteps
83
+ num_inference_steps = len(timesteps)
84
+ elif sigmas is not None:
85
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
86
+ if not accept_sigmas:
87
+ raise ValueError(
88
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
89
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
90
+ )
91
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
92
+ timesteps = scheduler.timesteps
93
+ num_inference_steps = len(timesteps)
94
+ else:
95
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
96
+ timesteps = scheduler.timesteps
97
+ return timesteps, num_inference_steps
98
+
99
+
100
+ def resize_mask(mask, latent, process_first_frame_only=True):
101
+ latent_size = latent.size()
102
+ batch_size, channels, num_frames, height, width = mask.shape
103
+
104
+ if process_first_frame_only:
105
+ target_size = list(latent_size[2:])
106
+ target_size[0] = 1
107
+ first_frame_resized = F.interpolate(
108
+ mask[:, :, 0:1, :, :],
109
+ size=target_size,
110
+ mode='trilinear',
111
+ align_corners=False
112
+ )
113
+
114
+ target_size = list(latent_size[2:])
115
+ target_size[0] = target_size[0] - 1
116
+ if target_size[0] != 0:
117
+ remaining_frames_resized = F.interpolate(
118
+ mask[:, :, 1:, :, :],
119
+ size=target_size,
120
+ mode='trilinear',
121
+ align_corners=False
122
+ )
123
+ resized_mask = torch.cat([first_frame_resized, remaining_frames_resized], dim=2)
124
+ else:
125
+ resized_mask = first_frame_resized
126
+ else:
127
+ target_size = list(latent_size[2:])
128
+ resized_mask = F.interpolate(
129
+ mask,
130
+ size=target_size,
131
+ mode='trilinear',
132
+ align_corners=False
133
+ )
134
+ return resized_mask
135
+
136
+
137
+ @dataclass
138
+ class WanPipelineOutput(BaseOutput):
139
+ r"""
140
+ Output class for CogVideo pipelines.
141
+
142
+ Args:
143
+ video (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
144
+ List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
145
+ denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
146
+ `(batch_size, num_frames, channels, height, width)`.
147
+ """
148
+
149
+ videos: torch.Tensor
150
+
151
+
152
+ class Wan2_2VaceFunPipeline(DiffusionPipeline):
153
+ r"""
154
+ Pipeline for text-to-video generation using Wan.
155
+
156
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
157
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
158
+ """
159
+
160
+ _optional_components = ["transformer_2"]
161
+ model_cpu_offload_seq = "text_encoder->transformer_2->transformer->vae"
162
+
163
+ _callback_tensor_inputs = [
164
+ "latents",
165
+ "prompt_embeds",
166
+ "negative_prompt_embeds",
167
+ ]
168
+
169
+ def __init__(
170
+ self,
171
+ tokenizer: AutoTokenizer,
172
+ text_encoder: WanT5EncoderModel,
173
+ vae: AutoencoderKLWan,
174
+ transformer: VaceWanTransformer3DModel,
175
+ transformer_2: VaceWanTransformer3DModel = None,
176
+ scheduler: FlowMatchEulerDiscreteScheduler = None,
177
+ ):
178
+ super().__init__()
179
+
180
+ self.register_modules(
181
+ tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer,
182
+ transformer_2=transformer_2, scheduler=scheduler
183
+ )
184
+
185
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae.spatial_compression_ratio)
186
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae.spatial_compression_ratio)
187
+ self.mask_processor = VaeImageProcessor(
188
+ vae_scale_factor=self.vae.spatial_compression_ratio, do_normalize=False, do_binarize=True, do_convert_grayscale=True
189
+ )
190
+
191
+ def _get_t5_prompt_embeds(
192
+ self,
193
+ prompt: Union[str, List[str]] = None,
194
+ num_videos_per_prompt: int = 1,
195
+ max_sequence_length: int = 512,
196
+ device: Optional[torch.device] = None,
197
+ dtype: Optional[torch.dtype] = None,
198
+ ):
199
+ device = device or self._execution_device
200
+ dtype = dtype or self.text_encoder.dtype
201
+
202
+ prompt = [prompt] if isinstance(prompt, str) else prompt
203
+ batch_size = len(prompt)
204
+
205
+ text_inputs = self.tokenizer(
206
+ prompt,
207
+ padding="max_length",
208
+ max_length=max_sequence_length,
209
+ truncation=True,
210
+ add_special_tokens=True,
211
+ return_tensors="pt",
212
+ )
213
+ text_input_ids = text_inputs.input_ids
214
+ prompt_attention_mask = text_inputs.attention_mask
215
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
216
+
217
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
218
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
219
+ logger.warning(
220
+ "The following part of your input was truncated because `max_sequence_length` is set to "
221
+ f" {max_sequence_length} tokens: {removed_text}"
222
+ )
223
+
224
+ seq_lens = prompt_attention_mask.gt(0).sum(dim=1).long()
225
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask.to(device))[0]
226
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
227
+
228
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
229
+ _, seq_len, _ = prompt_embeds.shape
230
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
231
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
232
+
233
+ return [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
234
+
235
+ def encode_prompt(
236
+ self,
237
+ prompt: Union[str, List[str]],
238
+ negative_prompt: Optional[Union[str, List[str]]] = None,
239
+ do_classifier_free_guidance: bool = True,
240
+ num_videos_per_prompt: int = 1,
241
+ prompt_embeds: Optional[torch.Tensor] = None,
242
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
243
+ max_sequence_length: int = 512,
244
+ device: Optional[torch.device] = None,
245
+ dtype: Optional[torch.dtype] = None,
246
+ ):
247
+ r"""
248
+ Encodes the prompt into text encoder hidden states.
249
+
250
+ Args:
251
+ prompt (`str` or `List[str]`, *optional*):
252
+ prompt to be encoded
253
+ negative_prompt (`str` or `List[str]`, *optional*):
254
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
255
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
256
+ less than `1`).
257
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
258
+ Whether to use classifier free guidance or not.
259
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
260
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
261
+ prompt_embeds (`torch.Tensor`, *optional*):
262
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
263
+ provided, text embeddings will be generated from `prompt` input argument.
264
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
265
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
266
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
267
+ argument.
268
+ device: (`torch.device`, *optional*):
269
+ torch device
270
+ dtype: (`torch.dtype`, *optional*):
271
+ torch dtype
272
+ """
273
+ device = device or self._execution_device
274
+
275
+ prompt = [prompt] if isinstance(prompt, str) else prompt
276
+ if prompt is not None:
277
+ batch_size = len(prompt)
278
+ else:
279
+ batch_size = prompt_embeds.shape[0]
280
+
281
+ if prompt_embeds is None:
282
+ prompt_embeds = self._get_t5_prompt_embeds(
283
+ prompt=prompt,
284
+ num_videos_per_prompt=num_videos_per_prompt,
285
+ max_sequence_length=max_sequence_length,
286
+ device=device,
287
+ dtype=dtype,
288
+ )
289
+
290
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
291
+ negative_prompt = negative_prompt or ""
292
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
293
+
294
+ if prompt is not None and type(prompt) is not type(negative_prompt):
295
+ raise TypeError(
296
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
297
+ f" {type(prompt)}."
298
+ )
299
+ elif batch_size != len(negative_prompt):
300
+ raise ValueError(
301
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
302
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
303
+ " the batch size of `prompt`."
304
+ )
305
+
306
+ negative_prompt_embeds = self._get_t5_prompt_embeds(
307
+ prompt=negative_prompt,
308
+ num_videos_per_prompt=num_videos_per_prompt,
309
+ max_sequence_length=max_sequence_length,
310
+ device=device,
311
+ dtype=dtype,
312
+ )
313
+
314
+ return prompt_embeds, negative_prompt_embeds
315
+
316
+ def prepare_latents(
317
+ self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None, num_length_latents=None
318
+ ):
319
+ if isinstance(generator, list) and len(generator) != batch_size:
320
+ raise ValueError(
321
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
322
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
323
+ )
324
+
325
+ shape = (
326
+ batch_size,
327
+ num_channels_latents,
328
+ (num_frames - 1) // self.vae.temporal_compression_ratio + 1 if num_length_latents is None else num_length_latents,
329
+ height // self.vae.spatial_compression_ratio,
330
+ width // self.vae.spatial_compression_ratio,
331
+ )
332
+
333
+ if latents is None:
334
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
335
+ else:
336
+ latents = latents.to(device)
337
+
338
+ # scale the initial noise by the standard deviation required by the scheduler
339
+ if hasattr(self.scheduler, "init_noise_sigma"):
340
+ latents = latents * self.scheduler.init_noise_sigma
341
+ return latents
342
+
343
+ def vace_encode_frames(self, frames, ref_images, masks=None, vae=None):
344
+ vae = self.vae if vae is None else vae
345
+ weight_dtype = frames.dtype
346
+ if ref_images is None:
347
+ ref_images = [None] * len(frames)
348
+ else:
349
+ assert len(frames) == len(ref_images)
350
+
351
+ if masks is None:
352
+ latents = vae.encode(frames)[0].mode()
353
+ else:
354
+ masks = [torch.where(m > 0.5, 1.0, 0.0).to(weight_dtype) for m in masks]
355
+ inactive = [i * (1 - m) + 0 * m for i, m in zip(frames, masks)]
356
+ reactive = [i * m + 0 * (1 - m) for i, m in zip(frames, masks)]
357
+ inactive = vae.encode(inactive)[0].mode()
358
+ reactive = vae.encode(reactive)[0].mode()
359
+ latents = [torch.cat((u, c), dim=0) for u, c in zip(inactive, reactive)]
360
+
361
+ cat_latents = []
362
+ for latent, refs in zip(latents, ref_images):
363
+ if refs is not None:
364
+ if masks is None:
365
+ ref_latent = vae.encode(refs)[0].mode()
366
+ else:
367
+ ref_latent = vae.encode(refs)[0].mode()
368
+ ref_latent = [torch.cat((u, torch.zeros_like(u)), dim=0) for u in ref_latent]
369
+ assert all([x.shape[1] == 1 for x in ref_latent])
370
+ latent = torch.cat([*ref_latent, latent], dim=1)
371
+ cat_latents.append(latent)
372
+ return cat_latents
373
+
374
+ def vace_encode_masks(self, masks, ref_images=None, vae_stride=[4, 8, 8]):
375
+ if ref_images is None:
376
+ ref_images = [None] * len(masks)
377
+ else:
378
+ assert len(masks) == len(ref_images)
379
+
380
+ result_masks = []
381
+ for mask, refs in zip(masks, ref_images):
382
+ c, depth, height, width = mask.shape
383
+ new_depth = int((depth + 3) // vae_stride[0])
384
+ height = 2 * (int(height) // (vae_stride[1] * 2))
385
+ width = 2 * (int(width) // (vae_stride[2] * 2))
386
+
387
+ # reshape
388
+ mask = mask[0, :, :, :]
389
+ mask = mask.view(
390
+ depth, height, vae_stride[1], width, vae_stride[1]
391
+ ) # depth, height, 8, width, 8
392
+ mask = mask.permute(2, 4, 0, 1, 3) # 8, 8, depth, height, width
393
+ mask = mask.reshape(
394
+ vae_stride[1] * vae_stride[2], depth, height, width
395
+ ) # 8*8, depth, height, width
396
+
397
+ # interpolation
398
+ mask = F.interpolate(mask.unsqueeze(0), size=(new_depth, height, width), mode='nearest-exact').squeeze(0)
399
+
400
+ if refs is not None:
401
+ length = len(refs)
402
+ mask_pad = torch.zeros_like(mask[:, :length, :, :])
403
+ mask = torch.cat((mask_pad, mask), dim=1)
404
+ result_masks.append(mask)
405
+ return result_masks
406
+
407
+ def vace_latent(self, z, m):
408
+ return [torch.cat([zz, mm], dim=0) for zz, mm in zip(z, m)]
409
+
410
+ def prepare_control_latents(
411
+ self, control, control_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
412
+ ):
413
+ # resize the control to latents shape as we concatenate the control to the latents
414
+ # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
415
+ # and half precision
416
+
417
+ if control is not None:
418
+ control = control.to(device=device, dtype=dtype)
419
+ bs = 1
420
+ new_control = []
421
+ for i in range(0, control.shape[0], bs):
422
+ control_bs = control[i : i + bs]
423
+ control_bs = self.vae.encode(control_bs)[0]
424
+ control_bs = control_bs.mode()
425
+ new_control.append(control_bs)
426
+ control = torch.cat(new_control, dim = 0)
427
+
428
+ if control_image is not None:
429
+ control_image = control_image.to(device=device, dtype=dtype)
430
+ bs = 1
431
+ new_control_pixel_values = []
432
+ for i in range(0, control_image.shape[0], bs):
433
+ control_pixel_values_bs = control_image[i : i + bs]
434
+ control_pixel_values_bs = self.vae.encode(control_pixel_values_bs)[0]
435
+ control_pixel_values_bs = control_pixel_values_bs.mode()
436
+ new_control_pixel_values.append(control_pixel_values_bs)
437
+ control_image_latents = torch.cat(new_control_pixel_values, dim = 0)
438
+ else:
439
+ control_image_latents = None
440
+
441
+ return control, control_image_latents
442
+
443
+ def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
444
+ frames = self.vae.decode(latents.to(self.vae.dtype)).sample
445
+ frames = (frames / 2 + 0.5).clamp(0, 1)
446
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
447
+ frames = frames.cpu().float().numpy()
448
+ return frames
449
+
450
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
451
+ def prepare_extra_step_kwargs(self, generator, eta):
452
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
453
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
454
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
455
+ # and should be between [0, 1]
456
+
457
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
458
+ extra_step_kwargs = {}
459
+ if accepts_eta:
460
+ extra_step_kwargs["eta"] = eta
461
+
462
+ # check if the scheduler accepts generator
463
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
464
+ if accepts_generator:
465
+ extra_step_kwargs["generator"] = generator
466
+ return extra_step_kwargs
467
+
468
+ # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs
469
+ def check_inputs(
470
+ self,
471
+ prompt,
472
+ height,
473
+ width,
474
+ negative_prompt,
475
+ callback_on_step_end_tensor_inputs,
476
+ prompt_embeds=None,
477
+ negative_prompt_embeds=None,
478
+ ):
479
+ if height % 8 != 0 or width % 8 != 0:
480
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
481
+
482
+ if callback_on_step_end_tensor_inputs is not None and not all(
483
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
484
+ ):
485
+ raise ValueError(
486
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
487
+ )
488
+ if prompt is not None and prompt_embeds is not None:
489
+ raise ValueError(
490
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
491
+ " only forward one of the two."
492
+ )
493
+ elif prompt is None and prompt_embeds is None:
494
+ raise ValueError(
495
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
496
+ )
497
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
498
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
499
+
500
+ if prompt is not None and negative_prompt_embeds is not None:
501
+ raise ValueError(
502
+ f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
503
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
504
+ )
505
+
506
+ if negative_prompt is not None and negative_prompt_embeds is not None:
507
+ raise ValueError(
508
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
509
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
510
+ )
511
+
512
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
513
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
514
+ raise ValueError(
515
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
516
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
517
+ f" {negative_prompt_embeds.shape}."
518
+ )
519
+
520
+ @property
521
+ def guidance_scale(self):
522
+ return self._guidance_scale
523
+
524
+ @property
525
+ def num_timesteps(self):
526
+ return self._num_timesteps
527
+
528
+ @property
529
+ def attention_kwargs(self):
530
+ return self._attention_kwargs
531
+
532
+ @property
533
+ def interrupt(self):
534
+ return self._interrupt
535
+
536
+ @torch.no_grad()
537
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
538
+ def __call__(
539
+ self,
540
+ prompt: Optional[Union[str, List[str]]] = None,
541
+ negative_prompt: Optional[Union[str, List[str]]] = None,
542
+ height: int = 480,
543
+ width: int = 720,
544
+ video: Union[torch.FloatTensor] = None,
545
+ mask_video: Union[torch.FloatTensor] = None,
546
+ control_video: Union[torch.FloatTensor] = None,
547
+ subject_ref_images: Union[torch.FloatTensor] = None,
548
+ num_frames: int = 49,
549
+ num_inference_steps: int = 50,
550
+ timesteps: Optional[List[int]] = None,
551
+ guidance_scale: float = 6,
552
+ num_videos_per_prompt: int = 1,
553
+ eta: float = 0.0,
554
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
555
+ latents: Optional[torch.FloatTensor] = None,
556
+ prompt_embeds: Optional[torch.FloatTensor] = None,
557
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
558
+ output_type: str = "numpy",
559
+ return_dict: bool = False,
560
+ callback_on_step_end: Optional[
561
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
562
+ ] = None,
563
+ attention_kwargs: Optional[Dict[str, Any]] = None,
564
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
565
+ max_sequence_length: int = 512,
566
+ boundary: float = 0.875,
567
+ comfyui_progressbar: bool = False,
568
+ shift: int = 5,
569
+ vace_context_scale: float = 1.0,
570
+ ) -> Union[WanPipelineOutput, Tuple]:
571
+ """
572
+ Function invoked when calling the pipeline for generation.
573
+ Args:
574
+
575
+ Examples:
576
+
577
+ Returns:
578
+
579
+ """
580
+
581
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
582
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
583
+ num_videos_per_prompt = 1
584
+
585
+ # 1. Check inputs. Raise error if not correct
586
+ self.check_inputs(
587
+ prompt,
588
+ height,
589
+ width,
590
+ negative_prompt,
591
+ callback_on_step_end_tensor_inputs,
592
+ prompt_embeds,
593
+ negative_prompt_embeds,
594
+ )
595
+ self._guidance_scale = guidance_scale
596
+ self._attention_kwargs = attention_kwargs
597
+ self._interrupt = False
598
+
599
+ # 2. Default call parameters
600
+ if prompt is not None and isinstance(prompt, str):
601
+ batch_size = 1
602
+ elif prompt is not None and isinstance(prompt, list):
603
+ batch_size = len(prompt)
604
+ else:
605
+ batch_size = prompt_embeds.shape[0]
606
+
607
+ device = self._execution_device
608
+ weight_dtype = self.text_encoder.dtype
609
+
610
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
611
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
612
+ # corresponds to doing no classifier free guidance.
613
+ do_classifier_free_guidance = guidance_scale > 1.0
614
+
615
+ # 3. Encode input prompt
616
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
617
+ prompt,
618
+ negative_prompt,
619
+ do_classifier_free_guidance,
620
+ num_videos_per_prompt=num_videos_per_prompt,
621
+ prompt_embeds=prompt_embeds,
622
+ negative_prompt_embeds=negative_prompt_embeds,
623
+ max_sequence_length=max_sequence_length,
624
+ device=device,
625
+ )
626
+ if do_classifier_free_guidance:
627
+ in_prompt_embeds = negative_prompt_embeds + prompt_embeds
628
+ else:
629
+ in_prompt_embeds = prompt_embeds
630
+
631
+ # 4. Prepare timesteps
632
+ if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler):
633
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps, mu=1)
634
+ elif isinstance(self.scheduler, FlowUniPCMultistepScheduler):
635
+ self.scheduler.set_timesteps(num_inference_steps, device=device, shift=shift)
636
+ timesteps = self.scheduler.timesteps
637
+ elif isinstance(self.scheduler, FlowDPMSolverMultistepScheduler):
638
+ sampling_sigmas = get_sampling_sigmas(num_inference_steps, shift)
639
+ timesteps, _ = retrieve_timesteps(
640
+ self.scheduler,
641
+ device=device,
642
+ sigmas=sampling_sigmas)
643
+ else:
644
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
645
+ self._num_timesteps = len(timesteps)
646
+ if comfyui_progressbar:
647
+ from comfy.utils import ProgressBar
648
+ pbar = ProgressBar(num_inference_steps + 2)
649
+
650
+ latent_channels = self.vae.config.latent_channels
651
+
652
+ if comfyui_progressbar:
653
+ pbar.update(1)
654
+
655
+ # Prepare mask latent variables
656
+ if mask_video is not None:
657
+ bs, _, video_length, height, width = video.size()
658
+ mask_condition = self.mask_processor.preprocess(rearrange(mask_video, "b c f h w -> (b f) c h w"), height=height, width=width)
659
+ mask_condition = mask_condition.to(dtype=torch.float32)
660
+ mask_condition = rearrange(mask_condition, "(b f) c h w -> b c f h w", f=video_length)
661
+ mask_condition = torch.tile(mask_condition, [1, 3, 1, 1, 1]).to(dtype=weight_dtype, device=device)
662
+
663
+
664
+ if control_video is not None:
665
+ video_length = control_video.shape[2]
666
+ control_video = self.image_processor.preprocess(rearrange(control_video, "b c f h w -> (b f) c h w"), height=height, width=width)
667
+ control_video = control_video.to(dtype=torch.float32)
668
+ input_video = rearrange(control_video, "(b f) c h w -> b c f h w", f=video_length)
669
+
670
+ input_video = input_video.to(dtype=weight_dtype, device=device)
671
+
672
+ elif video is not None:
673
+ video_length = video.shape[2]
674
+ init_video = self.image_processor.preprocess(rearrange(video, "b c f h w -> (b f) c h w"), height=height, width=width)
675
+ init_video = init_video.to(dtype=torch.float32)
676
+ init_video = rearrange(init_video, "(b f) c h w -> b c f h w", f=video_length).to(dtype=weight_dtype, device=device)
677
+
678
+ input_video = init_video * (mask_condition < 0.5)
679
+ input_video = input_video.to(dtype=weight_dtype, device=device)
680
+
681
+ if subject_ref_images is not None:
682
+ video_length = subject_ref_images.shape[2]
683
+ subject_ref_images = self.image_processor.preprocess(rearrange(subject_ref_images, "b c f h w -> (b f) c h w"), height=height, width=width)
684
+ subject_ref_images = subject_ref_images.to(dtype=torch.float32)
685
+ subject_ref_images = rearrange(subject_ref_images, "(b f) c h w -> b c f h w", f=video_length)
686
+ subject_ref_images = subject_ref_images.to(dtype=weight_dtype, device=device)
687
+
688
+ bs, c, f, h, w = subject_ref_images.size()
689
+ new_subject_ref_images = []
690
+ for i in range(bs):
691
+ new_subject_ref_images.append([])
692
+ for j in range(f):
693
+ new_subject_ref_images[i].append(subject_ref_images[i, :, j:j+1])
694
+ subject_ref_images = new_subject_ref_images
695
+
696
+ vace_latents = self.vace_encode_frames(input_video, subject_ref_images, masks=mask_condition, vae=self.vae)
697
+ mask_latents = self.vace_encode_masks(mask_condition, subject_ref_images, vae_stride=[4, self.vae.spatial_compression_ratio, self.vae.spatial_compression_ratio])
698
+ vace_context = self.vace_latent(vace_latents, mask_latents)
699
+
700
+ # 5. Prepare latents.
701
+ latents = self.prepare_latents(
702
+ batch_size * num_videos_per_prompt,
703
+ latent_channels,
704
+ num_frames,
705
+ height,
706
+ width,
707
+ weight_dtype,
708
+ device,
709
+ generator,
710
+ latents,
711
+ num_length_latents=vace_latents[0].size(1)
712
+ )
713
+
714
+ if comfyui_progressbar:
715
+ pbar.update(1)
716
+
717
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
718
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
719
+
720
+ target_shape = (self.vae.latent_channels, vace_latents[0].size(1), vace_latents[0].size(2), vace_latents[0].size(3))
721
+ seq_len = math.ceil((target_shape[2] * target_shape[3]) / (self.transformer.config.patch_size[1] * self.transformer.config.patch_size[2]) * target_shape[1])
722
+ # 7. Denoising loop
723
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
724
+ self.transformer.num_inference_steps = num_inference_steps
725
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
726
+ for i, t in enumerate(timesteps):
727
+ self.transformer.current_steps = i
728
+
729
+ if self.interrupt:
730
+ continue
731
+
732
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
733
+ if hasattr(self.scheduler, "scale_model_input"):
734
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
735
+
736
+ vace_context_input = torch.stack(vace_context * 2) if do_classifier_free_guidance else vace_context
737
+
738
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
739
+ timestep = t.expand(latent_model_input.shape[0])
740
+
741
+ if self.transformer_2 is not None:
742
+ if t >= boundary * self.scheduler.config.num_train_timesteps:
743
+ local_transformer = self.transformer_2
744
+ else:
745
+ local_transformer = self.transformer
746
+ else:
747
+ local_transformer = self.transformer
748
+
749
+ # predict noise model_output
750
+ with torch.cuda.amp.autocast(dtype=weight_dtype), torch.cuda.device(device=device):
751
+ noise_pred = local_transformer(
752
+ x=latent_model_input,
753
+ context=in_prompt_embeds,
754
+ t=timestep,
755
+ vace_context=vace_context_input,
756
+ seq_len=seq_len,
757
+ vace_context_scale=vace_context_scale,
758
+ )
759
+
760
+ # perform guidance
761
+ if do_classifier_free_guidance:
762
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
763
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
764
+
765
+ # compute the previous noisy sample x_t -> x_t-1
766
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
767
+
768
+ if callback_on_step_end is not None:
769
+ callback_kwargs = {}
770
+ for k in callback_on_step_end_tensor_inputs:
771
+ callback_kwargs[k] = locals()[k]
772
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
773
+
774
+ latents = callback_outputs.pop("latents", latents)
775
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
776
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
777
+
778
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
779
+ progress_bar.update()
780
+ if comfyui_progressbar:
781
+ pbar.update(1)
782
+
783
+ if subject_ref_images is not None:
784
+ len_subject_ref_images = len(subject_ref_images[0])
785
+ latents = latents[:, :, len_subject_ref_images:, :, :]
786
+
787
+ if output_type == "numpy":
788
+ video = self.decode_latents(latents)
789
+ elif not output_type == "latent":
790
+ video = self.decode_latents(latents)
791
+ video = self.video_processor.postprocess_video(video=video, output_type=output_type)
792
+ else:
793
+ video = latents
794
+
795
+ # Offload all models
796
+ self.maybe_free_model_hooks()
797
+
798
+ if not return_dict:
799
+ video = torch.from_numpy(video)
800
+
801
+ return WanPipelineOutput(videos=video)
videox_fun/pipeline/pipeline_wan_fun_control.py ADDED
@@ -0,0 +1,799 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import math
3
+ from dataclasses import dataclass
4
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn.functional as F
9
+ import torchvision.transforms.functional as TF
10
+ from diffusers import FlowMatchEulerDiscreteScheduler
11
+ from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
12
+ from diffusers.image_processor import VaeImageProcessor
13
+ from diffusers.models.embeddings import get_1d_rotary_pos_embed
14
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
15
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
16
+ from diffusers.utils import BaseOutput, logging, replace_example_docstring
17
+ from diffusers.utils.torch_utils import randn_tensor
18
+ from diffusers.video_processor import VideoProcessor
19
+ from einops import rearrange
20
+ from PIL import Image
21
+ from transformers import T5Tokenizer
22
+
23
+ from ..models import (AutoencoderKLWan, AutoTokenizer, CLIPModel,
24
+ WanT5EncoderModel, WanTransformer3DModel)
25
+ from ..utils.fm_solvers import (FlowDPMSolverMultistepScheduler,
26
+ get_sampling_sigmas)
27
+ from ..utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
28
+
29
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
30
+
31
+
32
+ EXAMPLE_DOC_STRING = """
33
+ Examples:
34
+ ```python
35
+ pass
36
+ ```
37
+ """
38
+
39
+
40
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
41
+ def retrieve_timesteps(
42
+ scheduler,
43
+ num_inference_steps: Optional[int] = None,
44
+ device: Optional[Union[str, torch.device]] = None,
45
+ timesteps: Optional[List[int]] = None,
46
+ sigmas: Optional[List[float]] = None,
47
+ **kwargs,
48
+ ):
49
+ """
50
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
51
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
52
+
53
+ Args:
54
+ scheduler (`SchedulerMixin`):
55
+ The scheduler to get timesteps from.
56
+ num_inference_steps (`int`):
57
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
58
+ must be `None`.
59
+ device (`str` or `torch.device`, *optional*):
60
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
61
+ timesteps (`List[int]`, *optional*):
62
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
63
+ `num_inference_steps` and `sigmas` must be `None`.
64
+ sigmas (`List[float]`, *optional*):
65
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
66
+ `num_inference_steps` and `timesteps` must be `None`.
67
+
68
+ Returns:
69
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
70
+ second element is the number of inference steps.
71
+ """
72
+ if timesteps is not None and sigmas is not None:
73
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
74
+ if timesteps is not None:
75
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
76
+ if not accepts_timesteps:
77
+ raise ValueError(
78
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
79
+ f" timestep schedules. Please check whether you are using the correct scheduler."
80
+ )
81
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
82
+ timesteps = scheduler.timesteps
83
+ num_inference_steps = len(timesteps)
84
+ elif sigmas is not None:
85
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
86
+ if not accept_sigmas:
87
+ raise ValueError(
88
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
89
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
90
+ )
91
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
92
+ timesteps = scheduler.timesteps
93
+ num_inference_steps = len(timesteps)
94
+ else:
95
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
96
+ timesteps = scheduler.timesteps
97
+ return timesteps, num_inference_steps
98
+
99
+
100
+ def resize_mask(mask, latent, process_first_frame_only=True):
101
+ latent_size = latent.size()
102
+ batch_size, channels, num_frames, height, width = mask.shape
103
+
104
+ if process_first_frame_only:
105
+ target_size = list(latent_size[2:])
106
+ target_size[0] = 1
107
+ first_frame_resized = F.interpolate(
108
+ mask[:, :, 0:1, :, :],
109
+ size=target_size,
110
+ mode='trilinear',
111
+ align_corners=False
112
+ )
113
+
114
+ target_size = list(latent_size[2:])
115
+ target_size[0] = target_size[0] - 1
116
+ if target_size[0] != 0:
117
+ remaining_frames_resized = F.interpolate(
118
+ mask[:, :, 1:, :, :],
119
+ size=target_size,
120
+ mode='trilinear',
121
+ align_corners=False
122
+ )
123
+ resized_mask = torch.cat([first_frame_resized, remaining_frames_resized], dim=2)
124
+ else:
125
+ resized_mask = first_frame_resized
126
+ else:
127
+ target_size = list(latent_size[2:])
128
+ resized_mask = F.interpolate(
129
+ mask,
130
+ size=target_size,
131
+ mode='trilinear',
132
+ align_corners=False
133
+ )
134
+ return resized_mask
135
+
136
+
137
+ @dataclass
138
+ class WanPipelineOutput(BaseOutput):
139
+ r"""
140
+ Output class for CogVideo pipelines.
141
+
142
+ Args:
143
+ video (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
144
+ List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
145
+ denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
146
+ `(batch_size, num_frames, channels, height, width)`.
147
+ """
148
+
149
+ videos: torch.Tensor
150
+
151
+
152
+ class WanFunControlPipeline(DiffusionPipeline):
153
+ r"""
154
+ Pipeline for text-to-video generation using Wan.
155
+
156
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
157
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
158
+ """
159
+
160
+ _optional_components = []
161
+ model_cpu_offload_seq = "text_encoder->clip_image_encoder->transformer->vae"
162
+
163
+ _callback_tensor_inputs = [
164
+ "latents",
165
+ "prompt_embeds",
166
+ "negative_prompt_embeds",
167
+ ]
168
+
169
+ def __init__(
170
+ self,
171
+ tokenizer: AutoTokenizer,
172
+ text_encoder: WanT5EncoderModel,
173
+ vae: AutoencoderKLWan,
174
+ transformer: WanTransformer3DModel,
175
+ clip_image_encoder: CLIPModel,
176
+ scheduler: FlowMatchEulerDiscreteScheduler,
177
+ ):
178
+ super().__init__()
179
+
180
+ self.register_modules(
181
+ tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, clip_image_encoder=clip_image_encoder, scheduler=scheduler
182
+ )
183
+
184
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae.spatial_compression_ratio)
185
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae.spatial_compression_ratio)
186
+ self.mask_processor = VaeImageProcessor(
187
+ vae_scale_factor=self.vae.spatial_compression_ratio, do_normalize=False, do_binarize=True, do_convert_grayscale=True
188
+ )
189
+
190
+ def _get_t5_prompt_embeds(
191
+ self,
192
+ prompt: Union[str, List[str]] = None,
193
+ num_videos_per_prompt: int = 1,
194
+ max_sequence_length: int = 512,
195
+ device: Optional[torch.device] = None,
196
+ dtype: Optional[torch.dtype] = None,
197
+ ):
198
+ device = device or self._execution_device
199
+ dtype = dtype or self.text_encoder.dtype
200
+
201
+ prompt = [prompt] if isinstance(prompt, str) else prompt
202
+ batch_size = len(prompt)
203
+
204
+ text_inputs = self.tokenizer(
205
+ prompt,
206
+ padding="max_length",
207
+ max_length=max_sequence_length,
208
+ truncation=True,
209
+ add_special_tokens=True,
210
+ return_tensors="pt",
211
+ )
212
+ text_input_ids = text_inputs.input_ids
213
+ prompt_attention_mask = text_inputs.attention_mask
214
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
215
+
216
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
217
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
218
+ logger.warning(
219
+ "The following part of your input was truncated because `max_sequence_length` is set to "
220
+ f" {max_sequence_length} tokens: {removed_text}"
221
+ )
222
+
223
+ seq_lens = prompt_attention_mask.gt(0).sum(dim=1).long()
224
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask.to(device))[0]
225
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
226
+
227
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
228
+ _, seq_len, _ = prompt_embeds.shape
229
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
230
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
231
+
232
+ return [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
233
+
234
+ def encode_prompt(
235
+ self,
236
+ prompt: Union[str, List[str]],
237
+ negative_prompt: Optional[Union[str, List[str]]] = None,
238
+ do_classifier_free_guidance: bool = True,
239
+ num_videos_per_prompt: int = 1,
240
+ prompt_embeds: Optional[torch.Tensor] = None,
241
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
242
+ max_sequence_length: int = 512,
243
+ device: Optional[torch.device] = None,
244
+ dtype: Optional[torch.dtype] = None,
245
+ ):
246
+ r"""
247
+ Encodes the prompt into text encoder hidden states.
248
+
249
+ Args:
250
+ prompt (`str` or `List[str]`, *optional*):
251
+ prompt to be encoded
252
+ negative_prompt (`str` or `List[str]`, *optional*):
253
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
254
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
255
+ less than `1`).
256
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
257
+ Whether to use classifier free guidance or not.
258
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
259
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
260
+ prompt_embeds (`torch.Tensor`, *optional*):
261
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
262
+ provided, text embeddings will be generated from `prompt` input argument.
263
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
264
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
265
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
266
+ argument.
267
+ device: (`torch.device`, *optional*):
268
+ torch device
269
+ dtype: (`torch.dtype`, *optional*):
270
+ torch dtype
271
+ """
272
+ device = device or self._execution_device
273
+
274
+ prompt = [prompt] if isinstance(prompt, str) else prompt
275
+ if prompt is not None:
276
+ batch_size = len(prompt)
277
+ else:
278
+ batch_size = prompt_embeds.shape[0]
279
+
280
+ if prompt_embeds is None:
281
+ prompt_embeds = self._get_t5_prompt_embeds(
282
+ prompt=prompt,
283
+ num_videos_per_prompt=num_videos_per_prompt,
284
+ max_sequence_length=max_sequence_length,
285
+ device=device,
286
+ dtype=dtype,
287
+ )
288
+
289
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
290
+ negative_prompt = negative_prompt or ""
291
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
292
+
293
+ if prompt is not None and type(prompt) is not type(negative_prompt):
294
+ raise TypeError(
295
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
296
+ f" {type(prompt)}."
297
+ )
298
+ elif batch_size != len(negative_prompt):
299
+ raise ValueError(
300
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
301
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
302
+ " the batch size of `prompt`."
303
+ )
304
+
305
+ negative_prompt_embeds = self._get_t5_prompt_embeds(
306
+ prompt=negative_prompt,
307
+ num_videos_per_prompt=num_videos_per_prompt,
308
+ max_sequence_length=max_sequence_length,
309
+ device=device,
310
+ dtype=dtype,
311
+ )
312
+
313
+ return prompt_embeds, negative_prompt_embeds
314
+
315
+ def prepare_latents(
316
+ self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
317
+ ):
318
+ if isinstance(generator, list) and len(generator) != batch_size:
319
+ raise ValueError(
320
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
321
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
322
+ )
323
+
324
+ shape = (
325
+ batch_size,
326
+ num_channels_latents,
327
+ (num_frames - 1) // self.vae.temporal_compression_ratio + 1,
328
+ height // self.vae.spatial_compression_ratio,
329
+ width // self.vae.spatial_compression_ratio,
330
+ )
331
+
332
+ if latents is None:
333
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
334
+ else:
335
+ latents = latents.to(device)
336
+
337
+ # scale the initial noise by the standard deviation required by the scheduler
338
+ if hasattr(self.scheduler, "init_noise_sigma"):
339
+ latents = latents * self.scheduler.init_noise_sigma
340
+ return latents
341
+
342
+ def prepare_control_latents(
343
+ self, control, control_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
344
+ ):
345
+ # resize the control to latents shape as we concatenate the control to the latents
346
+ # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
347
+ # and half precision
348
+
349
+ if control is not None:
350
+ control = control.to(device=device, dtype=dtype)
351
+ bs = 1
352
+ new_control = []
353
+ for i in range(0, control.shape[0], bs):
354
+ control_bs = control[i : i + bs]
355
+ control_bs = self.vae.encode(control_bs)[0]
356
+ control_bs = control_bs.mode()
357
+ new_control.append(control_bs)
358
+ control = torch.cat(new_control, dim = 0)
359
+
360
+ if control_image is not None:
361
+ control_image = control_image.to(device=device, dtype=dtype)
362
+ bs = 1
363
+ new_control_pixel_values = []
364
+ for i in range(0, control_image.shape[0], bs):
365
+ control_pixel_values_bs = control_image[i : i + bs]
366
+ control_pixel_values_bs = self.vae.encode(control_pixel_values_bs)[0]
367
+ control_pixel_values_bs = control_pixel_values_bs.mode()
368
+ new_control_pixel_values.append(control_pixel_values_bs)
369
+ control_image_latents = torch.cat(new_control_pixel_values, dim = 0)
370
+ else:
371
+ control_image_latents = None
372
+
373
+ return control, control_image_latents
374
+
375
+ def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
376
+ frames = self.vae.decode(latents.to(self.vae.dtype)).sample
377
+ frames = (frames / 2 + 0.5).clamp(0, 1)
378
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
379
+ frames = frames.cpu().float().numpy()
380
+ return frames
381
+
382
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
383
+ def prepare_extra_step_kwargs(self, generator, eta):
384
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
385
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
386
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
387
+ # and should be between [0, 1]
388
+
389
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
390
+ extra_step_kwargs = {}
391
+ if accepts_eta:
392
+ extra_step_kwargs["eta"] = eta
393
+
394
+ # check if the scheduler accepts generator
395
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
396
+ if accepts_generator:
397
+ extra_step_kwargs["generator"] = generator
398
+ return extra_step_kwargs
399
+
400
+ # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs
401
+ def check_inputs(
402
+ self,
403
+ prompt,
404
+ height,
405
+ width,
406
+ negative_prompt,
407
+ callback_on_step_end_tensor_inputs,
408
+ prompt_embeds=None,
409
+ negative_prompt_embeds=None,
410
+ ):
411
+ if height % 8 != 0 or width % 8 != 0:
412
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
413
+
414
+ if callback_on_step_end_tensor_inputs is not None and not all(
415
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
416
+ ):
417
+ raise ValueError(
418
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
419
+ )
420
+ if prompt is not None and prompt_embeds is not None:
421
+ raise ValueError(
422
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
423
+ " only forward one of the two."
424
+ )
425
+ elif prompt is None and prompt_embeds is None:
426
+ raise ValueError(
427
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
428
+ )
429
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
430
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
431
+
432
+ if prompt is not None and negative_prompt_embeds is not None:
433
+ raise ValueError(
434
+ f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
435
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
436
+ )
437
+
438
+ if negative_prompt is not None and negative_prompt_embeds is not None:
439
+ raise ValueError(
440
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
441
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
442
+ )
443
+
444
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
445
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
446
+ raise ValueError(
447
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
448
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
449
+ f" {negative_prompt_embeds.shape}."
450
+ )
451
+
452
+ @property
453
+ def guidance_scale(self):
454
+ return self._guidance_scale
455
+
456
+ @property
457
+ def num_timesteps(self):
458
+ return self._num_timesteps
459
+
460
+ @property
461
+ def attention_kwargs(self):
462
+ return self._attention_kwargs
463
+
464
+ @property
465
+ def interrupt(self):
466
+ return self._interrupt
467
+
468
+ @torch.no_grad()
469
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
470
+ def __call__(
471
+ self,
472
+ prompt: Optional[Union[str, List[str]]] = None,
473
+ negative_prompt: Optional[Union[str, List[str]]] = None,
474
+ height: int = 480,
475
+ width: int = 720,
476
+ control_video: Union[torch.FloatTensor] = None,
477
+ control_camera_video: Union[torch.FloatTensor] = None,
478
+ start_image: Union[torch.FloatTensor] = None,
479
+ ref_image: Union[torch.FloatTensor] = None,
480
+ num_frames: int = 49,
481
+ num_inference_steps: int = 50,
482
+ timesteps: Optional[List[int]] = None,
483
+ guidance_scale: float = 6,
484
+ num_videos_per_prompt: int = 1,
485
+ eta: float = 0.0,
486
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
487
+ latents: Optional[torch.FloatTensor] = None,
488
+ prompt_embeds: Optional[torch.FloatTensor] = None,
489
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
490
+ output_type: str = "numpy",
491
+ return_dict: bool = False,
492
+ callback_on_step_end: Optional[
493
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
494
+ ] = None,
495
+ attention_kwargs: Optional[Dict[str, Any]] = None,
496
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
497
+ clip_image: Image = None,
498
+ max_sequence_length: int = 512,
499
+ comfyui_progressbar: bool = False,
500
+ shift: int = 5,
501
+ ) -> Union[WanPipelineOutput, Tuple]:
502
+ """
503
+ Function invoked when calling the pipeline for generation.
504
+ Args:
505
+
506
+ Examples:
507
+
508
+ Returns:
509
+
510
+ """
511
+
512
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
513
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
514
+ num_videos_per_prompt = 1
515
+
516
+ # 1. Check inputs. Raise error if not correct
517
+ self.check_inputs(
518
+ prompt,
519
+ height,
520
+ width,
521
+ negative_prompt,
522
+ callback_on_step_end_tensor_inputs,
523
+ prompt_embeds,
524
+ negative_prompt_embeds,
525
+ )
526
+ self._guidance_scale = guidance_scale
527
+ self._attention_kwargs = attention_kwargs
528
+ self._interrupt = False
529
+
530
+ # 2. Default call parameters
531
+ if prompt is not None and isinstance(prompt, str):
532
+ batch_size = 1
533
+ elif prompt is not None and isinstance(prompt, list):
534
+ batch_size = len(prompt)
535
+ else:
536
+ batch_size = prompt_embeds.shape[0]
537
+
538
+ device = self._execution_device
539
+ weight_dtype = self.text_encoder.dtype
540
+
541
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
542
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
543
+ # corresponds to doing no classifier free guidance.
544
+ do_classifier_free_guidance = guidance_scale > 1.0
545
+
546
+ # 3. Encode input prompt
547
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
548
+ prompt,
549
+ negative_prompt,
550
+ do_classifier_free_guidance,
551
+ num_videos_per_prompt=num_videos_per_prompt,
552
+ prompt_embeds=prompt_embeds,
553
+ negative_prompt_embeds=negative_prompt_embeds,
554
+ max_sequence_length=max_sequence_length,
555
+ device=device,
556
+ )
557
+ if do_classifier_free_guidance:
558
+ in_prompt_embeds = negative_prompt_embeds + prompt_embeds
559
+ else:
560
+ in_prompt_embeds = prompt_embeds
561
+
562
+ # 4. Prepare timesteps
563
+ if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler):
564
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps, mu=1)
565
+ elif isinstance(self.scheduler, FlowUniPCMultistepScheduler):
566
+ self.scheduler.set_timesteps(num_inference_steps, device=device, shift=shift)
567
+ timesteps = self.scheduler.timesteps
568
+ elif isinstance(self.scheduler, FlowDPMSolverMultistepScheduler):
569
+ sampling_sigmas = get_sampling_sigmas(num_inference_steps, shift)
570
+ timesteps, _ = retrieve_timesteps(
571
+ self.scheduler,
572
+ device=device,
573
+ sigmas=sampling_sigmas)
574
+ else:
575
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
576
+ self._num_timesteps = len(timesteps)
577
+ if comfyui_progressbar:
578
+ from comfy.utils import ProgressBar
579
+ pbar = ProgressBar(num_inference_steps + 2)
580
+
581
+ # 5. Prepare latents.
582
+ latent_channels = self.vae.config.latent_channels
583
+ latents = self.prepare_latents(
584
+ batch_size * num_videos_per_prompt,
585
+ latent_channels,
586
+ num_frames,
587
+ height,
588
+ width,
589
+ weight_dtype,
590
+ device,
591
+ generator,
592
+ latents,
593
+ )
594
+ if comfyui_progressbar:
595
+ pbar.update(1)
596
+
597
+ # Prepare mask latent variables
598
+ if control_camera_video is not None:
599
+ control_latents = None
600
+ # Rearrange dimensions
601
+ # Concatenate and transpose dimensions
602
+ control_camera_latents = torch.concat(
603
+ [
604
+ torch.repeat_interleave(control_camera_video[:, :, 0:1], repeats=4, dim=2),
605
+ control_camera_video[:, :, 1:]
606
+ ], dim=2
607
+ ).transpose(1, 2)
608
+
609
+ # Reshape, transpose, and view into desired shape
610
+ b, f, c, h, w = control_camera_latents.shape
611
+ control_camera_latents = control_camera_latents.contiguous().view(b, f // 4, 4, c, h, w).transpose(2, 3)
612
+ control_camera_latents = control_camera_latents.contiguous().view(b, f // 4, c * 4, h, w).transpose(1, 2)
613
+ elif control_video is not None:
614
+ video_length = control_video.shape[2]
615
+ control_video = self.image_processor.preprocess(rearrange(control_video, "b c f h w -> (b f) c h w"), height=height, width=width)
616
+ control_video = control_video.to(dtype=torch.float32)
617
+ control_video = rearrange(control_video, "(b f) c h w -> b c f h w", f=video_length)
618
+ control_video_latents = self.prepare_control_latents(
619
+ None,
620
+ control_video,
621
+ batch_size,
622
+ height,
623
+ width,
624
+ weight_dtype,
625
+ device,
626
+ generator,
627
+ do_classifier_free_guidance
628
+ )[1]
629
+ control_camera_latents = None
630
+ else:
631
+ control_video_latents = torch.zeros_like(latents).to(device, weight_dtype)
632
+ control_camera_latents = None
633
+
634
+ if start_image is not None:
635
+ video_length = start_image.shape[2]
636
+ start_image = self.image_processor.preprocess(rearrange(start_image, "b c f h w -> (b f) c h w"), height=height, width=width)
637
+ start_image = start_image.to(dtype=torch.float32)
638
+ start_image = rearrange(start_image, "(b f) c h w -> b c f h w", f=video_length)
639
+
640
+ start_image_latentes = self.prepare_control_latents(
641
+ None,
642
+ start_image,
643
+ batch_size,
644
+ height,
645
+ width,
646
+ weight_dtype,
647
+ device,
648
+ generator,
649
+ do_classifier_free_guidance
650
+ )[1]
651
+
652
+ start_image_latentes_conv_in = torch.zeros_like(latents)
653
+ if latents.size()[2] != 1:
654
+ start_image_latentes_conv_in[:, :, :1] = start_image_latentes
655
+ else:
656
+ start_image_latentes_conv_in = torch.zeros_like(latents)
657
+
658
+ # Prepare clip latent variables
659
+ if clip_image is not None:
660
+ clip_image = TF.to_tensor(clip_image).sub_(0.5).div_(0.5).to(device, weight_dtype)
661
+ clip_context = self.clip_image_encoder([clip_image[:, None, :, :]])
662
+ else:
663
+ clip_image = Image.new("RGB", (512, 512), color=(0, 0, 0))
664
+ clip_image = TF.to_tensor(clip_image).sub_(0.5).div_(0.5).to(device, weight_dtype)
665
+ clip_context = self.clip_image_encoder([clip_image[:, None, :, :]])
666
+ clip_context = torch.zeros_like(clip_context)
667
+
668
+ if self.transformer.config.get("add_ref_conv", False):
669
+ if ref_image is not None:
670
+ video_length = ref_image.shape[2]
671
+ ref_image = self.image_processor.preprocess(rearrange(ref_image, "b c f h w -> (b f) c h w"), height=height, width=width)
672
+ ref_image = ref_image.to(dtype=torch.float32)
673
+ ref_image = rearrange(ref_image, "(b f) c h w -> b c f h w", f=video_length)
674
+
675
+ ref_image_latentes = self.prepare_control_latents(
676
+ None,
677
+ ref_image,
678
+ batch_size,
679
+ height,
680
+ width,
681
+ weight_dtype,
682
+ device,
683
+ generator,
684
+ do_classifier_free_guidance
685
+ )[1]
686
+ ref_image_latentes = ref_image_latentes[:, :, 0]
687
+ else:
688
+ ref_image_latentes = torch.zeros_like(latents)[:, :, 0]
689
+ else:
690
+ if ref_image is not None:
691
+ raise ValueError("The add_ref_conv is False, but ref_image is not None")
692
+ else:
693
+ ref_image_latentes = None
694
+
695
+ if comfyui_progressbar:
696
+ pbar.update(1)
697
+
698
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
699
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
700
+
701
+ target_shape = (self.vae.latent_channels, (num_frames - 1) // self.vae.temporal_compression_ratio + 1, width // self.vae.spatial_compression_ratio, height // self.vae.spatial_compression_ratio)
702
+ seq_len = math.ceil((target_shape[2] * target_shape[3]) / (self.transformer.config.patch_size[1] * self.transformer.config.patch_size[2]) * target_shape[1])
703
+ # 7. Denoising loop
704
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
705
+ self.transformer.num_inference_steps = num_inference_steps
706
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
707
+ for i, t in enumerate(timesteps):
708
+ self.transformer.current_steps = i
709
+
710
+ if self.interrupt:
711
+ continue
712
+
713
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
714
+ if hasattr(self.scheduler, "scale_model_input"):
715
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
716
+
717
+ # Prepare mask latent variables
718
+ if control_camera_video is not None:
719
+ control_latents_input = None
720
+ control_camera_latents_input = (
721
+ torch.cat([control_camera_latents] * 2) if do_classifier_free_guidance else control_camera_latents
722
+ ).to(device, weight_dtype)
723
+ else:
724
+ control_latents_input = (
725
+ torch.cat([control_video_latents] * 2) if do_classifier_free_guidance else control_video_latents
726
+ ).to(device, weight_dtype)
727
+ control_camera_latents_input = None
728
+
729
+ start_image_latentes_conv_in_input = (
730
+ torch.cat([start_image_latentes_conv_in] * 2) if do_classifier_free_guidance else start_image_latentes_conv_in
731
+ ).to(device, weight_dtype)
732
+ control_latents_input = start_image_latentes_conv_in_input if control_latents_input is None else \
733
+ torch.cat([control_latents_input, start_image_latentes_conv_in_input], dim = 1)
734
+
735
+ clip_context_input = (
736
+ torch.cat([clip_context] * 2) if do_classifier_free_guidance else clip_context
737
+ )
738
+
739
+ if ref_image_latentes is not None:
740
+ full_ref = (
741
+ torch.cat([ref_image_latentes] * 2) if do_classifier_free_guidance else ref_image_latentes
742
+ ).to(device, weight_dtype)
743
+ else:
744
+ full_ref = None
745
+
746
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
747
+ timestep = t.expand(latent_model_input.shape[0])
748
+
749
+ # predict noise model_output
750
+ with torch.cuda.amp.autocast(dtype=weight_dtype), torch.cuda.device(device=device):
751
+ noise_pred = self.transformer(
752
+ x=latent_model_input,
753
+ context=in_prompt_embeds,
754
+ t=timestep,
755
+ seq_len=seq_len,
756
+ y=control_latents_input,
757
+ y_camera=control_camera_latents_input,
758
+ full_ref=full_ref,
759
+ clip_fea=clip_context_input,
760
+ )
761
+
762
+ # perform guidance
763
+ if do_classifier_free_guidance:
764
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
765
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
766
+
767
+ # compute the previous noisy sample x_t -> x_t-1
768
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
769
+
770
+ if callback_on_step_end is not None:
771
+ callback_kwargs = {}
772
+ for k in callback_on_step_end_tensor_inputs:
773
+ callback_kwargs[k] = locals()[k]
774
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
775
+
776
+ latents = callback_outputs.pop("latents", latents)
777
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
778
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
779
+
780
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
781
+ progress_bar.update()
782
+ if comfyui_progressbar:
783
+ pbar.update(1)
784
+
785
+ if output_type == "numpy":
786
+ video = self.decode_latents(latents)
787
+ elif not output_type == "latent":
788
+ video = self.decode_latents(latents)
789
+ video = self.video_processor.postprocess_video(video=video, output_type=output_type)
790
+ else:
791
+ video = latents
792
+
793
+ # Offload all models
794
+ self.maybe_free_model_hooks()
795
+
796
+ if not return_dict:
797
+ video = torch.from_numpy(video)
798
+
799
+ return WanPipelineOutput(videos=video)
videox_fun/pipeline/pipeline_wan_fun_inpaint.py ADDED
@@ -0,0 +1,734 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import math
3
+ from dataclasses import dataclass
4
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn.functional as F
9
+ import torchvision.transforms.functional as TF
10
+ from diffusers import FlowMatchEulerDiscreteScheduler
11
+ from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
12
+ from diffusers.image_processor import VaeImageProcessor
13
+ from diffusers.models.embeddings import get_1d_rotary_pos_embed
14
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
15
+ from diffusers.utils import BaseOutput, logging, replace_example_docstring
16
+ from diffusers.utils.torch_utils import randn_tensor
17
+ from diffusers.video_processor import VideoProcessor
18
+ from einops import rearrange
19
+ from PIL import Image
20
+ from transformers import T5Tokenizer
21
+
22
+ from ..models import (AutoencoderKLWan, AutoTokenizer, CLIPModel,
23
+ WanT5EncoderModel, WanTransformer3DModel)
24
+ from ..utils.fm_solvers import (FlowDPMSolverMultistepScheduler,
25
+ get_sampling_sigmas)
26
+ from ..utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
27
+
28
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
29
+
30
+
31
+ EXAMPLE_DOC_STRING = """
32
+ Examples:
33
+ ```python
34
+ pass
35
+ ```
36
+ """
37
+
38
+
39
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
40
+ def retrieve_timesteps(
41
+ scheduler,
42
+ num_inference_steps: Optional[int] = None,
43
+ device: Optional[Union[str, torch.device]] = None,
44
+ timesteps: Optional[List[int]] = None,
45
+ sigmas: Optional[List[float]] = None,
46
+ **kwargs,
47
+ ):
48
+ """
49
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
50
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
51
+
52
+ Args:
53
+ scheduler (`SchedulerMixin`):
54
+ The scheduler to get timesteps from.
55
+ num_inference_steps (`int`):
56
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
57
+ must be `None`.
58
+ device (`str` or `torch.device`, *optional*):
59
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
60
+ timesteps (`List[int]`, *optional*):
61
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
62
+ `num_inference_steps` and `sigmas` must be `None`.
63
+ sigmas (`List[float]`, *optional*):
64
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
65
+ `num_inference_steps` and `timesteps` must be `None`.
66
+
67
+ Returns:
68
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
69
+ second element is the number of inference steps.
70
+ """
71
+ if timesteps is not None and sigmas is not None:
72
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
73
+ if timesteps is not None:
74
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
75
+ if not accepts_timesteps:
76
+ raise ValueError(
77
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
78
+ f" timestep schedules. Please check whether you are using the correct scheduler."
79
+ )
80
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
81
+ timesteps = scheduler.timesteps
82
+ num_inference_steps = len(timesteps)
83
+ elif sigmas is not None:
84
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
85
+ if not accept_sigmas:
86
+ raise ValueError(
87
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
88
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
89
+ )
90
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
91
+ timesteps = scheduler.timesteps
92
+ num_inference_steps = len(timesteps)
93
+ else:
94
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
95
+ timesteps = scheduler.timesteps
96
+ return timesteps, num_inference_steps
97
+
98
+
99
+ def resize_mask(mask, latent, process_first_frame_only=True):
100
+ latent_size = latent.size()
101
+ batch_size, channels, num_frames, height, width = mask.shape
102
+
103
+ if process_first_frame_only:
104
+ target_size = list(latent_size[2:])
105
+ target_size[0] = 1
106
+ first_frame_resized = F.interpolate(
107
+ mask[:, :, 0:1, :, :],
108
+ size=target_size,
109
+ mode='trilinear',
110
+ align_corners=False
111
+ )
112
+
113
+ target_size = list(latent_size[2:])
114
+ target_size[0] = target_size[0] - 1
115
+ if target_size[0] != 0:
116
+ remaining_frames_resized = F.interpolate(
117
+ mask[:, :, 1:, :, :],
118
+ size=target_size,
119
+ mode='trilinear',
120
+ align_corners=False
121
+ )
122
+ resized_mask = torch.cat([first_frame_resized, remaining_frames_resized], dim=2)
123
+ else:
124
+ resized_mask = first_frame_resized
125
+ else:
126
+ target_size = list(latent_size[2:])
127
+ resized_mask = F.interpolate(
128
+ mask,
129
+ size=target_size,
130
+ mode='trilinear',
131
+ align_corners=False
132
+ )
133
+ return resized_mask
134
+
135
+
136
+ @dataclass
137
+ class WanPipelineOutput(BaseOutput):
138
+ r"""
139
+ Output class for CogVideo pipelines.
140
+
141
+ Args:
142
+ video (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
143
+ List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
144
+ denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
145
+ `(batch_size, num_frames, channels, height, width)`.
146
+ """
147
+
148
+ videos: torch.Tensor
149
+
150
+
151
+ class WanFunInpaintPipeline(DiffusionPipeline):
152
+ r"""
153
+ Pipeline for text-to-video generation using Wan.
154
+
155
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
156
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
157
+ """
158
+
159
+ _optional_components = []
160
+ model_cpu_offload_seq = "text_encoder->clip_image_encoder->transformer->vae"
161
+
162
+ _callback_tensor_inputs = [
163
+ "latents",
164
+ "prompt_embeds",
165
+ "negative_prompt_embeds",
166
+ ]
167
+
168
+ def __init__(
169
+ self,
170
+ tokenizer: AutoTokenizer,
171
+ text_encoder: WanT5EncoderModel,
172
+ vae: AutoencoderKLWan,
173
+ transformer: WanTransformer3DModel,
174
+ clip_image_encoder: CLIPModel,
175
+ scheduler: FlowMatchEulerDiscreteScheduler,
176
+ ):
177
+ super().__init__()
178
+
179
+ self.register_modules(
180
+ tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, clip_image_encoder=clip_image_encoder, scheduler=scheduler
181
+ )
182
+
183
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae.spatial_compression_ratio)
184
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae.spatial_compression_ratio)
185
+ self.mask_processor = VaeImageProcessor(
186
+ vae_scale_factor=self.vae.spatial_compression_ratio, do_normalize=False, do_binarize=True, do_convert_grayscale=True
187
+ )
188
+
189
+ def _get_t5_prompt_embeds(
190
+ self,
191
+ prompt: Union[str, List[str]] = None,
192
+ num_videos_per_prompt: int = 1,
193
+ max_sequence_length: int = 512,
194
+ device: Optional[torch.device] = None,
195
+ dtype: Optional[torch.dtype] = None,
196
+ ):
197
+ device = device or self._execution_device
198
+ dtype = dtype or self.text_encoder.dtype
199
+
200
+ prompt = [prompt] if isinstance(prompt, str) else prompt
201
+ batch_size = len(prompt)
202
+
203
+ text_inputs = self.tokenizer(
204
+ prompt,
205
+ padding="max_length",
206
+ max_length=max_sequence_length,
207
+ truncation=True,
208
+ add_special_tokens=True,
209
+ return_tensors="pt",
210
+ )
211
+ text_input_ids = text_inputs.input_ids
212
+ prompt_attention_mask = text_inputs.attention_mask
213
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
214
+
215
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
216
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
217
+ logger.warning(
218
+ "The following part of your input was truncated because `max_sequence_length` is set to "
219
+ f" {max_sequence_length} tokens: {removed_text}"
220
+ )
221
+
222
+ seq_lens = prompt_attention_mask.gt(0).sum(dim=1).long()
223
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask.to(device))[0]
224
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
225
+
226
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
227
+ _, seq_len, _ = prompt_embeds.shape
228
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
229
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
230
+
231
+ return [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
232
+
233
+ def encode_prompt(
234
+ self,
235
+ prompt: Union[str, List[str]],
236
+ negative_prompt: Optional[Union[str, List[str]]] = None,
237
+ do_classifier_free_guidance: bool = True,
238
+ num_videos_per_prompt: int = 1,
239
+ prompt_embeds: Optional[torch.Tensor] = None,
240
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
241
+ max_sequence_length: int = 512,
242
+ device: Optional[torch.device] = None,
243
+ dtype: Optional[torch.dtype] = None,
244
+ ):
245
+ r"""
246
+ Encodes the prompt into text encoder hidden states.
247
+
248
+ Args:
249
+ prompt (`str` or `List[str]`, *optional*):
250
+ prompt to be encoded
251
+ negative_prompt (`str` or `List[str]`, *optional*):
252
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
253
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
254
+ less than `1`).
255
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
256
+ Whether to use classifier free guidance or not.
257
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
258
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
259
+ prompt_embeds (`torch.Tensor`, *optional*):
260
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
261
+ provided, text embeddings will be generated from `prompt` input argument.
262
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
263
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
264
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
265
+ argument.
266
+ device: (`torch.device`, *optional*):
267
+ torch device
268
+ dtype: (`torch.dtype`, *optional*):
269
+ torch dtype
270
+ """
271
+ device = device or self._execution_device
272
+
273
+ prompt = [prompt] if isinstance(prompt, str) else prompt
274
+ if prompt is not None:
275
+ batch_size = len(prompt)
276
+ else:
277
+ batch_size = prompt_embeds.shape[0]
278
+
279
+ if prompt_embeds is None:
280
+ prompt_embeds = self._get_t5_prompt_embeds(
281
+ prompt=prompt,
282
+ num_videos_per_prompt=num_videos_per_prompt,
283
+ max_sequence_length=max_sequence_length,
284
+ device=device,
285
+ dtype=dtype,
286
+ )
287
+
288
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
289
+ negative_prompt = negative_prompt or ""
290
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
291
+
292
+ if prompt is not None and type(prompt) is not type(negative_prompt):
293
+ raise TypeError(
294
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
295
+ f" {type(prompt)}."
296
+ )
297
+ elif batch_size != len(negative_prompt):
298
+ raise ValueError(
299
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
300
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
301
+ " the batch size of `prompt`."
302
+ )
303
+
304
+ negative_prompt_embeds = self._get_t5_prompt_embeds(
305
+ prompt=negative_prompt,
306
+ num_videos_per_prompt=num_videos_per_prompt,
307
+ max_sequence_length=max_sequence_length,
308
+ device=device,
309
+ dtype=dtype,
310
+ )
311
+
312
+ return prompt_embeds, negative_prompt_embeds
313
+
314
+ def prepare_latents(
315
+ self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
316
+ ):
317
+ if isinstance(generator, list) and len(generator) != batch_size:
318
+ raise ValueError(
319
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
320
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
321
+ )
322
+
323
+ shape = (
324
+ batch_size,
325
+ num_channels_latents,
326
+ (num_frames - 1) // self.vae.temporal_compression_ratio + 1,
327
+ height // self.vae.spatial_compression_ratio,
328
+ width // self.vae.spatial_compression_ratio,
329
+ )
330
+
331
+ if latents is None:
332
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
333
+ else:
334
+ latents = latents.to(device)
335
+
336
+ # scale the initial noise by the standard deviation required by the scheduler
337
+ if hasattr(self.scheduler, "init_noise_sigma"):
338
+ latents = latents * self.scheduler.init_noise_sigma
339
+ return latents
340
+
341
+ def prepare_mask_latents(
342
+ self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance, noise_aug_strength
343
+ ):
344
+ # resize the mask to latents shape as we concatenate the mask to the latents
345
+ # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
346
+ # and half precision
347
+
348
+ if mask is not None:
349
+ mask = mask.to(device=device, dtype=self.vae.dtype)
350
+ bs = 1
351
+ new_mask = []
352
+ for i in range(0, mask.shape[0], bs):
353
+ mask_bs = mask[i : i + bs]
354
+ mask_bs = self.vae.encode(mask_bs)[0]
355
+ mask_bs = mask_bs.mode()
356
+ new_mask.append(mask_bs)
357
+ mask = torch.cat(new_mask, dim = 0)
358
+ # mask = mask * self.vae.config.scaling_factor
359
+
360
+ if masked_image is not None:
361
+ masked_image = masked_image.to(device=device, dtype=self.vae.dtype)
362
+ bs = 1
363
+ new_mask_pixel_values = []
364
+ for i in range(0, masked_image.shape[0], bs):
365
+ mask_pixel_values_bs = masked_image[i : i + bs]
366
+ mask_pixel_values_bs = self.vae.encode(mask_pixel_values_bs)[0]
367
+ mask_pixel_values_bs = mask_pixel_values_bs.mode()
368
+ new_mask_pixel_values.append(mask_pixel_values_bs)
369
+ masked_image_latents = torch.cat(new_mask_pixel_values, dim = 0)
370
+ # masked_image_latents = masked_image_latents * self.vae.config.scaling_factor
371
+ else:
372
+ masked_image_latents = None
373
+
374
+ return mask, masked_image_latents
375
+
376
+ def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
377
+ frames = self.vae.decode(latents.to(self.vae.dtype)).sample
378
+ frames = (frames / 2 + 0.5).clamp(0, 1)
379
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
380
+ frames = frames.cpu().float().numpy()
381
+ return frames
382
+
383
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
384
+ def prepare_extra_step_kwargs(self, generator, eta):
385
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
386
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
387
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
388
+ # and should be between [0, 1]
389
+
390
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
391
+ extra_step_kwargs = {}
392
+ if accepts_eta:
393
+ extra_step_kwargs["eta"] = eta
394
+
395
+ # check if the scheduler accepts generator
396
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
397
+ if accepts_generator:
398
+ extra_step_kwargs["generator"] = generator
399
+ return extra_step_kwargs
400
+
401
+ # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs
402
+ def check_inputs(
403
+ self,
404
+ prompt,
405
+ height,
406
+ width,
407
+ negative_prompt,
408
+ callback_on_step_end_tensor_inputs,
409
+ prompt_embeds=None,
410
+ negative_prompt_embeds=None,
411
+ ):
412
+ if height % 8 != 0 or width % 8 != 0:
413
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
414
+
415
+ if callback_on_step_end_tensor_inputs is not None and not all(
416
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
417
+ ):
418
+ raise ValueError(
419
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
420
+ )
421
+ if prompt is not None and prompt_embeds is not None:
422
+ raise ValueError(
423
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
424
+ " only forward one of the two."
425
+ )
426
+ elif prompt is None and prompt_embeds is None:
427
+ raise ValueError(
428
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
429
+ )
430
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
431
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
432
+
433
+ if prompt is not None and negative_prompt_embeds is not None:
434
+ raise ValueError(
435
+ f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
436
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
437
+ )
438
+
439
+ if negative_prompt is not None and negative_prompt_embeds is not None:
440
+ raise ValueError(
441
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
442
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
443
+ )
444
+
445
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
446
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
447
+ raise ValueError(
448
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
449
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
450
+ f" {negative_prompt_embeds.shape}."
451
+ )
452
+
453
+ @property
454
+ def guidance_scale(self):
455
+ return self._guidance_scale
456
+
457
+ @property
458
+ def num_timesteps(self):
459
+ return self._num_timesteps
460
+
461
+ @property
462
+ def attention_kwargs(self):
463
+ return self._attention_kwargs
464
+
465
+ @property
466
+ def interrupt(self):
467
+ return self._interrupt
468
+
469
+ @torch.no_grad()
470
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
471
+ def __call__(
472
+ self,
473
+ prompt: Optional[Union[str, List[str]]] = None,
474
+ negative_prompt: Optional[Union[str, List[str]]] = None,
475
+ height: int = 480,
476
+ width: int = 720,
477
+ video: Union[torch.FloatTensor] = None,
478
+ mask_video: Union[torch.FloatTensor] = None,
479
+ num_frames: int = 49,
480
+ num_inference_steps: int = 50,
481
+ timesteps: Optional[List[int]] = None,
482
+ guidance_scale: float = 6,
483
+ num_videos_per_prompt: int = 1,
484
+ eta: float = 0.0,
485
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
486
+ latents: Optional[torch.FloatTensor] = None,
487
+ prompt_embeds: Optional[torch.FloatTensor] = None,
488
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
489
+ output_type: str = "numpy",
490
+ return_dict: bool = False,
491
+ callback_on_step_end: Optional[
492
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
493
+ ] = None,
494
+ attention_kwargs: Optional[Dict[str, Any]] = None,
495
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
496
+ clip_image: Image = None,
497
+ max_sequence_length: int = 512,
498
+ comfyui_progressbar: bool = False,
499
+ shift: int = 5,
500
+ ) -> Union[WanPipelineOutput, Tuple]:
501
+ """
502
+ Function invoked when calling the pipeline for generation.
503
+ Args:
504
+
505
+ Examples:
506
+
507
+ Returns:
508
+
509
+ """
510
+
511
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
512
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
513
+ num_videos_per_prompt = 1
514
+
515
+ # 1. Check inputs. Raise error if not correct
516
+ self.check_inputs(
517
+ prompt,
518
+ height,
519
+ width,
520
+ negative_prompt,
521
+ callback_on_step_end_tensor_inputs,
522
+ prompt_embeds,
523
+ negative_prompt_embeds,
524
+ )
525
+ self._guidance_scale = guidance_scale
526
+ self._attention_kwargs = attention_kwargs
527
+ self._interrupt = False
528
+
529
+ # 2. Default call parameters
530
+ if prompt is not None and isinstance(prompt, str):
531
+ batch_size = 1
532
+ elif prompt is not None and isinstance(prompt, list):
533
+ batch_size = len(prompt)
534
+ else:
535
+ batch_size = prompt_embeds.shape[0]
536
+
537
+ device = self._execution_device
538
+ weight_dtype = self.text_encoder.dtype
539
+
540
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
541
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
542
+ # corresponds to doing no classifier free guidance.
543
+ do_classifier_free_guidance = guidance_scale > 1.0
544
+
545
+ # 3. Encode input prompt
546
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
547
+ prompt,
548
+ negative_prompt,
549
+ do_classifier_free_guidance,
550
+ num_videos_per_prompt=num_videos_per_prompt,
551
+ prompt_embeds=prompt_embeds,
552
+ negative_prompt_embeds=negative_prompt_embeds,
553
+ max_sequence_length=max_sequence_length,
554
+ device=device,
555
+ )
556
+ if do_classifier_free_guidance:
557
+ in_prompt_embeds = negative_prompt_embeds + prompt_embeds
558
+ else:
559
+ in_prompt_embeds = prompt_embeds
560
+
561
+ # 4. Prepare timesteps
562
+ if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler):
563
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps, mu=1)
564
+ elif isinstance(self.scheduler, FlowUniPCMultistepScheduler):
565
+ self.scheduler.set_timesteps(num_inference_steps, device=device, shift=shift)
566
+ timesteps = self.scheduler.timesteps
567
+ elif isinstance(self.scheduler, FlowDPMSolverMultistepScheduler):
568
+ sampling_sigmas = get_sampling_sigmas(num_inference_steps, shift)
569
+ timesteps, _ = retrieve_timesteps(
570
+ self.scheduler,
571
+ device=device,
572
+ sigmas=sampling_sigmas)
573
+ else:
574
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
575
+ self._num_timesteps = len(timesteps)
576
+ if comfyui_progressbar:
577
+ from comfy.utils import ProgressBar
578
+ pbar = ProgressBar(num_inference_steps + 2)
579
+
580
+ # 5. Prepare latents.
581
+ if video is not None:
582
+ video_length = video.shape[2]
583
+ init_video = self.image_processor.preprocess(rearrange(video, "b c f h w -> (b f) c h w"), height=height, width=width)
584
+ init_video = init_video.to(dtype=torch.float32)
585
+ init_video = rearrange(init_video, "(b f) c h w -> b c f h w", f=video_length)
586
+ else:
587
+ init_video = None
588
+
589
+ latent_channels = self.vae.config.latent_channels
590
+ latents = self.prepare_latents(
591
+ batch_size * num_videos_per_prompt,
592
+ latent_channels,
593
+ num_frames,
594
+ height,
595
+ width,
596
+ weight_dtype,
597
+ device,
598
+ generator,
599
+ latents,
600
+ )
601
+ if comfyui_progressbar:
602
+ pbar.update(1)
603
+
604
+ # Prepare mask latent variables
605
+ if init_video is not None:
606
+ if (mask_video == 255).all():
607
+ mask_latents = torch.tile(
608
+ torch.zeros_like(latents)[:, :1].to(device, weight_dtype), [1, 4, 1, 1, 1]
609
+ )
610
+ masked_video_latents = torch.zeros_like(latents).to(device, weight_dtype)
611
+ else:
612
+ bs, _, video_length, height, width = video.size()
613
+ mask_condition = self.mask_processor.preprocess(rearrange(mask_video, "b c f h w -> (b f) c h w"), height=height, width=width)
614
+ mask_condition = mask_condition.to(dtype=torch.float32)
615
+ mask_condition = rearrange(mask_condition, "(b f) c h w -> b c f h w", f=video_length)
616
+
617
+ masked_video = init_video * (torch.tile(mask_condition, [1, 3, 1, 1, 1]) < 0.5)
618
+ _, masked_video_latents = self.prepare_mask_latents(
619
+ None,
620
+ masked_video,
621
+ batch_size,
622
+ height,
623
+ width,
624
+ weight_dtype,
625
+ device,
626
+ generator,
627
+ do_classifier_free_guidance,
628
+ noise_aug_strength=None,
629
+ )
630
+
631
+ mask_condition = torch.concat(
632
+ [
633
+ torch.repeat_interleave(mask_condition[:, :, 0:1], repeats=4, dim=2),
634
+ mask_condition[:, :, 1:]
635
+ ], dim=2
636
+ )
637
+ mask_condition = mask_condition.view(bs, mask_condition.shape[2] // 4, 4, height, width)
638
+ mask_condition = mask_condition.transpose(1, 2)
639
+ mask_latents = resize_mask(1 - mask_condition, masked_video_latents, True).to(device, weight_dtype)
640
+
641
+ # Prepare clip latent variables
642
+ if clip_image is not None:
643
+ clip_image = TF.to_tensor(clip_image).sub_(0.5).div_(0.5).to(device, weight_dtype)
644
+ clip_context = self.clip_image_encoder([clip_image[:, None, :, :]])
645
+ else:
646
+ clip_image = Image.new("RGB", (512, 512), color=(0, 0, 0))
647
+ clip_image = TF.to_tensor(clip_image).sub_(0.5).div_(0.5).to(device, weight_dtype)
648
+ clip_context = self.clip_image_encoder([clip_image[:, None, :, :]])
649
+ clip_context = torch.zeros_like(clip_context)
650
+ if comfyui_progressbar:
651
+ pbar.update(1)
652
+
653
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
654
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
655
+
656
+ target_shape = (self.vae.latent_channels, (num_frames - 1) // self.vae.temporal_compression_ratio + 1, width // self.vae.spatial_compression_ratio, height // self.vae.spatial_compression_ratio)
657
+ seq_len = math.ceil((target_shape[2] * target_shape[3]) / (self.transformer.config.patch_size[1] * self.transformer.config.patch_size[2]) * target_shape[1])
658
+ # 7. Denoising loop
659
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
660
+ self.transformer.num_inference_steps = num_inference_steps
661
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
662
+ for i, t in enumerate(timesteps):
663
+ self.transformer.current_steps = i
664
+
665
+ if self.interrupt:
666
+ continue
667
+
668
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
669
+ if hasattr(self.scheduler, "scale_model_input"):
670
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
671
+
672
+ if init_video is not None:
673
+ mask_input = torch.cat([mask_latents] * 2) if do_classifier_free_guidance else mask_latents
674
+ masked_video_latents_input = (
675
+ torch.cat([masked_video_latents] * 2) if do_classifier_free_guidance else masked_video_latents
676
+ )
677
+ y = torch.cat([mask_input, masked_video_latents_input], dim=1).to(device, weight_dtype)
678
+
679
+ clip_context_input = (
680
+ torch.cat([clip_context] * 2) if do_classifier_free_guidance else clip_context
681
+ )
682
+
683
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
684
+ timestep = t.expand(latent_model_input.shape[0])
685
+
686
+ # predict noise model_output
687
+ with torch.cuda.amp.autocast(dtype=weight_dtype), torch.cuda.device(device=device):
688
+ noise_pred = self.transformer(
689
+ x=latent_model_input,
690
+ context=in_prompt_embeds,
691
+ t=timestep,
692
+ seq_len=seq_len,
693
+ y=y,
694
+ clip_fea=clip_context_input,
695
+ )
696
+
697
+ # perform guidance
698
+ if do_classifier_free_guidance:
699
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
700
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
701
+
702
+ # compute the previous noisy sample x_t -> x_t-1
703
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
704
+
705
+ if callback_on_step_end is not None:
706
+ callback_kwargs = {}
707
+ for k in callback_on_step_end_tensor_inputs:
708
+ callback_kwargs[k] = locals()[k]
709
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
710
+
711
+ latents = callback_outputs.pop("latents", latents)
712
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
713
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
714
+
715
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
716
+ progress_bar.update()
717
+ if comfyui_progressbar:
718
+ pbar.update(1)
719
+
720
+ if output_type == "numpy":
721
+ video = self.decode_latents(latents)
722
+ elif not output_type == "latent":
723
+ video = self.decode_latents(latents)
724
+ video = self.video_processor.postprocess_video(video=video, output_type=output_type)
725
+ else:
726
+ video = latents
727
+
728
+ # Offload all models
729
+ self.maybe_free_model_hooks()
730
+
731
+ if not return_dict:
732
+ video = torch.from_numpy(video)
733
+
734
+ return WanPipelineOutput(videos=video)
videox_fun/pipeline/pipeline_wan_phantom.py ADDED
@@ -0,0 +1,695 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import math
3
+ from dataclasses import dataclass
4
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn.functional as F
9
+ import torchvision.transforms.functional as TF
10
+ from diffusers import FlowMatchEulerDiscreteScheduler
11
+ from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
12
+ from diffusers.image_processor import VaeImageProcessor
13
+ from diffusers.models.embeddings import get_1d_rotary_pos_embed
14
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
15
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
16
+ from diffusers.utils import BaseOutput, logging, replace_example_docstring
17
+ from diffusers.utils.torch_utils import randn_tensor
18
+ from diffusers.video_processor import VideoProcessor
19
+ from einops import rearrange
20
+ from PIL import Image
21
+ from transformers import T5Tokenizer
22
+
23
+ from ..models import (AutoencoderKLWan, AutoTokenizer, CLIPModel,
24
+ WanT5EncoderModel, WanTransformer3DModel)
25
+ from ..utils.fm_solvers import (FlowDPMSolverMultistepScheduler,
26
+ get_sampling_sigmas)
27
+ from ..utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
28
+
29
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
30
+
31
+
32
+ EXAMPLE_DOC_STRING = """
33
+ Examples:
34
+ ```python
35
+ pass
36
+ ```
37
+ """
38
+
39
+
40
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
41
+ def retrieve_timesteps(
42
+ scheduler,
43
+ num_inference_steps: Optional[int] = None,
44
+ device: Optional[Union[str, torch.device]] = None,
45
+ timesteps: Optional[List[int]] = None,
46
+ sigmas: Optional[List[float]] = None,
47
+ **kwargs,
48
+ ):
49
+ """
50
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
51
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
52
+
53
+ Args:
54
+ scheduler (`SchedulerMixin`):
55
+ The scheduler to get timesteps from.
56
+ num_inference_steps (`int`):
57
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
58
+ must be `None`.
59
+ device (`str` or `torch.device`, *optional*):
60
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
61
+ timesteps (`List[int]`, *optional*):
62
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
63
+ `num_inference_steps` and `sigmas` must be `None`.
64
+ sigmas (`List[float]`, *optional*):
65
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
66
+ `num_inference_steps` and `timesteps` must be `None`.
67
+
68
+ Returns:
69
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
70
+ second element is the number of inference steps.
71
+ """
72
+ if timesteps is not None and sigmas is not None:
73
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
74
+ if timesteps is not None:
75
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
76
+ if not accepts_timesteps:
77
+ raise ValueError(
78
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
79
+ f" timestep schedules. Please check whether you are using the correct scheduler."
80
+ )
81
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
82
+ timesteps = scheduler.timesteps
83
+ num_inference_steps = len(timesteps)
84
+ elif sigmas is not None:
85
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
86
+ if not accept_sigmas:
87
+ raise ValueError(
88
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
89
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
90
+ )
91
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
92
+ timesteps = scheduler.timesteps
93
+ num_inference_steps = len(timesteps)
94
+ else:
95
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
96
+ timesteps = scheduler.timesteps
97
+ return timesteps, num_inference_steps
98
+
99
+
100
+ def resize_mask(mask, latent, process_first_frame_only=True):
101
+ latent_size = latent.size()
102
+ batch_size, channels, num_frames, height, width = mask.shape
103
+
104
+ if process_first_frame_only:
105
+ target_size = list(latent_size[2:])
106
+ target_size[0] = 1
107
+ first_frame_resized = F.interpolate(
108
+ mask[:, :, 0:1, :, :],
109
+ size=target_size,
110
+ mode='trilinear',
111
+ align_corners=False
112
+ )
113
+
114
+ target_size = list(latent_size[2:])
115
+ target_size[0] = target_size[0] - 1
116
+ if target_size[0] != 0:
117
+ remaining_frames_resized = F.interpolate(
118
+ mask[:, :, 1:, :, :],
119
+ size=target_size,
120
+ mode='trilinear',
121
+ align_corners=False
122
+ )
123
+ resized_mask = torch.cat([first_frame_resized, remaining_frames_resized], dim=2)
124
+ else:
125
+ resized_mask = first_frame_resized
126
+ else:
127
+ target_size = list(latent_size[2:])
128
+ resized_mask = F.interpolate(
129
+ mask,
130
+ size=target_size,
131
+ mode='trilinear',
132
+ align_corners=False
133
+ )
134
+ return resized_mask
135
+
136
+
137
+ @dataclass
138
+ class WanPipelineOutput(BaseOutput):
139
+ r"""
140
+ Output class for CogVideo pipelines.
141
+
142
+ Args:
143
+ video (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
144
+ List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
145
+ denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
146
+ `(batch_size, num_frames, channels, height, width)`.
147
+ """
148
+
149
+ videos: torch.Tensor
150
+
151
+
152
+ class WanFunPhantomPipeline(DiffusionPipeline):
153
+ r"""
154
+ Pipeline for text-to-video generation using Wan.
155
+
156
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
157
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
158
+ """
159
+
160
+ _optional_components = []
161
+ model_cpu_offload_seq = "text_encoder->clip_image_encoder->transformer->vae"
162
+
163
+ _callback_tensor_inputs = [
164
+ "latents",
165
+ "prompt_embeds",
166
+ "negative_prompt_embeds",
167
+ ]
168
+
169
+ def __init__(
170
+ self,
171
+ tokenizer: AutoTokenizer,
172
+ text_encoder: WanT5EncoderModel,
173
+ vae: AutoencoderKLWan,
174
+ transformer: WanTransformer3DModel,
175
+ scheduler: FlowMatchEulerDiscreteScheduler,
176
+ ):
177
+ super().__init__()
178
+
179
+ self.register_modules(
180
+ tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
181
+ )
182
+
183
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae.spatial_compression_ratio)
184
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae.spatial_compression_ratio)
185
+ self.mask_processor = VaeImageProcessor(
186
+ vae_scale_factor=self.vae.spatial_compression_ratio, do_normalize=False, do_binarize=True, do_convert_grayscale=True
187
+ )
188
+
189
+ def _get_t5_prompt_embeds(
190
+ self,
191
+ prompt: Union[str, List[str]] = None,
192
+ num_videos_per_prompt: int = 1,
193
+ max_sequence_length: int = 512,
194
+ device: Optional[torch.device] = None,
195
+ dtype: Optional[torch.dtype] = None,
196
+ ):
197
+ device = device or self._execution_device
198
+ dtype = dtype or self.text_encoder.dtype
199
+
200
+ prompt = [prompt] if isinstance(prompt, str) else prompt
201
+ batch_size = len(prompt)
202
+
203
+ text_inputs = self.tokenizer(
204
+ prompt,
205
+ padding="max_length",
206
+ max_length=max_sequence_length,
207
+ truncation=True,
208
+ add_special_tokens=True,
209
+ return_tensors="pt",
210
+ )
211
+ text_input_ids = text_inputs.input_ids
212
+ prompt_attention_mask = text_inputs.attention_mask
213
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
214
+
215
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
216
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
217
+ logger.warning(
218
+ "The following part of your input was truncated because `max_sequence_length` is set to "
219
+ f" {max_sequence_length} tokens: {removed_text}"
220
+ )
221
+
222
+ seq_lens = prompt_attention_mask.gt(0).sum(dim=1).long()
223
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask.to(device))[0]
224
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
225
+
226
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
227
+ _, seq_len, _ = prompt_embeds.shape
228
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
229
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
230
+
231
+ return [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
232
+
233
+ def encode_prompt(
234
+ self,
235
+ prompt: Union[str, List[str]],
236
+ negative_prompt: Optional[Union[str, List[str]]] = None,
237
+ do_classifier_free_guidance: bool = True,
238
+ num_videos_per_prompt: int = 1,
239
+ prompt_embeds: Optional[torch.Tensor] = None,
240
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
241
+ max_sequence_length: int = 512,
242
+ device: Optional[torch.device] = None,
243
+ dtype: Optional[torch.dtype] = None,
244
+ ):
245
+ r"""
246
+ Encodes the prompt into text encoder hidden states.
247
+
248
+ Args:
249
+ prompt (`str` or `List[str]`, *optional*):
250
+ prompt to be encoded
251
+ negative_prompt (`str` or `List[str]`, *optional*):
252
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
253
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
254
+ less than `1`).
255
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
256
+ Whether to use classifier free guidance or not.
257
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
258
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
259
+ prompt_embeds (`torch.Tensor`, *optional*):
260
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
261
+ provided, text embeddings will be generated from `prompt` input argument.
262
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
263
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
264
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
265
+ argument.
266
+ device: (`torch.device`, *optional*):
267
+ torch device
268
+ dtype: (`torch.dtype`, *optional*):
269
+ torch dtype
270
+ """
271
+ device = device or self._execution_device
272
+
273
+ prompt = [prompt] if isinstance(prompt, str) else prompt
274
+ if prompt is not None:
275
+ batch_size = len(prompt)
276
+ else:
277
+ batch_size = prompt_embeds.shape[0]
278
+
279
+ if prompt_embeds is None:
280
+ prompt_embeds = self._get_t5_prompt_embeds(
281
+ prompt=prompt,
282
+ num_videos_per_prompt=num_videos_per_prompt,
283
+ max_sequence_length=max_sequence_length,
284
+ device=device,
285
+ dtype=dtype,
286
+ )
287
+
288
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
289
+ negative_prompt = negative_prompt or ""
290
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
291
+
292
+ if prompt is not None and type(prompt) is not type(negative_prompt):
293
+ raise TypeError(
294
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
295
+ f" {type(prompt)}."
296
+ )
297
+ elif batch_size != len(negative_prompt):
298
+ raise ValueError(
299
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
300
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
301
+ " the batch size of `prompt`."
302
+ )
303
+
304
+ negative_prompt_embeds = self._get_t5_prompt_embeds(
305
+ prompt=negative_prompt,
306
+ num_videos_per_prompt=num_videos_per_prompt,
307
+ max_sequence_length=max_sequence_length,
308
+ device=device,
309
+ dtype=dtype,
310
+ )
311
+
312
+ return prompt_embeds, negative_prompt_embeds
313
+
314
+ def prepare_latents(
315
+ self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
316
+ ):
317
+ if isinstance(generator, list) and len(generator) != batch_size:
318
+ raise ValueError(
319
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
320
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
321
+ )
322
+
323
+ shape = (
324
+ batch_size,
325
+ num_channels_latents,
326
+ (num_frames - 1) // self.vae.temporal_compression_ratio + 1,
327
+ height // self.vae.spatial_compression_ratio,
328
+ width // self.vae.spatial_compression_ratio,
329
+ )
330
+
331
+ if latents is None:
332
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
333
+ else:
334
+ latents = latents.to(device)
335
+
336
+ # scale the initial noise by the standard deviation required by the scheduler
337
+ if hasattr(self.scheduler, "init_noise_sigma"):
338
+ latents = latents * self.scheduler.init_noise_sigma
339
+ return latents
340
+
341
+ def prepare_control_latents(
342
+ self, control, control_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
343
+ ):
344
+ # resize the control to latents shape as we concatenate the control to the latents
345
+ # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
346
+ # and half precision
347
+
348
+ if control is not None:
349
+ control = control.to(device=device, dtype=dtype)
350
+ bs = 1
351
+ new_control = []
352
+ for i in range(0, control.shape[0], bs):
353
+ control_bs = control[i : i + bs]
354
+ control_bs = self.vae.encode(control_bs)[0]
355
+ control_bs = control_bs.mode()
356
+ new_control.append(control_bs)
357
+ control = torch.cat(new_control, dim = 0)
358
+
359
+ if control_image is not None:
360
+ control_image = control_image.to(device=device, dtype=dtype)
361
+ bs = 1
362
+ new_control_pixel_values = []
363
+ for i in range(0, control_image.shape[0], bs):
364
+ control_pixel_values_bs = control_image[i : i + bs]
365
+ control_pixel_values_bs = self.vae.encode(control_pixel_values_bs)[0]
366
+ control_pixel_values_bs = control_pixel_values_bs.mode()
367
+ new_control_pixel_values.append(control_pixel_values_bs)
368
+ control_image_latents = torch.cat(new_control_pixel_values, dim = 0)
369
+ else:
370
+ control_image_latents = None
371
+
372
+ return control, control_image_latents
373
+
374
+ def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
375
+ frames = self.vae.decode(latents.to(self.vae.dtype)).sample
376
+ frames = (frames / 2 + 0.5).clamp(0, 1)
377
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
378
+ frames = frames.cpu().float().numpy()
379
+ return frames
380
+
381
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
382
+ def prepare_extra_step_kwargs(self, generator, eta):
383
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
384
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
385
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
386
+ # and should be between [0, 1]
387
+
388
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
389
+ extra_step_kwargs = {}
390
+ if accepts_eta:
391
+ extra_step_kwargs["eta"] = eta
392
+
393
+ # check if the scheduler accepts generator
394
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
395
+ if accepts_generator:
396
+ extra_step_kwargs["generator"] = generator
397
+ return extra_step_kwargs
398
+
399
+ # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs
400
+ def check_inputs(
401
+ self,
402
+ prompt,
403
+ height,
404
+ width,
405
+ negative_prompt,
406
+ callback_on_step_end_tensor_inputs,
407
+ prompt_embeds=None,
408
+ negative_prompt_embeds=None,
409
+ ):
410
+ if height % 8 != 0 or width % 8 != 0:
411
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
412
+
413
+ if callback_on_step_end_tensor_inputs is not None and not all(
414
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
415
+ ):
416
+ raise ValueError(
417
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
418
+ )
419
+ if prompt is not None and prompt_embeds is not None:
420
+ raise ValueError(
421
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
422
+ " only forward one of the two."
423
+ )
424
+ elif prompt is None and prompt_embeds is None:
425
+ raise ValueError(
426
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
427
+ )
428
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
429
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
430
+
431
+ if prompt is not None and negative_prompt_embeds is not None:
432
+ raise ValueError(
433
+ f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
434
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
435
+ )
436
+
437
+ if negative_prompt is not None and negative_prompt_embeds is not None:
438
+ raise ValueError(
439
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
440
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
441
+ )
442
+
443
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
444
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
445
+ raise ValueError(
446
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
447
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
448
+ f" {negative_prompt_embeds.shape}."
449
+ )
450
+
451
+ @property
452
+ def guidance_scale(self):
453
+ return self._guidance_scale
454
+
455
+ @property
456
+ def num_timesteps(self):
457
+ return self._num_timesteps
458
+
459
+ @property
460
+ def attention_kwargs(self):
461
+ return self._attention_kwargs
462
+
463
+ @property
464
+ def interrupt(self):
465
+ return self._interrupt
466
+
467
+ @torch.no_grad()
468
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
469
+ def __call__(
470
+ self,
471
+ prompt: Optional[Union[str, List[str]]] = None,
472
+ negative_prompt: Optional[Union[str, List[str]]] = None,
473
+ height: int = 480,
474
+ width: int = 720,
475
+ subject_ref_images: Union[torch.FloatTensor] = None,
476
+ num_frames: int = 49,
477
+ num_inference_steps: int = 50,
478
+ timesteps: Optional[List[int]] = None,
479
+ guidance_scale: float = 6,
480
+ num_videos_per_prompt: int = 1,
481
+ eta: float = 0.0,
482
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
483
+ latents: Optional[torch.FloatTensor] = None,
484
+ prompt_embeds: Optional[torch.FloatTensor] = None,
485
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
486
+ output_type: str = "numpy",
487
+ return_dict: bool = False,
488
+ callback_on_step_end: Optional[
489
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
490
+ ] = None,
491
+ attention_kwargs: Optional[Dict[str, Any]] = None,
492
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
493
+ max_sequence_length: int = 512,
494
+ comfyui_progressbar: bool = False,
495
+ shift: int = 5,
496
+ ) -> Union[WanPipelineOutput, Tuple]:
497
+ """
498
+ Function invoked when calling the pipeline for generation.
499
+ Args:
500
+
501
+ Examples:
502
+
503
+ Returns:
504
+
505
+ """
506
+
507
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
508
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
509
+ num_videos_per_prompt = 1
510
+
511
+ # 1. Check inputs. Raise error if not correct
512
+ self.check_inputs(
513
+ prompt,
514
+ height,
515
+ width,
516
+ negative_prompt,
517
+ callback_on_step_end_tensor_inputs,
518
+ prompt_embeds,
519
+ negative_prompt_embeds,
520
+ )
521
+ self._guidance_scale = guidance_scale
522
+ self._attention_kwargs = attention_kwargs
523
+ self._interrupt = False
524
+
525
+ # 2. Default call parameters
526
+ if prompt is not None and isinstance(prompt, str):
527
+ batch_size = 1
528
+ elif prompt is not None and isinstance(prompt, list):
529
+ batch_size = len(prompt)
530
+ else:
531
+ batch_size = prompt_embeds.shape[0]
532
+
533
+ device = self._execution_device
534
+ weight_dtype = self.text_encoder.dtype
535
+
536
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
537
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
538
+ # corresponds to doing no classifier free guidance.
539
+ do_classifier_free_guidance = guidance_scale > 1.0
540
+
541
+ # 3. Encode input prompt
542
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
543
+ prompt,
544
+ negative_prompt,
545
+ do_classifier_free_guidance,
546
+ num_videos_per_prompt=num_videos_per_prompt,
547
+ prompt_embeds=prompt_embeds,
548
+ negative_prompt_embeds=negative_prompt_embeds,
549
+ max_sequence_length=max_sequence_length,
550
+ device=device,
551
+ )
552
+ if do_classifier_free_guidance:
553
+ in_prompt_embeds = negative_prompt_embeds + prompt_embeds
554
+ else:
555
+ in_prompt_embeds = prompt_embeds
556
+
557
+ # 4. Prepare timesteps
558
+ if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler):
559
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps, mu=1)
560
+ elif isinstance(self.scheduler, FlowUniPCMultistepScheduler):
561
+ self.scheduler.set_timesteps(num_inference_steps, device=device, shift=shift)
562
+ timesteps = self.scheduler.timesteps
563
+ elif isinstance(self.scheduler, FlowDPMSolverMultistepScheduler):
564
+ sampling_sigmas = get_sampling_sigmas(num_inference_steps, shift)
565
+ timesteps, _ = retrieve_timesteps(
566
+ self.scheduler,
567
+ device=device,
568
+ sigmas=sampling_sigmas)
569
+ else:
570
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
571
+ self._num_timesteps = len(timesteps)
572
+ if comfyui_progressbar:
573
+ from comfy.utils import ProgressBar
574
+ pbar = ProgressBar(num_inference_steps + 2)
575
+
576
+ # 5. Prepare latents.
577
+ latent_channels = self.vae.config.latent_channels
578
+ latents = self.prepare_latents(
579
+ batch_size * num_videos_per_prompt,
580
+ latent_channels,
581
+ num_frames,
582
+ height,
583
+ width,
584
+ weight_dtype,
585
+ device,
586
+ generator,
587
+ latents,
588
+ )
589
+ if comfyui_progressbar:
590
+ pbar.update(1)
591
+
592
+ if subject_ref_images is not None:
593
+ video_length = subject_ref_images.shape[2]
594
+ subject_ref_images = self.image_processor.preprocess(rearrange(subject_ref_images, "b c f h w -> (b f) c h w"), height=height, width=width)
595
+ subject_ref_images = subject_ref_images.to(dtype=torch.float32)
596
+ subject_ref_images = rearrange(subject_ref_images, "(b f) c h w -> b c f h w", f=video_length)
597
+
598
+ subject_ref_images_latentes = torch.cat(
599
+ [
600
+ self.prepare_control_latents(
601
+ None,
602
+ subject_ref_images[:, :, i:i+1],
603
+ batch_size,
604
+ height,
605
+ width,
606
+ weight_dtype,
607
+ device,
608
+ generator,
609
+ do_classifier_free_guidance
610
+ )[1] for i in range(video_length)
611
+ ], dim = 2
612
+ )
613
+
614
+ if comfyui_progressbar:
615
+ pbar.update(1)
616
+
617
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
618
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
619
+
620
+ target_shape = (self.vae.latent_channels, (num_frames - 1) // self.vae.temporal_compression_ratio + 1, width // self.vae.spatial_compression_ratio, height // self.vae.spatial_compression_ratio)
621
+ seq_len = math.ceil((target_shape[2] * target_shape[3]) / (self.transformer.config.patch_size[1] * self.transformer.config.patch_size[2]) * target_shape[1])
622
+ # 7. Denoising loop
623
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
624
+ self.transformer.num_inference_steps = num_inference_steps
625
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
626
+ for i, t in enumerate(timesteps):
627
+ self.transformer.current_steps = i
628
+
629
+ if self.interrupt:
630
+ continue
631
+
632
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
633
+ if hasattr(self.scheduler, "scale_model_input"):
634
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
635
+
636
+ if subject_ref_images is not None:
637
+ subject_ref = (
638
+ torch.cat(
639
+ [torch.zeros_like(subject_ref_images_latentes), subject_ref_images_latentes]
640
+ ) if do_classifier_free_guidance else subject_ref_images_latentes
641
+ ).to(device, weight_dtype)
642
+ else:
643
+ subject_ref = None
644
+
645
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
646
+ timestep = t.expand(latent_model_input.shape[0])
647
+
648
+ # predict noise model_output
649
+ with torch.cuda.amp.autocast(dtype=weight_dtype), torch.cuda.device(device=device):
650
+ noise_pred = self.transformer(
651
+ x=latent_model_input,
652
+ context=in_prompt_embeds,
653
+ t=timestep,
654
+ seq_len=seq_len,
655
+ subject_ref=subject_ref,
656
+ )
657
+
658
+ # perform guidance
659
+ if do_classifier_free_guidance:
660
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
661
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
662
+
663
+ # compute the previous noisy sample x_t -> x_t-1
664
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
665
+
666
+ if callback_on_step_end is not None:
667
+ callback_kwargs = {}
668
+ for k in callback_on_step_end_tensor_inputs:
669
+ callback_kwargs[k] = locals()[k]
670
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
671
+
672
+ latents = callback_outputs.pop("latents", latents)
673
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
674
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
675
+
676
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
677
+ progress_bar.update()
678
+ if comfyui_progressbar:
679
+ pbar.update(1)
680
+
681
+ if output_type == "numpy":
682
+ video = self.decode_latents(latents)
683
+ elif not output_type == "latent":
684
+ video = self.decode_latents(latents)
685
+ video = self.video_processor.postprocess_video(video=video, output_type=output_type)
686
+ else:
687
+ video = latents
688
+
689
+ # Offload all models
690
+ self.maybe_free_model_hooks()
691
+
692
+ if not return_dict:
693
+ video = torch.from_numpy(video)
694
+
695
+ return WanPipelineOutput(videos=video)
videox_fun/pipeline/pipeline_wan_vace.py ADDED
@@ -0,0 +1,787 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import math
3
+ from dataclasses import dataclass
4
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn.functional as F
9
+ import torchvision.transforms.functional as TF
10
+ from diffusers import FlowMatchEulerDiscreteScheduler
11
+ from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
12
+ from diffusers.image_processor import VaeImageProcessor
13
+ from diffusers.models.embeddings import get_1d_rotary_pos_embed
14
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
15
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
16
+ from diffusers.utils import BaseOutput, logging, replace_example_docstring
17
+ from diffusers.utils.torch_utils import randn_tensor
18
+ from diffusers.video_processor import VideoProcessor
19
+ from einops import rearrange
20
+ from PIL import Image
21
+ from transformers import T5Tokenizer
22
+
23
+ from ..models import (AutoencoderKLWan, AutoTokenizer,
24
+ WanT5EncoderModel, VaceWanTransformer3DModel)
25
+ from ..utils.fm_solvers import (FlowDPMSolverMultistepScheduler,
26
+ get_sampling_sigmas)
27
+ from ..utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
28
+
29
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
30
+
31
+
32
+ EXAMPLE_DOC_STRING = """
33
+ Examples:
34
+ ```python
35
+ pass
36
+ ```
37
+ """
38
+
39
+
40
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
41
+ def retrieve_timesteps(
42
+ scheduler,
43
+ num_inference_steps: Optional[int] = None,
44
+ device: Optional[Union[str, torch.device]] = None,
45
+ timesteps: Optional[List[int]] = None,
46
+ sigmas: Optional[List[float]] = None,
47
+ **kwargs,
48
+ ):
49
+ """
50
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
51
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
52
+
53
+ Args:
54
+ scheduler (`SchedulerMixin`):
55
+ The scheduler to get timesteps from.
56
+ num_inference_steps (`int`):
57
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
58
+ must be `None`.
59
+ device (`str` or `torch.device`, *optional*):
60
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
61
+ timesteps (`List[int]`, *optional*):
62
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
63
+ `num_inference_steps` and `sigmas` must be `None`.
64
+ sigmas (`List[float]`, *optional*):
65
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
66
+ `num_inference_steps` and `timesteps` must be `None`.
67
+
68
+ Returns:
69
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
70
+ second element is the number of inference steps.
71
+ """
72
+ if timesteps is not None and sigmas is not None:
73
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
74
+ if timesteps is not None:
75
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
76
+ if not accepts_timesteps:
77
+ raise ValueError(
78
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
79
+ f" timestep schedules. Please check whether you are using the correct scheduler."
80
+ )
81
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
82
+ timesteps = scheduler.timesteps
83
+ num_inference_steps = len(timesteps)
84
+ elif sigmas is not None:
85
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
86
+ if not accept_sigmas:
87
+ raise ValueError(
88
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
89
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
90
+ )
91
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
92
+ timesteps = scheduler.timesteps
93
+ num_inference_steps = len(timesteps)
94
+ else:
95
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
96
+ timesteps = scheduler.timesteps
97
+ return timesteps, num_inference_steps
98
+
99
+
100
+ def resize_mask(mask, latent, process_first_frame_only=True):
101
+ latent_size = latent.size()
102
+ batch_size, channels, num_frames, height, width = mask.shape
103
+
104
+ if process_first_frame_only:
105
+ target_size = list(latent_size[2:])
106
+ target_size[0] = 1
107
+ first_frame_resized = F.interpolate(
108
+ mask[:, :, 0:1, :, :],
109
+ size=target_size,
110
+ mode='trilinear',
111
+ align_corners=False
112
+ )
113
+
114
+ target_size = list(latent_size[2:])
115
+ target_size[0] = target_size[0] - 1
116
+ if target_size[0] != 0:
117
+ remaining_frames_resized = F.interpolate(
118
+ mask[:, :, 1:, :, :],
119
+ size=target_size,
120
+ mode='trilinear',
121
+ align_corners=False
122
+ )
123
+ resized_mask = torch.cat([first_frame_resized, remaining_frames_resized], dim=2)
124
+ else:
125
+ resized_mask = first_frame_resized
126
+ else:
127
+ target_size = list(latent_size[2:])
128
+ resized_mask = F.interpolate(
129
+ mask,
130
+ size=target_size,
131
+ mode='trilinear',
132
+ align_corners=False
133
+ )
134
+ return resized_mask
135
+
136
+
137
+ @dataclass
138
+ class WanPipelineOutput(BaseOutput):
139
+ r"""
140
+ Output class for CogVideo pipelines.
141
+
142
+ Args:
143
+ video (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
144
+ List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
145
+ denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
146
+ `(batch_size, num_frames, channels, height, width)`.
147
+ """
148
+
149
+ videos: torch.Tensor
150
+
151
+
152
+ class WanVacePipeline(DiffusionPipeline):
153
+ r"""
154
+ Pipeline for text-to-video generation using Wan.
155
+
156
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
157
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
158
+ """
159
+
160
+ _optional_components = []
161
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
162
+
163
+ _callback_tensor_inputs = [
164
+ "latents",
165
+ "prompt_embeds",
166
+ "negative_prompt_embeds",
167
+ ]
168
+
169
+ def __init__(
170
+ self,
171
+ tokenizer: AutoTokenizer,
172
+ text_encoder: WanT5EncoderModel,
173
+ vae: AutoencoderKLWan,
174
+ transformer: VaceWanTransformer3DModel,
175
+ scheduler: FlowMatchEulerDiscreteScheduler,
176
+ ):
177
+ super().__init__()
178
+
179
+ self.register_modules(
180
+ tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
181
+ )
182
+
183
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae.spatial_compression_ratio)
184
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae.spatial_compression_ratio)
185
+ self.mask_processor = VaeImageProcessor(
186
+ vae_scale_factor=self.vae.spatial_compression_ratio, do_normalize=False, do_binarize=True, do_convert_grayscale=True
187
+ )
188
+
189
+ def _get_t5_prompt_embeds(
190
+ self,
191
+ prompt: Union[str, List[str]] = None,
192
+ num_videos_per_prompt: int = 1,
193
+ max_sequence_length: int = 512,
194
+ device: Optional[torch.device] = None,
195
+ dtype: Optional[torch.dtype] = None,
196
+ ):
197
+ device = device or self._execution_device
198
+ dtype = dtype or self.text_encoder.dtype
199
+
200
+ prompt = [prompt] if isinstance(prompt, str) else prompt
201
+ batch_size = len(prompt)
202
+
203
+ text_inputs = self.tokenizer(
204
+ prompt,
205
+ padding="max_length",
206
+ max_length=max_sequence_length,
207
+ truncation=True,
208
+ add_special_tokens=True,
209
+ return_tensors="pt",
210
+ )
211
+ text_input_ids = text_inputs.input_ids
212
+ prompt_attention_mask = text_inputs.attention_mask
213
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
214
+
215
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
216
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
217
+ logger.warning(
218
+ "The following part of your input was truncated because `max_sequence_length` is set to "
219
+ f" {max_sequence_length} tokens: {removed_text}"
220
+ )
221
+
222
+ seq_lens = prompt_attention_mask.gt(0).sum(dim=1).long()
223
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask.to(device))[0]
224
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
225
+
226
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
227
+ _, seq_len, _ = prompt_embeds.shape
228
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
229
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
230
+
231
+ return [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
232
+
233
+ def encode_prompt(
234
+ self,
235
+ prompt: Union[str, List[str]],
236
+ negative_prompt: Optional[Union[str, List[str]]] = None,
237
+ do_classifier_free_guidance: bool = True,
238
+ num_videos_per_prompt: int = 1,
239
+ prompt_embeds: Optional[torch.Tensor] = None,
240
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
241
+ max_sequence_length: int = 512,
242
+ device: Optional[torch.device] = None,
243
+ dtype: Optional[torch.dtype] = None,
244
+ ):
245
+ r"""
246
+ Encodes the prompt into text encoder hidden states.
247
+
248
+ Args:
249
+ prompt (`str` or `List[str]`, *optional*):
250
+ prompt to be encoded
251
+ negative_prompt (`str` or `List[str]`, *optional*):
252
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
253
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
254
+ less than `1`).
255
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
256
+ Whether to use classifier free guidance or not.
257
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
258
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
259
+ prompt_embeds (`torch.Tensor`, *optional*):
260
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
261
+ provided, text embeddings will be generated from `prompt` input argument.
262
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
263
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
264
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
265
+ argument.
266
+ device: (`torch.device`, *optional*):
267
+ torch device
268
+ dtype: (`torch.dtype`, *optional*):
269
+ torch dtype
270
+ """
271
+ device = device or self._execution_device
272
+
273
+ prompt = [prompt] if isinstance(prompt, str) else prompt
274
+ if prompt is not None:
275
+ batch_size = len(prompt)
276
+ else:
277
+ batch_size = prompt_embeds.shape[0]
278
+
279
+ if prompt_embeds is None:
280
+ prompt_embeds = self._get_t5_prompt_embeds(
281
+ prompt=prompt,
282
+ num_videos_per_prompt=num_videos_per_prompt,
283
+ max_sequence_length=max_sequence_length,
284
+ device=device,
285
+ dtype=dtype,
286
+ )
287
+
288
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
289
+ negative_prompt = negative_prompt or ""
290
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
291
+
292
+ if prompt is not None and type(prompt) is not type(negative_prompt):
293
+ raise TypeError(
294
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
295
+ f" {type(prompt)}."
296
+ )
297
+ elif batch_size != len(negative_prompt):
298
+ raise ValueError(
299
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
300
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
301
+ " the batch size of `prompt`."
302
+ )
303
+
304
+ negative_prompt_embeds = self._get_t5_prompt_embeds(
305
+ prompt=negative_prompt,
306
+ num_videos_per_prompt=num_videos_per_prompt,
307
+ max_sequence_length=max_sequence_length,
308
+ device=device,
309
+ dtype=dtype,
310
+ )
311
+
312
+ return prompt_embeds, negative_prompt_embeds
313
+
314
+ def prepare_latents(
315
+ self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None, num_length_latents=None
316
+ ):
317
+ if isinstance(generator, list) and len(generator) != batch_size:
318
+ raise ValueError(
319
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
320
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
321
+ )
322
+
323
+ shape = (
324
+ batch_size,
325
+ num_channels_latents,
326
+ (num_frames - 1) // self.vae.temporal_compression_ratio + 1 if num_length_latents is None else num_length_latents,
327
+ height // self.vae.spatial_compression_ratio,
328
+ width // self.vae.spatial_compression_ratio,
329
+ )
330
+
331
+ if latents is None:
332
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
333
+ else:
334
+ latents = latents.to(device)
335
+
336
+ # scale the initial noise by the standard deviation required by the scheduler
337
+ if hasattr(self.scheduler, "init_noise_sigma"):
338
+ latents = latents * self.scheduler.init_noise_sigma
339
+ return latents
340
+
341
+ def vace_encode_frames(self, frames, ref_images, masks=None, vae=None):
342
+ vae = self.vae if vae is None else vae
343
+ weight_dtype = frames.dtype
344
+ if ref_images is None:
345
+ ref_images = [None] * len(frames)
346
+ else:
347
+ assert len(frames) == len(ref_images)
348
+
349
+ if masks is None:
350
+ latents = vae.encode(frames)[0].mode()
351
+ else:
352
+ masks = [torch.where(m > 0.5, 1.0, 0.0).to(weight_dtype) for m in masks]
353
+ inactive = [i * (1 - m) + 0 * m for i, m in zip(frames, masks)]
354
+ reactive = [i * m + 0 * (1 - m) for i, m in zip(frames, masks)]
355
+ inactive = vae.encode(inactive)[0].mode()
356
+ reactive = vae.encode(reactive)[0].mode()
357
+ latents = [torch.cat((u, c), dim=0) for u, c in zip(inactive, reactive)]
358
+
359
+ cat_latents = []
360
+ for latent, refs in zip(latents, ref_images):
361
+ if refs is not None:
362
+ if masks is None:
363
+ ref_latent = vae.encode(refs)[0].mode()
364
+ else:
365
+ ref_latent = vae.encode(refs)[0].mode()
366
+ ref_latent = [torch.cat((u, torch.zeros_like(u)), dim=0) for u in ref_latent]
367
+ assert all([x.shape[1] == 1 for x in ref_latent])
368
+ latent = torch.cat([*ref_latent, latent], dim=1)
369
+ cat_latents.append(latent)
370
+ return cat_latents
371
+
372
+ def vace_encode_masks(self, masks, ref_images=None, vae_stride=[4, 8, 8]):
373
+ if ref_images is None:
374
+ ref_images = [None] * len(masks)
375
+ else:
376
+ assert len(masks) == len(ref_images)
377
+
378
+ result_masks = []
379
+ for mask, refs in zip(masks, ref_images):
380
+ c, depth, height, width = mask.shape
381
+ new_depth = int((depth + 3) // vae_stride[0])
382
+ height = 2 * (int(height) // (vae_stride[1] * 2))
383
+ width = 2 * (int(width) // (vae_stride[2] * 2))
384
+
385
+ # reshape
386
+ mask = mask[0, :, :, :]
387
+ mask = mask.view(
388
+ depth, height, vae_stride[1], width, vae_stride[1]
389
+ ) # depth, height, 8, width, 8
390
+ mask = mask.permute(2, 4, 0, 1, 3) # 8, 8, depth, height, width
391
+ mask = mask.reshape(
392
+ vae_stride[1] * vae_stride[2], depth, height, width
393
+ ) # 8*8, depth, height, width
394
+
395
+ # interpolation
396
+ mask = F.interpolate(mask.unsqueeze(0), size=(new_depth, height, width), mode='nearest-exact').squeeze(0)
397
+
398
+ if refs is not None:
399
+ length = len(refs)
400
+ mask_pad = torch.zeros_like(mask[:, :length, :, :])
401
+ mask = torch.cat((mask_pad, mask), dim=1)
402
+ result_masks.append(mask)
403
+ return result_masks
404
+
405
+ def vace_latent(self, z, m):
406
+ return [torch.cat([zz, mm], dim=0) for zz, mm in zip(z, m)]
407
+
408
+ def prepare_control_latents(
409
+ self, control, control_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
410
+ ):
411
+ # resize the control to latents shape as we concatenate the control to the latents
412
+ # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
413
+ # and half precision
414
+
415
+ if control is not None:
416
+ control = control.to(device=device, dtype=dtype)
417
+ bs = 1
418
+ new_control = []
419
+ for i in range(0, control.shape[0], bs):
420
+ control_bs = control[i : i + bs]
421
+ control_bs = self.vae.encode(control_bs)[0]
422
+ control_bs = control_bs.mode()
423
+ new_control.append(control_bs)
424
+ control = torch.cat(new_control, dim = 0)
425
+
426
+ if control_image is not None:
427
+ control_image = control_image.to(device=device, dtype=dtype)
428
+ bs = 1
429
+ new_control_pixel_values = []
430
+ for i in range(0, control_image.shape[0], bs):
431
+ control_pixel_values_bs = control_image[i : i + bs]
432
+ control_pixel_values_bs = self.vae.encode(control_pixel_values_bs)[0]
433
+ control_pixel_values_bs = control_pixel_values_bs.mode()
434
+ new_control_pixel_values.append(control_pixel_values_bs)
435
+ control_image_latents = torch.cat(new_control_pixel_values, dim = 0)
436
+ else:
437
+ control_image_latents = None
438
+
439
+ return control, control_image_latents
440
+
441
+ def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
442
+ frames = self.vae.decode(latents.to(self.vae.dtype)).sample
443
+ frames = (frames / 2 + 0.5).clamp(0, 1)
444
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
445
+ frames = frames.cpu().float().numpy()
446
+ return frames
447
+
448
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
449
+ def prepare_extra_step_kwargs(self, generator, eta):
450
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
451
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
452
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
453
+ # and should be between [0, 1]
454
+
455
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
456
+ extra_step_kwargs = {}
457
+ if accepts_eta:
458
+ extra_step_kwargs["eta"] = eta
459
+
460
+ # check if the scheduler accepts generator
461
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
462
+ if accepts_generator:
463
+ extra_step_kwargs["generator"] = generator
464
+ return extra_step_kwargs
465
+
466
+ # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs
467
+ def check_inputs(
468
+ self,
469
+ prompt,
470
+ height,
471
+ width,
472
+ negative_prompt,
473
+ callback_on_step_end_tensor_inputs,
474
+ prompt_embeds=None,
475
+ negative_prompt_embeds=None,
476
+ ):
477
+ if height % 8 != 0 or width % 8 != 0:
478
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
479
+
480
+ if callback_on_step_end_tensor_inputs is not None and not all(
481
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
482
+ ):
483
+ raise ValueError(
484
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
485
+ )
486
+ if prompt is not None and prompt_embeds is not None:
487
+ raise ValueError(
488
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
489
+ " only forward one of the two."
490
+ )
491
+ elif prompt is None and prompt_embeds is None:
492
+ raise ValueError(
493
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
494
+ )
495
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
496
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
497
+
498
+ if prompt is not None and negative_prompt_embeds is not None:
499
+ raise ValueError(
500
+ f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
501
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
502
+ )
503
+
504
+ if negative_prompt is not None and negative_prompt_embeds is not None:
505
+ raise ValueError(
506
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
507
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
508
+ )
509
+
510
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
511
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
512
+ raise ValueError(
513
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
514
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
515
+ f" {negative_prompt_embeds.shape}."
516
+ )
517
+
518
+ @property
519
+ def guidance_scale(self):
520
+ return self._guidance_scale
521
+
522
+ @property
523
+ def num_timesteps(self):
524
+ return self._num_timesteps
525
+
526
+ @property
527
+ def attention_kwargs(self):
528
+ return self._attention_kwargs
529
+
530
+ @property
531
+ def interrupt(self):
532
+ return self._interrupt
533
+
534
+ @torch.no_grad()
535
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
536
+ def __call__(
537
+ self,
538
+ prompt: Optional[Union[str, List[str]]] = None,
539
+ negative_prompt: Optional[Union[str, List[str]]] = None,
540
+ height: int = 480,
541
+ width: int = 720,
542
+ video: Union[torch.FloatTensor] = None,
543
+ mask_video: Union[torch.FloatTensor] = None,
544
+ control_video: Union[torch.FloatTensor] = None,
545
+ subject_ref_images: Union[torch.FloatTensor] = None,
546
+ num_frames: int = 49,
547
+ num_inference_steps: int = 50,
548
+ timesteps: Optional[List[int]] = None,
549
+ guidance_scale: float = 6,
550
+ num_videos_per_prompt: int = 1,
551
+ eta: float = 0.0,
552
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
553
+ latents: Optional[torch.FloatTensor] = None,
554
+ prompt_embeds: Optional[torch.FloatTensor] = None,
555
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
556
+ output_type: str = "numpy",
557
+ return_dict: bool = False,
558
+ callback_on_step_end: Optional[
559
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
560
+ ] = None,
561
+ attention_kwargs: Optional[Dict[str, Any]] = None,
562
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
563
+ max_sequence_length: int = 512,
564
+ comfyui_progressbar: bool = False,
565
+ shift: int = 5,
566
+ vace_context_scale: float = 1.0
567
+ ) -> Union[WanPipelineOutput, Tuple]:
568
+ """
569
+ Function invoked when calling the pipeline for generation.
570
+ Args:
571
+
572
+ Examples:
573
+
574
+ Returns:
575
+
576
+ """
577
+
578
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
579
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
580
+ num_videos_per_prompt = 1
581
+
582
+ # 1. Check inputs. Raise error if not correct
583
+ self.check_inputs(
584
+ prompt,
585
+ height,
586
+ width,
587
+ negative_prompt,
588
+ callback_on_step_end_tensor_inputs,
589
+ prompt_embeds,
590
+ negative_prompt_embeds,
591
+ )
592
+ self._guidance_scale = guidance_scale
593
+ self._attention_kwargs = attention_kwargs
594
+ self._interrupt = False
595
+
596
+ # 2. Default call parameters
597
+ if prompt is not None and isinstance(prompt, str):
598
+ batch_size = 1
599
+ elif prompt is not None and isinstance(prompt, list):
600
+ batch_size = len(prompt)
601
+ else:
602
+ batch_size = prompt_embeds.shape[0]
603
+
604
+ device = self._execution_device
605
+ weight_dtype = self.text_encoder.dtype
606
+
607
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
608
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
609
+ # corresponds to doing no classifier free guidance.
610
+ do_classifier_free_guidance = guidance_scale > 1.0
611
+
612
+ # 3. Encode input prompt
613
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
614
+ prompt,
615
+ negative_prompt,
616
+ do_classifier_free_guidance,
617
+ num_videos_per_prompt=num_videos_per_prompt,
618
+ prompt_embeds=prompt_embeds,
619
+ negative_prompt_embeds=negative_prompt_embeds,
620
+ max_sequence_length=max_sequence_length,
621
+ device=device,
622
+ )
623
+ if do_classifier_free_guidance:
624
+ in_prompt_embeds = negative_prompt_embeds + prompt_embeds
625
+ else:
626
+ in_prompt_embeds = prompt_embeds
627
+
628
+ # 4. Prepare timesteps
629
+ if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler):
630
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps, mu=1)
631
+ elif isinstance(self.scheduler, FlowUniPCMultistepScheduler):
632
+ self.scheduler.set_timesteps(num_inference_steps, device=device, shift=shift)
633
+ timesteps = self.scheduler.timesteps
634
+ elif isinstance(self.scheduler, FlowDPMSolverMultistepScheduler):
635
+ sampling_sigmas = get_sampling_sigmas(num_inference_steps, shift)
636
+ timesteps, _ = retrieve_timesteps(
637
+ self.scheduler,
638
+ device=device,
639
+ sigmas=sampling_sigmas)
640
+ else:
641
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
642
+ self._num_timesteps = len(timesteps)
643
+ if comfyui_progressbar:
644
+ from comfy.utils import ProgressBar
645
+ pbar = ProgressBar(num_inference_steps + 2)
646
+
647
+ latent_channels = self.vae.config.latent_channels
648
+
649
+ if comfyui_progressbar:
650
+ pbar.update(1)
651
+
652
+ # Prepare mask latent variables
653
+ if mask_video is not None:
654
+ bs, _, video_length, height, width = video.size()
655
+ mask_condition = self.mask_processor.preprocess(rearrange(mask_video, "b c f h w -> (b f) c h w"), height=height, width=width)
656
+ mask_condition = mask_condition.to(dtype=torch.float32)
657
+ mask_condition = rearrange(mask_condition, "(b f) c h w -> b c f h w", f=video_length)
658
+ mask_condition = torch.tile(mask_condition, [1, 3, 1, 1, 1]).to(dtype=weight_dtype, device=device)
659
+
660
+
661
+ if control_video is not None:
662
+ video_length = control_video.shape[2]
663
+ control_video = self.image_processor.preprocess(rearrange(control_video, "b c f h w -> (b f) c h w"), height=height, width=width)
664
+ control_video = control_video.to(dtype=torch.float32)
665
+ input_video = rearrange(control_video, "(b f) c h w -> b c f h w", f=video_length)
666
+
667
+ input_video = input_video.to(dtype=weight_dtype, device=device)
668
+
669
+ elif video is not None:
670
+ video_length = video.shape[2]
671
+ init_video = self.image_processor.preprocess(rearrange(video, "b c f h w -> (b f) c h w"), height=height, width=width)
672
+ init_video = init_video.to(dtype=torch.float32)
673
+ init_video = rearrange(init_video, "(b f) c h w -> b c f h w", f=video_length).to(dtype=weight_dtype, device=device)
674
+
675
+ input_video = init_video * (mask_condition < 0.5)
676
+ input_video = input_video.to(dtype=weight_dtype, device=device)
677
+
678
+ if subject_ref_images is not None:
679
+ video_length = subject_ref_images.shape[2]
680
+ subject_ref_images = self.image_processor.preprocess(rearrange(subject_ref_images, "b c f h w -> (b f) c h w"), height=height, width=width)
681
+ subject_ref_images = subject_ref_images.to(dtype=torch.float32)
682
+ subject_ref_images = rearrange(subject_ref_images, "(b f) c h w -> b c f h w", f=video_length)
683
+ subject_ref_images = subject_ref_images.to(dtype=weight_dtype, device=device)
684
+
685
+ bs, c, f, h, w = subject_ref_images.size()
686
+ new_subject_ref_images = []
687
+ for i in range(bs):
688
+ new_subject_ref_images.append([])
689
+ for j in range(f):
690
+ new_subject_ref_images[i].append(subject_ref_images[i, :, j:j+1])
691
+ subject_ref_images = new_subject_ref_images
692
+
693
+ vace_latents = self.vace_encode_frames(input_video, subject_ref_images, masks=mask_condition, vae=self.vae)
694
+ mask_latents = self.vace_encode_masks(mask_condition, subject_ref_images)
695
+ vace_context = self.vace_latent(vace_latents, mask_latents)
696
+
697
+ # 5. Prepare latents.
698
+ latents = self.prepare_latents(
699
+ batch_size * num_videos_per_prompt,
700
+ latent_channels,
701
+ num_frames,
702
+ height,
703
+ width,
704
+ weight_dtype,
705
+ device,
706
+ generator,
707
+ latents,
708
+ num_length_latents=vace_latents[0].size(1)
709
+ )
710
+
711
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
712
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
713
+
714
+ target_shape = (self.vae.latent_channels, vace_latents[0].size(1), vace_latents[0].size(2), vace_latents[0].size(3))
715
+ seq_len = math.ceil((target_shape[2] * target_shape[3]) / (self.transformer.config.patch_size[1] * self.transformer.config.patch_size[2]) * target_shape[1])
716
+ # 7. Denoising loop
717
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
718
+ self.transformer.num_inference_steps = num_inference_steps
719
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
720
+ for i, t in enumerate(timesteps):
721
+ self.transformer.current_steps = i
722
+
723
+ if self.interrupt:
724
+ continue
725
+
726
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
727
+ if hasattr(self.scheduler, "scale_model_input"):
728
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
729
+
730
+ vace_context_input = torch.stack(vace_context * 2) if do_classifier_free_guidance else vace_context
731
+
732
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
733
+ timestep = t.expand(latent_model_input.shape[0])
734
+
735
+ # predict noise model_output
736
+ with torch.cuda.amp.autocast(dtype=weight_dtype), torch.cuda.device(device=device):
737
+ noise_pred = self.transformer(
738
+ x=latent_model_input,
739
+ context=in_prompt_embeds,
740
+ t=timestep,
741
+ vace_context=vace_context_input,
742
+ seq_len=seq_len,
743
+ vace_context_scale=vace_context_scale
744
+ )
745
+
746
+ # perform guidance
747
+ if do_classifier_free_guidance:
748
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
749
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
750
+
751
+ # compute the previous noisy sample x_t -> x_t-1
752
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
753
+
754
+ if callback_on_step_end is not None:
755
+ callback_kwargs = {}
756
+ for k in callback_on_step_end_tensor_inputs:
757
+ callback_kwargs[k] = locals()[k]
758
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
759
+
760
+ latents = callback_outputs.pop("latents", latents)
761
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
762
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
763
+
764
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
765
+ progress_bar.update()
766
+ if comfyui_progressbar:
767
+ pbar.update(1)
768
+
769
+ if subject_ref_images is not None:
770
+ len_subject_ref_images = len(subject_ref_images[0])
771
+ latents = latents[:, :, len_subject_ref_images:, :, :]
772
+
773
+ if output_type == "numpy":
774
+ video = self.decode_latents(latents)
775
+ elif not output_type == "latent":
776
+ video = self.decode_latents(latents)
777
+ video = self.video_processor.postprocess_video(video=video, output_type=output_type)
778
+ else:
779
+ video = latents
780
+
781
+ # Offload all models
782
+ self.maybe_free_model_hooks()
783
+
784
+ if not return_dict:
785
+ video = torch.from_numpy(video)
786
+
787
+ return WanPipelineOutput(videos=video)
videox_fun/pipeline/pipeline_z_image.py ADDED
@@ -0,0 +1,613 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ import numpy as np
17
+ import PIL
18
+ from dataclasses import dataclass
19
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
20
+
21
+ import torch
22
+ from diffusers.image_processor import VaeImageProcessor
23
+ from diffusers.loaders import FromSingleFileMixin
24
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
25
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
26
+ from diffusers.utils import (BaseOutput, is_torch_xla_available, logging,
27
+ replace_example_docstring)
28
+ from diffusers.utils.torch_utils import randn_tensor
29
+ from transformers import AutoTokenizer, PreTrainedModel
30
+
31
+ from ..models import AutoencoderKL, ZImageTransformer2DModel
32
+
33
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
34
+
35
+ EXAMPLE_DOC_STRING = """
36
+ Examples:
37
+ ```py
38
+ >>> import torch
39
+ >>> from diffusers import ZImagePipeline
40
+
41
+ >>> pipe = ZImagePipeline.from_pretrained("Z-a-o/Z-Image-Turbo", torch_dtype=torch.bfloat16)
42
+ >>> pipe.to("cuda")
43
+
44
+ >>> # Optionally, set the attention backend to flash-attn 2 or 3, default is SDPA in PyTorch.
45
+ >>> # (1) Use flash attention 2
46
+ >>> # pipe.transformer.set_attention_backend("flash")
47
+ >>> # (2) Use flash attention 3
48
+ >>> # pipe.transformer.set_attention_backend("_flash_3")
49
+
50
+ >>> prompt = "一幅为名为“造相「Z-IMAGE-TURBO」”的项目设计的创意海报。画面巧妙地将文字概念视觉化:一辆复古蒸汽小火车化身为巨大的拉链头,正拉开厚厚的冬日积雪,展露出一个生机盎然的春天。"
51
+ >>> image = pipe(
52
+ diffusers. prompt,
53
+ diffusers. height=1024,
54
+ diffusers. width=1024,
55
+ diffusers. num_inference_steps=9,
56
+ diffusers. guidance_scale=0.0,
57
+ diffusers. generator=torch.Generator("cuda").manual_seed(42),
58
+ diffusers. ).images[0]
59
+ >>> image.save("zimage.png")
60
+ ```
61
+ """
62
+
63
+
64
+ # Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
65
+ def calculate_shift(
66
+ image_seq_len,
67
+ base_seq_len: int = 256,
68
+ max_seq_len: int = 4096,
69
+ base_shift: float = 0.5,
70
+ max_shift: float = 1.15,
71
+ ):
72
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
73
+ b = base_shift - m * base_seq_len
74
+ mu = image_seq_len * m + b
75
+ return mu
76
+
77
+
78
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
79
+ def retrieve_timesteps(
80
+ scheduler,
81
+ num_inference_steps: Optional[int] = None,
82
+ device: Optional[Union[str, torch.device]] = None,
83
+ timesteps: Optional[List[int]] = None,
84
+ sigmas: Optional[List[float]] = None,
85
+ **kwargs,
86
+ ):
87
+ r"""
88
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
89
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
90
+
91
+ Args:
92
+ scheduler (`SchedulerMixin`):
93
+ The scheduler to get timesteps from.
94
+ num_inference_steps (`int`):
95
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
96
+ must be `None`.
97
+ device (`str` or `torch.device`, *optional*):
98
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
99
+ timesteps (`List[int]`, *optional*):
100
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
101
+ `num_inference_steps` and `sigmas` must be `None`.
102
+ sigmas (`List[float]`, *optional*):
103
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
104
+ `num_inference_steps` and `timesteps` must be `None`.
105
+
106
+ Returns:
107
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
108
+ second element is the number of inference steps.
109
+ """
110
+ if timesteps is not None and sigmas is not None:
111
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
112
+ if timesteps is not None:
113
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
114
+ if not accepts_timesteps:
115
+ raise ValueError(
116
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
117
+ f" timestep schedules. Please check whether you are using the correct scheduler."
118
+ )
119
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
120
+ timesteps = scheduler.timesteps
121
+ num_inference_steps = len(timesteps)
122
+ elif sigmas is not None:
123
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
124
+ if not accept_sigmas:
125
+ raise ValueError(
126
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
127
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
128
+ )
129
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
130
+ timesteps = scheduler.timesteps
131
+ num_inference_steps = len(timesteps)
132
+ else:
133
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
134
+ timesteps = scheduler.timesteps
135
+ return timesteps, num_inference_steps
136
+
137
+
138
+ @dataclass
139
+ class ZImagePipelineOutput(BaseOutput):
140
+ """
141
+ Output class for Z-Image image generation pipelines.
142
+
143
+ Args:
144
+ images (`List[PIL.Image.Image]` or `torch.Tensor` or `np.ndarray`)
145
+ List of denoised PIL images of length `batch_size` or numpy array or torch tensor of shape `(batch_size,
146
+ height, width, num_channels)`. PIL images or numpy array present the denoised images of the diffusion
147
+ pipeline. Torch tensors can represent either the denoised images or the intermediate latents ready to be
148
+ passed to the decoder.
149
+ """
150
+
151
+ images: Union[List[PIL.Image.Image], np.ndarray]
152
+
153
+
154
+ class ZImagePipeline(DiffusionPipeline, FromSingleFileMixin):
155
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
156
+ _optional_components = []
157
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
158
+
159
+ def __init__(
160
+ self,
161
+ scheduler: FlowMatchEulerDiscreteScheduler,
162
+ vae: AutoencoderKL,
163
+ text_encoder: PreTrainedModel,
164
+ tokenizer: AutoTokenizer,
165
+ transformer: ZImageTransformer2DModel,
166
+ ):
167
+ super().__init__()
168
+
169
+ self.register_modules(
170
+ vae=vae,
171
+ text_encoder=text_encoder,
172
+ tokenizer=tokenizer,
173
+ scheduler=scheduler,
174
+ transformer=transformer,
175
+ )
176
+ self.vae_scale_factor = (
177
+ 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
178
+ )
179
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
180
+
181
+ def encode_prompt(
182
+ self,
183
+ prompt: Union[str, List[str]],
184
+ device: Optional[torch.device] = None,
185
+ do_classifier_free_guidance: bool = True,
186
+ negative_prompt: Optional[Union[str, List[str]]] = None,
187
+ prompt_embeds: Optional[List[torch.FloatTensor]] = None,
188
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
189
+ max_sequence_length: int = 512,
190
+ ):
191
+ prompt = [prompt] if isinstance(prompt, str) else prompt
192
+ prompt_embeds = self._encode_prompt(
193
+ prompt=prompt,
194
+ device=device,
195
+ prompt_embeds=prompt_embeds,
196
+ max_sequence_length=max_sequence_length,
197
+ )
198
+
199
+ if do_classifier_free_guidance:
200
+ if negative_prompt is None:
201
+ negative_prompt = ["" for _ in prompt]
202
+ else:
203
+ negative_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
204
+ assert len(prompt) == len(negative_prompt)
205
+ negative_prompt_embeds = self._encode_prompt(
206
+ prompt=negative_prompt,
207
+ device=device,
208
+ prompt_embeds=negative_prompt_embeds,
209
+ max_sequence_length=max_sequence_length,
210
+ )
211
+ else:
212
+ negative_prompt_embeds = []
213
+ return prompt_embeds, negative_prompt_embeds
214
+
215
+ def _encode_prompt(
216
+ self,
217
+ prompt: Union[str, List[str]],
218
+ device: Optional[torch.device] = None,
219
+ prompt_embeds: Optional[List[torch.FloatTensor]] = None,
220
+ max_sequence_length: int = 512,
221
+ ) -> List[torch.FloatTensor]:
222
+ device = device or self._execution_device
223
+
224
+ if prompt_embeds is not None:
225
+ return prompt_embeds
226
+
227
+ if isinstance(prompt, str):
228
+ prompt = [prompt]
229
+
230
+ for i, prompt_item in enumerate(prompt):
231
+ messages = [
232
+ {"role": "user", "content": prompt_item},
233
+ ]
234
+ prompt_item = self.tokenizer.apply_chat_template(
235
+ messages,
236
+ tokenize=False,
237
+ add_generation_prompt=True,
238
+ enable_thinking=True,
239
+ )
240
+ prompt[i] = prompt_item
241
+
242
+ text_inputs = self.tokenizer(
243
+ prompt,
244
+ padding="max_length",
245
+ max_length=max_sequence_length,
246
+ truncation=True,
247
+ return_tensors="pt",
248
+ )
249
+
250
+ text_input_ids = text_inputs.input_ids.to(device)
251
+ prompt_masks = text_inputs.attention_mask.to(device).bool()
252
+
253
+ prompt_embeds = self.text_encoder(
254
+ input_ids=text_input_ids,
255
+ attention_mask=prompt_masks,
256
+ output_hidden_states=True,
257
+ ).hidden_states[-2]
258
+
259
+ embeddings_list = []
260
+
261
+ for i in range(len(prompt_embeds)):
262
+ embeddings_list.append(prompt_embeds[i][prompt_masks[i]])
263
+
264
+ return embeddings_list
265
+
266
+ def prepare_latents(
267
+ self,
268
+ batch_size,
269
+ num_channels_latents,
270
+ height,
271
+ width,
272
+ dtype,
273
+ device,
274
+ generator,
275
+ latents=None,
276
+ ):
277
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
278
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
279
+
280
+ shape = (batch_size, num_channels_latents, height, width)
281
+
282
+ if latents is None:
283
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
284
+ else:
285
+ if latents.shape != shape:
286
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
287
+ latents = latents.to(device)
288
+ return latents
289
+
290
+ @property
291
+ def guidance_scale(self):
292
+ return self._guidance_scale
293
+
294
+ @property
295
+ def do_classifier_free_guidance(self):
296
+ return self._guidance_scale > 1
297
+
298
+ @property
299
+ def joint_attention_kwargs(self):
300
+ return self._joint_attention_kwargs
301
+
302
+ @property
303
+ def num_timesteps(self):
304
+ return self._num_timesteps
305
+
306
+ @property
307
+ def interrupt(self):
308
+ return self._interrupt
309
+
310
+ @torch.no_grad()
311
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
312
+ def __call__(
313
+ self,
314
+ prompt: Union[str, List[str]] = None,
315
+ height: Optional[int] = None,
316
+ width: Optional[int] = None,
317
+ num_inference_steps: int = 50,
318
+ sigmas: Optional[List[float]] = None,
319
+ guidance_scale: float = 5.0,
320
+ cfg_normalization: bool = False,
321
+ cfg_truncation: float = 1.0,
322
+ negative_prompt: Optional[Union[str, List[str]]] = None,
323
+ num_images_per_prompt: Optional[int] = 1,
324
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
325
+ latents: Optional[torch.FloatTensor] = None,
326
+ prompt_embeds: Optional[List[torch.FloatTensor]] = None,
327
+ negative_prompt_embeds: Optional[List[torch.FloatTensor]] = None,
328
+ output_type: Optional[str] = "pil",
329
+ return_dict: bool = True,
330
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
331
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
332
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
333
+ max_sequence_length: int = 512,
334
+ ):
335
+ r"""
336
+ Function invoked when calling the pipeline for generation.
337
+
338
+ Args:
339
+ prompt (`str` or `List[str]`, *optional*):
340
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
341
+ instead.
342
+ height (`int`, *optional*, defaults to 1024):
343
+ The height in pixels of the generated image.
344
+ width (`int`, *optional*, defaults to 1024):
345
+ The width in pixels of the generated image.
346
+ num_inference_steps (`int`, *optional*, defaults to 50):
347
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
348
+ expense of slower inference.
349
+ sigmas (`List[float]`, *optional*):
350
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
351
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
352
+ will be used.
353
+ guidance_scale (`float`, *optional*, defaults to 5.0):
354
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
355
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
356
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
357
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
358
+ usually at the expense of lower image quality.
359
+ cfg_normalization (`bool`, *optional*, defaults to False):
360
+ Whether to apply configuration normalization.
361
+ cfg_truncation (`float`, *optional*, defaults to 1.0):
362
+ The truncation value for configuration.
363
+ negative_prompt (`str` or `List[str]`, *optional*):
364
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
365
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
366
+ less than `1`).
367
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
368
+ The number of images to generate per prompt.
369
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
370
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
371
+ to make generation deterministic.
372
+ latents (`torch.FloatTensor`, *optional*):
373
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
374
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
375
+ tensor will be generated by sampling using the supplied random `generator`.
376
+ prompt_embeds (`List[torch.FloatTensor]`, *optional*):
377
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
378
+ provided, text embeddings will be generated from `prompt` input argument.
379
+ negative_prompt_embeds (`List[torch.FloatTensor]`, *optional*):
380
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
381
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
382
+ argument.
383
+ output_type (`str`, *optional*, defaults to `"pil"`):
384
+ The output format of the generate image. Choose between
385
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
386
+ return_dict (`bool`, *optional*, defaults to `True`):
387
+ Whether or not to return a [`~pipelines.stable_diffusion.ZImagePipelineOutput`] instead of a plain
388
+ tuple.
389
+ joint_attention_kwargs (`dict`, *optional*):
390
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
391
+ `self.processor` in
392
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
393
+ callback_on_step_end (`Callable`, *optional*):
394
+ A function that calls at the end of each denoising steps during the inference. The function is called
395
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
396
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
397
+ `callback_on_step_end_tensor_inputs`.
398
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
399
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
400
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
401
+ `._callback_tensor_inputs` attribute of your pipeline class.
402
+ max_sequence_length (`int`, *optional*, defaults to 512):
403
+ Maximum sequence length to use with the `prompt`.
404
+
405
+ Examples:
406
+
407
+ Returns:
408
+ [`~pipelines.z_image.ZImagePipelineOutput`] or `tuple`: [`~pipelines.z_image.ZImagePipelineOutput`] if
409
+ `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the
410
+ generated images.
411
+ """
412
+ height = height or 1024
413
+ width = width or 1024
414
+
415
+ vae_scale = self.vae_scale_factor * 2
416
+ if height % vae_scale != 0:
417
+ raise ValueError(
418
+ f"Height must be divisible by {vae_scale} (got {height}). "
419
+ f"Please adjust the height to a multiple of {vae_scale}."
420
+ )
421
+ if width % vae_scale != 0:
422
+ raise ValueError(
423
+ f"Width must be divisible by {vae_scale} (got {width}). "
424
+ f"Please adjust the width to a multiple of {vae_scale}."
425
+ )
426
+
427
+ device = self._execution_device
428
+
429
+ self._guidance_scale = guidance_scale
430
+ self._joint_attention_kwargs = joint_attention_kwargs
431
+ self._interrupt = False
432
+ self._cfg_normalization = cfg_normalization
433
+ self._cfg_truncation = cfg_truncation
434
+ # 2. Define call parameters
435
+ if prompt is not None and isinstance(prompt, str):
436
+ batch_size = 1
437
+ elif prompt is not None and isinstance(prompt, list):
438
+ batch_size = len(prompt)
439
+ else:
440
+ batch_size = len(prompt_embeds)
441
+
442
+ # If prompt_embeds is provided and prompt is None, skip encoding
443
+ if prompt_embeds is not None and prompt is None:
444
+ if self.do_classifier_free_guidance and negative_prompt_embeds is None:
445
+ raise ValueError(
446
+ "When `prompt_embeds` is provided without `prompt`, "
447
+ "`negative_prompt_embeds` must also be provided for classifier-free guidance."
448
+ )
449
+ else:
450
+ (
451
+ prompt_embeds,
452
+ negative_prompt_embeds,
453
+ ) = self.encode_prompt(
454
+ prompt=prompt,
455
+ negative_prompt=negative_prompt,
456
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
457
+ prompt_embeds=prompt_embeds,
458
+ negative_prompt_embeds=negative_prompt_embeds,
459
+ device=device,
460
+ max_sequence_length=max_sequence_length,
461
+ )
462
+
463
+ # 4. Prepare latent variables
464
+ num_channels_latents = self.transformer.in_channels
465
+
466
+ latents = self.prepare_latents(
467
+ batch_size * num_images_per_prompt,
468
+ num_channels_latents,
469
+ height,
470
+ width,
471
+ torch.float32,
472
+ device,
473
+ generator,
474
+ latents,
475
+ )
476
+
477
+ # Repeat prompt_embeds for num_images_per_prompt
478
+ if num_images_per_prompt > 1:
479
+ prompt_embeds = [pe for pe in prompt_embeds for _ in range(num_images_per_prompt)]
480
+ if self.do_classifier_free_guidance and negative_prompt_embeds:
481
+ negative_prompt_embeds = [npe for npe in negative_prompt_embeds for _ in range(num_images_per_prompt)]
482
+
483
+ actual_batch_size = batch_size * num_images_per_prompt
484
+ image_seq_len = (latents.shape[2] // 2) * (latents.shape[3] // 2)
485
+
486
+ # 5. Prepare timesteps
487
+ mu = calculate_shift(
488
+ image_seq_len,
489
+ self.scheduler.config.get("base_image_seq_len", 256),
490
+ self.scheduler.config.get("max_image_seq_len", 4096),
491
+ self.scheduler.config.get("base_shift", 0.5),
492
+ self.scheduler.config.get("max_shift", 1.15),
493
+ )
494
+ self.scheduler.sigma_min = 0.0
495
+ scheduler_kwargs = {"mu": mu}
496
+ timesteps, num_inference_steps = retrieve_timesteps(
497
+ self.scheduler,
498
+ num_inference_steps,
499
+ device,
500
+ sigmas=sigmas,
501
+ **scheduler_kwargs,
502
+ )
503
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
504
+ self._num_timesteps = len(timesteps)
505
+
506
+ # 6. Denoising loop
507
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
508
+ for i, t in enumerate(timesteps):
509
+ if self.interrupt:
510
+ continue
511
+
512
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
513
+ timestep = t.expand(latents.shape[0])
514
+ timestep = (1000 - timestep) / 1000
515
+ # Normalized time for time-aware config (0 at start, 1 at end)
516
+ t_norm = timestep[0].item()
517
+
518
+ # Handle cfg truncation
519
+ current_guidance_scale = self.guidance_scale
520
+ if (
521
+ self.do_classifier_free_guidance
522
+ and self._cfg_truncation is not None
523
+ and float(self._cfg_truncation) <= 1
524
+ ):
525
+ if t_norm > self._cfg_truncation:
526
+ current_guidance_scale = 0.0
527
+
528
+ # Run CFG only if configured AND scale is non-zero
529
+ apply_cfg = self.do_classifier_free_guidance and current_guidance_scale > 0
530
+
531
+ if apply_cfg:
532
+ latents_typed = latents.to(self.transformer.dtype)
533
+ latent_model_input = latents_typed.repeat(2, 1, 1, 1)
534
+ prompt_embeds_model_input = prompt_embeds + negative_prompt_embeds
535
+ timestep_model_input = timestep.repeat(2)
536
+ else:
537
+ latent_model_input = latents.to(self.transformer.dtype)
538
+ prompt_embeds_model_input = prompt_embeds
539
+ timestep_model_input = timestep
540
+
541
+ latent_model_input = latent_model_input.unsqueeze(2)
542
+ latent_model_input_list = list(latent_model_input.unbind(dim=0))
543
+
544
+ model_out_list = self.transformer(
545
+ latent_model_input_list,
546
+ timestep_model_input,
547
+ prompt_embeds_model_input,
548
+ )[0]
549
+
550
+ if apply_cfg:
551
+ # Perform CFG
552
+ pos_out = model_out_list[:actual_batch_size]
553
+ neg_out = model_out_list[actual_batch_size:]
554
+
555
+ noise_pred = []
556
+ for j in range(actual_batch_size):
557
+ pos = pos_out[j].float()
558
+ neg = neg_out[j].float()
559
+
560
+ pred = pos + current_guidance_scale * (pos - neg)
561
+
562
+ # Renormalization
563
+ if self._cfg_normalization and float(self._cfg_normalization) > 0.0:
564
+ ori_pos_norm = torch.linalg.vector_norm(pos)
565
+ new_pos_norm = torch.linalg.vector_norm(pred)
566
+ max_new_norm = ori_pos_norm * float(self._cfg_normalization)
567
+ if new_pos_norm > max_new_norm:
568
+ pred = pred * (max_new_norm / new_pos_norm)
569
+
570
+ noise_pred.append(pred)
571
+
572
+ noise_pred = torch.stack(noise_pred, dim=0)
573
+ else:
574
+ noise_pred = torch.stack([t.float() for t in model_out_list], dim=0)
575
+
576
+ noise_pred = noise_pred.squeeze(2)
577
+ noise_pred = -noise_pred
578
+
579
+ # compute the previous noisy sample x_t -> x_t-1
580
+ latents = self.scheduler.step(noise_pred.to(torch.float32), t, latents, return_dict=False)[0]
581
+ assert latents.dtype == torch.float32
582
+
583
+ if callback_on_step_end is not None:
584
+ callback_kwargs = {}
585
+ for k in callback_on_step_end_tensor_inputs:
586
+ callback_kwargs[k] = locals()[k]
587
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
588
+
589
+ latents = callback_outputs.pop("latents", latents)
590
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
591
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
592
+
593
+ # call the callback, if provided
594
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
595
+ progress_bar.update()
596
+
597
+ if output_type == "latent":
598
+ image = latents
599
+
600
+ else:
601
+ latents = latents.to(self.vae.dtype)
602
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
603
+
604
+ image = self.vae.decode(latents, return_dict=False)[0]
605
+ image = self.image_processor.postprocess(image, output_type=output_type)
606
+
607
+ # Offload all models
608
+ self.maybe_free_model_hooks()
609
+
610
+ if not return_dict:
611
+ return (image,)
612
+
613
+ return ZImagePipelineOutput(images=image)
videox_fun/pipeline/pipeline_z_image_control.py ADDED
@@ -0,0 +1,633 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ import numpy as np
17
+ import PIL
18
+ from dataclasses import dataclass
19
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
20
+
21
+ import torch
22
+ import torch.nn.functional as F
23
+ from diffusers.image_processor import VaeImageProcessor
24
+ from diffusers.loaders import FromSingleFileMixin
25
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
26
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
27
+ from diffusers.utils import (BaseOutput, is_torch_xla_available, logging,
28
+ replace_example_docstring)
29
+ from diffusers.utils.torch_utils import randn_tensor
30
+ from transformers import AutoTokenizer, PreTrainedModel
31
+
32
+ from ..models import AutoencoderKL, ZImageTransformer2DModel
33
+
34
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
35
+
36
+ EXAMPLE_DOC_STRING = """
37
+ Examples:
38
+ ```py
39
+ >>> import torch
40
+ >>> from diffusers import ZImagePipeline
41
+
42
+ >>> pipe = ZImagePipeline.from_pretrained("Z-a-o/Z-Image-Turbo", torch_dtype=torch.bfloat16)
43
+ >>> pipe.to("cuda")
44
+
45
+ >>> # Optionally, set the attention backend to flash-attn 2 or 3, default is SDPA in PyTorch.
46
+ >>> # (1) Use flash attention 2
47
+ >>> # pipe.transformer.set_attention_backend("flash")
48
+ >>> # (2) Use flash attention 3
49
+ >>> # pipe.transformer.set_attention_backend("_flash_3")
50
+
51
+ >>> prompt = "一幅为名为“造相「Z-IMAGE-TURBO」”的项目设计的创意海报。画面巧妙地将文字概念视觉化:一辆复古蒸汽小火车化身为巨大的拉链头,正拉开厚厚的冬日积雪,展露出一个生机盎然的春天。"
52
+ >>> image = pipe(
53
+ diffusers. prompt,
54
+ diffusers. height=1024,
55
+ diffusers. width=1024,
56
+ diffusers. num_inference_steps=9,
57
+ diffusers. guidance_scale=0.0,
58
+ diffusers. generator=torch.Generator("cuda").manual_seed(42),
59
+ diffusers. ).images[0]
60
+ >>> image.save("zimage.png")
61
+ ```
62
+ """
63
+
64
+
65
+ # Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
66
+ def calculate_shift(
67
+ image_seq_len,
68
+ base_seq_len: int = 256,
69
+ max_seq_len: int = 4096,
70
+ base_shift: float = 0.5,
71
+ max_shift: float = 1.15,
72
+ ):
73
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
74
+ b = base_shift - m * base_seq_len
75
+ mu = image_seq_len * m + b
76
+ return mu
77
+
78
+
79
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
80
+ def retrieve_timesteps(
81
+ scheduler,
82
+ num_inference_steps: Optional[int] = None,
83
+ device: Optional[Union[str, torch.device]] = None,
84
+ timesteps: Optional[List[int]] = None,
85
+ sigmas: Optional[List[float]] = None,
86
+ **kwargs,
87
+ ):
88
+ r"""
89
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
90
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
91
+
92
+ Args:
93
+ scheduler (`SchedulerMixin`):
94
+ The scheduler to get timesteps from.
95
+ num_inference_steps (`int`):
96
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
97
+ must be `None`.
98
+ device (`str` or `torch.device`, *optional*):
99
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
100
+ timesteps (`List[int]`, *optional*):
101
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
102
+ `num_inference_steps` and `sigmas` must be `None`.
103
+ sigmas (`List[float]`, *optional*):
104
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
105
+ `num_inference_steps` and `timesteps` must be `None`.
106
+
107
+ Returns:
108
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
109
+ second element is the number of inference steps.
110
+ """
111
+ if timesteps is not None and sigmas is not None:
112
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
113
+ if timesteps is not None:
114
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
115
+ if not accepts_timesteps:
116
+ raise ValueError(
117
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
118
+ f" timestep schedules. Please check whether you are using the correct scheduler."
119
+ )
120
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
121
+ timesteps = scheduler.timesteps
122
+ num_inference_steps = len(timesteps)
123
+ elif sigmas is not None:
124
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
125
+ if not accept_sigmas:
126
+ raise ValueError(
127
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
128
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
129
+ )
130
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
131
+ timesteps = scheduler.timesteps
132
+ num_inference_steps = len(timesteps)
133
+ else:
134
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
135
+ timesteps = scheduler.timesteps
136
+ return timesteps, num_inference_steps
137
+
138
+
139
+ @dataclass
140
+ class ZImagePipelineOutput(BaseOutput):
141
+ """
142
+ Output class for Z-Image image generation pipelines.
143
+
144
+ Args:
145
+ images (`List[PIL.Image.Image]` or `torch.Tensor` or `np.ndarray`)
146
+ List of denoised PIL images of length `batch_size` or numpy array or torch tensor of shape `(batch_size,
147
+ height, width, num_channels)`. PIL images or numpy array present the denoised images of the diffusion
148
+ pipeline. Torch tensors can represent either the denoised images or the intermediate latents ready to be
149
+ passed to the decoder.
150
+ """
151
+
152
+ images: Union[List[PIL.Image.Image], np.ndarray]
153
+
154
+
155
+ class ZImageControlPipeline(DiffusionPipeline, FromSingleFileMixin):
156
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
157
+ _optional_components = []
158
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
159
+
160
+ def __init__(
161
+ self,
162
+ scheduler: FlowMatchEulerDiscreteScheduler,
163
+ vae: AutoencoderKL,
164
+ text_encoder: PreTrainedModel,
165
+ tokenizer: AutoTokenizer,
166
+ transformer: ZImageTransformer2DModel,
167
+ ):
168
+ super().__init__()
169
+
170
+ self.register_modules(
171
+ vae=vae,
172
+ text_encoder=text_encoder,
173
+ tokenizer=tokenizer,
174
+ scheduler=scheduler,
175
+ transformer=transformer,
176
+ )
177
+ self.vae_scale_factor = (
178
+ 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
179
+ )
180
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
181
+ self.mask_processor = VaeImageProcessor(
182
+ vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
183
+ )
184
+
185
+ def encode_prompt(
186
+ self,
187
+ prompt: Union[str, List[str]],
188
+ device: Optional[torch.device] = None,
189
+ do_classifier_free_guidance: bool = True,
190
+ negative_prompt: Optional[Union[str, List[str]]] = None,
191
+ prompt_embeds: Optional[List[torch.FloatTensor]] = None,
192
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
193
+ max_sequence_length: int = 512,
194
+ ):
195
+ prompt = [prompt] if isinstance(prompt, str) else prompt
196
+ prompt_embeds = self._encode_prompt(
197
+ prompt=prompt,
198
+ device=device,
199
+ prompt_embeds=prompt_embeds,
200
+ max_sequence_length=max_sequence_length,
201
+ )
202
+
203
+ if do_classifier_free_guidance:
204
+ if negative_prompt is None:
205
+ negative_prompt = ["" for _ in prompt]
206
+ else:
207
+ negative_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
208
+ assert len(prompt) == len(negative_prompt)
209
+ negative_prompt_embeds = self._encode_prompt(
210
+ prompt=negative_prompt,
211
+ device=device,
212
+ prompt_embeds=negative_prompt_embeds,
213
+ max_sequence_length=max_sequence_length,
214
+ )
215
+ else:
216
+ negative_prompt_embeds = []
217
+ return prompt_embeds, negative_prompt_embeds
218
+
219
+ def _encode_prompt(
220
+ self,
221
+ prompt: Union[str, List[str]],
222
+ device: Optional[torch.device] = None,
223
+ prompt_embeds: Optional[List[torch.FloatTensor]] = None,
224
+ max_sequence_length: int = 512,
225
+ ) -> List[torch.FloatTensor]:
226
+ device = device or self._execution_device
227
+
228
+ if prompt_embeds is not None:
229
+ return prompt_embeds
230
+
231
+ if isinstance(prompt, str):
232
+ prompt = [prompt]
233
+
234
+ for i, prompt_item in enumerate(prompt):
235
+ messages = [
236
+ {"role": "user", "content": prompt_item},
237
+ ]
238
+ prompt_item = self.tokenizer.apply_chat_template(
239
+ messages,
240
+ tokenize=False,
241
+ add_generation_prompt=True,
242
+ enable_thinking=True,
243
+ )
244
+ prompt[i] = prompt_item
245
+
246
+ text_inputs = self.tokenizer(
247
+ prompt,
248
+ padding="max_length",
249
+ max_length=max_sequence_length,
250
+ truncation=True,
251
+ return_tensors="pt",
252
+ )
253
+
254
+ text_input_ids = text_inputs.input_ids.to(device)
255
+ prompt_masks = text_inputs.attention_mask.to(device).bool()
256
+
257
+ prompt_embeds = self.text_encoder(
258
+ input_ids=text_input_ids,
259
+ attention_mask=prompt_masks,
260
+ output_hidden_states=True,
261
+ ).hidden_states[-2]
262
+
263
+ embeddings_list = []
264
+
265
+ for i in range(len(prompt_embeds)):
266
+ embeddings_list.append(prompt_embeds[i][prompt_masks[i]])
267
+
268
+ return embeddings_list
269
+
270
+ def prepare_latents(
271
+ self,
272
+ batch_size,
273
+ num_channels_latents,
274
+ height,
275
+ width,
276
+ dtype,
277
+ device,
278
+ generator,
279
+ latents=None,
280
+ ):
281
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
282
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
283
+
284
+ shape = (batch_size, num_channels_latents, height, width)
285
+
286
+ if latents is None:
287
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
288
+ else:
289
+ if latents.shape != shape:
290
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
291
+ latents = latents.to(device)
292
+ return latents
293
+
294
+ @property
295
+ def guidance_scale(self):
296
+ return self._guidance_scale
297
+
298
+ @property
299
+ def do_classifier_free_guidance(self):
300
+ return self._guidance_scale > 1
301
+
302
+ @property
303
+ def joint_attention_kwargs(self):
304
+ return self._joint_attention_kwargs
305
+
306
+ @property
307
+ def num_timesteps(self):
308
+ return self._num_timesteps
309
+
310
+ @property
311
+ def interrupt(self):
312
+ return self._interrupt
313
+
314
+ @torch.no_grad()
315
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
316
+ def __call__(
317
+ self,
318
+ prompt: Union[str, List[str]] = None,
319
+ height: Optional[int] = None,
320
+ width: Optional[int] = None,
321
+
322
+ control_image: Union[torch.FloatTensor] = None,
323
+ control_context_scale: float = 1.0,
324
+
325
+ num_inference_steps: int = 50,
326
+ sigmas: Optional[List[float]] = None,
327
+ guidance_scale: float = 5.0,
328
+ cfg_normalization: bool = False,
329
+ cfg_truncation: float = 1.0,
330
+ negative_prompt: Optional[Union[str, List[str]]] = None,
331
+ num_images_per_prompt: Optional[int] = 1,
332
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
333
+ latents: Optional[torch.FloatTensor] = None,
334
+ prompt_embeds: Optional[List[torch.FloatTensor]] = None,
335
+ negative_prompt_embeds: Optional[List[torch.FloatTensor]] = None,
336
+ output_type: Optional[str] = "pil",
337
+ return_dict: bool = True,
338
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
339
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
340
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
341
+ max_sequence_length: int = 512,
342
+ ):
343
+ r"""
344
+ Function invoked when calling the pipeline for generation.
345
+
346
+ Args:
347
+ prompt (`str` or `List[str]`, *optional*):
348
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
349
+ instead.
350
+ height (`int`, *optional*, defaults to 1024):
351
+ The height in pixels of the generated image.
352
+ width (`int`, *optional*, defaults to 1024):
353
+ The width in pixels of the generated image.
354
+ num_inference_steps (`int`, *optional*, defaults to 50):
355
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
356
+ expense of slower inference.
357
+ sigmas (`List[float]`, *optional*):
358
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
359
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
360
+ will be used.
361
+ guidance_scale (`float`, *optional*, defaults to 5.0):
362
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
363
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
364
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
365
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
366
+ usually at the expense of lower image quality.
367
+ cfg_normalization (`bool`, *optional*, defaults to False):
368
+ Whether to apply configuration normalization.
369
+ cfg_truncation (`float`, *optional*, defaults to 1.0):
370
+ The truncation value for configuration.
371
+ negative_prompt (`str` or `List[str]`, *optional*):
372
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
373
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
374
+ less than `1`).
375
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
376
+ The number of images to generate per prompt.
377
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
378
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
379
+ to make generation deterministic.
380
+ latents (`torch.FloatTensor`, *optional*):
381
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
382
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
383
+ tensor will be generated by sampling using the supplied random `generator`.
384
+ prompt_embeds (`List[torch.FloatTensor]`, *optional*):
385
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
386
+ provided, text embeddings will be generated from `prompt` input argument.
387
+ negative_prompt_embeds (`List[torch.FloatTensor]`, *optional*):
388
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
389
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
390
+ argument.
391
+ output_type (`str`, *optional*, defaults to `"pil"`):
392
+ The output format of the generate image. Choose between
393
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
394
+ return_dict (`bool`, *optional*, defaults to `True`):
395
+ Whether or not to return a [`~pipelines.stable_diffusion.ZImagePipelineOutput`] instead of a plain
396
+ tuple.
397
+ joint_attention_kwargs (`dict`, *optional*):
398
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
399
+ `self.processor` in
400
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
401
+ callback_on_step_end (`Callable`, *optional*):
402
+ A function that calls at the end of each denoising steps during the inference. The function is called
403
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
404
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
405
+ `callback_on_step_end_tensor_inputs`.
406
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
407
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
408
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
409
+ `._callback_tensor_inputs` attribute of your pipeline class.
410
+ max_sequence_length (`int`, *optional*, defaults to 512):
411
+ Maximum sequence length to use with the `prompt`.
412
+
413
+ Examples:
414
+
415
+ Returns:
416
+ [`~pipelines.z_image.ZImagePipelineOutput`] or `tuple`: [`~pipelines.z_image.ZImagePipelineOutput`] if
417
+ `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the
418
+ generated images.
419
+ """
420
+ height = height or 1024
421
+ width = width or 1024
422
+
423
+ vae_scale = self.vae_scale_factor * 2
424
+ if height % vae_scale != 0:
425
+ raise ValueError(
426
+ f"Height must be divisible by {vae_scale} (got {height}). "
427
+ f"Please adjust the height to a multiple of {vae_scale}."
428
+ )
429
+ if width % vae_scale != 0:
430
+ raise ValueError(
431
+ f"Width must be divisible by {vae_scale} (got {width}). "
432
+ f"Please adjust the width to a multiple of {vae_scale}."
433
+ )
434
+
435
+ self._guidance_scale = guidance_scale
436
+ self._joint_attention_kwargs = joint_attention_kwargs
437
+ self._interrupt = False
438
+ self._cfg_normalization = cfg_normalization
439
+ self._cfg_truncation = cfg_truncation
440
+ # 2. Define call parameters
441
+ if prompt is not None and isinstance(prompt, str):
442
+ batch_size = 1
443
+ elif prompt is not None and isinstance(prompt, list):
444
+ batch_size = len(prompt)
445
+ else:
446
+ batch_size = len(prompt_embeds)
447
+
448
+ device = self._execution_device
449
+ weight_dtype = self.text_encoder.dtype
450
+ num_channels_latents = self.transformer.in_channels
451
+
452
+ if control_image is not None:
453
+ control_image = self.image_processor.preprocess(control_image, height=height, width=width)
454
+ control_image = control_image.to(dtype=weight_dtype, device=device)
455
+ control_latents = self.vae.encode(control_image)[0].mode()
456
+ control_latents = (control_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
457
+ else:
458
+ control_latents = torch.zeros_like(inpaint_latent)
459
+
460
+ control_context = control_latents.unsqueeze(2)
461
+
462
+ # If prompt_embeds is provided and prompt is None, skip encoding
463
+ if prompt_embeds is not None and prompt is None:
464
+ if self.do_classifier_free_guidance and negative_prompt_embeds is None:
465
+ raise ValueError(
466
+ "When `prompt_embeds` is provided without `prompt`, "
467
+ "`negative_prompt_embeds` must also be provided for classifier-free guidance."
468
+ )
469
+ else:
470
+ (
471
+ prompt_embeds,
472
+ negative_prompt_embeds,
473
+ ) = self.encode_prompt(
474
+ prompt=prompt,
475
+ negative_prompt=negative_prompt,
476
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
477
+ prompt_embeds=prompt_embeds,
478
+ negative_prompt_embeds=negative_prompt_embeds,
479
+ device=device,
480
+ max_sequence_length=max_sequence_length,
481
+ )
482
+
483
+ # 4. Prepare latent variables
484
+ latents = self.prepare_latents(
485
+ batch_size * num_images_per_prompt,
486
+ num_channels_latents,
487
+ height,
488
+ width,
489
+ torch.float32,
490
+ device,
491
+ generator,
492
+ latents,
493
+ )
494
+
495
+ # Repeat prompt_embeds for num_images_per_prompt
496
+ if num_images_per_prompt > 1:
497
+ prompt_embeds = [pe for pe in prompt_embeds for _ in range(num_images_per_prompt)]
498
+ if self.do_classifier_free_guidance and negative_prompt_embeds:
499
+ negative_prompt_embeds = [npe for npe in negative_prompt_embeds for _ in range(num_images_per_prompt)]
500
+
501
+ actual_batch_size = batch_size * num_images_per_prompt
502
+ image_seq_len = (latents.shape[2] // 2) * (latents.shape[3] // 2)
503
+
504
+ # 5. Prepare timesteps
505
+ mu = calculate_shift(
506
+ image_seq_len,
507
+ self.scheduler.config.get("base_image_seq_len", 256),
508
+ self.scheduler.config.get("max_image_seq_len", 4096),
509
+ self.scheduler.config.get("base_shift", 0.5),
510
+ self.scheduler.config.get("max_shift", 1.15),
511
+ )
512
+ self.scheduler.sigma_min = 0.0
513
+ scheduler_kwargs = {"mu": mu}
514
+ timesteps, num_inference_steps = retrieve_timesteps(
515
+ self.scheduler,
516
+ num_inference_steps,
517
+ device,
518
+ sigmas=sigmas,
519
+ **scheduler_kwargs,
520
+ )
521
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
522
+ self._num_timesteps = len(timesteps)
523
+
524
+ # 6. Denoising loop
525
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
526
+ for i, t in enumerate(timesteps):
527
+ if self.interrupt:
528
+ continue
529
+
530
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
531
+ timestep = t.expand(latents.shape[0])
532
+ timestep = (1000 - timestep) / 1000
533
+ # Normalized time for time-aware config (0 at start, 1 at end)
534
+ t_norm = timestep[0].item()
535
+
536
+ # Handle cfg truncation
537
+ current_guidance_scale = self.guidance_scale
538
+ if (
539
+ self.do_classifier_free_guidance
540
+ and self._cfg_truncation is not None
541
+ and float(self._cfg_truncation) <= 1
542
+ ):
543
+ if t_norm > self._cfg_truncation:
544
+ current_guidance_scale = 0.0
545
+
546
+ # Run CFG only if configured AND scale is non-zero
547
+ apply_cfg = self.do_classifier_free_guidance and current_guidance_scale > 0
548
+
549
+ if apply_cfg:
550
+ latents_typed = latents.to(self.transformer.dtype)
551
+ latent_model_input = latents_typed.repeat(2, 1, 1, 1)
552
+ prompt_embeds_model_input = prompt_embeds + negative_prompt_embeds
553
+ timestep_model_input = timestep.repeat(2)
554
+ else:
555
+ latent_model_input = latents.to(self.transformer.dtype)
556
+ prompt_embeds_model_input = prompt_embeds
557
+ timestep_model_input = timestep
558
+
559
+ latent_model_input = latent_model_input.unsqueeze(2)
560
+ latent_model_input_list = list(latent_model_input.unbind(dim=0))
561
+
562
+ model_out_list = self.transformer(
563
+ latent_model_input_list,
564
+ timestep_model_input,
565
+ prompt_embeds_model_input,
566
+ control_context=control_context,
567
+ control_context_scale=control_context_scale,
568
+ )[0]
569
+
570
+ if apply_cfg:
571
+ # Perform CFG
572
+ pos_out = model_out_list[:actual_batch_size]
573
+ neg_out = model_out_list[actual_batch_size:]
574
+
575
+ noise_pred = []
576
+ for j in range(actual_batch_size):
577
+ pos = pos_out[j].float()
578
+ neg = neg_out[j].float()
579
+
580
+ pred = pos + current_guidance_scale * (pos - neg)
581
+
582
+ # Renormalization
583
+ if self._cfg_normalization and float(self._cfg_normalization) > 0.0:
584
+ ori_pos_norm = torch.linalg.vector_norm(pos)
585
+ new_pos_norm = torch.linalg.vector_norm(pred)
586
+ max_new_norm = ori_pos_norm * float(self._cfg_normalization)
587
+ if new_pos_norm > max_new_norm:
588
+ pred = pred * (max_new_norm / new_pos_norm)
589
+
590
+ noise_pred.append(pred)
591
+
592
+ noise_pred = torch.stack(noise_pred, dim=0)
593
+ else:
594
+ noise_pred = torch.stack([t.float() for t in model_out_list], dim=0)
595
+
596
+ noise_pred = noise_pred.squeeze(2)
597
+ noise_pred = -noise_pred
598
+
599
+ # compute the previous noisy sample x_t -> x_t-1
600
+ latents = self.scheduler.step(noise_pred.to(torch.float32), t, latents, return_dict=False)[0]
601
+ assert latents.dtype == torch.float32
602
+
603
+ if callback_on_step_end is not None:
604
+ callback_kwargs = {}
605
+ for k in callback_on_step_end_tensor_inputs:
606
+ callback_kwargs[k] = locals()[k]
607
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
608
+
609
+ latents = callback_outputs.pop("latents", latents)
610
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
611
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
612
+
613
+ # call the callback, if provided
614
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
615
+ progress_bar.update()
616
+
617
+ if output_type == "latent":
618
+ image = latents
619
+
620
+ else:
621
+ latents = latents.to(self.vae.dtype)
622
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
623
+
624
+ image = self.vae.decode(latents, return_dict=False)[0]
625
+ image = self.image_processor.postprocess(image, output_type=output_type)
626
+
627
+ # Offload all models
628
+ self.maybe_free_model_hooks()
629
+
630
+ if not return_dict:
631
+ return (image,)
632
+
633
+ return ZImagePipelineOutput(images=image)
videox_fun/reward/MPS/README.md ADDED
@@ -0,0 +1 @@
 
 
1
+ This folder is modified from the official [MPS](https://github.com/Kwai-Kolors/MPS/tree/main) repository.